Prechádzať zdrojové kódy

Make Subchannel internal state handling clearer (#1949)

Motivation:

The subchannel conflates closing (recoverable) with shutting down
(terminal). Shutting down is initiated by a higher level (load-balancer,
user) while closing happens when the subchannel closes unexpectedtly or
is no longer required (i.e. becomes idle). A closed subchannel can be
re-opeend, a shutdown subchannel can't. This distinction isn't clear
enough in the state handling.

Modifications:

- Rename 'close' to 'shutDown' where applicable
- Add a new 'shutting down' state and renaming the 'closing' state to
  'going away'.
- Add a state machine diagram
- Fix up a few state transitions

Result:

More robust state handling
George Barnett 1 rok pred
rodič
commit
b5ca79a915

+ 5 - 5
Sources/GRPCHTTP2Core/Client/Connection/LoadBalancers/PickFirstLoadBalancer.swift

@@ -185,7 +185,7 @@ extension PickFirstLoadBalancer {
     switch onUpdate {
     case .connect(let newSubchannel, close: let oldSubchannel):
       self.runSubchannel(newSubchannel, in: &group)
-      oldSubchannel?.close()
+      oldSubchannel?.shutDown()
 
     case .none:
       ()
@@ -226,9 +226,9 @@ extension PickFirstLoadBalancer {
 
     switch onUpdateState {
     case .close(let subchannel):
-      subchannel.close()
+      subchannel.shutDown()
     case .closeAndPublishStateChange(let subchannel, let connectivityState):
-      subchannel.close()
+      subchannel.shutDown()
       self.event.continuation.yield(.connectivityStateChanged(connectivityState))
     case .publishStateChange(let connectivityState):
       self.event.continuation.yield(.connectivityStateChanged(connectivityState))
@@ -251,8 +251,8 @@ extension PickFirstLoadBalancer {
     switch onClose {
     case .closeSubchannels(let subchannel1, let subchannel2):
       self.event.continuation.yield(.connectivityStateChanged(.shutdown))
-      subchannel1.close()
-      subchannel2?.close()
+      subchannel1.shutDown()
+      subchannel2?.shutDown()
 
     case .closed:
       self.event.continuation.yield(.connectivityStateChanged(.shutdown))

+ 5 - 5
Sources/GRPCHTTP2Core/Client/Connection/LoadBalancers/RoundRobinLoadBalancer.swift

@@ -245,7 +245,7 @@ extension RoundRobinLoadBalancer {
     // present if there are more to remove than to add. These are the excess subchannels which
     // are closed now.
     for subchannel in removed {
-      subchannel.close()
+      subchannel.shutDown()
     }
   }
 
@@ -288,10 +288,10 @@ extension RoundRobinLoadBalancer {
 
     case .closeAndPublishStateChange(let subchannel, let aggregateState):
       self.event.continuation.yield(.connectivityStateChanged(aggregateState))
-      subchannel.close()
+      subchannel.shutDown()
 
     case .close(let subchannel):
-      subchannel.close()
+      subchannel.shutDown()
 
     case .closed:
       // All subchannels are closed; finish the streams so the run loop exits.
@@ -306,7 +306,7 @@ extension RoundRobinLoadBalancer {
   private func handleSubchannelGoingAway(key: EndpointKey) {
     switch self.state.withLockedValue({ $0.parkSubchannel(withKey: key) }) {
     case .closeAndUpdateState(let subchannel, let connectivityState):
-      subchannel.close()
+      subchannel.shutDown()
       if let connectivityState = connectivityState {
         self.event.continuation.yield(.connectivityStateChanged(connectivityState))
       }
@@ -323,7 +323,7 @@ extension RoundRobinLoadBalancer {
 
       // Close the subchannels.
       for subchannel in subchannels {
-        subchannel.close()
+        subchannel.shutDown()
       }
 
     case .closed:

+ 143 - 96
Sources/GRPCHTTP2Core/Client/Connection/LoadBalancers/Subchannel.swift

@@ -23,7 +23,7 @@ import NIOConcurrencyHelpers
 /// endpoint. You can tell it to start connecting by calling ``connect()`` and you can listen
 /// to connectivity state changes by consuming the ``events`` sequence.
 ///
-/// You must call ``close()`` on the ``Subchannel`` when it's no longer required. This will move
+/// You must call ``shutDown()`` on the ``Subchannel`` when it's no longer required. This will move
 /// it to the ``ConnectivityState/shutdown`` state: existing RPCs may continue but all subsequent
 /// calls to ``makeStream(descriptor:options:)`` will fail.
 ///
@@ -60,8 +60,8 @@ struct Subchannel {
     case connect
     /// A backoff period has ended.
     case backedOff
-    /// Close the connection, if possible.
-    case close
+    /// Shuts down the connection, if possible.
+    case shutDown
     /// Handle the event from the underlying connection object.
     case handleConnectionEvent(Connection.Event)
   }
@@ -103,7 +103,7 @@ struct Subchannel {
   ) {
     assert(!endpoint.addresses.isEmpty, "endpoint.addresses mustn't be empty")
 
-    self.state = NIOLockedValueBox(.notConnected)
+    self.state = NIOLockedValueBox(.notConnected(.initial))
     self.endpoint = endpoint
     self.id = id
     self.connector = connector
@@ -140,8 +140,8 @@ extension Subchannel {
           self.handleConnectInput(in: &group)
         case .backedOff:
           self.handleBackedOffInput(in: &group)
-        case .close:
-          self.handleCloseInput(in: &group)
+        case .shutDown:
+          self.handleShutDownInput(in: &group)
         case .handleConnectionEvent(let event):
           self.handleConnectionEvent(event, in: &group)
         }
@@ -161,8 +161,8 @@ extension Subchannel {
   }
 
   /// Initiates graceful shutdown, if possible.
-  func close() {
-    self.input.continuation.yield(.close)
+  func shutDown() {
+    self.input.continuation.yield(.shutDown)
   }
 
   /// Make a stream using the subchannel if it's ready.
@@ -175,7 +175,7 @@ extension Subchannel {
   ) async throws -> Connection.Stream {
     let connection: Connection? = self.state.withLockedValue { state in
       switch state {
-      case .notConnected, .connecting, .closing, .closed:
+      case .notConnected, .connecting, .goingAway, .shuttingDown, .shutDown:
         return nil
       case .connected(let connected):
         return connected.connection
@@ -218,8 +218,9 @@ extension Subchannel {
     case .none:
       ()
 
-    case .close(let connection):
-      connection.close()
+    case .finish:
+      self.event.continuation.finish()
+      self.input.continuation.finish()
 
     case .connect(let connection):
       // About to start connecting, emit a state change event.
@@ -228,11 +229,15 @@ extension Subchannel {
     }
   }
 
-  private func handleCloseInput(in group: inout DiscardingTaskGroup) {
-    switch self.state.withLockedValue({ $0.close() }) {
+  private func handleShutDownInput(in group: inout DiscardingTaskGroup) {
+    switch self.state.withLockedValue({ $0.shutDown() }) {
     case .none:
       ()
 
+    case .emitShutdown:
+      // Connection closed because the load balancer asked it to, so notify the load balancer.
+      self.event.continuation.yield(.connectivityStateChanged(.shutdown))
+
     case .emitShutdownAndClose(let connection):
       // Connection closed because the load balancer asked it to, so notify the load balancer.
       self.event.continuation.yield(.connectivityStateChanged(.shutdown))
@@ -269,8 +274,10 @@ extension Subchannel {
       // Emit a connectivity state change: the load balancer can now use this subchannel.
       self.event.continuation.yield(.connectivityStateChanged(.ready))
 
-    case .closeAndEmitShutdown(let connection):
+    case .finishAndClose(let connection):
       self.event.continuation.yield(.connectivityStateChanged(.shutdown))
+      self.event.continuation.finish()
+      self.input.continuation.finish()
       connection.close()
 
     case .none:
@@ -299,9 +306,9 @@ extension Subchannel {
         }
       }
 
-    case .closeAndEmitShutdownEvent(let connection):
-      self.event.continuation.yield(.connectivityStateChanged(.shutdown))
-      connection.close()
+    case .finish:
+      self.event.continuation.finish()
+      self.input.continuation.finish()
 
     case .none:
       ()
@@ -363,24 +370,58 @@ extension Subchannel {
 
 @available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *)
 extension Subchannel {
+  ///            ┌───────────────┐
+  ///   ┌───────▶│ NOT CONNECTED │───────────shutDown─────────────┐
+  ///   │        └───────────────┘                                │
+  ///   │                │                                        │
+  ///   │    connFailed──┤connect                                 │
+  ///   │    /backedOff  │                                        │
+  ///   │     │          ▼                                        │
+  ///   │     │  ┌───────────────┐                                │
+  ///   │     └──│  CONNECTING   │──────┐                         │
+  ///   │        └───────────────┘      │                         │
+  ///   │                │              │                         │
+  /// closed       connSucceeded        │                         │
+  ///   │                │              │                         │
+  ///   │                ▼              │                         │
+  ///   │        ┌───────────────┐      │      ┌───────────────┐  │
+  ///   │        │   CONNECTED   │──shutDown──▶│ SHUTTING DOWN │  │
+  ///   │        └───────────────┘      │      └───────────────┘  │
+  ///   │                │              │              │          │
+  ///   │             goAway            │            closed       │
+  ///   │                │              │              │          │
+  ///   │                ▼              │              ▼          │
+  ///   │        ┌───────────────┐      │      ┌───────────────┐  │
+  ///   └────────│  GOING AWAY   │──────┘      │   SHUT DOWN   │◀─┘
+  ///            └───────────────┘             └───────────────┘
   private enum State {
     /// Not connected and not actively connecting.
-    case notConnected
+    case notConnected(NotConnected)
     /// A connection attempt is in-progress.
     case connecting(Connecting)
     /// A connection has been established.
     case connected(Connected)
-    /// The subchannel is closing.
-    case closing(Closing)
-    /// The subchannel is closed.
-    case closed
+    /// The subchannel is going away. It may return to the 'notConnected' state when the underlying
+    /// connection has closed.
+    case goingAway(GoingAway)
+    /// The subchannel is shutting down, it will enter the 'shutDown' state when closed, it may not
+    /// enter any other state.
+    case shuttingDown(ShuttingDown)
+    /// The subchannel is shutdown, this is a terminal state.
+    case shutDown(ShutDown)
+
+    struct NotConnected {
+      private init() {}
+      static let initial = NotConnected()
+      init(from state: Connected) {}
+      init(from state: GoingAway) {}
+    }
 
     struct Connecting {
       var connection: Connection
       let addresses: [SocketAddress]
       var addressIterator: Array<SocketAddress>.Iterator
       var backoff: ConnectionBackoff.Iterator
-      var shutdownRequested: Bool = false
     }
 
     struct Connected {
@@ -391,7 +432,7 @@ extension Subchannel {
       }
     }
 
-    struct Closing {
+    struct GoingAway {
       var connection: Connection
 
       init(from state: Connecting) {
@@ -403,6 +444,28 @@ extension Subchannel {
       }
     }
 
+    struct ShuttingDown {
+      var connection: Connection
+
+      init(from state: Connecting) {
+        self.connection = state.connection
+      }
+
+      init(from state: Connected) {
+        self.connection = state.connection
+      }
+
+      init(from state: GoingAway) {
+        self.connection = state.connection
+      }
+    }
+
+    struct ShutDown {
+      init(from state: ShuttingDown) {}
+      init(from state: GoingAway) {}
+      init(from state: NotConnected) {}
+    }
+
     mutating func makeConnection(
       to addresses: [SocketAddress],
       using connector: any HTTP2Connector,
@@ -432,7 +495,7 @@ extension Subchannel {
         self = .connecting(connecting)
         return connection
 
-      case .connecting, .connected, .closing, .closed:
+      case .connecting, .connected, .goingAway, .shuttingDown, .shutDown:
         return nil
       }
     }
@@ -441,63 +504,62 @@ extension Subchannel {
       case none
       case emitShutdownAndFinish
       case emitShutdownAndClose(Connection)
+      case emitShutdown
     }
 
-    mutating func close() -> OnClose {
-      let onClose: OnClose
+    mutating func shutDown() -> OnClose {
+      let onShutDown: OnClose
 
       switch self {
-      case .notConnected:
-        onClose = .emitShutdownAndFinish
+      case .notConnected(let state):
+        self = .shutDown(ShutDown(from: state))
+        onShutDown = .emitShutdownAndFinish
 
-      case .connecting(var state):
-        state.shutdownRequested = true
-        self = .connecting(state)
-        // Do nothing; the connection hasn't been established yet so can't be closed.
-        onClose = .none
+      case .connecting(let state):
+        // Only emit the shutdown; there's no connection to close yet.
+        self = .shuttingDown(ShuttingDown(from: state))
+        onShutDown = .emitShutdown
 
       case .connected(let state):
-        self = .closing(Closing(from: state))
-        onClose = .emitShutdownAndClose(state.connection)
+        self = .shuttingDown(ShuttingDown(from: state))
+        onShutDown = .emitShutdownAndClose(state.connection)
 
-      case .closing, .closed:
-        onClose = .none
+      case .goingAway(let state):
+        self = .shuttingDown(ShuttingDown(from: state))
+        onShutDown = .emitShutdown
+
+      case .shuttingDown, .shutDown:
+        onShutDown = .none
       }
 
-      return onClose
+      return onShutDown
     }
 
     enum OnConnectSucceeded {
       case updateStateToReady
-      case closeAndEmitShutdown(Connection)
+      case finishAndClose(Connection)
       case none
     }
 
     mutating func connectSucceeded() -> OnConnectSucceeded {
       switch self {
       case .connecting(let state):
-        if state.shutdownRequested {
-          self = .closing(Closing(from: state))
-          return .closeAndEmitShutdown(state.connection)
-        } else {
-          self = .connected(Connected(from: state))
-          return .updateStateToReady
-        }
+        self = .connected(Connected(from: state))
+        return .updateStateToReady
 
-      case .closing(let state):
-        // Shouldn't happen via the connecting state.
-        assertionFailure("Invalid state")
-        return .closeAndEmitShutdown(state.connection)
+      case .shuttingDown(let state):
+        self = .shutDown(ShutDown(from: state))
+        return .finishAndClose(state.connection)
 
-      case .notConnected, .connected, .closed:
+      case .notConnected, .connected, .goingAway, .shutDown:
         return .none
       }
     }
 
     enum OnConnectFailed {
       case none
+      case finish
       case connect(Connection)
-      case closeAndEmitShutdownEvent(Connection)
       case backoff(Duration)
     }
 
@@ -506,11 +568,7 @@ extension Subchannel {
 
       switch self {
       case .connecting(var state):
-        if state.shutdownRequested {
-          // Subchannel has been asked to shutdown, do so now.
-          self = .closing(Closing(from: state))
-          onConnectFailed = .closeAndEmitShutdownEvent(state.connection)
-        } else if let address = state.addressIterator.next() {
+        if let address = state.addressIterator.next() {
           state.connection = Connection(
             address: address,
             http2Connector: connector,
@@ -533,12 +591,11 @@ extension Subchannel {
           onConnectFailed = .backoff(backoff)
         }
 
-      case .closing:
-        // Should be handled via connection.closeRequested
-        assertionFailure("Invalid state")
-        onConnectFailed = .none
+      case .shuttingDown(let state):
+        self = .shutDown(ShutDown(from: state))
+        onConnectFailed = .finish
 
-      case .notConnected, .connected, .closed:
+      case .notConnected, .connected, .goingAway, .shutDown:
         onConnectFailed = .none
       }
 
@@ -548,26 +605,20 @@ extension Subchannel {
     enum OnBackedOff {
       case none
       case connect(Connection)
-      case close(Connection)
+      case finish
     }
 
     mutating func backedOff() -> OnBackedOff {
       switch self {
       case .connecting(let state):
-        if state.shutdownRequested {
-          self = .closing(Closing(from: state))
-          return .close(state.connection)
-        } else {
-          self = .connecting(state)
-          return .connect(state.connection)
-        }
+        self = .connecting(state)
+        return .connect(state.connection)
 
-      case .closing:
-        // Shouldn't happen via the connecting state.
-        assertionFailure("Invalid state")
-        return .none
+      case .shuttingDown(let state):
+        self = .shutDown(ShutDown(from: state))
+        return .finish
 
-      case .notConnected, .connected, .closed:
+      case .notConnected, .connected, .goingAway, .shutDown:
         return .none
       }
     }
@@ -575,9 +626,9 @@ extension Subchannel {
     mutating func goingAway() -> Bool {
       switch self {
       case .connected(let state):
-        self = .closing(Closing(from: state))
+        self = .goingAway(GoingAway(from: state))
         return true
-      case .notConnected, .closing, .connecting, .closed:
+      case .notConnected, .goingAway, .connecting, .shuttingDown, .shutDown:
         return false
       }
     }
@@ -593,37 +644,33 @@ extension Subchannel {
       let onClosed: OnClosed
 
       switch self {
-      case .connected:
+      case .connected(let state):
         switch reason {
         case .idleTimeout, .remote, .error(_, wasIdle: true):
-          self = .notConnected
+          self = .notConnected(NotConnected(from: state))
           onClosed = .emitIdle
 
         case .keepaliveTimeout, .error(_, wasIdle: false):
-          self = .notConnected
+          self = .notConnected(NotConnected(from: state))
           onClosed = .emitTransientFailureAndReconnect
 
         case .initiatedLocally:
-          self = .closed
+          // Should be in the 'shuttingDown' state.
+          assertionFailure("Invalid state")
+          let shuttingDown = State.ShuttingDown(from: state)
+          self = .shutDown(ShutDown(from: shuttingDown))
           onClosed = .finish(emitShutdown: true)
         }
 
-      case .closing:
-        switch reason {
-        case .idleTimeout, .remote, .error(_, wasIdle: true):
-          self = .notConnected
-          onClosed = .emitIdle
-
-        case .keepaliveTimeout, .error(_, wasIdle: false):
-          self = .notConnected
-          onClosed = .emitTransientFailureAndReconnect
+      case .goingAway(let state):
+        self = .notConnected(NotConnected(from: state))
+        onClosed = .emitIdle
 
-        case .initiatedLocally:
-          self = .closed
-          onClosed = .finish(emitShutdown: false)
-        }
+      case .shuttingDown(let state):
+        self = .shutDown(ShutDown(from: state))
+        return .finish(emitShutdown: false)
 
-      case .notConnected, .connecting, .closed:
+      case .notConnected, .connecting, .shutDown:
         onClosed = .nothing
       }
 

+ 0 - 3
Tests/GRPCHTTP2CoreTests/Client/Connection/GRPCChannelTests.swift

@@ -749,9 +749,6 @@ final class GRPCChannelTests: XCTestCase {
   }
 
   func testQueueRequestsThenClose() async throws {
-    let (resolver, continuation) = NameResolver.dynamic(updateMode: .push)
-    continuation.yield(.init(endpoints: [Endpoint()], serviceConfig: nil))
-
     // Set a high backoff so the channel stays in transient failure for long enough.
     var config = GRPCChannel.Config.defaults
     config.backoff.initial = .seconds(120)

+ 11 - 11
Tests/GRPCHTTP2CoreTests/Client/Connection/LoadBalancers/SubchannelTests.swift

@@ -35,7 +35,7 @@ final class SubchannelTests: XCTestCase {
       XCTAssertEqual(error.code, .unavailable)
     }
 
-    subchannel.close()
+    subchannel.shutDown()
   }
 
   func testMakeStreamOnShutdownSubchannel() async throws {
@@ -48,7 +48,7 @@ final class SubchannelTests: XCTestCase {
       connector: .never
     )
 
-    subchannel.close()
+    subchannel.shutDown()
     await subchannel.run()
 
     await XCTAssertThrowsErrorAsync(ofType: RPCError.self) {
@@ -105,7 +105,7 @@ final class SubchannelTests: XCTestCase {
               }
             }
           }
-          subchannel.close()
+          subchannel.shutDown()
 
         default:
           ()
@@ -121,7 +121,7 @@ final class SubchannelTests: XCTestCase {
     let subchannel = self.makeSubchannel(
       address: .unixDomainSocket(path: path),
       connector: .posix(),
-      backoff: .fixed(at: .milliseconds(100))
+      backoff: .fixed(at: .milliseconds(10))
     )
 
     await withThrowingTaskGroup(of: Void.self) { group in
@@ -150,7 +150,7 @@ final class SubchannelTests: XCTestCase {
           }
 
         case .connectivityStateChanged(.ready):
-          subchannel.close()
+          subchannel.shutDown()
 
         case .connectivityStateChanged(.shutdown):
           group.cancelAll()
@@ -212,7 +212,7 @@ final class SubchannelTests: XCTestCase {
         case .connectivityStateChanged(.idle):
           subchannel.connect()
         case .connectivityStateChanged(.ready):
-          subchannel.close()
+          subchannel.shutDown()
         case .connectivityStateChanged(.shutdown):
           group.cancelAll()
         default:
@@ -263,7 +263,7 @@ final class SubchannelTests: XCTestCase {
           }
 
         case .connectivityStateChanged(.ready):
-          subchannel.close()
+          subchannel.shutDown()
 
         case .connectivityStateChanged(.shutdown):
           group.cancelAll()
@@ -304,7 +304,7 @@ final class SubchannelTests: XCTestCase {
           if idleCount == 1 {
             subchannel.connect()
           } else {
-            subchannel.close()
+            subchannel.shutDown()
           }
 
         case .connectivityStateChanged(.shutdown):
@@ -356,7 +356,7 @@ final class SubchannelTests: XCTestCase {
           case 1:
             subchannel.connect()
           case 2:
-            subchannel.close()
+            subchannel.shutDown()
           default:
             XCTFail("Unexpected idle")
           }
@@ -430,7 +430,7 @@ final class SubchannelTests: XCTestCase {
               let _ = try await iterator.next()
             }
           } else if readyCount == 2 {
-            subchannel.close()
+            subchannel.shutDown()
           }
 
         case .connectivityStateChanged(.shutdown):
@@ -484,7 +484,7 @@ final class SubchannelTests: XCTestCase {
           if idleCount == 1 {
             subchannel.connect()
           } else if idleCount == 2 {
-            subchannel.close()
+            subchannel.shutDown()
           }
 
         case .connectivityStateChanged(.ready):