Browse Source

Move client serialization into the transport handler (#1083)

Motivation:

Moving the client serialization into transport handler allows us to
get rid of an unspecialized generic handler.

Modifications:

- Add an 'AnySerializer' and 'AnyDeserializer', this is unfortunate but
  necessary since 'ClientTransport' is generic over 'Request' and
  'Response' rather than their respective (de/)serializer. Changing this
  would involve changing the generic constraints on all of the client call
  objects.
- Move (de/)serialization into the 'ClientTransport'
- Add a reverse codec for 'fake' transport

Result:

3.5% fewer instructions in the unary_10k_small_requests benchmark.
George Barnett 5 years ago
parent
commit
31dffb8dc1

+ 19 - 1
Sources/GRPC/FakeChannel.swift

@@ -99,7 +99,25 @@ public class FakeChannel: GRPCChannel {
     )
   }
 
-  private func _makeCall<Request, Response>(
+  private func _makeCall<Request: Message, Response: Message>(
+    path: String,
+    type: GRPCCallType,
+    callOptions: CallOptions,
+    interceptors: [ClientInterceptor<Request, Response>]
+  ) -> Call<Request, Response> {
+    let stream: _FakeResponseStream<Request, Response>? = self.dequeueResponseStream(forPath: path)
+    let eventLoop = stream?.channel.eventLoop ?? EmbeddedEventLoop()
+    return Call(
+      path: path,
+      type: type,
+      eventLoop: eventLoop,
+      options: callOptions,
+      interceptors: interceptors,
+      transportFactory: .fake(stream, on: eventLoop)
+    )
+  }
+
+  private func _makeCall<Request: GRPCPayload, Response: GRPCPayload>(
     path: String,
     type: GRPCCallType,
     callOptions: CallOptions,

+ 45 - 30
Sources/GRPC/Interceptor/ClientTransport.swift

@@ -52,6 +52,12 @@ internal final class ClientTransport<Request, Response> {
   /// A buffer to store request parts and promises in before the channel has become active.
   private var writeBuffer = MarkedCircularBuffer<RequestAndPromise>(initialCapacity: 4)
 
+  /// The request serializer.
+  private let serializer: AnySerializer<Request>
+
+  /// The response deserializer.
+  private let deserializer: AnyDeserializer<Response>
+
   /// A request part and a promise.
   private struct RequestAndPromise {
     var request: GRPCClientRequestPart<Request>
@@ -102,12 +108,16 @@ internal final class ClientTransport<Request, Response> {
     details: CallDetails,
     eventLoop: EventLoop,
     interceptors: [ClientInterceptor<Request, Response>],
+    serializer: AnySerializer<Request>,
+    deserializer: AnyDeserializer<Response>,
     errorDelegate: ClientErrorDelegate?,
     onError: @escaping (Error) -> Void,
     onResponsePart: @escaping (GRPCClientResponsePart<Response>) -> Void
   ) {
     self.eventLoop = eventLoop
     self.callDetails = details
+    self.serializer = serializer
+    self.deserializer = deserializer
     self._pipeline = ClientInterceptorPipeline(
       eventLoop: eventLoop,
       details: details,
@@ -236,10 +246,10 @@ extension ClientTransport {
 
 extension ClientTransport: ChannelInboundHandler {
   @usableFromInline
-  typealias InboundIn = _GRPCClientResponsePart<Response>
+  typealias InboundIn = _RawGRPCClientResponsePart
 
   @usableFromInline
-  typealias OutboundOut = _GRPCClientRequestPart<Request>
+  typealias OutboundOut = _RawGRPCClientRequestPart
 
   @usableFromInline
   func handlerAdded(context: ChannelHandlerContext) {
@@ -311,16 +321,32 @@ extension ClientTransport: ChannelInboundHandler {
     self.eventLoop.assertInEventLoop()
     let part = self.unwrapInboundIn(data)
 
-    let isEnd: Bool
     switch part {
-    case .initialMetadata, .message, .trailingMetadata:
-      isEnd = false
-    case .status:
-      isEnd = true
-    }
+    case let .initialMetadata(headers):
+      if self.state.channelRead(isEnd: false) {
+        self.forwardToInterceptors(.metadata(headers))
+      }
 
-    if self.state.channelRead(isEnd: isEnd) {
-      self.forwardToInterceptors(part)
+    case let .message(context):
+      do {
+        let message = try self.deserializer.deserialize(byteBuffer: context.message)
+        if self.state.channelRead(isEnd: false) {
+          self.forwardToInterceptors(.message(message))
+        }
+      } catch {
+        self.channelError(error)
+      }
+
+    case let .trailingMetadata(trailers):
+      // The `Channel` delivers trailers and `GRPCStatus` separately, we want to emit them together
+      // in the interceptor pipeline.
+      self.trailers = trailers
+
+    case let .status(status):
+      if self.state.channelRead(isEnd: true) {
+        self.forwardToInterceptors(.end(status, self.trailers ?? [:]))
+        self.trailers = nil
+      }
     }
 
     // (We're the end of the channel. No need to forward anything.)
@@ -769,8 +795,13 @@ extension ClientTransport {
       context.channel.write(self.wrapOutboundOut(.head(head)), promise: promise)
 
     case let .message(request, metadata):
-      let message = _MessageContext<Request>(request, compressed: metadata.compress)
-      context.channel.write(self.wrapOutboundOut(.message(message)), promise: promise)
+      do {
+        let bytes = try self.serializer.serialize(request, allocator: context.channel.allocator)
+        let message = _MessageContext<ByteBuffer>(bytes, compressed: metadata.compress)
+        context.channel.write(self.wrapOutboundOut(.message(message)), promise: promise)
+      } catch {
+        self.channelError(error)
+      }
 
     case .end:
       context.channel.write(self.wrapOutboundOut(.end), promise: promise)
@@ -783,24 +814,8 @@ extension ClientTransport {
 
   /// Forward the response part to the interceptor pipeline.
   /// - Parameter part: The response part to forward.
-  private func forwardToInterceptors(_ part: _GRPCClientResponsePart<Response>) {
-    switch part {
-    case let .initialMetadata(metadata):
-      self._pipeline?.receive(.metadata(metadata))
-
-    case let .message(context):
-      self._pipeline?.receive(.message(context.message))
-
-    case let .trailingMetadata(trailers):
-      // The `Channel` delivers trailers and `GRPCStatus`, we want to emit them together in the
-      // interceptor pipeline.
-      self.trailers = trailers
-
-    case let .status(status):
-      let trailers = self.trailers ?? [:]
-      self.trailers = nil
-      self._pipeline?.receive(.end(status, trailers))
-    }
+  private func forwardToInterceptors(_ part: GRPCClientResponsePart<Response>) {
+    self._pipeline?.receive(part)
   }
 
   /// Forward the error to the interceptor pipeline.

+ 76 - 15
Sources/GRPC/Interceptor/ClientTransportFactory.swift

@@ -77,8 +77,8 @@ internal struct ClientTransportFactory<Request, Response> {
       multiplexer: multiplexer,
       scheme: scheme,
       authority: authority,
-      serializer: GRPCPayloadSerializer(),
-      deserializer: GRPCPayloadDeserializer(),
+      serializer: AnySerializer(wrapping: GRPCPayloadSerializer()),
+      deserializer: AnyDeserializer(wrapping: GRPCPayloadDeserializer()),
       errorDelegate: errorDelegate
     )
     return .init(http2)
@@ -87,11 +87,37 @@ internal struct ClientTransportFactory<Request, Response> {
   /// Make a factory for 'fake' transport.
   /// - Parameter fakeResponse: The fake response stream.
   /// - Returns: A factory for making and configuring fake transport.
-  internal static func fake(
+  internal static func fake<Request: SwiftProtobuf.Message, Response: SwiftProtobuf.Message>(
     _ fakeResponse: _FakeResponseStream<Request, Response>?,
     on eventLoop: EventLoop
   ) -> ClientTransportFactory<Request, Response> {
-    return .init(FakeClientTransportFactory(fakeResponse, on: eventLoop))
+    let factory = FakeClientTransportFactory(
+      fakeResponse,
+      on: eventLoop,
+      requestSerializer: ProtobufSerializer(),
+      requestDeserializer: ProtobufDeserializer(),
+      responseSerializer: ProtobufSerializer(),
+      responseDeserializer: ProtobufDeserializer()
+    )
+    return .init(factory)
+  }
+
+  /// Make a factory for 'fake' transport.
+  /// - Parameter fakeResponse: The fake response stream.
+  /// - Returns: A factory for making and configuring fake transport.
+  internal static func fake<Request: GRPCPayload, Response: GRPCPayload>(
+    _ fakeResponse: _FakeResponseStream<Request, Response>?,
+    on eventLoop: EventLoop
+  ) -> ClientTransportFactory<Request, Response> {
+    let factory = FakeClientTransportFactory(
+      fakeResponse,
+      on: eventLoop,
+      requestSerializer: GRPCPayloadSerializer(),
+      requestDeserializer: GRPCPayloadDeserializer(),
+      responseSerializer: GRPCPayloadSerializer(),
+      responseDeserializer: GRPCPayloadDeserializer()
+    )
+    return .init(factory)
   }
 
   /// Makes a configured `ClientTransport`.
@@ -103,7 +129,7 @@ internal struct ClientTransportFactory<Request, Response> {
   ///   - onError: A callback invoked when an error is received.
   ///   - onResponsePart: A closure called for each response part received.
   /// - Returns: A configured transport.
-  internal func makeConfiguredTransport<Request, Response>(
+  internal func makeConfiguredTransport(
     to path: String,
     for type: GRPCCallType,
     withOptions options: CallOptions,
@@ -151,8 +177,11 @@ private struct HTTP2ClientTransportFactory<Request, Response> {
   /// An error delegate.
   private var errorDelegate: ClientErrorDelegate?
 
-  /// A codec for serializing request messages and deserializing response parts.
-  private var codec: ChannelHandler
+  /// The request serializer.
+  private let serializer: AnySerializer<Request>
+
+  /// The response deserializer.
+  private let deserializer: AnyDeserializer<Response>
 
   fileprivate init<Serializer: MessageSerializer, Deserializer: MessageDeserializer>(
     multiplexer: EventLoopFuture<HTTP2StreamMultiplexer>,
@@ -165,11 +194,12 @@ private struct HTTP2ClientTransportFactory<Request, Response> {
     self.multiplexer = multiplexer
     self.scheme = scheme
     self.authority = authority
-    self.codec = GRPCClientCodecHandler(serializer: serializer, deserializer: deserializer)
+    self.serializer = AnySerializer(wrapping: serializer)
+    self.deserializer = AnyDeserializer(wrapping: deserializer)
     self.errorDelegate = errorDelegate
   }
 
-  fileprivate func makeTransport<Request, Response>(
+  fileprivate func makeTransport(
     to path: String,
     for type: GRPCCallType,
     withOptions options: CallOptions,
@@ -181,6 +211,8 @@ private struct HTTP2ClientTransportFactory<Request, Response> {
       details: self.makeCallDetails(type: type, path: path, options: options),
       eventLoop: self.multiplexer.eventLoop,
       interceptors: interceptors,
+      serializer: self.serializer,
+      deserializer: self.deserializer,
       errorDelegate: self.errorDelegate,
       onError: onError,
       onResponsePart: onResponsePart
@@ -198,7 +230,6 @@ private struct HTTP2ClientTransportFactory<Request, Response> {
               callType: transport.callDetails.type,
               logger: transport.logger
             ),
-            self.codec,
             transport,
           ])
         }
@@ -233,15 +264,43 @@ private struct FakeClientTransportFactory<Request, Response> {
   /// stream be `nil`.
   private var eventLoop: EventLoop
 
-  fileprivate init(
+  /// The request serializer.
+  private let requestSerializer: AnySerializer<Request>
+
+  /// The response deserializer.
+  private let responseDeserializer: AnyDeserializer<Response>
+
+  /// A codec for deserializing requests and serializing responses.
+  private let codec: ChannelHandler
+
+  fileprivate init<
+    RequestSerializer: MessageSerializer,
+    RequestDeserializer: MessageDeserializer,
+    ResponseSerializer: MessageSerializer,
+    ResponseDeserializer: MessageDeserializer
+  >(
     _ fakeResponseStream: _FakeResponseStream<Request, Response>?,
-    on eventLoop: EventLoop
-  ) {
+    on eventLoop: EventLoop,
+    requestSerializer: RequestSerializer,
+    requestDeserializer: RequestDeserializer,
+    responseSerializer: ResponseSerializer,
+    responseDeserializer: ResponseDeserializer
+  ) where RequestSerializer.Input == Request,
+    RequestDeserializer.Output == Request,
+    ResponseSerializer.Input == Response,
+    ResponseDeserializer.Output == Response
+  {
     self.fakeResponseStream = fakeResponseStream
     self.eventLoop = eventLoop
+    self.requestSerializer = AnySerializer(wrapping: requestSerializer)
+    self.responseDeserializer = AnyDeserializer(wrapping: responseDeserializer)
+    self.codec = GRPCClientReverseCodecHandler(
+      serializer: responseSerializer,
+      deserializer: requestDeserializer
+    )
   }
 
-  fileprivate func makeTransport<Request, Response>(
+  fileprivate func makeTransport(
     to path: String,
     for type: GRPCCallType,
     withOptions options: CallOptions,
@@ -259,6 +318,8 @@ private struct FakeClientTransportFactory<Request, Response> {
       ),
       eventLoop: self.eventLoop,
       interceptors: interceptors,
+      serializer: self.requestSerializer,
+      deserializer: self.responseDeserializer,
       errorDelegate: nil,
       onError: onError,
       onResponsePart: onResponsePart
@@ -268,7 +329,7 @@ private struct FakeClientTransportFactory<Request, Response> {
   fileprivate func configure<Request, Response>(_ transport: ClientTransport<Request, Response>) {
     transport.configure { handler in
       if let fakeResponse = self.fakeResponseStream {
-        return fakeResponse.channel.pipeline.addHandler(handler).always { result in
+        return fakeResponse.channel.pipeline.addHandlers(self.codec, handler).always { result in
           switch result {
           case .success:
             fakeResponse.activate()

+ 28 - 0
Sources/GRPC/Serialization.swift

@@ -125,3 +125,31 @@ public struct GRPCPayloadDeserializer<Message: GRPCPayload>: MessageDeserializer
     return try Message(serializedByteBuffer: &buffer)
   }
 }
+
+// MARK: - Any Serializer/Deserializer
+
+internal struct AnySerializer<Input>: MessageSerializer {
+  private let _serialize: (Input, ByteBufferAllocator) throws -> ByteBuffer
+
+  init<Serializer: MessageSerializer>(wrapping other: Serializer) where Serializer.Input == Input {
+    self._serialize = other.serialize(_:allocator:)
+  }
+
+  internal func serialize(_ input: Input, allocator: ByteBufferAllocator) throws -> ByteBuffer {
+    return try self._serialize(input, allocator)
+  }
+}
+
+internal struct AnyDeserializer<Output>: MessageDeserializer {
+  private let _deserialize: (ByteBuffer) throws -> Output
+
+  init<Deserializer: MessageDeserializer>(
+    wrapping other: Deserializer
+  ) where Deserializer.Output == Output {
+    self._deserialize = other.deserialize(byteBuffer:)
+  }
+
+  internal func deserialize(byteBuffer: ByteBuffer) throws -> Output {
+    return try self._deserialize(byteBuffer)
+  }
+}

+ 80 - 0
Sources/GRPC/_GRPCClientCodecHandler.swift

@@ -94,3 +94,83 @@ extension GRPCClientCodecHandler: ChannelOutboundHandler {
     }
   }
 }
+
+// MARK: Reverse Codec
+
+internal class GRPCClientReverseCodecHandler<
+  Serializer: MessageSerializer,
+  Deserializer: MessageDeserializer
+> {
+  /// The request serializer.
+  private let serializer: Serializer
+
+  /// The response deserializer.
+  private let deserializer: Deserializer
+
+  internal init(serializer: Serializer, deserializer: Deserializer) {
+    self.serializer = serializer
+    self.deserializer = deserializer
+  }
+}
+
+extension GRPCClientReverseCodecHandler: ChannelInboundHandler {
+  typealias InboundIn = _GRPCClientResponsePart<Serializer.Input>
+  typealias InboundOut = _RawGRPCClientResponsePart
+
+  internal func channelRead(context: ChannelHandlerContext, data: NIOAny) {
+    switch self.unwrapInboundIn(data) {
+    case let .initialMetadata(headers):
+      context.fireChannelRead(self.wrapInboundOut(.initialMetadata(headers)))
+
+    case let .message(messageContext):
+      do {
+        let response = try self.serializer.serialize(
+          messageContext.message,
+          allocator: context.channel.allocator
+        )
+        context.fireChannelRead(
+          self.wrapInboundOut(.message(.init(response, compressed: messageContext.compressed)))
+        )
+      } catch {
+        context.fireErrorCaught(error)
+      }
+
+    case let .trailingMetadata(trailers):
+      context.fireChannelRead(self.wrapInboundOut(.trailingMetadata(trailers)))
+
+    case let .status(status):
+      context.fireChannelRead(self.wrapInboundOut(.status(status)))
+    }
+  }
+}
+
+extension GRPCClientReverseCodecHandler: ChannelOutboundHandler {
+  typealias OutboundIn = _RawGRPCClientRequestPart
+  typealias OutboundOut = _GRPCClientRequestPart<Deserializer.Output>
+
+  internal func write(
+    context: ChannelHandlerContext,
+    data: NIOAny,
+    promise: EventLoopPromise<Void>?
+  ) {
+    switch self.unwrapOutboundIn(data) {
+    case let .head(head):
+      context.write(self.wrapOutboundOut(.head(head)), promise: promise)
+
+    case let .message(message):
+      do {
+        let deserialized = try self.deserializer.deserialize(byteBuffer: message.message)
+        context.write(
+          self.wrapOutboundOut(.message(.init(deserialized, compressed: message.compressed))),
+          promise: promise
+        )
+      } catch {
+        promise?.fail(error)
+        context.fireErrorCaught(error)
+      }
+
+    case .end:
+      context.write(self.wrapOutboundOut(.end), promise: promise)
+    }
+  }
+}

+ 0 - 1
Sources/GRPCPerformanceTests/Benchmarks/UnaryThroughput.swift

@@ -54,7 +54,6 @@ class Unary: ServerProvidingBenchmark {
       let requests = (lowerBound ..< upperBound).map { _ in
         client.get(Echo_EchoRequest.with { $0.text = self.requestText }).response
       }
-
       try EventLoopFuture.andAllSucceed(requests, on: self.group.next()).wait()
     }
   }

+ 25 - 0
Tests/GRPCTests/ClientTransportTests.swift

@@ -52,6 +52,8 @@ class ClientTransportTests: GRPCTestCase {
       details: details ?? self.makeDetails(),
       eventLoop: self.eventLoop,
       interceptors: interceptors,
+      serializer: AnySerializer(wrapping: StringSerializer()),
+      deserializer: AnyDeserializer(wrapping: StringDeserializer()),
       errorDelegate: nil,
       onError: onError,
       onResponsePart: onResponsePart
@@ -61,6 +63,12 @@ class ClientTransportTests: GRPCTestCase {
   private func configureTransport(additionalHandlers handlers: [ChannelHandler] = []) {
     self.transport.configure {
       var handlers = handlers
+      handlers.append(
+        GRPCClientReverseCodecHandler(
+          serializer: StringSerializer(),
+          deserializer: StringDeserializer()
+        )
+      )
       handlers.append($0)
       return self.channel.pipeline.addHandlers(handlers)
     }
@@ -337,3 +345,20 @@ class WriteRecorder<Write>: ChannelOutboundHandler {
 }
 
 private struct DummyError: Error {}
+
+private struct StringSerializer: MessageSerializer {
+  typealias Input = String
+
+  func serialize(_ input: String, allocator: ByteBufferAllocator) throws -> ByteBuffer {
+    return allocator.buffer(string: input)
+  }
+}
+
+private struct StringDeserializer: MessageDeserializer {
+  typealias Output = String
+
+  func deserialize(byteBuffer: ByteBuffer) throws -> String {
+    var buffer = byteBuffer
+    return buffer.readString(length: buffer.readableBytes)!
+  }
+}