2
0

TestServer.swift 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. /*
  2. * Copyright 2024, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import GRPCCore
  17. import NIOConcurrencyHelpers
  18. import NIOCore
  19. import NIOHTTP2
  20. import NIOPosix
  21. import XCTest
  22. @testable import GRPCHTTP2Core
  23. @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
  24. final class TestServer: Sendable {
  25. private let eventLoopGroup: any EventLoopGroup
  26. private typealias Stream = NIOAsyncChannel<RPCRequestPart, RPCResponsePart>
  27. private typealias Multiplexer = NIOHTTP2AsyncSequence<Stream>
  28. private let connected: NIOLockedValueBox<[any Channel]>
  29. typealias Inbound = NIOAsyncChannelInboundStream<RPCRequestPart>
  30. typealias Outbound = NIOAsyncChannelOutboundWriter<RPCResponsePart>
  31. private let server: NIOLockedValueBox<NIOAsyncChannel<Multiplexer, Never>?>
  32. init(eventLoopGroup: any EventLoopGroup) {
  33. self.eventLoopGroup = eventLoopGroup
  34. self.server = NIOLockedValueBox(nil)
  35. self.connected = NIOLockedValueBox([])
  36. }
  37. enum Target {
  38. case localhost
  39. case uds(String)
  40. }
  41. var clients: [any Channel] {
  42. return self.connected.withLockedValue { $0 }
  43. }
  44. func bind(to target: Target = .localhost) async throws -> GRPCHTTP2Core.SocketAddress {
  45. precondition(self.server.withLockedValue { $0 } == nil)
  46. @Sendable
  47. func configure(_ channel: any Channel) -> EventLoopFuture<Multiplexer> {
  48. self.connected.withLockedValue {
  49. $0.append(channel)
  50. }
  51. channel.closeFuture.whenSuccess {
  52. self.connected.withLockedValue { connected in
  53. guard let index = connected.firstIndex(where: { $0 === channel }) else { return }
  54. connected.remove(at: index)
  55. }
  56. }
  57. return channel.eventLoop.makeCompletedFuture {
  58. let sync = channel.pipeline.syncOperations
  59. let multiplexer = try sync.configureAsyncHTTP2Pipeline(mode: .server) { stream in
  60. stream.eventLoop.makeCompletedFuture {
  61. let handler = GRPCServerStreamHandler(
  62. scheme: .http,
  63. acceptedEncodings: .all,
  64. maximumPayloadSize: .max,
  65. methodDescriptorPromise: channel.eventLoop.makePromise(of: MethodDescriptor.self)
  66. )
  67. try stream.pipeline.syncOperations.addHandlers(handler)
  68. return try NIOAsyncChannel(
  69. wrappingChannelSynchronously: stream,
  70. configuration: .init(
  71. inboundType: RPCRequestPart.self,
  72. outboundType: RPCResponsePart.self
  73. )
  74. )
  75. }
  76. }
  77. return multiplexer.inbound
  78. }
  79. }
  80. let bootstrap = ServerBootstrap(group: self.eventLoopGroup)
  81. let server: NIOAsyncChannel<Multiplexer, Never>
  82. let address: GRPCHTTP2Core.SocketAddress
  83. switch target {
  84. case .localhost:
  85. server = try await bootstrap.bind(host: "127.0.0.1", port: 0) { channel in
  86. configure(channel)
  87. }
  88. address = .ipv4(host: "127.0.0.1", port: server.channel.localAddress!.port!)
  89. case .uds(let path):
  90. server = try await bootstrap.bind(unixDomainSocketPath: path, cleanupExistingSocketFile: true)
  91. { channel in
  92. configure(channel)
  93. }
  94. address = .unixDomainSocket(path: server.channel.localAddress!.pathname!)
  95. }
  96. self.server.withLockedValue { $0 = server }
  97. return address
  98. }
  99. func run(_ handle: @Sendable @escaping (Inbound, Outbound) async throws -> Void) async throws {
  100. guard let server = self.server.withLockedValue({ $0 }) else {
  101. fatalError("bind() must be called first")
  102. }
  103. do {
  104. try await server.executeThenClose { inbound, _ in
  105. try await withThrowingTaskGroup(of: Void.self) { multiplexerGroup in
  106. for try await multiplexer in inbound {
  107. multiplexerGroup.addTask {
  108. try await withThrowingTaskGroup(of: Void.self) { streamGroup in
  109. for try await stream in multiplexer {
  110. streamGroup.addTask {
  111. try await stream.executeThenClose { inbound, outbound in
  112. try await handle(inbound, outbound)
  113. }
  114. }
  115. }
  116. }
  117. }
  118. }
  119. }
  120. }
  121. } catch is CancellationError {
  122. ()
  123. }
  124. }
  125. }
  126. @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
  127. extension TestServer {
  128. enum RunHandler {
  129. case echo
  130. case never
  131. }
  132. func run(_ handler: RunHandler) async throws {
  133. switch handler {
  134. case .echo:
  135. try await self.run { inbound, outbound in
  136. for try await part in inbound {
  137. switch part {
  138. case .metadata:
  139. try await outbound.write(.metadata([:]))
  140. case .message(let bytes):
  141. try await outbound.write(.message(bytes))
  142. }
  143. }
  144. try await outbound.write(.status(Status(code: .ok, message: ""), [:]))
  145. }
  146. case .never:
  147. try await self.run { inbound, outbound in
  148. XCTFail("Unexpected stream")
  149. try await outbound.write(.status(Status(code: .unavailable, message: ""), [:]))
  150. outbound.finish()
  151. }
  152. }
  153. }
  154. }