Browse Source

Scaffolding for better server codegen (#1093)

Motivation:

Most of the RPC handling code on the server is generic but not
specialized; as a result we pay a fairly large performance cost. This is
the first in a series of PRs which changes how the shape of the
generated server code.

The plan is to have a new protocol between the generated code and gRPC,
'GRPCServerHandlerProtocol', which operates on concrete types
('ByteBuffer' as opposed to typed request/responses). As such the
protocol for asking for a handler ('CallHandlerProvider') will also
change to return an instance of this new protocol. In generated code the
returned handler will be generic over the request deserializer and
response serializer. The responsibility of the call handler will not
differ much from now (it will receive headers, serialized bytes, etc.)
but will be implemented as a class per call type (this is close to the
current situation where the same is true but a common base class is
shared). In addition the new handlers will not channel handlers: they
will eventually be held by the routing handler.

Modifications:

- Add 'GRPCServerHandlerProtocol', the inbound protocol that the new
  handlers will conform to
- Add 'GRPCServerResponseWriter', the internal outbound counterpart to
  'GRPCServerHandlerProtocol' used by the handlers to communicate back
  to gRPC and add an unimplemented conformance of this to
  'HTTP2ToRawGRPCServerCodec'.
- Add 'UnaryServerHandler' (and tests), a handler for unary requests
- Add 'ServerErrorProcessor'; cargo culted from '_BaseCallHandler'
- Modify the 'CallHandlerProvider' to add a 'handle(method:context:)'

Result:

We have a rough outline of how the new server handlers will look.
George Barnett 5 years ago
parent
commit
a45db1b439

+ 69 - 0
Sources/GRPC/CallHandlers/ServerHandlerProtocol.swift

@@ -0,0 +1,69 @@
+/*
+ * Copyright 2021, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import NIO
+import NIOHPACK
+
+/// This protocol lays out the inbound interface between the gRPC module and generated server code.
+/// On receiving a new RPC, gRPC will ask all available service providers for an instance of this
+/// protocol in order to handle the RPC.
+///
+/// See also: `CallHandlerProvider.handle(method:context:)`
+public protocol GRPCServerHandlerProtocol {
+  /// Called when request headers have been received at the start of an RPC.
+  /// - Parameter metadata: The request headers.
+  func receiveMetadata(_ metadata: HPACKHeaders)
+
+  /// Called when request message has been received.
+  /// - Parameter bytes: The bytes of the serialized request.
+  func receiveMessage(_ bytes: ByteBuffer)
+
+  /// Called at the end of the request stream.
+  func receiveEnd()
+
+  /// Called when an error has been encountered. The handler should be torn down on receiving an
+  /// error.
+  /// - Parameter error: The error which has been encountered.
+  func receiveError(_ error: Error)
+
+  /// Called when the RPC handler should be torn down.
+  func finish()
+}
+
+/// This protocol defines the outbound interface between the gRPC module and generated server code.
+/// It is used by server handlers in order to send responses back to gRPC.
+@usableFromInline
+internal protocol GRPCServerResponseWriter {
+  /// Send the initial response metadata.
+  /// - Parameters:
+  ///   - metadata: The user-provided metadata to send to the client.
+  ///   - promise: A promise to complete once the metadata has been handled.
+  func sendMetadata(_ metadata: HPACKHeaders, promise: EventLoopPromise<Void>?)
+
+  /// Send the serialized bytes of a response message.
+  /// - Parameters:
+  ///   - bytes: The serialized bytes to send to the client.
+  ///   - metadata: Metadata associated with sending the response, such as whether it should be
+  ///     compressed.
+  ///   - promise: A promise to complete once the message as been handled.
+  func sendMessage(_ bytes: ByteBuffer, metadata: MessageMetadata, promise: EventLoopPromise<Void>?)
+
+  /// Ends the response stream.
+  /// - Parameters:
+  ///   - status: The final status of the RPC.
+  ///   - trailers: Any user-provided trailers to send back to the client with the status.
+  ///   - promise: A promise to complete once the status and trailers have been handled.
+  func sendEnd(status: GRPCStatus, trailers: HPACKHeaders, promise: EventLoopPromise<Void>?)
+}

+ 300 - 0
Sources/GRPC/CallHandlers/UnaryServerHandler.swift

@@ -0,0 +1,300 @@
+/*
+ * Copyright 2021, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import NIO
+import NIOHPACK
+
+public final class UnaryServerHandler<
+  Serializer: MessageSerializer,
+  Deserializer: MessageDeserializer
+>: GRPCServerHandlerProtocol {
+  public typealias Request = Deserializer.Output
+  public typealias Response = Serializer.Input
+
+  /// A response serializer.
+  @usableFromInline
+  internal let serializer: Serializer
+
+  /// A request deserializer.
+  @usableFromInline
+  internal let deserializer: Deserializer
+
+  /// A pipeline of user provided interceptors.
+  @usableFromInline
+  internal var interceptors: ServerInterceptorPipeline<Request, Response>!
+
+  /// The context required in order create the function.
+  @usableFromInline
+  internal let context: CallHandlerContext
+
+  /// A reference to a `UserInfo`.
+  @usableFromInline
+  internal let userInfoRef: Ref<UserInfo>
+
+  /// The user provided function to execute.
+  @usableFromInline
+  internal let userFunction: (Request, StatusOnlyCallContext) -> EventLoopFuture<Response>
+
+  /// The state of the function invocation.
+  @usableFromInline
+  internal var state: State = .idle
+
+  @usableFromInline
+  internal enum State {
+    // Initial state. Nothing has happened yet.
+    case idle
+    // Headers have been received and now we're holding a context with which to invoke the user
+    // function when we receive a message.
+    case createdContext(UnaryResponseCallContext<Response>)
+    // The user function has been invoked, we're waiting for the response.
+    case invokedFunction(UnaryResponseCallContext<Response>)
+    // The function has completed or we are no longer proceeding with execution (because of an error
+    // or unexpected closure).
+    case completed
+  }
+
+  @inlinable
+  public init(
+    context: CallHandlerContext,
+    requestDeserializer: Deserializer,
+    responseSerializer: Serializer,
+    interceptors: [ServerInterceptor<Request, Response>],
+    userFunction: @escaping (Request, StatusOnlyCallContext) -> EventLoopFuture<Response>
+  ) {
+    self.userFunction = userFunction
+    self.serializer = responseSerializer
+    self.deserializer = requestDeserializer
+    self.context = context
+
+    let userInfoRef = Ref(UserInfo())
+    self.userInfoRef = userInfoRef
+    self.interceptors = ServerInterceptorPipeline(
+      logger: context.logger,
+      eventLoop: context.eventLoop,
+      path: context.path,
+      callType: .unary,
+      remoteAddress: context.remoteAddress,
+      userInfoRef: userInfoRef,
+      interceptors: interceptors,
+      onRequestPart: self.receiveInterceptedPart(_:),
+      onResponsePart: self.sendInterceptedPart(_:promise:)
+    )
+  }
+
+  // MARK: - Public API: gRPC to Interceptors
+
+  @inlinable
+  public func receiveMetadata(_ metadata: HPACKHeaders) {
+    self.interceptors.receive(.metadata(metadata))
+  }
+
+  @inlinable
+  public func receiveMessage(_ bytes: ByteBuffer) {
+    do {
+      let message = try self.deserializer.deserialize(byteBuffer: bytes)
+      self.interceptors.receive(.message(message))
+    } catch {
+      self.handleError(error)
+    }
+  }
+
+  @inlinable
+  public func receiveEnd() {
+    self.interceptors.receive(.end)
+  }
+
+  @inlinable
+  public func receiveError(_ error: Error) {
+    self._finish(error: error)
+  }
+
+  @inlinable
+  public func finish() {
+    self._finish(error: nil)
+  }
+
+  @inlinable
+  internal func _finish(error: Error?) {
+    switch self.state {
+    case .idle:
+      self.interceptors = nil
+      self.state = .completed
+
+    case let .createdContext(context),
+         let .invokedFunction(context):
+      let error = error ?? GRPCStatus(code: .unavailable, message: nil)
+      context.responsePromise.fail(error)
+
+    case .completed:
+      self.interceptors = nil
+    }
+  }
+
+  // MARK: - Interceptors to User Function
+
+  @inlinable
+  internal func receiveInterceptedPart(_ part: GRPCServerRequestPart<Request>) {
+    switch part {
+    case let .metadata(headers):
+      self.receiveInterceptedMetadata(headers)
+    case let .message(message):
+      self.receiveInterceptedMessage(message)
+    case .end:
+      // Ignored.
+      ()
+    }
+  }
+
+  @inlinable
+  internal func receiveInterceptedMetadata(_ headers: HPACKHeaders) {
+    switch self.state {
+    case .idle:
+      // Make a context to invoke the user function with.
+      let context = UnaryResponseCallContext<Response>(
+        eventLoop: self.context.eventLoop,
+        headers: headers,
+        logger: self.context.logger,
+        userInfoRef: self.userInfoRef
+      )
+
+      // Move to the next state.
+      self.state = .createdContext(context)
+
+      // Register a callback on the response future. The user function will complete this promise.
+      context.responsePromise.futureResult.whenComplete(self.userFunctionCompletedWithResult(_:))
+
+      // Send back response headers.
+      self.interceptors.send(.metadata([:]), promise: nil)
+
+    case .createdContext, .invokedFunction:
+      self.handleError(GRPCError.InvalidState("Protocol violation: already received headers"))
+
+    case .completed:
+      // We may receive headers from the interceptor pipeline if we have already finished (i.e. due
+      // to an error or otherwise) and an interceptor doing some async work later emitting headers.
+      // Dropping them is fine.
+      ()
+    }
+  }
+
+  @inlinable
+  internal func receiveInterceptedMessage(_ request: Request) {
+    switch self.state {
+    case .idle:
+      self
+        .handleError(GRPCError.InvalidState("Protocol violation: message received before headers"))
+
+    case let .createdContext(context):
+      // Happy path: execute the function; complete the promise with the result.
+      self.state = .invokedFunction(context)
+      context.responsePromise.completeWith(self.userFunction(request, context))
+
+    case .invokedFunction:
+      // The function's already been invoked with a message.
+      self.handleError(GRPCError.InvalidState("Protocol violation: already received message"))
+
+    case .completed:
+      // We received a message but we're already done: this may happen if we terminate the RPC
+      // due to a channel error, for example.
+      ()
+    }
+  }
+
+  // MARK: - User Function To Interceptors
+
+  @inlinable
+  internal func userFunctionCompletedWithResult(_ result: Result<Response, Error>) {
+    switch self.state {
+    case .idle:
+      // Invalid state: the user function can only complete if it was executed.
+      preconditionFailure()
+
+    // 'created' is allowed here: we may have to (and tear down) after receiving headers
+    // but before receiving a message.
+    case let .createdContext(context),
+         let .invokedFunction(context):
+      self.state = .completed
+
+      switch result {
+      case let .success(response):
+        let metadata = MessageMetadata(compress: false, flush: false)
+        self.interceptors.send(.message(response, metadata), promise: nil)
+        self.interceptors.send(.end(context.responseStatus, context.trailers), promise: nil)
+
+      case let .failure(error):
+        let (status, trailers) = ServerErrorProcessor.processObserverError(
+          error,
+          headers: context.headers,
+          trailers: context.trailers,
+          delegate: self.context.errorDelegate
+        )
+        self.interceptors.send(.end(status, trailers), promise: nil)
+      }
+
+    case .completed:
+      // We've already failed. Ignore this.
+      ()
+    }
+  }
+
+  @inlinable
+  internal func sendInterceptedPart(
+    _ part: GRPCServerResponsePart<Response>,
+    promise: EventLoopPromise<Void>?
+  ) {
+    switch part {
+    case let .metadata(headers):
+      self.context.responseWriter.sendMetadata(headers, promise: promise)
+
+    case let .message(message, metadata):
+      do {
+        let bytes = try self.serializer.serialize(message, allocator: self.context.allocator)
+        self.context.responseWriter.sendMessage(bytes, metadata: metadata, promise: promise)
+      } catch {
+        // Serialization failed: fail the promise and send end.
+        promise?.fail(error)
+        let (status, trailers) = ServerErrorProcessor.processLibraryError(
+          error,
+          delegate: self.context.errorDelegate
+        )
+        // Loop back via the interceptors.
+        self.interceptors.send(.end(status, trailers), promise: nil)
+      }
+
+    case let .end(status, trailers):
+      self.context.responseWriter.sendEnd(status: status, trailers: trailers, promise: promise)
+    }
+  }
+
+  @inlinable
+  internal func handleError(_ error: Error) {
+    switch self.state {
+    case .idle:
+      // We don't have a promise to fail. Just send end back.
+      let (status, trailers) = ServerErrorProcessor.processLibraryError(
+        error,
+        delegate: self.context.errorDelegate
+      )
+      self.interceptors.send(.end(status, trailers), promise: nil)
+
+    case let .createdContext(context),
+         let .invokedFunction(context):
+      context.responsePromise.fail(error)
+
+    case .completed:
+      ()
+    }
+  }
+}

+ 24 - 1
Sources/GRPC/GRPCServerRequestRoutingHandler.swift

@@ -15,7 +15,9 @@
  */
 import Logging
 import NIO
+import NIOHPACK
 import NIOHTTP1
+import NIOHTTP2
 import SwiftProtobuf
 
 /// Processes individual gRPC messages and stream-close events on an HTTP2 channel.
@@ -34,6 +36,23 @@ public protocol CallHandlerProvider: AnyObject {
   /// method. Returns nil for methods not handled by this service.
   func handleMethod(_ methodName: Substring, callHandlerContext: CallHandlerContext)
     -> GRPCCallHandler?
+
+  /// Returns a call handler for the method with the given name, if this service provider implements
+  /// the given method. Returns `nil` if the method is not handled by this provider.
+  /// - Parameters:
+  ///   - name: The name of the method to handle.
+  ///   - context: An opaque context providing components to construct the handler with.
+  func handle(method name: Substring, context: CallHandlerContext) -> GRPCServerHandlerProtocol?
+}
+
+extension CallHandlerProvider {
+  // TODO: remove this once we've removed 'handleMethod(_:callHandlerContext:)'.
+  public func handle(
+    method name: Substring,
+    context: CallHandlerContext
+  ) -> GRPCServerHandlerProtocol? {
+    return nil
+  }
 }
 
 // This is public because it will be passed into generated code, all members are `internal` because
@@ -52,6 +71,10 @@ public struct CallHandlerContext {
   internal var path: String
   @usableFromInline
   internal var remoteAddress: SocketAddress?
+  @usableFromInline
+  internal var responseWriter: GRPCServerResponseWriter
+  @usableFromInline
+  internal var allocator: ByteBufferAllocator
 }
 
 /// A call URI split into components.
@@ -61,7 +84,7 @@ struct CallPath {
   /// The name of the method to call.
   var method: String.UTF8View.SubSequence
 
-  /// Charater used to split the path into components.
+  /// Character used to split the path into components.
   private let pathSplitDelimiter = UInt8(ascii: "/")
 
   /// Split a path into service and method.

+ 27 - 2
Sources/GRPC/HTTP2ToRawGRPCServerCodec.swift

@@ -18,7 +18,7 @@ import NIO
 import NIOHPACK
 import NIOHTTP2
 
-internal final class HTTP2ToRawGRPCServerCodec: ChannelDuplexHandler {
+internal final class HTTP2ToRawGRPCServerCodec: ChannelDuplexHandler, GRPCServerResponseWriter {
   typealias InboundIn = HTTP2Frame.FramePayload
   typealias InboundOut = GRPCServerRequestPart<ByteBuffer>
 
@@ -112,7 +112,9 @@ internal final class HTTP2ToRawGRPCServerCodec: ChannelDuplexHandler {
         eventLoop: context.eventLoop,
         errorDelegate: self.errorDelegate,
         remoteAddress: context.channel.remoteAddress,
-        logger: self.logger
+        logger: self.logger,
+        allocator: context.channel.allocator,
+        responseWriter: self
       )
       self.act(on: action, with: context)
 
@@ -162,4 +164,27 @@ internal final class HTTP2ToRawGRPCServerCodec: ChannelDuplexHandler {
 
     self.act(on: action, with: context)
   }
+
+  internal func sendMetadata(
+    _ metadata: HPACKHeaders,
+    promise: EventLoopPromise<Void>?
+  ) {
+    fatalError("TODO: not used yet")
+  }
+
+  internal func sendMessage(
+    _ bytes: ByteBuffer,
+    metadata: MessageMetadata,
+    promise: EventLoopPromise<Void>?
+  ) {
+    fatalError("TODO: not used yet")
+  }
+
+  internal func sendEnd(
+    status: GRPCStatus,
+    trailers: HPACKHeaders,
+    promise: EventLoopPromise<Void>?
+  ) {
+    fatalError("TODO: not used yet")
+  }
 }

+ 18 - 6
Sources/GRPC/HTTP2ToRawGRPCStateMachine.swift

@@ -279,7 +279,9 @@ extension HTTP2ToRawGRPCStateMachine.RequestIdleResponseIdleState {
     eventLoop: EventLoop,
     errorDelegate: ServerErrorDelegate?,
     remoteAddress: SocketAddress?,
-    logger: Logger
+    logger: Logger,
+    allocator: ByteBufferAllocator,
+    responseWriter: GRPCServerResponseWriter
   ) -> HTTP2ToRawGRPCStateMachine.StateAndAction {
     // Extract and validate the content type. If it's nil we need to close.
     guard let contentType = self.extractContentType(from: headers) else {
@@ -325,7 +327,9 @@ extension HTTP2ToRawGRPCStateMachine.RequestIdleResponseIdleState {
       encoding: self.encoding,
       eventLoop: eventLoop,
       path: path,
-      remoteAddress: remoteAddress
+      remoteAddress: remoteAddress,
+      responseWriter: responseWriter,
+      allocator: allocator
     )
 
     // We have a matching service, hopefully we have a provider for the method too.
@@ -886,7 +890,9 @@ extension HTTP2ToRawGRPCStateMachine {
     eventLoop: EventLoop,
     errorDelegate: ServerErrorDelegate?,
     remoteAddress: SocketAddress?,
-    logger: Logger
+    logger: Logger,
+    allocator: ByteBufferAllocator,
+    responseWriter: GRPCServerResponseWriter
   ) -> Action {
     return self.withStateAvoidingCoWs { state in
       state.receive(
@@ -894,7 +900,9 @@ extension HTTP2ToRawGRPCStateMachine {
         eventLoop: eventLoop,
         errorDelegate: errorDelegate,
         remoteAddress: remoteAddress,
-        logger: logger
+        logger: logger,
+        allocator: allocator,
+        responseWriter: responseWriter
       )
     }
   }
@@ -982,7 +990,9 @@ extension HTTP2ToRawGRPCStateMachine.State {
     eventLoop: EventLoop,
     errorDelegate: ServerErrorDelegate?,
     remoteAddress: SocketAddress?,
-    logger: Logger
+    logger: Logger,
+    allocator: ByteBufferAllocator,
+    responseWriter: GRPCServerResponseWriter
   ) -> HTTP2ToRawGRPCStateMachine.Action {
     switch self {
     // This is the only state in which we can receive headers. Everything else is invalid.
@@ -992,7 +1002,9 @@ extension HTTP2ToRawGRPCStateMachine.State {
         eventLoop: eventLoop,
         errorDelegate: errorDelegate,
         remoteAddress: remoteAddress,
-        logger: logger
+        logger: logger,
+        allocator: allocator,
+        responseWriter: responseWriter
       )
       self = stateAndAction.state
       return stateAndAction.action

+ 3 - 1
Sources/GRPC/ServerCallContexts/ServerCallContext.swift

@@ -61,7 +61,8 @@ open class ServerCallContextBase: ServerCallContext {
   }
 
   /// A reference to an underlying `UserInfo`. We share this with the interceptors.
-  private let userInfoRef: Ref<UserInfo>
+  @usableFromInline
+  internal let userInfoRef: Ref<UserInfo>
 
   /// Metadata to return at the end of the RPC. If this is required it should be updated before
   /// the `responsePromise` or `statusPromise` is fulfilled.
@@ -76,6 +77,7 @@ open class ServerCallContextBase: ServerCallContext {
     self.init(eventLoop: eventLoop, headers: headers, logger: logger, userInfoRef: .init(userInfo))
   }
 
+  @inlinable
   internal init(
     eventLoop: EventLoop,
     headers: HPACKHeaders,

+ 1 - 0
Sources/GRPC/ServerCallContexts/UnaryResponseCallContext.swift

@@ -44,6 +44,7 @@ open class UnaryResponseCallContext<ResponsePayload>: ServerCallContextBase, Sta
     self.init(eventLoop: eventLoop, headers: headers, logger: logger, userInfoRef: .init(userInfo))
   }
 
+  @inlinable
   override internal init(
     eventLoop: EventLoop,
     headers: HPACKHeaders,

+ 93 - 0
Sources/GRPC/ServerErrorProcessor.swift

@@ -0,0 +1,93 @@
+/*
+ * Copyright 2021, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import NIOHPACK
+
+@usableFromInline
+internal enum ServerErrorProcessor {
+  /// Processes a library error to form a `GRPCStatus` and trailers to send back to the client.
+  /// - Parameter error: The error to process.
+  /// - Returns: The status and trailers to send to the client.
+  @usableFromInline
+  internal static func processLibraryError(
+    _ error: Error,
+    delegate: ServerErrorDelegate?
+  ) -> (GRPCStatus, HPACKHeaders) {
+    // Observe the error if we have a delegate.
+    delegate?.observeLibraryError(error)
+
+    // What status are we terminating this RPC with?
+    // - If we have a delegate, try transforming the error. If the delegate returns trailers, merge
+    //   them with any on the call context.
+    // - If we don't have a delegate, then try to transform the error to a status.
+    // - Fallback to a generic error.
+    let status: GRPCStatus
+    let trailers: HPACKHeaders
+
+    if let transformed = delegate?.transformLibraryError(error) {
+      status = transformed.status
+      trailers = transformed.trailers ?? [:]
+    } else if let grpcStatusTransformable = error as? GRPCStatusTransformable {
+      status = grpcStatusTransformable.makeGRPCStatus()
+      trailers = [:]
+    } else {
+      // Eh... well, we don't what status to use. Use a generic one.
+      status = .processingError
+      trailers = [:]
+    }
+
+    return (status, trailers)
+  }
+
+  /// Processes an error, transforming it into a 'GRPCStatus' and any trailers to send to the peer.
+  @usableFromInline
+  internal static func processObserverError(
+    _ error: Error,
+    headers: HPACKHeaders,
+    trailers: HPACKHeaders,
+    delegate: ServerErrorDelegate?
+  ) -> (GRPCStatus, HPACKHeaders) {
+    // Observe the error if we have a delegate.
+    delegate?.observeRequestHandlerError(error, headers: headers)
+
+    // What status are we terminating this RPC with?
+    // - If we have a delegate, try transforming the error. If the delegate returns trailers, merge
+    //   them with any on the call context.
+    // - If we don't have a delegate, then try to transform the error to a status.
+    // - Fallback to a generic error.
+    let status: GRPCStatus
+    let mergedTrailers: HPACKHeaders
+
+    if let transformed = delegate?.transformRequestHandlerError(error, headers: headers) {
+      status = transformed.status
+      if var transformedTrailers = transformed.trailers {
+        // The delegate returned trailers: merge in those from the context as well.
+        transformedTrailers.add(contentsOf: trailers)
+        mergedTrailers = transformedTrailers
+      } else {
+        mergedTrailers = trailers
+      }
+    } else if let grpcStatusTransformable = error as? GRPCStatusTransformable {
+      status = grpcStatusTransformable.makeGRPCStatus()
+      mergedTrailers = trailers
+    } else {
+      // Eh... well, we don't what status to use. Use a generic one.
+      status = .processingError
+      mergedTrailers = trailers
+    }
+
+    return (status, mergedTrailers)
+  }
+}

+ 18 - 2
Tests/GRPCTests/ClientTransportTests.swift

@@ -346,7 +346,7 @@ class WriteRecorder<Write>: ChannelOutboundHandler {
 
 private struct DummyError: Error {}
 
-private struct StringSerializer: MessageSerializer {
+internal struct StringSerializer: MessageSerializer {
   typealias Input = String
 
   func serialize(_ input: String, allocator: ByteBufferAllocator) throws -> ByteBuffer {
@@ -354,7 +354,7 @@ private struct StringSerializer: MessageSerializer {
   }
 }
 
-private struct StringDeserializer: MessageDeserializer {
+internal struct StringDeserializer: MessageDeserializer {
   typealias Output = String
 
   func deserialize(byteBuffer: ByteBuffer) throws -> String {
@@ -362,3 +362,19 @@ private struct StringDeserializer: MessageDeserializer {
     return buffer.readString(length: buffer.readableBytes)!
   }
 }
+
+internal struct ThrowingStringSerializer: MessageSerializer {
+  typealias Input = String
+
+  func serialize(_ input: String, allocator: ByteBufferAllocator) throws -> ByteBuffer {
+    throw DummyError()
+  }
+}
+
+internal struct ThrowingStringDeserializer: MessageDeserializer {
+  typealias Output = String
+
+  func deserialize(byteBuffer: ByteBuffer) throws -> String {
+    throw DummyError()
+  }
+}

+ 54 - 12
Tests/GRPCTests/HTTP2ToRawGRPCStateMachineTests.swift

@@ -101,7 +101,9 @@ class HTTP2ToRawGRPCStateMachineTests: GRPCTestCase {
       eventLoop: self.eventLoop,
       errorDelegate: nil,
       remoteAddress: nil,
-      logger: self.logger
+      logger: self.logger,
+      allocator: ByteBufferAllocator(),
+      responseWriter: NoOpResponseWriter()
     )
 
     assertThat(receiveHeadersAction, .is(.configure()))
@@ -171,7 +173,9 @@ class HTTP2ToRawGRPCStateMachineTests: GRPCTestCase {
       eventLoop: self.eventLoop,
       errorDelegate: nil,
       remoteAddress: nil,
-      logger: self.logger
+      logger: self.logger,
+      allocator: ByteBufferAllocator(),
+      responseWriter: NoOpResponseWriter()
     )
     assertThat(action, .is(.configure()))
   }
@@ -183,7 +187,9 @@ class HTTP2ToRawGRPCStateMachineTests: GRPCTestCase {
       eventLoop: self.eventLoop,
       errorDelegate: nil,
       remoteAddress: nil,
-      logger: self.logger
+      logger: self.logger,
+      allocator: ByteBufferAllocator(),
+      responseWriter: NoOpResponseWriter()
     )
     assertThat(
       action,
@@ -198,7 +204,9 @@ class HTTP2ToRawGRPCStateMachineTests: GRPCTestCase {
       eventLoop: self.eventLoop,
       errorDelegate: nil,
       remoteAddress: nil,
-      logger: self.logger
+      logger: self.logger,
+      allocator: ByteBufferAllocator(),
+      responseWriter: NoOpResponseWriter()
     )
     assertThat(action, .is(.write(.trailersOnly(code: .unimplemented), flush: true)))
   }
@@ -210,7 +218,9 @@ class HTTP2ToRawGRPCStateMachineTests: GRPCTestCase {
       eventLoop: self.eventLoop,
       errorDelegate: nil,
       remoteAddress: nil,
-      logger: self.logger
+      logger: self.logger,
+      allocator: ByteBufferAllocator(),
+      responseWriter: NoOpResponseWriter()
     )
     assertThat(action, .is(.write(.trailersOnly(code: .unimplemented), flush: true)))
   }
@@ -222,7 +232,9 @@ class HTTP2ToRawGRPCStateMachineTests: GRPCTestCase {
       eventLoop: self.eventLoop,
       errorDelegate: nil,
       remoteAddress: nil,
-      logger: self.logger
+      logger: self.logger,
+      allocator: ByteBufferAllocator(),
+      responseWriter: NoOpResponseWriter()
     )
     assertThat(action, .is(.write(.trailersOnly(code: .unimplemented), flush: true)))
   }
@@ -234,7 +246,9 @@ class HTTP2ToRawGRPCStateMachineTests: GRPCTestCase {
       eventLoop: self.eventLoop,
       errorDelegate: nil,
       remoteAddress: nil,
-      logger: self.logger
+      logger: self.logger,
+      allocator: ByteBufferAllocator(),
+      responseWriter: NoOpResponseWriter()
     )
     assertThat(action, .is(.write(.trailersOnly(code: .unimplemented), flush: true)))
   }
@@ -247,7 +261,9 @@ class HTTP2ToRawGRPCStateMachineTests: GRPCTestCase {
       eventLoop: self.eventLoop,
       errorDelegate: nil,
       remoteAddress: nil,
-      logger: self.logger
+      logger: self.logger,
+      allocator: ByteBufferAllocator(),
+      responseWriter: NoOpResponseWriter()
     )
     assertThat(action, .is(.write(.trailersOnly(code: .invalidArgument), flush: true)))
   }
@@ -260,7 +276,9 @@ class HTTP2ToRawGRPCStateMachineTests: GRPCTestCase {
       eventLoop: self.eventLoop,
       errorDelegate: nil,
       remoteAddress: nil,
-      logger: self.logger
+      logger: self.logger,
+      allocator: ByteBufferAllocator(),
+      responseWriter: NoOpResponseWriter()
     )
 
     assertThat(action, .is(.write(.trailersOnly(code: .unimplemented), flush: true)))
@@ -279,7 +297,9 @@ class HTTP2ToRawGRPCStateMachineTests: GRPCTestCase {
       eventLoop: self.eventLoop,
       errorDelegate: nil,
       remoteAddress: nil,
-      logger: self.logger
+      logger: self.logger,
+      allocator: ByteBufferAllocator(),
+      responseWriter: NoOpResponseWriter()
     )
 
     // This is expected: however, we also expect 'grpc-accept-encoding' to be in the response
@@ -301,7 +321,9 @@ class HTTP2ToRawGRPCStateMachineTests: GRPCTestCase {
       eventLoop: self.eventLoop,
       errorDelegate: nil,
       remoteAddress: nil,
-      logger: self.logger
+      logger: self.logger,
+      allocator: ByteBufferAllocator(),
+      responseWriter: NoOpResponseWriter()
     )
 
     assertThat(action, .is(.configure()))
@@ -315,7 +337,9 @@ class HTTP2ToRawGRPCStateMachineTests: GRPCTestCase {
       eventLoop: self.eventLoop,
       errorDelegate: nil,
       remoteAddress: nil,
-      logger: self.logger
+      logger: self.logger,
+      allocator: ByteBufferAllocator(),
+      responseWriter: NoOpResponseWriter()
     )
 
     // This is expected, but we need to check the value of 'grpc-encoding' in the response headers.
@@ -640,3 +664,21 @@ extension ServerMessageEncoding {
     return .enabled(.init(enabledAlgorithms: algorithms, decompressionLimit: .absolute(.max)))
   }
 }
+
+class NoOpResponseWriter: GRPCServerResponseWriter {
+  func sendMetadata(_ metadata: HPACKHeaders, promise: EventLoopPromise<Void>?) {
+    promise?.succeed(())
+  }
+
+  func sendMessage(
+    _ bytes: ByteBuffer,
+    metadata: MessageMetadata,
+    promise: EventLoopPromise<Void>?
+  ) {
+    promise?.succeed(())
+  }
+
+  func sendEnd(status: GRPCStatus, trailers: HPACKHeaders, promise: EventLoopPromise<Void>?) {
+    promise?.succeed(())
+  }
+}

+ 3 - 1
Tests/GRPCTests/ServerInterceptorTests.swift

@@ -46,7 +46,9 @@ class ServerInterceptorTests: GRPCTestCase {
       logger: self.serverLogger,
       encoding: .disabled,
       eventLoop: self.channel.eventLoop,
-      path: path
+      path: path,
+      responseWriter: NoOpResponseWriter(),
+      allocator: ByteBufferAllocator()
     )
   }
 

+ 253 - 0
Tests/GRPCTests/UnaryServerHandlerTests.swift

@@ -0,0 +1,253 @@
+/*
+ * Copyright 2021, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+@testable import GRPC
+import NIO
+import NIOHPACK
+import XCTest
+
+final class ResponseRecorder: GRPCServerResponseWriter {
+  var metadata: HPACKHeaders?
+  var messages: [ByteBuffer] = []
+  var status: GRPCStatus?
+  var trailers: HPACKHeaders?
+
+  func sendMetadata(_ metadata: HPACKHeaders, promise: EventLoopPromise<Void>?) {
+    XCTAssertNil(self.metadata)
+    self.metadata = metadata
+    promise?.succeed(())
+  }
+
+  func sendMessage(
+    _ bytes: ByteBuffer,
+    metadata: MessageMetadata,
+    promise: EventLoopPromise<Void>?
+  ) {
+    self.messages.append(bytes)
+    promise?.succeed(())
+  }
+
+  func sendEnd(status: GRPCStatus, trailers: HPACKHeaders, promise: EventLoopPromise<Void>?) {
+    XCTAssertNil(self.status)
+    XCTAssertNil(self.trailers)
+    self.status = status
+    self.trailers = trailers
+    promise?.succeed(())
+  }
+}
+
+class UnaryServerHandlerTests: GRPCTestCase {
+  let eventLoop = EmbeddedEventLoop()
+  let allocator = ByteBufferAllocator()
+
+  private func makeCallHandlerContext(writer: GRPCServerResponseWriter) -> CallHandlerContext {
+    return CallHandlerContext(
+      errorDelegate: nil,
+      logger: self.logger,
+      encoding: .disabled,
+      eventLoop: self.eventLoop,
+      path: "/ignored",
+      remoteAddress: nil,
+      responseWriter: writer,
+      allocator: self.allocator
+    )
+  }
+
+  private func makeHandler(
+    writer: GRPCServerResponseWriter,
+    function: @escaping (String, StatusOnlyCallContext) -> EventLoopFuture<String>
+  ) -> UnaryServerHandler<StringSerializer, StringDeserializer> {
+    return UnaryServerHandler(
+      context: self.makeCallHandlerContext(writer: writer),
+      requestDeserializer: StringDeserializer(),
+      responseSerializer: StringSerializer(),
+      interceptors: [],
+      userFunction: function
+    )
+  }
+
+  private func echo(_ request: String, context: StatusOnlyCallContext) -> EventLoopFuture<String> {
+    return context.eventLoop.makeSucceededFuture(request)
+  }
+
+  private func neverComplete(
+    _ request: String,
+    context: StatusOnlyCallContext
+  ) -> EventLoopFuture<String> {
+    let scheduled = context.eventLoop.scheduleTask(deadline: .distantFuture) {
+      return request
+    }
+    return scheduled.futureResult
+  }
+
+  private func neverCalled(
+    _ request: String,
+    context: StatusOnlyCallContext
+  ) -> EventLoopFuture<String> {
+    XCTFail("Unexpected function invocation")
+    return context.eventLoop.makeFailedFuture(GRPCError.InvalidState(""))
+  }
+
+  func testHappyPath() {
+    let recorder = ResponseRecorder()
+    let handler = self.makeHandler(writer: recorder, function: self.echo(_:context:))
+
+    handler.receiveMetadata([:])
+    assertThat(recorder.metadata, .is([:]))
+
+    let buffer = ByteBuffer(string: "hello")
+    handler.receiveMessage(buffer)
+    handler.receiveEnd()
+    handler.finish()
+
+    assertThat(recorder.messages.first, .is(buffer))
+    assertThat(recorder.status, .notNil(.hasCode(.ok)))
+    assertThat(recorder.trailers, .is([:]))
+  }
+
+  func testThrowingDeserializer() {
+    let recorder = ResponseRecorder()
+    let handler = UnaryServerHandler(
+      context: self.makeCallHandlerContext(writer: recorder),
+      requestDeserializer: ThrowingStringDeserializer(),
+      responseSerializer: StringSerializer(),
+      interceptors: [],
+      userFunction: self.neverCalled(_:context:)
+    )
+
+    handler.receiveMetadata([:])
+    assertThat(recorder.metadata, .is([:]))
+
+    let buffer = ByteBuffer(string: "hello")
+    handler.receiveMessage(buffer)
+
+    assertThat(recorder.messages, .isEmpty())
+    assertThat(recorder.status, .notNil(.hasCode(.internalError)))
+  }
+
+  func testThrowingSerializer() {
+    let recorder = ResponseRecorder()
+    let handler = UnaryServerHandler(
+      context: self.makeCallHandlerContext(writer: recorder),
+      requestDeserializer: StringDeserializer(),
+      responseSerializer: ThrowingStringSerializer(),
+      interceptors: [],
+      userFunction: self.echo(_:context:)
+    )
+
+    handler.receiveMetadata([:])
+    assertThat(recorder.metadata, .is([:]))
+
+    let buffer = ByteBuffer(string: "hello")
+    handler.receiveMessage(buffer)
+    handler.receiveEnd()
+
+    assertThat(recorder.messages, .isEmpty())
+    assertThat(recorder.status, .notNil(.hasCode(.internalError)))
+  }
+
+  func testUserFunctionReturnsFailedFuture() {
+    let recorder = ResponseRecorder()
+    let handler = self.makeHandler(writer: recorder) { _, context in
+      return context.eventLoop.makeFailedFuture(GRPCStatus(code: .unavailable, message: ":("))
+    }
+
+    handler.receiveMetadata([:])
+    assertThat(recorder.metadata, .is([:]))
+
+    let buffer = ByteBuffer(string: "hello")
+    handler.receiveMessage(buffer)
+
+    assertThat(recorder.messages, .isEmpty())
+    assertThat(recorder.status, .notNil(.hasCode(.unavailable)))
+    assertThat(recorder.status?.message, .is(":("))
+  }
+
+  func testReceiveMessageBeforeHeaders() {
+    let recorder = ResponseRecorder()
+    let handler = self.makeHandler(writer: recorder, function: self.neverCalled(_:context:))
+
+    handler.receiveMessage(ByteBuffer(string: "foo"))
+    assertThat(recorder.metadata, .is(.nil()))
+    assertThat(recorder.messages, .isEmpty())
+    assertThat(recorder.status, .notNil(.hasCode(.internalError)))
+  }
+
+  func testReceiveMultipleHeaders() {
+    let recorder = ResponseRecorder()
+    let handler = self.makeHandler(writer: recorder, function: self.neverCalled(_:context:))
+
+    handler.receiveMetadata([:])
+    assertThat(recorder.metadata, .is([:]))
+
+    handler.receiveMetadata([:])
+    assertThat(recorder.messages, .isEmpty())
+    assertThat(recorder.status, .notNil(.hasCode(.internalError)))
+  }
+
+  func testReceiveMultipleMessages() {
+    let recorder = ResponseRecorder()
+    let handler = self.makeHandler(writer: recorder, function: self.neverComplete(_:context:))
+
+    handler.receiveMetadata([:])
+    assertThat(recorder.metadata, .is([:]))
+
+    let buffer = ByteBuffer(string: "hello")
+    handler.receiveMessage(buffer)
+    handler.receiveEnd()
+    // Send another message before the function completes.
+    handler.receiveMessage(buffer)
+
+    assertThat(recorder.messages, .isEmpty())
+    assertThat(recorder.status, .notNil(.hasCode(.internalError)))
+  }
+
+  func testFinishBeforeStarting() {
+    let recorder = ResponseRecorder()
+    let handler = self.makeHandler(writer: recorder, function: self.neverCalled(_:context:))
+
+    handler.finish()
+    assertThat(recorder.metadata, .is(.nil()))
+    assertThat(recorder.messages, .isEmpty())
+    assertThat(recorder.status, .is(.nil()))
+    assertThat(recorder.trailers, .is(.nil()))
+  }
+
+  func testFinishAfterHeaders() {
+    let recorder = ResponseRecorder()
+    let handler = self.makeHandler(writer: recorder, function: self.neverCalled(_:context:))
+    handler.receiveMetadata([:])
+    assertThat(recorder.metadata, .is([:]))
+
+    handler.finish()
+
+    assertThat(recorder.messages, .isEmpty())
+    assertThat(recorder.status, .notNil(.hasCode(.unavailable)))
+    assertThat(recorder.trailers, .is([:]))
+  }
+
+  func testFinishAfterMessage() {
+    let recorder = ResponseRecorder()
+    let handler = self.makeHandler(writer: recorder, function: self.neverComplete(_:context:))
+
+    handler.receiveMetadata([:])
+    handler.receiveMessage(ByteBuffer(string: "hello"))
+    handler.finish()
+
+    assertThat(recorder.messages, .isEmpty())
+    assertThat(recorder.status, .notNil(.hasCode(.unavailable)))
+    assertThat(recorder.trailers, .is([:]))
+  }
+}