GRPCServerStreamHandler.swift 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. /*
  2. * Copyright 2024, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import GRPCCore
  17. import NIOCore
  18. import NIOHTTP2
  19. @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
  20. final class GRPCServerStreamHandler: ChannelDuplexHandler, RemovableChannelHandler {
  21. typealias InboundIn = HTTP2Frame.FramePayload
  22. typealias InboundOut = RPCRequestPart
  23. typealias OutboundIn = RPCResponsePart
  24. typealias OutboundOut = HTTP2Frame.FramePayload
  25. private var stateMachine: GRPCStreamStateMachine
  26. private var isReading = false
  27. private var flushPending = false
  28. // We buffer the final status + trailers to avoid reordering issues (i.e.,
  29. // if there are messages still not written into the channel because flush has
  30. // not been called, but the server sends back trailers).
  31. private var pendingTrailers:
  32. (trailers: HTTP2Frame.FramePayload, promise: EventLoopPromise<Void>?)?
  33. private let methodDescriptorPromise: EventLoopPromise<MethodDescriptor>
  34. init(
  35. scheme: GRPCStreamStateMachineConfiguration.Scheme,
  36. acceptedEncodings: CompressionAlgorithmSet,
  37. maximumPayloadSize: Int,
  38. methodDescriptorPromise: EventLoopPromise<MethodDescriptor>,
  39. skipStateMachineAssertions: Bool = false
  40. ) {
  41. self.stateMachine = .init(
  42. configuration: .server(.init(scheme: scheme, acceptedEncodings: acceptedEncodings)),
  43. maximumPayloadSize: maximumPayloadSize,
  44. skipAssertions: skipStateMachineAssertions
  45. )
  46. self.methodDescriptorPromise = methodDescriptorPromise
  47. }
  48. }
  49. // - MARK: ChannelInboundHandler
  50. @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
  51. extension GRPCServerStreamHandler {
  52. func channelRead(context: ChannelHandlerContext, data: NIOAny) {
  53. self.isReading = true
  54. let frame = self.unwrapInboundIn(data)
  55. switch frame {
  56. case .data(let frameData):
  57. let endStream = frameData.endStream
  58. switch frameData.data {
  59. case .byteBuffer(let buffer):
  60. do {
  61. switch try self.stateMachine.receive(buffer: buffer, endStream: endStream) {
  62. case .endRPCAndForwardErrorStatus:
  63. preconditionFailure(
  64. "OnBufferReceivedAction.endRPCAndForwardErrorStatus should never be returned for the server."
  65. )
  66. case .readInbound:
  67. loop: while true {
  68. switch self.stateMachine.nextInboundMessage() {
  69. case .receiveMessage(let message):
  70. context.fireChannelRead(self.wrapInboundOut(.message(message)))
  71. case .awaitMoreMessages:
  72. break loop
  73. case .noMoreMessages:
  74. context.fireUserInboundEventTriggered(ChannelEvent.inputClosed)
  75. break loop
  76. }
  77. }
  78. }
  79. } catch {
  80. context.fireErrorCaught(error)
  81. }
  82. case .fileRegion:
  83. preconditionFailure("Unexpected IOData.fileRegion")
  84. }
  85. case .headers(let headers):
  86. do {
  87. let action = try self.stateMachine.receive(
  88. headers: headers.headers,
  89. endStream: headers.endStream
  90. )
  91. switch action {
  92. case .receivedMetadata(let metadata, let methodDescriptor):
  93. if let methodDescriptor = methodDescriptor {
  94. self.methodDescriptorPromise.succeed(methodDescriptor)
  95. context.fireChannelRead(self.wrapInboundOut(.metadata(metadata)))
  96. } else {
  97. assertionFailure("Method descriptor should have been present if we received metadata.")
  98. }
  99. case .rejectRPC(let trailers):
  100. self.flushPending = true
  101. self.methodDescriptorPromise.fail(
  102. RPCError(
  103. code: .unavailable,
  104. message: "RPC was rejected."
  105. )
  106. )
  107. let response = HTTP2Frame.FramePayload.headers(.init(headers: trailers, endStream: true))
  108. context.write(self.wrapOutboundOut(response), promise: nil)
  109. case .receivedStatusAndMetadata:
  110. throw RPCError(
  111. code: .internalError,
  112. message: "Server cannot get receivedStatusAndMetadata."
  113. )
  114. case .doNothing:
  115. throw RPCError(code: .internalError, message: "Server cannot receive doNothing.")
  116. }
  117. } catch {
  118. context.fireErrorCaught(error)
  119. }
  120. case .ping, .goAway, .priority, .rstStream, .settings, .pushPromise, .windowUpdate,
  121. .alternativeService, .origin:
  122. ()
  123. }
  124. }
  125. func channelReadComplete(context: ChannelHandlerContext) {
  126. self.isReading = false
  127. if self.flushPending {
  128. self.flushPending = false
  129. context.flush()
  130. }
  131. context.fireChannelReadComplete()
  132. }
  133. func handlerRemoved(context: ChannelHandlerContext) {
  134. self.stateMachine.tearDown()
  135. self.methodDescriptorPromise.fail(
  136. RPCError(
  137. code: .unavailable,
  138. message: "RPC stream was closed before we got any Metadata."
  139. )
  140. )
  141. }
  142. }
  143. // - MARK: ChannelOutboundHandler
  144. @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
  145. extension GRPCServerStreamHandler {
  146. func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
  147. let frame = self.unwrapOutboundIn(data)
  148. switch frame {
  149. case .metadata(let metadata):
  150. do {
  151. self.flushPending = true
  152. let headers = try self.stateMachine.send(metadata: metadata)
  153. context.write(self.wrapOutboundOut(.headers(.init(headers: headers))), promise: promise)
  154. } catch {
  155. promise?.fail(error)
  156. context.fireErrorCaught(error)
  157. }
  158. case .message(let message):
  159. do {
  160. try self.stateMachine.send(message: message, promise: promise)
  161. } catch {
  162. promise?.fail(error)
  163. context.fireErrorCaught(error)
  164. }
  165. case .status(let status, let metadata):
  166. do {
  167. let headers = try self.stateMachine.send(status: status, metadata: metadata)
  168. let response = HTTP2Frame.FramePayload.headers(.init(headers: headers, endStream: true))
  169. self.pendingTrailers = (response, promise)
  170. } catch {
  171. promise?.fail(error)
  172. context.fireErrorCaught(error)
  173. }
  174. }
  175. }
  176. func flush(context: ChannelHandlerContext) {
  177. if self.isReading {
  178. // We don't want to flush yet if we're still in a read loop.
  179. return
  180. }
  181. do {
  182. loop: while true {
  183. switch try self.stateMachine.nextOutboundFrame() {
  184. case .sendFrame(let byteBuffer, let promise):
  185. self.flushPending = true
  186. context.write(
  187. self.wrapOutboundOut(.data(.init(data: .byteBuffer(byteBuffer)))),
  188. promise: promise
  189. )
  190. case .noMoreMessages:
  191. if let pendingTrailers = self.pendingTrailers {
  192. self.flushPending = true
  193. self.pendingTrailers = nil
  194. context.write(
  195. self.wrapOutboundOut(pendingTrailers.trailers),
  196. promise: pendingTrailers.promise
  197. )
  198. }
  199. break loop
  200. case .awaitMoreMessages:
  201. break loop
  202. }
  203. }
  204. if self.flushPending {
  205. self.flushPending = false
  206. context.flush()
  207. }
  208. } catch {
  209. context.fireErrorCaught(error)
  210. }
  211. }
  212. }