瀏覽代碼

Add pick first to grpc-channel (#1923)

Motivation:

The pick first load balancer was added in 28523856. However,
`GRPCChannel` wasn't updated to support it.

Modifications:

Add support to `GRPCChannel`

Result:

`GRPCChannel` can use the pick first LB
George Barnett 1 年之前
父節點
當前提交
1f20dfa017

+ 72 - 42
Sources/GRPCHTTP2Core/Client/Connection/GRPCChannel.swift

@@ -397,40 +397,87 @@ extension GRPCChannel {
     }
   }
 
+  enum SupportedLoadBalancerConfig {
+    case roundRobin
+    case pickFirst(ServiceConfig.LoadBalancingConfig.PickFirst)
+
+    init?(_ config: ServiceConfig.LoadBalancingConfig) {
+      if let pickFirst = config.pickFirst {
+        self = .pickFirst(pickFirst)
+      } else if config.roundRobin != nil {
+        self = .roundRobin
+      } else {
+        return nil
+      }
+    }
+
+    func matches(loadBalancer: LoadBalancer) -> Bool {
+      switch (self, loadBalancer) {
+      case (.roundRobin, .roundRobin):
+        return true
+      case (.pickFirst, .pickFirst):
+        return true
+      case (.roundRobin, .pickFirst),
+        (.pickFirst, .roundRobin):
+        return false
+      }
+    }
+  }
+
   private func updateLoadBalancer(
     serviceConfig: ServiceConfig,
     endpoints: [Endpoint],
     in group: inout DiscardingTaskGroup
   ) {
-    // Pick the first applicable policy, else fallback to pick-first.
-    for policy in serviceConfig.loadBalancingConfig {
-      let onUpdatePolicy: GRPCChannel.StateMachine.OnChangeLoadBalancer
-
-      if policy.roundRobin != nil {
-        onUpdatePolicy = self.state.withLockedValue { state in
-          state.changeLoadBalancerKind(to: .roundRobin) {
-            let loadBalancer = RoundRobinLoadBalancer(
-              connector: self.connector,
-              backoff: self.backoff,
-              defaultCompression: self.defaultCompression,
-              enabledCompression: self.enabledCompression
-            )
-            return .roundRobin(loadBalancer)
-          }
+    assert(!endpoints.isEmpty, "endpoints must be non-empty")
+
+    // Find the first supported config.
+    var configFromServiceConfig: SupportedLoadBalancerConfig?
+    for config in serviceConfig.loadBalancingConfig {
+      if let config = SupportedLoadBalancerConfig(config) {
+        configFromServiceConfig = config
+        break
+      }
+    }
+
+    let onUpdatePolicy: GRPCChannel.StateMachine.OnChangeLoadBalancer
+    var endpoints = endpoints
+
+    // Fallback to pick-first if no other config applies.
+    let loadBalancerConfig = configFromServiceConfig ?? .pickFirst(.init(shuffleAddressList: false))
+    switch loadBalancerConfig {
+    case .roundRobin:
+      onUpdatePolicy = self.state.withLockedValue { state in
+        state.changeLoadBalancerKind(to: loadBalancerConfig) {
+          let loadBalancer = RoundRobinLoadBalancer(
+            connector: self.connector,
+            backoff: self.backoff,
+            defaultCompression: self.defaultCompression,
+            enabledCompression: self.enabledCompression
+          )
+          return .roundRobin(loadBalancer)
         }
-      } else if policy.pickFirst != nil {
-        fatalError("TODO: use pick-first when supported")
-      } else {
-        // Policy isn't known, ignore it.
-        continue
       }
 
-      self.handleLoadBalancerChange(onUpdatePolicy, endpoints: endpoints, in: &group)
-      return
+    case .pickFirst(let pickFirst):
+      if pickFirst.shuffleAddressList {
+        endpoints[0].addresses.shuffle()
+      }
+
+      onUpdatePolicy = self.state.withLockedValue { state in
+        state.changeLoadBalancerKind(to: loadBalancerConfig) {
+          let loadBalancer = PickFirstLoadBalancer(
+            connector: self.connector,
+            backoff: self.backoff,
+            defaultCompression: self.defaultCompression,
+            enabledCompression: self.enabledCompression
+          )
+          return .pickFirst(loadBalancer)
+        }
+      }
     }
 
-    // No suitable config was found, fallback to pick-first.
-    fatalError("TODO: use pick-first when supported")
+    self.handleLoadBalancerChange(onUpdatePolicy, endpoints: endpoints, in: &group)
   }
 
   private func handleLoadBalancerChange(
@@ -620,23 +667,6 @@ extension GRPCChannel.StateMachine {
     }
   }
 
-  enum LoadBalancerKind {
-    case roundRobin
-    case pickFirst
-
-    func matches(loadBalancer: LoadBalancer) -> Bool {
-      switch (self, loadBalancer) {
-      case (.roundRobin, .roundRobin):
-        return true
-      case (.pickFirst, .pickFirst):
-        return true
-      case (.roundRobin, .pickFirst),
-        (.pickFirst, .roundRobin):
-        return false
-      }
-    }
-  }
-
   enum OnChangeLoadBalancer {
     case runLoadBalancer(LoadBalancer, stop: LoadBalancer?)
     case updateLoadBalancer(LoadBalancer)
@@ -644,7 +674,7 @@ extension GRPCChannel.StateMachine {
   }
 
   mutating func changeLoadBalancerKind(
-    to newLoadBalancerKind: LoadBalancerKind,
+    to newLoadBalancerKind: GRPCChannel.SupportedLoadBalancerConfig,
     _ makeLoadBalancer: () -> LoadBalancer
   ) -> OnChangeLoadBalancer {
     let onChangeLoadBalancer: OnChangeLoadBalancer

+ 200 - 0
Tests/GRPCHTTP2CoreTests/Client/Connection/GRPCChannelTests.swift

@@ -547,6 +547,206 @@ final class GRPCChannelTests: XCTestCase {
       }
     }
   }
+
+  func testLoadBalancerChangingFromRoundRobinToPickFirst() async throws {
+    // The test will push different configs to the resolver, first a round-robin LB, then a
+    // pick-first LB.
+    let (resolver, continuation) = NameResolver.dynamic(updateMode: .push)
+    let channel = GRPCChannel(
+      resolver: resolver,
+      connector: .posix(),
+      config: .defaults,
+      defaultServiceConfig: ServiceConfig()
+    )
+
+    // Start a few servers.
+    let server1 = TestServer(eventLoopGroup: .singletonMultiThreadedEventLoopGroup)
+    let address1 = try await server1.bind()
+
+    let server2 = TestServer(eventLoopGroup: .singletonMultiThreadedEventLoopGroup)
+    let address2 = try await server2.bind()
+
+    let server3 = TestServer(eventLoopGroup: .singletonMultiThreadedEventLoopGroup)
+    let address3 = try await server3.bind()
+
+    try await withThrowingTaskGroup(of: Void.self) { group in
+      // Run the servers, no RPCs will be run against them.
+      for server in [server1, server2, server3] {
+        group.addTask {
+          try await server.run(.never)
+        }
+      }
+
+      group.addTask {
+        await channel.connect()
+      }
+
+      for await event in channel.connectivityState {
+        switch event {
+        case .idle:
+          let endpoints = [address1, address2].map { Endpoint(addresses: [$0]) }
+          var serviceConfig = ServiceConfig()
+          serviceConfig.loadBalancingConfig = [.roundRobin]
+          let resolutionResult = NameResolutionResult(
+            endpoints: endpoints,
+            serviceConfig: .success(serviceConfig)
+          )
+
+          // Push the first resolution result which uses round robin. This will result in the
+          // channel becoming ready.
+          continuation.yield(resolutionResult)
+
+        case .ready:
+          // Channel is ready, server 1 and 2 should have clients shortly.
+          try await XCTPoll(every: .milliseconds(10)) {
+            server1.clients.count == 1 && server2.clients.count == 1 && server3.clients.count == 0
+          }
+
+          // Both subchannels are ready, prepare and yield an update to the resolver.
+          var serviceConfig = ServiceConfig()
+          serviceConfig.loadBalancingConfig = [.pickFirst(shuffleAddressList: false)]
+          let resolutionResult = NameResolutionResult(
+            endpoints: [Endpoint(addresses: [address3])],
+            serviceConfig: .success(serviceConfig)
+          )
+          continuation.yield(resolutionResult)
+
+          // Only server 3 should have a connection.
+          try await XCTPoll(every: .milliseconds(10)) {
+            server1.clients.count == 0 && server2.clients.count == 0 && server3.clients.count == 1
+          }
+
+          channel.close()
+
+        case .shutdown:
+          group.cancelAll()
+
+        default:
+          ()
+        }
+      }
+    }
+  }
+
+  func testPickFirstShufflingAddressList() async throws {
+    // This test checks that the pick first load-balancer has its address list shuffled. We can't
+    // assert this deterministically, so instead we'll run an experiment a number of times. Each
+    // round will create N servers and provide them as endpoints to the pick-first load balancer.
+    // The channel will establish a connection to one of the servers and its identity will be noted.
+    let numberOfRounds = 100
+    let numberOfServers = 2
+
+    let servers = (0 ..< numberOfServers).map { _ in
+      TestServer(eventLoopGroup: .singletonMultiThreadedEventLoopGroup)
+    }
+
+    var addresses = [SocketAddress]()
+    for server in servers {
+      let address = try await server.bind()
+      addresses.append(address)
+    }
+
+    let endpoint = Endpoint(addresses: addresses)
+    var counts = Array(repeating: 0, count: addresses.count)
+
+    // Supply service config on init, not via the load-balancer.
+    var serviceConfig = ServiceConfig()
+    serviceConfig.loadBalancingConfig = [.pickFirst(shuffleAddressList: true)]
+
+    try await withThrowingDiscardingTaskGroup { group in
+      // Run the servers.
+      for server in servers {
+        group.addTask {
+          try await server.run(.never)
+        }
+      }
+
+      // Run the experiment.
+      for _ in 0 ..< numberOfRounds {
+        let channel = GRPCChannel(
+          resolver: .static(endpoints: [endpoint]),
+          connector: .posix(),
+          config: .defaults,
+          defaultServiceConfig: serviceConfig
+        )
+
+        group.addTask {
+          await channel.connect()
+        }
+
+        for await state in channel.connectivityState {
+          switch state {
+          case .ready:
+            for index in servers.indices {
+              if servers[index].clients.count == 1 {
+                counts[index] += 1
+                break
+              }
+            }
+            channel.close()
+          default:
+            ()
+          }
+        }
+      }
+
+      // Stop the servers.
+      group.cancelAll()
+    }
+
+    // The address list is shuffled, so there's no guarantee how many times we'll hit each server.
+    // Assert that the minimum a server should be hit is 10% of the time.
+    let expected = Double(numberOfRounds) / Double(numberOfServers)
+    let minimum = expected * 0.1
+    XCTAssert(counts.allSatisfy({ Double($0) >= minimum }), "\(counts)")
+  }
+
+  func testPickFirstIsFallbackPolicy() async throws {
+    // Start a few servers.
+    let server1 = TestServer(eventLoopGroup: .singletonMultiThreadedEventLoopGroup)
+    let address1 = try await server1.bind()
+
+    let server2 = TestServer(eventLoopGroup: .singletonMultiThreadedEventLoopGroup)
+    let address2 = try await server2.bind()
+
+    // Prepare a channel with an empty service config.
+    let channel = GRPCChannel(
+      resolver: .static(endpoints: [Endpoint(address1, address2)]),
+      connector: .posix(),
+      config: .defaults,
+      defaultServiceConfig: ServiceConfig()
+    )
+
+    try await withThrowingDiscardingTaskGroup { group in
+      // Run the servers.
+      for server in [server1, server2] {
+        group.addTask {
+          try await server.run(.never)
+        }
+      }
+
+      group.addTask {
+        await channel.connect()
+      }
+
+      for try await state in channel.connectivityState {
+        switch state {
+        case .ready:
+          // Only server 1 should have a connection.
+          try await XCTPoll(every: .milliseconds(10)) {
+            server1.clients.count == 1 && server2.clients.count == 0
+          }
+
+          channel.close()
+
+        default:
+          ()
+        }
+      }
+
+      group.cancelAll()
+    }
+  }
 }
 
 @available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *)