GRPCServerStreamHandler.swift 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. /*
  2. * Copyright 2024, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import GRPCCore
  17. import NIOCore
  18. import NIOHTTP2
  19. @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
  20. final class GRPCServerStreamHandler: ChannelDuplexHandler {
  21. typealias InboundIn = HTTP2Frame.FramePayload
  22. typealias InboundOut = RPCRequestPart
  23. typealias OutboundIn = RPCResponsePart
  24. typealias OutboundOut = HTTP2Frame.FramePayload
  25. private var stateMachine: GRPCStreamStateMachine
  26. private var isReading = false
  27. private var flushPending = false
  28. // We buffer the final status + trailers to avoid reordering issues (i.e.,
  29. // if there are messages still not written into the channel because flush has
  30. // not been called, but the server sends back trailers).
  31. private var pendingTrailers: HTTP2Frame.FramePayload?
  32. init(
  33. scheme: Scheme,
  34. acceptedEncodings: [CompressionAlgorithm],
  35. maximumPayloadSize: Int,
  36. skipStateMachineAssertions: Bool = false
  37. ) {
  38. self.stateMachine = .init(
  39. configuration: .server(.init(scheme: scheme, acceptedEncodings: acceptedEncodings)),
  40. maximumPayloadSize: maximumPayloadSize,
  41. skipAssertions: skipStateMachineAssertions
  42. )
  43. }
  44. }
  45. // - MARK: ChannelInboundHandler
  46. @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
  47. extension GRPCServerStreamHandler {
  48. func channelRead(context: ChannelHandlerContext, data: NIOAny) {
  49. self.isReading = true
  50. let frame = self.unwrapInboundIn(data)
  51. switch frame {
  52. case .data(let frameData):
  53. let endStream = frameData.endStream
  54. switch frameData.data {
  55. case .byteBuffer(let buffer):
  56. do {
  57. try self.stateMachine.receive(buffer: buffer, endStream: endStream)
  58. loop: while true {
  59. switch self.stateMachine.nextInboundMessage() {
  60. case .receiveMessage(let message):
  61. context.fireChannelRead(self.wrapInboundOut(.message(message)))
  62. case .awaitMoreMessages:
  63. break loop
  64. case .noMoreMessages:
  65. context.fireUserInboundEventTriggered(ChannelEvent.inputClosed)
  66. break loop
  67. }
  68. }
  69. } catch {
  70. context.fireErrorCaught(error)
  71. }
  72. case .fileRegion:
  73. preconditionFailure("Unexpected IOData.fileRegion")
  74. }
  75. case .headers(let headers):
  76. do {
  77. let action = try self.stateMachine.receive(
  78. headers: headers.headers,
  79. endStream: headers.endStream
  80. )
  81. switch action {
  82. case .receivedMetadata(let metadata):
  83. context.fireChannelRead(self.wrapInboundOut(.metadata(metadata)))
  84. case .rejectRPC(let trailers):
  85. self.flushPending = true
  86. let response = HTTP2Frame.FramePayload.headers(.init(headers: trailers, endStream: true))
  87. context.write(self.wrapOutboundOut(response), promise: nil)
  88. case .receivedStatusAndMetadata:
  89. throw RPCError(
  90. code: .internalError,
  91. message: "Server cannot get receivedStatusAndMetadata."
  92. )
  93. case .doNothing:
  94. throw RPCError(code: .internalError, message: "Server cannot receive doNothing.")
  95. }
  96. } catch {
  97. context.fireErrorCaught(error)
  98. }
  99. case .ping, .goAway, .priority, .rstStream, .settings, .pushPromise, .windowUpdate,
  100. .alternativeService, .origin:
  101. ()
  102. }
  103. }
  104. func channelReadComplete(context: ChannelHandlerContext) {
  105. self.isReading = false
  106. if self.flushPending {
  107. self.flushPending = false
  108. context.flush()
  109. }
  110. context.fireChannelReadComplete()
  111. }
  112. func handlerRemoved(context: ChannelHandlerContext) {
  113. self.stateMachine.tearDown()
  114. }
  115. }
  116. // - MARK: ChannelOutboundHandler
  117. @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
  118. extension GRPCServerStreamHandler {
  119. func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
  120. let frame = self.unwrapOutboundIn(data)
  121. switch frame {
  122. case .metadata(let metadata):
  123. do {
  124. self.flushPending = true
  125. let headers = try self.stateMachine.send(metadata: metadata)
  126. context.write(self.wrapOutboundOut(.headers(.init(headers: headers))), promise: nil)
  127. // TODO: move the promise handling into the state machine
  128. promise?.succeed()
  129. } catch {
  130. context.fireErrorCaught(error)
  131. // TODO: move the promise handling into the state machine
  132. promise?.fail(error)
  133. }
  134. case .message(let message):
  135. do {
  136. try self.stateMachine.send(message: message, endStream: false)
  137. // TODO: move the promise handling into the state machine
  138. promise?.succeed()
  139. } catch {
  140. context.fireErrorCaught(error)
  141. // TODO: move the promise handling into the state machine
  142. promise?.fail(error)
  143. }
  144. case .status(let status, let metadata):
  145. do {
  146. let headers = try self.stateMachine.send(status: status, metadata: metadata)
  147. let response = HTTP2Frame.FramePayload.headers(.init(headers: headers, endStream: true))
  148. self.pendingTrailers = response
  149. // TODO: move the promise handling into the state machine
  150. promise?.succeed()
  151. } catch {
  152. context.fireErrorCaught(error)
  153. // TODO: move the promise handling into the state machine
  154. promise?.fail(error)
  155. }
  156. }
  157. }
  158. func flush(context: ChannelHandlerContext) {
  159. if self.isReading {
  160. // We don't want to flush yet if we're still in a read loop.
  161. return
  162. }
  163. do {
  164. loop: while true {
  165. switch try self.stateMachine.nextOutboundMessage() {
  166. case .sendMessage(let byteBuffer):
  167. self.flushPending = true
  168. context.write(
  169. self.wrapOutboundOut(.data(.init(data: .byteBuffer(byteBuffer)))),
  170. promise: nil
  171. )
  172. case .noMoreMessages:
  173. if let pendingTrailers = self.pendingTrailers {
  174. self.flushPending = true
  175. self.pendingTrailers = nil
  176. context.write(self.wrapOutboundOut(pendingTrailers), promise: nil)
  177. }
  178. break loop
  179. case .awaitMoreMessages:
  180. break loop
  181. }
  182. }
  183. if self.flushPending {
  184. self.flushPending = false
  185. context.flush()
  186. }
  187. } catch {
  188. context.fireErrorCaught(error)
  189. }
  190. }
  191. }