CommonHTTP2ServerTransport.swift 8.7 KB


  1. /*
  2. * Copyright 2024, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. package import GRPCCore
  17. package import NIOCore
  18. package import NIOExtras
  19. private import NIOHTTP2
  20. private import Synchronization
  21. /// Provides the common functionality for a `NIO`-based server transport.
  22. ///
  23. /// - SeeAlso: ``HTTP2ListenerFactory``.
  24. @available(macOS 15.0, iOS 18.0, watchOS 11.0, tvOS 18.0, visionOS 2.0, *)
  25. package final class CommonHTTP2ServerTransport<
  26. ListenerFactory: HTTP2ListenerFactory
  27. >: ServerTransport, ListeningServerTransport {
  28. private let eventLoopGroup: any EventLoopGroup
  29. private let address: SocketAddress
  30. private let listeningAddressState: Mutex<State>
  31. private let serverQuiescingHelper: ServerQuiescingHelper
  32. private let factory: ListenerFactory
  33. private enum State {
  34. case idle(EventLoopPromise<SocketAddress>)
  35. case listening(EventLoopFuture<SocketAddress>)
  36. case closedOrInvalidAddress(RuntimeError)
  37. var listeningAddressFuture: EventLoopFuture<SocketAddress> {
  38. get throws {
  39. switch self {
  40. case .idle(let eventLoopPromise):
  41. return eventLoopPromise.futureResult
  42. case .listening(let eventLoopFuture):
  43. return eventLoopFuture
  44. case .closedOrInvalidAddress(let runtimeError):
  45. throw runtimeError
  46. }
  47. }
  48. }
  49. enum OnBound {
  50. case succeedPromise(_ promise: EventLoopPromise<SocketAddress>, address: SocketAddress)
  51. case failPromise(_ promise: EventLoopPromise<SocketAddress>, error: RuntimeError)
  52. }
  53. mutating func addressBound(
  54. _ address: NIOCore.SocketAddress?,
  55. userProvidedAddress: SocketAddress
  56. ) -> OnBound {
  57. switch self {
  58. case .idle(let listeningAddressPromise):
  59. if let address {
  60. self = .listening(listeningAddressPromise.futureResult)
  61. return .succeedPromise(listeningAddressPromise, address: SocketAddress(address))
  62. } else if userProvidedAddress.virtualSocket != nil {
  63. self = .listening(listeningAddressPromise.futureResult)
  64. return .succeedPromise(listeningAddressPromise, address: userProvidedAddress)
  65. } else {
  66. assertionFailure("Unknown address type")
  67. let invalidAddressError = RuntimeError(
  68. code: .transportError,
  69. message: "Unknown address type returned by transport."
  70. )
  71. self = .closedOrInvalidAddress(invalidAddressError)
  72. return .failPromise(listeningAddressPromise, error: invalidAddressError)
  73. }
  74. case .listening, .closedOrInvalidAddress:
  75. fatalError("Invalid state: addressBound should only be called once and when in idle state")
  76. }
  77. }
  78. enum OnClose {
  79. case failPromise(EventLoopPromise<SocketAddress>, error: RuntimeError)
  80. case doNothing
  81. }
  82. mutating func close() -> OnClose {
  83. let serverStoppedError = RuntimeError(
  84. code: .serverIsStopped,
  85. message: """
  86. There is no listening address bound for this server: there may have been \
  87. an error which caused the transport to close, or it may have shut down.
  88. """
  89. )
  90. switch self {
  91. case .idle(let listeningAddressPromise):
  92. self = .closedOrInvalidAddress(serverStoppedError)
  93. return .failPromise(listeningAddressPromise, error: serverStoppedError)
  94. case .listening:
  95. self = .closedOrInvalidAddress(serverStoppedError)
  96. return .doNothing
  97. case .closedOrInvalidAddress:
  98. return .doNothing
  99. }
  100. }
  101. }
  102. /// The listening address for this server transport.
  103. ///
  104. /// It is an `async` property because it will only return once the address has been successfully bound.
  105. ///
  106. /// - Throws: A runtime error will be thrown if the address could not be bound or is not bound any
  107. /// longer, because the transport isn't listening anymore. It can also throw if the transport returned an
  108. /// invalid address.
  109. package var listeningAddress: SocketAddress {
  110. get async throws {
  111. try await self.listeningAddressState
  112. .withLock { try $0.listeningAddressFuture }
  113. .get()
  114. }
  115. }
  116. package init(
  117. address: SocketAddress,
  118. eventLoopGroup: any EventLoopGroup,
  119. quiescingHelper: ServerQuiescingHelper,
  120. listenerFactory: ListenerFactory
  121. ) {
  122. self.eventLoopGroup = eventLoopGroup
  123. self.address = address
  124. let eventLoop = eventLoopGroup.any()
  125. self.listeningAddressState = Mutex(.idle(eventLoop.makePromise()))
  126. self.factory = listenerFactory
  127. self.serverQuiescingHelper = quiescingHelper
  128. }
  129. package func listen(
  130. streamHandler: @escaping @Sendable (
  131. _ stream: RPCStream<Inbound, Outbound>,
  132. _ context: ServerContext
  133. ) async -> Void
  134. ) async throws {
  135. defer {
  136. switch self.listeningAddressState.withLock({ $0.close() }) {
  137. case .failPromise(let promise, let error):
  138. promise.fail(error)
  139. case .doNothing:
  140. ()
  141. }
  142. }
  143. let serverChannel = try await self.factory.makeListeningChannel(
  144. eventLoopGroup: self.eventLoopGroup,
  145. address: self.address,
  146. serverQuiescingHelper: self.serverQuiescingHelper
  147. )
  148. let action = self.listeningAddressState.withLock {
  149. $0.addressBound(
  150. serverChannel.channel.localAddress,
  151. userProvidedAddress: self.address
  152. )
  153. }
  154. switch action {
  155. case .succeedPromise(let promise, let address):
  156. promise.succeed(address)
  157. case .failPromise(let promise, let error):
  158. promise.fail(error)
  159. }
  160. try await serverChannel.executeThenClose { inbound in
  161. try await withThrowingDiscardingTaskGroup { group in
  162. for try await (connectionChannel, streamMultiplexer) in inbound {
  163. group.addTask {
  164. try await self.handleConnection(
  165. connectionChannel,
  166. multiplexer: streamMultiplexer,
  167. streamHandler: streamHandler
  168. )
  169. }
  170. }
  171. }
  172. }
  173. }
  174. private func handleConnection(
  175. _ connection: NIOAsyncChannel<HTTP2Frame, HTTP2Frame>,
  176. multiplexer: ChannelPipeline.SynchronousOperations.HTTP2StreamMultiplexer,
  177. streamHandler: @escaping @Sendable (
  178. _ stream: RPCStream<Inbound, Outbound>,
  179. _ context: ServerContext
  180. ) async -> Void
  181. ) async throws {
  182. try await connection.executeThenClose { inbound, _ in
  183. await withDiscardingTaskGroup { group in
  184. group.addTask {
  185. do {
  186. for try await _ in inbound {}
  187. } catch {
  188. // We don't want to close the channel if one connection throws.
  189. return
  190. }
  191. }
  192. do {
  193. for try await (stream, descriptor) in multiplexer.inbound {
  194. group.addTask {
  195. await self.handleStream(stream, handler: streamHandler, descriptor: descriptor)
  196. }
  197. }
  198. } catch {
  199. return
  200. }
  201. }
  202. }
  203. }
  204. private func handleStream(
  205. _ stream: NIOAsyncChannel<RPCRequestPart, RPCResponsePart>,
  206. handler streamHandler: @escaping @Sendable (
  207. _ stream: RPCStream<Inbound, Outbound>,
  208. _ context: ServerContext
  209. ) async -> Void,
  210. descriptor: EventLoopFuture<MethodDescriptor>
  211. ) async {
  212. // It's okay to ignore these errors:
  213. // - If we get an error because the http2Stream failed to close, then there's nothing we can do
  214. // - If we get an error because the inner closure threw, then the only possible scenario in which
  215. // that could happen is if methodDescriptor.get() throws - in which case, it means we never got
  216. // the RPC metadata, which means we can't do anything either and it's okay to just kill the stream.
  217. try? await stream.executeThenClose { inbound, outbound in
  218. guard let descriptor = try? await descriptor.get() else {
  219. return
  220. }
  221. let rpcStream = RPCStream(
  222. descriptor: descriptor,
  223. inbound: RPCAsyncSequence(wrapping: inbound),
  224. outbound: RPCWriter.Closable(
  225. wrapping: ServerConnection.Stream.Outbound(
  226. responseWriter: outbound,
  227. http2Stream: stream
  228. )
  229. )
  230. )
  231. let context = ServerContext(descriptor: descriptor)
  232. await streamHandler(rpcStream, context)
  233. }
  234. }
  235. package func beginGracefulShutdown() {
  236. self.serverQuiescingHelper.initiateShutdown(promise: nil)
  237. }
  238. }