2
0
Эх сурвалжийг харах

Add a request queue for the grpc channel (#1881)

Motivation:

The grpc channel needs to enqueue requests when the channel isn't ready
to handle RPCs. When the channel becomes ready, it can attempt to
execute the RPCs on a load balancer.

Modifications:

- Add a request queue. The queue stores continuations for a
  `LoadBalancer`. Elements can be removed in-order (popped) or by ID (in
  case of cancellation).
- Add a queue ID

Result:

Requests can be queued
George Barnett 1 жил өмнө
parent
commit
c4353e9e44

+ 69 - 0
Sources/GRPCHTTP2Core/Client/Connection/LoadBalancers/LoadBalancer.swift

@@ -0,0 +1,69 @@
+/*
+ * Copyright 2024, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *)
+enum LoadBalancer: Sendable {
+  case roundRobin(RoundRobinLoadBalancer)
+}
+
+@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *)
+extension LoadBalancer {
+  init(_ loadBalancer: RoundRobinLoadBalancer) {
+    self = .roundRobin(loadBalancer)
+  }
+
+  var id: LoadBalancerID {
+    switch self {
+    case .roundRobin(let loadBalancer):
+      return loadBalancer.id
+    }
+  }
+
+  var events: AsyncStream<LoadBalancerEvent> {
+    switch self {
+    case .roundRobin(let loadBalancer):
+      return loadBalancer.events
+    }
+  }
+
+  func run() async {
+    switch self {
+    case .roundRobin(let loadBalancer):
+      await loadBalancer.run()
+    }
+  }
+
+  func updateAddresses(_ endpoints: [Endpoint]) {
+    switch self {
+    case .roundRobin(let loadBalancer):
+      loadBalancer.updateAddresses(endpoints)
+    }
+  }
+
+  func close() {
+    switch self {
+    case .roundRobin(let loadBalancer):
+      loadBalancer.close()
+    }
+  }
+
+  func pickSubchannel() -> Subchannel? {
+    switch self {
+    case .roundRobin(let loadBalancer):
+      return loadBalancer.pickSubchannel()
+    }
+  }
+}

+ 4 - 0
Sources/GRPCHTTP2Core/Client/Connection/LoadBalancers/RoundRobinLoadBalancer.swift

@@ -117,6 +117,9 @@ struct RoundRobinLoadBalancer {
   /// The set of enabled compression algorithms.
   private let enabledCompression: CompressionAlgorithmSet
 
+  /// The ID of this load balancer.
+  internal let id: LoadBalancerID
+
   init(
     connector: any HTTP2Connector,
     backoff: ConnectionBackoff,
@@ -127,6 +130,7 @@ struct RoundRobinLoadBalancer {
     self.backoff = backoff
     self.defaultCompression = defaultCompression
     self.enabledCompression = enabledCompression
+    self.id = LoadBalancerID()
 
     self.event = AsyncStream.makeStream(of: LoadBalancerEvent.self)
     self.input = AsyncStream.makeStream(of: Input.self)

+ 105 - 0
Sources/GRPCHTTP2Core/Client/Connection/RequestQueue.swift

@@ -0,0 +1,105 @@
+/*
+ * Copyright 2024, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import DequeModule
+
+@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *)
+struct RequestQueue {
+  typealias Continuation = CheckedContinuation<LoadBalancer, Error>
+
+  private struct QueueEntry {
+    var continuation: Continuation
+    var waitForReady: Bool
+  }
+
+  /// IDs of entries in the order they should be processed.
+  ///
+  /// If an ID is popped from the queue but isn't present in `entriesByID` then it must've
+  /// been removed directly by its ID, this is fine.
+  private var ids: Deque<QueueEntryID>
+
+  /// Entries keyed by their ID.
+  private var entriesByID: [QueueEntryID: QueueEntry]
+
+  init() {
+    self.ids = []
+    self.entriesByID = [:]
+  }
+
+  /// Remove the first continuation from the queue.
+  mutating func popFirst() -> Continuation? {
+    while let id = self.ids.popFirst() {
+      if let waiter = self.entriesByID.removeValue(forKey: id) {
+        return waiter.continuation
+      }
+    }
+
+    assert(self.entriesByID.isEmpty)
+    return nil
+  }
+
+  /// Append a continuation to the queue.
+  ///
+  /// - Parameters:
+  ///   - continuation: The continuation to append.
+  ///   - waitForReady: Whether the request associated with the continuation is willing to wait for
+  ///       the channel to become ready.
+  ///   - id: The unique ID of the queue entry.
+  mutating func append(continuation: Continuation, waitForReady: Bool, id: QueueEntryID) {
+    let entry = QueueEntry(continuation: continuation, waitForReady: waitForReady)
+    let removed = self.entriesByID.updateValue(entry, forKey: id)
+    assert(removed == nil, "id '\(id)' reused")
+    self.ids.append(id)
+  }
+
+  /// Remove the waiter with the given ID, if it exists.
+  mutating func removeEntry(withID id: QueueEntryID) -> Continuation? {
+    let waiter = self.entriesByID.removeValue(forKey: id)
+    return waiter?.continuation
+  }
+
+  /// Remove all waiters, returning their continuations.
+  mutating func removeAll() -> [Continuation] {
+    let continuations = Array(self.entriesByID.values.map { $0.continuation })
+    self.ids.removeAll(keepingCapacity: true)
+    self.entriesByID.removeAll(keepingCapacity: true)
+    return continuations
+  }
+
+  /// Remove all entries which were appended to the queue with a value of `false`
+  /// for `waitForReady`.
+  mutating func removeFastFailingEntries() -> [Continuation] {
+    var removed = [Continuation]()
+    var remainingIDs = Deque<QueueEntryID>()
+    var remainingEntriesByID = [QueueEntryID: QueueEntry]()
+
+    while let id = self.ids.popFirst() {
+      guard let waiter = self.entriesByID.removeValue(forKey: id) else { continue }
+
+      if waiter.waitForReady {
+        remainingEntriesByID[id] = waiter
+        remainingIDs.append(id)
+      } else {
+        removed.append(waiter.continuation)
+      }
+    }
+
+    assert(self.entriesByID.isEmpty)
+    self.entriesByID = remainingEntriesByID
+    self.ids = remainingIDs
+    return removed
+  }
+}

+ 8 - 0
Sources/GRPCHTTP2Core/Internal/ProcessUniqueID.swift

@@ -45,3 +45,11 @@ struct LoadBalancerID: Hashable, Sendable, CustomStringConvertible {
     "lb_\(self.id)"
   }
 }
+
+/// A process-unique ID for an entry in a queue.
+struct QueueEntryID: Hashable, Sendable, CustomStringConvertible {
+  private let id = ProcessUniqueID()
+  var description: String {
+    "q_entry_\(self.id)"
+  }
+}

+ 259 - 0
Tests/GRPCHTTP2CoreTests/Client/Connection/RequestQueueTests.swift

@@ -0,0 +1,259 @@
+/*
+ * Copyright 2024, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import GRPCCore
+import XCTest
+
+@testable import GRPCHTTP2Core
+
+@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *)
+final class RequestQueueTests: XCTestCase {
+  struct AnErrorToAvoidALeak: Error {}
+
+  func testPopFirstEmpty() {
+    var queue = RequestQueue()
+    XCTAssertNil(queue.popFirst())
+  }
+
+  func testPopFirstNonEmpty() async {
+    _ = try? await withCheckedThrowingContinuation { continuation in
+      var queue = RequestQueue()
+      let id = QueueEntryID()
+
+      queue.append(continuation: continuation, waitForReady: false, id: id)
+      guard let popped = queue.popFirst() else {
+        return XCTFail("Missing continuation")
+      }
+      XCTAssertNil(queue.popFirst())
+
+      popped.resume(throwing: AnErrorToAvoidALeak())
+    }
+  }
+
+  func testPopFirstMultiple() async {
+    await withTaskGroup(of: QueueEntryID.self) { group in
+      let queue = _LockedValueBox(RequestQueue())
+      let signal1 = AsyncStream.makeStream(of: Void.self)
+      let signal2 = AsyncStream.makeStream(of: Void.self)
+
+      let id1 = QueueEntryID()
+      let id2 = QueueEntryID()
+
+      group.addTask {
+        _ = try? await withCheckedThrowingContinuation { continuation in
+          queue.withLockedValue {
+            $0.append(continuation: continuation, waitForReady: false, id: id1)
+          }
+
+          signal1.continuation.yield()
+          signal1.continuation.finish()
+        }
+
+        return id1
+      }
+
+      group.addTask {
+        // Wait until instructed to append.
+        for await _ in signal1.stream {}
+
+        _ = try? await withCheckedThrowingContinuation { continuation in
+          queue.withLockedValue {
+            $0.append(continuation: continuation, waitForReady: false, id: id2)
+          }
+
+          signal2.continuation.yield()
+          signal2.continuation.finish()
+        }
+
+        return id2
+      }
+
+      // Wait for both continuations to be enqueued.
+      for await _ in signal2.stream {}
+
+      for id in [id1, id2] {
+        let continuation = queue.withLockedValue { $0.popFirst() }
+        continuation?.resume(throwing: AnErrorToAvoidALeak())
+        let actual = await group.next()
+        XCTAssertEqual(id, actual)
+      }
+    }
+  }
+
+  func testRemoveEntryByID() async {
+    _ = try? await withCheckedThrowingContinuation { continuation in
+      var queue = RequestQueue()
+      let id = QueueEntryID()
+
+      queue.append(continuation: continuation, waitForReady: false, id: id)
+      guard let popped = queue.removeEntry(withID: id) else {
+        return XCTFail("Missing continuation")
+      }
+      XCTAssertNil(queue.removeEntry(withID: id))
+
+      popped.resume(throwing: AnErrorToAvoidALeak())
+    }
+  }
+
+  func testRemoveEntryByIDMultiple() async {
+    await withTaskGroup(of: QueueEntryID.self) { group in
+      let queue = _LockedValueBox(RequestQueue())
+      let signal1 = AsyncStream.makeStream(of: Void.self)
+      let signal2 = AsyncStream.makeStream(of: Void.self)
+
+      let id1 = QueueEntryID()
+      let id2 = QueueEntryID()
+
+      group.addTask {
+        _ = try? await withCheckedThrowingContinuation { continuation in
+          queue.withLockedValue {
+            $0.append(continuation: continuation, waitForReady: false, id: id1)
+          }
+
+          signal1.continuation.yield()
+          signal1.continuation.finish()
+        }
+
+        return id1
+      }
+
+      group.addTask {
+        // Wait until instructed to append.
+        for await _ in signal1.stream {}
+
+        _ = try? await withCheckedThrowingContinuation { continuation in
+          queue.withLockedValue {
+            $0.append(continuation: continuation, waitForReady: false, id: id2)
+          }
+
+          signal2.continuation.yield()
+          signal2.continuation.finish()
+        }
+
+        return id2
+      }
+
+      // Wait for both continuations to be enqueued.
+      for await _ in signal2.stream {}
+
+      for id in [id1, id2] {
+        let continuation = queue.withLockedValue { $0.removeEntry(withID: id) }
+        continuation?.resume(throwing: AnErrorToAvoidALeak())
+        let actual = await group.next()
+        XCTAssertEqual(id, actual)
+      }
+    }
+  }
+
+  func testRemoveFastFailingEntries() async throws {
+    let queue = _LockedValueBox(RequestQueue())
+    let enqueued = AsyncStream.makeStream(of: Void.self)
+
+    try await withThrowingTaskGroup(of: Void.self) { group in
+      var waitForReadyIDs = [QueueEntryID]()
+      var failFastIDs = [QueueEntryID]()
+
+      for _ in 0 ..< 50 {
+        waitForReadyIDs.append(QueueEntryID())
+        failFastIDs.append(QueueEntryID())
+      }
+
+      for ids in [waitForReadyIDs, failFastIDs] {
+        let waitForReady = ids == waitForReadyIDs
+        for id in ids {
+          group.addTask {
+            do {
+              _ = try await withCheckedThrowingContinuation { continuation in
+                queue.withLockedValue {
+                  $0.append(continuation: continuation, waitForReady: waitForReady, id: id)
+                }
+                enqueued.continuation.yield()
+              }
+            } catch is AnErrorToAvoidALeak {
+              ()
+            }
+          }
+        }
+      }
+
+      // Wait for all continuations to be enqueued.
+      var numberEnqueued = 0
+      for await _ in enqueued.stream {
+        numberEnqueued += 1
+        if numberEnqueued == (waitForReadyIDs.count + failFastIDs.count) {
+          enqueued.continuation.finish()
+        }
+      }
+
+      // Remove all fast-failing continuations.
+      let continuations = queue.withLockedValue {
+        $0.removeFastFailingEntries()
+      }
+
+      for continuation in continuations {
+        continuation.resume(throwing: AnErrorToAvoidALeak())
+      }
+
+      for id in failFastIDs {
+        queue.withLockedValue {
+          XCTAssertNil($0.removeEntry(withID: id))
+        }
+      }
+
+      for id in waitForReadyIDs {
+        let maybeContinuation = queue.withLockedValue { $0.removeEntry(withID: id) }
+        let continuation = try XCTUnwrap(maybeContinuation)
+        continuation.resume(throwing: AnErrorToAvoidALeak())
+      }
+    }
+  }
+
+  func testRemoveAll() async throws {
+    let queue = _LockedValueBox(RequestQueue())
+    let enqueued = AsyncStream.makeStream(of: Void.self)
+
+    await withThrowingTaskGroup(of: Void.self) { group in
+      for _ in 0 ..< 10 {
+        group.addTask {
+          _ = try await withCheckedThrowingContinuation { continuation in
+            queue.withLockedValue {
+              $0.append(continuation: continuation, waitForReady: false, id: QueueEntryID())
+            }
+
+            enqueued.continuation.yield()
+          }
+        }
+      }
+
+      // Wait for all continuations to be enqueued.
+      var numberEnqueued = 0
+      for await _ in enqueued.stream {
+        numberEnqueued += 1
+        if numberEnqueued == 10 {
+          enqueued.continuation.finish()
+        }
+      }
+
+      let continuations = queue.withLockedValue { $0.removeAll() }
+      XCTAssertEqual(continuations.count, 10)
+      XCTAssertNil(queue.withLockedValue { $0.popFirst() })
+
+      for continuation in continuations {
+        continuation.resume(throwing: AnErrorToAvoidALeak())
+      }
+    }
+  }
+}

+ 6 - 0
Tests/GRPCHTTP2CoreTests/Internal/ProcessUniqueIDTests.swift

@@ -48,4 +48,10 @@ final class ProcessUniqueIDTests: XCTestCase {
     let description = String(describing: id)
     XCTAssert(description.hasPrefix("lb_"))
   }
+
+  func testQueueEntryDescription() {
+    let id = QueueEntryID()
+    let description = String(describing: id)
+    XCTAssert(description.hasPrefix("q_entry_"))
+  }
 }