GRPCServerStreamHandler.swift 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. /*
  2. * Copyright 2024, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import GRPCCore
  17. import NIOCore
  18. import NIOHTTP2
  19. @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
  20. final class GRPCServerStreamHandler: ChannelDuplexHandler, RemovableChannelHandler {
  21. typealias InboundIn = HTTP2Frame.FramePayload
  22. typealias InboundOut = RPCRequestPart
  23. typealias OutboundIn = RPCResponsePart
  24. typealias OutboundOut = HTTP2Frame.FramePayload
  25. private var stateMachine: GRPCStreamStateMachine
  26. private var isReading = false
  27. private var flushPending = false
  28. // We buffer the final status + trailers to avoid reordering issues (i.e.,
  29. // if there are messages still not written into the channel because flush has
  30. // not been called, but the server sends back trailers).
  31. private var pendingTrailers:
  32. (trailers: HTTP2Frame.FramePayload, promise: EventLoopPromise<Void>?)?
  33. private let methodDescriptorPromise: EventLoopPromise<MethodDescriptor>
  34. init(
  35. scheme: GRPCStreamStateMachineConfiguration.Scheme,
  36. acceptedEncodings: CompressionAlgorithmSet,
  37. maximumPayloadSize: Int,
  38. methodDescriptorPromise: EventLoopPromise<MethodDescriptor>,
  39. skipStateMachineAssertions: Bool = false
  40. ) {
  41. self.stateMachine = .init(
  42. configuration: .server(.init(scheme: scheme, acceptedEncodings: acceptedEncodings)),
  43. maximumPayloadSize: maximumPayloadSize,
  44. skipAssertions: skipStateMachineAssertions
  45. )
  46. self.methodDescriptorPromise = methodDescriptorPromise
  47. }
  48. }
  49. // - MARK: ChannelInboundHandler
  50. @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
  51. extension GRPCServerStreamHandler {
  52. func channelRead(context: ChannelHandlerContext, data: NIOAny) {
  53. self.isReading = true
  54. let frame = self.unwrapInboundIn(data)
  55. switch frame {
  56. case .data(let frameData):
  57. let endStream = frameData.endStream
  58. switch frameData.data {
  59. case .byteBuffer(let buffer):
  60. do {
  61. switch try self.stateMachine.receive(buffer: buffer, endStream: endStream) {
  62. case .endRPCAndForwardErrorStatus:
  63. preconditionFailure(
  64. "OnBufferReceivedAction.endRPCAndForwardErrorStatus should never be returned for the server."
  65. )
  66. case .readInbound:
  67. loop: while true {
  68. switch self.stateMachine.nextInboundMessage() {
  69. case .receiveMessage(let message):
  70. context.fireChannelRead(self.wrapInboundOut(.message(message)))
  71. case .awaitMoreMessages:
  72. break loop
  73. case .noMoreMessages:
  74. context.fireUserInboundEventTriggered(ChannelEvent.inputClosed)
  75. break loop
  76. }
  77. }
  78. case .doNothing:
  79. ()
  80. }
  81. } catch {
  82. context.fireErrorCaught(error)
  83. }
  84. case .fileRegion:
  85. preconditionFailure("Unexpected IOData.fileRegion")
  86. }
  87. case .headers(let headers):
  88. do {
  89. let action = try self.stateMachine.receive(
  90. headers: headers.headers,
  91. endStream: headers.endStream
  92. )
  93. switch action {
  94. case .receivedMetadata(let metadata, let methodDescriptor):
  95. if let methodDescriptor = methodDescriptor {
  96. self.methodDescriptorPromise.succeed(methodDescriptor)
  97. context.fireChannelRead(self.wrapInboundOut(.metadata(metadata)))
  98. } else {
  99. assertionFailure("Method descriptor should have been present if we received metadata.")
  100. }
  101. case .rejectRPC(let trailers):
  102. self.flushPending = true
  103. self.methodDescriptorPromise.fail(
  104. RPCError(
  105. code: .unavailable,
  106. message: "RPC was rejected."
  107. )
  108. )
  109. let response = HTTP2Frame.FramePayload.headers(.init(headers: trailers, endStream: true))
  110. context.write(self.wrapOutboundOut(response), promise: nil)
  111. case .receivedStatusAndMetadata:
  112. throw RPCError(
  113. code: .internalError,
  114. message: "Server cannot get receivedStatusAndMetadata."
  115. )
  116. case .protocolViolation:
  117. context.writeAndFlush(self.wrapOutboundOut(.rstStream(.protocolError)), promise: nil)
  118. context.close(promise: nil)
  119. case .doNothing:
  120. throw RPCError(code: .internalError, message: "Server cannot receive doNothing.")
  121. }
  122. } catch {
  123. context.fireErrorCaught(error)
  124. }
  125. case .rstStream:
  126. self.handleUnexpectedInboundClose(context: context, reason: .streamReset)
  127. case .ping, .goAway, .priority, .settings, .pushPromise, .windowUpdate,
  128. .alternativeService, .origin:
  129. ()
  130. }
  131. }
  132. func channelReadComplete(context: ChannelHandlerContext) {
  133. self.isReading = false
  134. if self.flushPending {
  135. self.flushPending = false
  136. context.flush()
  137. }
  138. context.fireChannelReadComplete()
  139. }
  140. func handlerRemoved(context: ChannelHandlerContext) {
  141. self.stateMachine.tearDown()
  142. self.methodDescriptorPromise.fail(
  143. RPCError(
  144. code: .unavailable,
  145. message: "RPC stream was closed before we got any Metadata."
  146. )
  147. )
  148. }
  149. func channelInactive(context: ChannelHandlerContext) {
  150. self.handleUnexpectedInboundClose(context: context, reason: .channelInactive)
  151. context.fireChannelInactive()
  152. }
  153. func errorCaught(context: ChannelHandlerContext, error: any Error) {
  154. self.handleUnexpectedInboundClose(context: context, reason: .errorThrown(error))
  155. }
  156. private func handleUnexpectedInboundClose(
  157. context: ChannelHandlerContext,
  158. reason: GRPCStreamStateMachine.UnexpectedInboundCloseReason
  159. ) {
  160. switch self.stateMachine.unexpectedInboundClose(reason: reason) {
  161. case .fireError_serverOnly(let wrappedError):
  162. context.fireErrorCaught(wrappedError)
  163. case .doNothing:
  164. ()
  165. case .forwardStatus_clientOnly:
  166. assertionFailure(
  167. "`forwardStatus` should only happen on the client side, never on the server."
  168. )
  169. }
  170. }
  171. }
  172. // - MARK: ChannelOutboundHandler
  173. @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
  174. extension GRPCServerStreamHandler {
  175. func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
  176. let frame = self.unwrapOutboundIn(data)
  177. switch frame {
  178. case .metadata(let metadata):
  179. do {
  180. self.flushPending = true
  181. let headers = try self.stateMachine.send(metadata: metadata)
  182. context.write(self.wrapOutboundOut(.headers(.init(headers: headers))), promise: promise)
  183. } catch {
  184. promise?.fail(error)
  185. context.fireErrorCaught(error)
  186. }
  187. case .message(let message):
  188. do {
  189. try self.stateMachine.send(message: message, promise: promise)
  190. } catch {
  191. promise?.fail(error)
  192. context.fireErrorCaught(error)
  193. }
  194. case .status(let status, let metadata):
  195. do {
  196. let headers = try self.stateMachine.send(status: status, metadata: metadata)
  197. let response = HTTP2Frame.FramePayload.headers(.init(headers: headers, endStream: true))
  198. self.pendingTrailers = (response, promise)
  199. } catch {
  200. promise?.fail(error)
  201. context.fireErrorCaught(error)
  202. }
  203. }
  204. }
  205. func flush(context: ChannelHandlerContext) {
  206. if self.isReading {
  207. // We don't want to flush yet if we're still in a read loop.
  208. return
  209. }
  210. do {
  211. loop: while true {
  212. switch try self.stateMachine.nextOutboundFrame() {
  213. case .sendFrame(let byteBuffer, let promise):
  214. self.flushPending = true
  215. context.write(
  216. self.wrapOutboundOut(.data(.init(data: .byteBuffer(byteBuffer)))),
  217. promise: promise
  218. )
  219. case .noMoreMessages:
  220. if let pendingTrailers = self.pendingTrailers {
  221. self.flushPending = true
  222. self.pendingTrailers = nil
  223. context.write(
  224. self.wrapOutboundOut(pendingTrailers.trailers),
  225. promise: pendingTrailers.promise
  226. )
  227. }
  228. break loop
  229. case .awaitMoreMessages:
  230. break loop
  231. }
  232. }
  233. if self.flushPending {
  234. self.flushPending = false
  235. context.flush()
  236. }
  237. } catch {
  238. context.fireErrorCaught(error)
  239. }
  240. }
  241. }