Browse Source

Refactor round robin tests (#1894)

Motivation:

The test harness for the round-robin load balancer tests doesn't need to
be so tightly coupled to round-robin tests. It can also be used for the
pick-first load balancer.

Modifications:

- Move the RR LB test into its own file
- Refactor so the test context holds a `LoadBalancer` rather than a
  `RoundRobinLoadBalancer`

Result:

Can be easily extended for pick-first LB tests
George Barnett 1 year ago
parent
commit
90bea733b5

+ 148 - 0
Tests/GRPCHTTP2CoreTests/Client/Connection/LoadBalancers/LoadBalancerTest.swift

@@ -0,0 +1,148 @@
+/*
+ * Copyright 2024, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+@_spi(Package) @testable import GRPCHTTP2Core
+import XCTest
+
+@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *)
+enum LoadBalancerTest {
+  struct Context {
+    let servers: [(server: TestServer, address: GRPCHTTP2Core.SocketAddress)]
+    let loadBalancer: LoadBalancer
+  }
+
+  static func roundRobin(
+    servers serverCount: Int,
+    connector: any HTTP2Connector,
+    backoff: ConnectionBackoff = .defaults,
+    timeout: Duration = .seconds(10),
+    function: String = #function,
+    handleEvent: @escaping @Sendable (Context, LoadBalancerEvent) async throws -> Void,
+    verifyEvents: @escaping @Sendable ([LoadBalancerEvent]) -> Void = { _ in }
+  ) async throws {
+    try await Self.run(
+      servers: serverCount,
+      timeout: timeout,
+      function: function,
+      handleEvent: handleEvent,
+      verifyEvents: verifyEvents
+    ) {
+      let roundRobin = RoundRobinLoadBalancer(
+        connector: connector,
+        backoff: backoff,
+        defaultCompression: .none,
+        enabledCompression: .none
+      )
+      return .roundRobin(roundRobin)
+    }
+  }
+
+  private static func run(
+    servers serverCount: Int,
+    timeout: Duration,
+    function: String,
+    handleEvent: @escaping @Sendable (Context, LoadBalancerEvent) async throws -> Void,
+    verifyEvents: @escaping @Sendable ([LoadBalancerEvent]) -> Void = { _ in },
+    makeLoadBalancer: @escaping @Sendable () -> LoadBalancer
+  ) async throws {
+    enum TestEvent {
+      case timedOut
+      case completed(Result<Void, Error>)
+    }
+
+    try await withThrowingTaskGroup(of: TestEvent.self) { group in
+      group.addTask {
+        try? await Task.sleep(for: timeout)
+        return .timedOut
+      }
+
+      group.addTask {
+        do {
+          try await Self._run(
+            servers: serverCount,
+            handleEvent: handleEvent,
+            verifyEvents: verifyEvents,
+            makeLoadBalancer: makeLoadBalancer
+          )
+          return .completed(.success(()))
+        } catch {
+          return .completed(.failure(error))
+        }
+      }
+
+      let result = try await group.next()!
+      group.cancelAll()
+
+      switch result {
+      case .timedOut:
+        XCTFail("'\(function)' timed out after \(timeout)")
+      case .completed(let result):
+        try result.get()
+      }
+    }
+  }
+
+  private static func _run(
+    servers serverCount: Int,
+    handleEvent: @escaping @Sendable (Context, LoadBalancerEvent) async throws -> Void,
+    verifyEvents: @escaping @Sendable ([LoadBalancerEvent]) -> Void,
+    makeLoadBalancer: @escaping @Sendable () -> LoadBalancer
+  ) async throws {
+    try await withThrowingTaskGroup(of: Void.self) { group in
+      // Create the test servers.
+      var servers = [(server: TestServer, address: GRPCHTTP2Core.SocketAddress)]()
+      for _ in 0 ..< serverCount {
+        let server = TestServer(eventLoopGroup: .singletonMultiThreadedEventLoopGroup)
+        let address = try await server.bind()
+        servers.append((server, address))
+
+        group.addTask {
+          try await server.run { _, _ in
+            XCTFail("Unexpected stream")
+          }
+        }
+      }
+
+      // Create the load balancer.
+      let loadBalancer = makeLoadBalancer()
+
+      group.addTask {
+        await loadBalancer.run()
+      }
+
+      let context = Context(servers: servers, loadBalancer: loadBalancer)
+
+      var events = [LoadBalancerEvent]()
+      for await event in loadBalancer.events {
+        events.append(event)
+        try await handleEvent(context, event)
+      }
+
+      verifyEvents(events)
+      group.cancelAll()
+    }
+  }
+}
+
+@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *)
+extension LoadBalancerTest.Context {
+  var roundRobin: RoundRobinLoadBalancer? {
+    switch self.loadBalancer {
+    case .roundRobin(let loadBalancer):
+      return loadBalancer
+    }
+  }
+}

+ 18 - 121
Tests/GRPCHTTP2CoreTests/Client/Connection/LoadBalancers/RoundRobinLoadBalancerTests.swift

@@ -24,13 +24,13 @@ import XCTest
 @available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *)
 final class RoundRobinLoadBalancerTests: XCTestCase {
   func testMultipleConnectionsAreEstablished() async throws {
-    try await RoundRobinLoadBalancerTest.run(servers: 3, connector: .posix()) { context, event in
+    try await LoadBalancerTest.roundRobin(servers: 3, connector: .posix()) { context, event in
       switch event {
       case .connectivityStateChanged(.idle):
         // Update the addresses for the load balancer, this will trigger subchannels to be created
         // for each.
         let endpoints = context.servers.map { Endpoint(addresses: [$0.address]) }
-        context.loadBalancer.updateAddresses(endpoints)
+        context.roundRobin!.updateAddresses(endpoints)
 
       case .connectivityStateChanged(.ready):
         // Poll until each server has one connected client.
@@ -56,13 +56,13 @@ final class RoundRobinLoadBalancerTests: XCTestCase {
   }
 
   func testSubchannelsArePickedEvenly() async throws {
-    try await RoundRobinLoadBalancerTest.run(servers: 3, connector: .posix()) { context, event in
+    try await LoadBalancerTest.roundRobin(servers: 3, connector: .posix()) { context, event in
       switch event {
       case .connectivityStateChanged(.idle):
         // Update the addresses for the load balancer, this will trigger subchannels to be created
         // for each.
         let endpoints = context.servers.map { Endpoint(addresses: [$0.address]) }
-        context.loadBalancer.updateAddresses(endpoints)
+        context.roundRobin!.updateAddresses(endpoints)
 
       case .connectivityStateChanged(.ready):
         // Subchannel is ready. This happens when any subchannel becomes ready. Loop until
@@ -110,12 +110,12 @@ final class RoundRobinLoadBalancerTests: XCTestCase {
   }
 
   func testAddressUpdatesAreHandledGracefully() async throws {
-    try await RoundRobinLoadBalancerTest.run(servers: 3, connector: .posix()) { context, event in
+    try await LoadBalancerTest.roundRobin(servers: 3, connector: .posix()) { context, event in
       switch event {
       case .connectivityStateChanged(.idle):
         // Do the first connect.
         let endpoints = [Endpoint(addresses: [context.servers[0].address])]
-        context.loadBalancer.updateAddresses(endpoints)
+        context.roundRobin!.updateAddresses(endpoints)
 
       case .connectivityStateChanged(.ready):
         // Now the first connection should be established.
@@ -131,7 +131,7 @@ final class RoundRobinLoadBalancerTests: XCTestCase {
             Endpoint(addresses: [context.servers[0].address]),
             Endpoint(addresses: [context.servers[1].address]),
           ]
-          context.loadBalancer.updateAddresses(endpoints)
+          context.roundRobin!.updateAddresses(endpoints)
 
           try await XCTPoll(every: .milliseconds(10)) {
             context.servers.prefix(2).allSatisfy { $0.server.clients.count == 1 }
@@ -141,7 +141,7 @@ final class RoundRobinLoadBalancerTests: XCTestCase {
         // Remove those two endpoints and add a third.
         do {
           let endpoints = [Endpoint(addresses: [context.servers[2].address])]
-          context.loadBalancer.updateAddresses(endpoints)
+          context.roundRobin!.updateAddresses(endpoints)
 
           try await XCTPoll(every: .milliseconds(10)) {
             let disconnected = context.servers.prefix(2).allSatisfy { $0.server.clients.isEmpty }
@@ -169,16 +169,16 @@ final class RoundRobinLoadBalancerTests: XCTestCase {
   }
 
   func testSameAddressUpdatesAreIgnored() async throws {
-    try await RoundRobinLoadBalancerTest.run(servers: 3, connector: .posix()) { context, event in
+    try await LoadBalancerTest.roundRobin(servers: 3, connector: .posix()) { context, event in
       switch event {
       case .connectivityStateChanged(.idle):
         let endpoints = context.servers.map { _, address in Endpoint(addresses: [address]) }
-        context.loadBalancer.updateAddresses(endpoints)
+        context.roundRobin!.updateAddresses(endpoints)
 
       case .connectivityStateChanged(.ready):
         // Update with the same addresses, these should be ignored.
         let endpoints = context.servers.map { _, address in Endpoint(addresses: [address]) }
-        context.loadBalancer.updateAddresses(endpoints)
+        context.roundRobin!.updateAddresses(endpoints)
 
         // We should still have three connections.
         try await XCTPoll(every: .milliseconds(10)) {
@@ -202,15 +202,15 @@ final class RoundRobinLoadBalancerTests: XCTestCase {
   }
 
   func testEmptyAddressUpdatesAreIgnored() async throws {
-    try await RoundRobinLoadBalancerTest.run(servers: 3, connector: .posix()) { context, event in
+    try await LoadBalancerTest.roundRobin(servers: 3, connector: .posix()) { context, event in
       switch event {
       case .connectivityStateChanged(.idle):
         let endpoints = context.servers.map { _, address in Endpoint(addresses: [address]) }
-        context.loadBalancer.updateAddresses(endpoints)
+        context.roundRobin!.updateAddresses(endpoints)
 
       case .connectivityStateChanged(.ready):
         // Update with no-addresses, should be ignored so a subchannel can still be picked.
-        context.loadBalancer.updateAddresses([])
+        context.roundRobin!.updateAddresses([])
 
         // We should still have three connections.
         try await XCTPoll(every: .milliseconds(10)) {
@@ -234,12 +234,12 @@ final class RoundRobinLoadBalancerTests: XCTestCase {
   }
 
   func testSubchannelReceivesGoAway() async throws {
-    try await RoundRobinLoadBalancerTest.run(servers: 3, connector: .posix()) { context, event in
+    try await LoadBalancerTest.roundRobin(servers: 3, connector: .posix()) { context, event in
       switch event {
       case .connectivityStateChanged(.idle):
         // Trigger the connect.
         let endpoints = context.servers.map { Endpoint(addresses: [$0.address]) }
-        context.loadBalancer.updateAddresses(endpoints)
+        context.roundRobin!.updateAddresses(endpoints)
 
       case .connectivityStateChanged(.ready):
         // Wait for all servers to become ready.
@@ -284,7 +284,6 @@ final class RoundRobinLoadBalancerTests: XCTestCase {
       default:
         ()
       }
-
     } verifyEvents: { events in
       let expected: [LoadBalancerEvent] = [
         .connectivityStateChanged(.idle),
@@ -326,7 +325,7 @@ final class RoundRobinLoadBalancerTests: XCTestCase {
     let idle = ManagedAtomic(0)
     let ready = ManagedAtomic(0)
 
-    try await RoundRobinLoadBalancerTest.run(
+    try await LoadBalancerTest.roundRobin(
       servers: 1,
       connector: .posix(maxIdleTime: .milliseconds(25))  // Aggressively idle the connection
     ) { context, event in
@@ -340,7 +339,7 @@ final class RoundRobinLoadBalancerTests: XCTestCase {
           // which it will connect to. Wait for it to be ready and then idle again.
           let address = context.servers[0].address
           let endpoints = [Endpoint(addresses: [address])]
-          context.loadBalancer.updateAddresses(endpoints)
+          context.roundRobin!.updateAddresses(endpoints)
 
         case 2:
           // Load-balancer has the endpoints but all are idle. Picking will trigger a connect.
@@ -379,105 +378,3 @@ final class RoundRobinLoadBalancerTests: XCTestCase {
     }
   }
 }
-
-@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *)
-enum RoundRobinLoadBalancerTest {
-  struct Context {
-    let servers: [(server: TestServer, address: GRPCHTTP2Core.SocketAddress)]
-    let loadBalancer: RoundRobinLoadBalancer
-  }
-
-  static func run(
-    servers serverCount: Int,
-    connector: any HTTP2Connector,
-    backoff: ConnectionBackoff = .defaults,
-    timeout: Duration = .seconds(10),
-    function: String = #function,
-    handleEvent: @escaping @Sendable (Context, LoadBalancerEvent) async throws -> Void,
-    verifyEvents: @escaping @Sendable ([LoadBalancerEvent]) -> Void = { _ in }
-  ) async throws {
-    enum TestEvent {
-      case timedOut
-      case completed(Result<Void, Error>)
-    }
-
-    try await withThrowingTaskGroup(of: TestEvent.self) { group in
-      group.addTask {
-        try? await Task.sleep(for: timeout)
-        return .timedOut
-      }
-
-      group.addTask {
-        do {
-          try await Self._run(
-            servers: serverCount,
-            connector: connector,
-            backoff: backoff,
-            handleEvent: handleEvent,
-            verifyEvents: verifyEvents
-          )
-          return .completed(.success(()))
-        } catch {
-          return .completed(.failure(error))
-        }
-      }
-
-      let result = try await group.next()!
-      group.cancelAll()
-
-      switch result {
-      case .timedOut:
-        XCTFail("'\(function)' timed out after \(timeout)")
-      case .completed(let result):
-        try result.get()
-      }
-    }
-  }
-
-  private static func _run(
-    servers serverCount: Int,
-    connector: some HTTP2Connector,
-    backoff: ConnectionBackoff,
-    handleEvent: @escaping @Sendable (Context, LoadBalancerEvent) async throws -> Void,
-    verifyEvents: @escaping @Sendable ([LoadBalancerEvent]) -> Void
-  ) async throws {
-    try await withThrowingTaskGroup(of: Void.self) { group in
-      // Create the test servers.
-      var servers = [(server: TestServer, address: GRPCHTTP2Core.SocketAddress)]()
-      for _ in 1 ... serverCount {
-        let server = TestServer(eventLoopGroup: .singletonMultiThreadedEventLoopGroup)
-        let address = try await server.bind()
-        servers.append((server, address))
-
-        group.addTask {
-          try await server.run { _, _ in
-            XCTFail("Unexpected stream")
-          }
-        }
-      }
-
-      // Create the load balancer.
-      let loadBalancer = RoundRobinLoadBalancer(
-        connector: connector,
-        backoff: backoff,
-        defaultCompression: .none,
-        enabledCompression: .none
-      )
-
-      group.addTask {
-        await loadBalancer.run()
-      }
-
-      let context = Context(servers: servers, loadBalancer: loadBalancer)
-
-      var events = [LoadBalancerEvent]()
-      for await event in loadBalancer.events {
-        events.append(event)
-        try await handleEvent(context, event)
-      }
-
-      verifyEvents(events)
-      group.cancelAll()
-    }
-  }
-}