Browse Source

Add an RPC cancellation handler (#2090)

Motivation:

As a service author it's useful to know if the RPC has been cancelled
(because it's timed out, the remote peer closed it, the connection
dropped etc).

For cases where the stream has already closed this can be surfaced by a
read or write failing. However, for cases like server-streaming RPCs
where there are no reads and writes can be infrequent it's useful to
have a more explicit signal.

Modifications:

- Add a `ServerCancellationManager`, this is internal per-stream storage
for registering cancellation handlers and storing whether the RPC has
been cancelled.
- Add the `RPCCancellationHandle` nested within the `ServerContext`.
This holds an instance of the manager and provides higher level APIs
allowing users to check if the RPC has been cancellation and to wait
until the RPC has been cancelled.
- Add a top-level `withRPCCancellationHandler` which registers a
callback with the manager.
- Add a top-level `withServerContextRPCCancellationHandle` for creating
and binding the task local manager. This is intended for use by
transport implementations rather than users.
- Update the in-process transport to cancel RPCs when shutting down
gracefully.
- Update the server executor to cancel RPCs when the timeout fires.

Result:

Users can watch for cancellation using `withRPCCancellationHandler`.
George Barnett 1 year ago
parent
commit
945cbf6993

+ 254 - 0
Sources/GRPCCore/Call/Server/Internal/ServerCancellationManager.swift

@@ -0,0 +1,254 @@
+/*
+ * Copyright 2024, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+private import Synchronization
+
+/// Stores cancellation state for an RPC on the server .
+package final class ServerCancellationManager: Sendable {
+  private let state: Mutex<State>
+
+  package init() {
+    self.state = Mutex(State())
+  }
+
+  /// Returns whether the RPC has been marked as cancelled.
+  package var isRPCCancelled: Bool {
+    self.state.withLock {
+      return $0.isRPCCancelled
+    }
+  }
+
+  /// Marks the RPC as cancelled, potentially running any cancellation handlers.
+  package func cancelRPC() {
+    switch self.state.withLock({ $0.cancelRPC() }) {
+    case .executeAndResume(let onCancelHandlers, let onCancelWaiters):
+      for handler in onCancelHandlers {
+        handler.handler()
+      }
+
+      for onCancelWaiter in onCancelWaiters {
+        switch onCancelWaiter {
+        case .taskCancelled:
+          ()
+        case .waiting(_, let continuation):
+          continuation.resume(returning: .rpc)
+        }
+      }
+
+    case .doNothing:
+      ()
+    }
+  }
+
+  /// Adds a handler which is invoked when the RPC is cancelled.
+  ///
+  /// - Returns: The ID of the handler, if it was added, or `nil` if the RPC is already cancelled.
+  package func addRPCCancelledHandler(_ handler: @Sendable @escaping () -> Void) -> UInt64? {
+    return self.state.withLock { state -> UInt64? in
+      state.addRPCCancelledHandler(handler)
+    }
+  }
+
+  /// Removes a handler by its ID.
+  package func removeRPCCancelledHandler(withID id: UInt64) {
+    self.state.withLock { state in
+      state.removeRPCCancelledHandler(withID: id)
+    }
+  }
+
+  /// Suspends until the RPC is cancelled or the `Task` is cancelled.
+  package func suspendUntilRPCIsCancelled() async throws(CancellationError) {
+    let id = self.state.withLock { $0.nextID() }
+
+    let source = await withTaskCancellationHandler {
+      await withCheckedContinuation { continuation in
+        let onAddWaiter = self.state.withLock {
+          $0.addRPCIsCancelledWaiter(continuation: continuation, withID: id)
+        }
+
+        switch onAddWaiter {
+        case .doNothing:
+          ()
+        case .complete(let continuation, let result):
+          continuation.resume(returning: result)
+        }
+      }
+    } onCancel: {
+      switch self.state.withLock({ $0.cancelRPCCancellationWaiter(withID: id) }) {
+      case .resume(let continuation, let result):
+        continuation.resume(returning: result)
+      case .doNothing:
+        ()
+      }
+    }
+
+    switch source {
+    case .rpc:
+      ()
+    case .task:
+      throw CancellationError()
+    }
+  }
+}
+
+extension ServerCancellationManager {
+  enum CancellationSource {
+    case rpc
+    case task
+  }
+
+  struct Handler: Sendable {
+    var id: UInt64
+    var handler: @Sendable () -> Void
+  }
+
+  enum Waiter: Sendable {
+    case waiting(UInt64, CheckedContinuation<CancellationSource, Never>)
+    case taskCancelled(UInt64)
+
+    var id: UInt64 {
+      switch self {
+      case .waiting(let id, _):
+        return id
+      case .taskCancelled(let id):
+        return id
+      }
+    }
+  }
+
+  struct State {
+    private var handlers: [Handler]
+    private var waiters: [Waiter]
+    private var _nextID: UInt64
+    var isRPCCancelled: Bool
+
+    mutating func nextID() -> UInt64 {
+      let id = self._nextID
+      self._nextID &+= 1
+      return id
+    }
+
+    init() {
+      self.handlers = []
+      self.waiters = []
+      self._nextID = 0
+      self.isRPCCancelled = false
+    }
+
+    mutating func cancelRPC() -> OnCancelRPC {
+      let onCancel: OnCancelRPC
+
+      if self.isRPCCancelled {
+        onCancel = .doNothing
+      } else {
+        self.isRPCCancelled = true
+        onCancel = .executeAndResume(self.handlers, self.waiters)
+        self.handlers = []
+        self.waiters = []
+      }
+
+      return onCancel
+    }
+
+    mutating func addRPCCancelledHandler(_ handler: @Sendable @escaping () -> Void) -> UInt64? {
+      if self.isRPCCancelled {
+        handler()
+        return nil
+      } else {
+        let id = self.nextID()
+        self.handlers.append(.init(id: id, handler: handler))
+        return id
+      }
+    }
+
+    mutating func removeRPCCancelledHandler(withID id: UInt64) {
+      if let index = self.handlers.firstIndex(where: { $0.id == id }) {
+        self.handlers.remove(at: index)
+      }
+    }
+
+    enum OnCancelRPC {
+      case executeAndResume([Handler], [Waiter])
+      case doNothing
+    }
+
+    enum OnAddWaiter {
+      case complete(CheckedContinuation<CancellationSource, Never>, CancellationSource)
+      case doNothing
+    }
+
+    mutating func addRPCIsCancelledWaiter(
+      continuation: CheckedContinuation<CancellationSource, Never>,
+      withID id: UInt64
+    ) -> OnAddWaiter {
+      let onAddWaiter: OnAddWaiter
+
+      if self.isRPCCancelled {
+        onAddWaiter = .complete(continuation, .rpc)
+      } else if let index = self.waiters.firstIndex(where: { $0.id == id }) {
+        switch self.waiters[index] {
+        case .taskCancelled:
+          onAddWaiter = .complete(continuation, .task)
+        case .waiting:
+          // There's already a continuation enqueued.
+          fatalError("Inconsistent state")
+        }
+      } else {
+        self.waiters.append(.waiting(id, continuation))
+        onAddWaiter = .doNothing
+      }
+
+      return onAddWaiter
+    }
+
+    enum OnCancelRPCCancellationWaiter {
+      case resume(CheckedContinuation<CancellationSource, Never>, CancellationSource)
+      case doNothing
+    }
+
+    mutating func cancelRPCCancellationWaiter(withID id: UInt64) -> OnCancelRPCCancellationWaiter {
+      let onCancelWaiter: OnCancelRPCCancellationWaiter
+
+      if let index = self.waiters.firstIndex(where: { $0.id == id }) {
+        let waiter = self.waiters.removeWithoutMaintainingOrder(at: index)
+        switch waiter {
+        case .taskCancelled:
+          onCancelWaiter = .doNothing
+        case .waiting(_, let continuation):
+          onCancelWaiter = .resume(continuation, .task)
+        }
+      } else {
+        self.waiters.append(.taskCancelled(id))
+        onCancelWaiter = .doNothing
+      }
+
+      return onCancelWaiter
+    }
+  }
+}
+
+extension Array {
+  fileprivate mutating func removeWithoutMaintainingOrder(at index: Int) -> Element {
+    let lastElementIndex = self.index(before: self.endIndex)
+
+    if index == lastElementIndex {
+      return self.remove(at: index)
+    } else {
+      self.swapAt(index, lastElementIndex)
+      return self.removeLast()
+    }
+  }
+}

+ 17 - 31
Sources/GRPCCore/Call/Server/Internal/ServerRPCExecutor.swift

@@ -119,43 +119,29 @@ struct ServerRPCExecutor {
       _ context: ServerContext
     ) async throws -> StreamingServerResponse<Output>
   ) async {
-    await withTaskGroup(of: ServerExecutorTask.self) { group in
+    await withTaskGroup(of: Void.self) { group in
       group.addTask {
-        let result = await Result {
+        do {
           try await Task.sleep(for: timeout, clock: .continuous)
+          context.cancellation.cancel()
+        } catch {
+          ()  // Only cancel the RPC if the timeout completes.
         }
-        return .timedOut(result)
       }
 
-      group.addTask {
-        await Self._processRPC(
-          context: context,
-          metadata: metadata,
-          inbound: inbound,
-          outbound: outbound,
-          deserializer: deserializer,
-          serializer: serializer,
-          interceptors: interceptors,
-          handler: handler
-        )
-        return .executed
-      }
-
-      while let next = await group.next() {
-        switch next {
-        case .timedOut(.success):
-          // Timeout expired; cancel the work.
-          group.cancelAll()
-
-        case .timedOut(.failure):
-          // Timeout failed (because it was cancelled). Wait for more tasks to finish.
-          ()
+      await Self._processRPC(
+        context: context,
+        metadata: metadata,
+        inbound: inbound,
+        outbound: outbound,
+        deserializer: deserializer,
+        serializer: serializer,
+        interceptors: interceptors,
+        handler: handler
+      )
 
-        case .executed:
-          // The work finished. Cancel any remaining tasks.
-          group.cancelAll()
-        }
-      }
+      // Cancel the timeout
+      group.cancelAll()
     }
   }
 

+ 117 - 0
Sources/GRPCCore/Call/Server/ServerContext+RPCCancellationHandle.swift

@@ -0,0 +1,117 @@
+/*
+ * Copyright 2024, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import Synchronization
+
+extension ServerContext {
+  @TaskLocal
+  internal static var rpcCancellation: RPCCancellationHandle?
+
+  /// A handle for the cancellation status of the RPC.
+  public struct RPCCancellationHandle: Sendable {
+    internal let manager: ServerCancellationManager
+
+    /// Create a cancellation handle.
+    ///
+    /// To create an instance of this handle appropriately bound to a `Task`
+    /// use ``withServerContextRPCCancellationHandle(_:)``.
+    public init() {
+      self.manager = ServerCancellationManager()
+    }
+
+    /// Returns whether the RPC has been cancelled.
+    public var isCancelled: Bool {
+      self.manager.isRPCCancelled
+    }
+
+    /// Waits until the RPC has been cancelled.
+    ///
+    /// Throws a `CancellationError` if the `Task` is cancelled.
+    ///
+    /// You can also be notified when an RPC is cancelled by using
+    /// ``withRPCCancellationHandler(operation:onCancelRPC:)``.
+    public var cancelled: Void {
+      get async throws {
+        try await self.manager.suspendUntilRPCIsCancelled()
+      }
+    }
+
+    /// Signal that the RPC should be cancelled.
+    ///
+    /// This is idempotent: calling it more than once has no effect.
+    public func cancel() {
+      self.manager.cancelRPC()
+    }
+  }
+}
+
+/// Execute an operation with an RPC cancellation handler that's immediately invoked
+/// if the RPC is canceled.
+///
+/// RPCs can be cancelled for a number of reasons including:
+/// 1. The RPC was taking too long to process and a timeout passed.
+/// 2. The remote peer closed the underlying stream, either because they were no longer
+///    interested in the result or due to a broken connection.
+/// 3. The server began shutting down.
+///
+/// - Important: This only applies to RPCs on the server.
+/// - Parameters:
+///   - operation: The operation to execute.
+///   - handler: The handler which is invoked when the RPC is cancelled.
+/// - Throws: Any error thrown by the `operation` closure.
+/// - Returns: The result of the `operation` closure.
+public func withRPCCancellationHandler<Result, Failure: Error>(
+  operation: () async throws(Failure) -> Result,
+  onCancelRPC handler: @Sendable @escaping () -> Void
+) async throws(Failure) -> Result {
+  guard let manager = ServerContext.rpcCancellation?.manager,
+    let id = manager.addRPCCancelledHandler(handler)
+  else {
+    return try await operation()
+  }
+
+  defer {
+    manager.removeRPCCancelledHandler(withID: id)
+  }
+
+  return try await operation()
+}
+
+/// Provides scoped access to a server RPC cancellation handle.
+///
+/// The cancellation handle should be passed to a ``ServerContext`` and last
+/// the duration of the RPC.
+///
+/// - Important: This function is intended for use when implementing
+///   a ``ServerTransport``.
+///
+/// If you want to be notified about RPCs being cancelled
+/// use ``withRPCCancellationHandler(operation:onCancelRPC:)``.
+///
+/// - Parameter operation: The operation to execute with the handle.
+public func withServerContextRPCCancellationHandle<Success, Failure: Error>(
+  _ operation: (ServerContext.RPCCancellationHandle) async throws(Failure) -> Success
+) async throws(Failure) -> Success {
+  let handle = ServerContext.RPCCancellationHandle()
+  let result = await ServerContext.$rpcCancellation.withValue(handle) {
+    // Wrap up the outcome in a result as 'withValue' doesn't support typed throws.
+    return await Swift.Result { () async throws(Failure) -> Success in
+      return try await operation(handle)
+    }
+  }
+
+  return try result.get()
+}

+ 10 - 1
Sources/GRPCCore/Call/Server/ServerContext.swift

@@ -19,8 +19,17 @@ public struct ServerContext: Sendable {
   /// A description of the method being called.
   public var descriptor: MethodDescriptor
 
+  /// A handle for checking the cancellation status of an RPC.
+  public var cancellation: RPCCancellationHandle
+
   /// Create a new server context.
-  public init(descriptor: MethodDescriptor) {
+  ///
+  /// - Parameters:
+  ///   - descriptor: A description of the method being called.
+  ///   - cancellation: A cancellation handle. You can create a cancellation handle
+  ///     using ``withServerContextRPCCancellationHandle(_:)``.
+  public init(descriptor: MethodDescriptor, cancellation: RPCCancellationHandle) {
     self.descriptor = descriptor
+    self.cancellation = cancellation
   }
 }

+ 2 - 2
Sources/GRPCCore/Internal/Result+Catching.swift

@@ -14,12 +14,12 @@
  * limitations under the License.
  */
 
-extension Result where Failure == any Error {
+extension Result {
   /// Like `Result(catching:)`, but `async`.
   ///
   /// - Parameter body: An `async` closure to catch the result of.
   @inlinable
-  init(catching body: () async throws -> Success) async {
+  init(catching body: () async throws(Failure) -> Success) async {
     do {
       self = .success(try await body())
     } catch {

+ 58 - 3
Sources/GRPCInProcessTransport/InProcessTransport+Server.swift

@@ -15,6 +15,7 @@
  */
 
 public import GRPCCore
+private import Synchronization
 
 extension InProcessTransport {
   /// An in-process implementation of a ``ServerTransport``.
@@ -27,16 +28,54 @@ extension InProcessTransport {
   /// To stop listening to new requests, call ``beginGracefulShutdown()``.
   ///
   /// - SeeAlso: ``ClientTransport``
-  public struct Server: ServerTransport, Sendable {
+  public final class Server: ServerTransport, Sendable {
     public typealias Inbound = RPCAsyncSequence<RPCRequestPart, any Error>
     public typealias Outbound = RPCWriter<RPCResponsePart>.Closable
 
     private let newStreams: AsyncStream<RPCStream<Inbound, Outbound>>
     private let newStreamsContinuation: AsyncStream<RPCStream<Inbound, Outbound>>.Continuation
 
+    private struct State: Sendable {
+      private var _nextID: UInt64
+      private var handles: [UInt64: ServerContext.RPCCancellationHandle]
+      private var isShutdown: Bool
+
+      private mutating func nextID() -> UInt64 {
+        let id = self._nextID
+        self._nextID &+= 1
+        return id
+      }
+
+      init() {
+        self._nextID = 0
+        self.handles = [:]
+        self.isShutdown = false
+      }
+
+      mutating func addHandle(_ handle: ServerContext.RPCCancellationHandle) -> (UInt64, Bool) {
+        let handleID = self.nextID()
+        self.handles[handleID] = handle
+        return (handleID, self.isShutdown)
+      }
+
+      mutating func removeHandle(withID id: UInt64) {
+        self.handles.removeValue(forKey: id)
+      }
+
+      mutating func beginShutdown() -> [ServerContext.RPCCancellationHandle] {
+        self.isShutdown = true
+        let values = Array(self.handles.values)
+        self.handles.removeAll()
+        return values
+      }
+    }
+
+    private let handles: Mutex<State>
+
     /// Creates a new instance of ``Server``.
     public init() {
       (self.newStreams, self.newStreamsContinuation) = AsyncStream.makeStream()
+      self.handles = Mutex(State())
     }
 
     /// Publish a new ``RPCStream``, which will be returned by the transport's ``events``
@@ -64,8 +103,21 @@ extension InProcessTransport {
       await withDiscardingTaskGroup { group in
         for await stream in self.newStreams {
           group.addTask {
-            let context = ServerContext(descriptor: stream.descriptor)
-            await streamHandler(stream, context)
+            await withServerContextRPCCancellationHandle { handle in
+              let (id, isShutdown) = self.handles.withLock({ $0.addHandle(handle) })
+              defer {
+                self.handles.withLock { $0.removeHandle(withID: id) }
+              }
+
+              // This happens if `beginGracefulShutdown` is called after the stream is added to
+              // new streams but before it's dequeued.
+              if isShutdown {
+                handle.cancel()
+              }
+
+              let context = ServerContext(descriptor: stream.descriptor, cancellation: handle)
+              await streamHandler(stream, context)
+            }
           }
         }
       }
@@ -76,6 +128,9 @@ extension InProcessTransport {
     /// - SeeAlso: ``ServerTransport``
     public func beginGracefulShutdown() {
       self.newStreamsContinuation.finish()
+      for handle in self.handles.withLock({ $0.beginShutdown() }) {
+        handle.cancel()
+      }
     }
   }
 }

+ 91 - 0
Tests/GRPCCoreTests/Call/Server/Internal/ServerCancellationManagerTests.swift

@@ -0,0 +1,91 @@
+/*
+ * Copyright 2024, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import GRPCCore
+import Testing
+
+@Suite
+struct ServerCancellationManagerTests {
+  @Test("Isn't cancelled after init")
+  func isNotCancelled() {
+    let manager = ServerCancellationManager()
+    #expect(!manager.isRPCCancelled)
+  }
+
+  @Test("Is cancelled")
+  func isCancelled() {
+    let manager = ServerCancellationManager()
+    manager.cancelRPC()
+    #expect(manager.isRPCCancelled)
+  }
+
+  @Test("Cancellation handler runs")
+  func addCancellationHandler() async throws {
+    let manager = ServerCancellationManager()
+    let signal = AsyncStream.makeStream(of: Void.self)
+
+    let id = manager.addRPCCancelledHandler {
+      signal.continuation.finish()
+    }
+
+    #expect(id != nil)
+    manager.cancelRPC()
+    let events: [Void] = await signal.stream.reduce(into: []) { $0.append($1) }
+    #expect(events.isEmpty)
+  }
+
+  @Test("Cancellation handler runs immediately when already cancelled")
+  func addCancellationHandlerAfterCancelled() async throws {
+    let manager = ServerCancellationManager()
+    let signal = AsyncStream.makeStream(of: Void.self)
+    manager.cancelRPC()
+
+    let id = manager.addRPCCancelledHandler {
+      signal.continuation.finish()
+    }
+
+    #expect(id == nil)
+    let events: [Void] = await signal.stream.reduce(into: []) { $0.append($1) }
+    #expect(events.isEmpty)
+  }
+
+  @Test("Remove cancellation handler")
+  func removeCancellationHandler() async throws {
+    let manager = ServerCancellationManager()
+    let signal = AsyncStream.makeStream(of: Void.self)
+
+    let id = manager.addRPCCancelledHandler {
+      Issue.record("Unexpected cancellation")
+    }
+
+    #expect(id != nil)
+    manager.removeRPCCancelledHandler(withID: id!)
+    manager.cancelRPC()
+  }
+
+  @Test("Wait for cancellation")
+  func waitForCancellation() async throws {
+    let manager = ServerCancellationManager()
+    try await withThrowingTaskGroup(of: Void.self) { group in
+      group.addTask {
+        try await manager.suspendUntilRPCIsCancelled()
+      }
+
+      manager.cancelRPC()
+      try await group.waitForAll()
+    }
+  }
+}

+ 34 - 22
Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTestSupport/ServerRPCExecutorTestHarness.swift

@@ -20,24 +20,29 @@ import XCTest
 struct ServerRPCExecutorTestHarness {
   struct ServerHandler<Input: Sendable, Output: Sendable>: Sendable {
     let fn:
-      @Sendable (StreamingServerRequest<Input>) async throws -> StreamingServerResponse<Output>
+      @Sendable (
+        _ request: StreamingServerRequest<Input>,
+        _ context: ServerContext
+      ) async throws -> StreamingServerResponse<Output>
 
     init(
       _ fn: @escaping @Sendable (
-        StreamingServerRequest<Input>
+        _ request: StreamingServerRequest<Input>,
+        _ context: ServerContext
       ) async throws -> StreamingServerResponse<Output>
     ) {
       self.fn = fn
     }
 
     func handle(
-      _ request: StreamingServerRequest<Input>
+      _ request: StreamingServerRequest<Input>,
+      _ context: ServerContext
     ) async throws -> StreamingServerResponse<Output> {
-      try await self.fn(request)
+      try await self.fn(request, context)
     }
 
     static func throwing(_ error: any Error) -> Self {
-      return Self { _ in throw error }
+      return Self { _, _ in throw error }
     }
   }
 
@@ -51,7 +56,8 @@ struct ServerRPCExecutorTestHarness {
     deserializer: some MessageDeserializer<Input>,
     serializer: some MessageSerializer<Output>,
     handler: @escaping @Sendable (
-      StreamingServerRequest<Input>
+      StreamingServerRequest<Input>,
+      ServerContext
     ) async throws -> StreamingServerResponse<Output>,
     producer: @escaping @Sendable (
       RPCWriter<RPCRequestPart>.Closable
@@ -93,21 +99,27 @@ struct ServerRPCExecutorTestHarness {
       }
 
       group.addTask {
-        let context = ServerContext(descriptor: MethodDescriptor(service: "foo", method: "bar"))
-        await ServerRPCExecutor.execute(
-          context: context,
-          stream: RPCStream(
-            descriptor: context.descriptor,
-            inbound: RPCAsyncSequence(wrapping: input.stream),
-            outbound: RPCWriter.Closable(wrapping: output.continuation)
-          ),
-          deserializer: deserializer,
-          serializer: serializer,
-          interceptors: self.interceptors,
-          handler: { stream, context in
-            try await handler.handle(stream)
-          }
-        )
+        await withServerContextRPCCancellationHandle { cancellation in
+          let context = ServerContext(
+            descriptor: MethodDescriptor(service: "foo", method: "bar"),
+            cancellation: cancellation
+          )
+
+          await ServerRPCExecutor.execute(
+            context: context,
+            stream: RPCStream(
+              descriptor: context.descriptor,
+              inbound: RPCAsyncSequence(wrapping: input.stream),
+              outbound: RPCWriter.Closable(wrapping: output.continuation)
+            ),
+            deserializer: deserializer,
+            serializer: serializer,
+            interceptors: self.interceptors,
+            handler: { stream, context in
+              try await handler.handle(stream, context)
+            }
+          )
+        }
       }
 
       try await group.waitForAll()
@@ -135,7 +147,7 @@ struct ServerRPCExecutorTestHarness {
 
 extension ServerRPCExecutorTestHarness.ServerHandler where Input == Output {
   static var echo: Self {
-    return Self { request in
+    return Self { request, context in
       return StreamingServerResponse(metadata: request.metadata) { writer in
         try await writer.write(contentsOf: request.messages)
         return [:]

+ 8 - 14
Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTests.swift

@@ -83,7 +83,7 @@ final class ServerRPCExecutorTests: XCTestCase {
     try await harness.execute(
       deserializer: JSONDeserializer<String>(),
       serializer: JSONSerializer<String>()
-    ) { request in
+    ) { request, _ in
       let messages = try await request.messages.collect()
       XCTAssertEqual(messages, ["hello"])
       return StreamingServerResponse(metadata: request.metadata) { writer in
@@ -112,7 +112,7 @@ final class ServerRPCExecutorTests: XCTestCase {
     try await harness.execute(
       deserializer: JSONDeserializer<String>(),
       serializer: JSONSerializer<String>()
-    ) { request in
+    ) { request, _ in
       let messages = try await request.messages.collect()
       XCTAssertEqual(messages, ["hello", "world"])
       return StreamingServerResponse(metadata: request.metadata) { writer in
@@ -144,7 +144,7 @@ final class ServerRPCExecutorTests: XCTestCase {
     try await harness.execute(
       deserializer: IdentityDeserializer(),
       serializer: IdentitySerializer()
-    ) { request in
+    ) { request, _ in
       return StreamingServerResponse(metadata: request.metadata) { _ in
         return ["bar": "baz"]
       }
@@ -235,15 +235,9 @@ final class ServerRPCExecutorTests: XCTestCase {
     try await harness.execute(
       deserializer: IdentityDeserializer(),
       serializer: IdentitySerializer()
-    ) { request in
-      do {
-        try await Task.sleep(until: .now.advanced(by: .seconds(180)), clock: .continuous)
-      } catch is CancellationError {
-        throw RPCError(code: .cancelled, message: "Sleep was cancelled")
-      }
-
-      XCTFail("Server handler should've been cancelled by timeout.")
-      return StreamingServerResponse(error: RPCError(code: .failedPrecondition, message: ""))
+    ) { request, context in
+      try await context.cancellation.cancelled
+      throw RPCError(code: .cancelled, message: "Cancelled from server handler")
     } producer: { inbound in
       try await inbound.write(.metadata(["grpc-timeout": "1000n"]))
       await inbound.finish()
@@ -251,7 +245,7 @@ final class ServerRPCExecutorTests: XCTestCase {
       let part = try await outbound.collect().first
       XCTAssertStatus(part) { status, _ in
         XCTAssertEqual(status.code, .cancelled)
-        XCTAssertEqual(status.message, "Sleep was cancelled")
+        XCTAssertEqual(status.message, "Cancelled from server handler")
       }
     }
   }
@@ -268,7 +262,7 @@ final class ServerRPCExecutorTests: XCTestCase {
     try await harness.execute(
       deserializer: IdentityDeserializer(),
       serializer: IdentitySerializer()
-    ) { request in
+    ) { request, _ in
       XCTFail("Unexpected request")
       return StreamingServerResponse(
         of: [UInt8].self,

+ 62 - 0
Tests/GRPCCoreTests/Call/Server/ServerContextTests.swift

@@ -0,0 +1,62 @@
+/*
+ * Copyright 2024, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import GRPCCore
+import Testing
+
+@Suite("ServerContext")
+struct ServerContextTests {
+  @Suite("CancellationHandle")
+  struct CancellationHandle {
+    @Test("Is cancelled")
+    func isCancelled() async throws {
+      await withServerContextRPCCancellationHandle { handle in
+        #expect(!handle.isCancelled)
+        handle.cancel()
+        #expect(handle.isCancelled)
+      }
+    }
+
+    @Test("Wait for cancellation")
+    func waitForCancellation() async throws {
+      await withServerContextRPCCancellationHandle { handle in
+        await withTaskGroup(of: Void.self) { group in
+          group.addTask {
+            try? await handle.cancelled
+          }
+          handle.cancel()
+          await group.waitForAll()
+        }
+      }
+    }
+
+    @Test("Binds task local")
+    func bindsTaskLocal() async throws {
+      await withServerContextRPCCancellationHandle { handle in
+        let signal = AsyncStream.makeStream(of: Void.self)
+
+        await withRPCCancellationHandler {
+          handle.cancel()
+          for await _ in signal.stream {}
+        } onCancelRPC: {
+          // If the task local wasn't bound, this wouldn't run.
+          signal.continuation.finish()
+        }
+      }
+
+    }
+  }
+}

+ 125 - 0
Tests/GRPCInProcessTransportTests/InProcessTransportTests.swift

@@ -0,0 +1,125 @@
+/*
+ * Copyright 2024, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import GRPCCore
+import GRPCInProcessTransport
+import Testing
+
+@Suite("InProcess transport")
+struct InProcessTransportTests {
+  private static let cancellationModes = ["await-cancelled", "with-cancellation-handler"]
+
+  private func withTestServerAndClient(
+    execute: (GRPCServer, GRPCClient) async throws -> Void
+  ) async throws {
+    try await withThrowingDiscardingTaskGroup { group in
+      let inProcess = InProcessTransport()
+
+      let server = GRPCServer(transport: inProcess.server, services: [TestService()])
+      group.addTask {
+        try await server.serve()
+      }
+
+      let client = GRPCClient(transport: inProcess.client)
+      group.addTask {
+        try await client.run()
+      }
+
+      try await execute(server, client)
+    }
+  }
+
+  @Test("RPC cancelled by graceful shutdown", arguments: Self.cancellationModes)
+  func cancelledByGracefulShutdown(mode: String) async throws {
+    try await self.withTestServerAndClient { server, client in
+      try await client.serverStreaming(
+        request: ClientRequest(message: mode),
+        descriptor: .testCancellation,
+        serializer: UTF8Serializer(),
+        deserializer: UTF8Deserializer(),
+        options: .defaults
+      ) { response in
+        // Got initial metadata, begin shutdown to cancel the RPC.
+        server.beginGracefulShutdown()
+
+        // Now wait for the response.
+        let messages = try await response.messages.reduce(into: []) { $0.append($1) }
+        #expect(messages == ["isCancelled=true"])
+      }
+
+      // Finally, shutdown the client so its run() method returns.
+      client.beginGracefulShutdown()
+    }
+  }
+}
+
+private struct TestService: RegistrableRPCService {
+  func cancellation(
+    request: ServerRequest<String>,
+    context: ServerContext
+  ) async throws -> StreamingServerResponse<String> {
+    switch request.message {
+    case "await-cancelled":
+      return StreamingServerResponse { body in
+        try await context.cancellation.cancelled
+        try await body.write("isCancelled=\(context.cancellation.isCancelled)")
+        return [:]
+      }
+
+    case "with-cancellation-handler":
+      let signal = AsyncStream.makeStream(of: Void.self)
+      return StreamingServerResponse { body in
+        try await withRPCCancellationHandler {
+          for await _ in signal.stream {}
+          try await body.write("isCancelled=\(context.cancellation.isCancelled)")
+          return [:]
+        } onCancelRPC: {
+          signal.continuation.finish()
+        }
+      }
+
+    default:
+      throw RPCError(code: .invalidArgument, message: "Invalid argument '\(request.message)'")
+    }
+  }
+
+  func registerMethods(with router: inout RPCRouter) {
+    router.registerHandler(
+      forMethod: .testCancellation,
+      deserializer: UTF8Deserializer(),
+      serializer: UTF8Serializer(),
+      handler: {
+        try await self.cancellation(request: ServerRequest(stream: $0), context: $1)
+      }
+    )
+  }
+}
+
+extension MethodDescriptor {
+  fileprivate static let testCancellation = Self(service: "test", method: "cancellation")
+}
+
+private struct UTF8Serializer: MessageSerializer {
+  func serialize(_ message: String) throws -> [UInt8] {
+    Array(message.utf8)
+  }
+}
+
+private struct UTF8Deserializer: MessageDeserializer {
+  func deserialize(_ serializedMessageBytes: [UInt8]) throws -> String {
+    String(decoding: serializedMessageBytes, as: UTF8.self)
+  }
+}