GRPCClientStreamHandler.swift 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. /*
  2. * Copyright 2024, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. internal import GRPCCore
  17. internal import NIOCore
  18. internal import NIOHTTP2
  19. final class GRPCClientStreamHandler: ChannelDuplexHandler {
  20. typealias InboundIn = HTTP2Frame.FramePayload
  21. typealias InboundOut = RPCResponsePart<GRPCNIOTransportBytes>
  22. typealias OutboundIn = RPCRequestPart<GRPCNIOTransportBytes>
  23. typealias OutboundOut = HTTP2Frame.FramePayload
  24. private var stateMachine: GRPCStreamStateMachine
  25. private var isReading = false
  26. private var flushPending = false
  27. private var requestFinished = false
  28. init(
  29. methodDescriptor: MethodDescriptor,
  30. scheme: Scheme,
  31. authority: String?,
  32. outboundEncoding: CompressionAlgorithm,
  33. acceptedEncodings: CompressionAlgorithmSet,
  34. maxPayloadSize: Int,
  35. skipStateMachineAssertions: Bool = false
  36. ) {
  37. self.stateMachine = .init(
  38. configuration: .client(
  39. .init(
  40. methodDescriptor: methodDescriptor,
  41. scheme: scheme,
  42. authority: authority,
  43. outboundEncoding: outboundEncoding,
  44. acceptedEncodings: acceptedEncodings
  45. )
  46. ),
  47. maxPayloadSize: maxPayloadSize,
  48. skipAssertions: skipStateMachineAssertions
  49. )
  50. }
  51. }
  52. // - MARK: ChannelInboundHandler
  53. extension GRPCClientStreamHandler {
  54. func channelRead(context: ChannelHandlerContext, data: NIOAny) {
  55. self.isReading = true
  56. let frame = self.unwrapInboundIn(data)
  57. switch frame {
  58. case .data(let frameData):
  59. let endStream = frameData.endStream
  60. switch frameData.data {
  61. case .byteBuffer(let buffer):
  62. do {
  63. switch try self.stateMachine.receive(buffer: buffer, endStream: endStream) {
  64. case .endRPCAndForwardErrorStatus_clientOnly(let status):
  65. context.fireChannelRead(self.wrapInboundOut(.status(status, [:])))
  66. context.close(promise: nil)
  67. case .forwardErrorAndClose_serverOnly:
  68. assertionFailure("Unexpected client action")
  69. case .readInbound:
  70. loop: while true {
  71. switch self.stateMachine.nextInboundMessage() {
  72. case .receiveMessage(let message):
  73. let wrapped = GRPCNIOTransportBytes(message)
  74. context.fireChannelRead(self.wrapInboundOut(.message(wrapped)))
  75. case .awaitMoreMessages:
  76. break loop
  77. case .noMoreMessages:
  78. // This could only happen if the server sends a data frame with EOS
  79. // set, without sending status and trailers.
  80. // If this happens, we should have forwarded an error status above
  81. // so we should never reach this point. Do nothing.
  82. break loop
  83. }
  84. }
  85. case .doNothing:
  86. ()
  87. }
  88. } catch let invalidState {
  89. let error = RPCError(invalidState)
  90. context.fireErrorCaught(error)
  91. }
  92. case .fileRegion:
  93. preconditionFailure("Unexpected IOData.fileRegion")
  94. }
  95. case .headers(let headers):
  96. do {
  97. let action = try self.stateMachine.receive(
  98. headers: headers.headers,
  99. endStream: headers.endStream
  100. )
  101. switch action {
  102. case .receivedMetadata(let metadata, _):
  103. context.fireChannelRead(self.wrapInboundOut(.metadata(metadata)))
  104. case .receivedStatusAndMetadata_clientOnly(let status, let metadata):
  105. context.fireChannelRead(self.wrapInboundOut(.status(status, metadata)))
  106. context.fireUserInboundEventTriggered(ChannelEvent.inputClosed)
  107. case .rejectRPC_serverOnly, .protocolViolation_serverOnly:
  108. assertionFailure("Unexpected action '\(action)'")
  109. case .doNothing:
  110. ()
  111. }
  112. } catch let invalidState {
  113. let error = RPCError(invalidState)
  114. context.fireErrorCaught(error)
  115. }
  116. case .rstStream:
  117. self.handleUnexpectedInboundClose(context: context, reason: .streamReset)
  118. case .ping, .goAway, .priority, .settings, .pushPromise, .windowUpdate,
  119. .alternativeService, .origin:
  120. ()
  121. }
  122. }
  123. func channelReadComplete(context: ChannelHandlerContext) {
  124. self.isReading = false
  125. if self.flushPending {
  126. self.flushPending = false
  127. self.flush(context: context)
  128. }
  129. context.fireChannelReadComplete()
  130. }
  131. func handlerRemoved(context: ChannelHandlerContext) {
  132. self.stateMachine.tearDown()
  133. }
  134. func channelInactive(context: ChannelHandlerContext) {
  135. self.handleUnexpectedInboundClose(context: context, reason: .channelInactive)
  136. context.fireChannelInactive()
  137. }
  138. func errorCaught(context: ChannelHandlerContext, error: any Error) {
  139. self.handleUnexpectedInboundClose(context: context, reason: .errorThrown(error))
  140. }
  141. private func handleUnexpectedInboundClose(
  142. context: ChannelHandlerContext,
  143. reason: GRPCStreamStateMachine.UnexpectedInboundCloseReason
  144. ) {
  145. switch self.stateMachine.unexpectedInboundClose(reason: reason) {
  146. case .forwardStatus_clientOnly(let status):
  147. context.fireChannelRead(self.wrapInboundOut(.status(status, [:])))
  148. case .doNothing:
  149. ()
  150. case .fireError_serverOnly:
  151. assertionFailure("`fireError` should only happen on the server side, never on the client.")
  152. }
  153. }
  154. }
  155. // - MARK: ChannelOutboundHandler
  156. extension GRPCClientStreamHandler {
  157. func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
  158. switch self.unwrapOutboundIn(data) {
  159. case .metadata(let metadata):
  160. do {
  161. self.flushPending = true
  162. let headers = try self.stateMachine.send(metadata: metadata)
  163. context.write(self.wrapOutboundOut(.headers(.init(headers: headers))), promise: promise)
  164. } catch let invalidState {
  165. let error = RPCError(invalidState)
  166. promise?.fail(error)
  167. context.fireErrorCaught(error)
  168. }
  169. case .message(let message):
  170. do {
  171. try self.stateMachine.send(message: message.buffer, promise: promise)
  172. } catch let invalidState {
  173. let error = RPCError(invalidState)
  174. promise?.fail(error)
  175. context.fireErrorCaught(error)
  176. }
  177. }
  178. }
  179. func close(context: ChannelHandlerContext, mode: CloseMode, promise: EventLoopPromise<Void>?) {
  180. switch mode {
  181. case .input:
  182. context.fireUserInboundEventTriggered(ChannelEvent.inputClosed)
  183. promise?.succeed()
  184. case .output:
  185. // We flush all pending messages and update the internal state machine's
  186. // state, but we don't close the outbound end of the channel, because
  187. // forwarding the close in this case would cause the HTTP2 stream handler
  188. // to close the whole channel (as the mode is ignored in its implementation).
  189. do {
  190. try self.stateMachine.closeOutbound()
  191. // Force a flush by calling _flush instead of flush
  192. // (otherwise, we'd skip flushing if we're in a read loop)
  193. self._flush(context: context)
  194. promise?.succeed()
  195. } catch let invalidState {
  196. let error = RPCError(invalidState)
  197. promise?.fail(error)
  198. context.fireErrorCaught(error)
  199. }
  200. case .all:
  201. // Since we're closing the whole channel here, we *do* forward the close
  202. // down the pipeline.
  203. do {
  204. try self.stateMachine.closeOutbound()
  205. // Force a flush by calling _flush
  206. // (otherwise, we'd skip flushing if we're in a read loop)
  207. self._flush(context: context)
  208. context.close(mode: mode, promise: promise)
  209. } catch let invalidState {
  210. let error = RPCError(invalidState)
  211. promise?.fail(error)
  212. context.fireErrorCaught(error)
  213. }
  214. }
  215. }
  216. func flush(context: ChannelHandlerContext) {
  217. if self.isReading {
  218. // We don't want to flush yet if we're still in a read loop.
  219. self.flushPending = true
  220. return
  221. }
  222. self._flush(context: context)
  223. }
  224. private func _flush(context: ChannelHandlerContext) {
  225. do {
  226. loop: while true {
  227. switch try self.stateMachine.nextOutboundFrame() {
  228. case .sendFrame(let byteBuffer, let promise):
  229. self.flushPending = true
  230. context.write(
  231. self.wrapOutboundOut(.data(.init(data: .byteBuffer(byteBuffer)))),
  232. promise: promise
  233. )
  234. case .noMoreMessages:
  235. // If we're done writing (i.e. we have no more messages returned from state
  236. // machine) then return.
  237. if self.requestFinished { return }
  238. self.requestFinished = true
  239. // Write an empty data frame with the EOS flag set, to signal the RPC
  240. // request is now finished.
  241. context.write(
  242. self.wrapOutboundOut(
  243. HTTP2Frame.FramePayload.data(
  244. .init(
  245. data: .byteBuffer(.init()),
  246. endStream: true
  247. )
  248. )
  249. ),
  250. promise: nil
  251. )
  252. context.flush()
  253. break loop
  254. case .awaitMoreMessages:
  255. if self.flushPending {
  256. self.flushPending = false
  257. context.flush()
  258. }
  259. break loop
  260. case .closeAndFailPromise(let promise, let error):
  261. context.close(mode: .all, promise: nil)
  262. promise?.fail(error)
  263. break loop
  264. }
  265. }
  266. } catch let invalidState {
  267. context.fireErrorCaught(RPCError(invalidState))
  268. }
  269. }
  270. }