Procházet zdrojové kódy

Simplify client rpc execution (#1993)

* Simplify client rpc execution

Motivation:

The existing code for executing RPCs on the client is layered, the
one-shot, retry and hedging executors serialize the messages of the
requests they're processing into bytes. These 'raw' requests are
processed by the stream executor which is structured in such a way as
requiring a task group of events it needs to handled. These events are
handled in a run method (which is in turn another task for the caller to
run). This requires quite a bit of machinery (and allocations) for it to
work.

Modifications:

- Push the serialization/deserialization of messages down into the
  stream executor. This lets a typed request be transformed directly to
  `RPCRequestPart`s and for `RPCResponsePart`s to be transformed
  directly to a typed response (rather than via an intermediary
  request/response typed to `[UInt8]`).
- Remove the sendability requirement from the `next` function for client
  interceptors. This makes sense: interceptors should be straight-line
  code so shouldn't be shuffled off into a subtask. This unlocks the
  ability to remove the task group and event stream previously used by
  the stream executor as it can add child tasks to a provided task group
  (rather than needing a separate task group with an event stream).
- Apply this change to the one-shot, hedging, and retry executors.

Result:

- ~35% reduction in allocations for unary RPCs on the client

* use task group void

* fix comment
George Barnett před 1 rokem
rodič
revize
3f749833b2

+ 1 - 1
Sources/GRPCCore/Call/Client/ClientInterceptor.swift

@@ -102,7 +102,7 @@ public protocol ClientInterceptor: Sendable {
   func intercept<Input: Sendable, Output: Sendable>(
     request: ClientRequest.Stream<Input>,
     context: ClientInterceptorContext,
-    next: @Sendable (
+    next: (
       _ request: ClientRequest.Stream<Input>,
       _ context: ClientInterceptorContext
     ) async throws -> ClientResponse.Stream<Output>

+ 43 - 44
Sources/GRPCCore/Call/Client/Internal/ClientRPCExecutor+HedgingExecutor.swift

@@ -103,12 +103,7 @@ extension ClientRPCExecutor.HedgingExecutor {
       }
 
       group.addTask {
-        var metadata = request.metadata
-        if let deadline = self.deadline {
-          metadata.timeout = ContinuousClock.now.duration(to: deadline)
-        }
-
-        let replayableRequest = ClientRequest.Stream(metadata: metadata) { writer in
+        let replayableRequest = ClientRequest.Stream(metadata: request.metadata) { writer in
           try await writer.write(contentsOf: broadcast.stream)
         }
 
@@ -243,6 +238,10 @@ extension ClientRPCExecutor.HedgingExecutor {
             ()
           }
 
+        case .attemptPicked:
+          // Not used by this task group.
+          fatalError("Internal inconsistency")
+
         case .attemptCompleted(let outcome):
           switch outcome {
           case .usableResponse(let response):
@@ -327,7 +326,7 @@ extension ClientRPCExecutor.HedgingExecutor {
         descriptor: method,
         options: options
       ) { stream -> _HedgingAttemptTaskResult<R, Output>.AttemptResult in
-        return await withTaskGroup(of: _HedgingAttemptSubtaskResult<Output>.self) { group in
+        return await withTaskGroup(of: _HedgingAttemptTaskResult<R, Output>.self) { group in
           group.addTask {
             do {
               // The picker stream will have at most one element.
@@ -340,35 +339,27 @@ extension ClientRPCExecutor.HedgingExecutor {
             }
           }
 
-          let processor = ClientStreamExecutor(transport: self.transport)
-
           group.addTask {
-            await processor.run()
-            return .processorFinished
-          }
-
-          group.addTask {
-            let response = await ClientRPCExecutor.unsafeExecute(
-              request: request,
-              method: method,
-              attempt: attempt,
-              serializer: self.serializer,
-              deserializer: self.deserializer,
-              interceptors: self.interceptors,
-              streamProcessor: processor,
-              stream: stream
-            )
-            return .response(response)
-          }
-
-          for await next in group {
-            switch next {
-            case .attemptPicked(let wasPicked):
-              if !wasPicked {
-                group.cancelAll()
+            let result = await withTaskGroup(
+              of: Void.self,
+              returning: _HedgingAttemptTaskResult<R, Output>.AttemptResult.self
+            ) { group in
+              var request = request
+              if let deadline = self.deadline {
+                request.metadata.timeout = ContinuousClock.now.duration(to: deadline)
               }
 
-            case .response(let response):
+              let response = await ClientRPCExecutor._execute(
+                in: &group,
+                request: request,
+                method: method,
+                attempt: attempt,
+                serializer: self.serializer,
+                deserializer: self.deserializer,
+                interceptors: self.interceptors,
+                stream: stream
+              )
+
               switch response.accepted {
               case .success:
                 self.transport.retryThrottle?.recordSuccess()
@@ -405,10 +396,25 @@ extension ClientRPCExecutor.HedgingExecutor {
                   }
                 }
               }
+            }
 
-            case .processorFinished:
-              // Processor finished, wait for the response outcome.
-              ()
+            return .attemptCompleted(result)
+          }
+
+          for await next in group {
+            switch next {
+            case .attemptPicked(let wasPicked):
+              if !wasPicked {
+                group.cancelAll()
+              }
+
+            case .attemptCompleted(let result):
+              group.cancelAll()
+              return result
+
+            case .scheduledAttemptFired:
+              // Not used by this task group.
+              fatalError("Internal inconsistency")
             }
           }
 
@@ -516,6 +522,7 @@ enum _HedgingTaskResult<R> {
 @available(macOS 15.0, iOS 18.0, watchOS 11.0, tvOS 18.0, visionOS 2.0, *)
 @usableFromInline
 enum _HedgingAttemptTaskResult<R, Output> {
+  case attemptPicked(Bool)
   case attemptCompleted(AttemptResult)
   case scheduledAttemptFired(ScheduleEvent)
 
@@ -532,11 +539,3 @@ enum _HedgingAttemptTaskResult<R, Output> {
     case cancelled
   }
 }
-
-@available(macOS 15.0, iOS 18.0, watchOS 11.0, tvOS 18.0, visionOS 2.0, *)
-@usableFromInline
-enum _HedgingAttemptSubtaskResult<Output> {
-  case attemptPicked(Bool)
-  case processorFinished
-  case response(ClientResponse.Stream<Output>)
-}

+ 96 - 67
Sources/GRPCCore/Call/Client/Internal/ClientRPCExecutor+OneShotExecutor.swift

@@ -66,86 +66,115 @@ extension ClientRPCExecutor.OneShotExecutor {
     options: CallOptions,
     responseHandler: @Sendable @escaping (ClientResponse.Stream<Output>) async throws -> R
   ) async throws -> R {
-    let result = await withTaskGroup(
-      of: _OneShotExecutorTask<R>.self,
-      returning: Result<R, any Error>.self
-    ) { group in
+    let result: Result<R, any Error>
+
+    if let deadline = self.deadline {
+      var request = request
+      request.metadata.timeout = ContinuousClock.now.duration(to: deadline)
+      result = await withDeadline(deadline) {
+        await self._execute(
+          request: request,
+          method: method,
+          options: options,
+          responseHandler: responseHandler
+        )
+      }
+    } else {
+      result = await self._execute(
+        request: request,
+        method: method,
+        options: options,
+        responseHandler: responseHandler
+      )
+    }
+
+    return try result.get()
+  }
+}
+
+@available(macOS 15.0, iOS 18.0, watchOS 11.0, tvOS 18.0, visionOS 2.0, *)
+extension ClientRPCExecutor.OneShotExecutor {
+  @inlinable
+  func _execute<R>(
+    request: ClientRequest.Stream<Input>,
+    method: MethodDescriptor,
+    options: CallOptions,
+    responseHandler: @Sendable @escaping (ClientResponse.Stream<Output>) async throws -> R
+  ) async -> Result<R, any Error> {
+    return await withTaskGroup(of: Void.self, returning: Result<R, any Error>.self) { group in
       do {
         return try await self.transport.withStream(descriptor: method, options: options) { stream in
-          var request = request
-
-          if let deadline = self.deadline {
-            request.metadata.timeout = ContinuousClock.now.duration(to: deadline)
-            group.addTask {
-              let result = await Result {
-                try await Task.sleep(until: deadline, clock: .continuous)
-              }
-              return .timedOut(result)
-            }
+          let response = await ClientRPCExecutor._execute(
+            in: &group,
+            request: request,
+            method: method,
+            attempt: 1,
+            serializer: self.serializer,
+            deserializer: self.deserializer,
+            interceptors: self.interceptors,
+            stream: stream
+          )
+
+          let result = await Result {
+            try await responseHandler(response)
           }
 
-          let streamExecutor = ClientStreamExecutor(transport: self.transport)
-          group.addTask {
-            await streamExecutor.run()
-            return .streamExecutorCompleted
-          }
-
-          group.addTask { [request] in
-            let response = await ClientRPCExecutor.unsafeExecute(
-              request: request,
-              method: method,
-              attempt: 1,
-              serializer: self.serializer,
-              deserializer: self.deserializer,
-              interceptors: self.interceptors,
-              streamProcessor: streamExecutor,
-              stream: stream
-            )
-
-            let result = await Result {
-              try await responseHandler(response)
-            }
-
-            return .responseHandled(result)
-          }
+          // The user handler can finish before the stream. Cancel it if that's the case.
+          group.cancelAll()
 
-          while let result = await group.next() {
-            switch result {
-            case .streamExecutorCompleted:
-              // Stream finished; wait for the response to be handled.
-              ()
-
-            case .timedOut(.success):
-              // The deadline passed; cancel the ongoing work group.
-              group.cancelAll()
-
-            case .timedOut(.failure):
-              // The deadline task failed (because the task was cancelled). Wait for the response
-              // to be handled.
-              ()
-
-            case .responseHandled(let result):
-              // Response handled: cancel any other remaining tasks.
-              group.cancelAll()
-              return result
-            }
-          }
-
-          // Unreachable: exactly one task returns `responseHandled` and we return when it completes.
-          fatalError("Internal inconsistency")
+          return result
         }
       } catch {
         return .failure(error)
       }
     }
+  }
+}
 
-    return try result.get()
+@available(macOS 15.0, iOS 18.0, watchOS 11.0, tvOS 18.0, visionOS 2.0, *)
+@inlinable
+func withDeadline<Result>(
+  _ deadline: ContinuousClock.Instant,
+  execute: @escaping () async -> Result
+) async -> Result {
+  return await withTaskGroup(of: _DeadlineChildTaskResult<Result>.self) { group in
+    group.addTask {
+      do {
+        try await Task.sleep(until: deadline)
+        return .deadlinePassed
+      } catch {
+        return .timeoutCancelled
+      }
+    }
+
+    group.addTask {
+      let result = await execute()
+      return .taskCompleted(result)
+    }
+
+    while let next = await group.next() {
+      switch next {
+      case .deadlinePassed:
+        // Timeout expired; cancel the work.
+        group.cancelAll()
+
+      case .timeoutCancelled:
+        ()  // Wait for more tasks to finish.
+
+      case .taskCompleted(let result):
+        // The work finished. Cancel any remaining tasks.
+        group.cancelAll()
+        return result
+      }
+    }
+
+    fatalError("Internal inconsistency")
   }
 }
 
 @usableFromInline
-enum _OneShotExecutorTask<R> {
-  case streamExecutorCompleted
-  case timedOut(Result<Void, any Error>)
-  case responseHandled(Result<R, any Error>)
+enum _DeadlineChildTaskResult<Value> {
+  case deadlinePassed
+  case timeoutCancelled
+  case taskCompleted(Value)
 }

+ 120 - 132
Sources/GRPCCore/Call/Client/Internal/ClientRPCExecutor+RetryExecutor.swift

@@ -125,131 +125,20 @@ extension ClientRPCExecutor.RetryExecutor {
             options: options
           ) { stream in
             group.addTask {
-              await withTaskGroup(
-                of: _RetryExecutorSubTask<R>.self,
-                returning: _RetryExecutorTask<R>.self
-              ) { thisAttemptGroup in
-                let streamExecutor = ClientStreamExecutor(transport: self.transport)
-                thisAttemptGroup.addTask {
-                  await streamExecutor.run()
-                  return .streamProcessed
-                }
-
-                thisAttemptGroup.addTask {
-                  var metadata = request.metadata
-                  // Work out the timeout from the deadline.
-                  if let deadline = self.deadline {
-                    metadata.timeout = ContinuousClock.now.duration(to: deadline)
-                  }
-
-                  let response = await ClientRPCExecutor.unsafeExecute(
-                    request: ClientRequest.Stream(metadata: metadata) {
-                      try await $0.write(contentsOf: retry.stream)
-                    },
-                    method: method,
-                    attempt: attempt,
-                    serializer: self.serializer,
-                    deserializer: self.deserializer,
-                    interceptors: self.interceptors,
-                    streamProcessor: streamExecutor,
-                    stream: stream
-                  )
-
-                  let shouldRetry: Bool
-                  let retryDelayOverride: Duration?
-
-                  switch response.accepted {
-                  case .success:
-                    // Request was accepted. This counts as success to the throttle and there's no need
-                    // to retry.
-                    self.transport.retryThrottle?.recordSuccess()
-                    retryDelayOverride = nil
-                    shouldRetry = false
-
-                  case .failure(let error):
-                    // The request was rejected. Determine whether a retry should be carried out. The
-                    // following conditions must be checked:
-                    //
-                    // - Whether the status code is retryable.
-                    // - Whether more attempts are permitted by the config.
-                    // - Whether the throttle permits another retry to be carried out.
-                    // - Whether the server pushed back to either stop further retries or to override
-                    //   the delay before the next retry.
-                    let code = Status.Code(error.code)
-                    let isRetryableStatusCode = self.policy.retryableStatusCodes.contains(code)
-
-                    if isRetryableStatusCode {
-                      // Counted as failure for throttling.
-                      let throttled = self.transport.retryThrottle?.recordFailure() ?? false
-
-                      // Status code can be retried, Did the server send pushback?
-                      switch error.metadata.retryPushback {
-                      case .retryAfter(let delay):
-                        // Pushback: only retry if our config permits it.
-                        shouldRetry = (attempt < self.policy.maximumAttempts) && !throttled
-                        retryDelayOverride = delay
-                      case .stopRetrying:
-                        // Server told us to stop trying.
-                        shouldRetry = false
-                        retryDelayOverride = nil
-                      case .none:
-                        // No pushback: only retry if our config permits it.
-                        shouldRetry = (attempt < self.policy.maximumAttempts) && !throttled
-                        retryDelayOverride = nil
-                        break
-                      }
-                    } else {
-                      // Not-retryable; this is considered a success.
-                      self.transport.retryThrottle?.recordSuccess()
-                      shouldRetry = false
-                      retryDelayOverride = nil
-                    }
-                  }
-
-                  if shouldRetry {
-                    // Cancel subscribers of the broadcast sequence. This is safe as we are the only
-                    // subscriber and maximises the chances that 'isKnownSafeForNextSubscriber' will
-                    // return true.
-                    //
-                    // Note: this must only be called if we should retry, otherwise we may cancel a
-                    // subscriber for an accepted request.
-                    retry.stream.invalidateAllSubscriptions()
-
-                    // Only retry if we know it's safe for the next subscriber, that is, the first
-                    // element is still in the buffer. It's safe to call this because there's only
-                    // ever one attempt at a time and the existing subscribers have been invalidated.
-                    if retry.stream.isKnownSafeForNextSubscriber {
-                      return .retry(retryDelayOverride)
-                    }
-                  }
-
-                  // Not retrying or not safe to retry.
-                  let result = await Result {
-                    // Check for cancellation; the RPC may have timed out in which case we should skip
-                    // the response handler.
-                    try Task.checkCancellation()
-                    return try await responseHandler(response)
-                  }
-                  return .handledResponse(result)
-                }
-
-                while let result = await thisAttemptGroup.next() {
-                  switch result {
-                  case .streamProcessed:
-                    ()  // Continue processing; wait for the response to be handled.
-
-                  case .retry(let delayOverride):
-                    thisAttemptGroup.cancelAll()
-                    return .retry(delayOverride)
-
-                  case .handledResponse(let result):
-                    thisAttemptGroup.cancelAll()
-                    return .handledResponse(result)
-                  }
-                }
-
-                fatalError("Internal inconsistency")
+              var metadata = request.metadata
+              // Work out the timeout from the deadline.
+              if let deadline = self.deadline {
+                metadata.timeout = ContinuousClock.now.duration(to: deadline)
               }
+
+              return await self.executeAttempt(
+                stream: stream,
+                metadata: metadata,
+                retryStream: retry.stream,
+                method: method,
+                attempt: attempt,
+                responseHandler: responseHandler
+              )
             }
 
             loop: while let next = await group.next() {
@@ -307,6 +196,113 @@ extension ClientRPCExecutor.RetryExecutor {
 
     return try result.get()
   }
+
+  @inlinable
+  func executeAttempt<R>(
+    stream: RPCStream<ClientTransport.Inbound, ClientTransport.Outbound>,
+    metadata: Metadata,
+    retryStream: BroadcastAsyncSequence<Input>,
+    method: MethodDescriptor,
+    attempt: Int,
+    responseHandler: @Sendable @escaping (ClientResponse.Stream<Output>) async throws -> R
+  ) async -> _RetryExecutorTask<R> {
+    return await withTaskGroup(
+      of: Void.self,
+      returning: _RetryExecutorTask<R>.self
+    ) { group in
+      let request = ClientRequest.Stream(metadata: metadata) {
+        try await $0.write(contentsOf: retryStream)
+      }
+
+      let response = await ClientRPCExecutor._execute(
+        in: &group,
+        request: request,
+        method: method,
+        attempt: attempt,
+        serializer: self.serializer,
+        deserializer: self.deserializer,
+        interceptors: self.interceptors,
+        stream: stream
+      )
+
+      let shouldRetry: Bool
+      let retryDelayOverride: Duration?
+
+      switch response.accepted {
+      case .success:
+        // Request was accepted. This counts as success to the throttle and there's no need
+        // to retry.
+        self.transport.retryThrottle?.recordSuccess()
+        retryDelayOverride = nil
+        shouldRetry = false
+
+      case .failure(let error):
+        // The request was rejected. Determine whether a retry should be carried out. The
+        // following conditions must be checked:
+        //
+        // - Whether the status code is retryable.
+        // - Whether more attempts are permitted by the config.
+        // - Whether the throttle permits another retry to be carried out.
+        // - Whether the server pushed back to either stop further retries or to override
+        //   the delay before the next retry.
+        let code = Status.Code(error.code)
+        let isRetryableStatusCode = self.policy.retryableStatusCodes.contains(code)
+
+        if isRetryableStatusCode {
+          // Counted as failure for throttling.
+          let throttled = self.transport.retryThrottle?.recordFailure() ?? false
+
+          // Status code can be retried, Did the server send pushback?
+          switch error.metadata.retryPushback {
+          case .retryAfter(let delay):
+            // Pushback: only retry if our config permits it.
+            shouldRetry = (attempt < self.policy.maximumAttempts) && !throttled
+            retryDelayOverride = delay
+          case .stopRetrying:
+            // Server told us to stop trying.
+            shouldRetry = false
+            retryDelayOverride = nil
+          case .none:
+            // No pushback: only retry if our config permits it.
+            shouldRetry = (attempt < self.policy.maximumAttempts) && !throttled
+            retryDelayOverride = nil
+            break
+          }
+        } else {
+          // Not-retryable; this is considered a success.
+          self.transport.retryThrottle?.recordSuccess()
+          shouldRetry = false
+          retryDelayOverride = nil
+        }
+      }
+
+      if shouldRetry {
+        // Cancel subscribers of the broadcast sequence. This is safe as we are the only
+        // subscriber and maximises the chances that 'isKnownSafeForNextSubscriber' will
+        // return true.
+        //
+        // Note: this must only be called if we should retry, otherwise we may cancel a
+        // subscriber for an accepted request.
+        retryStream.invalidateAllSubscriptions()
+
+        // Only retry if we know it's safe for the next subscriber, that is, the first
+        // element is still in the buffer. It's safe to call this because there's only
+        // ever one attempt at a time and the existing subscribers have been invalidated.
+        if retryStream.isKnownSafeForNextSubscriber {
+          return .retry(retryDelayOverride)
+        }
+      }
+
+      // Not retrying or not safe to retry.
+      let result = await Result {
+        // Check for cancellation; the RPC may have timed out in which case we should skip
+        // the response handler.
+        try Task.checkCancellation()
+        return try await responseHandler(response)
+      }
+      return .handledResponse(result)
+    }
+  }
 }
 
 @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
@@ -317,11 +313,3 @@ enum _RetryExecutorTask<R> {
   case retry(Duration?)
   case outboundFinished(Result<Void, any Error>)
 }
-
-@available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
-@usableFromInline
-enum _RetryExecutorSubTask<R> {
-  case streamProcessed
-  case handledResponse(Result<R, any Error>)
-  case retry(Duration?)
-}

+ 22 - 77
Sources/GRPCCore/Call/Client/Internal/ClientRPCExecutor.swift

@@ -104,9 +104,6 @@ enum ClientRPCExecutor {
 extension ClientRPCExecutor {
   /// Executes a request on a given stream processor.
   ///
-  /// - Warning: This method is "unsafe" because the `streamProcessor` must be running in a task
-  ///   while this function is executing.
-  ///
   /// - Parameters:
   ///   - request: The request to execute.
   ///   - method: A description of the method to execute the request against.
@@ -115,116 +112,58 @@ extension ClientRPCExecutor {
   ///   - deserializer: A deserializer to convert bytes to output messages.
   ///   - interceptors: An array of interceptors which the request and response pass through. The
   ///       interceptors will be called in the order of the array.
-  ///   - streamProcessor: A processor which executes the serialized request.
   /// - Returns: The deserialized response.
-  @inlinable
-  static func unsafeExecute<Transport: ClientTransport, Input: Sendable, Output: Sendable>(
+  @inlinable  // would be private
+  static func _execute<Input: Sendable, Output: Sendable>(
+    in group: inout TaskGroup<Void>,
     request: ClientRequest.Stream<Input>,
     method: MethodDescriptor,
     attempt: Int,
     serializer: some MessageSerializer<Input>,
     deserializer: some MessageDeserializer<Output>,
     interceptors: [any ClientInterceptor],
-    streamProcessor: ClientStreamExecutor<Transport>,
-    stream: RPCStream<Transport.Inbound, Transport.Outbound>
+    stream: RPCStream<ClientTransport.Inbound, ClientTransport.Outbound>
   ) async -> ClientResponse.Stream<Output> {
     let context = ClientInterceptorContext(descriptor: method)
 
     if interceptors.isEmpty {
-      return await Self._runRPC(
+      return await ClientStreamExecutor.execute(
+        in: &group,
         request: request,
         context: context,
         attempt: attempt,
         serializer: serializer,
         deserializer: deserializer,
-        streamProcessor: streamProcessor,
         stream: stream
       )
     } else {
       return await Self._intercept(
+        in: &group,
         request: request,
         context: context,
-        interceptors: interceptors
-      ) { request, context in
-        return await Self._runRPC(
+        iterator: interceptors.makeIterator()
+      ) { group, request, context in
+        return await ClientStreamExecutor.execute(
+          in: &group,
           request: request,
           context: context,
           attempt: attempt,
           serializer: serializer,
           deserializer: deserializer,
-          streamProcessor: streamProcessor,
           stream: stream
         )
       }
     }
   }
 
-  @inlinable
-  static func _runRPC<Transport: ClientTransport, Input: Sendable, Output: Sendable>(
-    request: ClientRequest.Stream<Input>,
-    context: ClientInterceptorContext,
-    attempt: Int,
-    serializer: some MessageSerializer<Input>,
-    deserializer: some MessageDeserializer<Output>,
-    streamProcessor: ClientStreamExecutor<Transport>,
-    stream: RPCStream<Transport.Inbound, Transport.Outbound>
-  ) async -> ClientResponse.Stream<Output> {
-    // Let the server know this is a retry.
-    var metadata = request.metadata
-    if attempt > 1 {
-      metadata.previousRPCAttempts = attempt &- 1
-    }
-
-    var response = await streamProcessor.execute(
-      request: ClientRequest.Stream<[UInt8]>(metadata: metadata) { writer in
-        try await request.producer(.serializing(into: writer, with: serializer))
-      },
-      method: context.descriptor,
-      stream: stream
-    )
-
-    // Attach the number of previous attempts, it can be useful information for callers.
-    if attempt > 1 {
-      switch response.accepted {
-      case .success(var contents):
-        contents.metadata.previousRPCAttempts = attempt &- 1
-        response.accepted = .success(contents)
-
-      case .failure(var error):
-        error.metadata.previousRPCAttempts = attempt &- 1
-        response.accepted = .failure(error)
-      }
-    }
-
-    return response.map { bytes in
-      try deserializer.deserialize(bytes)
-    }
-  }
-
-  @inlinable
-  static func _intercept<Input, Output>(
-    request: ClientRequest.Stream<Input>,
-    context: ClientInterceptorContext,
-    interceptors: [any ClientInterceptor],
-    finally: @Sendable (
-      _ request: ClientRequest.Stream<Input>,
-      _ context: ClientInterceptorContext
-    ) async -> ClientResponse.Stream<Output>
-  ) async -> ClientResponse.Stream<Output> {
-    return await self._intercept(
-      request: request,
-      context: context,
-      iterator: interceptors.makeIterator(),
-      finally: finally
-    )
-  }
-
   @inlinable
   static func _intercept<Input, Output>(
+    in group: inout TaskGroup<Void>,
     request: ClientRequest.Stream<Input>,
     context: ClientInterceptorContext,
     iterator: Array<any ClientInterceptor>.Iterator,
-    finally: @Sendable (
+    finally: (
+      _ group: inout TaskGroup<Void>,
       _ request: ClientRequest.Stream<Input>,
       _ context: ClientInterceptorContext
     ) async -> ClientResponse.Stream<Output>
@@ -236,7 +175,13 @@ extension ClientRPCExecutor {
       let iter = iterator
       do {
         return try await interceptor.intercept(request: request, context: context) {
-          await self._intercept(request: $0, context: $1, iterator: iter, finally: finally)
+          await self._intercept(
+            in: &group,
+            request: $0,
+            context: $1,
+            iterator: iter,
+            finally: finally
+          )
         }
       } catch let error as RPCError {
         return ClientResponse.Stream(error: error)
@@ -246,7 +191,7 @@ extension ClientRPCExecutor {
       }
 
     case .none:
-      return await finally(request, context)
+      return await finally(&group, request, context)
     }
   }
 }

+ 105 - 108
Sources/GRPCCore/Call/Client/Internal/ClientStreamExecutor.swift

@@ -16,59 +16,7 @@
 
 @available(macOS 15.0, iOS 18.0, watchOS 11.0, tvOS 18.0, visionOS 2.0, *)
 @usableFromInline
-internal struct ClientStreamExecutor<Transport: ClientTransport> {
-  /// The client transport to execute the stream on.
-  @usableFromInline
-  let _transport: Transport
-
-  /// An `AsyncStream` and continuation to send and receive processing events on.
-  @usableFromInline
-  let _work: (stream: AsyncStream<_Event>, continuation: AsyncStream<_Event>.Continuation)
-
-  @usableFromInline
-  let _watermarks: (low: Int, high: Int)
-
-  @usableFromInline
-  enum _Event: Sendable {
-    /// Send the request on the outbound stream.
-    case request(ClientRequest.Stream<[UInt8]>, Transport.Outbound)
-    /// Receive the response from the inbound stream.
-    case response(
-      RPCWriter<ClientResponse.Stream<[UInt8]>.Contents.BodyPart>.Closable,
-      UnsafeTransfer<Transport.Inbound.AsyncIterator>
-    )
-  }
-
-  @inlinable
-  init(transport: Transport, responseStreamWatermarks: (low: Int, high: Int) = (16, 32)) {
-    self._transport = transport
-    self._work = AsyncStream.makeStream()
-    self._watermarks = responseStreamWatermarks
-  }
-
-  /// Run the stream executor.
-  ///
-  /// This is required to be running until the response returned from ``execute(request:method:)``
-  /// has been processed.
-  @inlinable
-  func run() async {
-    await withTaskGroup(of: Void.self) { group in
-      for await event in self._work.stream {
-        switch event {
-        case .request(let request, let outboundStream):
-          group.addTask {
-            await self._processRequest(request, on: outboundStream)
-          }
-
-        case .response(let writer, let iterator):
-          group.addTask {
-            await self._processResponse(writer: writer, iterator: iterator)
-          }
-        }
-      }
-    }
-  }
-
+internal enum ClientStreamExecutor {
   /// Execute a request on the stream executor.
   ///
   /// The ``run()`` method must be running at the same time as this method.
@@ -78,36 +26,51 @@ internal struct ClientStreamExecutor<Transport: ClientTransport> {
   ///   - method: A description of the method to call.
   /// - Returns: A streamed response.
   @inlinable
-  func execute(
-    request: ClientRequest.Stream<[UInt8]>,
-    method: MethodDescriptor,
-    stream: RPCStream<Transport.Inbound, Transport.Outbound>
-  ) async -> ClientResponse.Stream<[UInt8]> {
-    // Each execution method can add work to process in the 'run' method. They must not add
-    // new work once they return.
-    defer { self._work.continuation.finish() }
-
-    // Start processing the request.
-    self._work.continuation.yield(.request(request, stream.outbound))
+  static func execute<Input: Sendable, Output: Sendable>(
+    in group: inout TaskGroup<Void>,
+    request: ClientRequest.Stream<Input>,
+    context: ClientInterceptorContext,
+    attempt: Int,
+    serializer: some MessageSerializer<Input>,
+    deserializer: some MessageDeserializer<Output>,
+    stream: RPCStream<ClientTransport.Inbound, ClientTransport.Outbound>
+  ) async -> ClientResponse.Stream<Output> {
+    // Let the server know this is a retry.
+    var metadata = request.metadata
+    if attempt > 1 {
+      metadata.previousRPCAttempts = attempt &- 1
+    }
 
-    let part = await self._waitForFirstResponsePart(on: stream.inbound)
+    group.addTask {
+      await Self._processRequest(on: stream.outbound, request: request, serializer: serializer)
+    }
 
+    let part = await Self._waitForFirstResponsePart(on: stream.inbound)
     // Wait for the first response to determine how to handle the response.
     switch part {
-    case .metadata(let metadata, let iterator):
-      // Expected happy case: the server is processing the request.
+    case .metadata(var metadata, let iterator):
+      // Attach the number of previous attempts, it can be useful information for callers.
+      if attempt > 1 {
+        metadata.previousRPCAttempts = attempt &- 1
+      }
+
+      let bodyParts = RawBodyPartToMessageSequence(
+        base: AsyncIteratorSequence(iterator.wrappedValue),
+        deserializer: deserializer
+      )
 
-      // TODO: (optimisation) use a hint about whether the response is streamed. Use a specialised
-      // sequence to avoid allocations if it isn't
-      let responses = RPCAsyncSequence.makeBackpressuredStream(
-        of: ClientResponse.Stream<[UInt8]>.Contents.BodyPart.self,
-        watermarks: self._watermarks
+      // Expected happy case: the server is processing the request.
+      return ClientResponse.Stream(
+        metadata: metadata,
+        bodyParts: RPCAsyncSequence(wrapping: bodyParts)
       )
 
-      self._work.continuation.yield(.response(responses.writer, iterator))
-      return ClientResponse.Stream(metadata: metadata, bodyParts: responses.stream)
+    case .status(let status, var metadata):
+      // Attach the number of previous attempts, it can be useful information for callers.
+      if attempt > 1 {
+        metadata.previousRPCAttempts = attempt &- 1
+      }
 
-    case .status(let status, let metadata):
       // Expected unhappy (but okay) case; the server rejected the request.
       return ClientResponse.Stream(status: status, metadata: metadata)
 
@@ -117,14 +80,15 @@ internal struct ClientStreamExecutor<Transport: ClientTransport> {
     }
   }
 
-  @inlinable
-  func _processRequest<Stream: ClosableRPCWriterProtocol<RPCRequestPart>>(
-    _ request: ClientRequest.Stream<[UInt8]>,
-    on stream: Stream
+  @inlinable  // would be private
+  static func _processRequest<Outbound>(
+    on stream: some ClosableRPCWriterProtocol<RPCRequestPart>,
+    request: ClientRequest.Stream<Outbound>,
+    serializer: some MessageSerializer<Outbound>
   ) async {
     let result = await Result {
       try await stream.write(.metadata(request.metadata))
-      try await request.producer(.map(into: stream) { .message($0) })
+      try await request.producer(.map(into: stream) { .message(try serializer.serialize($0)) })
     }.castError(to: RPCError.self) { other in
       RPCError(code: .unknown, message: "Write failed.", cause: other)
     }
@@ -139,14 +103,14 @@ internal struct ClientStreamExecutor<Transport: ClientTransport> {
 
   @usableFromInline
   enum OnFirstResponsePart: Sendable {
-    case metadata(Metadata, UnsafeTransfer<Transport.Inbound.AsyncIterator>)
+    case metadata(Metadata, UnsafeTransfer<ClientTransport.Inbound.AsyncIterator>)
     case status(Status, Metadata)
     case failed(RPCError)
   }
 
-  @inlinable
-  func _waitForFirstResponsePart(
-    on stream: Transport.Inbound
+  @inlinable  // would be private
+  static func _waitForFirstResponsePart(
+    on stream: ClientTransport.Inbound
   ) async -> OnFirstResponsePart {
     var iterator = stream.makeAsyncIterator()
     let result = await Result<OnFirstResponsePart, any Error> {
@@ -193,15 +157,56 @@ internal struct ClientStreamExecutor<Transport: ClientTransport> {
     }
   }
 
-  @inlinable
-  func _processResponse(
-    writer: RPCWriter<ClientResponse.Stream<[UInt8]>.Contents.BodyPart>.Closable,
-    iterator: UnsafeTransfer<Transport.Inbound.AsyncIterator>
-  ) async {
-    var iterator = iterator.wrappedValue
-    let result = await Result {
-      while let next = try await iterator.next() {
-        switch next {
+  @usableFromInline
+  @available(macOS 15.0, iOS 18.0, watchOS 11.0, tvOS 18.0, visionOS 2.0, *)
+  struct RawBodyPartToMessageSequence<
+    Base: AsyncSequence<RPCResponsePart, Failure>,
+    Message: Sendable,
+    Deserializer: MessageDeserializer<Message>,
+    Failure: Error
+  >: AsyncSequence {
+    @usableFromInline
+    typealias Element = AsyncIterator.Element
+
+    @usableFromInline
+    let base: Base
+    @usableFromInline
+    let deserializer: Deserializer
+
+    @inlinable
+    init(base: Base, deserializer: Deserializer) {
+      self.base = base
+      self.deserializer = deserializer
+    }
+
+    @inlinable
+    func makeAsyncIterator() -> AsyncIterator {
+      AsyncIterator(base: self.base.makeAsyncIterator(), deserializer: self.deserializer)
+    }
+
+    @usableFromInline
+    struct AsyncIterator: AsyncIteratorProtocol {
+      @usableFromInline
+      typealias Element = ClientResponse.Stream<Message>.Contents.BodyPart
+
+      @usableFromInline
+      var base: Base.AsyncIterator
+      @usableFromInline
+      let deserializer: Deserializer
+
+      @inlinable
+      init(base: Base.AsyncIterator, deserializer: Deserializer) {
+        self.base = base
+        self.deserializer = deserializer
+      }
+
+      @inlinable
+      mutating func next(
+        isolation actor: isolated (any Actor)?
+      ) async throws(any Error) -> ClientResponse.Stream<Message>.Contents.BodyPart? {
+        guard let part = try await self.base.next(isolation: `actor`) else { return nil }
+
+        switch part {
         case .metadata(let metadata):
           let error = RPCError(
             code: .internalError,
@@ -213,30 +218,22 @@ internal struct ClientStreamExecutor<Transport: ClientTransport> {
           throw error
 
         case .message(let bytes):
-          try await writer.write(.message(bytes))
+          let message = try self.deserializer.deserialize(bytes)
+          return .message(message)
 
         case .status(let status, let metadata):
           if let error = RPCError(status: status, metadata: metadata) {
             throw error
           } else {
-            try await writer.write(.trailingMetadata(metadata))
+            return .trailingMetadata(metadata)
           }
         }
       }
-    }.castError(to: RPCError.self) { error in
-      RPCError(
-        code: .unknown,
-        message: "Can't write to output stream, cancelling RPC.",
-        cause: error
-      )
-    }
 
-    // Make sure the writer is finished.
-    switch result {
-    case .success:
-      writer.finish()
-    case .failure(let error):
-      writer.finish(throwing: error)
+      @inlinable
+      mutating func next() async throws -> ClientResponse.Stream<Message>.Contents.BodyPart? {
+        try await self.next(isolation: nil)
+      }
     }
   }
 }

+ 4 - 4
Sources/GRPCCore/Streaming/Internal/RPCWriter+Map.swift

@@ -23,10 +23,10 @@ struct MapRPCWriter<Value, Mapped, Base: RPCWriterProtocol<Mapped>>: RPCWriterPr
   @usableFromInline
   let base: Base
   @usableFromInline
-  let transform: @Sendable (Value) -> Mapped
+  let transform: @Sendable (Value) throws -> Mapped
 
   @inlinable
-  init(base: Base, transform: @escaping @Sendable (Value) -> Mapped) {
+  init(base: Base, transform: @escaping @Sendable (Value) throws -> Mapped) {
     self.base = base
     self.transform = transform
   }
@@ -38,7 +38,7 @@ struct MapRPCWriter<Value, Mapped, Base: RPCWriterProtocol<Mapped>>: RPCWriterPr
 
   @inlinable
   func write(contentsOf elements: some Sequence<Value>) async throws {
-    let transformed = elements.lazy.map { self.transform($0) }
+    let transformed = try elements.lazy.map { try self.transform($0) }
     try await self.base.write(contentsOf: transformed)
   }
 }
@@ -48,7 +48,7 @@ extension RPCWriter {
   @inlinable
   static func map<Mapped>(
     into writer: some RPCWriterProtocol<Mapped>,
-    transform: @Sendable @escaping (Element) -> Mapped
+    transform: @Sendable @escaping (Element) throws -> Mapped
   ) -> Self {
     let mapper = MapRPCWriter(base: writer, transform: transform)
     return RPCWriter(wrapping: mapper)

+ 4 - 2
Sources/GRPCInterceptors/ClientTracingInterceptor.swift

@@ -46,8 +46,10 @@ public struct ClientTracingInterceptor: ClientInterceptor {
   public func intercept<Input, Output>(
     request: ClientRequest.Stream<Input>,
     context: ClientInterceptorContext,
-    next: @Sendable (ClientRequest.Stream<Input>, ClientInterceptorContext) async throws ->
-      ClientResponse.Stream<Output>
+    next: (
+      ClientRequest.Stream<Input>,
+      ClientInterceptorContext
+    ) async throws -> ClientResponse.Stream<Output>
   ) async throws -> ClientResponse.Stream<Output> where Input: Sendable, Output: Sendable {
     var request = request
     let tracer = InstrumentationSystem.tracer

+ 2 - 2
Tests/GRPCCoreTests/Test Utilities/Call/Client/ClientInterceptors.swift

@@ -52,7 +52,7 @@ struct RejectAllClientInterceptor: ClientInterceptor {
   func intercept<Input: Sendable, Output: Sendable>(
     request: ClientRequest.Stream<Input>,
     context: ClientInterceptorContext,
-    next: @Sendable (
+    next: (
       ClientRequest.Stream<Input>,
       ClientInterceptorContext
     ) async throws -> ClientResponse.Stream<Output>
@@ -77,7 +77,7 @@ struct RequestCountingClientInterceptor: ClientInterceptor {
   func intercept<Input: Sendable, Output: Sendable>(
     request: ClientRequest.Stream<Input>,
     context: ClientInterceptorContext,
-    next: @Sendable (
+    next: (
       ClientRequest.Stream<Input>,
       ClientInterceptorContext
     ) async throws -> ClientResponse.Stream<Output>