Browse Source

Remove `onNext` from connectivity state delegate (#506)

* Remove `onNext` from connectivity state delegate

* Avoid incorrect state transitions
George Barnett 6 years ago
parent
commit
bfdc15ec6a

+ 11 - 8
Sources/GRPC/ClientConnection.swift

@@ -159,6 +159,11 @@ extension ClientConnection {
     connectivityMonitor: ConnectivityStateMonitor,
     backoffIterator: ConnectionBackoff.Iterator?
   ) -> EventLoopFuture<Channel> {
+    // We could have been shutdown by the user, avoid a connection attempt if this is the case.
+    guard connectivityMonitor.state != .shutdown else {
+      return configuration.eventLoopGroup.next().makeFailedFuture(GRPCStatus.processingError)
+    }
+
     connectivityMonitor.state = .connecting
     let timeoutAndBackoff = backoffIterator?.next()
     var bootstrap = ClientConnection.makeBootstrap(configuration: configuration)
@@ -174,14 +179,12 @@ extension ClientConnection {
       } else {
         return channel.eventLoop.makeSucceededFuture(channel)
       }
-    }.always { result in
-      switch result {
-      case .success:
-        // Update the state once the channel has been assigned, when it may be used for making
-        // RPCs.
-        break
-
-      case .failure:
+    }
+
+    channel.whenFailure { _ in
+      // We could have been shutdown by the user whilst we were connecting. If we were then avoid
+      // the this extra state transition.
+      if connectivityMonitor.state != .shutdown {
         // We might try again in a moment.
         connectivityMonitor.state = timeoutAndBackoff == nil ? .shutdown : .transientFailure
       }

+ 0 - 56
Sources/GRPC/ConnectivityState.swift

@@ -55,12 +55,6 @@ public protocol ConnectivityStateDelegate: class {
 public class ConnectivityStateMonitor {
   public typealias Callback = () -> Void
 
-  private var idleCallback: Callback?
-  private var connectingCallback: Callback?
-  private var readyCallback: Callback?
-  private var transientFailureCallback: Callback?
-  private var shutdownCallback: Callback?
-
   /// A delegate to call when the connectivity state changes.
   public var delegate: ConnectivityStateDelegate?
 
@@ -69,7 +63,6 @@ public class ConnectivityStateMonitor {
     didSet {
       if oldValue != self.state {
         self.delegate?.connectivityStateDidChange(from: oldValue, to: self.state)
-        self.triggerAndResetCallback()
       }
     }
   }
@@ -81,53 +74,4 @@ public class ConnectivityStateMonitor {
     self.delegate = delegate
     self.state = .idle
   }
-
-  /// Registers a callback on the given state and calls it the next time that state is observed.
-  /// Subsequent transitions to that state will **not** trigger the callback.
-  ///
-  /// - Parameter state: The state on which to call the given callback.
-  /// - Parameter callback: The closure to call once the given state has been transitioned to. The
-  ///     `callback` can be removed by passing in `nil`.
-  public func onNext(state: ConnectivityState, callback: Callback?) {
-    switch state {
-    case .idle:
-      self.idleCallback = callback
-
-    case .connecting:
-      self.connectingCallback = callback
-
-    case .ready:
-      self.readyCallback = callback
-
-    case .transientFailure:
-      self.transientFailureCallback = callback
-
-    case .shutdown:
-      self.shutdownCallback = callback
-    }
-  }
-
-  private func triggerAndResetCallback() {
-    switch self.state {
-    case .idle:
-      self.idleCallback?()
-      self.idleCallback = nil
-
-    case .connecting:
-      self.connectingCallback?()
-      self.connectingCallback = nil
-
-    case .ready:
-      self.readyCallback?()
-      self.readyCallback = nil
-
-    case .transientFailure:
-      self.transientFailureCallback?()
-      self.transientFailureCallback = nil
-
-    case .shutdown:
-      self.shutdownCallback?()
-      self.shutdownCallback = nil
-    }
-  }
 }

+ 52 - 28
Tests/GRPCTests/ClientConnectionBackoffTests.swift

@@ -28,8 +28,45 @@ class ConnectivityStateCollectionDelegate: ConnectivityStateDelegate {
     return self.states
   }
 
+  var idleExpectation: XCTestExpectation?
+  var connectingExpectation: XCTestExpectation?
+  var readyExpectation: XCTestExpectation?
+  var transientFailureExpectation: XCTestExpectation?
+  var shutdownExpectation: XCTestExpectation?
+
+  init(
+    idle: XCTestExpectation? = nil,
+    connecting: XCTestExpectation? = nil,
+    ready: XCTestExpectation? = nil,
+    transientFailure: XCTestExpectation? = nil,
+    shutdown: XCTestExpectation? = nil
+  ) {
+    self.idleExpectation = idle
+    self.connectingExpectation = connecting
+    self.readyExpectation = ready
+    self.transientFailureExpectation = transientFailure
+    self.shutdownExpectation = shutdown
+  }
+
   func connectivityStateDidChange(from oldState: ConnectivityState, to newState: ConnectivityState) {
     self.states.append(newState)
+
+    switch newState {
+    case .idle:
+      self.idleExpectation?.fulfill()
+
+    case .connecting:
+      self.connectingExpectation?.fulfill()
+
+    case .ready:
+      self.readyExpectation?.fulfill()
+
+    case .transientFailure:
+      self.transientFailureExpectation?.fulfill()
+
+    case .shutdown:
+      self.shutdownExpectation?.fulfill()
+    }
   }
 }
 
@@ -93,28 +130,21 @@ class ClientConnectionBackoffTests: XCTestCase {
     configuration.connectionBackoff = nil
 
     let connectionShutdown = self.expectation(description: "client shutdown")
+    self.stateDelegate.shutdownExpectation = connectionShutdown
     self.client = self.makeClientConnection(configuration)
-    self.client.connectivity.onNext(state: .shutdown) {
-      connectionShutdown.fulfill()
-    }
 
     self.wait(for: [connectionShutdown], timeout: 1.0)
     XCTAssertEqual(self.stateDelegate.states, [.connecting, .shutdown])
   }
 
   func testClientEventuallyConnects() throws {
-    // Start the client first.
-    self.client = self.makeClientConnection(self.makeClientConfiguration())
-
     let transientFailure = self.expectation(description: "connection transientFailure")
-    self.client.connectivity.onNext(state: .transientFailure) {
-      transientFailure.fulfill()
-    }
-
     let connectionReady = self.expectation(description: "connection ready")
-    self.client.connectivity.onNext(state: .ready) {
-      connectionReady.fulfill()
-    }
+    self.stateDelegate.transientFailureExpectation = transientFailure
+    self.stateDelegate.readyExpectation = connectionReady
+
+    // Start the client first.
+    self.client = self.makeClientConnection(self.makeClientConfiguration())
 
     self.wait(for: [transientFailure], timeout: 1.0)
 
@@ -128,10 +158,8 @@ class ClientConnectionBackoffTests: XCTestCase {
 
   func testClientEventuallyTimesOut() throws {
     let connectionShutdown = self.expectation(description: "connection shutdown")
+    self.stateDelegate.shutdownExpectation = connectionShutdown
     self.client = self.makeClientConnection(self.makeClientConfiguration())
-    self.client.connectivity.onNext(state: .shutdown) {
-      connectionShutdown.fulfill()
-    }
 
     self.wait(for: [connectionShutdown], timeout: 1.0)
     XCTAssertEqual(self.stateDelegate.states, [.connecting, .transientFailure, .connecting, .shutdown])
@@ -141,13 +169,15 @@ class ClientConnectionBackoffTests: XCTestCase {
     self.server = self.makeServer()
     let server = try self.server.wait()
 
-    let connectionReady = self.expectation(description: "connection ready")
     var configuration = self.makeClientConfiguration()
     configuration.connectionBackoff!.maximumBackoff = 2.0
+
+    let connectionReady = self.expectation(description: "connection ready")
+    let transientFailure = self.expectation(description: "connection transientFailure")
+    self.stateDelegate.readyExpectation = connectionReady
+    self.stateDelegate.transientFailureExpectation = transientFailure
+
     self.client = self.makeClientConnection(configuration)
-    self.client.connectivity.onNext(state: .ready) {
-      connectionReady.fulfill()
-    }
 
     // Once the connection is ready we can kill the server.
     self.wait(for: [connectionReady], timeout: 1.0)
@@ -158,18 +188,12 @@ class ClientConnectionBackoffTests: XCTestCase {
     self.server = nil
     self.serverGroup = nil
 
-    let transientFailure = self.expectation(description: "connection transientFailure")
-    self.client.connectivity.onNext(state: .transientFailure) {
-      transientFailure.fulfill()
-    }
-
     self.wait(for: [transientFailure], timeout: 1.0)
     XCTAssertEqual(self.stateDelegate.clearStates(), [.connecting, .transientFailure])
 
+    // Replace the ready expectation (since it's already been fulfilled).
     let reconnectionReady = self.expectation(description: "(re)connection ready")
-    self.client.connectivity.onNext(state: .ready) {
-      reconnectionReady.fulfill()
-    }
+    self.stateDelegate.readyExpectation = reconnectionReady
 
     let echo = Echo_EchoServiceClient(connection: self.client)
     // This should succeed once we get a connection again.

+ 12 - 12
Tests/GRPCTests/ClientTLSFailureTests.swift

@@ -121,10 +121,10 @@ class ClientTLSFailureTests: XCTestCase {
     let errorRecorder = ErrorRecordingDelegate(expectation: errorExpectation)
     configuration.errorDelegate = errorRecorder
 
-    let connection = ClientConnection(configuration: configuration)
-    connection.connectivity.onNext(state: .shutdown) {
-      shutdownExpectation.fulfill()
-    }
+    let stateChangeDelegate = ConnectivityStateCollectionDelegate(shutdown: shutdownExpectation)
+    configuration.connectivityStateDelegate = stateChangeDelegate
+
+    _ = ClientConnection(configuration: configuration)
 
     self.wait(for: [shutdownExpectation, errorExpectation], timeout: self.defaultTestTimeout)
 
@@ -143,10 +143,10 @@ class ClientTLSFailureTests: XCTestCase {
     let errorRecorder = ErrorRecordingDelegate(expectation: errorExpectation)
     configuration.errorDelegate = errorRecorder
 
-    let connection = ClientConnection(configuration: configuration)
-    connection.connectivity.onNext(state: .shutdown) {
-      shutdownExpectation.fulfill()
-    }
+    let stateChangeDelegate = ConnectivityStateCollectionDelegate(shutdown: shutdownExpectation)
+    configuration.connectivityStateDelegate = stateChangeDelegate
+
+    _ = ClientConnection(configuration: configuration)
 
     self.wait(for: [shutdownExpectation, errorExpectation], timeout: self.defaultTestTimeout)
 
@@ -170,10 +170,10 @@ class ClientTLSFailureTests: XCTestCase {
     let errorRecorder = ErrorRecordingDelegate(expectation: errorExpectation)
     configuration.errorDelegate = errorRecorder
 
-    let connection = ClientConnection(configuration: configuration)
-    connection.connectivity.onNext(state: .shutdown) {
-      shutdownExpectation.fulfill()
-    }
+    let stateChangeDelegate = ConnectivityStateCollectionDelegate(shutdown: shutdownExpectation)
+    configuration.connectivityStateDelegate = stateChangeDelegate
+
+    let _ = ClientConnection(configuration: configuration)
 
     self.wait(for: [shutdownExpectation, errorExpectation], timeout: self.defaultTestTimeout)
 

+ 1 - 57
Tests/GRPCTests/ConnectivityStateMonitorTests.swift

@@ -24,7 +24,7 @@ class ConnectivityStateMonitorTests: XCTestCase {
   let states: [ConnectivityState] = [.connecting, .ready, .transientFailure, .shutdown, .idle]
 
   func testDelegateOnlyCalledForChanges() {
-    let recorder = StateRecordingDelegate()
+    let recorder = ConnectivityStateCollectionDelegate()
     self.monitor.delegate = recorder
 
     self.monitor.state = .connecting
@@ -34,60 +34,4 @@ class ConnectivityStateMonitorTests: XCTestCase {
 
     XCTAssertEqual(recorder.states, [.connecting, .ready, .shutdown])
   }
-
-  func testOnNextIsOnlyInvokedOnce() {
-    for state in self.states {
-      let currentState = self.monitor.state
-
-      var calls = 0
-      self.monitor.onNext(state: state) {
-        calls += 1
-      }
-
-      // Trigger the callback.
-      self.monitor.state = state
-      XCTAssertEqual(calls, 1)
-
-      // Go back and forth; the callback should not be triggered again.
-      self.monitor.state = currentState
-      self.monitor.state = state
-      XCTAssertEqual(calls, 1)
-    }
-  }
-
-  func testRemovingCallbacks() {
-    for state in self.states {
-      self.monitor.onNext(state: state) {
-        XCTFail("Callback unexpectedly called")
-      }
-
-      self.monitor.onNext(state: state, callback: nil)
-      self.monitor.state = state
-    }
-  }
-
-  func testMultipleCallbacksRegistered() {
-    var calls = 0
-    self.states.forEach {
-      self.monitor.onNext(state: $0) {
-        calls += 1
-      }
-    }
-
-    self.states.forEach {
-      self.monitor.state = $0
-    }
-
-    XCTAssertEqual(calls, self.states.count)
-  }
-}
-
-extension ConnectivityStateMonitorTests {
-  /// A `ConnectivityStateDelegate` which each new state.
-  class StateRecordingDelegate: ConnectivityStateDelegate {
-    var states: [ConnectivityState] = []
-    func connectivityStateDidChange(from oldState: ConnectivityState, to newState: ConnectivityState) {
-      self.states.append(newState)
-    }
-  }
 }

+ 0 - 3
Tests/GRPCTests/XCTestManifests.swift

@@ -108,9 +108,6 @@ extension ConnectivityStateMonitorTests {
     // to regenerate.
     static let __allTests__ConnectivityStateMonitorTests = [
         ("testDelegateOnlyCalledForChanges", testDelegateOnlyCalledForChanges),
-        ("testMultipleCallbacksRegistered", testMultipleCallbacksRegistered),
-        ("testOnNextIsOnlyInvokedOnce", testOnNextIsOnlyInvokedOnce),
-        ("testRemovingCallbacks", testRemovingCallbacks),
     ]
 }