GRPCServerStreamHandler.swift 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. /*
  2. * Copyright 2024, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import GRPCCore
  17. import NIOCore
  18. import NIOHTTP2
  19. @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
  20. final class GRPCServerStreamHandler: ChannelDuplexHandler {
  21. typealias InboundIn = HTTP2Frame.FramePayload
  22. typealias InboundOut = RPCRequestPart
  23. typealias OutboundIn = RPCResponsePart
  24. typealias OutboundOut = HTTP2Frame.FramePayload
  25. private var stateMachine: GRPCStreamStateMachine
  26. private var isReading = false
  27. private var flushPending = false
  28. // We buffer the final status + trailers to avoid reordering issues (i.e.,
  29. // if there are messages still not written into the channel because flush has
  30. // not been called, but the server sends back trailers).
  31. private var pendingTrailers:
  32. (trailers: HTTP2Frame.FramePayload, promise: EventLoopPromise<Void>?)?
  33. init(
  34. scheme: Scheme,
  35. acceptedEncodings: [CompressionAlgorithm],
  36. maximumPayloadSize: Int,
  37. skipStateMachineAssertions: Bool = false
  38. ) {
  39. self.stateMachine = .init(
  40. configuration: .server(.init(scheme: scheme, acceptedEncodings: acceptedEncodings)),
  41. maximumPayloadSize: maximumPayloadSize,
  42. skipAssertions: skipStateMachineAssertions
  43. )
  44. }
  45. }
  46. // - MARK: ChannelInboundHandler
  47. @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
  48. extension GRPCServerStreamHandler {
  49. func channelRead(context: ChannelHandlerContext, data: NIOAny) {
  50. self.isReading = true
  51. let frame = self.unwrapInboundIn(data)
  52. switch frame {
  53. case .data(let frameData):
  54. let endStream = frameData.endStream
  55. switch frameData.data {
  56. case .byteBuffer(let buffer):
  57. do {
  58. try self.stateMachine.receive(buffer: buffer, endStream: endStream)
  59. loop: while true {
  60. switch self.stateMachine.nextInboundMessage() {
  61. case .receiveMessage(let message):
  62. context.fireChannelRead(self.wrapInboundOut(.message(message)))
  63. case .awaitMoreMessages:
  64. break loop
  65. case .noMoreMessages:
  66. context.fireUserInboundEventTriggered(ChannelEvent.inputClosed)
  67. break loop
  68. }
  69. }
  70. } catch {
  71. context.fireErrorCaught(error)
  72. }
  73. case .fileRegion:
  74. preconditionFailure("Unexpected IOData.fileRegion")
  75. }
  76. case .headers(let headers):
  77. do {
  78. let action = try self.stateMachine.receive(
  79. headers: headers.headers,
  80. endStream: headers.endStream
  81. )
  82. switch action {
  83. case .receivedMetadata(let metadata):
  84. context.fireChannelRead(self.wrapInboundOut(.metadata(metadata)))
  85. case .rejectRPC(let trailers):
  86. self.flushPending = true
  87. let response = HTTP2Frame.FramePayload.headers(.init(headers: trailers, endStream: true))
  88. context.write(self.wrapOutboundOut(response), promise: nil)
  89. case .receivedStatusAndMetadata:
  90. throw RPCError(
  91. code: .internalError,
  92. message: "Server cannot get receivedStatusAndMetadata."
  93. )
  94. case .doNothing:
  95. throw RPCError(code: .internalError, message: "Server cannot receive doNothing.")
  96. }
  97. } catch {
  98. context.fireErrorCaught(error)
  99. }
  100. case .ping, .goAway, .priority, .rstStream, .settings, .pushPromise, .windowUpdate,
  101. .alternativeService, .origin:
  102. ()
  103. }
  104. }
  105. func channelReadComplete(context: ChannelHandlerContext) {
  106. self.isReading = false
  107. if self.flushPending {
  108. self.flushPending = false
  109. context.flush()
  110. }
  111. context.fireChannelReadComplete()
  112. }
  113. func handlerRemoved(context: ChannelHandlerContext) {
  114. self.stateMachine.tearDown()
  115. }
  116. }
  117. // - MARK: ChannelOutboundHandler
  118. @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
  119. extension GRPCServerStreamHandler {
  120. func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
  121. let frame = self.unwrapOutboundIn(data)
  122. switch frame {
  123. case .metadata(let metadata):
  124. do {
  125. self.flushPending = true
  126. let headers = try self.stateMachine.send(metadata: metadata)
  127. context.write(self.wrapOutboundOut(.headers(.init(headers: headers))), promise: promise)
  128. } catch {
  129. promise?.fail(error)
  130. context.fireErrorCaught(error)
  131. }
  132. case .message(let message):
  133. do {
  134. try self.stateMachine.send(message: message, promise: promise)
  135. } catch {
  136. promise?.fail(error)
  137. context.fireErrorCaught(error)
  138. }
  139. case .status(let status, let metadata):
  140. do {
  141. let headers = try self.stateMachine.send(status: status, metadata: metadata)
  142. let response = HTTP2Frame.FramePayload.headers(.init(headers: headers, endStream: true))
  143. self.pendingTrailers = (response, promise)
  144. } catch {
  145. promise?.fail(error)
  146. context.fireErrorCaught(error)
  147. }
  148. }
  149. }
  150. func flush(context: ChannelHandlerContext) {
  151. if self.isReading {
  152. // We don't want to flush yet if we're still in a read loop.
  153. return
  154. }
  155. do {
  156. loop: while true {
  157. switch try self.stateMachine.nextOutboundFrame() {
  158. case .sendFrame(let byteBuffer, let promise):
  159. self.flushPending = true
  160. context.write(
  161. self.wrapOutboundOut(.data(.init(data: .byteBuffer(byteBuffer)))),
  162. promise: promise
  163. )
  164. case .noMoreMessages:
  165. if let pendingTrailers = self.pendingTrailers {
  166. self.flushPending = true
  167. self.pendingTrailers = nil
  168. context.write(
  169. self.wrapOutboundOut(pendingTrailers.trailers),
  170. promise: pendingTrailers.promise
  171. )
  172. }
  173. break loop
  174. case .awaitMoreMessages:
  175. break loop
  176. }
  177. }
  178. if self.flushPending {
  179. self.flushPending = false
  180. context.flush()
  181. }
  182. } catch {
  183. context.fireErrorCaught(error)
  184. }
  185. }
  186. }