GRPCServerStreamHandler.swift 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. /*
  2. * Copyright 2024, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import GRPCCore
  17. import NIOCore
  18. import NIOHTTP2
  19. @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
  20. final class GRPCServerStreamHandler: ChannelDuplexHandler {
  21. typealias InboundIn = HTTP2Frame.FramePayload
  22. typealias InboundOut = RPCRequestPart
  23. typealias OutboundIn = RPCResponsePart
  24. typealias OutboundOut = HTTP2Frame.FramePayload
  25. private var stateMachine: GRPCStreamStateMachine
  26. private var isReading = false
  27. private var flushPending = false
  28. // We buffer the final status + trailers to avoid reordering issues (i.e.,
  29. // if there are messages still not written into the channel because flush has
  30. // not been called, but the server sends back trailers).
  31. private var pendingTrailers:
  32. (trailers: HTTP2Frame.FramePayload, promise: EventLoopPromise<Void>?)?
  33. init(
  34. scheme: Scheme,
  35. acceptedEncodings: [CompressionAlgorithm],
  36. maximumPayloadSize: Int,
  37. skipStateMachineAssertions: Bool = false
  38. ) {
  39. self.stateMachine = .init(
  40. configuration: .server(.init(scheme: scheme, acceptedEncodings: acceptedEncodings)),
  41. maximumPayloadSize: maximumPayloadSize,
  42. skipAssertions: skipStateMachineAssertions
  43. )
  44. }
  45. }
  46. // - MARK: ChannelInboundHandler
  47. @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
  48. extension GRPCServerStreamHandler {
  49. func channelRead(context: ChannelHandlerContext, data: NIOAny) {
  50. self.isReading = true
  51. let frame = self.unwrapInboundIn(data)
  52. switch frame {
  53. case .data(let frameData):
  54. let endStream = frameData.endStream
  55. switch frameData.data {
  56. case .byteBuffer(let buffer):
  57. do {
  58. switch try self.stateMachine.receive(buffer: buffer, endStream: endStream) {
  59. case .endRPCAndForwardErrorStatus:
  60. preconditionFailure(
  61. "OnBufferReceivedAction.endRPCAndForwardErrorStatus should never be returned for the server."
  62. )
  63. case .readInbound:
  64. loop: while true {
  65. switch self.stateMachine.nextInboundMessage() {
  66. case .receiveMessage(let message):
  67. context.fireChannelRead(self.wrapInboundOut(.message(message)))
  68. case .awaitMoreMessages:
  69. break loop
  70. case .noMoreMessages:
  71. context.fireUserInboundEventTriggered(ChannelEvent.inputClosed)
  72. break loop
  73. }
  74. }
  75. }
  76. } catch {
  77. context.fireErrorCaught(error)
  78. }
  79. case .fileRegion:
  80. preconditionFailure("Unexpected IOData.fileRegion")
  81. }
  82. case .headers(let headers):
  83. do {
  84. let action = try self.stateMachine.receive(
  85. headers: headers.headers,
  86. endStream: headers.endStream
  87. )
  88. switch action {
  89. case .receivedMetadata(let metadata):
  90. context.fireChannelRead(self.wrapInboundOut(.metadata(metadata)))
  91. case .rejectRPC(let trailers):
  92. self.flushPending = true
  93. let response = HTTP2Frame.FramePayload.headers(.init(headers: trailers, endStream: true))
  94. context.write(self.wrapOutboundOut(response), promise: nil)
  95. case .receivedStatusAndMetadata:
  96. throw RPCError(
  97. code: .internalError,
  98. message: "Server cannot get receivedStatusAndMetadata."
  99. )
  100. case .doNothing:
  101. throw RPCError(code: .internalError, message: "Server cannot receive doNothing.")
  102. }
  103. } catch {
  104. context.fireErrorCaught(error)
  105. }
  106. case .ping, .goAway, .priority, .rstStream, .settings, .pushPromise, .windowUpdate,
  107. .alternativeService, .origin:
  108. ()
  109. }
  110. }
  111. func channelReadComplete(context: ChannelHandlerContext) {
  112. self.isReading = false
  113. if self.flushPending {
  114. self.flushPending = false
  115. context.flush()
  116. }
  117. context.fireChannelReadComplete()
  118. }
  119. func handlerRemoved(context: ChannelHandlerContext) {
  120. self.stateMachine.tearDown()
  121. }
  122. }
  123. // - MARK: ChannelOutboundHandler
  124. @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
  125. extension GRPCServerStreamHandler {
  126. func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
  127. let frame = self.unwrapOutboundIn(data)
  128. switch frame {
  129. case .metadata(let metadata):
  130. do {
  131. self.flushPending = true
  132. let headers = try self.stateMachine.send(metadata: metadata)
  133. context.write(self.wrapOutboundOut(.headers(.init(headers: headers))), promise: promise)
  134. } catch {
  135. promise?.fail(error)
  136. context.fireErrorCaught(error)
  137. }
  138. case .message(let message):
  139. do {
  140. try self.stateMachine.send(message: message, promise: promise)
  141. } catch {
  142. promise?.fail(error)
  143. context.fireErrorCaught(error)
  144. }
  145. case .status(let status, let metadata):
  146. do {
  147. let headers = try self.stateMachine.send(status: status, metadata: metadata)
  148. let response = HTTP2Frame.FramePayload.headers(.init(headers: headers, endStream: true))
  149. self.pendingTrailers = (response, promise)
  150. } catch {
  151. promise?.fail(error)
  152. context.fireErrorCaught(error)
  153. }
  154. }
  155. }
  156. func flush(context: ChannelHandlerContext) {
  157. if self.isReading {
  158. // We don't want to flush yet if we're still in a read loop.
  159. return
  160. }
  161. do {
  162. loop: while true {
  163. switch try self.stateMachine.nextOutboundFrame() {
  164. case .sendFrame(let byteBuffer, let promise):
  165. self.flushPending = true
  166. context.write(
  167. self.wrapOutboundOut(.data(.init(data: .byteBuffer(byteBuffer)))),
  168. promise: promise
  169. )
  170. case .noMoreMessages:
  171. if let pendingTrailers = self.pendingTrailers {
  172. self.flushPending = true
  173. self.pendingTrailers = nil
  174. context.write(
  175. self.wrapOutboundOut(pendingTrailers.trailers),
  176. promise: pendingTrailers.promise
  177. )
  178. }
  179. break loop
  180. case .awaitMoreMessages:
  181. break loop
  182. }
  183. }
  184. if self.flushPending {
  185. self.flushPending = false
  186. context.flush()
  187. }
  188. } catch {
  189. context.fireErrorCaught(error)
  190. }
  191. }
  192. }