Browse Source

Add a server debug channel initializer (#911)

Motivation:

Sometimes it's helpful to be able to add additional handlers to the
channel pipeline. We let clients to do this, so we should let servers do
it too.

Modifications:

- Add a `debugChannelInitializer` to server configuration and the server
  builder

Result:

- Users can add additional channel handlers to servers
George Barnett 5 years ago
parent
commit
72a6aca3e6

+ 38 - 14
Sources/GRPC/Server.swift

@@ -115,12 +115,23 @@ public final class Server {
           return channel.pipeline.addHandler(handler)
         }
 
+        let configured: EventLoopFuture<Void>
+
         if let tls = configuration.tls {
-          return channel.configureTLS(configuration: tls).flatMap {
+          configured = channel.configureTLS(configuration: tls).flatMap {
             channel.pipeline.addHandler(protocolSwitcher)
           }
         } else {
-          return channel.pipeline.addHandler(protocolSwitcher)
+          configured = channel.pipeline.addHandler(protocolSwitcher)
+        }
+
+        // Add the debug initializer, if there is one.
+        if let debugAcceptedChannelInitializer = configuration.debugChannelInitializer {
+          return configured.flatMap {
+            debugAcceptedChannelInitializer(channel)
+          }
+        } else {
+          return configured
         }
       }
 
@@ -216,19 +227,30 @@ extension Server {
     /// available to service providers via `context`. Defaults to a no-op logger.
     public var logger: Logger
 
+    /// A channel initializer which will be run after gRPC has initialized each accepted channel.
+    /// This may be used to add additional handlers to the pipeline and is intended for debugging.
+    /// This is analogous to `NIO.ServerBootstrap.childChannelInitializer`.
+    ///
+    /// - Warning: The initializer closure may be invoked *multiple times*. More precisely: it will
+    ///   be invoked at most once per accepted connection.
+    public var debugChannelInitializer: ((Channel) -> EventLoopFuture<Void>)?
+
     /// Create a `Configuration` with some pre-defined defaults.
     ///
-    /// - Parameter target: The target to bind to.
-    /// - Parameter eventLoopGroup: The event loop group to run the server on.
-    /// - Parameter serviceProviders: An array of `CallHandlerProvider`s which the server should use
-    ///     to handle requests.
-    /// - Parameter errorDelegate: The error delegate, defaulting to a logging delegate.
-    /// - Parameter tls: TLS configuration, defaulting to `nil`.
-    /// - Parameter connectionKeepalive: The keepalive configuration to use.
-    /// - Parameter connectionIdleTimeout: The amount of time to wait before closing the connection, defaulting to 5 minutes.
-    /// - Parameter messageEncoding: Message compression configuration, defaulting to no compression.
-    /// - Parameter httpTargetWindowSize: The HTTP/2 flow control target window size.
-    /// - Parameter logger: A logger. Defaults to a no-op logger.
+    /// - Parameters:
+    ///   - target: The target to bind to.
+    ///   -  eventLoopGroup: The event loop group to run the server on.
+    ///   - serviceProviders: An array of `CallHandlerProvider`s which the server should use
+    ///       to handle requests.
+    ///   - errorDelegate: The error delegate, defaulting to a logging delegate.
+    ///   - tls: TLS configuration, defaulting to `nil`.
+    ///   - connectionKeepalive: The keepalive configuration to use.
+    ///   - connectionIdleTimeout: The amount of time to wait before closing the connection, defaulting to 5 minutes.
+    ///   - messageEncoding: Message compression configuration, defaulting to no compression.
+    ///   - httpTargetWindowSize: The HTTP/2 flow control target window size.
+    ///   - logger: A logger. Defaults to a no-op logger.
+    ///   - debugChannelInitializer: A channel initializer which will be called for each connection
+    ///     the server accepts after gRPC has initialized the channel. Defaults to `nil`.
     public init(
       target: BindTarget,
       eventLoopGroup: EventLoopGroup,
@@ -239,7 +261,8 @@ extension Server {
       connectionIdleTimeout: TimeAmount = .minutes(5),
       messageEncoding: ServerMessageEncoding = .disabled,
       httpTargetWindowSize: Int = 65535,
-      logger: Logger = Logger(label: "io.grpc", factory: { _ in SwiftLogNoOpLogHandler() })
+      logger: Logger = Logger(label: "io.grpc", factory: { _ in SwiftLogNoOpLogHandler() }),
+      debugChannelInitializer: ((Channel) -> EventLoopFuture<Void>)? = nil
     ) {
       self.target = target
       self.eventLoopGroup = eventLoopGroup
@@ -251,6 +274,7 @@ extension Server {
       self.messageEncoding = messageEncoding
       self.httpTargetWindowSize = httpTargetWindowSize
       self.logger = logger
+      self.debugChannelInitializer = debugChannelInitializer
     }
   }
 }

+ 16 - 0
Sources/GRPC/ServerBuilder.swift

@@ -142,6 +142,22 @@ extension Server.Builder {
   }
 }
 
+extension Server.Builder {
+  /// A channel initializer which will be run after gRPC has initialized each accepted channel.
+  /// This may be used to add additional handlers to the pipeline and is intended for debugging.
+  /// This is analogous to `NIO.ServerBootstrap.childChannelInitializer`.
+  ///
+  /// - Warning: The initializer closure may be invoked *multiple times*. More precisely: it will
+  ///   be invoked at most once per accepted connection.
+  @discardableResult
+  public func withDebugChannelInitializer(
+    _ debugChannelInitializer: @escaping (Channel) -> EventLoopFuture<Void>
+  ) -> Self {
+    self.configuration.debugChannelInitializer = debugChannelInitializer
+    return self
+  }
+}
+
 extension Server {
   /// Returns an insecure `Server` builder which is *not configured with TLS*.
   public static func insecure(group: EventLoopGroup) -> Builder {

+ 11 - 5
Tests/GRPCTests/ClientDebugChannelInitializerTests.swift → Tests/GRPCTests/DebugChannelInitializerTests.swift

@@ -20,26 +20,31 @@ import NIO
 import NIOConcurrencyHelpers
 import XCTest
 
-class ClientDebugChannelInitializerTests: GRPCTestCase {
+class DebugChannelInitializerTests: GRPCTestCase {
   func testDebugChannelInitializerIsCalled() throws {
     let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
     defer {
       XCTAssertNoThrow(try group.syncShutdownGracefully())
     }
 
+    let serverDebugInitializerCalled = group.next().makePromise(of: Void.self)
     let server = try Server.insecure(group: group)
       .withServiceProviders([EchoProvider()])
+      .withDebugChannelInitializer { channel in
+        serverDebugInitializerCalled.succeed(())
+        return channel.eventLoop.makeSucceededFuture(())
+      }
       .bind(host: "localhost", port: 0)
       .wait()
     defer {
       XCTAssertNoThrow(try server.close().wait())
     }
 
-    let debugInitializerCalled = group.next().makePromise(of: Void.self)
+    let clientDebugInitializerCalled = group.next().makePromise(of: Void.self)
     let connection = ClientConnection.insecure(group: group)
       .withBackgroundActivityLogger(self.clientLogger)
       .withDebugChannelInitializer { channel in
-        debugInitializerCalled.succeed(())
+        clientDebugInitializerCalled.succeed(())
         return channel.eventLoop.makeSucceededFuture(())
       }
       .connect(host: "localhost", port: server.channel.localAddress!.port!)
@@ -52,7 +57,8 @@ class ClientDebugChannelInitializerTests: GRPCTestCase {
     let get = echo.get(.with { $0.text = "Hello!" })
     XCTAssertTrue(try get.status.map { $0.isOk }.wait())
 
-    // Check the initializer was called.
-    XCTAssertNoThrow(try debugInitializerCalled.futureResult.wait())
+    // Check the initializers were called.
+    XCTAssertNoThrow(try clientDebugInitializerCalled.futureResult.wait())
+    XCTAssertNoThrow(try serverDebugInitializerCalled.futureResult.wait())
   }
 }

+ 10 - 10
Tests/GRPCTests/XCTestManifests.swift

@@ -80,15 +80,6 @@ extension ClientConnectionBackoffTests {
     ]
 }
 
-extension ClientDebugChannelInitializerTests {
-    // DO NOT MODIFY: This is autogenerated, use:
-    //   `swift test --generate-linuxmain`
-    // to regenerate.
-    static let __allTests__ClientDebugChannelInitializerTests = [
-        ("testDebugChannelInitializerIsCalled", testDebugChannelInitializerIsCalled),
-    ]
-}
-
 extension ClientTLSFailureTests {
     // DO NOT MODIFY: This is autogenerated, use:
     //   `swift test --generate-linuxmain`
@@ -187,6 +178,15 @@ extension ConnectivityStateMonitorTests {
     ]
 }
 
+extension DebugChannelInitializerTests {
+    // DO NOT MODIFY: This is autogenerated, use:
+    //   `swift test --generate-linuxmain`
+    // to regenerate.
+    static let __allTests__DebugChannelInitializerTests = [
+        ("testDebugChannelInitializerIsCalled", testDebugChannelInitializerIsCalled),
+    ]
+}
+
 extension DelegatingErrorHandlerTests {
     // DO NOT MODIFY: This is autogenerated, use:
     //   `swift test --generate-linuxmain`
@@ -921,7 +921,6 @@ public func __allTests() -> [XCTestCaseEntry] {
         testCase(ClientCancellingTests.__allTests__ClientCancellingTests),
         testCase(ClientClosedChannelTests.__allTests__ClientClosedChannelTests),
         testCase(ClientConnectionBackoffTests.__allTests__ClientConnectionBackoffTests),
-        testCase(ClientDebugChannelInitializerTests.__allTests__ClientDebugChannelInitializerTests),
         testCase(ClientTLSFailureTests.__allTests__ClientTLSFailureTests),
         testCase(ClientTLSHostnameOverrideTests.__allTests__ClientTLSHostnameOverrideTests),
         testCase(ClientThrowingWhenServerReturningErrorTests.__allTests__ClientThrowingWhenServerReturningErrorTests),
@@ -929,6 +928,7 @@ public func __allTests() -> [XCTestCaseEntry] {
         testCase(ConnectionBackoffTests.__allTests__ConnectionBackoffTests),
         testCase(ConnectionManagerTests.__allTests__ConnectionManagerTests),
         testCase(ConnectivityStateMonitorTests.__allTests__ConnectivityStateMonitorTests),
+        testCase(DebugChannelInitializerTests.__allTests__DebugChannelInitializerTests),
         testCase(DelegatingErrorHandlerTests.__allTests__DelegatingErrorHandlerTests),
         testCase(EchoTestClientTests.__allTests__EchoTestClientTests),
         testCase(FakeChannelTests.__allTests__FakeChannelTests),