Przeglądaj źródła

Add co-operative cancellation to async writer and passthrough source (#1414)

Motivation:

Whenever we create continuations we should be careful to add support for
co-operative cancellation via a cancellation handler.

Modifications:

- Add co-operative cancellation to the async write and passthrough
  source
- Tests

Result:

Better cancellation support
George Barnett 3 lat temu
rodzic
commit
0680b7b359

+ 29 - 34
Sources/GRPC/AsyncAwaitSupport/AsyncWriter.swift

@@ -245,50 +245,45 @@ internal final actor AsyncWriter<Delegate: AsyncWriterDelegate>: Sendable {
   ///   have been suspended.
   @inlinable
   internal func write(_ element: Element) async throws {
-    try await withCheckedThrowingContinuation { continuation in
-      self._write(element, continuation: continuation)
-    }
-  }
-
-  @inlinable
-  internal func _write(_ element: Element, continuation: CheckedContinuation<Void, Error>) {
     // There are three outcomes of writing:
     // - write the element directly (if the writer isn't paused and no writes are pending)
     // - queue the element (the writer is paused or there are writes already pending)
     // - error (the writer is complete or the queue is full).
-
-    if self._completionState.isPendingOrCompleted {
-      continuation.resume(throwing: GRPCAsyncWriterError.alreadyFinished)
-    } else if !self._isPaused, self._pendingElements.isEmpty {
-      self._delegate.write(element)
-      continuation.resume()
-    } else if self._pendingElements.count < self._maxPendingElements {
-      // The continuation will be resumed later.
-      self._pendingElements.append(PendingElement(element, continuation: continuation))
-    } else {
-      continuation.resume(throwing: GRPCAsyncWriterError.tooManyPendingWrites)
+    return try await withTaskCancellationHandler {
+      if self._completionState.isPendingOrCompleted {
+        throw GRPCAsyncWriterError.alreadyFinished
+      } else if !self._isPaused, self._pendingElements.isEmpty {
+        self._delegate.write(element)
+      } else if self._pendingElements.count < self._maxPendingElements {
+        // The continuation will be resumed later.
+        try await withCheckedThrowingContinuation { continuation in
+          self._pendingElements.append(PendingElement(element, continuation: continuation))
+        }
+      } else {
+        throw GRPCAsyncWriterError.tooManyPendingWrites
+      }
+    } onCancel: {
+      self.cancelAsynchronously()
     }
   }
 
   /// Write the final element
   @inlinable
   internal func finish(_ end: End) async throws {
-    try await withCheckedThrowingContinuation { continuation in
-      self._finish(end, continuation: continuation)
-    }
-  }
-
-  @inlinable
-  internal func _finish(_ end: End, continuation: CheckedContinuation<Void, Error>) {
-    if self._completionState.isPendingOrCompleted {
-      continuation.resume(throwing: GRPCAsyncWriterError.alreadyFinished)
-    } else if !self._isPaused, self._pendingElements.isEmpty {
-      self._completionState = .completed
-      self._delegate.writeEnd(end)
-      continuation.resume()
-    } else {
-      // Either we're paused or there are pending writes which must be consumed first.
-      self._completionState = .pending(PendingEnd(end, continuation: continuation))
+    return try await withTaskCancellationHandler {
+      if self._completionState.isPendingOrCompleted {
+        throw GRPCAsyncWriterError.alreadyFinished
+      } else if !self._isPaused, self._pendingElements.isEmpty {
+        self._completionState = .completed
+        self._delegate.writeEnd(end)
+      } else {
+        try await withCheckedThrowingContinuation { continuation in
+          // Either we're paused or there are pending writes which must be consumed first.
+          self._completionState = .pending(PendingEnd(end, continuation: continuation))
+        }
+      }
+    } onCancel: {
+      self.cancelAsynchronously()
     }
   }
 }

+ 1 - 0
Sources/GRPC/AsyncAwaitSupport/PassthroughMessageSequence.swift

@@ -50,6 +50,7 @@ internal struct PassthroughMessageSequence<Element, Failure: Error>: AsyncSequen
 
     @inlinable
     internal func next() async throws -> Element? {
+      // The storage handles co-operative cancellation, so we don't bother checking here.
       return try await self._storage.consumeNextElement()
     }
   }

+ 22 - 17
Sources/GRPC/AsyncAwaitSupport/PassthroughMessageSource.swift

@@ -108,12 +108,14 @@ internal final class PassthroughMessageSource<Element, Failure: Error> {
     let result: _YieldResult = self._lock.withLock {
       if self._isTerminated {
         return .alreadyTerminated
-      } else if let continuation = self._continuation {
+      } else {
         self._isTerminated = isTerminator
+      }
+
+      if let continuation = self._continuation {
         self._continuation = nil
         return .resume(continuation)
       } else {
-        self._isTerminated = isTerminator
         self._continuationResults.append(continuationResult)
         return .queued(self._continuationResults.count)
       }
@@ -138,28 +140,31 @@ internal final class PassthroughMessageSource<Element, Failure: Error> {
 
   @inlinable
   internal func consumeNextElement() async throws -> Element? {
-    return try await withCheckedThrowingContinuation {
-      self._consumeNextElement(continuation: $0)
+    self._lock.lock()
+    if let nextResult = self._continuationResults.popFirst() {
+      self._lock.unlock()
+      return try nextResult.get()
+    } else if self._isTerminated {
+      self._lock.unlock()
+      return nil
     }
-  }
 
-  @inlinable
-  internal func _consumeNextElement(continuation: CheckedContinuation<Element?, Error>) {
-    let continuationResult: _ContinuationResult? = self._lock.withLock {
-      if let nextResult = self._continuationResults.popFirst() {
-        return nextResult
-      } else if self._isTerminated {
-        return .success(nil)
-      } else {
+    // Slow path; we need a continuation.
+    return try await withTaskCancellationHandler {
+      try await withCheckedThrowingContinuation { continuation in
         // Nothing buffered and not terminated yet: save the continuation for later.
         precondition(self._continuation == nil)
         self._continuation = continuation
-        return nil
+        self._lock.unlock()
+      }
+    } onCancel: {
+      let continuation: CheckedContinuation<Element?, Error>? = self._lock.withLock {
+        let cont = self._continuation
+        self._continuation = nil
+        return cont
       }
-    }
 
-    if let continuationResult = continuationResult {
-      continuation.resume(with: continuationResult)
+      continuation?.resume(throwing: CancellationError())
     }
   }
 }

+ 28 - 0
Tests/GRPCTests/AsyncAwaitSupport/AsyncWriterTests.swift

@@ -243,6 +243,34 @@ internal class AsyncWriterTests: GRPCTestCase {
     XCTAssertTrue(delegate.elements.isEmpty)
     XCTAssertNil(delegate.end)
   }
+
+  func testCooperativeCancellationOnWrite() async throws {
+    let delegate = CollectingDelegate<String, Void>()
+    let writer = AsyncWriter(isWritable: false, delegate: delegate)
+    try await withTaskCancelledAfter(nanoseconds: 100_000) {
+      do {
+        // Without co-operative cancellation then this will suspend indefinitely.
+        try await writer.write("I should be cancelled")
+        XCTFail("write(_:) should throw CancellationError")
+      } catch {
+        XCTAssert(error is CancellationError)
+      }
+    }
+  }
+
+  func testCooperativeCancellationOnFinish() async throws {
+    let delegate = CollectingDelegate<String, Void>()
+    let writer = AsyncWriter(isWritable: false, delegate: delegate)
+    try await withTaskCancelledAfter(nanoseconds: 100_000) {
+      do {
+        // Without co-operative cancellation then this will suspend indefinitely.
+        try await writer.finish()
+        XCTFail("finish() should throw CancellationError")
+      } catch {
+        XCTAssert(error is CancellationError)
+      }
+    }
+  }
 }
 
 fileprivate final class CollectingDelegate<

+ 27 - 0
Tests/GRPCTests/AsyncAwaitSupport/PassthroughMessageSourceTests.swift

@@ -126,6 +126,33 @@ class PassthroughMessageSourceTests: GRPCTestCase {
       }
     }
   }
+
+  func testCooperativeCancellationOfSourceOnNext() async throws {
+    let source = PassthroughMessageSource<String, TestError>()
+    try await withTaskCancelledAfter(nanoseconds: 100_000) {
+      do {
+        _ = try await source.consumeNextElement()
+        XCTFail("consumeNextElement() should throw CancellationError")
+      } catch {
+        XCTAssert(error is CancellationError)
+      }
+    }
+  }
+
+  func testCooperativeCancellationOfSequenceOnNext() async throws {
+    let source = PassthroughMessageSource<String, TestError>()
+    let sequence = PassthroughMessageSequence(consuming: source)
+    try await withTaskCancelledAfter(nanoseconds: 100_000) {
+      do {
+        for try await _ in sequence {
+          XCTFail("consumeNextElement() should throw CancellationError")
+        }
+        XCTFail("consumeNextElement() should throw CancellationError")
+      } catch {
+        XCTAssert(error is CancellationError)
+      }
+    }
+  }
 }
 
 fileprivate struct TestError: Error {}

+ 40 - 0
Tests/GRPCTests/AsyncAwaitSupport/XCTest+AsyncAwait.swift

@@ -31,4 +31,44 @@ internal func XCTAssertThrowsError<T>(
   }
 }
 
+fileprivate enum TaskResult<Result> {
+  case operation(Result)
+  case cancellation
+}
+
+@available(macOS 12, iOS 15, tvOS 15, watchOS 8, *)
+func withTaskCancelledAfter<Result>(
+  nanoseconds: UInt64,
+  operation: @escaping @Sendable () async -> Result
+) async throws {
+  try await withThrowingTaskGroup(of: TaskResult<Result>.self) { group in
+    group.addTask {
+      return .operation(await operation())
+    }
+
+    group.addTask {
+      try await Task.sleep(nanoseconds: nanoseconds)
+      return .cancellation
+    }
+
+    // Only the sleeping task can throw if it's cancelled, in which case we want to throw.
+    let firstResult = try await group.next()
+    // A task completed, cancel the rest.
+    group.cancelAll()
+
+    // Check which task completed.
+    switch firstResult {
+    case .cancellation:
+      () // Fine, what we expect.
+    case .operation:
+      XCTFail("Operation completed before cancellation")
+    case .none:
+      XCTFail("No tasks completed")
+    }
+
+    // Wait for the other task. The operation cannot, only the sleeping task can.
+    try await group.waitForAll()
+  }
+}
+
 #endif // compiler(>=5.6)