Browse Source

Update the codegen to use the new handlers.

Motivation:

We have new server handlers; all we need to do is generate the code to
use them.

Modifications:

- Update codegen
- Fixup interceptor tests and remove tests already covered by
  {Unary,...}HandlerTests
- Deprecate `handleMethod` with a message to re-generate server code
- Regenerate our own code

Result:

Instruction count reductions:
- unary_10k_small_requests: -4.7%
- embedded_server_unary_10k_small_requests: -25.4%
- embedded_server_client_streaming_1_rpc_10k_small_requests: -35.6%
- embedded_server_client_streaming_10k_rpcs_1_small_requests: -25.1%
- embedded_server_server_streaming_1_rpc_10k_small_responses: -14.1%
- embedded_server_server_streaming_10k_rpcs_1_small_response: -21.1%
- embedded_server_bidi_1_rpc_10k_small_requests: -29.2%
- embedded_server_bidi_10k_rpcs_1_small_request: -23.9%
George Barnett 5 years ago
parent
commit
56ec6db08d

+ 8 - 0
Sources/GRPC/CallHandlers/CallHandlerFactory.swift

@@ -22,6 +22,7 @@ public enum CallHandlerFactory {
   public typealias UnaryContext<Response> = UnaryResponseCallContext<Response>
   public typealias UnaryEventObserver<Request, Response> = (Request) -> EventLoopFuture<Response>
 
+  @available(*, deprecated, message: "Please regenerate your server code.")
   @inlinable
   public static func makeUnary<Request: Message, Response: Message>(
     callHandlerContext: CallHandlerContext,
@@ -38,6 +39,7 @@ public enum CallHandlerFactory {
     )
   }
 
+  @available(*, deprecated, message: "Please regenerate your server code.")
   @inlinable
   public static func makeUnary<Request: GRPCPayload, Response: GRPCPayload>(
     callHandlerContext: CallHandlerContext,
@@ -58,6 +60,7 @@ public enum CallHandlerFactory {
   public typealias ClientStreamingEventObserver<Request> =
     EventLoopFuture<(StreamEvent<Request>) -> Void>
 
+  @available(*, deprecated, message: "Please regenerate your server code.")
   @inlinable
   public static func makeClientStreaming<Request: Message, Response: Message>(
     callHandlerContext: CallHandlerContext,
@@ -74,6 +77,7 @@ public enum CallHandlerFactory {
     )
   }
 
+  @available(*, deprecated, message: "Please regenerate your server code.")
   @inlinable
   public static func makeClientStreaming<Request: GRPCPayload, Response: GRPCPayload>(
     callHandlerContext: CallHandlerContext,
@@ -96,6 +100,7 @@ public enum CallHandlerFactory {
   public typealias ServerStreamingContext<Response> = StreamingResponseCallContext<Response>
   public typealias ServerStreamingEventObserver<Request> = (Request) -> EventLoopFuture<GRPCStatus>
 
+  @available(*, deprecated, message: "Please regenerate your server code.")
   @inlinable
   public static func makeServerStreaming<Request: Message, Response: Message>(
     callHandlerContext: CallHandlerContext,
@@ -112,6 +117,7 @@ public enum CallHandlerFactory {
     )
   }
 
+  @available(*, deprecated, message: "Please regenerate your server code.")
   @inlinable
   public static func makeServerStreaming<Request: GRPCPayload, Response: GRPCPayload>(
     callHandlerContext: CallHandlerContext,
@@ -135,6 +141,7 @@ public enum CallHandlerFactory {
   public typealias BidirectionalStreamingEventObserver<Request> =
     EventLoopFuture<(StreamEvent<Request>) -> Void>
 
+  @available(*, deprecated, message: "Please regenerate your server code.")
   @inlinable
   public static func makeBidirectionalStreaming<Request: Message, Response: Message>(
     callHandlerContext: CallHandlerContext,
@@ -154,6 +161,7 @@ public enum CallHandlerFactory {
     )
   }
 
+  @available(*, deprecated, message: "Please regenerate your server code.")
   @inlinable
   public static func makeBidirectionalStreaming<Request: GRPCPayload, Response: GRPCPayload>(
     callHandlerContext: CallHandlerContext,

+ 8 - 0
Sources/GRPC/GRPCServerRequestRoutingHandler.swift

@@ -53,6 +53,14 @@ extension CallHandlerProvider {
   ) -> GRPCServerHandlerProtocol? {
     return nil
   }
+
+  // TODO: remove this once we've removed 'handleMethod(_:callHandlerContext:)'.
+  public func handleMethod(
+    _ methodName: Substring,
+    callHandlerContext: CallHandlerContext
+  ) -> GRPCCallHandler? {
+    return nil
+  }
 }
 
 // This is public because it will be passed into generated code, all members are `internal` because

+ 17 - 24
Sources/protoc-gen-grpc-swift/Generator-Server.swift

@@ -78,15 +78,15 @@ extension Generator {
       )
       self.println("/// Returns nil for methods not handled by this service.")
       self.printFunction(
-        name: "handleMethod",
+        name: "handle",
         arguments: [
-          "_ methodName: Substring",
-          "callHandlerContext: CallHandlerContext",
+          "method name: Substring",
+          "context: CallHandlerContext",
         ],
-        returnType: "GRPCCallHandler?",
+        returnType: "GRPCServerHandlerProtocol?",
         access: self.access
       ) {
-        self.println("switch methodName {")
+        self.println("switch name {")
         for method in self.service.methods {
           self.method = method
           self.println("case \"\(method.name)\":")
@@ -95,38 +95,31 @@ extension Generator {
             let callHandlerType: String
             switch streamingType(method) {
             case .unary:
-              callHandlerType = "CallHandlerFactory.makeUnary"
+              callHandlerType = "UnaryServerHandler"
             case .serverStreaming:
-              callHandlerType = "CallHandlerFactory.makeServerStreaming"
+              callHandlerType = "ServerStreamingServerHandler"
             case .clientStreaming:
-              callHandlerType = "CallHandlerFactory.makeClientStreaming"
+              callHandlerType = "ClientStreamingServerHandler"
             case .bidirectionalStreaming:
-              callHandlerType = "CallHandlerFactory.makeBidirectionalStreaming"
+              callHandlerType = "BidirectionalStreamingServerHandler"
             }
 
             self.println("return \(callHandlerType)(")
             self.withIndentation {
-              self.println("callHandlerContext: callHandlerContext,")
+              self.println("context: context,")
+              self.println("requestDeserializer: ProtobufDeserializer<\(self.methodInputName)>(),")
+              self.println("responseSerializer: ProtobufSerializer<\(self.methodOutputName)>(),")
               self.println(
-                "interceptors: self.interceptors?.\(self.methodInterceptorFactoryName)() ?? []"
+                "interceptors: self.interceptors?.\(self.methodInterceptorFactoryName)() ?? [],"
               )
-            }
-            self.println(") { context in")
-            self.withIndentation {
-              switch streamingType(self.method) {
+              switch streamingType(method) {
               case .unary, .serverStreaming:
-                self.println("return { request in")
-                self.withIndentation {
-                  self.println(
-                    "self.\(self.methodFunctionName)(request: request, context: context)"
-                  )
-                }
-                self.println("}")
+                self.println("userFunction: self.\(self.methodFunctionName)(request:context:)")
               case .clientStreaming, .bidirectionalStreaming:
-                self.println("self.\(self.methodFunctionName)(context: context)")
+                self.println("observerFactory: self.\(self.methodFunctionName)(context:)")
               }
             }
-            self.println("}")
+            self.println(")")
           }
           self.println()
         }

+ 62 - 165
Tests/GRPCTests/ServerInterceptorTests.swift

@@ -22,15 +22,24 @@ import NIOHTTP1
 import SwiftProtobuf
 import XCTest
 
-class ServerInterceptorTests: GRPCTestCase {
-  private var channel: EmbeddedChannel!
-
-  override func setUp() {
-    super.setUp()
-    self.channel = EmbeddedChannel()
+extension GRPCServerHandlerProtocol {
+  fileprivate func receiveRequest(_ request: Echo_EchoRequest) {
+    let serializer = ProtobufSerializer<Echo_EchoRequest>()
+    do {
+      let buffer = try serializer.serialize(request, allocator: ByteBufferAllocator())
+      self.receiveMessage(buffer)
+    } catch {
+      XCTFail("Unexpected error: \(error)")
+    }
   }
+}
+
+class ServerInterceptorTests: GRPCTestCase {
+  private let eventLoop = EmbeddedEventLoop()
+  private let recorder = ResponseRecorder()
 
-  private func makeRecorder() -> RecordingServerInterceptor<Echo_EchoRequest, Echo_EchoResponse> {
+  private func makeRecordingInterceptor()
+    -> RecordingServerInterceptor<Echo_EchoRequest, Echo_EchoResponse> {
     return .init()
   }
 
@@ -45,9 +54,9 @@ class ServerInterceptorTests: GRPCTestCase {
       errorDelegate: nil,
       logger: self.serverLogger,
       encoding: .disabled,
-      eventLoop: self.channel.eventLoop,
+      eventLoop: self.eventLoop,
       path: path,
-      responseWriter: NoOpResponseWriter(),
+      responseWriter: self.recorder,
       allocator: ByteBufferAllocator()
     )
   }
@@ -62,215 +71,103 @@ class ServerInterceptorTests: GRPCTestCase {
   private func handleMethod(
     _ method: Substring,
     using provider: CallHandlerProvider
-  ) -> GRPCCallHandler? {
+  ) -> GRPCServerHandlerProtocol? {
     let path = "/\(provider.serviceName)/\(method)"
     let context = self.makeHandlerContext(for: path)
-    return provider.handleMethod(method, callHandlerContext: context)
+    return provider.handle(method: method, context: context)
   }
 
   fileprivate typealias ResponsePart = GRPCServerResponsePart<Echo_EchoResponse>
 
   func testPassThroughInterceptor() throws {
-    let recorder = self.makeRecorder()
-    let provider = self.echoProvider(interceptedBy: recorder)
+    let recordingInterceptor = self.makeRecordingInterceptor()
+    let provider = self.echoProvider(interceptedBy: recordingInterceptor)
 
     let handler = try assertNotNil(self.handleMethod("Get", using: provider))
-    assertThat(try self.channel.pipeline.addHandlers([Codec(), handler]).wait(), .doesNotThrow())
 
     // Send requests.
-    assertThat(try self.channel.writeInbound(self.request(.metadata([:]))), .doesNotThrow())
-    assertThat(
-      try self.channel.writeInbound(self.request(.message(.with { $0.text = "" }))),
-      .doesNotThrow()
-    )
-    assertThat(try self.channel.writeInbound(self.request(.end)), .doesNotThrow())
+    handler.receiveMetadata([:])
+    handler.receiveRequest(.with { $0.text = "" })
+    handler.receiveEnd()
 
     // Expect responses.
-    assertThat(try self.channel.readOutbound(as: ResponsePart.self), .notNil(.metadata()))
-    assertThat(try self.channel.readOutbound(as: ResponsePart.self), .notNil(.message()))
-    assertThat(try self.channel.readOutbound(as: ResponsePart.self), .notNil(.end()))
+    assertThat(self.recorder.metadata, .is(.notNil()))
+    assertThat(self.recorder.messages.count, .is(1))
+    assertThat(self.recorder.status, .is(.notNil()))
 
     // We expect 2 request parts: the provider responds before it sees end, that's fine.
-    assertThat(recorder.requestParts, .hasCount(2))
-    assertThat(recorder.requestParts[0], .is(.metadata()))
-    assertThat(recorder.requestParts[1], .is(.message()))
-
-    assertThat(recorder.responseParts, .hasCount(3))
-    assertThat(recorder.responseParts[0], .is(.metadata()))
-    assertThat(recorder.responseParts[1], .is(.message()))
-    assertThat(recorder.responseParts[2], .is(.end(status: .is(.ok))))
-  }
-
-  func _testExtraRequestPartsAreIgnored(
-    part: ExtraRequestPartEmitter.Part,
-    callType: GRPCCallType
-  ) throws {
-    let interceptor = ExtraRequestPartEmitter(repeat: part, times: 3)
-    let provider = self.echoProvider(interceptedBy: interceptor)
-
-    let method: Substring
-
-    switch callType {
-    case .unary:
-      method = "Get"
-    case .clientStreaming:
-      method = "Collect"
-    case .serverStreaming:
-      method = "Expand"
-    case .bidirectionalStreaming:
-      method = "Update"
-    }
-
-    let handler = try assertNotNil(self.handleMethod(method, using: provider))
-    assertThat(try self.channel.pipeline.addHandlers([Codec(), handler]).wait(), .doesNotThrow())
-
-    // Send the requests.
-    assertThat(try self.channel.writeInbound(self.request(.metadata([:]))), .doesNotThrow())
-    assertThat(try self.channel.writeInbound(self.request(.message(.init()))), .doesNotThrow())
-    assertThat(try self.channel.writeInbound(self.request(.end)), .doesNotThrow())
-
-    // Expect the responses.
-    assertThat(try self.channel.readOutbound(as: ResponsePart.self), .notNil(.metadata()))
-    assertThat(try self.channel.readOutbound(as: ResponsePart.self), .notNil(.message()))
-    assertThat(try self.channel.readOutbound(as: ResponsePart.self), .notNil(.end()))
-    // No more response parts.
-    assertThat(try self.channel.readOutbound(as: ResponsePart.self), .is(.nil()))
-  }
-
-  func testExtraRequestMetadataIsIgnoredForUnary() throws {
-    try self._testExtraRequestPartsAreIgnored(part: .metadata, callType: .unary)
-  }
-
-  func testExtraRequestMessageIsIgnoredForUnary() throws {
-    try self._testExtraRequestPartsAreIgnored(part: .message, callType: .unary)
-  }
-
-  func testExtraRequestEndIsIgnoredForUnary() throws {
-    try self._testExtraRequestPartsAreIgnored(part: .end, callType: .unary)
-  }
-
-  func testExtraRequestMetadataIsIgnoredForClientStreaming() throws {
-    try self._testExtraRequestPartsAreIgnored(part: .metadata, callType: .clientStreaming)
-  }
+    assertThat(recordingInterceptor.requestParts, .hasCount(2))
+    assertThat(recordingInterceptor.requestParts[0], .is(.metadata()))
+    assertThat(recordingInterceptor.requestParts[1], .is(.message()))
 
-  func testExtraRequestEndIsIgnoredForClientStreaming() throws {
-    try self._testExtraRequestPartsAreIgnored(part: .end, callType: .clientStreaming)
-  }
-
-  func testExtraRequestMetadataIsIgnoredForServerStreaming() throws {
-    try self._testExtraRequestPartsAreIgnored(part: .metadata, callType: .serverStreaming)
-  }
-
-  func testExtraRequestMessageIsIgnoredForServerStreaming() throws {
-    try self._testExtraRequestPartsAreIgnored(part: .message, callType: .serverStreaming)
-  }
-
-  func testExtraRequestEndIsIgnoredForServerStreaming() throws {
-    try self._testExtraRequestPartsAreIgnored(part: .end, callType: .serverStreaming)
-  }
-
-  func testExtraRequestMetadataIsIgnoredForBidirectionalStreaming() throws {
-    try self._testExtraRequestPartsAreIgnored(part: .metadata, callType: .bidirectionalStreaming)
-  }
-
-  func testExtraRequestEndIsIgnoredForBidirectionalStreaming() throws {
-    try self._testExtraRequestPartsAreIgnored(part: .end, callType: .bidirectionalStreaming)
+    assertThat(recordingInterceptor.responseParts, .hasCount(3))
+    assertThat(recordingInterceptor.responseParts[0], .is(.metadata()))
+    assertThat(recordingInterceptor.responseParts[1], .is(.message()))
+    assertThat(recordingInterceptor.responseParts[2], .is(.end(status: .is(.ok))))
   }
 
   func testUnaryFromInterceptor() throws {
     let provider = EchoFromInterceptor()
     let handler = try assertNotNil(self.handleMethod("Get", using: provider))
-    assertThat(try self.channel.pipeline.addHandlers([Codec(), handler]).wait(), .doesNotThrow())
 
     // Send the requests.
-    assertThat(try self.channel.writeInbound(self.request(.metadata([:]))), .doesNotThrow())
-    assertThat(
-      try self.channel.writeInbound(self.request(.message(.init(text: "foo")))),
-      .doesNotThrow()
-    )
-    assertThat(try self.channel.writeInbound(self.request(.end)), .doesNotThrow())
+    handler.receiveMetadata([:])
+    handler.receiveRequest(.with { $0.text = "foo" })
+    handler.receiveEnd()
 
     // Get the responses.
-    assertThat(try self.channel.readOutbound(as: ResponsePart.self), .notNil(.metadata()))
-    assertThat(
-      try self.channel.readOutbound(as: ResponsePart.self),
-      .notNil(.message(.equalTo(.with { $0.text = "echo: foo" })))
-    )
-    assertThat(try self.channel.readOutbound(as: ResponsePart.self), .notNil(.end()))
+    assertThat(self.recorder.metadata, .is(.notNil()))
+    assertThat(self.recorder.messages.count, .is(1))
+    assertThat(self.recorder.status, .is(.notNil()))
   }
 
   func testClientStreamingFromInterceptor() throws {
     let provider = EchoFromInterceptor()
     let handler = try assertNotNil(self.handleMethod("Collect", using: provider))
-    assertThat(try self.channel.pipeline.addHandlers([Codec(), handler]).wait(), .doesNotThrow())
 
     // Send the requests.
-    assertThat(try self.channel.writeInbound(self.request(.metadata([:]))), .doesNotThrow())
+    handler.receiveMetadata([:])
     for text in ["a", "b", "c"] {
-      let message = self.request(.message(.init(text: text)))
-      assertThat(try self.channel.writeInbound(message), .doesNotThrow())
+      handler.receiveRequest(.with { $0.text = text })
     }
-    assertThat(try self.channel.writeInbound(self.request(.end)), .doesNotThrow())
+    handler.receiveEnd()
 
-    // Receive responses.
-    assertThat(try self.channel.readOutbound(as: ResponsePart.self), .notNil(.metadata()))
-    assertThat(
-      try self.channel.readOutbound(as: ResponsePart.self),
-      .notNil(.message(.equalTo(.with { $0.text = "echo: a b c" })))
-    )
-    assertThat(try self.channel.readOutbound(as: ResponsePart.self), .notNil(.end()))
+    // Get the responses.
+    assertThat(self.recorder.metadata, .is(.notNil()))
+    assertThat(self.recorder.messages.count, .is(1))
+    assertThat(self.recorder.status, .is(.notNil()))
   }
 
   func testServerStreamingFromInterceptor() throws {
     let provider = EchoFromInterceptor()
     let handler = try assertNotNil(self.handleMethod("Expand", using: provider))
-    assertThat(try self.channel.pipeline.addHandlers([Codec(), handler]).wait(), .doesNotThrow())
 
     // Send the requests.
-    assertThat(try self.channel.writeInbound(self.request(.metadata([:]))), .doesNotThrow())
-    assertThat(
-      try self.channel.writeInbound(self.request(.message(.with { $0.text = "a b c" }))),
-      .doesNotThrow()
-    )
-    assertThat(try self.channel.writeInbound(self.request(.end)), .doesNotThrow())
+    handler.receiveMetadata([:])
+    handler.receiveRequest(.with { $0.text = "a b c" })
+    handler.receiveEnd()
 
-    // Receive responses.
-    assertThat(try self.channel.readOutbound(as: ResponsePart.self), .notNil(.metadata()))
-    for text in ["a", "b", "c"] {
-      let expected = Echo_EchoResponse(text: "echo: " + text)
-      assertThat(
-        try self.channel.readOutbound(as: ResponsePart.self),
-        .notNil(.message(.equalTo(expected)))
-      )
-    }
-    assertThat(try self.channel.readOutbound(as: ResponsePart.self), .notNil(.end()))
+    // Get the responses.
+    assertThat(self.recorder.metadata, .is(.notNil()))
+    assertThat(self.recorder.messages.count, .is(3))
+    assertThat(self.recorder.status, .is(.notNil()))
   }
 
   func testBidirectionalStreamingFromInterceptor() throws {
     let provider = EchoFromInterceptor()
     let handler = try assertNotNil(self.handleMethod("Update", using: provider))
-    assertThat(try self.channel.pipeline.addHandlers([Codec(), handler]).wait(), .doesNotThrow())
 
     // Send the requests.
-    assertThat(try self.channel.writeInbound(self.request(.metadata([:]))), .doesNotThrow())
+    handler.receiveMetadata([:])
     for text in ["a", "b", "c"] {
-      assertThat(
-        try self.channel.writeInbound(self.request(.message(.init(text: text)))),
-        .doesNotThrow()
-      )
+      handler.receiveRequest(.with { $0.text = text })
     }
-    assertThat(try self.channel.writeInbound(self.request(.end)), .doesNotThrow())
+    handler.receiveEnd()
 
-    // Receive responses.
-    assertThat(try self.channel.readOutbound(as: ResponsePart.self), .notNil(.metadata()))
-    for text in ["a", "b", "c"] {
-      let expected = Echo_EchoResponse(text: "echo: " + text)
-      assertThat(
-        try self.channel.readOutbound(as: ResponsePart.self),
-        .notNil(.message(.equalTo(expected)))
-      )
-    }
-    assertThat(try self.channel.readOutbound(as: ResponsePart.self), .notNil(.end()))
+    // Get the responses.
+    assertThat(self.recorder.metadata, .is(.notNil()))
+    assertThat(self.recorder.messages.count, .is(3))
+    assertThat(self.recorder.status, .is(.notNil()))
   }
 }
 

+ 1 - 1
Tests/GRPCTests/ServerWebTests.swift

@@ -96,7 +96,7 @@ extension ServerWebTests {
   func testUnaryWithoutRequestMessage() {
     let expectedData = self.gRPCWebTrailers(
       status: 13,
-      message: "Request stream cardinality violation"
+      message: "Protocol violation: End received before message"
     )
 
     let expectedResponse = expectedData.base64EncodedString()