Quellcode durchsuchen

Add the pick-first load balancer (#1907)

George Barnett vor 1 Jahr
Ursprung
Commit
28523856f7

+ 20 - 2
Sources/GRPCHTTP2Core/Client/Connection/GRPCChannel.swift

@@ -438,10 +438,17 @@ extension GRPCChannel {
     endpoints: [Endpoint],
     in group: inout DiscardingTaskGroup
   ) {
+    assert(!endpoints.isEmpty, "endpoints must be non-empty")
+
     switch update {
     case .runLoadBalancer(let new, let old):
       old?.close()
-      new.updateAddresses(endpoints)
+      switch new {
+      case .roundRobin(let loadBalancer):
+        loadBalancer.updateAddresses(endpoints)
+      case .pickFirst(let loadBalancer):
+        loadBalancer.updateEndpoint(endpoints.first!)
+      }
 
       group.addTask {
         await new.run()
@@ -454,7 +461,12 @@ extension GRPCChannel {
       }
 
     case .updateLoadBalancer(let existing):
-      existing.updateAddresses(endpoints)
+      switch existing {
+      case .roundRobin(let loadBalancer):
+        loadBalancer.updateAddresses(endpoints)
+      case .pickFirst(let loadBalancer):
+        loadBalancer.updateEndpoint(endpoints.first!)
+      }
 
     case .none:
       ()
@@ -610,11 +622,17 @@ extension GRPCChannel.StateMachine {
 
   enum LoadBalancerKind {
     case roundRobin
+    case pickFirst
 
     func matches(loadBalancer: LoadBalancer) -> Bool {
       switch (self, loadBalancer) {
       case (.roundRobin, .roundRobin):
         return true
+      case (.pickFirst, .pickFirst):
+        return true
+      case (.roundRobin, .pickFirst),
+        (.pickFirst, .roundRobin):
+        return false
       }
     }
   }

+ 11 - 7
Sources/GRPCHTTP2Core/Client/Connection/LoadBalancers/LoadBalancer.swift

@@ -17,6 +17,7 @@
 @available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *)
 enum LoadBalancer: Sendable {
   case roundRobin(RoundRobinLoadBalancer)
+  case pickFirst(PickFirstLoadBalancer)
 }
 
 @available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *)
@@ -29,6 +30,8 @@ extension LoadBalancer {
     switch self {
     case .roundRobin(let loadBalancer):
       return loadBalancer.id
+    case .pickFirst(let loadBalancer):
+      return loadBalancer.id
     }
   }
 
@@ -36,6 +39,8 @@ extension LoadBalancer {
     switch self {
     case .roundRobin(let loadBalancer):
       return loadBalancer.events
+    case .pickFirst(let loadBalancer):
+      return loadBalancer.events
     }
   }
 
@@ -43,13 +48,8 @@ extension LoadBalancer {
     switch self {
     case .roundRobin(let loadBalancer):
       await loadBalancer.run()
-    }
-  }
-
-  func updateAddresses(_ endpoints: [Endpoint]) {
-    switch self {
-    case .roundRobin(let loadBalancer):
-      loadBalancer.updateAddresses(endpoints)
+    case .pickFirst(let loadBalancer):
+      await loadBalancer.run()
     }
   }
 
@@ -57,6 +57,8 @@ extension LoadBalancer {
     switch self {
     case .roundRobin(let loadBalancer):
       loadBalancer.close()
+    case .pickFirst(let loadBalancer):
+      loadBalancer.close()
     }
   }
 
@@ -64,6 +66,8 @@ extension LoadBalancer {
     switch self {
     case .roundRobin(let loadBalancer):
       return loadBalancer.pickSubchannel()
+    case .pickFirst(let loadBalancer):
+      return loadBalancer.pickSubchannel()
     }
   }
 }

+ 609 - 0
Sources/GRPCHTTP2Core/Client/Connection/LoadBalancers/PickFirstLoadBalancer.swift

@@ -0,0 +1,609 @@
+/*
+ * Copyright 2024, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import GRPCCore
+
+/// A load-balancer which has a single subchannel.
+///
+/// This load-balancer starts in an 'idle' state and begins connecting when a set of addresses is
+/// provided to it with ``updateEndpoint(_:)``. Repeated calls to ``updateEndpoint(_:)`` will
+/// update the subchannel gracefully: RPCs will continue to use the old subchannel until the new
+/// subchannel becomes ready.
+///
+/// You must call ``close()`` on the load-balancer when it's no longer required. This will move
+/// it to the ``ConnectivityState/shutdown`` state: existing RPCs may continue but all subsequent
+/// calls to ``makeStream(descriptor:options:)`` will fail.
+///
+/// To use this load-balancer you must run it in a task:
+///
+/// ```swift
+/// await withDiscardingTaskGroup { group in
+///   // Run the load-balancer
+///   group.addTask { await pickFirst.run() }
+///
+///   // Update its endpoint.
+///   let endpoint = Endpoint(
+///     addresses: [
+///       .ipv4(host: "127.0.0.1", port: 1001),
+///       .ipv4(host: "127.0.0.1", port: 1002),
+///       .ipv4(host: "127.0.0.1", port: 1003)
+///     ]
+///   )
+///   pickFirst.updateEndpoint(endpoint)
+///
+///   // Consume state update events
+///   for await event in pickFirst.events {
+///     switch event {
+///     case .connectivityStateChanged(.ready):
+///       // ...
+///     default:
+///       // ...
+///     }
+///   }
+/// }
+/// ```
+@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *)
+struct PickFirstLoadBalancer {
+  enum Input: Sendable, Hashable {
+    /// Update the addresses used by the load balancer to the following endpoints.
+    case updateEndpoint(Endpoint)
+    /// Close the load balancer.
+    case close
+  }
+
+  /// Events which can happen to the load balancer.
+  private let event:
+    (
+      stream: AsyncStream<LoadBalancerEvent>,
+      continuation: AsyncStream<LoadBalancerEvent>.Continuation
+    )
+
+  /// Inputs which this load balancer should react to.
+  private let input: (stream: AsyncStream<Input>, continuation: AsyncStream<Input>.Continuation)
+
+  /// A connector, capable of creating connections.
+  private let connector: any HTTP2Connector
+
+  /// Connection backoff configuration.
+  private let backoff: ConnectionBackoff
+
+  /// The default compression algorithm to use. Can be overridden on a per-call basis.
+  private let defaultCompression: CompressionAlgorithm
+
+  /// The set of enabled compression algorithms.
+  private let enabledCompression: CompressionAlgorithmSet
+
+  /// The state of the load-balancer.
+  private let state: _LockedValueBox<State>
+
+  /// The ID of this load balancer.
+  internal let id: LoadBalancerID
+
+  init(
+    connector: any HTTP2Connector,
+    backoff: ConnectionBackoff,
+    defaultCompression: CompressionAlgorithm,
+    enabledCompression: CompressionAlgorithmSet
+  ) {
+    self.connector = connector
+    self.backoff = backoff
+    self.defaultCompression = defaultCompression
+    self.enabledCompression = enabledCompression
+    self.id = LoadBalancerID()
+    self.state = _LockedValueBox(State())
+
+    self.event = AsyncStream.makeStream(of: LoadBalancerEvent.self)
+    self.input = AsyncStream.makeStream(of: Input.self)
+    // The load balancer starts in the idle state.
+    self.event.continuation.yield(.connectivityStateChanged(.idle))
+  }
+
+  /// A stream of events which can happen to the load balancer.
+  var events: AsyncStream<LoadBalancerEvent> {
+    self.event.stream
+  }
+
+  /// Runs the load balancer, returning when it has closed.
+  ///
+  /// You can monitor events which happen on the load balancer with ``events``.
+  func run() async {
+    await withDiscardingTaskGroup { group in
+      for await input in self.input.stream {
+        switch input {
+        case .updateEndpoint(let endpoint):
+          self.handleUpdateEndpoint(endpoint, in: &group)
+        case .close:
+          self.handleCloseInput()
+        }
+      }
+    }
+
+    if Task.isCancelled {
+      // Finish the event stream as it's unlikely to have been finished by a regular code path.
+      self.event.continuation.finish()
+    }
+  }
+
+  /// Update the addresses used by the load balancer.
+  ///
+  /// This may result in new subchannels being created and some subchannels being removed.
+  func updateEndpoint(_ endpoint: Endpoint) {
+    self.input.continuation.yield(.updateEndpoint(endpoint))
+  }
+
+  /// Close the load balancer, and all subchannels it manages.
+  func close() {
+    self.input.continuation.yield(.close)
+  }
+
+  /// Pick a ready subchannel from the load balancer.
+  ///
+  /// - Returns: A subchannel, or `nil` if there aren't any ready subchannels.
+  func pickSubchannel() -> Subchannel? {
+    let onPickSubchannel = self.state.withLockedValue { $0.pickSubchannel() }
+    switch onPickSubchannel {
+    case .picked(let subchannel):
+      return subchannel
+    case .notAvailable(let subchannel):
+      subchannel?.connect()
+      return nil
+    }
+  }
+}
+
+@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *)
+extension PickFirstLoadBalancer {
+  private func handleUpdateEndpoint(_ endpoint: Endpoint, in group: inout DiscardingTaskGroup) {
+    if endpoint.addresses.isEmpty { return }
+
+    let onUpdate = self.state.withLockedValue { state in
+      state.updateEndpoint(endpoint) { endpoint, id in
+        Subchannel(
+          endpoint: endpoint,
+          id: id,
+          connector: self.connector,
+          backoff: self.backoff,
+          defaultCompression: self.defaultCompression,
+          enabledCompression: self.enabledCompression
+        )
+      }
+    }
+
+    switch onUpdate {
+    case .connect(let newSubchannel, close: let oldSubchannel):
+      self.runSubchannel(newSubchannel, in: &group)
+      oldSubchannel?.close()
+
+    case .none:
+      ()
+    }
+  }
+
+  private func runSubchannel(
+    _ subchannel: Subchannel,
+    in group: inout DiscardingTaskGroup
+  ) {
+    // Start running it and tell it to connect.
+    subchannel.connect()
+    group.addTask {
+      await subchannel.run()
+    }
+
+    group.addTask {
+      for await event in subchannel.events {
+        switch event {
+        case .connectivityStateChanged(let state):
+          self.handleSubchannelConnectivityStateChange(state, id: subchannel.id)
+        case .goingAway:
+          self.handleGoAway(id: subchannel.id)
+        case .requiresNameResolution:
+          self.event.continuation.yield(.requiresNameResolution)
+        }
+      }
+    }
+  }
+
+  private func handleSubchannelConnectivityStateChange(
+    _ connectivityState: ConnectivityState,
+    id: SubchannelID
+  ) {
+    let onUpdateState = self.state.withLockedValue {
+      $0.updateSubchannelConnectivityState(connectivityState, id: id)
+    }
+
+    switch onUpdateState {
+    case .close(let subchannel):
+      subchannel.close()
+    case .closeAndPublishStateChange(let subchannel, let connectivityState):
+      subchannel.close()
+      self.event.continuation.yield(.connectivityStateChanged(connectivityState))
+    case .publishStateChange(let connectivityState):
+      self.event.continuation.yield(.connectivityStateChanged(connectivityState))
+    case .closed:
+      self.event.continuation.finish()
+      self.input.continuation.finish()
+    case .none:
+      ()
+    }
+  }
+
+  private func handleGoAway(id: SubchannelID) {
+    self.state.withLockedValue { state in
+      state.receivedGoAway(id: id)
+    }
+  }
+
+  private func handleCloseInput() {
+    let onClose = self.state.withLockedValue { $0.close() }
+    switch onClose {
+    case .closeSubchannels(let subchannel1, let subchannel2):
+      self.event.continuation.yield(.connectivityStateChanged(.shutdown))
+      subchannel1.close()
+      subchannel2?.close()
+
+    case .closed:
+      self.event.continuation.yield(.connectivityStateChanged(.shutdown))
+      self.event.continuation.finish()
+      self.input.continuation.finish()
+
+    case .none:
+      ()
+    }
+  }
+}
+
+@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *)
+extension PickFirstLoadBalancer {
+  enum State: Sendable {
+    case active(Active)
+    case closing(Closing)
+    case closed
+
+    init() {
+      self = .active(Active())
+    }
+  }
+}
+
+@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *)
+extension PickFirstLoadBalancer.State {
+  struct Active: Sendable {
+    var endpoint: Endpoint?
+    var connectivityState: ConnectivityState
+    var current: Subchannel?
+    var next: Subchannel?
+    var parked: [SubchannelID: Subchannel]
+    var isCurrentGoingAway: Bool
+
+    init() {
+      self.endpoint = nil
+      self.connectivityState = .idle
+      self.current = nil
+      self.next = nil
+      self.parked = [:]
+      self.isCurrentGoingAway = false
+    }
+  }
+
+  struct Closing: Sendable {
+    var parked: [SubchannelID: Subchannel]
+
+    init(from state: Active) {
+      self.parked = state.parked
+    }
+  }
+}
+
+@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *)
+extension PickFirstLoadBalancer.State.Active {
+  mutating func updateEndpoint(
+    _ endpoint: Endpoint,
+    makeSubchannel: (_ endpoint: Endpoint, _ id: SubchannelID) -> Subchannel
+  ) -> PickFirstLoadBalancer.State.OnUpdateEndpoint {
+    if self.endpoint == endpoint { return .none }
+
+    let onUpdateEndpoint: PickFirstLoadBalancer.State.OnUpdateEndpoint
+
+    let id = SubchannelID()
+    let newSubchannel = makeSubchannel(endpoint, id)
+
+    switch (self.current, self.next) {
+    case (.some(let current), .none):
+      if self.connectivityState == .idle {
+        // Current subchannel is idle and we have a new endpoint, move straight to the new
+        // subchannel.
+        self.current = newSubchannel
+        self.parked[current.id] = current
+        onUpdateEndpoint = .connect(newSubchannel, close: current)
+      } else {
+        // Current subchannel is in a non-idle state, set it as the next subchannel and promote
+        // it when it becomes ready.
+        self.next = newSubchannel
+        onUpdateEndpoint = .connect(newSubchannel, close: nil)
+      }
+
+    case (.some, .some(let next)):
+      // Current and next subchannel exist. Replace the next subchannel.
+      self.next = newSubchannel
+      self.parked[next.id] = next
+      onUpdateEndpoint = .connect(newSubchannel, close: next)
+
+    case (.none, .none):
+      self.current = newSubchannel
+      onUpdateEndpoint = .connect(newSubchannel, close: nil)
+
+    case (.none, .some(let next)):
+      self.current = newSubchannel
+      self.next = nil
+      self.parked[next.id] = next
+      onUpdateEndpoint = .connect(newSubchannel, close: next)
+    }
+
+    return onUpdateEndpoint
+  }
+
+  mutating func updateSubchannelConnectivityState(
+    _ connectivityState: ConnectivityState,
+    id: SubchannelID
+  ) -> (PickFirstLoadBalancer.State.OnConnectivityStateUpdate, PickFirstLoadBalancer.State) {
+    let onUpdate: PickFirstLoadBalancer.State.OnConnectivityStateUpdate
+
+    if let current = self.current, current.id == id {
+      if connectivityState == self.connectivityState {
+        onUpdate = .none
+      } else {
+        self.connectivityState = connectivityState
+        onUpdate = .publishStateChange(connectivityState)
+      }
+    } else if let next = self.next, next.id == id {
+      // if it becomes ready then promote it
+      switch connectivityState {
+      case .ready:
+        if self.connectivityState != connectivityState {
+          self.connectivityState = connectivityState
+
+          if let current = self.current {
+            onUpdate = .closeAndPublishStateChange(current, connectivityState)
+          } else {
+            onUpdate = .publishStateChange(connectivityState)
+          }
+
+          self.current = next
+          self.isCurrentGoingAway = false
+        } else {
+          // No state change to publish, just roll over.
+          onUpdate = self.current.map { .close($0) } ?? .none
+          self.current = next
+          self.isCurrentGoingAway = false
+        }
+
+      case .idle, .connecting, .transientFailure, .shutdown:
+        onUpdate = .none
+      }
+
+    } else {
+      switch connectivityState {
+      case .idle:
+        if let subchannel = self.parked[id] {
+          onUpdate = .close(subchannel)
+        } else {
+          onUpdate = .none
+        }
+
+      case .shutdown:
+        self.parked.removeValue(forKey: id)
+        onUpdate = .none
+
+      case .connecting, .ready, .transientFailure:
+        onUpdate = .none
+      }
+    }
+
+    return (onUpdate, .active(self))
+  }
+
+  mutating func receivedGoAway(id: SubchannelID) {
+    if let current = self.current, current.id == id {
+      // When receiving a GOAWAY the subchannel will ask for an address to be re-resolved and the
+      // connection will eventually become idle. At this point we wait: the connection remains
+      // in its current state.
+      self.isCurrentGoingAway = true
+    } else if let next = self.next, next.id == id {
+      // The next connection is going away, park it.
+      // connection.
+      self.next = nil
+      self.parked[next.id] = next
+    }
+  }
+
+  mutating func close() -> (PickFirstLoadBalancer.State.OnClose, PickFirstLoadBalancer.State) {
+    let onClose: PickFirstLoadBalancer.State.OnClose
+    let nextState: PickFirstLoadBalancer.State
+
+    if let current = self.current {
+      self.parked[current.id] = current
+      if let next = self.next {
+        self.parked[next.id] = next
+        onClose = .closeSubchannels(current, next)
+      } else {
+        onClose = .closeSubchannels(current, nil)
+      }
+      nextState = .closing(PickFirstLoadBalancer.State.Closing(from: self))
+    } else {
+      onClose = .closed
+      nextState = .closed
+    }
+
+    return (onClose, nextState)
+  }
+
+  func pickSubchannel() -> PickFirstLoadBalancer.State.OnPickSubchannel {
+    let onPick: PickFirstLoadBalancer.State.OnPickSubchannel
+
+    if let current = self.current, !self.isCurrentGoingAway {
+      switch self.connectivityState {
+      case .idle:
+        onPick = .notAvailable(current)
+      case .ready:
+        onPick = .picked(current)
+      case .connecting, .transientFailure, .shutdown:
+        onPick = .notAvailable(nil)
+      }
+    } else {
+      onPick = .notAvailable(nil)
+    }
+
+    return onPick
+  }
+}
+
+@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *)
+extension PickFirstLoadBalancer.State.Closing {
+  mutating func updateSubchannelConnectivityState(
+    _ connectivityState: ConnectivityState,
+    id: SubchannelID
+  ) -> (PickFirstLoadBalancer.State.OnConnectivityStateUpdate, PickFirstLoadBalancer.State) {
+    let onUpdate: PickFirstLoadBalancer.State.OnConnectivityStateUpdate
+    let nextState: PickFirstLoadBalancer.State
+
+    switch connectivityState {
+    case .idle:
+      if let subchannel = self.parked[id] {
+        onUpdate = .close(subchannel)
+      } else {
+        onUpdate = .none
+      }
+      nextState = .closing(self)
+
+    case .shutdown:
+      if self.parked.removeValue(forKey: id) != nil {
+        if self.parked.isEmpty {
+          onUpdate = .closed
+          nextState = .closed
+        } else {
+          onUpdate = .none
+          nextState = .closing(self)
+        }
+      } else {
+        onUpdate = .none
+        nextState = .closing(self)
+      }
+
+    case .connecting, .ready, .transientFailure:
+      onUpdate = .none
+      nextState = .closing(self)
+    }
+
+    return (onUpdate, nextState)
+  }
+}
+
+@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *)
+extension PickFirstLoadBalancer.State {
+  enum OnUpdateEndpoint {
+    case connect(Subchannel, close: Subchannel?)
+    case none
+  }
+
+  mutating func updateEndpoint(
+    _ endpoint: Endpoint,
+    makeSubchannel: (_ endpoint: Endpoint, _ id: SubchannelID) -> Subchannel
+  ) -> OnUpdateEndpoint {
+    let onUpdateEndpoint: OnUpdateEndpoint
+
+    switch self {
+    case .active(var state):
+      onUpdateEndpoint = state.updateEndpoint(endpoint) { endpoint, id in
+        makeSubchannel(endpoint, id)
+      }
+      self = .active(state)
+
+    case .closing, .closed:
+      onUpdateEndpoint = .none
+    }
+
+    return onUpdateEndpoint
+  }
+
+  enum OnConnectivityStateUpdate {
+    case closeAndPublishStateChange(Subchannel, ConnectivityState)
+    case publishStateChange(ConnectivityState)
+    case close(Subchannel)
+    case closed
+    case none
+  }
+
+  mutating func updateSubchannelConnectivityState(
+    _ connectivityState: ConnectivityState,
+    id: SubchannelID
+  ) -> OnConnectivityStateUpdate {
+    let onUpdateState: OnConnectivityStateUpdate
+
+    switch self {
+    case .active(var state):
+      (onUpdateState, self) = state.updateSubchannelConnectivityState(connectivityState, id: id)
+    case .closing(var state):
+      (onUpdateState, self) = state.updateSubchannelConnectivityState(connectivityState, id: id)
+    case .closed:
+      onUpdateState = .none
+    }
+
+    return onUpdateState
+  }
+
+  mutating func receivedGoAway(id: SubchannelID) {
+    switch self {
+    case .active(var state):
+      state.receivedGoAway(id: id)
+      self = .active(state)
+    case .closing, .closed:
+      ()
+    }
+  }
+
+  enum OnClose {
+    case closeSubchannels(Subchannel, Subchannel?)
+    case closed
+    case none
+  }
+
+  mutating func close() -> OnClose {
+    let onClose: OnClose
+
+    switch self {
+    case .active(var state):
+      (onClose, self) = state.close()
+    case .closing, .closed:
+      onClose = .none
+    }
+
+    return onClose
+  }
+
+  enum OnPickSubchannel {
+    case picked(Subchannel)
+    case notAvailable(Subchannel?)
+  }
+
+  func pickSubchannel() -> OnPickSubchannel {
+    switch self {
+    case .active(let state):
+      return state.pickSubchannel()
+    case .closing, .closed:
+      return .notAvailable(nil)
+    }
+  }
+}

+ 37 - 0
Tests/GRPCHTTP2CoreTests/Client/Connection/LoadBalancers/LoadBalancerTest.swift

@@ -24,6 +24,32 @@ enum LoadBalancerTest {
     let loadBalancer: LoadBalancer
   }
 
+  static func pickFirst(
+    servers serverCount: Int,
+    connector: any HTTP2Connector,
+    backoff: ConnectionBackoff = .defaults,
+    timeout: Duration = .seconds(10),
+    function: String = #function,
+    handleEvent: @escaping @Sendable (Context, LoadBalancerEvent) async throws -> Void,
+    verifyEvents: @escaping @Sendable ([LoadBalancerEvent]) -> Void = { _ in }
+  ) async throws {
+    try await Self.run(
+      servers: serverCount,
+      timeout: timeout,
+      function: function,
+      handleEvent: handleEvent,
+      verifyEvents: verifyEvents
+    ) {
+      let pickFirst = PickFirstLoadBalancer(
+        connector: connector,
+        backoff: backoff,
+        defaultCompression: .none,
+        enabledCompression: .none
+      )
+      return .pickFirst(pickFirst)
+    }
+  }
+
   static func roundRobin(
     servers serverCount: Int,
     connector: any HTTP2Connector,
@@ -143,6 +169,17 @@ extension LoadBalancerTest.Context {
     switch self.loadBalancer {
     case .roundRobin(let loadBalancer):
       return loadBalancer
+    case .pickFirst:
+      return nil
+    }
+  }
+
+  var pickFirst: PickFirstLoadBalancer? {
+    switch self.loadBalancer {
+    case .roundRobin:
+      return nil
+    case .pickFirst(let loadBalancer):
+      return loadBalancer
     }
   }
 }

+ 333 - 0
Tests/GRPCHTTP2CoreTests/Client/Connection/LoadBalancers/PickFirstLoadBalancerTests.swift

@@ -0,0 +1,333 @@
+/*
+ * Copyright 2024, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import Atomics
+import GRPCCore
+@_spi(Package) @testable import GRPCHTTP2Core
+import NIOHTTP2
+import NIOPosix
+import XCTest
+
+@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *)
+final class PickFirstLoadBalancerTests: XCTestCase {
+  func testPickFirstConnectsToServer() async throws {
+    try await LoadBalancerTest.pickFirst(servers: 1, connector: .posix()) { context, event in
+      switch event {
+      case .connectivityStateChanged(.idle):
+        let endpoint = Endpoint(addresses: context.servers.map { $0.address })
+        context.pickFirst!.updateEndpoint(endpoint)
+      case .connectivityStateChanged(.ready):
+        context.loadBalancer.close()
+      default:
+        ()
+      }
+    } verifyEvents: { events in
+      let expected: [LoadBalancerEvent] = [
+        .connectivityStateChanged(.idle),
+        .connectivityStateChanged(.connecting),
+        .connectivityStateChanged(.ready),
+        .connectivityStateChanged(.shutdown),
+      ]
+      XCTAssertEqual(events, expected)
+    }
+  }
+
+  func testPickSubchannelWhenNotReady() async throws {
+    try await LoadBalancerTest.pickFirst(servers: 1, connector: .posix()) { context, event in
+      switch event {
+      case .connectivityStateChanged(.idle):
+        XCTAssertNil(context.loadBalancer.pickSubchannel())
+        context.loadBalancer.close()
+      case .connectivityStateChanged(.shutdown):
+        XCTAssertNil(context.loadBalancer.pickSubchannel())
+      default:
+        ()
+      }
+    } verifyEvents: { events in
+      let expected: [LoadBalancerEvent] = [
+        .connectivityStateChanged(.idle),
+        .connectivityStateChanged(.shutdown),
+      ]
+      XCTAssertEqual(events, expected)
+    }
+  }
+
+  func testPickSubchannelReturnsSameSubchannel() async throws {
+    try await LoadBalancerTest.pickFirst(servers: 1, connector: .posix()) { context, event in
+      switch event {
+      case .connectivityStateChanged(.idle):
+        let endpoint = Endpoint(addresses: context.servers.map { $0.address })
+        context.pickFirst!.updateEndpoint(endpoint)
+
+      case .connectivityStateChanged(.ready):
+        var ids = Set<SubchannelID>()
+        for _ in 0 ..< 100 {
+          let subchannel = try XCTUnwrap(context.loadBalancer.pickSubchannel())
+          ids.insert(subchannel.id)
+        }
+        XCTAssertEqual(ids.count, 1)
+        context.loadBalancer.close()
+
+      default:
+        ()
+      }
+    } verifyEvents: { events in
+      let expected: [LoadBalancerEvent] = [
+        .connectivityStateChanged(.idle),
+        .connectivityStateChanged(.connecting),
+        .connectivityStateChanged(.ready),
+        .connectivityStateChanged(.shutdown),
+      ]
+      XCTAssertEqual(events, expected)
+    }
+  }
+
+  func testEndpointUpdateHandledGracefully() async throws {
+    try await LoadBalancerTest.pickFirst(servers: 2, connector: .posix()) { context, event in
+      switch event {
+      case .connectivityStateChanged(.idle):
+        let endpoint = Endpoint(addresses: [context.servers[0].address])
+        context.pickFirst!.updateEndpoint(endpoint)
+
+      case .connectivityStateChanged(.ready):
+        // Must be connected to server-0.
+        try await XCTPoll(every: .milliseconds(10)) {
+          context.servers[0].server.clients.count == 1
+        }
+
+        // Update the endpoint so that it contains server-1.
+        let endpoint = Endpoint(addresses: [context.servers[1].address])
+        context.pickFirst!.updateEndpoint(endpoint)
+
+        // Should remain in the ready state
+        try await XCTPoll(every: .milliseconds(10)) {
+          context.servers[0].server.clients.isEmpty && context.servers[1].server.clients.count == 1
+        }
+
+        context.loadBalancer.close()
+
+      default:
+        ()
+      }
+    } verifyEvents: { events in
+      let expected: [LoadBalancerEvent] = [
+        .connectivityStateChanged(.idle),
+        .connectivityStateChanged(.connecting),
+        .connectivityStateChanged(.ready),
+        .connectivityStateChanged(.shutdown),
+      ]
+      XCTAssertEqual(events, expected)
+    }
+  }
+
+  func testSameEndpointUpdateIsIgnored() async throws {
+    try await LoadBalancerTest.pickFirst(servers: 1, connector: .posix()) { context, event in
+      switch event {
+      case .connectivityStateChanged(.idle):
+        let endpoint = Endpoint(addresses: context.servers.map { $0.address })
+        context.pickFirst!.updateEndpoint(endpoint)
+
+      case .connectivityStateChanged(.ready):
+        // Must be connected to server-0.
+        try await XCTPoll(every: .milliseconds(10)) {
+          context.servers[0].server.clients.count == 1
+        }
+
+        // Update the endpoint. This should be a no-op, server should remain connected.
+        let endpoint = Endpoint(addresses: context.servers.map { $0.address })
+        context.pickFirst!.updateEndpoint(endpoint)
+        try await XCTPoll(every: .milliseconds(10)) {
+          context.servers[0].server.clients.count == 1
+        }
+
+        context.loadBalancer.close()
+
+      default:
+        ()
+      }
+    } verifyEvents: { events in
+      let expected: [LoadBalancerEvent] = [
+        .connectivityStateChanged(.idle),
+        .connectivityStateChanged(.connecting),
+        .connectivityStateChanged(.ready),
+        .connectivityStateChanged(.shutdown),
+      ]
+      XCTAssertEqual(events, expected)
+    }
+  }
+
+  func testEmptyEndpointUpdateIsIgnored() async throws {
+    // Checks that an update using the empty endpoint is ignored.
+    try await LoadBalancerTest.pickFirst(servers: 0, connector: .posix()) { context, event in
+      switch event {
+      case .connectivityStateChanged(.idle):
+        let endpoint = Endpoint(addresses: [])
+        // Should no-op.
+        context.pickFirst!.updateEndpoint(endpoint)
+        context.loadBalancer.close()
+
+      default:
+        ()
+      }
+    } verifyEvents: { events in
+      let expected: [LoadBalancerEvent] = [
+        .connectivityStateChanged(.idle),
+        .connectivityStateChanged(.shutdown),
+      ]
+      XCTAssertEqual(events, expected)
+    }
+  }
+
+  func testPickOnIdleTriggersConnect() async throws {
+    // Tests that picking a subchannel when the load balancer is idle triggers a reconnect and
+    // becomes ready again. Uses a very short idle time to re-enter the idle state.
+    let idle = ManagedAtomic(0)
+
+    try await LoadBalancerTest.pickFirst(
+      servers: 1,
+      connector: .posix(maxIdleTime: .milliseconds(1))  // Aggressively idle the connection
+    ) { context, event in
+      switch event {
+      case .connectivityStateChanged(.idle):
+        let idleCount = idle.wrappingIncrementThenLoad(ordering: .sequentiallyConsistent)
+
+        switch idleCount {
+        case 1:
+          // The first idle happens when the load balancer in started, give it an endpoint
+          // which it will connect to. Wait for it to be ready and then idle again.
+          let endpoint = Endpoint(addresses: context.servers.map { $0.address })
+          context.pickFirst!.updateEndpoint(endpoint)
+        case 2:
+          // Load-balancer has the endpoints but all are idle. Picking will trigger a connect.
+          XCTAssertNil(context.loadBalancer.pickSubchannel())
+        case 3:
+          // Connection idled again. Shut it down.
+          context.loadBalancer.close()
+
+        default:
+          XCTFail("Became idle too many times")
+        }
+
+      default:
+        ()
+      }
+    } verifyEvents: { events in
+      let expected: [LoadBalancerEvent] = [
+        .connectivityStateChanged(.idle),
+        .connectivityStateChanged(.connecting),
+        .connectivityStateChanged(.ready),
+        .connectivityStateChanged(.idle),
+        .connectivityStateChanged(.connecting),
+        .connectivityStateChanged(.ready),
+        .connectivityStateChanged(.idle),
+        .connectivityStateChanged(.shutdown),
+      ]
+      XCTAssertEqual(events, expected)
+    }
+  }
+
+  func testPickFirstConnectionDropReturnsToIdle() async throws {
+    // Checks that when the load balancers connection is unexpectedly dropped when there are no
+    // open streams that it returns to the idle state.
+    let idleCount = ManagedAtomic(0)
+
+    try await LoadBalancerTest.pickFirst(servers: 1, connector: .posix()) { context, event in
+      switch event {
+      case .connectivityStateChanged(.idle):
+        switch idleCount.wrappingIncrementThenLoad(ordering: .sequentiallyConsistent) {
+        case 1:
+          let endpoint = Endpoint(addresses: context.servers.map { $0.address })
+          context.pickFirst!.updateEndpoint(endpoint)
+        case 2:
+          context.loadBalancer.close()
+        default:
+          ()
+        }
+
+      case .connectivityStateChanged(.ready):
+        // Drop the connection.
+        context.servers[0].server.clients[0].close(mode: .all, promise: nil)
+
+      default:
+        ()
+      }
+    } verifyEvents: { events in
+      let expected: [LoadBalancerEvent] = [
+        .connectivityStateChanged(.idle),
+        .connectivityStateChanged(.connecting),
+        .connectivityStateChanged(.ready),
+        .connectivityStateChanged(.idle),
+        .connectivityStateChanged(.shutdown),
+      ]
+      XCTAssertEqual(events, expected)
+    }
+  }
+
+  func testPickFirstReceivesGoAway() async throws {
+    let idleCount = ManagedAtomic(0)
+    try await LoadBalancerTest.pickFirst(servers: 2, connector: .posix()) { context, event in
+      switch event {
+      case .connectivityStateChanged(.idle):
+        switch idleCount.wrappingIncrementThenLoad(ordering: .sequentiallyConsistent) {
+        case 1:
+          // Provide the address of the first server.
+          context.pickFirst!.updateEndpoint(Endpoint(context.servers[0].address))
+        case 2:
+          // Provide the address of the second server.
+          context.pickFirst!.updateEndpoint(Endpoint(context.servers[1].address))
+        default:
+          ()
+        }
+
+      case .connectivityStateChanged(.ready):
+        switch idleCount.load(ordering: .sequentiallyConsistent) {
+        case 1:
+          // Must be connected to server 1, send a GOAWAY frame.
+          let channel = context.servers[0].server.clients.first!
+          let goAway = HTTP2Frame(
+            streamID: .rootStream,
+            payload: .goAway(lastStreamID: 0, errorCode: .noError, opaqueData: nil)
+          )
+          channel.writeAndFlush(goAway, promise: nil)
+
+        case 2:
+          // Must only be connected to server 2 now.
+          XCTAssertEqual(context.servers[0].server.clients.count, 0)
+          XCTAssertEqual(context.servers[1].server.clients.count, 1)
+          context.loadBalancer.close()
+
+        default:
+          ()
+        }
+
+      default:
+        ()
+      }
+    } verifyEvents: { events in
+      let expected: [LoadBalancerEvent] = [
+        .connectivityStateChanged(.idle),
+        .connectivityStateChanged(.connecting),
+        .connectivityStateChanged(.ready),
+        .requiresNameResolution,
+        .connectivityStateChanged(.idle),
+        .connectivityStateChanged(.connecting),
+        .connectivityStateChanged(.ready),
+        .connectivityStateChanged(.shutdown),
+      ]
+      XCTAssertEqual(events, expected)
+    }
+  }
+}