GRPCClientStreamHandler.swift 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. /*
  2. * Copyright 2024, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. internal import GRPCCore
  17. internal import NIOCore
  18. internal import NIOHTTP2
  19. @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
  20. final class GRPCClientStreamHandler: ChannelDuplexHandler {
  21. typealias InboundIn = HTTP2Frame.FramePayload
  22. typealias InboundOut = RPCResponsePart
  23. typealias OutboundIn = RPCRequestPart
  24. typealias OutboundOut = HTTP2Frame.FramePayload
  25. private var stateMachine: GRPCStreamStateMachine
  26. private var isReading = false
  27. private var flushPending = false
  28. init(
  29. methodDescriptor: MethodDescriptor,
  30. scheme: Scheme,
  31. outboundEncoding: CompressionAlgorithm,
  32. acceptedEncodings: CompressionAlgorithmSet,
  33. maximumPayloadSize: Int,
  34. skipStateMachineAssertions: Bool = false
  35. ) {
  36. self.stateMachine = .init(
  37. configuration: .client(
  38. .init(
  39. methodDescriptor: methodDescriptor,
  40. scheme: scheme,
  41. outboundEncoding: outboundEncoding,
  42. acceptedEncodings: acceptedEncodings
  43. )
  44. ),
  45. maximumPayloadSize: maximumPayloadSize,
  46. skipAssertions: skipStateMachineAssertions
  47. )
  48. }
  49. }
  50. // - MARK: ChannelInboundHandler
  51. @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
  52. extension GRPCClientStreamHandler {
  53. func channelRead(context: ChannelHandlerContext, data: NIOAny) {
  54. self.isReading = true
  55. let frame = self.unwrapInboundIn(data)
  56. switch frame {
  57. case .data(let frameData):
  58. let endStream = frameData.endStream
  59. switch frameData.data {
  60. case .byteBuffer(let buffer):
  61. do {
  62. switch try self.stateMachine.receive(buffer: buffer, endStream: endStream) {
  63. case .endRPCAndForwardErrorStatus_clientOnly(let status):
  64. context.fireChannelRead(self.wrapInboundOut(.status(status, [:])))
  65. context.close(promise: nil)
  66. case .forwardErrorAndClose_serverOnly:
  67. assertionFailure("Unexpected client action")
  68. case .readInbound:
  69. loop: while true {
  70. switch self.stateMachine.nextInboundMessage() {
  71. case .receiveMessage(let message):
  72. context.fireChannelRead(self.wrapInboundOut(.message(message)))
  73. case .awaitMoreMessages:
  74. break loop
  75. case .noMoreMessages:
  76. // This could only happen if the server sends a data frame with EOS
  77. // set, without sending status and trailers.
  78. // If this happens, we should have forwarded an error status above
  79. // so we should never reach this point. Do nothing.
  80. break loop
  81. }
  82. }
  83. case .doNothing:
  84. ()
  85. }
  86. } catch let invalidState {
  87. let error = RPCError(invalidState)
  88. context.fireErrorCaught(error)
  89. }
  90. case .fileRegion:
  91. preconditionFailure("Unexpected IOData.fileRegion")
  92. }
  93. case .headers(let headers):
  94. do {
  95. let action = try self.stateMachine.receive(
  96. headers: headers.headers,
  97. endStream: headers.endStream
  98. )
  99. switch action {
  100. case .receivedMetadata(let metadata, _):
  101. context.fireChannelRead(self.wrapInboundOut(.metadata(metadata)))
  102. case .receivedStatusAndMetadata_clientOnly(let status, let metadata):
  103. context.fireChannelRead(self.wrapInboundOut(.status(status, metadata)))
  104. context.fireUserInboundEventTriggered(ChannelEvent.inputClosed)
  105. case .rejectRPC_serverOnly, .protocolViolation_serverOnly:
  106. assertionFailure("Unexpected action '\(action)'")
  107. case .doNothing:
  108. ()
  109. }
  110. } catch let invalidState {
  111. let error = RPCError(invalidState)
  112. context.fireErrorCaught(error)
  113. }
  114. case .rstStream:
  115. self.handleUnexpectedInboundClose(context: context, reason: .streamReset)
  116. case .ping, .goAway, .priority, .settings, .pushPromise, .windowUpdate,
  117. .alternativeService, .origin:
  118. ()
  119. }
  120. }
  121. func channelReadComplete(context: ChannelHandlerContext) {
  122. self.isReading = false
  123. if self.flushPending {
  124. self.flushPending = false
  125. self.flush(context: context)
  126. }
  127. context.fireChannelReadComplete()
  128. }
  129. func handlerRemoved(context: ChannelHandlerContext) {
  130. self.stateMachine.tearDown()
  131. }
  132. func channelInactive(context: ChannelHandlerContext) {
  133. self.handleUnexpectedInboundClose(context: context, reason: .channelInactive)
  134. context.fireChannelInactive()
  135. }
  136. func errorCaught(context: ChannelHandlerContext, error: any Error) {
  137. self.handleUnexpectedInboundClose(context: context, reason: .errorThrown(error))
  138. }
  139. private func handleUnexpectedInboundClose(
  140. context: ChannelHandlerContext,
  141. reason: GRPCStreamStateMachine.UnexpectedInboundCloseReason
  142. ) {
  143. switch self.stateMachine.unexpectedInboundClose(reason: reason) {
  144. case .forwardStatus_clientOnly(let status):
  145. context.fireChannelRead(self.wrapInboundOut(.status(status, [:])))
  146. case .doNothing:
  147. ()
  148. case .fireError_serverOnly:
  149. assertionFailure("`fireError` should only happen on the server side, never on the client.")
  150. }
  151. }
  152. }
  153. // - MARK: ChannelOutboundHandler
  154. @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
  155. extension GRPCClientStreamHandler {
  156. func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
  157. switch self.unwrapOutboundIn(data) {
  158. case .metadata(let metadata):
  159. do {
  160. self.flushPending = true
  161. let headers = try self.stateMachine.send(metadata: metadata)
  162. context.write(self.wrapOutboundOut(.headers(.init(headers: headers))), promise: promise)
  163. } catch let invalidState {
  164. let error = RPCError(invalidState)
  165. promise?.fail(error)
  166. context.fireErrorCaught(error)
  167. }
  168. case .message(let message):
  169. do {
  170. try self.stateMachine.send(message: message, promise: promise)
  171. } catch let invalidState {
  172. let error = RPCError(invalidState)
  173. promise?.fail(error)
  174. context.fireErrorCaught(error)
  175. }
  176. }
  177. }
  178. func close(context: ChannelHandlerContext, mode: CloseMode, promise: EventLoopPromise<Void>?) {
  179. switch mode {
  180. case .input:
  181. context.fireUserInboundEventTriggered(ChannelEvent.inputClosed)
  182. promise?.succeed()
  183. case .output:
  184. // We flush all pending messages and update the internal state machine's
  185. // state, but we don't close the outbound end of the channel, because
  186. // forwarding the close in this case would cause the HTTP2 stream handler
  187. // to close the whole channel (as the mode is ignored in its implementation).
  188. do {
  189. try self.stateMachine.closeOutbound()
  190. // Force a flush by calling _flush instead of flush
  191. // (otherwise, we'd skip flushing if we're in a read loop)
  192. self._flush(context: context)
  193. promise?.succeed()
  194. } catch let invalidState {
  195. let error = RPCError(invalidState)
  196. promise?.fail(error)
  197. context.fireErrorCaught(error)
  198. }
  199. case .all:
  200. // Since we're closing the whole channel here, we *do* forward the close
  201. // down the pipeline.
  202. do {
  203. try self.stateMachine.closeOutbound()
  204. // Force a flush by calling _flush
  205. // (otherwise, we'd skip flushing if we're in a read loop)
  206. self._flush(context: context)
  207. context.close(mode: mode, promise: promise)
  208. } catch let invalidState {
  209. let error = RPCError(invalidState)
  210. promise?.fail(error)
  211. context.fireErrorCaught(error)
  212. }
  213. }
  214. }
  215. func flush(context: ChannelHandlerContext) {
  216. if self.isReading {
  217. // We don't want to flush yet if we're still in a read loop.
  218. self.flushPending = true
  219. return
  220. }
  221. self._flush(context: context)
  222. }
  223. private func _flush(context: ChannelHandlerContext) {
  224. do {
  225. loop: while true {
  226. switch try self.stateMachine.nextOutboundFrame() {
  227. case .sendFrame(let byteBuffer, let promise):
  228. self.flushPending = true
  229. context.write(
  230. self.wrapOutboundOut(.data(.init(data: .byteBuffer(byteBuffer)))),
  231. promise: promise
  232. )
  233. case .noMoreMessages:
  234. // Write an empty data frame with the EOS flag set, to signal the RPC
  235. // request is now finished.
  236. context.write(
  237. self.wrapOutboundOut(
  238. HTTP2Frame.FramePayload.data(
  239. .init(
  240. data: .byteBuffer(.init()),
  241. endStream: true
  242. )
  243. )
  244. ),
  245. promise: nil
  246. )
  247. context.flush()
  248. break loop
  249. case .awaitMoreMessages:
  250. if self.flushPending {
  251. self.flushPending = false
  252. context.flush()
  253. }
  254. break loop
  255. case .closeAndFailPromise(let promise, let error):
  256. context.close(mode: .all, promise: nil)
  257. promise?.fail(error)
  258. break loop
  259. }
  260. }
  261. } catch let invalidState {
  262. context.fireErrorCaught(RPCError(invalidState))
  263. }
  264. }
  265. }