Răsfoiți Sursa

Increase cancellation support for async calls (#1608)

Motivation:

The async client calls have limiteed support for cancellation: they
support it for the "wrapped" calls and request/response streams but not
for metadata/status on the lower level call objects.

Modifications:

- Add support for Task cancellation on the async call types

Result:

Better cancellation support
George Barnett 2 ani în urmă
părinte
comite
76d4ec1dff

+ 15 - 3
Sources/GRPC/AsyncAwaitSupport/GRPCAsyncBidirectionalStreamingCall.swift

@@ -52,6 +52,12 @@ public struct GRPCAsyncBidirectionalStreamingCall<Request: Sendable, Response: S
 
   // MARK: - Response Parts
 
+  private func withRPCCancellation<R: Sendable>(_ fn: () async throws -> R) async rethrows -> R {
+    return try await withTaskCancellationHandler(operation: fn) {
+      self.cancel()
+    }
+  }
+
   /// The initial metadata returned from the server.
   ///
   /// - Important: The initial metadata will only be available when the first response has been
@@ -59,7 +65,9 @@ public struct GRPCAsyncBidirectionalStreamingCall<Request: Sendable, Response: S
   /// this property.
   public var initialMetadata: HPACKHeaders {
     get async throws {
-      try await self.responseParts.initialMetadata.get()
+      try await self.withRPCCancellation {
+        try await self.responseParts.initialMetadata.get()
+      }
     }
   }
 
@@ -68,7 +76,9 @@ public struct GRPCAsyncBidirectionalStreamingCall<Request: Sendable, Response: S
   /// - Important: Awaiting this property will suspend until the responses have been consumed.
   public var trailingMetadata: HPACKHeaders {
     get async throws {
-      try await self.responseParts.trailingMetadata.get()
+      try await self.withRPCCancellation {
+        try await self.responseParts.trailingMetadata.get()
+      }
     }
   }
 
@@ -78,7 +88,9 @@ public struct GRPCAsyncBidirectionalStreamingCall<Request: Sendable, Response: S
   public var status: GRPCStatus {
     get async {
       // force-try acceptable because any error is encapsulated in a successful GRPCStatus future.
-      try! await self.responseParts.status.get()
+      await self.withRPCCancellation {
+        try! await self.responseParts.status.get()
+      }
     }
   }
 

+ 18 - 4
Sources/GRPC/AsyncAwaitSupport/GRPCAsyncClientStreamingCall.swift

@@ -43,19 +43,29 @@ public struct GRPCAsyncClientStreamingCall<Request: Sendable, Response: Sendable
 
   // MARK: - Response Parts
 
+  private func withRPCCancellation<R: Sendable>(_ fn: () async throws -> R) async rethrows -> R {
+    return try await withTaskCancellationHandler(operation: fn) {
+      self.cancel()
+    }
+  }
+
   /// The initial metadata returned from the server.
   ///
   /// - Important: The initial metadata will only be available when the response has been received.
   public var initialMetadata: HPACKHeaders {
     get async throws {
-      try await self.responseParts.initialMetadata.get()
+      return try await self.withRPCCancellation {
+        try await self.responseParts.initialMetadata.get()
+      }
     }
   }
 
   /// The response returned by the server.
   public var response: Response {
     get async throws {
-      try await self.responseParts.response.get()
+      return try await self.withRPCCancellation {
+        try await self.responseParts.response.get()
+      }
     }
   }
 
@@ -64,7 +74,9 @@ public struct GRPCAsyncClientStreamingCall<Request: Sendable, Response: Sendable
   /// - Important: Awaiting this property will suspend until the responses have been consumed.
   public var trailingMetadata: HPACKHeaders {
     get async throws {
-      try await self.responseParts.trailingMetadata.get()
+      return try await self.withRPCCancellation {
+        try await self.responseParts.trailingMetadata.get()
+      }
     }
   }
 
@@ -74,7 +86,9 @@ public struct GRPCAsyncClientStreamingCall<Request: Sendable, Response: Sendable
   public var status: GRPCStatus {
     get async {
       // force-try acceptable because any error is encapsulated in a successful GRPCStatus future.
-      try! await self.responseParts.status.get()
+      return await self.withRPCCancellation {
+        try! await self.responseParts.status.get()
+      }
     }
   }
 

+ 15 - 3
Sources/GRPC/AsyncAwaitSupport/GRPCAsyncServerStreamingCall.swift

@@ -49,6 +49,12 @@ public struct GRPCAsyncServerStreamingCall<Request: Sendable, Response: Sendable
 
   // MARK: - Response Parts
 
+  private func withRPCCancellation<R: Sendable>(_ fn: () async throws -> R) async rethrows -> R {
+    return try await withTaskCancellationHandler(operation: fn) {
+      self.cancel()
+    }
+  }
+
   /// The initial metadata returned from the server.
   ///
   /// - Important: The initial metadata will only be available when the first response has been
@@ -56,7 +62,9 @@ public struct GRPCAsyncServerStreamingCall<Request: Sendable, Response: Sendable
   /// this property.
   public var initialMetadata: HPACKHeaders {
     get async throws {
-      try await self.responseParts.initialMetadata.get()
+      try await self.withRPCCancellation {
+        try await self.responseParts.initialMetadata.get()
+      }
     }
   }
 
@@ -65,7 +73,9 @@ public struct GRPCAsyncServerStreamingCall<Request: Sendable, Response: Sendable
   /// - Important: Awaiting this property will suspend until the responses have been consumed.
   public var trailingMetadata: HPACKHeaders {
     get async throws {
-      try await self.responseParts.trailingMetadata.get()
+      try await self.withRPCCancellation {
+        try await self.responseParts.trailingMetadata.get()
+      }
     }
   }
 
@@ -75,7 +85,9 @@ public struct GRPCAsyncServerStreamingCall<Request: Sendable, Response: Sendable
   public var status: GRPCStatus {
     get async {
       // force-try acceptable because any error is encapsulated in a successful GRPCStatus future.
-      try! await self.responseParts.status.get()
+      await self.withRPCCancellation {
+        try! await self.responseParts.status.get()
+      }
     }
   }
 

+ 18 - 4
Sources/GRPC/AsyncAwaitSupport/GRPCAsyncUnaryCall.swift

@@ -41,12 +41,20 @@ public struct GRPCAsyncUnaryCall<Request: Sendable, Response: Sendable>: Sendabl
 
   // MARK: - Response Parts
 
+  private func withRPCCancellation<R: Sendable>(_ fn: () async throws -> R) async rethrows -> R {
+    return try await withTaskCancellationHandler(operation: fn) {
+      self.cancel()
+    }
+  }
+
   /// The initial metadata returned from the server.
   ///
   /// - Important: The initial metadata will only be available when the response has been received.
   public var initialMetadata: HPACKHeaders {
     get async throws {
-      try await self.responseParts.initialMetadata.get()
+      try await self.withRPCCancellation {
+        try await self.responseParts.initialMetadata.get()
+      }
     }
   }
 
@@ -56,7 +64,9 @@ public struct GRPCAsyncUnaryCall<Request: Sendable, Response: Sendable>: Sendabl
   /// Callers should rely on the `status` of the call for the canonical outcome.
   public var response: Response {
     get async throws {
-      try await self.responseParts.response.get()
+      try await self.withRPCCancellation {
+        try await self.responseParts.response.get()
+      }
     }
   }
 
@@ -65,7 +75,9 @@ public struct GRPCAsyncUnaryCall<Request: Sendable, Response: Sendable>: Sendabl
   /// - Important: Awaiting this property will suspend until the responses have been consumed.
   public var trailingMetadata: HPACKHeaders {
     get async throws {
-      try await self.responseParts.trailingMetadata.get()
+      try await self.withRPCCancellation {
+        try await self.responseParts.trailingMetadata.get()
+      }
     }
   }
 
@@ -75,7 +87,9 @@ public struct GRPCAsyncUnaryCall<Request: Sendable, Response: Sendable>: Sendabl
   public var status: GRPCStatus {
     get async {
       // force-try acceptable because any error is encapsulated in a successful GRPCStatus future.
-      try! await self.responseParts.status.get()
+      await self.withRPCCancellation {
+        try! await self.responseParts.status.get()
+      }
     }
   }
 

+ 122 - 0
Tests/GRPCTests/AsyncAwaitSupport/AsyncClientTests.swift

@@ -414,4 +414,126 @@ final class AsyncClientCancellationTests: GRPCTestCase {
       XCTAssertFalse(error is CancellationError)
     }
   }
+
+  func testCancelUnary() async throws {
+    // We don't want the RPC to complete before we cancel it so use the never resolving service.
+    let echo = try self.startServerAndClient(service: NeverResolvingEchoProvider())
+
+    do {
+      let get = echo.makeGetCall(.with { $0.text = "foo bar baz" })
+      let task = Task { try await get.initialMetadata }
+      task.cancel()
+      await XCTAssertThrowsError(try await task.value)
+    }
+
+    do {
+      let get = echo.makeGetCall(.with { $0.text = "foo bar baz" })
+      let task = Task { try await get.response }
+      task.cancel()
+      await XCTAssertThrowsError(try await task.value)
+    }
+
+    do {
+      let get = echo.makeGetCall(.with { $0.text = "foo bar baz" })
+      let task = Task { try await get.trailingMetadata }
+      task.cancel()
+      await XCTAssertNoThrowAsync(try await task.value)
+    }
+
+    do {
+      let get = echo.makeGetCall(.with { $0.text = "foo bar baz" })
+      let task = Task { await get.status }
+      task.cancel()
+      let status = await task.value
+      XCTAssertEqual(status.code, .cancelled)
+    }
+  }
+
+  func testCancelClientStreaming() async throws {
+    // We don't want the RPC to complete before we cancel it so use the never resolving service.
+    let echo = try self.startServerAndClient(service: NeverResolvingEchoProvider())
+
+    do {
+      let collect = echo.makeCollectCall()
+      let task = Task { try await collect.initialMetadata }
+      task.cancel()
+      await XCTAssertThrowsError(try await task.value)
+    }
+
+    do {
+      let collect = echo.makeCollectCall()
+      let task = Task { try await collect.response }
+      task.cancel()
+      await XCTAssertThrowsError(try await task.value)
+    }
+
+    do {
+      let collect = echo.makeCollectCall()
+      let task = Task { try await collect.trailingMetadata }
+      task.cancel()
+      await XCTAssertNoThrowAsync(try await task.value)
+    }
+
+    do {
+      let collect = echo.makeCollectCall()
+      let task = Task { await collect.status }
+      task.cancel()
+      let status = await task.value
+      XCTAssertEqual(status.code, .cancelled)
+    }
+  }
+
+  func testCancelServerStreaming() async throws {
+    // We don't want the RPC to complete before we cancel it so use the never resolving service.
+    let echo = try self.startServerAndClient(service: NeverResolvingEchoProvider())
+
+    do {
+      let expand = echo.makeExpandCall(.with { $0.text = "foo bar baz" })
+      let task = Task { try await expand.initialMetadata }
+      task.cancel()
+      await XCTAssertThrowsError(try await task.value)
+    }
+
+    do {
+      let expand = echo.makeExpandCall(.with { $0.text = "foo bar baz" })
+      let task = Task { try await expand.trailingMetadata }
+      task.cancel()
+      await XCTAssertNoThrowAsync(try await task.value)
+    }
+
+    do {
+      let expand = echo.makeExpandCall(.with { $0.text = "foo bar baz" })
+      let task = Task { await expand.status }
+      task.cancel()
+      let status = await task.value
+      XCTAssertEqual(status.code, .cancelled)
+    }
+  }
+
+  func testCancelBidirectionalStreaming() async throws {
+    // We don't want the RPC to complete before we cancel it so use the never resolving service.
+    let echo = try self.startServerAndClient(service: NeverResolvingEchoProvider())
+
+    do {
+      let update = echo.makeUpdateCall()
+      let task = Task { try await update.initialMetadata }
+      task.cancel()
+      await XCTAssertThrowsError(try await task.value)
+    }
+
+    do {
+      let update = echo.makeUpdateCall()
+      let task = Task { try await update.trailingMetadata }
+      task.cancel()
+      await XCTAssertNoThrowAsync(try await task.value)
+    }
+
+    do {
+      let update = echo.makeUpdateCall()
+      let task = Task { await update.status }
+      task.cancel()
+      let status = await task.value
+      XCTAssertEqual(status.code, .cancelled)
+    }
+  }
 }

+ 13 - 0
Tests/GRPCTests/AsyncAwaitSupport/XCTest+AsyncAwait.swift

@@ -30,6 +30,19 @@ internal func XCTAssertThrowsError<T>(
   }
 }
 
+@available(macOS 12, iOS 15, tvOS 15, watchOS 8, *)
+internal func XCTAssertNoThrowAsync<T>(
+  _ expression: @autoclosure () async throws -> T,
+  file: StaticString = #filePath,
+  line: UInt = #line
+) async {
+  do {
+    _ = try await expression()
+  } catch {
+    XCTFail("Expression throw error '\(error)'", file: file, line: line)
+  }
+}
+
 private enum TaskResult<Result> {
   case operation(Result)
   case cancellation

+ 1 - 1
Tests/GRPCTests/InterceptedRPCCancellationTests.swift

@@ -33,7 +33,7 @@ final class InterceptedRPCCancellationTests: GRPCTestCase {
     }
 
     // Interceptor checks that a "magic" header is present.
-    let serverInterceptors = EchoServerInterceptors(MagicRequiredServerInterceptor.init)
+    let serverInterceptors = EchoServerInterceptors({ MagicRequiredServerInterceptor() })
     let server = try Server.insecure(group: group)
       .withLogger(self.serverLogger)
       .withServiceProviders([EchoProvider(interceptors: serverInterceptors)])