Browse Source

Add a passthrough message source and sequence (#1252)

Motivation:

`AsyncThrowingStream` provides an implementation of `AsyncSequence`
which allows the holder to provide new values to that sequence from
within a closure provided to the initializer. This API doesn't fit our
needs: we must be able to provide the values 'from the outside' rather
than during initialization.

Modifications:

- Add a `PassthroughMessageSequence`, an implementation of 
  `AsyncSequence` which consumes messages from a
  `PassthroughMessagesSource`.
- The source  may have values provided to it via `yield(_:)` and terminated
  with `finish()` or `finish(throwing:)`.
- Add tests and a few `AsyncSequence` helpers.

Result:

We have an `AsyncSequence` implementation which can have values provided
to it.
George Barnett 4 years ago
parent
commit
65cbfca60a

+ 58 - 0
Sources/GRPC/AsyncAwaitSupport/PassthroughMessageSequence.swift

@@ -0,0 +1,58 @@
+/*
+ * Copyright 2021, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#if compiler(>=5.5)
+
+/// An ``AsyncSequence`` adapter for a ``PassthroughMessageSource``.`
+@available(macOS 12, iOS 15, tvOS 15, watchOS 8, *)
+@usableFromInline
+internal struct PassthroughMessageSequence<Element, Failure: Error>: AsyncSequence {
+  @usableFromInline
+  internal typealias Element = Element
+
+  @usableFromInline
+  internal typealias AsyncIterator = Iterator
+
+  /// The source of messages in the sequence.
+  @usableFromInline
+  internal let _source: PassthroughMessageSource<Element, Failure>
+
+  @usableFromInline
+  internal func makeAsyncIterator() -> Iterator {
+    return Iterator(storage: self._source)
+  }
+
+  @usableFromInline
+  internal init(consuming source: PassthroughMessageSource<Element, Failure>) {
+    self._source = source
+  }
+
+  @usableFromInline
+  internal struct Iterator: AsyncIteratorProtocol {
+    @usableFromInline
+    internal let _storage: PassthroughMessageSource<Element, Failure>
+
+    fileprivate init(storage: PassthroughMessageSource<Element, Failure>) {
+      self._storage = storage
+    }
+
+    @inlinable
+    internal func next() async throws -> Element? {
+      return try await self._storage.consumeNextElement()
+    }
+  }
+}
+
+#endif // compiler(>=5.5)

+ 162 - 0
Sources/GRPC/AsyncAwaitSupport/PassthroughMessageSource.swift

@@ -0,0 +1,162 @@
+/*
+ * Copyright 2021, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#if compiler(>=5.5)
+import NIOConcurrencyHelpers
+import NIOCore
+
+/// The source of messages for a ``PassthroughMessageSequence``.`
+///
+/// Values may be provided to the source with calls to ``yield(_:)`` which returns whether the value
+/// was accepted (and how many values are yet to be consumed) -- or dropped.
+///
+/// The backing storage has an unbounded capacity and callers should use the number of unconsumed
+/// values returned from ``yield(_:)`` as an indication of when to stop providing values.
+///
+/// The source must be finished exactly once by calling ``finish()`` or ``finish(throwing:)`` to
+/// indicate that the sequence should end with an error.
+@available(macOS 12, iOS 15, tvOS 15, watchOS 8, *)
+@usableFromInline
+internal final class PassthroughMessageSource<Element, Failure: Error> {
+  @usableFromInline
+  internal typealias _ContinuationResult = Result<Element?, Error>
+
+  /// All state in this class must be accessed via the lock.
+  ///
+  /// - Important: We use a `class` with a lock rather than an `actor` as we must guarantee that
+  ///   calls to ``yield(_:)`` are not reordered.
+  @usableFromInline
+  internal let _lock: Lock
+
+  /// A queue of elements which may be consumed as soon as there is demand.
+  @usableFromInline
+  internal var _continuationResults: CircularBuffer<_ContinuationResult>
+
+  /// A continuation which will be resumed in the future. The continuation must be `nil`
+  /// if ``continuationResults`` is not empty.
+  @usableFromInline
+  internal var _continuation: Optional<CheckedContinuation<Element?, Error>>
+
+  /// True if a terminal continuation result (`.success(nil)` or `.failure()`) has been seen.
+  /// No more values may be enqueued to `continuationResults` if this is `true`.
+  @usableFromInline
+  internal var _isTerminated: Bool
+
+  @usableFromInline
+  internal init(initialBufferCapacity: Int = 16) {
+    self._lock = Lock()
+    self._continuationResults = CircularBuffer(initialCapacity: initialBufferCapacity)
+    self._continuation = nil
+    self._isTerminated = false
+  }
+
+  // MARK: - Append / Yield
+
+  @usableFromInline
+  internal enum YieldResult: Hashable {
+    /// The value was accepted. The `queueDepth` indicates how many elements are waiting to be
+    /// consumed.
+    ///
+    /// If `queueDepth` is zero then the value was consumed immediately.
+    case accepted(queueDepth: Int)
+
+    /// The value was dropped because the source has already been finished.
+    case dropped
+  }
+
+  @inlinable
+  internal func yield(_ element: Element) -> YieldResult {
+    let continuationResult: _ContinuationResult = .success(element)
+    return self._yield(continuationResult, isTerminator: false)
+  }
+
+  @inlinable
+  internal func finish(throwing error: Failure? = nil) -> YieldResult {
+    let continuationResult: _ContinuationResult = error.map { .failure($0) } ?? .success(nil)
+    return self._yield(continuationResult, isTerminator: true)
+  }
+
+  @usableFromInline
+  internal enum _YieldResult {
+    /// The sequence has already been terminated; drop the element.
+    case alreadyTerminated
+    /// The element was added to the queue to be consumed later.
+    case queued(Int)
+    /// Demand for an element already existed: complete the continuation with the result being
+    /// yielded.
+    case resume(CheckedContinuation<Element?, Error>)
+  }
+
+  @inlinable
+  internal func _yield(
+    _ continuationResult: _ContinuationResult, isTerminator: Bool
+  ) -> YieldResult {
+    let result: _YieldResult = self._lock.withLock {
+      if self._isTerminated {
+        return .alreadyTerminated
+      } else if let continuation = self._continuation {
+        self._continuation = nil
+        return .resume(continuation)
+      } else {
+        self._isTerminated = isTerminator
+        self._continuationResults.append(continuationResult)
+        return .queued(self._continuationResults.count)
+      }
+    }
+
+    let yieldResult: YieldResult
+    switch result {
+    case let .queued(size):
+      yieldResult = .accepted(queueDepth: size)
+    case let .resume(continuation):
+      // If we resume a continuation then the queue must be empty
+      yieldResult = .accepted(queueDepth: 0)
+      continuation.resume(with: continuationResult)
+    case .alreadyTerminated:
+      yieldResult = .dropped
+    }
+
+    return yieldResult
+  }
+
+  // MARK: - Next
+
+  @inlinable
+  internal func consumeNextElement() async throws -> Element? {
+    return try await withCheckedThrowingContinuation {
+      self._consumeNextElement(continuation: $0)
+    }
+  }
+
+  @inlinable
+  internal func _consumeNextElement(continuation: CheckedContinuation<Element?, Error>) {
+    let continuationResult: _ContinuationResult? = self._lock.withLock {
+      if let nextResult = self._continuationResults.popFirst() {
+        return nextResult
+      } else {
+        // Nothing buffered and not terminated yet: save the continuation for later.
+        assert(self._continuation == nil)
+        self._continuation = continuation
+        return nil
+      }
+    }
+
+    if let continuationResult = continuationResult {
+      continuation.resume(with: continuationResult)
+    }
+  }
+}
+
+#endif // compiler(>=5.5)

+ 31 - 0
Tests/GRPCTests/AsyncAwaitSupport/AsyncSequence+Helpers.swift

@@ -0,0 +1,31 @@
+/*
+ * Copyright 2021, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#if compiler(>=5.5)
+
+@available(macOS 12, iOS 15, tvOS 15, watchOS 8, *)
+extension AsyncSequence {
+  internal func collect() async throws -> [Element] {
+    return try await self.reduce(into: []) { accumulated, next in
+      accumulated.append(next)
+    }
+  }
+
+  internal func count() async throws -> Int {
+    return try await self.reduce(0) { count, _ in count + 1 }
+  }
+}
+
+#endif // compiler(>=5.5)

+ 145 - 0
Tests/GRPCTests/AsyncAwaitSupport/PassthroughMessageSourceTests.swift

@@ -0,0 +1,145 @@
+/*
+ * Copyright 2021, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#if compiler(>=5.5)
+@testable import GRPC
+import XCTest
+
+@available(macOS 12, iOS 15, tvOS 15, watchOS 8, *)
+class PassthroughMessageSourceTests: GRPCTestCase {
+  func testBasicUsage() {
+    XCTAsyncTest {
+      let source = PassthroughMessageSource<String, Never>()
+      let sequence = PassthroughMessageSequence(consuming: source)
+
+      XCTAssertEqual(source.yield("foo"), .accepted(queueDepth: 1))
+      XCTAssertEqual(source.yield("bar"), .accepted(queueDepth: 2))
+      XCTAssertEqual(source.yield("baz"), .accepted(queueDepth: 3))
+
+      let firstTwo = try await sequence.prefix(2).collect()
+      XCTAssertEqual(firstTwo, ["foo", "bar"])
+
+      XCTAssertEqual(source.yield("bar"), .accepted(queueDepth: 2))
+      XCTAssertEqual(source.yield("foo"), .accepted(queueDepth: 3))
+
+      XCTAssertEqual(source.finish(), .accepted(queueDepth: 4))
+
+      let theRest = try await sequence.collect()
+      XCTAssertEqual(theRest, ["baz", "bar", "foo"])
+    }
+  }
+
+  func testFinishWithError() {
+    XCTAsyncTest {
+      let source = PassthroughMessageSource<String, TestError>()
+
+      XCTAssertEqual(source.yield("one"), .accepted(queueDepth: 1))
+      XCTAssertEqual(source.yield("two"), .accepted(queueDepth: 2))
+      XCTAssertEqual(source.yield("three"), .accepted(queueDepth: 3))
+      XCTAssertEqual(source.finish(throwing: TestError()), .accepted(queueDepth: 4))
+
+      // We should still be able to get the elements before the error.
+      let sequence = PassthroughMessageSequence(consuming: source)
+      let elements = try await sequence.prefix(3).collect()
+      XCTAssertEqual(elements, ["one", "two", "three"])
+
+      do {
+        for try await element in sequence {
+          XCTFail("Unexpected value '\(element)'")
+        }
+        XCTFail("AsyncSequence did not throw")
+      } catch {
+        XCTAssert(error is TestError)
+      }
+    }
+  }
+
+  func testYieldAfterFinish() {
+    XCTAsyncTest {
+      let source = PassthroughMessageSource<String, Never>()
+      XCTAssertEqual(source.finish(), .accepted(queueDepth: 1))
+      XCTAssertEqual(source.yield("foo"), .dropped)
+
+      let sequence = PassthroughMessageSequence(consuming: source)
+      let elements = try await sequence.count()
+      XCTAssertEqual(elements, 0)
+    }
+  }
+
+  func testMultipleFinishes() {
+    XCTAsyncTest {
+      let source = PassthroughMessageSource<String, TestError>()
+      XCTAssertEqual(source.finish(), .accepted(queueDepth: 1))
+      XCTAssertEqual(source.finish(), .dropped)
+      XCTAssertEqual(source.finish(throwing: TestError()), .dropped)
+
+      let sequence = PassthroughMessageSequence(consuming: source)
+      let elements = try await sequence.count()
+      XCTAssertEqual(elements, 0)
+    }
+  }
+
+  func testConsumeBeforeYield() {
+    XCTAsyncTest {
+      let source = PassthroughMessageSource<String, Never>()
+      let sequence = PassthroughMessageSequence(consuming: source)
+
+      await withThrowingTaskGroup(of: Void.self) { group in
+        group.addTask(priority: .high) {
+          let iterator = sequence.makeAsyncIterator()
+          if let next = try await iterator.next() {
+            XCTAssertEqual(next, "one")
+          } else {
+            XCTFail("No value produced")
+          }
+        }
+
+        group.addTask(priority: .low) {
+          let result = source.yield("one")
+          // We can't guarantee that this task will run after the other so we *may* have a queue
+          // depth of one.
+          XCTAssert(result == .accepted(queueDepth: 0) || result == .accepted(queueDepth: 1))
+        }
+      }
+    }
+  }
+
+  func testConsumeBeforeFinish() {
+    XCTAsyncTest {
+      let source = PassthroughMessageSource<String, TestError>()
+      let sequence = PassthroughMessageSequence(consuming: source)
+
+      await withThrowingTaskGroup(of: Void.self) { group in
+        group.addTask(priority: .high) {
+          let iterator = sequence.makeAsyncIterator()
+          await XCTAssertThrowsError(_ = try await iterator.next()) { error in
+            XCTAssert(error is TestError)
+          }
+        }
+
+        group.addTask(priority: .low) {
+          let result = source.finish(throwing: TestError())
+          // We can't guarantee that this task will run after the other so we *may* have a queue
+          // depth of one.
+          XCTAssert(result == .accepted(queueDepth: 0) || result == .accepted(queueDepth: 1))
+        }
+      }
+    }
+  }
+}
+
+fileprivate struct TestError: Error {}
+
+#endif // compiler(>=5.5)

+ 66 - 0
Tests/GRPCTests/AsyncAwaitSupport/XCTest+AsyncAwait.swift

@@ -0,0 +1,66 @@
+/*
+ * Copyright 2021, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#if compiler(>=5.5)
+import XCTest
+
+extension XCTestCase {
+  @available(macOS 12, iOS 15, tvOS 15, watchOS 8, *)
+  /// Cross-platform XCTest support for async-await tests.
+  ///
+  /// Currently the Linux implementation of XCTest doesn't have async-await support.
+  /// Until it does, we make use of this shim which uses a detached `Task` along with
+  /// `XCTest.wait(for:timeout:)` to wrap the operation.
+  ///
+  /// - NOTE: Support for Linux is tracked by https://bugs.swift.org/browse/SR-14403.
+  /// - NOTE: Implementation currently in progress: https://github.com/apple/swift-corelibs-xctest/pull/326
+  func XCTAsyncTest(
+    expectationDescription: String = "Async operation",
+    timeout: TimeInterval = 30,
+    file: StaticString = #filePath,
+    line: UInt = #line,
+    function: StaticString = #function,
+    operation: @escaping () async throws -> Void
+  ) {
+    let expectation = self.expectation(description: expectationDescription)
+    Task {
+      do {
+        try await operation()
+      } catch {
+        XCTFail("Error thrown while executing \(function): \(error)", file: file, line: line)
+        Thread.callStackSymbols.forEach { print($0) }
+      }
+      expectation.fulfill()
+    }
+    self.wait(for: [expectation], timeout: timeout)
+  }
+}
+
+@available(macOS 12, iOS 15, tvOS 15, watchOS 8, *)
+internal func XCTAssertThrowsError<T>(
+  _ expression: @autoclosure () async throws -> T,
+  verify: (Error) -> Void = { _ in },
+  file: StaticString = #file,
+  line: UInt = #line
+) async {
+  do {
+    _ = try await expression()
+    XCTFail("Expression did not throw error", file: file, line: line)
+  } catch {
+    verify(error)
+  }
+}
+
+#endif // compiler(>=5.5)

+ 22 - 8
Tests/GRPCTests/GRPCAsyncClientCallTests.swift

@@ -178,10 +178,7 @@ class GRPCAsyncClientCallTests: GRPCTestCase {
 
     await assertThat(try await update.initialMetadata, .is(.equalTo(Self.OKInitialMetadata)))
 
-    actor TestResults {
-      static var numResponses = 0
-      static var numRequests = 0
-    }
+    let counter = RequestResponseCounter()
 
     // Send the requests and get responses in separate concurrent tasks and await the group.
     _ = await withThrowingTaskGroup(of: Void.self) { taskGroup in
@@ -189,22 +186,39 @@ class GRPCAsyncClientCallTests: GRPCTestCase {
       taskGroup.addTask {
         for word in ["boyle", "jeffers", "holt"] {
           try await update.sendMessage(.with { $0.text = word })
-          TestResults.numRequests += 1
+          await counter.incrementRequests()
         }
         try await update.sendEnd()
       }
       // Get responses in a separate task.
       taskGroup.addTask {
         for try await _ in update.responses {
-          TestResults.numResponses += 1
+          await counter.incrementResponses()
         }
       }
     }
-    await assertThat(TestResults.numRequests, .is(.equalTo(3)))
-    await assertThat(TestResults.numResponses, .is(.equalTo(3)))
+
+    await assertThat(await counter.numRequests, .is(.equalTo(3)))
+    await assertThat(await counter.numResponses, .is(.equalTo(3)))
     await assertThat(try await update.trailingMetadata, .is(.equalTo(Self.OKTrailingMetadata)))
     await assertThat(await update.status, .hasCode(.ok))
   } }
 }
 
+// Workaround https://bugs.swift.org/browse/SR-15070 (compiler crashes when defining a class/actor
+// in an async context).
+@available(macOS 12, iOS 15, tvOS 15, watchOS 8, *)
+fileprivate actor RequestResponseCounter {
+  var numResponses = 0
+  var numRequests = 0
+
+  func incrementResponses() async {
+    self.numResponses += 1
+  }
+
+  func incrementRequests() async {
+    self.numRequests += 1
+  }
+}
+
 #endif

+ 0 - 32
Tests/GRPCTests/XCTestHelpers.swift

@@ -696,36 +696,4 @@ func assertThat<Value>(
   }
 }
 
-@available(macOS 12, iOS 15, tvOS 15, watchOS 8, *)
-extension XCTestCase {
-  /// Cross-platform XCTest support for async-await tests.
-  ///
-  /// Currently the Linux implementation of XCTest doesn't have async-await support.
-  /// Until it does, we make use of this shim which uses a detached `Task` along with
-  /// `XCTest.wait(for:timeout:)` to wrap the operation.
-  ///
-  /// - NOTE: Support for Linux is tracked by https://bugs.swift.org/browse/SR-14403.
-  /// - NOTE: Implementation currently in progress: https://github.com/apple/swift-corelibs-xctest/pull/326
-  func XCTAsyncTest(
-    expectationDescription: String = "Async operation",
-    timeout: TimeInterval = 30,
-    file: StaticString = #filePath,
-    line: UInt = #line,
-    function: StaticString = #function,
-    operation: @escaping () async throws -> Void
-  ) {
-    let expectation = self.expectation(description: expectationDescription)
-    Task {
-      do {
-        try await operation()
-      } catch {
-        XCTFail("Error thrown while executing \(function): \(error)", file: file, line: line)
-        Thread.callStackSymbols.forEach { print($0) }
-      }
-      expectation.fulfill()
-    }
-    self.wait(for: [expectation], timeout: timeout)
-  }
-}
-
 #endif