Jelajahi Sumber

Pull channel creation out of `ConnectionManager` (#1158)

Motivation:

The `ConnectionManager` is initialized with a
`ClientConnection.Configuration` struct because it needs enough "stuff"
to make `NIO.Channel`s with. That's not really information it needs to
hold on to, it just needs a way to make channels. Moreover, it limits
the connection manager to using only that configuration object. That's a
bit of a pain if we want to add a connection pool with a different
configuration object.

Modifications:

- Pull out the channel creation into a `ConnectionManagerChannelProvider`
  protocol.
- Move the `ClientConnection.Configuration` based channel creation into
  an object conforming to the above.

Result:

Looser coupling.
George Barnett 4 tahun lalu
induk
melakukan
ab9165ea92

+ 76 - 92
Sources/GRPC/ConnectionManager.swift

@@ -19,7 +19,7 @@ import NIO
 import NIOConcurrencyHelpers
 import NIOHTTP2
 
-internal class ConnectionManager {
+internal final class ConnectionManager {
   internal enum Reconnect {
     case none
     case after(TimeInterval)
@@ -203,10 +203,25 @@ internal class ConnectionManager {
     }
   }
 
+  /// The `EventLoop` that the managed connection will run on.
   internal let eventLoop: EventLoop
+
+  /// A connectivity state monitor.
   internal let monitor: ConnectivityStateMonitor
+
+  /// An `EventLoopFuture<Channel>` provider.
+  private let channelProvider: ConnectionManagerChannelProvider
+
+  /// The behavior for starting a call, i.e. how patient is the caller when asking for a
+  /// multiplexer.
+  private let callStartBehavior: CallStartBehavior.Behavior
+
+  /// The configuration to use when backing off between connection attempts, if reconnection
+  /// attempts should be made at all.
+  private let connectionBackoff: ConnectionBackoff?
+
+  /// A logger.
   internal var logger: Logger
-  private let configuration: ClientConnection.Configuration
 
   private let connectionID: String
   private var channelNumber: UInt64
@@ -233,11 +248,12 @@ internal class ConnectionManager {
     logger[metadataKey: MetadataKey.connectionID] = "\(self.connectionIDAndNumber)"
   }
 
-  // Only used for testing.
-  private var channelProvider: (() -> EventLoopFuture<Channel>)?
-
   internal convenience init(configuration: ClientConnection.Configuration, logger: Logger) {
-    self.init(configuration: configuration, logger: logger, channelProvider: nil)
+    self.init(
+      configuration: configuration,
+      channelProvider: ClientConnection.ChannelProvider(configuration: configuration),
+      logger: logger
+    )
   }
 
   /// Create a `ConnectionManager` for testing: uses the given `channelProvider` to create channels.
@@ -246,17 +262,49 @@ internal class ConnectionManager {
     logger: Logger,
     channelProvider: @escaping () -> EventLoopFuture<Channel>
   ) -> ConnectionManager {
+    struct Wrapper: ConnectionManagerChannelProvider {
+      var callback: () -> EventLoopFuture<Channel>
+      func makeChannel(
+        managedBy connectionManager: ConnectionManager,
+        onEventLoop eventLoop: EventLoop,
+        connectTimeout: TimeAmount?,
+        logger: Logger
+      ) -> EventLoopFuture<Channel> {
+        return self.callback().hop(to: eventLoop)
+      }
+    }
+
     return ConnectionManager(
       configuration: configuration,
-      logger: logger,
-      channelProvider: channelProvider
+      channelProvider: Wrapper(callback: channelProvider),
+      logger: logger
     )
   }
 
-  private init(
+  private convenience init(
     configuration: ClientConnection.Configuration,
-    logger: Logger,
-    channelProvider: (() -> EventLoopFuture<Channel>)?
+    channelProvider: ConnectionManagerChannelProvider,
+    logger: Logger
+  ) {
+    self.init(
+      eventLoop: configuration.eventLoopGroup.next(),
+      channelProvider: channelProvider,
+      callStartBehavior: configuration.callStartBehavior.wrapped,
+      connectionBackoff: configuration.connectionBackoff,
+      connectivityStateDelegate: configuration.connectivityStateDelegate,
+      connectivityStateDelegateQueue: configuration.connectivityStateDelegateQueue,
+      logger: logger
+    )
+  }
+
+  private init(
+    eventLoop: EventLoop,
+    channelProvider: ConnectionManagerChannelProvider,
+    callStartBehavior: CallStartBehavior.Behavior,
+    connectionBackoff: ConnectionBackoff?,
+    connectivityStateDelegate: ConnectivityStateDelegate?,
+    connectivityStateDelegateQueue: DispatchQueue?,
+    logger: Logger
   ) {
     // Setup the logger.
     var logger = logger
@@ -264,16 +312,16 @@ internal class ConnectionManager {
     let channelNumber: UInt64 = 0
     logger[metadataKey: MetadataKey.connectionID] = "\(connectionID)/\(channelNumber)"
 
-    let eventLoop = configuration.eventLoopGroup.next()
     self.eventLoop = eventLoop
     self.state = .idle
-    self.monitor = ConnectivityStateMonitor(
-      delegate: configuration.connectivityStateDelegate,
-      queue: configuration.connectivityStateDelegateQueue
-    )
-    self.configuration = configuration
 
     self.channelProvider = channelProvider
+    self.callStartBehavior = callStartBehavior
+    self.connectionBackoff = connectionBackoff
+    self.monitor = ConnectivityStateMonitor(
+      delegate: connectivityStateDelegate,
+      queue: connectivityStateDelegateQueue
+    )
 
     self.connectionID = connectionID
     self.channelNumber = channelNumber
@@ -285,7 +333,7 @@ internal class ConnectionManager {
   /// one chance to connect - if not reconnections are managed here.
   internal func getHTTP2Multiplexer() -> EventLoopFuture<HTTP2StreamMultiplexer> {
     func getHTTP2Multiplexer0() -> EventLoopFuture<HTTP2StreamMultiplexer> {
-      switch self.configuration.callStartBehavior.wrapped {
+      switch self.callStartBehavior {
       case .waitsForConnectivity:
         return self.getHTTP2MultiplexerPatient()
       case .fastFailure:
@@ -564,7 +612,7 @@ internal class ConnectionManager {
     // the channel?
     case let .ready(ready):
       // No, no backoff is configured.
-      if self.configuration.connectionBackoff == nil {
+      if self.connectionBackoff == nil {
         self.logger.debug("shutting down connection, no reconnect configured/remaining")
         self.state = .shutdown(
           ShutdownState(
@@ -581,7 +629,7 @@ internal class ConnectionManager {
           self.startConnecting()
         }
         self.logger.debug("scheduling connection attempt", metadata: ["delay": "0"])
-        let backoffIterator = self.configuration.connectionBackoff?.makeIterator()
+        let backoffIterator = self.connectionBackoff?.makeIterator()
         self.state = .transientFailure(TransientFailureState(
           from: ready,
           scheduled: scheduled,
@@ -747,7 +795,7 @@ extension ConnectionManager {
   private func startConnecting() {
     switch self.state {
     case .idle:
-      let iterator = self.configuration.connectionBackoff?.makeIterator()
+      let iterator = self.connectionBackoff?.makeIterator()
       self.startConnecting(
         backoffIterator: iterator,
         muxPromise: self.eventLoop.makePromise()
@@ -788,12 +836,17 @@ extension ConnectionManager {
     self.eventLoop.assertInEventLoop()
 
     let candidate: EventLoopFuture<Channel> = self.eventLoop.flatSubmit {
-      let channel = self.makeChannel(
-        connectTimeout: timeoutAndBackoff?.timeout
+      let channel: EventLoopFuture<Channel> = self.channelProvider.makeChannel(
+        managedBy: self,
+        onEventLoop: self.eventLoop,
+        connectTimeout: timeoutAndBackoff.map { .seconds(timeInterval: $0.timeout) },
+        logger: self.logger
       )
+
       channel.whenFailure { error in
         self.connectionFailed(withError: error)
       }
+
       return channel
     }
 
@@ -820,72 +873,3 @@ extension ConnectionManager {
     preconditionFailure("Invalid state \(self.state) for \(function)", file: file, line: line)
   }
 }
-
-extension ConnectionManager {
-  private func makeBootstrap(
-    connectTimeout: TimeInterval?
-  ) -> ClientBootstrapProtocol {
-    let serverHostname: String? = self.configuration.tls.flatMap { tls -> String? in
-      if let hostnameOverride = tls.hostnameOverride {
-        return hostnameOverride
-      } else {
-        return configuration.target.host
-      }
-    }.flatMap { hostname in
-      if hostname.isIPAddress {
-        return nil
-      } else {
-        return hostname
-      }
-    }
-
-    let bootstrap = PlatformSupport.makeClientBootstrap(group: self.eventLoop, logger: self.logger)
-      .channelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
-      .channelOption(ChannelOptions.socket(IPPROTO_TCP, TCP_NODELAY), value: 1)
-      .channelInitializer { channel in
-        let initialized = channel.configureGRPCClient(
-          httpTargetWindowSize: self.configuration.httpTargetWindowSize,
-          tlsConfiguration: self.configuration.tls?.configuration,
-          tlsServerHostname: serverHostname,
-          connectionManager: self,
-          connectionKeepalive: self.configuration.connectionKeepalive,
-          connectionIdleTimeout: self.configuration.connectionIdleTimeout,
-          errorDelegate: self.configuration.errorDelegate,
-          requiresZeroLengthWriteWorkaround: PlatformSupport.requiresZeroLengthWriteWorkaround(
-            group: self.eventLoop,
-            hasTLS: self.configuration.tls != nil
-          ),
-          logger: self.logger,
-          customVerificationCallback: self.configuration.tls?.customVerificationCallback
-        )
-
-        // Run the debug initializer, if there is one.
-        if let debugInitializer = self.configuration.debugChannelInitializer {
-          return initialized.flatMap {
-            debugInitializer(channel)
-          }
-        } else {
-          return initialized
-        }
-      }
-
-    if let connectTimeout = connectTimeout {
-      return bootstrap.connectTimeout(.seconds(timeInterval: connectTimeout))
-    } else {
-      return bootstrap
-    }
-  }
-
-  private func makeChannel(
-    connectTimeout: TimeInterval?
-  ) -> EventLoopFuture<Channel> {
-    if let provider = self.channelProvider {
-      return provider()
-    } else {
-      let bootstrap = self.makeBootstrap(
-        connectTimeout: connectTimeout
-      )
-      return bootstrap.connect(to: self.configuration.target)
-    }
-  }
-}

+ 102 - 0
Sources/GRPC/ConnectionManagerChannelProvider.swift

@@ -0,0 +1,102 @@
+/*
+ * Copyright 2021, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import Logging
+import NIO
+
+internal protocol ConnectionManagerChannelProvider {
+  /// Make an `EventLoopFuture<Channel>`.
+  ///
+  /// - Parameters:
+  ///   - connectionManager: The `ConnectionManager` requesting the `Channel`.
+  ///   - eventLoop: The `EventLoop` to use for the`Channel`.
+  ///   - connectTimeout: Optional connection timeout when starting the connection.
+  ///   - logger: A logger.
+  func makeChannel(
+    managedBy connectionManager: ConnectionManager,
+    onEventLoop eventLoop: EventLoop,
+    connectTimeout: TimeAmount?,
+    logger: Logger
+  ) -> EventLoopFuture<Channel>
+}
+
+extension ClientConnection {
+  internal struct ChannelProvider {
+    private var configuration: Configuration
+
+    internal init(configuration: Configuration) {
+      self.configuration = configuration
+    }
+  }
+}
+
+extension ClientConnection.ChannelProvider: ConnectionManagerChannelProvider {
+  internal func makeChannel(
+    managedBy connectionManager: ConnectionManager,
+    onEventLoop eventLoop: EventLoop,
+    connectTimeout: TimeAmount?,
+    logger: Logger
+  ) -> EventLoopFuture<Channel> {
+    let serverHostname: String? = self.configuration.tls.flatMap { tls -> String? in
+      if let hostnameOverride = tls.hostnameOverride {
+        return hostnameOverride
+      } else {
+        return self.configuration.target.host
+      }
+    }.flatMap { hostname in
+      if hostname.isIPAddress {
+        return nil
+      } else {
+        return hostname
+      }
+    }
+
+    let bootstrap = PlatformSupport.makeClientBootstrap(group: eventLoop, logger: logger)
+      .channelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
+      .channelOption(ChannelOptions.socket(IPPROTO_TCP, TCP_NODELAY), value: 1)
+      .channelInitializer { channel in
+        let initialized = channel.configureGRPCClient(
+          httpTargetWindowSize: self.configuration.httpTargetWindowSize,
+          tlsConfiguration: self.configuration.tls?.configuration,
+          tlsServerHostname: serverHostname,
+          connectionManager: connectionManager,
+          connectionKeepalive: self.configuration.connectionKeepalive,
+          connectionIdleTimeout: self.configuration.connectionIdleTimeout,
+          errorDelegate: self.configuration.errorDelegate,
+          requiresZeroLengthWriteWorkaround: PlatformSupport.requiresZeroLengthWriteWorkaround(
+            group: eventLoop,
+            hasTLS: self.configuration.tls != nil
+          ),
+          logger: logger,
+          customVerificationCallback: self.configuration.tls?.customVerificationCallback
+        )
+
+        // Run the debug initializer, if there is one.
+        if let debugInitializer = self.configuration.debugChannelInitializer {
+          return initialized.flatMap {
+            debugInitializer(channel)
+          }
+        } else {
+          return initialized
+        }
+      }
+
+    if let connectTimeout = connectTimeout {
+      _ = bootstrap.connectTimeout(connectTimeout)
+    }
+
+    return bootstrap.connect(to: self.configuration.target)
+  }
+}