GRPCClientStreamHandler.swift 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. /*
  2. * Copyright 2024, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. internal import GRPCCore
  17. internal import NIOCore
  18. internal import NIOHTTP2
  19. @available(gRPCSwiftNIOTransport 2.0, *)
  20. final class GRPCClientStreamHandler: ChannelDuplexHandler {
  21. typealias InboundIn = HTTP2Frame.FramePayload
  22. typealias InboundOut = RPCResponsePart<GRPCNIOTransportBytes>
  23. typealias OutboundIn = RPCRequestPart<GRPCNIOTransportBytes>
  24. typealias OutboundOut = HTTP2Frame.FramePayload
  25. private var stateMachine: GRPCStreamStateMachine
  26. private var isReading = false
  27. private var flushPending = false
  28. private var requestFinished = false
  29. init(
  30. methodDescriptor: MethodDescriptor,
  31. scheme: Scheme,
  32. authority: String?,
  33. outboundEncoding: CompressionAlgorithm,
  34. acceptedEncodings: CompressionAlgorithmSet,
  35. maxPayloadSize: Int,
  36. skipStateMachineAssertions: Bool = false
  37. ) {
  38. self.stateMachine = .init(
  39. configuration: .client(
  40. .init(
  41. methodDescriptor: methodDescriptor,
  42. scheme: scheme,
  43. authority: authority,
  44. outboundEncoding: outboundEncoding,
  45. acceptedEncodings: acceptedEncodings
  46. )
  47. ),
  48. maxPayloadSize: maxPayloadSize,
  49. skipAssertions: skipStateMachineAssertions
  50. )
  51. }
  52. }
  53. // - MARK: ChannelInboundHandler
  54. @available(gRPCSwiftNIOTransport 2.0, *)
  55. extension GRPCClientStreamHandler {
  56. func channelRead(context: ChannelHandlerContext, data: NIOAny) {
  57. self.isReading = true
  58. let frame = self.unwrapInboundIn(data)
  59. switch frame {
  60. case .data(let frameData):
  61. let endStream = frameData.endStream
  62. switch frameData.data {
  63. case .byteBuffer(let buffer):
  64. do {
  65. switch try self.stateMachine.receive(buffer: buffer, endStream: endStream) {
  66. case .endRPCAndForwardErrorStatus_clientOnly(let status):
  67. context.fireChannelRead(self.wrapInboundOut(.status(status, [:])))
  68. context.close(promise: nil)
  69. case .forwardErrorAndClose_serverOnly:
  70. assertionFailure("Unexpected client action")
  71. case .readInbound:
  72. loop: while true {
  73. switch self.stateMachine.nextInboundMessage() {
  74. case .receiveMessage(let message):
  75. let wrapped = GRPCNIOTransportBytes(message)
  76. context.fireChannelRead(self.wrapInboundOut(.message(wrapped)))
  77. case .awaitMoreMessages:
  78. break loop
  79. case .noMoreMessages:
  80. // This could only happen if the server sends a data frame with EOS
  81. // set, without sending status and trailers.
  82. // If this happens, we should have forwarded an error status above
  83. // so we should never reach this point. Do nothing.
  84. break loop
  85. }
  86. }
  87. case .doNothing:
  88. ()
  89. }
  90. } catch let invalidState {
  91. let error = RPCError(invalidState)
  92. context.fireErrorCaught(error)
  93. }
  94. case .fileRegion:
  95. preconditionFailure("Unexpected IOData.fileRegion")
  96. }
  97. case .headers(let headers):
  98. do {
  99. let action = try self.stateMachine.receive(
  100. headers: headers.headers,
  101. endStream: headers.endStream
  102. )
  103. switch action {
  104. case .receivedMetadata(let metadata, _):
  105. context.fireChannelRead(self.wrapInboundOut(.metadata(metadata)))
  106. case .receivedStatusAndMetadata_clientOnly(let status, let metadata):
  107. context.fireChannelRead(self.wrapInboundOut(.status(status, metadata)))
  108. context.fireUserInboundEventTriggered(ChannelEvent.inputClosed)
  109. case .rejectRPC_serverOnly, .protocolViolation_serverOnly:
  110. assertionFailure("Unexpected action '\(action)'")
  111. case .doNothing:
  112. ()
  113. }
  114. } catch let invalidState {
  115. let error = RPCError(invalidState)
  116. context.fireErrorCaught(error)
  117. }
  118. case .rstStream(let errorCode):
  119. self.handleUnexpectedInboundClose(context: context, reason: .streamReset(errorCode))
  120. case .ping, .goAway, .priority, .settings, .pushPromise, .windowUpdate,
  121. .alternativeService, .origin:
  122. ()
  123. }
  124. }
  125. func channelReadComplete(context: ChannelHandlerContext) {
  126. self.isReading = false
  127. if self.flushPending {
  128. self.flushPending = false
  129. self.flush(context: context)
  130. }
  131. context.fireChannelReadComplete()
  132. }
  133. func handlerRemoved(context: ChannelHandlerContext) {
  134. self.stateMachine.tearDown()
  135. }
  136. func channelInactive(context: ChannelHandlerContext) {
  137. self.handleUnexpectedInboundClose(context: context, reason: .channelInactive)
  138. context.fireChannelInactive()
  139. }
  140. func errorCaught(context: ChannelHandlerContext, error: any Error) {
  141. self.handleUnexpectedInboundClose(context: context, reason: .errorThrown(error))
  142. }
  143. private func handleUnexpectedInboundClose(
  144. context: ChannelHandlerContext,
  145. reason: GRPCStreamStateMachine.UnexpectedInboundCloseReason
  146. ) {
  147. switch self.stateMachine.unexpectedClose(reason: reason) {
  148. case .forwardStatus_clientOnly(let status):
  149. context.fireChannelRead(self.wrapInboundOut(.status(status, [:])))
  150. case .doNothing:
  151. ()
  152. case .fireError_serverOnly:
  153. assertionFailure("`fireError` should only happen on the server side, never on the client.")
  154. }
  155. }
  156. }
  157. // - MARK: ChannelOutboundHandler
  158. @available(gRPCSwiftNIOTransport 2.0, *)
  159. extension GRPCClientStreamHandler {
  160. func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
  161. switch self.unwrapOutboundIn(data) {
  162. case .metadata(let metadata):
  163. do {
  164. self.flushPending = true
  165. let headers = try self.stateMachine.send(metadata: metadata)
  166. context.write(self.wrapOutboundOut(.headers(.init(headers: headers))), promise: promise)
  167. } catch let invalidState {
  168. let error = RPCError(invalidState)
  169. promise?.fail(error)
  170. context.fireErrorCaught(error)
  171. }
  172. case .message(let message):
  173. do {
  174. try self.stateMachine.send(message: message.buffer, promise: promise)
  175. } catch let invalidState {
  176. let error = RPCError(invalidState)
  177. promise?.fail(error)
  178. context.fireErrorCaught(error)
  179. }
  180. }
  181. }
  182. func close(context: ChannelHandlerContext, mode: CloseMode, promise: EventLoopPromise<Void>?) {
  183. switch mode {
  184. case .input:
  185. context.fireUserInboundEventTriggered(ChannelEvent.inputClosed)
  186. promise?.succeed()
  187. case .output:
  188. // We flush all pending messages and update the internal state machine's
  189. // state, but we don't close the outbound end of the channel, because
  190. // forwarding the close in this case would cause the HTTP2 stream handler
  191. // to close the whole channel (as the mode is ignored in its implementation).
  192. do {
  193. try self.stateMachine.closeOutbound()
  194. // Force a flush by calling _flush instead of flush
  195. // (otherwise, we'd skip flushing if we're in a read loop)
  196. self._flush(context: context)
  197. promise?.succeed()
  198. } catch let invalidState {
  199. let error = RPCError(invalidState)
  200. promise?.fail(error)
  201. context.fireErrorCaught(error)
  202. }
  203. case .all:
  204. // Since we're closing the whole channel here, we *do* forward the close
  205. // down the pipeline.
  206. do {
  207. try self.stateMachine.closeOutbound()
  208. // Force a flush by calling _flush
  209. // (otherwise, we'd skip flushing if we're in a read loop)
  210. self._flush(context: context)
  211. context.close(mode: mode, promise: promise)
  212. } catch let invalidState {
  213. let error = RPCError(invalidState)
  214. promise?.fail(error)
  215. context.fireErrorCaught(error)
  216. }
  217. }
  218. }
  219. func flush(context: ChannelHandlerContext) {
  220. if self.isReading {
  221. // We don't want to flush yet if we're still in a read loop.
  222. self.flushPending = true
  223. return
  224. }
  225. self._flush(context: context)
  226. }
  227. private func _flush(context: ChannelHandlerContext) {
  228. do {
  229. loop: while true {
  230. switch try self.stateMachine.nextOutboundFrame() {
  231. case .sendFrame(let byteBuffer, let promise):
  232. self.flushPending = true
  233. context.write(
  234. self.wrapOutboundOut(.data(.init(data: .byteBuffer(byteBuffer)))),
  235. promise: promise
  236. )
  237. case .noMoreMessages:
  238. // If we're done writing (i.e. we have no more messages returned from state
  239. // machine) then return.
  240. if self.requestFinished { return }
  241. self.requestFinished = true
  242. // Write an empty data frame with the EOS flag set, to signal the RPC
  243. // request is now finished.
  244. context.write(
  245. self.wrapOutboundOut(
  246. HTTP2Frame.FramePayload.data(
  247. .init(
  248. data: .byteBuffer(.init()),
  249. endStream: true
  250. )
  251. )
  252. ),
  253. promise: nil
  254. )
  255. context.flush()
  256. break loop
  257. case .awaitMoreMessages:
  258. if self.flushPending {
  259. self.flushPending = false
  260. context.flush()
  261. }
  262. break loop
  263. case .closeAndFailPromise(let promise, let error):
  264. context.close(mode: .all, promise: nil)
  265. promise?.fail(error)
  266. break loop
  267. }
  268. }
  269. } catch let invalidState {
  270. context.fireErrorCaught(RPCError(invalidState))
  271. }
  272. }
  273. }