Преглед изворни кода

A more general channel provider (#1162)

Motivation:

In #1158 we pulled connection creation of the connection manager into a
channel provider in order to loosen the coupling between the connection
manager and `ClientConnection`. This change further decouples the
`ConnectionManager` from the channel provider pulling out the relevant
configuration into a `DefaultChannelProvider`.

Modifications:

- Refactor `ClientConnection.ChannelProvider` to rely on the bits of
  configuration it actually requires rather than `ClientConnection.Configuration`
- Rename to `DefaultChannelProvider`

Result:

We can configure channels for the `ConnectionManager` without being tied
to `ClientConnection`.
George Barnett пре 4 година
родитељ
комит
baa4c3cd1c

+ 3 - 0
.swiftformat

@@ -19,6 +19,9 @@
 # Don't indent #if blocks
 --ifdef no-indent
 
+# Don't turn Optional<Foo> into Foo?
+--shortoptionals except-properties
+
 # This rule doesn't always work as we'd expect: specifically when we return a
 # succeeded future whose type is a closure then that closure is incorrectly
 # treated as a trailing closure. This is relevant because the service provider

+ 2 - 2
Sources/GRPC/ConnectionManager.swift

@@ -251,7 +251,7 @@ internal final class ConnectionManager {
   internal convenience init(configuration: ClientConnection.Configuration, logger: Logger) {
     self.init(
       configuration: configuration,
-      channelProvider: ClientConnection.ChannelProvider(configuration: configuration),
+      channelProvider: DefaultChannelProvider(configuration: configuration),
       logger: logger
     )
   }
@@ -297,7 +297,7 @@ internal final class ConnectionManager {
     )
   }
 
-  private init(
+  internal init(
     eventLoop: EventLoop,
     channelProvider: ConnectionManagerChannelProvider,
     callStartBehavior: CallStartBehavior.Behavior,

+ 76 - 34
Sources/GRPC/ConnectionManagerChannelProvider.swift

@@ -15,6 +15,7 @@
  */
 import Logging
 import NIO
+import NIOSSL
 
 internal protocol ConnectionManagerChannelProvider {
   /// Make an `EventLoopFuture<Channel>`.
@@ -32,36 +33,80 @@ internal protocol ConnectionManagerChannelProvider {
   ) -> EventLoopFuture<Channel>
 }
 
-extension ClientConnection {
-  internal struct ChannelProvider {
-    private var configuration: Configuration
+internal struct DefaultChannelProvider: ConnectionManagerChannelProvider {
+  internal var connectionTarget: ConnectionTarget
+  internal var connectionKeepalive: ClientConnectionKeepalive
+  internal var connectionIdleTimeout: TimeAmount
 
-    internal init(configuration: Configuration) {
-      self.configuration = configuration
-    }
+  internal var tlsConfiguration: Optional<TLSConfiguration>
+  internal var tlsHostnameOverride: Optional<String>
+  internal var tlsCustomVerificationCallback: Optional<NIOSSLCustomVerificationCallback>
+
+  internal var httpTargetWindowSize: Int
+
+  internal var errorDelegate: Optional<ClientErrorDelegate>
+  internal var debugChannelInitializer: Optional<(Channel) -> EventLoopFuture<Void>>
+
+  internal init(
+    connectionTarget: ConnectionTarget,
+    connectionKeepalive: ClientConnectionKeepalive,
+    connectionIdleTimeout: TimeAmount,
+    tlsConfiguration: TLSConfiguration?,
+    tlsHostnameOverride: String?,
+    tlsCustomVerificationCallback: NIOSSLCustomVerificationCallback?,
+    httpTargetWindowSize: Int,
+    errorDelegate: ClientErrorDelegate?,
+    debugChannelInitializer: ((Channel) -> EventLoopFuture<Void>)?
+  ) {
+    self.connectionTarget = connectionTarget
+    self.connectionKeepalive = connectionKeepalive
+    self.connectionIdleTimeout = connectionIdleTimeout
+
+    self.tlsConfiguration = tlsConfiguration
+    self.tlsHostnameOverride = tlsHostnameOverride
+    self.tlsCustomVerificationCallback = tlsCustomVerificationCallback
+
+    self.httpTargetWindowSize = httpTargetWindowSize
+
+    self.errorDelegate = errorDelegate
+    self.debugChannelInitializer = debugChannelInitializer
+  }
+
+  internal init(configuration: ClientConnection.Configuration) {
+    self.init(
+      connectionTarget: configuration.target,
+      connectionKeepalive: configuration.connectionKeepalive,
+      connectionIdleTimeout: configuration.connectionIdleTimeout,
+      tlsConfiguration: configuration.tls?.configuration,
+      tlsHostnameOverride: configuration.tls?.hostnameOverride,
+      tlsCustomVerificationCallback: configuration.tls?.customVerificationCallback,
+      httpTargetWindowSize: configuration.httpTargetWindowSize,
+      errorDelegate: configuration.errorDelegate,
+      debugChannelInitializer: configuration.debugChannelInitializer
+    )
+  }
+
+  private var serverHostname: String? {
+    let hostname = self.tlsHostnameOverride ?? self.connectionTarget.host
+    return hostname.isIPAddress ? nil : hostname
+  }
+
+  private var hasTLS: Bool {
+    return self.tlsConfiguration != nil
+  }
+
+  private func requiresZeroLengthWorkaround(eventLoop: EventLoop) -> Bool {
+    return PlatformSupport.requiresZeroLengthWriteWorkaround(group: eventLoop, hasTLS: self.hasTLS)
   }
-}
 
-extension ClientConnection.ChannelProvider: ConnectionManagerChannelProvider {
   internal func makeChannel(
     managedBy connectionManager: ConnectionManager,
     onEventLoop eventLoop: EventLoop,
     connectTimeout: TimeAmount?,
     logger: Logger
   ) -> EventLoopFuture<Channel> {
-    let serverHostname: String? = self.configuration.tls.flatMap { tls -> String? in
-      if let hostnameOverride = tls.hostnameOverride {
-        return hostnameOverride
-      } else {
-        return self.configuration.target.host
-      }
-    }.flatMap { hostname in
-      if hostname.isIPAddress {
-        return nil
-      } else {
-        return hostname
-      }
-    }
+    let hostname = self.serverHostname
+    let needsZeroLengthWriteWorkaround = self.requiresZeroLengthWorkaround(eventLoop: eventLoop)
 
     let bootstrap = PlatformSupport.makeClientBootstrap(group: eventLoop, logger: logger)
       .channelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
@@ -72,26 +117,23 @@ extension ClientConnection.ChannelProvider: ConnectionManagerChannelProvider {
         do {
           try sync.configureGRPCClient(
             channel: channel,
-            httpTargetWindowSize: self.configuration.httpTargetWindowSize,
-            tlsConfiguration: self.configuration.tls?.configuration,
-            tlsServerHostname: serverHostname,
+            httpTargetWindowSize: self.httpTargetWindowSize,
+            tlsConfiguration: self.tlsConfiguration,
+            tlsServerHostname: hostname,
             connectionManager: connectionManager,
-            connectionKeepalive: self.configuration.connectionKeepalive,
-            connectionIdleTimeout: self.configuration.connectionIdleTimeout,
-            errorDelegate: self.configuration.errorDelegate,
-            requiresZeroLengthWriteWorkaround: PlatformSupport.requiresZeroLengthWriteWorkaround(
-              group: eventLoop,
-              hasTLS: self.configuration.tls != nil
-            ),
+            connectionKeepalive: self.connectionKeepalive,
+            connectionIdleTimeout: self.connectionIdleTimeout,
+            errorDelegate: self.errorDelegate,
+            requiresZeroLengthWriteWorkaround: needsZeroLengthWriteWorkaround,
             logger: logger,
-            customVerificationCallback: self.configuration.tls?.customVerificationCallback
+            customVerificationCallback: self.tlsCustomVerificationCallback
           )
         } catch {
           return channel.eventLoop.makeFailedFuture(error)
         }
 
         // Run the debug initializer, if there is one.
-        if let debugInitializer = self.configuration.debugChannelInitializer {
+        if let debugInitializer = self.debugChannelInitializer {
           return debugInitializer(channel)
         } else {
           return channel.eventLoop.makeSucceededVoidFuture()
@@ -102,6 +144,6 @@ extension ClientConnection.ChannelProvider: ConnectionManagerChannelProvider {
       _ = bootstrap.connectTimeout(connectTimeout)
     }
 
-    return bootstrap.connect(to: self.configuration.target)
+    return bootstrap.connect(to: self.connectionTarget)
   }
 }