Selaa lähdekoodia

Avoid dropping continuations in grpc channel (#1942)

Motivation:

If the request queue was non-empty and close was called on the grpc
channel, the request queue would be dropped along with any
continuations. The request queue can be non-empty if the active load
balancer isn't in the ready state (i.e. connecting).

Moreover, if close is called while a subchannel is connecting it can
result in the shutdown event being fired twice.

Modifications:

- Fail any continuations in the request queue when closing the grpc
  channel
- Alter when shutdown events are fired by the subchannel. At the moment
  they're typically fired when close is called and on some paths, when
  the connection is closed as well. Shutdown events are now fired when
  entering the closing state or when transitioning directly to closed
  (i.e. from idle/connected).

Result:

Continuations aren't dropped
George Barnett 1 vuosi sitten
vanhempi
commit
909d1aafce

+ 7 - 3
Sources/GRPCHTTP2Core/Client/Connection/GRPCChannel.swift

@@ -355,10 +355,13 @@ extension GRPCChannel {
 extension GRPCChannel {
   private func handleClose(in group: inout DiscardingTaskGroup) {
     switch self.state.withLockedValue({ $0.close() }) {
-    case .close(let current, let next, let resolver):
+    case .close(let current, let next, let resolver, let continuations):
       resolver?.cancel()
       current.close()
       next?.close()
+      for continuation in continuations {
+        continuation.resume(throwing: RPCError(code: .unavailable, message: "channel is closed"))
+      }
       self._connectivityState.continuation.yield(.shutdown)
 
     case .cancelAll(let continuations):
@@ -924,7 +927,7 @@ extension GRPCChannel.StateMachine {
   enum OnClose {
     case none
     case cancelAll([RequestQueue.Continuation])
-    case close(LoadBalancer, LoadBalancer?, CancellableTaskHandle?)
+    case close(LoadBalancer, LoadBalancer?, CancellableTaskHandle?, [RequestQueue.Continuation])
   }
 
   mutating func close() -> OnClose {
@@ -936,7 +939,8 @@ extension GRPCChannel.StateMachine {
       onClose = .cancelAll(state.queue.removeAll())
 
     case .running(var state):
-      onClose = .close(state.current, state.next, state.nameResolverHandle)
+      let continuations = state.queue.removeAll()
+      onClose = .close(state.current, state.next, state.nameResolverHandle, continuations)
 
       state.past[state.current.id] = state.current
       if let next = state.next {

+ 117 - 61
Sources/GRPCHTTP2Core/Client/Connection/LoadBalancers/Subchannel.swift

@@ -218,16 +218,13 @@ extension Subchannel {
     case .none:
       ()
 
+    case .close(let connection):
+      connection.close()
+
     case .connect(let connection):
       // About to start connecting, emit a state change event.
       self.event.continuation.yield(.connectivityStateChanged(.connecting))
       self.runConnection(connection, in: &group)
-
-    case .shutdown:
-      self.event.continuation.yield(.connectivityStateChanged(.shutdown))
-      // Close the event streams.
-      self.event.continuation.finish()
-      self.input.continuation.finish()
     }
   }
 
@@ -236,10 +233,12 @@ extension Subchannel {
     case .none:
       ()
 
-    case .close(let connection):
+    case .emitShutdownAndClose(let connection):
+      // Connection closed because the load balancer asked it to, so notify the load balancer.
+      self.event.continuation.yield(.connectivityStateChanged(.shutdown))
       connection.close()
 
-    case .shutdown:
+    case .emitShutdownAndFinish:
       // Connection closed because the load balancer asked it to, so notify the load balancer.
       self.event.continuation.yield(.connectivityStateChanged(.shutdown))
       // At this point there are no more events: close the event streams.
@@ -266,11 +265,12 @@ extension Subchannel {
 
   private func handleConnectSucceededEvent() {
     switch self.state.withLockedValue({ $0.connectSucceeded() }) {
-    case .updateState:
+    case .updateStateToReady:
       // Emit a connectivity state change: the load balancer can now use this subchannel.
       self.event.continuation.yield(.connectivityStateChanged(.ready))
 
-    case .close(let connection):
+    case .closeAndEmitShutdown(let connection):
+      self.event.continuation.yield(.connectivityStateChanged(.shutdown))
       connection.close()
 
     case .none:
@@ -299,11 +299,9 @@ extension Subchannel {
         }
       }
 
-    case .shutdown:
+    case .closeAndEmitShutdownEvent(let connection):
       self.event.continuation.yield(.connectivityStateChanged(.shutdown))
-      // No more events, close the streams.
-      self.event.continuation.finish()
-      self.input.continuation.finish()
+      connection.close()
 
     case .none:
       ()
@@ -325,26 +323,24 @@ extension Subchannel {
     _ reason: Connection.CloseReason,
     in group: inout DiscardingTaskGroup
   ) {
-    let isClosed = self.state.withLockedValue { $0.closed(reason: reason) }
-    guard isClosed else { return }
+    switch self.state.withLockedValue({ $0.closed(reason: reason) }) {
+    case .nothing:
+      ()
 
-    switch reason {
-    case .idleTimeout, .remote, .error(_, wasIdle: true):
-      // Connection closed due to an idle timeout or the remote telling it to GOAWAY; notify the
-      // load balancer about this.
+    case .emitIdle:
       self.event.continuation.yield(.connectivityStateChanged(.idle))
 
-    case .keepaliveTimeout, .error(_, wasIdle: false):
+    case .emitTransientFailureAndReconnect:
       // Unclean closes trigger a transient failure state change and a name resolution.
       self.event.continuation.yield(.connectivityStateChanged(.transientFailure))
       self.event.continuation.yield(.requiresNameResolution)
-
       // Attempt to reconnect.
       self.handleConnectInput(in: &group)
 
-    case .initiatedLocally:
-      // Connection closed because the load balancer asked it to, so notify the load balancer.
-      self.event.continuation.yield(.connectivityStateChanged(.shutdown))
+    case .finish(let emitShutdown):
+      if emitShutdown {
+        self.event.continuation.yield(.connectivityStateChanged(.shutdown))
+      }
 
       // At this point there are no more events: close the event streams.
       self.event.continuation.finish()
@@ -384,6 +380,7 @@ extension Subchannel {
       let addresses: [SocketAddress]
       var addressIterator: Array<SocketAddress>.Iterator
       var backoff: ConnectionBackoff.Iterator
+      var shutdownRequested: Bool = false
     }
 
     struct Connected {
@@ -442,8 +439,8 @@ extension Subchannel {
 
     enum OnClose {
       case none
-      case shutdown
-      case close(Connection)
+      case emitShutdownAndFinish
+      case emitShutdownAndClose(Connection)
     }
 
     mutating func close() -> OnClose {
@@ -451,16 +448,17 @@ extension Subchannel {
 
       switch self {
       case .notConnected:
-        onClose = .shutdown
+        onClose = .emitShutdownAndFinish
 
-      case .connecting(let state):
-        self = .closing(Closing(from: state))
+      case .connecting(var state):
+        state.shutdownRequested = true
+        self = .connecting(state)
         // Do nothing; the connection hasn't been established yet so can't be closed.
         onClose = .none
 
       case .connected(let state):
         self = .closing(Closing(from: state))
-        onClose = .close(state.connection)
+        onClose = .emitShutdownAndClose(state.connection)
 
       case .closing, .closed:
         onClose = .none
@@ -470,19 +468,27 @@ extension Subchannel {
     }
 
     enum OnConnectSucceeded {
-      case updateState
-      case close(Connection)
+      case updateStateToReady
+      case closeAndEmitShutdown(Connection)
       case none
     }
 
     mutating func connectSucceeded() -> OnConnectSucceeded {
       switch self {
       case .connecting(let state):
-        self = .connected(Connected(from: state))
-        return .updateState
+        if state.shutdownRequested {
+          self = .closing(Closing(from: state))
+          return .closeAndEmitShutdown(state.connection)
+        } else {
+          self = .connected(Connected(from: state))
+          return .updateStateToReady
+        }
+
       case .closing(let state):
-        self = .closing(state)
-        return .close(state.connection)
+        // Shouldn't happen via the connecting state.
+        assertionFailure("Invalid state")
+        return .closeAndEmitShutdown(state.connection)
+
       case .notConnected, .connected, .closed:
         return .none
       }
@@ -491,58 +497,76 @@ extension Subchannel {
     enum OnConnectFailed {
       case none
       case connect(Connection)
+      case closeAndEmitShutdownEvent(Connection)
       case backoff(Duration)
-      case shutdown
     }
 
     mutating func connectFailed(connector: any HTTP2Connector) -> OnConnectFailed {
+      let onConnectFailed: OnConnectFailed
+
       switch self {
-      case .connecting(var connecting):
-        if let address = connecting.addressIterator.next() {
-          connecting.connection = Connection(
+      case .connecting(var state):
+        if state.shutdownRequested {
+          // Subchannel has been asked to shutdown, do so now.
+          self = .closing(Closing(from: state))
+          onConnectFailed = .closeAndEmitShutdownEvent(state.connection)
+        } else if let address = state.addressIterator.next() {
+          state.connection = Connection(
             address: address,
             http2Connector: connector,
             defaultCompression: .none,
             enabledCompression: .all
           )
-          self = .connecting(connecting)
-          return .connect(connecting.connection)
+          self = .connecting(state)
+          onConnectFailed = .connect(state.connection)
         } else {
-          connecting.addressIterator = connecting.addresses.makeIterator()
-          let address = connecting.addressIterator.next()!
-          connecting.connection = Connection(
+          state.addressIterator = state.addresses.makeIterator()
+          let address = state.addressIterator.next()!
+          state.connection = Connection(
             address: address,
             http2Connector: connector,
             defaultCompression: .none,
             enabledCompression: .all
           )
-          let backoff = connecting.backoff.next()
-          self = .connecting(connecting)
-          return .backoff(backoff)
+          let backoff = state.backoff.next()
+          self = .connecting(state)
+          onConnectFailed = .backoff(backoff)
         }
 
       case .closing:
-        self = .closed
-        return .shutdown
+        // Should be handled via connection.closeRequested
+        assertionFailure("Invalid state")
+        onConnectFailed = .none
 
       case .notConnected, .connected, .closed:
-        return .none
+        onConnectFailed = .none
       }
+
+      return onConnectFailed
     }
 
     enum OnBackedOff {
       case none
       case connect(Connection)
-      case shutdown
+      case close(Connection)
     }
 
     mutating func backedOff() -> OnBackedOff {
       switch self {
       case .connecting(let state):
-        return .connect(state.connection)
+        if state.shutdownRequested {
+          self = .closing(Closing(from: state))
+          return .close(state.connection)
+        } else {
+          self = .connecting(state)
+          return .connect(state.connection)
+        }
+
       case .closing:
-        self = .closed
-        return .shutdown
+        // Shouldn't happen via the connecting state.
+        assertionFailure("Invalid state")
+        return .none
+
       case .notConnected, .connected, .closed:
         return .none
       }
@@ -558,20 +582,52 @@ extension Subchannel {
       }
     }
 
-    mutating func closed(reason: Connection.CloseReason) -> Bool {
+    enum OnClosed {
+      case nothing
+      case emitIdle
+      case emitTransientFailureAndReconnect
+      case finish(emitShutdown: Bool)
+    }
+
+    mutating func closed(reason: Connection.CloseReason) -> OnClosed {
+      let onClosed: OnClosed
+
       switch self {
-      case .connected, .closing:
+      case .connected:
+        switch reason {
+        case .idleTimeout, .remote, .error(_, wasIdle: true):
+          self = .notConnected
+          onClosed = .emitIdle
+
+        case .keepaliveTimeout, .error(_, wasIdle: false):
+          self = .notConnected
+          onClosed = .emitTransientFailureAndReconnect
+
+        case .initiatedLocally:
+          self = .closed
+          onClosed = .finish(emitShutdown: true)
+        }
+
+      case .closing:
         switch reason {
-        case .idleTimeout, .keepaliveTimeout, .error, .remote:
+        case .idleTimeout, .remote, .error(_, wasIdle: true):
+          self = .notConnected
+          onClosed = .emitIdle
+
+        case .keepaliveTimeout, .error(_, wasIdle: false):
           self = .notConnected
+          onClosed = .emitTransientFailureAndReconnect
+
         case .initiatedLocally:
           self = .closed
+          onClosed = .finish(emitShutdown: false)
         }
 
-        return true
       case .notConnected, .connecting, .closed:
-        return false
+        onClosed = .nothing
       }
+
+      return onClosed
     }
   }
 }

+ 52 - 0
Tests/GRPCHTTP2CoreTests/Client/Connection/GRPCChannelTests.swift

@@ -747,6 +747,58 @@ final class GRPCChannelTests: XCTestCase {
       group.cancelAll()
     }
   }
+
+  func testQueueRequestsThenClose() async throws {
+    let (resolver, continuation) = NameResolver.dynamic(updateMode: .push)
+    continuation.yield(.init(endpoints: [Endpoint()], serviceConfig: nil))
+
+    // Set a high backoff so the channel stays in transient failure for long enough.
+    var config = GRPCChannel.Config.defaults
+    config.backoff.initial = .seconds(120)
+
+    let channel = GRPCChannel(
+      resolver: .static(
+        endpoints: [
+          Endpoint(.unixDomainSocket(path: "/testQueueRequestsThenClose"))
+        ]
+      ),
+      connector: .posix(),
+      config: .defaults,
+      defaultServiceConfig: ServiceConfig()
+    )
+
+    try await withThrowingDiscardingTaskGroup { group in
+      group.addTask {
+        await channel.connect()
+      }
+
+      for try await state in channel.connectivityState {
+        switch state {
+        case .transientFailure:
+          group.addTask {
+            // Sleep a little to increase the chances of the stream being queued before the channel
+            // reacts to the close.
+            try await Task.sleep(for: .milliseconds(10))
+            channel.close()
+          }
+
+          // Try to open a new stream.
+          await XCTAssertThrowsErrorAsync(ofType: RPCError.self) {
+            try await channel.withStream(descriptor: .echoGet, options: .defaults) { stream in
+              XCTFail("Unexpected new stream")
+            }
+          } errorHandler: { error in
+            XCTAssertEqual(error.code, .unavailable)
+          }
+
+        default:
+          ()
+        }
+      }
+
+      group.cancelAll()
+    }
+  }
 }
 
 @available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *)