GRPCServerStreamHandler.swift 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. /*
  2. * Copyright 2024, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. package import GRPCCore
  17. package import NIOCore
  18. package import NIOHTTP2
  19. package final class GRPCServerStreamHandler: ChannelDuplexHandler, RemovableChannelHandler {
  20. package typealias InboundIn = HTTP2Frame.FramePayload
  21. package typealias InboundOut = RPCRequestPart
  22. package typealias OutboundIn = RPCResponsePart
  23. package typealias OutboundOut = HTTP2Frame.FramePayload
  24. private var stateMachine: GRPCStreamStateMachine
  25. private let eventLoop: any EventLoop
  26. private var isReading = false
  27. private var flushPending = false
  28. private var isCancelled = false
  29. // We buffer the final status + trailers to avoid reordering issues (i.e.,
  30. // if there are messages still not written into the channel because flush has
  31. // not been called, but the server sends back trailers).
  32. private var pendingTrailers:
  33. (trailers: HTTP2Frame.FramePayload, promise: EventLoopPromise<Void>?)?
  34. private let methodDescriptorPromise: EventLoopPromise<MethodDescriptor>
  35. private var cancellationHandle: Optional<ServerContext.RPCCancellationHandle>
  36. // Existential errors unconditionally allocate, avoid this per-use allocation by doing it
  37. // statically.
  38. private static let handlerRemovedBeforeDescriptorResolved: any Error = RPCError(
  39. code: .unavailable,
  40. message: "RPC stream was closed before we got any Metadata."
  41. )
  42. package init(
  43. scheme: Scheme,
  44. acceptedEncodings: CompressionAlgorithmSet,
  45. maxPayloadSize: Int,
  46. methodDescriptorPromise: EventLoopPromise<MethodDescriptor>,
  47. eventLoop: any EventLoop,
  48. cancellationHandler: ServerContext.RPCCancellationHandle? = nil,
  49. skipStateMachineAssertions: Bool = false
  50. ) {
  51. self.stateMachine = .init(
  52. configuration: .server(.init(scheme: scheme, acceptedEncodings: acceptedEncodings)),
  53. maxPayloadSize: maxPayloadSize,
  54. skipAssertions: skipStateMachineAssertions
  55. )
  56. self.methodDescriptorPromise = methodDescriptorPromise
  57. self.cancellationHandle = cancellationHandler
  58. self.eventLoop = eventLoop
  59. }
  60. package func setCancellationHandle(_ handle: ServerContext.RPCCancellationHandle) {
  61. if self.eventLoop.inEventLoop {
  62. self.syncSetCancellationHandle(handle)
  63. } else {
  64. let loopBoundSelf = NIOLoopBound(self, eventLoop: self.eventLoop)
  65. self.eventLoop.execute {
  66. loopBoundSelf.value.syncSetCancellationHandle(handle)
  67. }
  68. }
  69. }
  70. private func syncSetCancellationHandle(_ handle: ServerContext.RPCCancellationHandle) {
  71. assert(self.cancellationHandle == nil, "\(#function) must only be called once")
  72. if self.isCancelled {
  73. handle.cancel()
  74. } else {
  75. self.cancellationHandle = handle
  76. }
  77. }
  78. private func cancelRPC() {
  79. if let handle = self.cancellationHandle.take() {
  80. handle.cancel()
  81. } else {
  82. self.isCancelled = true
  83. }
  84. }
  85. }
  86. // - MARK: ChannelInboundHandler
  87. extension GRPCServerStreamHandler {
  88. package func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
  89. switch event {
  90. case is ChannelShouldQuiesceEvent:
  91. self.cancelRPC()
  92. default:
  93. ()
  94. }
  95. context.fireUserInboundEventTriggered(event)
  96. }
  97. package func channelRead(context: ChannelHandlerContext, data: NIOAny) {
  98. self.isReading = true
  99. let frame = self.unwrapInboundIn(data)
  100. switch frame {
  101. case .data(let frameData):
  102. let endStream = frameData.endStream
  103. switch frameData.data {
  104. case .byteBuffer(let buffer):
  105. do {
  106. switch try self.stateMachine.receive(buffer: buffer, endStream: endStream) {
  107. case .endRPCAndForwardErrorStatus_clientOnly:
  108. preconditionFailure(
  109. "OnBufferReceivedAction.endRPCAndForwardErrorStatus should never be returned for the server."
  110. )
  111. case .forwardErrorAndClose_serverOnly(let error):
  112. context.fireErrorCaught(error)
  113. context.close(mode: .all, promise: nil)
  114. case .readInbound:
  115. loop: while true {
  116. switch self.stateMachine.nextInboundMessage() {
  117. case .receiveMessage(let message):
  118. context.fireChannelRead(self.wrapInboundOut(.message(message)))
  119. case .awaitMoreMessages:
  120. break loop
  121. case .noMoreMessages:
  122. context.fireUserInboundEventTriggered(ChannelEvent.inputClosed)
  123. break loop
  124. }
  125. }
  126. case .doNothing:
  127. ()
  128. }
  129. } catch let invalidState {
  130. let error = RPCError(invalidState)
  131. context.fireErrorCaught(error)
  132. }
  133. case .fileRegion:
  134. preconditionFailure("Unexpected IOData.fileRegion")
  135. }
  136. case .headers(let headers):
  137. do {
  138. let action = try self.stateMachine.receive(
  139. headers: headers.headers,
  140. endStream: headers.endStream
  141. )
  142. switch action {
  143. case .receivedMetadata(let metadata, let methodDescriptor):
  144. if let methodDescriptor = methodDescriptor {
  145. self.methodDescriptorPromise.succeed(methodDescriptor)
  146. context.fireChannelRead(self.wrapInboundOut(.metadata(metadata)))
  147. } else {
  148. assertionFailure("Method descriptor should have been present if we received metadata.")
  149. }
  150. case .rejectRPC_serverOnly(let trailers):
  151. self.flushPending = true
  152. self.methodDescriptorPromise.fail(
  153. RPCError(
  154. code: .unavailable,
  155. message: "RPC was rejected."
  156. )
  157. )
  158. let response = HTTP2Frame.FramePayload.headers(.init(headers: trailers, endStream: true))
  159. context.write(self.wrapOutboundOut(response), promise: nil)
  160. case .receivedStatusAndMetadata_clientOnly:
  161. assertionFailure("Unexpected action")
  162. case .protocolViolation_serverOnly:
  163. context.writeAndFlush(self.wrapOutboundOut(.rstStream(.protocolError)), promise: nil)
  164. context.close(promise: nil)
  165. case .doNothing:
  166. ()
  167. }
  168. } catch let invalidState {
  169. let error = RPCError(invalidState)
  170. context.fireErrorCaught(error)
  171. }
  172. case .rstStream:
  173. self.handleUnexpectedInboundClose(context: context, reason: .streamReset)
  174. case .ping, .goAway, .priority, .settings, .pushPromise, .windowUpdate,
  175. .alternativeService, .origin:
  176. ()
  177. }
  178. }
  179. package func channelReadComplete(context: ChannelHandlerContext) {
  180. self.isReading = false
  181. if self.flushPending {
  182. self.flushPending = false
  183. context.flush()
  184. }
  185. context.fireChannelReadComplete()
  186. }
  187. package func handlerRemoved(context: ChannelHandlerContext) {
  188. self.stateMachine.tearDown()
  189. self.methodDescriptorPromise.fail(Self.handlerRemovedBeforeDescriptorResolved)
  190. }
  191. package func channelInactive(context: ChannelHandlerContext) {
  192. self.handleUnexpectedInboundClose(context: context, reason: .channelInactive)
  193. context.fireChannelInactive()
  194. }
  195. package func errorCaught(context: ChannelHandlerContext, error: any Error) {
  196. self.handleUnexpectedInboundClose(context: context, reason: .errorThrown(error))
  197. }
  198. private func handleUnexpectedInboundClose(
  199. context: ChannelHandlerContext,
  200. reason: GRPCStreamStateMachine.UnexpectedInboundCloseReason
  201. ) {
  202. switch self.stateMachine.unexpectedInboundClose(reason: reason) {
  203. case .fireError_serverOnly(let wrappedError):
  204. self.cancelRPC()
  205. context.fireErrorCaught(wrappedError)
  206. case .doNothing:
  207. ()
  208. case .forwardStatus_clientOnly:
  209. assertionFailure(
  210. "`forwardStatus` should only happen on the client side, never on the server."
  211. )
  212. }
  213. }
  214. }
  215. // - MARK: ChannelOutboundHandler
  216. extension GRPCServerStreamHandler {
  217. package func write(
  218. context: ChannelHandlerContext,
  219. data: NIOAny,
  220. promise: EventLoopPromise<Void>?
  221. ) {
  222. let frame = self.unwrapOutboundIn(data)
  223. switch frame {
  224. case .metadata(let metadata):
  225. do {
  226. self.flushPending = true
  227. let headers = try self.stateMachine.send(metadata: metadata)
  228. context.write(self.wrapOutboundOut(.headers(.init(headers: headers))), promise: promise)
  229. } catch let invalidState {
  230. let error = RPCError(invalidState)
  231. promise?.fail(error)
  232. context.fireErrorCaught(error)
  233. }
  234. case .message(let message):
  235. do {
  236. try self.stateMachine.send(message: message, promise: promise)
  237. } catch let invalidState {
  238. let error = RPCError(invalidState)
  239. promise?.fail(error)
  240. context.fireErrorCaught(error)
  241. }
  242. case .status(let status, let metadata):
  243. do {
  244. let headers = try self.stateMachine.send(status: status, metadata: metadata)
  245. let response = HTTP2Frame.FramePayload.headers(.init(headers: headers, endStream: true))
  246. self.pendingTrailers = (response, promise)
  247. } catch let invalidState {
  248. let error = RPCError(invalidState)
  249. promise?.fail(error)
  250. context.fireErrorCaught(error)
  251. }
  252. }
  253. }
  254. package func flush(context: ChannelHandlerContext) {
  255. if self.isReading {
  256. // We don't want to flush yet if we're still in a read loop.
  257. return
  258. }
  259. do {
  260. loop: while true {
  261. switch try self.stateMachine.nextOutboundFrame() {
  262. case .sendFrame(let byteBuffer, let promise):
  263. self.flushPending = true
  264. context.write(
  265. self.wrapOutboundOut(.data(.init(data: .byteBuffer(byteBuffer)))),
  266. promise: promise
  267. )
  268. case .noMoreMessages:
  269. if let pendingTrailers = self.pendingTrailers {
  270. self.flushPending = true
  271. self.pendingTrailers = nil
  272. context.write(
  273. self.wrapOutboundOut(pendingTrailers.trailers),
  274. promise: pendingTrailers.promise
  275. )
  276. }
  277. break loop
  278. case .awaitMoreMessages:
  279. break loop
  280. case .closeAndFailPromise(let promise, let error):
  281. context.close(mode: .all, promise: nil)
  282. promise?.fail(error)
  283. }
  284. }
  285. if self.flushPending {
  286. self.flushPending = false
  287. context.flush()
  288. }
  289. } catch let invalidState {
  290. let error = RPCError(invalidState)
  291. context.fireErrorCaught(error)
  292. }
  293. }
  294. }