GRPCClientStreamHandler.swift 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. /*
  2. * Copyright 2024, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import GRPCCore
  17. import NIOCore
  18. import NIOHTTP2
  19. @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
  20. final class GRPCClientStreamHandler: ChannelDuplexHandler {
  21. typealias InboundIn = HTTP2Frame.FramePayload
  22. typealias InboundOut = RPCResponsePart
  23. typealias OutboundIn = RPCRequestPart
  24. typealias OutboundOut = HTTP2Frame.FramePayload
  25. private var stateMachine: GRPCStreamStateMachine
  26. private var isReading = false
  27. private var flushPending = false
  28. init(
  29. methodDescriptor: MethodDescriptor,
  30. scheme: Scheme,
  31. outboundEncoding: CompressionAlgorithm,
  32. acceptedEncodings: [CompressionAlgorithm],
  33. maximumPayloadSize: Int,
  34. skipStateMachineAssertions: Bool = false
  35. ) {
  36. self.stateMachine = .init(
  37. configuration: .client(
  38. .init(
  39. methodDescriptor: methodDescriptor,
  40. scheme: scheme,
  41. outboundEncoding: outboundEncoding,
  42. acceptedEncodings: acceptedEncodings
  43. )
  44. ),
  45. maximumPayloadSize: maximumPayloadSize,
  46. skipAssertions: skipStateMachineAssertions
  47. )
  48. }
  49. }
  50. // - MARK: ChannelInboundHandler
  51. @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
  52. extension GRPCClientStreamHandler {
  53. func channelRead(context: ChannelHandlerContext, data: NIOAny) {
  54. self.isReading = true
  55. let frame = self.unwrapInboundIn(data)
  56. switch frame {
  57. case .data(let frameData):
  58. let endStream = frameData.endStream
  59. switch frameData.data {
  60. case .byteBuffer(let buffer):
  61. do {
  62. try self.stateMachine.receive(buffer: buffer, endStream: endStream)
  63. loop: while true {
  64. switch self.stateMachine.nextInboundMessage() {
  65. case .receiveMessage(let message):
  66. context.fireChannelRead(self.wrapInboundOut(.message(message)))
  67. case .awaitMoreMessages:
  68. break loop
  69. case .noMoreMessages:
  70. context.fireUserInboundEventTriggered(ChannelEvent.inputClosed)
  71. break loop
  72. }
  73. }
  74. } catch {
  75. context.fireErrorCaught(error)
  76. }
  77. case .fileRegion:
  78. preconditionFailure("Unexpected IOData.fileRegion")
  79. }
  80. case .headers(let headers):
  81. do {
  82. let action = try self.stateMachine.receive(
  83. headers: headers.headers,
  84. endStream: headers.endStream
  85. )
  86. switch action {
  87. case .receivedMetadata(let metadata):
  88. context.fireChannelRead(self.wrapInboundOut(.metadata(metadata)))
  89. case .rejectRPC:
  90. throw RPCError(
  91. code: .internalError,
  92. message: "Client cannot get rejectRPC."
  93. )
  94. case .receivedStatusAndMetadata(let status, let metadata):
  95. context.fireChannelRead(self.wrapInboundOut(.status(status, metadata)))
  96. case .doNothing:
  97. ()
  98. }
  99. } catch {
  100. context.fireErrorCaught(error)
  101. }
  102. case .ping, .goAway, .priority, .rstStream, .settings, .pushPromise, .windowUpdate,
  103. .alternativeService, .origin:
  104. ()
  105. }
  106. }
  107. func channelReadComplete(context: ChannelHandlerContext) {
  108. self.isReading = false
  109. if self.flushPending {
  110. self.flushPending = false
  111. self.flush(context: context)
  112. }
  113. context.fireChannelReadComplete()
  114. }
  115. func handlerRemoved(context: ChannelHandlerContext) {
  116. self.stateMachine.tearDown()
  117. }
  118. }
  119. // - MARK: ChannelOutboundHandler
  120. @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
  121. extension GRPCClientStreamHandler {
  122. func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
  123. switch self.unwrapOutboundIn(data) {
  124. case .metadata(let metadata):
  125. do {
  126. self.flushPending = true
  127. let headers = try self.stateMachine.send(metadata: metadata)
  128. context.write(self.wrapOutboundOut(.headers(.init(headers: headers))), promise: promise)
  129. } catch {
  130. promise?.fail(error)
  131. context.fireErrorCaught(error)
  132. }
  133. case .message(let message):
  134. do {
  135. try self.stateMachine.send(message: message, promise: promise)
  136. } catch {
  137. promise?.fail(error)
  138. context.fireErrorCaught(error)
  139. }
  140. }
  141. }
  142. func close(context: ChannelHandlerContext, mode: CloseMode, promise: EventLoopPromise<Void>?) {
  143. switch mode {
  144. case .output, .all:
  145. do {
  146. try self.stateMachine.closeOutbound()
  147. // Force a flush by calling _flush
  148. // (otherwise, we'd skip flushing if we're in a read loop)
  149. self._flush(context: context)
  150. context.close(mode: mode, promise: promise)
  151. } catch {
  152. promise?.fail(error)
  153. context.fireErrorCaught(error)
  154. }
  155. case .input:
  156. context.close(mode: .input, promise: promise)
  157. }
  158. }
  159. func flush(context: ChannelHandlerContext) {
  160. if self.isReading {
  161. // We don't want to flush yet if we're still in a read loop.
  162. self.flushPending = true
  163. return
  164. }
  165. self._flush(context: context)
  166. }
  167. private func _flush(context: ChannelHandlerContext) {
  168. do {
  169. loop: while true {
  170. switch try self.stateMachine.nextOutboundFrame() {
  171. case .sendFrame(let byteBuffer, let promise):
  172. self.flushPending = true
  173. context.write(
  174. self.wrapOutboundOut(.data(.init(data: .byteBuffer(byteBuffer)))),
  175. promise: promise
  176. )
  177. case .noMoreMessages:
  178. // Write an empty data frame with the EOS flag set, to signal the RPC
  179. // request is now finished.
  180. context.write(
  181. self.wrapOutboundOut(
  182. HTTP2Frame.FramePayload.data(
  183. .init(
  184. data: .byteBuffer(.init()),
  185. endStream: true
  186. )
  187. )
  188. ),
  189. promise: nil
  190. )
  191. context.flush()
  192. break loop
  193. case .awaitMoreMessages:
  194. if self.flushPending {
  195. self.flushPending = false
  196. context.flush()
  197. }
  198. break loop
  199. }
  200. }
  201. } catch {
  202. context.fireErrorCaught(error)
  203. }
  204. }
  205. }