GRPCServerStreamHandler.swift 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. /*
  2. * Copyright 2024, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. package import GRPCCore
  17. package import NIOCore
  18. package import NIOHTTP2
  19. @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
  20. package final class GRPCServerStreamHandler: ChannelDuplexHandler, RemovableChannelHandler {
  21. package typealias InboundIn = HTTP2Frame.FramePayload
  22. package typealias InboundOut = RPCRequestPart
  23. package typealias OutboundIn = RPCResponsePart
  24. package typealias OutboundOut = HTTP2Frame.FramePayload
  25. private var stateMachine: GRPCStreamStateMachine
  26. private var isReading = false
  27. private var flushPending = false
  28. // We buffer the final status + trailers to avoid reordering issues (i.e.,
  29. // if there are messages still not written into the channel because flush has
  30. // not been called, but the server sends back trailers).
  31. private var pendingTrailers:
  32. (trailers: HTTP2Frame.FramePayload, promise: EventLoopPromise<Void>?)?
  33. private let methodDescriptorPromise: EventLoopPromise<MethodDescriptor>
  34. // Existential errors unconditionally allocate, avoid this per-use allocation by doing it
  35. // statically.
  36. private static let handlerRemovedBeforeDescriptorResolved: any Error = RPCError(
  37. code: .unavailable,
  38. message: "RPC stream was closed before we got any Metadata."
  39. )
  40. package init(
  41. scheme: Scheme,
  42. acceptedEncodings: CompressionAlgorithmSet,
  43. maximumPayloadSize: Int,
  44. methodDescriptorPromise: EventLoopPromise<MethodDescriptor>,
  45. skipStateMachineAssertions: Bool = false
  46. ) {
  47. self.stateMachine = .init(
  48. configuration: .server(.init(scheme: scheme, acceptedEncodings: acceptedEncodings)),
  49. maximumPayloadSize: maximumPayloadSize,
  50. skipAssertions: skipStateMachineAssertions
  51. )
  52. self.methodDescriptorPromise = methodDescriptorPromise
  53. }
  54. }
  55. // - MARK: ChannelInboundHandler
  56. @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
  57. extension GRPCServerStreamHandler {
  58. package func channelRead(context: ChannelHandlerContext, data: NIOAny) {
  59. self.isReading = true
  60. let frame = self.unwrapInboundIn(data)
  61. switch frame {
  62. case .data(let frameData):
  63. let endStream = frameData.endStream
  64. switch frameData.data {
  65. case .byteBuffer(let buffer):
  66. do {
  67. switch try self.stateMachine.receive(buffer: buffer, endStream: endStream) {
  68. case .endRPCAndForwardErrorStatus_clientOnly:
  69. preconditionFailure(
  70. "OnBufferReceivedAction.endRPCAndForwardErrorStatus should never be returned for the server."
  71. )
  72. case .forwardErrorAndClose_serverOnly(let error):
  73. context.fireErrorCaught(error)
  74. context.close(mode: .all, promise: nil)
  75. case .readInbound:
  76. loop: while true {
  77. switch self.stateMachine.nextInboundMessage() {
  78. case .receiveMessage(let message):
  79. context.fireChannelRead(self.wrapInboundOut(.message(message)))
  80. case .awaitMoreMessages:
  81. break loop
  82. case .noMoreMessages:
  83. context.fireUserInboundEventTriggered(ChannelEvent.inputClosed)
  84. break loop
  85. }
  86. }
  87. case .doNothing:
  88. ()
  89. }
  90. } catch let invalidState {
  91. let error = RPCError(invalidState)
  92. context.fireErrorCaught(error)
  93. }
  94. case .fileRegion:
  95. preconditionFailure("Unexpected IOData.fileRegion")
  96. }
  97. case .headers(let headers):
  98. do {
  99. let action = try self.stateMachine.receive(
  100. headers: headers.headers,
  101. endStream: headers.endStream
  102. )
  103. switch action {
  104. case .receivedMetadata(let metadata, let methodDescriptor):
  105. if let methodDescriptor = methodDescriptor {
  106. self.methodDescriptorPromise.succeed(methodDescriptor)
  107. context.fireChannelRead(self.wrapInboundOut(.metadata(metadata)))
  108. } else {
  109. assertionFailure("Method descriptor should have been present if we received metadata.")
  110. }
  111. case .rejectRPC_serverOnly(let trailers):
  112. self.flushPending = true
  113. self.methodDescriptorPromise.fail(
  114. RPCError(
  115. code: .unavailable,
  116. message: "RPC was rejected."
  117. )
  118. )
  119. let response = HTTP2Frame.FramePayload.headers(.init(headers: trailers, endStream: true))
  120. context.write(self.wrapOutboundOut(response), promise: nil)
  121. case .receivedStatusAndMetadata_clientOnly:
  122. assertionFailure("Unexpected action")
  123. case .protocolViolation_serverOnly:
  124. context.writeAndFlush(self.wrapOutboundOut(.rstStream(.protocolError)), promise: nil)
  125. context.close(promise: nil)
  126. case .doNothing:
  127. ()
  128. }
  129. } catch let invalidState {
  130. let error = RPCError(invalidState)
  131. context.fireErrorCaught(error)
  132. }
  133. case .rstStream:
  134. self.handleUnexpectedInboundClose(context: context, reason: .streamReset)
  135. case .ping, .goAway, .priority, .settings, .pushPromise, .windowUpdate,
  136. .alternativeService, .origin:
  137. ()
  138. }
  139. }
  140. package func channelReadComplete(context: ChannelHandlerContext) {
  141. self.isReading = false
  142. if self.flushPending {
  143. self.flushPending = false
  144. context.flush()
  145. }
  146. context.fireChannelReadComplete()
  147. }
  148. package func handlerRemoved(context: ChannelHandlerContext) {
  149. self.stateMachine.tearDown()
  150. self.methodDescriptorPromise.fail(Self.handlerRemovedBeforeDescriptorResolved)
  151. }
  152. package func channelInactive(context: ChannelHandlerContext) {
  153. self.handleUnexpectedInboundClose(context: context, reason: .channelInactive)
  154. context.fireChannelInactive()
  155. }
  156. package func errorCaught(context: ChannelHandlerContext, error: any Error) {
  157. self.handleUnexpectedInboundClose(context: context, reason: .errorThrown(error))
  158. }
  159. private func handleUnexpectedInboundClose(
  160. context: ChannelHandlerContext,
  161. reason: GRPCStreamStateMachine.UnexpectedInboundCloseReason
  162. ) {
  163. switch self.stateMachine.unexpectedInboundClose(reason: reason) {
  164. case .fireError_serverOnly(let wrappedError):
  165. context.fireErrorCaught(wrappedError)
  166. case .doNothing:
  167. ()
  168. case .forwardStatus_clientOnly:
  169. assertionFailure(
  170. "`forwardStatus` should only happen on the client side, never on the server."
  171. )
  172. }
  173. }
  174. }
  175. // - MARK: ChannelOutboundHandler
  176. @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
  177. extension GRPCServerStreamHandler {
  178. package func write(
  179. context: ChannelHandlerContext,
  180. data: NIOAny,
  181. promise: EventLoopPromise<Void>?
  182. ) {
  183. let frame = self.unwrapOutboundIn(data)
  184. switch frame {
  185. case .metadata(let metadata):
  186. do {
  187. self.flushPending = true
  188. let headers = try self.stateMachine.send(metadata: metadata)
  189. context.write(self.wrapOutboundOut(.headers(.init(headers: headers))), promise: promise)
  190. } catch let invalidState {
  191. let error = RPCError(invalidState)
  192. promise?.fail(error)
  193. context.fireErrorCaught(error)
  194. }
  195. case .message(let message):
  196. do {
  197. try self.stateMachine.send(message: message, promise: promise)
  198. } catch let invalidState {
  199. let error = RPCError(invalidState)
  200. promise?.fail(error)
  201. context.fireErrorCaught(error)
  202. }
  203. case .status(let status, let metadata):
  204. do {
  205. let headers = try self.stateMachine.send(status: status, metadata: metadata)
  206. let response = HTTP2Frame.FramePayload.headers(.init(headers: headers, endStream: true))
  207. self.pendingTrailers = (response, promise)
  208. } catch let invalidState {
  209. let error = RPCError(invalidState)
  210. promise?.fail(error)
  211. context.fireErrorCaught(error)
  212. }
  213. }
  214. }
  215. package func flush(context: ChannelHandlerContext) {
  216. if self.isReading {
  217. // We don't want to flush yet if we're still in a read loop.
  218. return
  219. }
  220. do {
  221. loop: while true {
  222. switch try self.stateMachine.nextOutboundFrame() {
  223. case .sendFrame(let byteBuffer, let promise):
  224. self.flushPending = true
  225. context.write(
  226. self.wrapOutboundOut(.data(.init(data: .byteBuffer(byteBuffer)))),
  227. promise: promise
  228. )
  229. case .noMoreMessages:
  230. if let pendingTrailers = self.pendingTrailers {
  231. self.flushPending = true
  232. self.pendingTrailers = nil
  233. context.write(
  234. self.wrapOutboundOut(pendingTrailers.trailers),
  235. promise: pendingTrailers.promise
  236. )
  237. }
  238. break loop
  239. case .awaitMoreMessages:
  240. break loop
  241. case .closeAndFailPromise(let promise, let error):
  242. context.close(mode: .all, promise: nil)
  243. promise?.fail(error)
  244. }
  245. }
  246. if self.flushPending {
  247. self.flushPending = false
  248. context.flush()
  249. }
  250. } catch let invalidState {
  251. let error = RPCError(invalidState)
  252. context.fireErrorCaught(error)
  253. }
  254. }
  255. }