Bladeren bron

Make client RPC cancellation non-async (#1413)

Motivation:

Cancelling an RPC as a client is done via a cancellation function on the
call which `throws` and is also `async`. That it `throws` is odd as an
API because by virtue of cancelling we no longer care about the result
so it doesn't matter if cancellation was successful or not (i.e. if the
RPC was already cancelled). That the cancellation is `async` does not
fit with task cancellation handlers which are also not `async`.

Modifications:

- Make all async call `cancel()` sync and non throwing.
- Add tests for the 'wrapped' RPCs (and fix some bugs along the way)

Result:

Cancelling RPCs from the client is async and more robust.
George Barnett 3 jaren geleden
bovenliggende
commit
938d1413ec

+ 2 - 2
Sources/GRPC/AsyncAwaitSupport/GRPCAsyncBidirectionalStreamingCall.swift

@@ -36,8 +36,8 @@ public struct GRPCAsyncBidirectionalStreamingCall<Request: Sendable, Response: S
   }
 
   /// Cancel this RPC if it hasn't already completed.
-  public func cancel() async throws {
-    try await self.call.cancel().get()
+  public func cancel() {
+    self.call.cancel(promise: nil)
   }
 
   // MARK: - Response Parts

+ 2 - 2
Sources/GRPC/AsyncAwaitSupport/GRPCAsyncClientStreamingCall.swift

@@ -32,8 +32,8 @@ public struct GRPCAsyncClientStreamingCall<Request: Sendable, Response: Sendable
   }
 
   /// Cancel this RPC if it hasn't already completed.
-  public func cancel() async throws {
-    try await self.call.cancel().get()
+  public func cancel() {
+    self.call.cancel(promise: nil)
   }
 
   // MARK: - Response Parts

+ 2 - 1
Sources/GRPC/AsyncAwaitSupport/GRPCAsyncResponseStream.swift

@@ -44,7 +44,8 @@ public struct GRPCAsyncResponseStream<Element>: AsyncSequence {
 
     @inlinable
     public mutating func next() async throws -> Element? {
-      try await self.iterator.next()
+      if Task.isCancelled { throw GRPCStatus(code: .cancelled) }
+      return try await self.iterator.next()
     }
   }
 }

+ 2 - 2
Sources/GRPC/AsyncAwaitSupport/GRPCAsyncServerStreamingCall.swift

@@ -33,8 +33,8 @@ public struct GRPCAsyncServerStreamingCall<Request: Sendable, Response: Sendable
   }
 
   /// Cancel this RPC if it hasn't already completed.
-  public func cancel() async throws {
-    try await self.call.cancel().get()
+  public func cancel() {
+    self.call.cancel(promise: nil)
   }
 
   // MARK: - Response Parts

+ 2 - 2
Sources/GRPC/AsyncAwaitSupport/GRPCAsyncUnaryCall.swift

@@ -32,8 +32,8 @@ public struct GRPCAsyncUnaryCall<Request: Sendable, Response: Sendable>: Sendabl
   }
 
   /// Cancel this RPC if it hasn't already completed.
-  public func cancel() async throws {
-    try await self.call.cancel().get()
+  public func cancel() {
+    self.call.cancel(promise: nil)
   }
 
   // MARK: - Response Parts

+ 46 - 37
Sources/GRPC/AsyncAwaitSupport/GRPCClient+AsyncAwaitSupport.swift

@@ -165,12 +165,18 @@ extension GRPCClient {
     interceptors: [ClientInterceptor<Request, Response>] = [],
     responseType: Response.Type = Response.self
   ) async throws -> Response {
-    return try await self.channel.makeAsyncUnaryCall(
+    let call = self.channel.makeAsyncUnaryCall(
       path: path,
       request: request,
       callOptions: callOptions ?? self.defaultCallOptions,
       interceptors: interceptors
-    ).response
+    )
+
+    return try await withTaskCancellationHandler {
+      try await call.response
+    } onCancel: {
+      call.cancel()
+    }
   }
 
   public func performAsyncUnaryCall<
@@ -183,12 +189,18 @@ extension GRPCClient {
     interceptors: [ClientInterceptor<Request, Response>] = [],
     responseType: Response.Type = Response.self
   ) async throws -> Response {
-    return try await self.channel.makeAsyncUnaryCall(
+    let call = self.channel.makeAsyncUnaryCall(
       path: path,
       request: request,
       callOptions: callOptions ?? self.defaultCallOptions,
       interceptors: interceptors
-    ).response
+    )
+
+    return try await withTaskCancellationHandler {
+      try await call.response
+    } onCancel: {
+      call.cancel()
+    }
   }
 
   public func performAsyncServerStreamingCall<
@@ -401,30 +413,25 @@ extension GRPCClient {
     _ call: GRPCAsyncClientStreamingCall<Request, Response>,
     with requests: RequestStream
   ) async throws -> Response where RequestStream.Element == Request {
-    // We use a detached task because we use cancellation to signal early, but successful exit.
-    let requestsTask = Task.detached {
-      try Task.checkCancellation()
-      for try await request in requests {
-        try Task.checkCancellation()
-        try await call.requestStream.send(request)
-      }
-      try Task.checkCancellation()
-      try await call.requestStream.finish()
-      try Task.checkCancellation()
-    }
     return try await withTaskCancellationHandler {
-      // Await the response, which may come before the request stream is exhausted.
-      let response = try await call.response
-      // If we have a response, we can stop sending requests.
-      requestsTask.cancel()
-      // Return the response.
-      return response
-    } onCancel: {
-      requestsTask.cancel()
-      // If this outer task is cancelled then we should also cancel the RPC.
-      Task.detached {
-        try await call.cancel()
+      Task {
+        do {
+          // `AsyncSequence`s are encouraged to co-operatively check for cancellation, and we will
+          // cancel the call `onCancel` anyway, so there's no need to check here too.
+          for try await request in requests {
+            try await call.requestStream.send(request)
+          }
+          try await call.requestStream.finish()
+        } catch {
+          // If we throw then cancel the call. We will rely on the response throwing an appropriate
+          // error below.
+          call.cancel()
+        }
       }
+
+      return try await call.response
+    } onCancel: {
+      call.cancel()
     }
   }
 
@@ -438,20 +445,22 @@ extension GRPCClient {
     with requests: RequestStream
   ) -> GRPCAsyncResponseStream<Response> where RequestStream.Element == Request {
     Task {
-      try await withTaskCancellationHandler {
-        try Task.checkCancellation()
-        for try await request in requests {
-          try Task.checkCancellation()
-          try await call.requestStream.send(request)
-        }
-        try Task.checkCancellation()
-        try await call.requestStream.finish()
-      } onCancel: {
-        Task.detached {
-          try await call.cancel()
+      do {
+        try await withTaskCancellationHandler {
+          // `AsyncSequence`s are encouraged to co-operatively check for cancellation, and we will
+          // cancel the call `onCancel` anyway, so there's no need to check here too.
+          for try await request in requests {
+            try await call.requestStream.send(request)
+          }
+          try await call.requestStream.finish()
+        } onCancel: {
+          call.cancel()
         }
+      } catch {
+        call.cancel()
       }
     }
+
     return call.responseStream
   }
 }

+ 4 - 1
Sources/GRPC/AsyncAwaitSupport/PassthroughMessageSource.swift

@@ -109,6 +109,7 @@ internal final class PassthroughMessageSource<Element, Failure: Error> {
       if self._isTerminated {
         return .alreadyTerminated
       } else if let continuation = self._continuation {
+        self._isTerminated = isTerminator
         self._continuation = nil
         return .resume(continuation)
       } else {
@@ -147,9 +148,11 @@ internal final class PassthroughMessageSource<Element, Failure: Error> {
     let continuationResult: _ContinuationResult? = self._lock.withLock {
       if let nextResult = self._continuationResults.popFirst() {
         return nextResult
+      } else if self._isTerminated {
+        return .success(nil)
       } else {
         // Nothing buffered and not terminated yet: save the continuation for later.
-        assert(self._continuation == nil)
+        precondition(self._continuation == nil)
         self._continuation = continuation
         return nil
       }

+ 134 - 16
Tests/GRPCTests/AsyncAwaitSupport/AsyncClientTests.swift

@@ -83,40 +83,104 @@ final class AsyncClientCancellationTests: GRPCTestCase {
     let echo = try self.startServerAndClient(service: NeverResolvingEchoProvider())
 
     let get = echo.makeGetCall(.with { $0.text = "foo bar baz" })
-    try await get.cancel()
-
-    await XCTAssertThrowsError(try await get.response)
+    get.cancel()
+
+    do {
+      _ = try await get.response
+      XCTFail("Expected to throw a status with code .cancelled")
+    } catch let status as GRPCStatus {
+      XCTAssertEqual(status.code, .cancelled)
+    } catch {
+      XCTFail("Expected to throw a status with code .cancelled")
+    }
 
     // Status should be 'cancelled'.
     let status = await get.status
     XCTAssertEqual(status.code, .cancelled)
   }
 
+  func testCancelFailsUnaryResponseForWrappedCall() async throws {
+    // We don't want the RPC to complete before we cancel it so use the never resolving service.
+    let echo = try self.startServerAndClient(service: NeverResolvingEchoProvider())
+
+    let task = Task {
+      try await echo.get(.with { $0.text = "I'll be cancelled" })
+    }
+
+    task.cancel()
+
+    do {
+      _ = try await task.value
+      XCTFail("Expected to throw a status with code .cancelled")
+    } catch let status as GRPCStatus {
+      XCTAssertEqual(status.code, .cancelled)
+    } catch {
+      XCTFail("Expected to throw a status with code .cancelled")
+    }
+  }
+
   func testCancelServerStreamingClosesResponseStream() async throws {
     // We don't want the RPC to complete before we cancel it so use the never resolving service.
     let echo = try self.startServerAndClient(service: NeverResolvingEchoProvider())
 
     let expand = echo.makeExpandCall(.with { $0.text = "foo bar baz" })
-    try await expand.cancel()
+    expand.cancel()
 
     var responseStream = expand.responseStream.makeAsyncIterator()
-    await XCTAssertThrowsError(try await responseStream.next())
+
+    do {
+      _ = try await responseStream.next()
+      XCTFail("Expected to throw a status with code .cancelled")
+    } catch let status as GRPCStatus {
+      XCTAssertEqual(status.code, .cancelled)
+    } catch {
+      XCTFail("Expected to throw a status with code .cancelled")
+    }
 
     // Status should be 'cancelled'.
     let status = await expand.status
     XCTAssertEqual(status.code, .cancelled)
   }
 
+  func testCancelServerStreamingClosesResponseStreamForWrappedCall() async throws {
+    // We don't want the RPC to complete before we cancel it so use the never resolving service.
+    let echo = try self.startServerAndClient(service: NeverResolvingEchoProvider())
+
+    let task = Task {
+      let responseStream = echo.expand(.with { $0.text = "foo bar baz" })
+      var responseIterator = responseStream.makeAsyncIterator()
+      do {
+        _ = try await responseIterator.next()
+        XCTFail("Expected to throw a status with code .cancelled")
+      } catch let status as GRPCStatus {
+        XCTAssertEqual(status.code, .cancelled)
+      } catch {
+        XCTFail("Expected to throw a status with code .cancelled")
+      }
+    }
+
+    task.cancel()
+    await task.value
+  }
+
   func testCancelClientStreamingClosesRequestStreamAndFailsResponse() async throws {
     let echo = try self.startServerAndClient(service: EchoProvider())
 
     let collect = echo.makeCollectCall()
     // Make sure the stream is up before we cancel it.
     try await collect.requestStream.send(.with { $0.text = "foo" })
-    try await collect.cancel()
+    collect.cancel()
+
+    // Cancellation is async so loop until we error.
+    while true {
+      do {
+        try await collect.requestStream.send(.with { $0.text = "foo" })
+        try await Task.sleep(nanoseconds: 1000)
+      } catch {
+        break
+      }
+    }
 
-    // The next send should fail.
-    await XCTAssertThrowsError(try await collect.requestStream.send(.with { $0.text = "foo" }))
     // There should be no response.
     await XCTAssertThrowsError(try await collect.response)
 
@@ -125,6 +189,29 @@ final class AsyncClientCancellationTests: GRPCTestCase {
     XCTAssertEqual(status.code, .cancelled)
   }
 
+  func testCancelClientStreamingClosesRequestStreamAndFailsResponseForWrappedCall() async throws {
+    let echo = try self.startServerAndClient(service: NeverResolvingEchoProvider())
+    let requests = (0 ..< 10).map { i in
+      Echo_EchoRequest.with {
+        $0.text = String(i)
+      }
+    }
+
+    let task = Task {
+      do {
+        let _ = try await echo.collect(requests)
+        XCTFail("Expected to throw a status with code .cancelled")
+      } catch let status as GRPCStatus {
+        XCTAssertEqual(status.code, .cancelled)
+      } catch {
+        XCTFail("Expected to throw a status with code .cancelled")
+      }
+    }
+
+    task.cancel()
+    await task.value
+  }
+
   func testClientStreamingClosesRequestStreamOnEnd() async throws {
     let echo = try self.startServerAndClient(service: EchoProvider())
 
@@ -154,16 +241,49 @@ final class AsyncClientCancellationTests: GRPCTestCase {
     var responseStream = update.responseStream.makeAsyncIterator()
     _ = try await responseStream.next()
 
-    // Now cancel. The next send should fail and we shouldn't receive any more responses.
-    try await update.cancel()
-    await XCTAssertThrowsError(try await update.requestStream.send(.with { $0.text = "foo" }))
-    await XCTAssertThrowsError(try await responseStream.next())
+    update.cancel()
+
+    // Cancellation is async so loop until we error.
+    while true {
+      do {
+        try await update.requestStream.send(.with { $0.text = "foo" })
+        try await Task.sleep(nanoseconds: 1000)
+      } catch {
+        break
+      }
+    }
 
     // Status should be 'cancelled'.
     let status = await update.status
     XCTAssertEqual(status.code, .cancelled)
   }
 
+  func testCancelBidiStreamingClosesRequestStreamAndResponseStreamForWrappedCall() async throws {
+    let echo = try self.startServerAndClient(service: EchoProvider())
+    let requests = (0 ..< 10).map { i in
+      Echo_EchoRequest.with {
+        $0.text = String(i)
+      }
+    }
+
+    let task = Task {
+      let responseStream = echo.update(requests)
+      var responseIterator = responseStream.makeAsyncIterator()
+
+      do {
+        _ = try await responseIterator.next()
+        XCTFail("Expected to throw a status with code .cancelled")
+      } catch let status as GRPCStatus {
+        XCTAssertEqual(status.code, .cancelled)
+      } catch {
+        XCTFail("Expected to throw a status with code .cancelled")
+      }
+    }
+
+    task.cancel()
+    await task.value
+  }
+
   func testBidiStreamingClosesRequestStreamOnEnd() async throws {
     let echo = try self.startServerAndClient(service: EchoProvider())
 
@@ -204,11 +324,9 @@ final class AsyncClientCancellationTests: GRPCTestCase {
     func cancel() {
       switch self {
       case let .clientStreaming(call):
-        // TODO: this should be async
-        Task { try await call.cancel() }
+        call.cancel()
       case let .bidirectionalStreaming(call):
-        // TODO: this should be async
-        Task { try await call.cancel() }
+        call.cancel()
       }
     }
   }