2
0

GRPCServerStreamHandler.swift 8.9 KB


  1. /*
  2. * Copyright 2024, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import GRPCCore
  17. import NIOCore
  18. import NIOHTTP2
  19. @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
  20. final class GRPCServerStreamHandler: ChannelDuplexHandler, RemovableChannelHandler {
  21. typealias InboundIn = HTTP2Frame.FramePayload
  22. typealias InboundOut = RPCRequestPart
  23. typealias OutboundIn = RPCResponsePart
  24. typealias OutboundOut = HTTP2Frame.FramePayload
  25. private var stateMachine: GRPCStreamStateMachine
  26. private var isReading = false
  27. private var flushPending = false
  28. // We buffer the final status + trailers to avoid reordering issues (i.e.,
  29. // if there are messages still not written into the channel because flush has
  30. // not been called, but the server sends back trailers).
  31. private var pendingTrailers:
  32. (trailers: HTTP2Frame.FramePayload, promise: EventLoopPromise<Void>?)?
  33. private let methodDescriptorPromise: EventLoopPromise<MethodDescriptor>
  34. // Existential errors unconditionally allocate, avoid this per-use allocation by doing it
  35. // statically.
  36. private static let handlerRemovedBeforeDescriptorResolved: any Error = RPCError(
  37. code: .unavailable,
  38. message: "RPC stream was closed before we got any Metadata."
  39. )
  40. init(
  41. scheme: GRPCStreamStateMachineConfiguration.Scheme,
  42. acceptedEncodings: CompressionAlgorithmSet,
  43. maximumPayloadSize: Int,
  44. methodDescriptorPromise: EventLoopPromise<MethodDescriptor>,
  45. skipStateMachineAssertions: Bool = false
  46. ) {
  47. self.stateMachine = .init(
  48. configuration: .server(.init(scheme: scheme, acceptedEncodings: acceptedEncodings)),
  49. maximumPayloadSize: maximumPayloadSize,
  50. skipAssertions: skipStateMachineAssertions
  51. )
  52. self.methodDescriptorPromise = methodDescriptorPromise
  53. }
  54. }
  55. // - MARK: ChannelInboundHandler
  56. @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
  57. extension GRPCServerStreamHandler {
  58. func channelRead(context: ChannelHandlerContext, data: NIOAny) {
  59. self.isReading = true
  60. let frame = self.unwrapInboundIn(data)
  61. switch frame {
  62. case .data(let frameData):
  63. let endStream = frameData.endStream
  64. switch frameData.data {
  65. case .byteBuffer(let buffer):
  66. do {
  67. switch try self.stateMachine.receive(buffer: buffer, endStream: endStream) {
  68. case .endRPCAndForwardErrorStatus:
  69. preconditionFailure(
  70. "OnBufferReceivedAction.endRPCAndForwardErrorStatus should never be returned for the server."
  71. )
  72. case .readInbound:
  73. loop: while true {
  74. switch self.stateMachine.nextInboundMessage() {
  75. case .receiveMessage(let message):
  76. context.fireChannelRead(self.wrapInboundOut(.message(message)))
  77. case .awaitMoreMessages:
  78. break loop
  79. case .noMoreMessages:
  80. context.fireUserInboundEventTriggered(ChannelEvent.inputClosed)
  81. break loop
  82. }
  83. }
  84. case .doNothing:
  85. ()
  86. }
  87. } catch {
  88. context.fireErrorCaught(error)
  89. }
  90. case .fileRegion:
  91. preconditionFailure("Unexpected IOData.fileRegion")
  92. }
  93. case .headers(let headers):
  94. do {
  95. let action = try self.stateMachine.receive(
  96. headers: headers.headers,
  97. endStream: headers.endStream
  98. )
  99. switch action {
  100. case .receivedMetadata(let metadata, let methodDescriptor):
  101. if let methodDescriptor = methodDescriptor {
  102. self.methodDescriptorPromise.succeed(methodDescriptor)
  103. context.fireChannelRead(self.wrapInboundOut(.metadata(metadata)))
  104. } else {
  105. assertionFailure("Method descriptor should have been present if we received metadata.")
  106. }
  107. case .rejectRPC(let trailers):
  108. self.flushPending = true
  109. self.methodDescriptorPromise.fail(
  110. RPCError(
  111. code: .unavailable,
  112. message: "RPC was rejected."
  113. )
  114. )
  115. let response = HTTP2Frame.FramePayload.headers(.init(headers: trailers, endStream: true))
  116. context.write(self.wrapOutboundOut(response), promise: nil)
  117. case .receivedStatusAndMetadata:
  118. throw RPCError(
  119. code: .internalError,
  120. message: "Server cannot get receivedStatusAndMetadata."
  121. )
  122. case .protocolViolation:
  123. context.writeAndFlush(self.wrapOutboundOut(.rstStream(.protocolError)), promise: nil)
  124. context.close(promise: nil)
  125. case .doNothing:
  126. throw RPCError(code: .internalError, message: "Server cannot receive doNothing.")
  127. }
  128. } catch {
  129. context.fireErrorCaught(error)
  130. }
  131. case .rstStream:
  132. self.handleUnexpectedInboundClose(context: context, reason: .streamReset)
  133. case .ping, .goAway, .priority, .settings, .pushPromise, .windowUpdate,
  134. .alternativeService, .origin:
  135. ()
  136. }
  137. }
  138. func channelReadComplete(context: ChannelHandlerContext) {
  139. self.isReading = false
  140. if self.flushPending {
  141. self.flushPending = false
  142. context.flush()
  143. }
  144. context.fireChannelReadComplete()
  145. }
  146. func handlerRemoved(context: ChannelHandlerContext) {
  147. self.stateMachine.tearDown()
  148. self.methodDescriptorPromise.fail(Self.handlerRemovedBeforeDescriptorResolved)
  149. }
  150. func channelInactive(context: ChannelHandlerContext) {
  151. self.handleUnexpectedInboundClose(context: context, reason: .channelInactive)
  152. context.fireChannelInactive()
  153. }
  154. func errorCaught(context: ChannelHandlerContext, error: any Error) {
  155. self.handleUnexpectedInboundClose(context: context, reason: .errorThrown(error))
  156. }
  157. private func handleUnexpectedInboundClose(
  158. context: ChannelHandlerContext,
  159. reason: GRPCStreamStateMachine.UnexpectedInboundCloseReason
  160. ) {
  161. switch self.stateMachine.unexpectedInboundClose(reason: reason) {
  162. case .fireError_serverOnly(let wrappedError):
  163. context.fireErrorCaught(wrappedError)
  164. case .doNothing:
  165. ()
  166. case .forwardStatus_clientOnly:
  167. assertionFailure(
  168. "`forwardStatus` should only happen on the client side, never on the server."
  169. )
  170. }
  171. }
  172. }
  173. // - MARK: ChannelOutboundHandler
  174. @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
  175. extension GRPCServerStreamHandler {
  176. func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
  177. let frame = self.unwrapOutboundIn(data)
  178. switch frame {
  179. case .metadata(let metadata):
  180. do {
  181. self.flushPending = true
  182. let headers = try self.stateMachine.send(metadata: metadata)
  183. context.write(self.wrapOutboundOut(.headers(.init(headers: headers))), promise: promise)
  184. } catch {
  185. promise?.fail(error)
  186. context.fireErrorCaught(error)
  187. }
  188. case .message(let message):
  189. do {
  190. try self.stateMachine.send(message: message, promise: promise)
  191. } catch {
  192. promise?.fail(error)
  193. context.fireErrorCaught(error)
  194. }
  195. case .status(let status, let metadata):
  196. do {
  197. let headers = try self.stateMachine.send(status: status, metadata: metadata)
  198. let response = HTTP2Frame.FramePayload.headers(.init(headers: headers, endStream: true))
  199. self.pendingTrailers = (response, promise)
  200. } catch {
  201. promise?.fail(error)
  202. context.fireErrorCaught(error)
  203. }
  204. }
  205. }
  206. func flush(context: ChannelHandlerContext) {
  207. if self.isReading {
  208. // We don't want to flush yet if we're still in a read loop.
  209. return
  210. }
  211. do {
  212. loop: while true {
  213. switch try self.stateMachine.nextOutboundFrame() {
  214. case .sendFrame(let byteBuffer, let promise):
  215. self.flushPending = true
  216. context.write(
  217. self.wrapOutboundOut(.data(.init(data: .byteBuffer(byteBuffer)))),
  218. promise: promise
  219. )
  220. case .noMoreMessages:
  221. if let pendingTrailers = self.pendingTrailers {
  222. self.flushPending = true
  223. self.pendingTrailers = nil
  224. context.write(
  225. self.wrapOutboundOut(pendingTrailers.trailers),
  226. promise: pendingTrailers.promise
  227. )
  228. }
  229. break loop
  230. case .awaitMoreMessages:
  231. break loop
  232. }
  233. }
  234. if self.flushPending {
  235. self.flushPending = false
  236. context.flush()
  237. }
  238. } catch {
  239. context.fireErrorCaught(error)
  240. }
  241. }
  242. }