| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344 |
- /*
- * Copyright 2024, gRPC Authors All rights reserved.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- package import GRPCCore
- package import NIOCore
- package import NIOHTTP2
- @available(gRPCSwiftNIOTransport 2.0, *)
- package final class GRPCServerStreamHandler: ChannelDuplexHandler, RemovableChannelHandler {
- package typealias InboundIn = HTTP2Frame.FramePayload
- package typealias InboundOut = RPCRequestPart<GRPCNIOTransportBytes>
- package typealias OutboundIn = RPCResponsePart<GRPCNIOTransportBytes>
- package typealias OutboundOut = HTTP2Frame.FramePayload
- private var stateMachine: GRPCStreamStateMachine
- private let eventLoop: any EventLoop
- private var isReading = false
- private var flushPending = false
- private var isCancelled = false
- // We buffer the final status + trailers to avoid reordering issues (i.e.,
- // if there are messages still not written into the channel because flush has
- // not been called, but the server sends back trailers).
- private var pendingTrailers:
- (trailers: HTTP2Frame.FramePayload, promise: EventLoopPromise<Void>?)?
- private let methodDescriptorPromise: EventLoopPromise<MethodDescriptor>
- private var cancellationHandle: Optional<ServerContext.RPCCancellationHandle>
- // Existential errors unconditionally allocate, avoid this per-use allocation by doing it
- // statically.
- private static let handlerRemovedBeforeDescriptorResolved: any Error = RPCError(
- code: .unavailable,
- message: "RPC stream was closed before we got any Metadata."
- )
- package init(
- scheme: Scheme,
- acceptedEncodings: CompressionAlgorithmSet,
- maxPayloadSize: Int,
- methodDescriptorPromise: EventLoopPromise<MethodDescriptor>,
- eventLoop: any EventLoop,
- cancellationHandler: ServerContext.RPCCancellationHandle? = nil,
- skipStateMachineAssertions: Bool = false
- ) {
- self.stateMachine = .init(
- configuration: .server(.init(scheme: scheme, acceptedEncodings: acceptedEncodings)),
- maxPayloadSize: maxPayloadSize,
- skipAssertions: skipStateMachineAssertions
- )
- self.methodDescriptorPromise = methodDescriptorPromise
- self.cancellationHandle = cancellationHandler
- self.eventLoop = eventLoop
- }
- package func setCancellationHandle(_ handle: ServerContext.RPCCancellationHandle) {
- if self.eventLoop.inEventLoop {
- self.syncSetCancellationHandle(handle)
- } else {
- let loopBoundSelf = NIOLoopBound(self, eventLoop: self.eventLoop)
- self.eventLoop.execute {
- loopBoundSelf.value.syncSetCancellationHandle(handle)
- }
- }
- }
- private func syncSetCancellationHandle(_ handle: ServerContext.RPCCancellationHandle) {
- assert(self.cancellationHandle == nil, "\(#function) must only be called once")
- if self.isCancelled {
- handle.cancel()
- } else {
- self.cancellationHandle = handle
- }
- }
- private func cancelRPC() {
- if let handle = self.cancellationHandle.take() {
- handle.cancel()
- } else {
- self.isCancelled = true
- }
- }
- }
- // - MARK: ChannelInboundHandler
- @available(gRPCSwiftNIOTransport 2.0, *)
- extension GRPCServerStreamHandler {
- package func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
- switch event {
- case is ChannelShouldQuiesceEvent:
- self.cancelRPC()
- default:
- ()
- }
- context.fireUserInboundEventTriggered(event)
- }
- package func channelRead(context: ChannelHandlerContext, data: NIOAny) {
- self.isReading = true
- let frame = self.unwrapInboundIn(data)
- switch frame {
- case .data(let frameData):
- let endStream = frameData.endStream
- switch frameData.data {
- case .byteBuffer(let buffer):
- do {
- switch try self.stateMachine.receive(buffer: buffer, endStream: endStream) {
- case .endRPCAndForwardErrorStatus_clientOnly:
- preconditionFailure(
- "OnBufferReceivedAction.endRPCAndForwardErrorStatus should never be returned for the server."
- )
- case .forwardErrorAndClose_serverOnly(let error):
- context.fireErrorCaught(error)
- context.close(mode: .all, promise: nil)
- case .readInbound:
- loop: while true {
- switch self.stateMachine.nextInboundMessage() {
- case .receiveMessage(let message):
- let wrapped = GRPCNIOTransportBytes(message)
- context.fireChannelRead(self.wrapInboundOut(.message(wrapped)))
- case .awaitMoreMessages:
- break loop
- case .noMoreMessages:
- context.fireUserInboundEventTriggered(ChannelEvent.inputClosed)
- break loop
- }
- }
- case .doNothing:
- ()
- }
- } catch let invalidState {
- let error = RPCError(invalidState)
- context.fireErrorCaught(error)
- }
- case .fileRegion:
- preconditionFailure("Unexpected IOData.fileRegion")
- }
- case .headers(let headers):
- do {
- let action = try self.stateMachine.receive(
- headers: headers.headers,
- endStream: headers.endStream
- )
- switch action {
- case .receivedMetadata(let metadata, let methodDescriptor):
- if let methodDescriptor = methodDescriptor {
- self.methodDescriptorPromise.succeed(methodDescriptor)
- context.fireChannelRead(self.wrapInboundOut(.metadata(metadata)))
- } else {
- assertionFailure("Method descriptor should have been present if we received metadata.")
- }
- case .rejectRPC_serverOnly(let trailers):
- self.flushPending = true
- self.methodDescriptorPromise.fail(
- RPCError(
- code: .unavailable,
- message: "RPC was rejected."
- )
- )
- let response = HTTP2Frame.FramePayload.headers(.init(headers: trailers, endStream: true))
- context.write(self.wrapOutboundOut(response), promise: nil)
- case .receivedStatusAndMetadata_clientOnly:
- assertionFailure("Unexpected action")
- case .protocolViolation_serverOnly:
- context.writeAndFlush(self.wrapOutboundOut(.rstStream(.protocolError)), promise: nil)
- context.close(promise: nil)
- case .doNothing:
- ()
- }
- } catch let invalidState {
- let error = RPCError(invalidState)
- context.fireErrorCaught(error)
- }
- case .rstStream:
- self.handleUnexpectedInboundClose(context: context, reason: .streamReset)
- case .ping, .goAway, .priority, .settings, .pushPromise, .windowUpdate,
- .alternativeService, .origin:
- ()
- }
- }
- package func channelReadComplete(context: ChannelHandlerContext) {
- self.isReading = false
- if self.flushPending {
- self.flushPending = false
- context.flush()
- }
- context.fireChannelReadComplete()
- }
- package func handlerRemoved(context: ChannelHandlerContext) {
- self.stateMachine.tearDown()
- self.methodDescriptorPromise.fail(Self.handlerRemovedBeforeDescriptorResolved)
- }
- package func channelInactive(context: ChannelHandlerContext) {
- self.handleUnexpectedInboundClose(context: context, reason: .channelInactive)
- context.fireChannelInactive()
- }
- package func errorCaught(context: ChannelHandlerContext, error: any Error) {
- self.handleUnexpectedInboundClose(context: context, reason: .errorThrown(error))
- }
- private func handleUnexpectedInboundClose(
- context: ChannelHandlerContext,
- reason: GRPCStreamStateMachine.UnexpectedInboundCloseReason
- ) {
- switch self.stateMachine.unexpectedInboundClose(reason: reason) {
- case .fireError_serverOnly(let wrappedError):
- self.cancelRPC()
- context.fireErrorCaught(wrappedError)
- case .doNothing:
- ()
- case .forwardStatus_clientOnly:
- assertionFailure(
- "`forwardStatus` should only happen on the client side, never on the server."
- )
- }
- }
- }
- // - MARK: ChannelOutboundHandler
- @available(gRPCSwiftNIOTransport 2.0, *)
- extension GRPCServerStreamHandler {
- package func write(
- context: ChannelHandlerContext,
- data: NIOAny,
- promise: EventLoopPromise<Void>?
- ) {
- let frame = self.unwrapOutboundIn(data)
- switch frame {
- case .metadata(let metadata):
- do {
- self.flushPending = true
- let headers = try self.stateMachine.send(metadata: metadata)
- context.write(self.wrapOutboundOut(.headers(.init(headers: headers))), promise: promise)
- } catch let invalidState {
- let error = RPCError(invalidState)
- promise?.fail(error)
- context.fireErrorCaught(error)
- }
- case .message(let message):
- do {
- try self.stateMachine.send(message: message.buffer, promise: promise)
- } catch let invalidState {
- let error = RPCError(invalidState)
- promise?.fail(error)
- context.fireErrorCaught(error)
- }
- case .status(let status, let metadata):
- do {
- let headers = try self.stateMachine.send(status: status, metadata: metadata)
- let response = HTTP2Frame.FramePayload.headers(.init(headers: headers, endStream: true))
- self.pendingTrailers = (response, promise)
- } catch let invalidState {
- let error = RPCError(invalidState)
- promise?.fail(error)
- context.fireErrorCaught(error)
- }
- }
- }
- package func flush(context: ChannelHandlerContext) {
- if self.isReading {
- // We don't want to flush yet if we're still in a read loop.
- return
- }
- do {
- loop: while true {
- switch try self.stateMachine.nextOutboundFrame() {
- case .sendFrame(let byteBuffer, let promise):
- self.flushPending = true
- context.write(
- self.wrapOutboundOut(.data(.init(data: .byteBuffer(byteBuffer)))),
- promise: promise
- )
- case .noMoreMessages:
- if let pendingTrailers = self.pendingTrailers {
- self.flushPending = true
- self.pendingTrailers = nil
- context.write(
- self.wrapOutboundOut(pendingTrailers.trailers),
- promise: pendingTrailers.promise
- )
- }
- break loop
- case .awaitMoreMessages:
- break loop
- case .closeAndFailPromise(let promise, let error):
- context.close(mode: .all, promise: nil)
- promise?.fail(error)
- }
- }
- if self.flushPending {
- self.flushPending = false
- context.flush()
- }
- } catch let invalidState {
- let error = RPCError(invalidState)
- context.fireErrorCaught(error)
- }
- }
- }
|