Browse Source

Move server (de)serialization to the base call handler (#1046)

Motivation:

Recently, when experimenting I found that found that doing
serialization/deserializion in the server call handler (rather than a
separate codec) led to a non-trivial performance gain.

Modifications:

- Make the `_BaseCallHandler` and its subclasses generic over the
  serializer and deserializer and do (de)serialization there rather than
  in a separate codec handler.
- Update the call handler factory - which provides convenience methods
  used by generated code. While this is a breaking change, it remains
  source compatible, so code is not required to be regenerated.
- Remove the requirement that a 'GRPCCallHandler' provide a codec
- Remove the server codec handler and its tests
- Make MessageSerializer and MessageDeserializer and their
  implementations public

Result:

A 6% gain in QPS benchmarks (5.8% for unary, 6.4% for bidirectional
streaming, average across 3 runs)
George Barnett 5 years ago
parent
commit
2ce52e04a4

+ 10 - 9
Sources/GRPC/CallHandlers/BidirectionalStreamingCallHandler.swift

@@ -26,9 +26,9 @@ import SwiftProtobuf
 ///   they can fail the observer block future.
 /// - To close the call and send the status, complete `context.statusPromise`.
 public class BidirectionalStreamingCallHandler<
-  RequestPayload,
-  ResponsePayload
->: _BaseCallHandler<RequestPayload, ResponsePayload> {
+  RequestDeserializer: MessageDeserializer,
+  ResponseSerializer: MessageSerializer
+>: _BaseCallHandler<RequestDeserializer, ResponseSerializer> {
   private typealias Context = StreamingResponseCallContext<ResponsePayload>
   private typealias Observer = EventLoopFuture<(StreamEvent<RequestPayload>) -> Void>
 
@@ -44,18 +44,19 @@ public class BidirectionalStreamingCallHandler<
 
   // We ask for a future of type `EventObserver` to allow the framework user to e.g. asynchronously authenticate a call.
   // If authentication fails, they can simply fail the observer future, which causes the call to be terminated.
-  internal init<Serializer: MessageSerializer, Deserializer: MessageDeserializer>(
-    serializer: Serializer,
-    deserializer: Deserializer,
+  internal init(
+    serializer: ResponseSerializer,
+    deserializer: RequestDeserializer,
     callHandlerContext: CallHandlerContext,
-    interceptors: [ServerInterceptor<Deserializer.Output, Serializer.Input>],
+    interceptors: [ServerInterceptor<RequestDeserializer.Output, ResponseSerializer.Input>],
     eventObserverFactory: @escaping (StreamingResponseCallContext<ResponsePayload>)
       -> EventLoopFuture<(StreamEvent<RequestPayload>) -> Void>
-  ) where Serializer.Input == ResponsePayload, Deserializer.Output == RequestPayload {
+  ) {
     self.state = .requestIdleResponseIdle(eventObserverFactory)
     super.init(
       callHandlerContext: callHandlerContext,
-      codec: GRPCServerCodecHandler(serializer: serializer, deserializer: deserializer),
+      requestDeserializr: deserializer,
+      responseSerializer: serializer,
       callType: .bidirectionalStreaming,
       interceptors: interceptors
     )

+ 20 - 8
Sources/GRPC/CallHandlers/CallHandlerFactory.swift

@@ -27,7 +27,7 @@ public enum CallHandlerFactory {
     interceptors: [ServerInterceptor<Request, Response>] = [],
     eventObserverFactory: @escaping (UnaryContext<Response>)
       -> UnaryEventObserver<Request, Response>
-  ) -> UnaryCallHandler<Request, Response> {
+  ) -> UnaryCallHandler<ProtobufDeserializer<Request>, ProtobufSerializer<Response>> {
     return UnaryCallHandler(
       serializer: ProtobufSerializer(),
       deserializer: ProtobufDeserializer(),
@@ -42,7 +42,7 @@ public enum CallHandlerFactory {
     interceptors: [ServerInterceptor<Request, Response>] = [],
     eventObserverFactory: @escaping (UnaryContext<Response>)
       -> UnaryEventObserver<Request, Response>
-  ) -> UnaryCallHandler<Request, Response> {
+  ) -> UnaryCallHandler<GRPCPayloadDeserializer<Request>, GRPCPayloadSerializer<Response>> {
     return UnaryCallHandler(
       serializer: GRPCPayloadSerializer(),
       deserializer: GRPCPayloadDeserializer(),
@@ -61,7 +61,7 @@ public enum CallHandlerFactory {
     interceptors: [ServerInterceptor<Request, Response>] = [],
     eventObserverFactory: @escaping (ClientStreamingContext<Response>)
       -> ClientStreamingEventObserver<Request>
-  ) -> ClientStreamingCallHandler<Request, Response> {
+  ) -> ClientStreamingCallHandler<ProtobufDeserializer<Request>, ProtobufSerializer<Response>> {
     return ClientStreamingCallHandler(
       serializer: ProtobufSerializer(),
       deserializer: ProtobufDeserializer(),
@@ -76,7 +76,10 @@ public enum CallHandlerFactory {
     interceptors: [ServerInterceptor<Request, Response>] = [],
     eventObserverFactory: @escaping (ClientStreamingContext<Response>)
       -> ClientStreamingEventObserver<Request>
-  ) -> ClientStreamingCallHandler<Request, Response> {
+  ) -> ClientStreamingCallHandler<
+    GRPCPayloadDeserializer<Request>,
+    GRPCPayloadSerializer<Response>
+  > {
     return ClientStreamingCallHandler(
       serializer: GRPCPayloadSerializer(),
       deserializer: GRPCPayloadDeserializer(),
@@ -94,7 +97,7 @@ public enum CallHandlerFactory {
     interceptors: [ServerInterceptor<Request, Response>] = [],
     eventObserverFactory: @escaping (ServerStreamingContext<Response>)
       -> ServerStreamingEventObserver<Request>
-  ) -> ServerStreamingCallHandler<Request, Response> {
+  ) -> ServerStreamingCallHandler<ProtobufDeserializer<Request>, ProtobufSerializer<Response>> {
     return ServerStreamingCallHandler(
       serializer: ProtobufSerializer(),
       deserializer: ProtobufDeserializer(),
@@ -109,7 +112,10 @@ public enum CallHandlerFactory {
     interceptors: [ServerInterceptor<Request, Response>] = [],
     eventObserverFactory: @escaping (ServerStreamingContext<Response>)
       -> ServerStreamingEventObserver<Request>
-  ) -> ServerStreamingCallHandler<Request, Response> {
+  ) -> ServerStreamingCallHandler<
+    GRPCPayloadDeserializer<Request>,
+    GRPCPayloadSerializer<Response>
+  > {
     return ServerStreamingCallHandler(
       serializer: GRPCPayloadSerializer(),
       deserializer: GRPCPayloadDeserializer(),
@@ -128,7 +134,10 @@ public enum CallHandlerFactory {
     interceptors: [ServerInterceptor<Request, Response>] = [],
     eventObserverFactory: @escaping (BidirectionalStreamingContext<Response>)
       -> BidirectionalStreamingEventObserver<Request>
-  ) -> BidirectionalStreamingCallHandler<Request, Response> {
+  ) -> BidirectionalStreamingCallHandler<
+    ProtobufDeserializer<Request>,
+    ProtobufSerializer<Response>
+  > {
     return BidirectionalStreamingCallHandler(
       serializer: ProtobufSerializer(),
       deserializer: ProtobufDeserializer(),
@@ -143,7 +152,10 @@ public enum CallHandlerFactory {
     interceptors: [ServerInterceptor<Request, Response>] = [],
     eventObserverFactory: @escaping (BidirectionalStreamingContext<Response>)
       -> BidirectionalStreamingEventObserver<Request>
-  ) -> BidirectionalStreamingCallHandler<Request, Response> {
+  ) -> BidirectionalStreamingCallHandler<
+    GRPCPayloadDeserializer<Request>,
+    GRPCPayloadSerializer<Response>
+  > {
     return BidirectionalStreamingCallHandler(
       serializer: GRPCPayloadSerializer(),
       deserializer: GRPCPayloadDeserializer(),

+ 10 - 9
Sources/GRPC/CallHandlers/ClientStreamingCallHandler.swift

@@ -27,9 +27,9 @@ import SwiftProtobuf
 ///   they can fail the observer block future.
 /// - To close the call and send the response, complete `context.responsePromise`.
 public final class ClientStreamingCallHandler<
-  RequestPayload,
-  ResponsePayload
->: _BaseCallHandler<RequestPayload, ResponsePayload> {
+  RequestDeserializer: MessageDeserializer,
+  ResponseSerializer: MessageSerializer
+>: _BaseCallHandler<RequestDeserializer, ResponseSerializer> {
   private typealias Context = UnaryResponseCallContext<ResponsePayload>
   private typealias Observer = EventLoopFuture<(StreamEvent<RequestPayload>) -> Void>
 
@@ -45,18 +45,19 @@ public final class ClientStreamingCallHandler<
 
   // We ask for a future of type `EventObserver` to allow the framework user to e.g. asynchronously authenticate a call.
   // If authentication fails, they can simply fail the observer future, which causes the call to be terminated.
-  internal init<Serializer: MessageSerializer, Deserializer: MessageDeserializer>(
-    serializer: Serializer,
-    deserializer: Deserializer,
+  internal init(
+    serializer: ResponseSerializer,
+    deserializer: RequestDeserializer,
     callHandlerContext: CallHandlerContext,
-    interceptors: [ServerInterceptor<Deserializer.Output, Serializer.Input>],
+    interceptors: [ServerInterceptor<RequestDeserializer.Output, ResponseSerializer.Input>],
     eventObserverFactory: @escaping (UnaryResponseCallContext<ResponsePayload>)
       -> EventLoopFuture<(StreamEvent<RequestPayload>) -> Void>
-  ) where Serializer.Input == ResponsePayload, Deserializer.Output == RequestPayload {
+  ) {
     self.state = .requestIdleResponseIdle(eventObserverFactory)
     super.init(
       callHandlerContext: callHandlerContext,
-      codec: GRPCServerCodecHandler(serializer: serializer, deserializer: deserializer),
+      requestDeserializr: deserializer,
+      responseSerializer: serializer,
       callType: .clientStreaming,
       interceptors: interceptors
     )

+ 10 - 9
Sources/GRPC/CallHandlers/ServerStreamingCallHandler.swift

@@ -25,9 +25,9 @@ import SwiftProtobuf
 /// - The observer block is implemented by the framework user and calls `context.sendResponse` as needed.
 /// - To close the call and send the status, complete the status future returned by the observer block.
 public final class ServerStreamingCallHandler<
-  RequestPayload,
-  ResponsePayload
->: _BaseCallHandler<RequestPayload, ResponsePayload> {
+  RequestDeserializer: MessageDeserializer,
+  ResponseSerializer: MessageSerializer
+>: _BaseCallHandler<RequestDeserializer, ResponseSerializer> {
   private typealias Context = StreamingResponseCallContext<ResponsePayload>
   private typealias Observer = (RequestPayload) -> EventLoopFuture<GRPCStatus>
 
@@ -46,18 +46,19 @@ public final class ServerStreamingCallHandler<
     }
   }
 
-  internal init<Serializer: MessageSerializer, Deserializer: MessageDeserializer>(
-    serializer: Serializer,
-    deserializer: Deserializer,
+  internal init(
+    serializer: ResponseSerializer,
+    deserializer: RequestDeserializer,
     callHandlerContext: CallHandlerContext,
-    interceptors: [ServerInterceptor<Deserializer.Output, Serializer.Input>],
+    interceptors: [ServerInterceptor<RequestDeserializer.Output, ResponseSerializer.Input>],
     eventObserverFactory: @escaping (StreamingResponseCallContext<ResponsePayload>)
       -> (RequestPayload) -> EventLoopFuture<GRPCStatus>
-  ) where Serializer.Input == ResponsePayload, Deserializer.Output == RequestPayload {
+  ) {
     self.state = .requestIdleResponseIdle(eventObserverFactory)
     super.init(
       callHandlerContext: callHandlerContext,
-      codec: GRPCServerCodecHandler(serializer: serializer, deserializer: deserializer),
+      requestDeserializr: deserializer,
+      responseSerializer: serializer,
       callType: .serverStreaming,
       interceptors: interceptors
     )

+ 10 - 9
Sources/GRPC/CallHandlers/UnaryCallHandler.swift

@@ -26,9 +26,9 @@ import SwiftProtobuf
 /// - To return a response to the client, the framework user should complete that future
 ///   (similar to e.g. serving regular HTTP requests in frameworks such as Vapor).
 public final class UnaryCallHandler<
-  RequestPayload,
-  ResponsePayload
->: _BaseCallHandler<RequestPayload, ResponsePayload> {
+  RequestDeserializer: MessageDeserializer,
+  ResponseSerializer: MessageSerializer
+>: _BaseCallHandler<RequestDeserializer, ResponseSerializer> {
   private typealias Context = UnaryResponseCallContext<ResponsePayload>
   private typealias Observer = (RequestPayload) -> EventLoopFuture<ResponsePayload>
 
@@ -69,18 +69,19 @@ public final class UnaryCallHandler<
     }
   }
 
-  internal init<Serializer: MessageSerializer, Deserializer: MessageDeserializer>(
-    serializer: Serializer,
-    deserializer: Deserializer,
+  internal init(
+    serializer: ResponseSerializer,
+    deserializer: RequestDeserializer,
     callHandlerContext: CallHandlerContext,
-    interceptors: [ServerInterceptor<Deserializer.Output, Serializer.Input>],
+    interceptors: [ServerInterceptor<RequestDeserializer.Output, ResponseSerializer.Input>],
     eventObserverFactory: @escaping (UnaryResponseCallContext<ResponsePayload>)
       -> (RequestPayload) -> EventLoopFuture<ResponsePayload>
-  ) where Serializer.Input == ResponsePayload, Deserializer.Output == RequestPayload {
+  ) {
     self.state = .requestIdleResponseIdle(eventObserverFactory)
     super.init(
       callHandlerContext: callHandlerContext,
-      codec: GRPCServerCodecHandler(serializer: serializer, deserializer: deserializer),
+      requestDeserializr: deserializer,
+      responseSerializer: serializer,
       callType: .unary,
       interceptors: interceptors
     )

+ 72 - 33
Sources/GRPC/CallHandlers/_BaseCallHandler.swift

@@ -23,14 +23,18 @@ import SwiftProtobuf
 ///
 /// Calls through to `processMessage` for individual messages it receives, which needs to be implemented by subclasses.
 /// - Important: This is **NOT** part of the public API.
-public class _BaseCallHandler<Request, Response>: GRPCCallHandler, ChannelInboundHandler {
-  public typealias InboundIn = _GRPCServerRequestPart<Request>
-  public typealias OutboundOut = _GRPCServerResponsePart<Response>
+public class _BaseCallHandler<
+  RequestDeserializer: MessageDeserializer,
+  ResponseSerializer: MessageSerializer
+>: GRPCCallHandler, ChannelInboundHandler {
+  public typealias RequestPayload = RequestDeserializer.Output
+  public typealias ResponsePayload = ResponseSerializer.Input
 
-  public let _codec: ChannelHandler
+  public typealias InboundIn = _RawGRPCServerRequestPart
+  public typealias OutboundOut = _RawGRPCServerResponsePart
 
   /// An interceptor pipeline.
-  private var pipeline: ServerInterceptorPipeline<Request, Response>?
+  private var pipeline: ServerInterceptorPipeline<RequestPayload, ResponsePayload>?
 
   /// Our current state.
   private var state: State = .idle
@@ -41,6 +45,12 @@ public class _BaseCallHandler<Request, Response>: GRPCCallHandler, ChannelInboun
   /// Some context provided to us from the routing handler.
   private let callHandlerContext: CallHandlerContext
 
+  /// A request deserializer.
+  private let requestDeserializer: RequestDeserializer
+
+  /// A response serializer.
+  private let responseSerializer: ResponseSerializer
+
   /// The event loop this call is being handled on.
   internal var eventLoop: EventLoop {
     return self.callHandlerContext.eventLoop
@@ -61,14 +71,15 @@ public class _BaseCallHandler<Request, Response>: GRPCCallHandler, ChannelInboun
 
   internal init(
     callHandlerContext: CallHandlerContext,
-    codec: ChannelHandler,
+    requestDeserializr: RequestDeserializer,
+    responseSerializer: ResponseSerializer,
     callType: GRPCCallType,
-    interceptors: [ServerInterceptor<Request, Response>]
+    interceptors: [ServerInterceptor<RequestPayload, ResponsePayload>]
   ) {
     let userInfoRef = Ref(UserInfo())
-
+    self.requestDeserializer = requestDeserializr
+    self.responseSerializer = responseSerializer
     self.callHandlerContext = callHandlerContext
-    self._codec = codec
     self.callType = callType
     self.userInfoRef = userInfoRef
     self.pipeline = ServerInterceptorPipeline(
@@ -104,7 +115,20 @@ public class _BaseCallHandler<Request, Response>: GRPCCallHandler, ChannelInboun
 
   public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
     let part = self.unwrapInboundIn(data)
-    self.act(on: self.state.channelRead(part))
+
+    switch part {
+    case let .headers(headers):
+      self.act(on: self.state.channelRead(.headers(headers)))
+    case let .message(buffer):
+      do {
+        let request = try self.requestDeserializer.deserialize(byteBuffer: buffer)
+        self.act(on: self.state.channelRead(.message(request)))
+      } catch {
+        self.errorCaught(context: context, error: error)
+      }
+    case .end:
+      self.act(on: self.state.channelRead(.end))
+    }
     // We're the last handler. We don't have anything to forward.
   }
 
@@ -114,7 +138,7 @@ public class _BaseCallHandler<Request, Response>: GRPCCallHandler, ChannelInboun
     fatalError("must be overridden by subclasses")
   }
 
-  internal func observeRequest(_ message: Request) {
+  internal func observeRequest(_ message: RequestPayload) {
     fatalError("must be overridden by subclasses")
   }
 
@@ -131,7 +155,7 @@ public class _BaseCallHandler<Request, Response>: GRPCCallHandler, ChannelInboun
   ///   - part: The response part to send.
   ///   - promise: A promise to complete once the response part has been written.
   internal func sendResponsePartFromObserver(
-    _ part: GRPCServerResponsePart<Response>,
+    _ part: GRPCServerResponsePart<ResponsePayload>,
     promise: EventLoopPromise<Void>?
   ) {
     self.act(on: self.state.sendResponsePartFromObserver(part, promise: promise))
@@ -211,7 +235,7 @@ public class _BaseCallHandler<Request, Response>: GRPCCallHandler, ChannelInboun
 extension _BaseCallHandler {
   /// Receive a request part from the interceptors pipeline to forward to the event observer.
   /// - Parameter part: The request part to forward.
-  private func receiveRequestPartFromInterceptors(_ part: GRPCServerRequestPart<Request>) {
+  private func receiveRequestPartFromInterceptors(_ part: GRPCServerRequestPart<RequestPayload>) {
     self.act(on: self.state.receiveRequestPartFromInterceptors(part))
   }
 
@@ -221,7 +245,7 @@ extension _BaseCallHandler {
   ///   - part: The response part to send.
   ///   - promise: A promise to complete once the response part has been written.
   private func sendResponsePartFromInterceptors(
-    _ part: GRPCServerResponsePart<Response>,
+    _ part: GRPCServerResponsePart<ResponsePayload>,
     promise: EventLoopPromise<Void>?
   ) {
     self.act(on: self.state.sendResponsePartFromInterceptors(part, promise: promise))
@@ -381,21 +405,24 @@ extension _BaseCallHandler.State {
     case none
 
     /// Receive the request part in the interceptor pipeline.
-    case receiveRequestPartInInterceptors(GRPCServerRequestPart<Request>)
+    case receiveRequestPartInInterceptors(GRPCServerRequestPart<_BaseCallHandler.RequestPayload>)
 
     /// Receive the request part in the observer.
-    case receiveRequestPartInObserver(GRPCServerRequestPart<Request>)
+    case receiveRequestPartInObserver(GRPCServerRequestPart<_BaseCallHandler.RequestPayload>)
 
     /// Receive an error in the observer.
     case receiveLibraryErrorInObserver(Error)
 
     /// Send a response part to the interceptor pipeline.
-    case sendResponsePartToInterceptors(GRPCServerResponsePart<Response>, EventLoopPromise<Void>?)
+    case sendResponsePartToInterceptors(
+      GRPCServerResponsePart<_BaseCallHandler.ResponsePayload>,
+      EventLoopPromise<Void>?
+    )
 
     /// Write the response part to the `Channel`.
     case writeResponsePartToChannel(
       ChannelHandlerContext,
-      GRPCServerResponsePart<Response>,
+      GRPCServerResponsePart<_BaseCallHandler.ResponsePayload>,
       promise: EventLoopPromise<Void>?
     )
 
@@ -438,7 +465,9 @@ extension _BaseCallHandler.State {
 
   /// Receive a request part from the `Channel`. If we're active we just forward these through the
   /// pipeline. We validate at the other end.
-  internal mutating func channelRead(_ requestPart: _GRPCServerRequestPart<Request>) -> Action {
+  internal mutating func channelRead(
+    _ requestPart: _GRPCServerRequestPart<_BaseCallHandler.RequestPayload>
+  ) -> Action {
     switch self {
     case .idle:
       preconditionFailure("Invalid state: the handler isn't in the pipeline yet")
@@ -448,7 +477,7 @@ extension _BaseCallHandler.State {
       self = .idle
 
       let filter: StreamState.Filter
-      let part: GRPCServerRequestPart<Request>
+      let part: GRPCServerRequestPart<_BaseCallHandler.RequestPayload>
 
       switch requestPart {
       case let .headers(headers):
@@ -479,7 +508,7 @@ extension _BaseCallHandler.State {
 
   /// Send a response part from the observer to the interceptors.
   internal mutating func sendResponsePartFromObserver(
-    _ part: GRPCServerResponsePart<Response>,
+    _ part: GRPCServerResponsePart<_BaseCallHandler.ResponsePayload>,
     promise: EventLoopPromise<Void>?
   ) -> Action {
     switch self {
@@ -518,7 +547,7 @@ extension _BaseCallHandler.State {
 
   /// Send a response part from the interceptors to the `Channel`.
   internal mutating func sendResponsePartFromInterceptors(
-    _ part: GRPCServerResponsePart<Response>,
+    _ part: GRPCServerResponsePart<_BaseCallHandler.ResponsePayload>,
     promise: EventLoopPromise<Void>?
   ) -> Action {
     switch self {
@@ -559,7 +588,7 @@ extension _BaseCallHandler.State {
 
   /// A request part has traversed the interceptor pipeline, now send it to the observer.
   internal mutating func receiveRequestPartFromInterceptors(
-    _ part: GRPCServerRequestPart<Request>
+    _ part: GRPCServerRequestPart<_BaseCallHandler.RequestPayload>
   ) -> Action {
     switch self {
     case .idle:
@@ -632,13 +661,13 @@ extension _BaseCallHandler {
   }
 
   /// Receives a request part in the interceptor pipeline.
-  private func receiveRequestPartInInterceptors(_ part: GRPCServerRequestPart<Request>) {
+  private func receiveRequestPartInInterceptors(_ part: GRPCServerRequestPart<RequestPayload>) {
     self.pipeline?.receive(part)
   }
 
   /// Observe a request part. This just farms out to the subclass implementation for the
   /// appropriate part.
-  private func receiveRequestPartInObserver(_ part: GRPCServerRequestPart<Request>) {
+  private func receiveRequestPartInObserver(_ part: GRPCServerRequestPart<RequestPayload>) {
     switch part {
     case let .metadata(headers):
       self.observeHeaders(headers)
@@ -651,7 +680,7 @@ extension _BaseCallHandler {
 
   /// Sends a response part into the interceptor pipeline.
   private func sendResponsePartToInterceptors(
-    _ part: GRPCServerResponsePart<Response>,
+    _ part: GRPCServerResponsePart<ResponsePayload>,
     promise: EventLoopPromise<Void>?
   ) {
     if let pipeline = self.pipeline {
@@ -664,7 +693,7 @@ extension _BaseCallHandler {
   /// Writes a response part to the `Channel`.
   private func writeResponsePartToChannel(
     context: ChannelHandlerContext,
-    part: GRPCServerResponsePart<Response>,
+    part: GRPCServerResponsePart<ResponsePayload>,
     promise: EventLoopPromise<Void>?
   ) {
     let flush: Bool
@@ -677,12 +706,22 @@ extension _BaseCallHandler {
       context.write(self.wrapOutboundOut(.headers(headers)), promise: promise)
 
     case let .message(message, metadata):
-      context.write(
-        self.wrapOutboundOut(.message(.init(message, compressed: metadata.compress))),
-        promise: promise
-      )
-      // Flush if we've been told to flush.
-      flush = metadata.flush
+      do {
+        let serializedResponse = try self.responseSerializer.serialize(
+          message,
+          allocator: context.channel.allocator
+        )
+        context.write(
+          self.wrapOutboundOut(.message(.init(serializedResponse, compressed: metadata.compress))),
+          promise: promise
+        )
+        // Flush if we've been told to flush.
+        flush = metadata.flush
+      } catch {
+        self.errorCaught(context: context, error: error)
+        promise?.fail(error)
+        return
+      }
 
     case let .end(status, trailers):
       context.write(self.wrapOutboundOut(.statusAndTrailers(status, trailers)), promise: promise)

+ 0 - 87
Sources/GRPC/GRPCServerCodecHandler.swift

@@ -1,87 +0,0 @@
-/*
- * Copyright 2020, gRPC Authors All rights reserved.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-import NIO
-
-class GRPCServerCodecHandler<Serializer: MessageSerializer, Deserializer: MessageDeserializer> {
-  /// The response serializer.
-  private let serializer: Serializer
-
-  /// The request deserializer.
-  private let deserializer: Deserializer
-
-  internal init(serializer: Serializer, deserializer: Deserializer) {
-    self.serializer = serializer
-    self.deserializer = deserializer
-  }
-}
-
-extension GRPCServerCodecHandler: ChannelInboundHandler {
-  typealias InboundIn = _RawGRPCServerRequestPart
-  typealias InboundOut = _GRPCServerRequestPart<Deserializer.Output>
-
-  internal func channelRead(context: ChannelHandlerContext, data: NIOAny) {
-    switch self.unwrapInboundIn(data) {
-    case let .headers(head):
-      context.fireChannelRead(self.wrapInboundOut(.headers(head)))
-
-    case let .message(buffer):
-      do {
-        let deserialized = try self.deserializer.deserialize(byteBuffer: buffer)
-        context.fireChannelRead(self.wrapInboundOut(.message(deserialized)))
-      } catch {
-        context.fireErrorCaught(error)
-      }
-
-    case .end:
-      context.fireChannelRead(self.wrapInboundOut(.end))
-    }
-  }
-}
-
-extension GRPCServerCodecHandler: ChannelOutboundHandler {
-  typealias OutboundIn = _GRPCServerResponsePart<Serializer.Input>
-  typealias OutboundOut = _RawGRPCServerResponsePart
-
-  internal func write(
-    context: ChannelHandlerContext,
-    data: NIOAny,
-    promise: EventLoopPromise<Void>?
-  ) {
-    switch self.unwrapOutboundIn(data) {
-    case let .headers(headers):
-      context.write(self.wrapOutboundOut(.headers(headers)), promise: promise)
-
-    case let .message(messageContext):
-      do {
-        let buffer = try self.serializer.serialize(
-          messageContext.message,
-          allocator: context.channel.allocator
-        )
-        context.write(
-          self.wrapOutboundOut(.message(.init(buffer, compressed: messageContext.compressed))),
-          promise: promise
-        )
-      } catch {
-        let error = GRPCError.SerializationFailure().captureContext()
-        promise?.fail(error)
-        context.fireErrorCaught(error)
-      }
-
-    case let .statusAndTrailers(status, trailers):
-      context.write(self.wrapOutboundOut(.statusAndTrailers(status, trailers)), promise: promise)
-    }
-  }
-}

+ 2 - 5
Sources/GRPC/GRPCServerRequestRoutingHandler.swift

@@ -19,9 +19,7 @@ import NIOHTTP1
 import SwiftProtobuf
 
 /// Processes individual gRPC messages and stream-close events on an HTTP2 channel.
-public protocol GRPCCallHandler: ChannelHandler {
-  var _codec: ChannelHandler { get }
-}
+public protocol GRPCCallHandler: ChannelHandler {}
 
 /// Provides `GRPCCallHandler` objects for the methods on a particular service name.
 ///
@@ -184,8 +182,7 @@ extension GRPCServerRequestRoutingHandler: ChannelInboundHandler, RemovableChann
 
       // Configure the rest of the pipeline to serve the RPC.
       let httpToGRPC = HTTP1ToGRPCServerCodec(encoding: self.encoding, logger: self.logger)
-      let codec = callHandler._codec
-      context.pipeline.addHandlers([httpToGRPC, codec, callHandler], position: .after(self))
+      context.pipeline.addHandlers([httpToGRPC, callHandler], position: .after(self))
         .whenSuccess {
           context.pipeline.removeHandler(self, promise: nil)
         }

+ 10 - 10
Sources/GRPC/Serialization.swift

@@ -17,7 +17,7 @@ import NIO
 import NIOFoundationCompat
 import SwiftProtobuf
 
-internal protocol MessageSerializer {
+public protocol MessageSerializer {
   associatedtype Input
 
   /// Serializes `input` into a `ByteBuffer` allocated using the provided `allocator`.
@@ -28,7 +28,7 @@ internal protocol MessageSerializer {
   func serialize(_ input: Input, allocator: ByteBufferAllocator) throws -> ByteBuffer
 }
 
-internal protocol MessageDeserializer {
+public protocol MessageDeserializer {
   associatedtype Output
 
   /// Deserializes `byteBuffer` to produce a single `Output`.
@@ -39,8 +39,8 @@ internal protocol MessageDeserializer {
 
 // MARK: Protobuf
 
-internal struct ProtobufSerializer<Message: SwiftProtobuf.Message>: MessageSerializer {
-  internal func serialize(_ message: Message, allocator: ByteBufferAllocator) throws -> ByteBuffer {
+public struct ProtobufSerializer<Message: SwiftProtobuf.Message>: MessageSerializer {
+  public func serialize(_ message: Message, allocator: ByteBufferAllocator) throws -> ByteBuffer {
     // Serialize the message.
     let serialized = try message.serializedData()
 
@@ -58,8 +58,8 @@ internal struct ProtobufSerializer<Message: SwiftProtobuf.Message>: MessageSeria
   }
 }
 
-internal struct ProtobufDeserializer<Message: SwiftProtobuf.Message>: MessageDeserializer {
-  internal func deserialize(byteBuffer: ByteBuffer) throws -> Message {
+public struct ProtobufDeserializer<Message: SwiftProtobuf.Message>: MessageDeserializer {
+  public func deserialize(byteBuffer: ByteBuffer) throws -> Message {
     var buffer = byteBuffer
     // '!' is okay; we can always read 'readableBytes'.
     let data = buffer.readData(length: buffer.readableBytes)!
@@ -69,8 +69,8 @@ internal struct ProtobufDeserializer<Message: SwiftProtobuf.Message>: MessageDes
 
 // MARK: GRPCPayload
 
-internal struct GRPCPayloadSerializer<Message: GRPCPayload>: MessageSerializer {
-  internal func serialize(_ message: Message, allocator: ByteBufferAllocator) throws -> ByteBuffer {
+public struct GRPCPayloadSerializer<Message: GRPCPayload>: MessageSerializer {
+  public func serialize(_ message: Message, allocator: ByteBufferAllocator) throws -> ByteBuffer {
     // Reserve 5 leading bytes. This a minor optimisation win: the length prefixed message writer
     // can re-use the leading 5 bytes without needing to allocate a new buffer and copy over the
     // serialized message.
@@ -101,8 +101,8 @@ internal struct GRPCPayloadSerializer<Message: GRPCPayload>: MessageSerializer {
   }
 }
 
-internal struct GRPCPayloadDeserializer<Message: GRPCPayload>: MessageDeserializer {
-  internal func deserialize(byteBuffer: ByteBuffer) throws -> Message {
+public struct GRPCPayloadDeserializer<Message: GRPCPayload>: MessageDeserializer {
+  public func deserialize(byteBuffer: ByteBuffer) throws -> Message {
     var buffer = byteBuffer
     return try Message(serializedByteBuffer: &buffer)
   }

+ 0 - 61
Tests/GRPCTests/GRPCServerCodecHandlerTests.swift

@@ -1,61 +0,0 @@
-/*
- * Copyright 2020, gRPC Authors All rights reserved.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-@testable import GRPC
-import NIO
-import XCTest
-
-class GRPCServerCodecHandlerTests: GRPCTestCase {
-  struct BlowUpError: Error {}
-
-  struct BlowUpSerializer: MessageSerializer {
-    typealias Input = Any
-
-    func serialize(_ input: Any, allocator: ByteBufferAllocator) throws -> ByteBuffer {
-      throw BlowUpError()
-    }
-  }
-
-  struct BlowUpDeserializer: MessageDeserializer {
-    typealias Output = Any
-
-    func deserialize(byteBuffer: ByteBuffer) throws -> Any {
-      throw BlowUpError()
-    }
-  }
-
-  func testSerializationFailure() throws {
-    let handler = GRPCServerCodecHandler(
-      serializer: BlowUpSerializer(),
-      deserializer: BlowUpDeserializer()
-    )
-    let channel = EmbeddedChannel(handler: handler)
-    XCTAssertThrowsError(try channel.writeInbound(_RawGRPCServerRequestPart.message(ByteBuffer())))
-    XCTAssertNil(try channel.readInbound(as: Any.self))
-  }
-
-  func testDeserializationFailure() throws {
-    let handler = GRPCServerCodecHandler(
-      serializer: BlowUpSerializer(),
-      deserializer: BlowUpDeserializer()
-    )
-    let channel = EmbeddedChannel(handler: handler)
-    XCTAssertThrowsError(
-      try channel
-        .writeOutbound(_GRPCServerResponsePart<Any>.message(.init(ByteBuffer(), compressed: false)))
-    )
-    XCTAssertNil(try channel.readOutbound(as: Any.self))
-  }
-}

+ 6 - 2
Tests/GRPCTests/GRPCServerRequestRoutingHandlerTests.swift

@@ -117,8 +117,12 @@ class GRPCServerRequestRoutingHandlerTests: GRPCTestCase {
     XCTAssertThrowsError(try router.wait())
 
     // There should now be a unary call handler.
-    let unary = self.channel.pipeline
-      .handler(type: UnaryCallHandler<Echo_EchoRequest, Echo_EchoResponse>.self)
+    let unary = self.channel.pipeline.handler(
+      type: UnaryCallHandler<
+        ProtobufDeserializer<Echo_EchoRequest>,
+        ProtobufSerializer<Echo_EchoResponse>
+      >.self
+    )
     XCTAssertNoThrow(try unary.wait())
   }
 

+ 49 - 6
Tests/GRPCTests/ServerInterceptorTests.swift

@@ -73,7 +73,7 @@ class ServerInterceptorTests: GRPCTestCase {
     let provider = self.echoProvider(interceptedBy: recorder)
 
     let handler = try assertNotNil(self.handleMethod("Get", using: provider))
-    assertThat(try self.channel.pipeline.addHandler(handler).wait(), .doesNotThrow())
+    assertThat(try self.channel.pipeline.addHandlers([Codec(), handler]).wait(), .doesNotThrow())
 
     // Send requests.
     assertThat(try self.channel.writeInbound(self.request(.headers([:]))), .doesNotThrow())
@@ -120,7 +120,7 @@ class ServerInterceptorTests: GRPCTestCase {
     }
 
     let handler = try assertNotNil(self.handleMethod(method, using: provider))
-    assertThat(try self.channel.pipeline.addHandler(handler).wait(), .doesNotThrow())
+    assertThat(try self.channel.pipeline.addHandlers([Codec(), handler]).wait(), .doesNotThrow())
 
     // Send the requests.
     assertThat(try self.channel.writeInbound(self.request(.headers([:]))), .doesNotThrow())
@@ -178,7 +178,7 @@ class ServerInterceptorTests: GRPCTestCase {
   func testUnaryFromInterceptor() throws {
     let provider = EchoFromInterceptor()
     let handler = try assertNotNil(self.handleMethod("Get", using: provider))
-    assertThat(try self.channel.pipeline.addHandler(handler).wait(), .doesNotThrow())
+    assertThat(try self.channel.pipeline.addHandlers([Codec(), handler]).wait(), .doesNotThrow())
 
     // Send the requests.
     assertThat(try self.channel.writeInbound(self.request(.headers([:]))), .doesNotThrow())
@@ -200,7 +200,7 @@ class ServerInterceptorTests: GRPCTestCase {
   func testClientStreamingFromInterceptor() throws {
     let provider = EchoFromInterceptor()
     let handler = try assertNotNil(self.handleMethod("Collect", using: provider))
-    assertThat(try self.channel.pipeline.addHandler(handler).wait(), .doesNotThrow())
+    assertThat(try self.channel.pipeline.addHandlers([Codec(), handler]).wait(), .doesNotThrow())
 
     // Send the requests.
     assertThat(try self.channel.writeInbound(self.request(.headers([:]))), .doesNotThrow())
@@ -222,7 +222,7 @@ class ServerInterceptorTests: GRPCTestCase {
   func testServerStreamingFromInterceptor() throws {
     let provider = EchoFromInterceptor()
     let handler = try assertNotNil(self.handleMethod("Expand", using: provider))
-    assertThat(try self.channel.pipeline.addHandler(handler).wait(), .doesNotThrow())
+    assertThat(try self.channel.pipeline.addHandlers([Codec(), handler]).wait(), .doesNotThrow())
 
     // Send the requests.
     assertThat(try self.channel.writeInbound(self.request(.headers([:]))), .doesNotThrow())
@@ -247,7 +247,7 @@ class ServerInterceptorTests: GRPCTestCase {
   func testBidirectionalStreamingFromInterceptor() throws {
     let provider = EchoFromInterceptor()
     let handler = try assertNotNil(self.handleMethod("Update", using: provider))
-    assertThat(try self.channel.pipeline.addHandler(handler).wait(), .doesNotThrow())
+    assertThat(try self.channel.pipeline.addHandlers([Codec(), handler]).wait(), .doesNotThrow())
 
     // Send the requests.
     assertThat(try self.channel.writeInbound(self.request(.headers([:]))), .doesNotThrow())
@@ -437,3 +437,46 @@ class EchoFromInterceptor: Echo_EchoProvider {
     }
   }
 }
+
+// Avoid having to serialize/deserialize messages in test cases.
+private class Codec: ChannelDuplexHandler {
+  typealias InboundIn = _GRPCServerRequestPart<Echo_EchoRequest>
+  typealias InboundOut = _RawGRPCServerRequestPart
+
+  typealias OutboundIn = _RawGRPCServerResponsePart
+  typealias OutboundOut = _GRPCServerResponsePart<Echo_EchoResponse>
+
+  private let serializer = ProtobufSerializer<Echo_EchoRequest>()
+  private let deserializer = ProtobufDeserializer<Echo_EchoResponse>()
+
+  func channelRead(context: ChannelHandlerContext, data: NIOAny) {
+    switch self.unwrapInboundIn(data) {
+    case let .headers(headers):
+      context.fireChannelRead(self.wrapInboundOut(.headers(headers)))
+
+    case let .message(message):
+      let serialized = try! self.serializer.serialize(message, allocator: context.channel.allocator)
+      context.fireChannelRead(self.wrapInboundOut(.message(serialized)))
+
+    case .end:
+      context.fireChannelRead(self.wrapInboundOut(.end))
+    }
+  }
+
+  func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
+    switch self.unwrapOutboundIn(data) {
+    case let .headers(headers):
+      context.write(self.wrapOutboundOut(.headers(headers)), promise: promise)
+
+    case let .message(message):
+      let deserialzed = try! self.deserializer.deserialize(byteBuffer: message.message)
+      context.write(
+        self.wrapOutboundOut(.message(.init(deserialzed, compressed: message.compressed))),
+        promise: promise
+      )
+
+    case let .statusAndTrailers(status, trailers):
+      context.write(self.wrapOutboundOut(.statusAndTrailers(status, trailers)), promise: promise)
+    }
+  }
+}

+ 0 - 11
Tests/GRPCTests/XCTestManifests.swift

@@ -634,16 +634,6 @@ extension GRPCSecureInteroperabilityTests {
     ]
 }
 
-extension GRPCServerCodecHandlerTests {
-    // DO NOT MODIFY: This is autogenerated, use:
-    //   `swift test --generate-linuxmain`
-    // to regenerate.
-    static let __allTests__GRPCServerCodecHandlerTests = [
-        ("testDeserializationFailure", testDeserializationFailure),
-        ("testSerializationFailure", testSerializationFailure),
-    ]
-}
-
 extension GRPCServerRequestRoutingHandlerTests {
     // DO NOT MODIFY: This is autogenerated, use:
     //   `swift test --generate-linuxmain`
@@ -1154,7 +1144,6 @@ public func __allTests() -> [XCTestCaseEntry] {
         testCase(GRPCInsecureInteroperabilityTests.__allTests__GRPCInsecureInteroperabilityTests),
         testCase(GRPCPingHandlerTests.__allTests__GRPCPingHandlerTests),
         testCase(GRPCSecureInteroperabilityTests.__allTests__GRPCSecureInteroperabilityTests),
-        testCase(GRPCServerCodecHandlerTests.__allTests__GRPCServerCodecHandlerTests),
         testCase(GRPCServerRequestRoutingHandlerTests.__allTests__GRPCServerRequestRoutingHandlerTests),
         testCase(GRPCStatusCodeTests.__allTests__GRPCStatusCodeTests),
         testCase(GRPCStatusMessageMarshallerTests.__allTests__GRPCStatusMessageMarshallerTests),