GRPCServerStreamHandler.swift 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
  1. /*
  2. * Copyright 2024, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. package import GRPCCore
  17. package import NIOCore
  18. package import NIOHTTP2
  19. @available(gRPCSwiftNIOTransport 1.0, *)
  20. package final class GRPCServerStreamHandler: ChannelDuplexHandler, RemovableChannelHandler {
  21. package typealias InboundIn = HTTP2Frame.FramePayload
  22. package typealias InboundOut = RPCRequestPart<GRPCNIOTransportBytes>
  23. package typealias OutboundIn = RPCResponsePart<GRPCNIOTransportBytes>
  24. package typealias OutboundOut = HTTP2Frame.FramePayload
  25. private var stateMachine: GRPCStreamStateMachine
  26. private let eventLoop: any EventLoop
  27. private var isReading = false
  28. private var flushPending = false
  29. private var isCancelled = false
  30. // We buffer the final status + trailers to avoid reordering issues (i.e.,
  31. // if there are messages still not written into the channel because flush has
  32. // not been called, but the server sends back trailers).
  33. private var pendingTrailers:
  34. (trailers: HTTP2Frame.FramePayload, promise: EventLoopPromise<Void>?)?
  35. private let methodDescriptorPromise: EventLoopPromise<MethodDescriptor>
  36. private var cancellationHandle: Optional<ServerContext.RPCCancellationHandle>
  37. package let connectionManagementHandler: ServerConnectionManagementHandler.SyncView
  38. // Existential errors unconditionally allocate, avoid this per-use allocation by doing it
  39. // statically.
  40. private static let handlerRemovedBeforeDescriptorResolved: any Error = RPCError(
  41. code: .unavailable,
  42. message: "RPC stream was closed before we got any Metadata."
  43. )
  44. package init(
  45. scheme: Scheme,
  46. acceptedEncodings: CompressionAlgorithmSet,
  47. maxPayloadSize: Int,
  48. methodDescriptorPromise: EventLoopPromise<MethodDescriptor>,
  49. eventLoop: any EventLoop,
  50. connectionManagementHandler: ServerConnectionManagementHandler.SyncView,
  51. cancellationHandler: ServerContext.RPCCancellationHandle? = nil,
  52. skipStateMachineAssertions: Bool = false
  53. ) {
  54. self.stateMachine = .init(
  55. configuration: .server(.init(scheme: scheme, acceptedEncodings: acceptedEncodings)),
  56. maxPayloadSize: maxPayloadSize,
  57. skipAssertions: skipStateMachineAssertions
  58. )
  59. self.methodDescriptorPromise = methodDescriptorPromise
  60. self.cancellationHandle = cancellationHandler
  61. self.eventLoop = eventLoop
  62. self.connectionManagementHandler = connectionManagementHandler
  63. }
  64. package func setCancellationHandle(_ handle: ServerContext.RPCCancellationHandle) {
  65. if self.eventLoop.inEventLoop {
  66. self.syncSetCancellationHandle(handle)
  67. } else {
  68. let loopBoundSelf = NIOLoopBound(self, eventLoop: self.eventLoop)
  69. self.eventLoop.execute {
  70. loopBoundSelf.value.syncSetCancellationHandle(handle)
  71. }
  72. }
  73. }
  74. private func syncSetCancellationHandle(_ handle: ServerContext.RPCCancellationHandle) {
  75. assert(self.cancellationHandle == nil, "\(#function) must only be called once")
  76. if self.isCancelled {
  77. handle.cancel()
  78. } else {
  79. self.cancellationHandle = handle
  80. }
  81. }
  82. private func cancelRPC() {
  83. if let handle = self.cancellationHandle.take() {
  84. handle.cancel()
  85. } else {
  86. self.isCancelled = true
  87. }
  88. }
  89. }
  90. // - MARK: ChannelInboundHandler
  91. @available(gRPCSwiftNIOTransport 1.0, *)
  92. extension GRPCServerStreamHandler {
  93. package func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
  94. switch event {
  95. case is ChannelShouldQuiesceEvent:
  96. self.cancelRPC()
  97. default:
  98. ()
  99. }
  100. context.fireUserInboundEventTriggered(event)
  101. }
  102. package func channelRead(context: ChannelHandlerContext, data: NIOAny) {
  103. self.isReading = true
  104. let frame = self.unwrapInboundIn(data)
  105. switch frame {
  106. case .data(let frameData):
  107. let endStream = frameData.endStream
  108. switch frameData.data {
  109. case .byteBuffer(let buffer):
  110. do {
  111. switch try self.stateMachine.receive(buffer: buffer, endStream: endStream) {
  112. case .endRPCAndForwardErrorStatus_clientOnly:
  113. preconditionFailure(
  114. "OnBufferReceivedAction.endRPCAndForwardErrorStatus should never be returned for the server."
  115. )
  116. case .forwardErrorAndClose_serverOnly(let error):
  117. context.fireErrorCaught(error)
  118. context.close(mode: .all, promise: nil)
  119. case .readInbound:
  120. loop: while true {
  121. switch self.stateMachine.nextInboundMessage() {
  122. case .receiveMessage(let message):
  123. let wrapped = GRPCNIOTransportBytes(message)
  124. context.fireChannelRead(self.wrapInboundOut(.message(wrapped)))
  125. case .awaitMoreMessages:
  126. break loop
  127. case .noMoreMessages:
  128. context.fireUserInboundEventTriggered(ChannelEvent.inputClosed)
  129. break loop
  130. }
  131. }
  132. case .doNothing:
  133. ()
  134. }
  135. } catch let invalidState {
  136. let error = RPCError(invalidState)
  137. context.fireErrorCaught(error)
  138. }
  139. case .fileRegion:
  140. preconditionFailure("Unexpected IOData.fileRegion")
  141. }
  142. case .headers(let headers):
  143. do {
  144. let action = try self.stateMachine.receive(
  145. headers: headers.headers,
  146. endStream: headers.endStream
  147. )
  148. switch action {
  149. case .receivedMetadata(let metadata, let methodDescriptor):
  150. if let methodDescriptor = methodDescriptor {
  151. self.methodDescriptorPromise.succeed(methodDescriptor)
  152. context.fireChannelRead(self.wrapInboundOut(.metadata(metadata)))
  153. } else {
  154. assertionFailure("Method descriptor should have been present if we received metadata.")
  155. }
  156. case .rejectRPC_serverOnly(let trailers):
  157. self.flushPending = true
  158. self.methodDescriptorPromise.fail(
  159. RPCError(
  160. code: .unavailable,
  161. message: "RPC was rejected."
  162. )
  163. )
  164. let response = HTTP2Frame.FramePayload.headers(.init(headers: trailers, endStream: true))
  165. context.write(self.wrapOutboundOut(response), promise: nil)
  166. case .receivedStatusAndMetadata_clientOnly:
  167. assertionFailure("Unexpected action")
  168. case .protocolViolation_serverOnly:
  169. context.writeAndFlush(self.wrapOutboundOut(.rstStream(.protocolError)), promise: nil)
  170. context.close(promise: nil)
  171. case .doNothing:
  172. ()
  173. }
  174. } catch let invalidState {
  175. let error = RPCError(invalidState)
  176. context.fireErrorCaught(error)
  177. }
  178. case .rstStream:
  179. self.handleUnexpectedInboundClose(context: context, reason: .streamReset)
  180. case .ping, .goAway, .priority, .settings, .pushPromise, .windowUpdate,
  181. .alternativeService, .origin:
  182. ()
  183. }
  184. }
  185. package func channelReadComplete(context: ChannelHandlerContext) {
  186. self.isReading = false
  187. if self.flushPending {
  188. self.flushPending = false
  189. context.flush()
  190. }
  191. context.fireChannelReadComplete()
  192. }
  193. package func handlerRemoved(context: ChannelHandlerContext) {
  194. self.stateMachine.tearDown()
  195. self.methodDescriptorPromise.fail(Self.handlerRemovedBeforeDescriptorResolved)
  196. }
  197. package func channelInactive(context: ChannelHandlerContext) {
  198. self.handleUnexpectedInboundClose(context: context, reason: .channelInactive)
  199. context.fireChannelInactive()
  200. }
  201. package func errorCaught(context: ChannelHandlerContext, error: any Error) {
  202. self.handleUnexpectedInboundClose(context: context, reason: .errorThrown(error))
  203. }
  204. private func handleUnexpectedInboundClose(
  205. context: ChannelHandlerContext,
  206. reason: GRPCStreamStateMachine.UnexpectedInboundCloseReason
  207. ) {
  208. switch self.stateMachine.unexpectedInboundClose(reason: reason) {
  209. case .fireError_serverOnly(let wrappedError):
  210. self.cancelRPC()
  211. context.fireErrorCaught(wrappedError)
  212. case .doNothing:
  213. ()
  214. case .forwardStatus_clientOnly:
  215. assertionFailure(
  216. "`forwardStatus` should only happen on the client side, never on the server."
  217. )
  218. }
  219. }
  220. }
  221. // - MARK: ChannelOutboundHandler
  222. @available(gRPCSwiftNIOTransport 1.0, *)
  223. extension GRPCServerStreamHandler {
  224. package func write(
  225. context: ChannelHandlerContext,
  226. data: NIOAny,
  227. promise: EventLoopPromise<Void>?
  228. ) {
  229. let frame = self.unwrapOutboundIn(data)
  230. switch frame {
  231. case .metadata(let metadata):
  232. do {
  233. self.flushPending = true
  234. let headers = try self.stateMachine.send(metadata: metadata)
  235. context.write(self.wrapOutboundOut(.headers(.init(headers: headers))), promise: promise)
  236. self.connectionManagementHandler.wroteHeadersFrame()
  237. } catch let invalidState {
  238. let error = RPCError(invalidState)
  239. promise?.fail(error)
  240. context.fireErrorCaught(error)
  241. }
  242. case .message(let message):
  243. do {
  244. try self.stateMachine.send(message: message.buffer, promise: promise)
  245. self.connectionManagementHandler.wroteDataFrame()
  246. } catch let invalidState {
  247. let error = RPCError(invalidState)
  248. promise?.fail(error)
  249. context.fireErrorCaught(error)
  250. }
  251. case .status(let status, let metadata):
  252. do {
  253. let headers = try self.stateMachine.send(status: status, metadata: metadata)
  254. let response = HTTP2Frame.FramePayload.headers(.init(headers: headers, endStream: true))
  255. self.pendingTrailers = (response, promise)
  256. } catch let invalidState {
  257. let error = RPCError(invalidState)
  258. promise?.fail(error)
  259. context.fireErrorCaught(error)
  260. }
  261. }
  262. }
  263. package func flush(context: ChannelHandlerContext) {
  264. if self.isReading {
  265. // We don't want to flush yet if we're still in a read loop.
  266. return
  267. }
  268. do {
  269. loop: while true {
  270. switch try self.stateMachine.nextOutboundFrame() {
  271. case .sendFrame(let byteBuffer, let promise):
  272. self.flushPending = true
  273. context.write(
  274. self.wrapOutboundOut(.data(.init(data: .byteBuffer(byteBuffer)))),
  275. promise: promise
  276. )
  277. case .noMoreMessages:
  278. if let pendingTrailers = self.pendingTrailers {
  279. self.flushPending = true
  280. self.pendingTrailers = nil
  281. context.write(
  282. self.wrapOutboundOut(pendingTrailers.trailers),
  283. promise: pendingTrailers.promise
  284. )
  285. }
  286. break loop
  287. case .awaitMoreMessages:
  288. break loop
  289. case .closeAndFailPromise(let promise, let error):
  290. context.close(mode: .all, promise: nil)
  291. promise?.fail(error)
  292. }
  293. }
  294. if self.flushPending {
  295. self.flushPending = false
  296. context.flush()
  297. }
  298. } catch let invalidState {
  299. let error = RPCError(invalidState)
  300. context.fireErrorCaught(error)
  301. }
  302. }
  303. }