TestServer.swift 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. /*
  2. * Copyright 2024, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import GRPCCore
  17. import NIOConcurrencyHelpers
  18. import NIOCore
  19. import NIOHTTP2
  20. import NIOPosix
  21. @testable import GRPCHTTP2Core
  22. @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
  23. final class TestServer: Sendable {
  24. private let eventLoopGroup: any EventLoopGroup
  25. private typealias Stream = NIOAsyncChannel<RPCRequestPart, RPCResponsePart>
  26. private typealias Multiplexer = NIOHTTP2AsyncSequence<Stream>
  27. private let connected: NIOLockedValueBox<[Channel]>
  28. typealias Inbound = NIOAsyncChannelInboundStream<RPCRequestPart>
  29. typealias Outbound = NIOAsyncChannelOutboundWriter<RPCResponsePart>
  30. private let server: NIOLockedValueBox<NIOAsyncChannel<Multiplexer, Never>?>
  31. init(eventLoopGroup: any EventLoopGroup) {
  32. self.eventLoopGroup = eventLoopGroup
  33. self.server = NIOLockedValueBox(nil)
  34. self.connected = NIOLockedValueBox([])
  35. }
  36. enum Target {
  37. case localhost
  38. case uds(String)
  39. }
  40. var clients: [Channel] {
  41. return self.connected.withLockedValue { $0 }
  42. }
  43. func bind(to target: Target = .localhost) async throws -> GRPCHTTP2Core.SocketAddress {
  44. precondition(self.server.withLockedValue { $0 } == nil)
  45. @Sendable
  46. func configure(_ channel: Channel) -> EventLoopFuture<Multiplexer> {
  47. self.connected.withLockedValue {
  48. $0.append(channel)
  49. }
  50. channel.closeFuture.whenSuccess {
  51. self.connected.withLockedValue { connected in
  52. guard let index = connected.firstIndex(where: { $0 === channel }) else { return }
  53. connected.remove(at: index)
  54. }
  55. }
  56. return channel.eventLoop.makeCompletedFuture {
  57. let sync = channel.pipeline.syncOperations
  58. let multiplexer = try sync.configureAsyncHTTP2Pipeline(mode: .server) { stream in
  59. stream.eventLoop.makeCompletedFuture {
  60. let handler = GRPCServerStreamHandler(
  61. scheme: .http,
  62. acceptedEncodings: .all,
  63. maximumPayloadSize: .max,
  64. methodDescriptorPromise: channel.eventLoop.makePromise(of: MethodDescriptor.self)
  65. )
  66. try stream.pipeline.syncOperations.addHandlers(handler)
  67. return try NIOAsyncChannel(
  68. wrappingChannelSynchronously: stream,
  69. configuration: .init(
  70. inboundType: RPCRequestPart.self,
  71. outboundType: RPCResponsePart.self
  72. )
  73. )
  74. }
  75. }
  76. return multiplexer.inbound
  77. }
  78. }
  79. let bootstrap = ServerBootstrap(group: self.eventLoopGroup)
  80. let server: NIOAsyncChannel<Multiplexer, Never>
  81. let address: GRPCHTTP2Core.SocketAddress
  82. switch target {
  83. case .localhost:
  84. server = try await bootstrap.bind(host: "127.0.0.1", port: 0) { channel in
  85. configure(channel)
  86. }
  87. address = .ipv4(host: "127.0.0.1", port: server.channel.localAddress!.port!)
  88. case .uds(let path):
  89. server = try await bootstrap.bind(unixDomainSocketPath: path, cleanupExistingSocketFile: true)
  90. { channel in
  91. configure(channel)
  92. }
  93. address = .unixDomainSocket(path: server.channel.localAddress!.pathname!)
  94. }
  95. self.server.withLockedValue { $0 = server }
  96. return address
  97. }
  98. func run(_ handle: @Sendable @escaping (Inbound, Outbound) async throws -> Void) async throws {
  99. guard let server = self.server.withLockedValue({ $0 }) else {
  100. fatalError("bind() must be called first")
  101. }
  102. try await server.executeThenClose { inbound, _ in
  103. try await withThrowingTaskGroup(of: Void.self) { multiplexerGroup in
  104. for try await multiplexer in inbound {
  105. multiplexerGroup.addTask {
  106. try await withThrowingTaskGroup(of: Void.self) { streamGroup in
  107. for try await stream in multiplexer {
  108. streamGroup.addTask {
  109. try await stream.executeThenClose { inbound, outbound in
  110. try await handle(inbound, outbound)
  111. }
  112. }
  113. }
  114. }
  115. }
  116. }
  117. }
  118. }
  119. }
  120. }