Browse Source

Add support for retries (#1708)

Motivation:

The `ClientRPCExecutor` currently ignores retry and hedging policies.
This change adds support for retries.

Modifications:

- Add a retry executor and wire it up to the client rpc executor
- Add a few missing state transitions to the broadcasts sequence

Result:

RPC can be retried under certain conditions
George Barnett 2 years ago
parent
commit
017dc09f11

+ 4 - 7
Sources/GRPCCore/Call/Client/Internal/ClientRPCExecutor+OneShotExecutor.swift

@@ -80,7 +80,8 @@ extension ClientRPCExecutor.OneShotExecutor {
 
       let streamExecutor = ClientStreamExecutor(transport: self.transport)
       group.addTask {
-        return .streamExecutorCompleted(await streamExecutor.run())
+        await streamExecutor.run()
+        return .streamExecutorCompleted
       }
 
       group.addTask {
@@ -103,14 +104,10 @@ extension ClientRPCExecutor.OneShotExecutor {
 
       while let result = await group.next() {
         switch result {
-        case .streamExecutorCompleted(.success):
+        case .streamExecutorCompleted:
           // Stream finished; wait for the response to be handled.
           ()
 
-        case .streamExecutorCompleted(.failure):
-          // Stream execution threw: cancel and wait.
-          group.cancelAll()
-
         case .timedOut(.success):
           // The deadline passed; cancel the ongoing work group.
           group.cancelAll()
@@ -137,7 +134,7 @@ extension ClientRPCExecutor.OneShotExecutor {
 
 @usableFromInline
 enum _OneShotExecutorTask<R> {
-  case streamExecutorCompleted(Result<Void, RPCError>)
+  case streamExecutorCompleted
   case timedOut(Result<Void, Error>)
   case responseHandled(Result<R, Error>)
 }

+ 306 - 0
Sources/GRPCCore/Call/Client/Internal/ClientRPCExecutor+RetryExecutor.swift

@@ -0,0 +1,306 @@
+/*
+ * Copyright 2023, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+@available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
+extension ClientRPCExecutor {
+  @usableFromInline
+  struct RetryExecutor<
+    Transport: ClientTransport,
+    Serializer: MessageSerializer,
+    Deserializer: MessageDeserializer
+  > {
+    @usableFromInline
+    typealias Input = Serializer.Message
+    @usableFromInline
+    typealias Output = Deserializer.Message
+
+    @usableFromInline
+    let transport: Transport
+    @usableFromInline
+    let policy: RetryPolicy
+    @usableFromInline
+    let timeout: Duration?
+    @usableFromInline
+    let interceptors: [any ClientInterceptor]
+    @usableFromInline
+    let serializer: Serializer
+    @usableFromInline
+    let deserializer: Deserializer
+    @usableFromInline
+    let bufferSize: Int
+
+    @inlinable
+    init(
+      transport: Transport,
+      policy: RetryPolicy,
+      timeout: Duration?,
+      interceptors: [any ClientInterceptor],
+      serializer: Serializer,
+      deserializer: Deserializer,
+      bufferSize: Int
+    ) {
+      self.transport = transport
+      self.policy = policy
+      self.timeout = timeout
+      self.interceptors = interceptors
+      self.serializer = serializer
+      self.deserializer = deserializer
+      self.bufferSize = bufferSize
+    }
+  }
+}
+
+@available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
+extension ClientRPCExecutor.RetryExecutor {
+  @inlinable
+  func execute<R: Sendable>(
+    request: ClientRequest.Stream<Input>,
+    method: MethodDescriptor,
+    responseHandler: @Sendable @escaping (ClientResponse.Stream<Output>) async throws -> R
+  ) async throws -> R {
+    // There's quite a lot going on here...
+    //
+    // The high level approach is to have two levels of task group. In the outer level tasks are
+    // run to:
+    // - run a timeout task (if necessary),
+    // - run the request producer so that it writes into a broadcast sequence (in this instance we
+    //   don't care about broadcasting but the sequence's ability to replay)
+    // - run the inner task group.
+    //
+    // An inner task group is run for each RPC attempt. We might also pause between attempts. The
+    // inner group runs two tasks:
+    // - a stream executor, and
+    // - the unsafe RPC executor which inspects the response, either passing it to the handler or
+    //   deciding a retry should be undertaken.
+    //
+    // It is also worth noting that the server can override the retry delay using "pushback" and
+    // retries may be skipped if the throttle is applied.
+    let result = await withTaskGroup(
+      of: _RetryExecutorTask<R>.self,
+      returning: Result<R, Error>.self
+    ) { group in
+      // Add a task to limit the overall execution time of the RPC.
+      if let timeout = self.timeout {
+        group.addTask {
+          let result = await Result {
+            try await Task.sleep(until: .now.advanced(by: timeout), clock: .continuous)
+          }
+          return .timedOut(result)
+        }
+      }
+
+      // Play the original request into the broadcast sequence and construct a replayable request.
+      let retry = BroadcastAsyncSequence<Input>.makeStream(bufferSize: self.bufferSize)
+      group.addTask {
+        let result = await Result {
+          try await request.producer(RPCWriter(wrapping: retry.continuation))
+        }
+        retry.continuation.finish(with: result)
+        return .outboundFinished(result)
+      }
+
+      // The sequence isn't limited by the number of attempts as the iterator is reset when the
+      // server applies pushback.
+      let delaySequence = RetryDelaySequence(policy: self.policy)
+      var delayIterator = delaySequence.makeIterator()
+
+      for attempt in 1 ... self.policy.maximumAttempts {
+        group.addTask {
+          await withTaskGroup(
+            of: _RetryExecutorSubTask<R>.self,
+            returning: _RetryExecutorTask<R>.self
+          ) { thisAttemptGroup in
+            let streamExecutor = ClientStreamExecutor(transport: self.transport)
+            thisAttemptGroup.addTask {
+              await streamExecutor.run()
+              return .streamProcessed
+            }
+
+            thisAttemptGroup.addTask {
+              let response = await ClientRPCExecutor.unsafeExecute(
+                request: ClientRequest.Stream(metadata: request.metadata) {
+                  try await $0.write(contentsOf: retry.stream)
+                },
+                method: method,
+                attempt: attempt,
+                serializer: self.serializer,
+                deserializer: self.deserializer,
+                interceptors: self.interceptors,
+                streamProcessor: streamExecutor
+              )
+
+              let shouldRetry: Bool
+              let retryDelayOverride: Duration?
+
+              switch response.accepted {
+              case .success:
+                // Request was accepted. This counts as success to the throttle and there's no need
+                // to retry.
+                self.transport.retryThrottle.recordSuccess()
+                retryDelayOverride = nil
+                shouldRetry = false
+
+              case .failure(let error):
+                // The request was rejected. Determine whether a retry should be carried out. The
+                // following conditions must be checked:
+                //
+                // - Whether the status code is retryable.
+                // - Whether more attempts are permitted by the config.
+                // - Whether the throttle permits another retry to be carried out.
+                // - Whether the server pushed back to either stop further retries or to override
+                //   the delay before the next retry.
+                let code = Status.Code(error.code)
+                let isRetryableStatusCode = self.policy.retryableStatusCodes.contains(code)
+
+                if isRetryableStatusCode {
+                  // Counted as failure for throttling.
+                  let throttled = self.transport.retryThrottle.recordFailure()
+
+                  // Status code can be retried, Did the server send pushback?
+                  switch error.metadata.retryPushback {
+                  case .retryAfter(let delay):
+                    // Pushback: only retry if our config permits it.
+                    shouldRetry = (attempt < self.policy.maximumAttempts) && !throttled
+                    retryDelayOverride = delay
+                  case .stopRetrying:
+                    // Server told us to stop trying.
+                    shouldRetry = false
+                    retryDelayOverride = nil
+                  case .none:
+                    // No pushback: only retry if our config permits it.
+                    shouldRetry = (attempt < self.policy.maximumAttempts) && !throttled
+                    retryDelayOverride = nil
+                    break
+                  }
+                } else {
+                  // Not-retryable; this is considered a success.
+                  self.transport.retryThrottle.recordSuccess()
+                  shouldRetry = false
+                  retryDelayOverride = nil
+                }
+              }
+
+              if shouldRetry {
+                // Cancel subscribers of the broadcast sequence. This is safe as we are the only
+                // subscriber and maximises the chances that 'isKnownSafeForNextSubscriber' will
+                // return true.
+                //
+                // Note: this must only be called if we should retry, otherwise we may cancel a
+                // subscriber for an accepted request.
+                retry.stream.invalidateAllSubscriptions()
+
+                // Only retry if we know it's safe for the next subscriber, that is, the first
+                // element is still in the buffer. It's safe to call this because there's only
+                // ever one attempt at a time and the existing subscribers have been invalidated.
+                if retry.stream.isKnownSafeForNextSubscriber {
+                  return .retry(retryDelayOverride)
+                }
+              }
+
+              // Not retrying or not safe to retry.
+              let result = await Result {
+                // Check for cancellation; the RPC may have timed out in which case we should skip
+                // the response handler.
+                try Task.checkCancellation()
+                return try await responseHandler(response)
+              }
+              return .handledResponse(result)
+            }
+
+            while let result = await thisAttemptGroup.next() {
+              switch result {
+              case .streamProcessed:
+                ()  // Continue processing; wait for the response to be handled.
+
+              case .retry(let delayOverride):
+                thisAttemptGroup.cancelAll()
+                return .retry(delayOverride)
+
+              case .handledResponse(let result):
+                thisAttemptGroup.cancelAll()
+                return .handledResponse(result)
+              }
+            }
+
+            fatalError("Internal inconsistency")
+          }
+        }
+
+        loop: while let next = await group.next() {
+          switch next {
+          case .handledResponse(let result):
+            // A usable response; cancel the remaining work and return the result.
+            group.cancelAll()
+            return result
+
+          case .retry(let delayOverride):
+            // The attempt failed, wait a bit and then retry. The server might have overridden the
+            // delay via pushback so preferentially use that value.
+            //
+            // Any error will come from cancellation: if it happens while we're sleeping we can
+            // just loop around, the next attempt will be cancelled immediately and we will return
+            // its response to the client.
+            if let delayOverride = delayOverride {
+              // If the delay is overridden with server pushback then reset the iterator for the
+              // next retry.
+              delayIterator = delaySequence.makeIterator()
+              try? await Task.sleep(until: .now.advanced(by: delayOverride), clock: .continuous)
+            } else {
+              // The delay iterator never terminates.
+              try? await Task.sleep(
+                until: .now.advanced(by: delayIterator.next()!),
+                clock: .continuous
+              )
+            }
+
+            break loop  // from the while loop so another attempt can be started.
+
+          case .timedOut(.success), .outboundFinished(.failure):
+            // Timeout task fired successfully or failed to process the outbound stream. Cancel and
+            // wait for a usable response (which is likely to be an error).
+            group.cancelAll()
+
+          case .timedOut(.failure), .outboundFinished(.success):
+            // Timeout task failed which means it was cancelled (so no need to cancel again) or the
+            // outbound stream was successfully processed (so don't need to do anything).
+            ()
+          }
+        }
+      }
+
+      fatalError("Internal inconsistency")
+    }
+
+    return try result.get()
+  }
+}
+
+@available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
+@usableFromInline
+enum _RetryExecutorTask<R> {
+  case timedOut(Result<Void, Error>)
+  case handledResponse(Result<R, Error>)
+  case retry(Duration?)
+  case outboundFinished(Result<Void, Error>)
+}
+
+@available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
+@usableFromInline
+enum _RetryExecutorSubTask<R> {
+  case streamProcessed
+  case handledResponse(Result<R, Error>)
+  case retry(Duration?)
+}

+ 18 - 1
Sources/GRPCCore/Call/Client/Internal/ClientRPCExecutor.swift

@@ -58,7 +58,24 @@ enum ClientRPCExecutor {
         responseHandler: handler
       )
 
-    case .retry, .hedge:
+    case .retry(let policy):
+      let retryExecutor = RetryExecutor(
+        transport: transport,
+        policy: policy,
+        timeout: configuration.timeout,
+        interceptors: interceptors,
+        serializer: serializer,
+        deserializer: deserializer,
+        bufferSize: 64  // TODO: the client should have some control over this.
+      )
+
+      return try await retryExecutor.execute(
+        request: request,
+        method: method,
+        responseHandler: handler
+      )
+
+    case .hedge:
       fatalError()
     }
   }

+ 7 - 21
Sources/GRPCCore/Call/Client/Internal/ClientStreamExecutor.swift

@@ -51,8 +51,8 @@ internal struct ClientStreamExecutor<Transport: ClientTransport> {
   /// This is required to be running until the response returned from ``execute(request:method:)``
   /// has been processed.
   @inlinable
-  func run() async -> Result<Void, RPCError> {
-    await withTaskGroup(of: Result<Void, RPCError>.self) { group in
+  func run() async {
+    await withTaskGroup(of: Void.self) { group in
       for await event in self._work.stream {
         switch event {
         case .request(let request, let outboundStream):
@@ -66,18 +66,6 @@ internal struct ClientStreamExecutor<Transport: ClientTransport> {
           }
         }
       }
-
-      while let result = await group.next() {
-        switch result {
-        case .success:
-          ()
-        case .failure:
-          group.cancelAll()
-          return result
-        }
-      }
-
-      return .success(())
     }
   }
 
@@ -117,8 +105,10 @@ internal struct ClientStreamExecutor<Transport: ClientTransport> {
     // Start processing the request.
     self._work.continuation.yield(.request(request, stream.outbound))
 
+    let part = await self._waitForFirstResponsePart(on: stream.inbound)
+
     // Wait for the first response to determine how to handle the response.
-    switch await self._waitForFirstResponsePart(on: stream.inbound) {
+    switch part {
     case .metadata(let metadata, let iterator):
       // Expected happy case: the server is processing the request.
 
@@ -146,7 +136,7 @@ internal struct ClientStreamExecutor<Transport: ClientTransport> {
   func _processRequest<Stream: ClosableRPCWriterProtocol<RPCRequestPart>>(
     _ request: ClientRequest.Stream<[UInt8]>,
     on stream: Stream
-  ) async -> Result<Void, RPCError> {
+  ) async {
     let result = await Result {
       try await stream.write(.metadata(request.metadata))
       try await request.producer(.map(into: stream) { .message($0) })
@@ -160,8 +150,6 @@ internal struct ClientStreamExecutor<Transport: ClientTransport> {
     case .failure(let error):
       stream.finish(throwing: error)
     }
-
-    return result
   }
 
   @usableFromInline
@@ -224,7 +212,7 @@ internal struct ClientStreamExecutor<Transport: ClientTransport> {
   func _processResponse(
     writer: RPCWriter<ClientResponse.Stream<[UInt8]>.Contents.BodyPart>.Closable,
     iterator: UnsafeTransfer<Transport.Inbound.AsyncIterator>
-  ) async -> Result<Void, RPCError> {
+  ) async {
     var iterator = iterator.wrappedValue
     let result = await Result {
       while let next = try await iterator.next() {
@@ -265,7 +253,5 @@ internal struct ClientStreamExecutor<Transport: ClientTransport> {
     case .failure(let error):
       writer.finish(throwing: error)
     }
-
-    return result
   }
 }

+ 95 - 0
Sources/GRPCCore/Call/Client/Internal/RetryDelaySequence.swift

@@ -0,0 +1,95 @@
+/*
+ * Copyright 2023, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#if canImport(Darwin)
+import Darwin
+#else
+import Glibc
+#endif
+
+@available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
+@usableFromInline
+struct RetryDelaySequence: Sequence {
+  @usableFromInline
+  typealias Element = Duration
+
+  @usableFromInline
+  let policy: RetryPolicy
+
+  @inlinable
+  init(policy: RetryPolicy) {
+    self.policy = policy
+  }
+
+  @inlinable
+  func makeIterator() -> Iterator {
+    Iterator(policy: self.policy)
+  }
+
+  @usableFromInline
+  struct Iterator: IteratorProtocol {
+    @usableFromInline
+    let policy: RetryPolicy
+    @usableFromInline
+    private(set) var n = 1
+
+    @inlinable
+    init(policy: RetryPolicy) {
+      self.policy = policy
+    }
+
+    @inlinable
+    var _initialBackoffSeconds: Double {
+      Self._durationToTimeInterval(self.policy.initialBackoff)
+    }
+
+    @inlinable
+    var _maximumBackoffSeconds: Double {
+      Self._durationToTimeInterval(self.policy.maximumBackoff)
+    }
+
+    @inlinable
+    mutating func next() -> Duration? {
+      defer { self.n += 1 }
+
+      /// The nth retry will happen after a randomly chosen delay between zero and
+      /// `min(initialBackoff * backoffMultiplier^(n-1), maximumBackoff)`.
+      let factor = pow(self.policy.backoffMultiplier, Double(self.n - 1))
+      let computedBackoff = self._initialBackoffSeconds * factor
+      let clampedBackoff = Swift.min(computedBackoff, self._maximumBackoffSeconds)
+      let randomisedBackoff = Double.random(in: 0.0 ... clampedBackoff)
+
+      return Self._timeIntervalToDuration(randomisedBackoff)
+    }
+
+    @inlinable
+    static func _timeIntervalToDuration(_ seconds: Double) -> Duration {
+      let secondsComponent = Int64(seconds)
+      let attoseconds = (seconds - Double(secondsComponent)) * 1e18
+      let attosecondsComponent = Int64(attoseconds)
+      return Duration(
+        secondsComponent: secondsComponent,
+        attosecondsComponent: attosecondsComponent
+      )
+    }
+
+    @inlinable
+    static func _durationToTimeInterval(_ duration: Duration) -> Double {
+      var seconds = Double(duration.components.seconds)
+      seconds += (Double(duration.components.attoseconds) / 1e18)
+      return seconds
+    }
+  }
+}

+ 25 - 0
Sources/GRPCCore/Streaming/Internal/BroadcastAsyncSequence+RPCWriter.swift

@@ -0,0 +1,25 @@
+/*
+ * Copyright 2023, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
+extension BroadcastAsyncSequence.Source: ClosableRPCWriterProtocol {
+  @inlinable
+  func write(contentsOf elements: some Sequence<Element>) async throws {
+    for element in elements {
+      try await self.write(element)
+    }
+  }
+}

+ 105 - 56
Sources/GRPCCore/Streaming/Internal/BroadcastAsyncSequence.swift

@@ -430,10 +430,7 @@ struct _BroadcastSequenceStateMachine<Element: Sendable>: Sendable {
 
       @inlinable
       mutating func finish(result: Result<Void, Error>) -> OnFinish {
-        guard let continuations = self.subscriptions.removeSubscribersWithContinuations() else {
-          return .none
-        }
-
+        let continuations = self.subscriptions.removeSubscribersWithContinuations()
         return .resume(
           .init(continuations: continuations, result: result.map { nil }),
           .init(continuations: [], result: .success(()))
@@ -455,8 +452,7 @@ struct _BroadcastSequenceStateMachine<Element: Sendable>: Sendable {
       mutating func cancel(
         _ id: _BroadcastSequenceStateMachine<Element>.Subscriptions.ID
       ) -> OnCancelSubscription {
-        let (removed, continuation) = self.subscriptions.removeSubscriber(withID: id)
-        assert(removed)
+        let (_, continuation) = self.subscriptions.removeSubscriber(withID: id)
         if let continuation = continuation {
           return .resume(continuation, .failure(CancellationError()))
         } else {
@@ -469,8 +465,11 @@ struct _BroadcastSequenceStateMachine<Element: Sendable>: Sendable {
         _ continuation: ConsumerContinuation,
         forSubscription id: _BroadcastSequenceStateMachine<Element>.Subscriptions.ID
       ) -> OnSetContinuation {
-        let didSet = self.subscriptions.setContinuation(continuation, forSubscriber: id)
-        return didSet ? .none : .resume(continuation, .failure(CancellationError()))
+        if self.subscriptions.setContinuation(continuation, forSubscriber: id) {
+          return .none
+        } else {
+          return .resume(continuation, .failure(CancellationError()))
+        }
       }
 
       @inlinable
@@ -480,23 +479,28 @@ struct _BroadcastSequenceStateMachine<Element: Sendable>: Sendable {
 
       @inlinable
       mutating func invalidateAllSubscriptions() -> OnInvalidateAllSubscriptions {
+        // Remove subscriptions with continuations, they need to be failed.
+        let continuations = self.subscriptions.removeSubscribersWithContinuations()
+        let consumerContinuations = ConsumerContinuations(
+          continuations: continuations,
+          result: .failure(BroadcastAsyncSequenceError.consumingTooSlow)
+        )
+
+        // Remove any others to be failed when they next call 'next'.
         let ids = self.subscriptions.removeAllSubscribers()
         self.subscriptionsToDrop.append(contentsOf: ids)
-        return .none
+        return .resume(consumerContinuations)
       }
 
       @inlinable
       mutating func dropResources(error: BroadcastAsyncSequenceError) -> OnDropResources {
-        if let continuations = self.subscriptions.removeSubscribersWithContinuations() {
-          let consumerContinuations = ConsumerContinuations(
-            continuations: continuations,
-            result: .failure(error)
-          )
-          let producerContinuations = ProducerContinuations(continuations: [], result: .success(()))
-          return .resume(consumerContinuations, producerContinuations)
-        } else {
-          return .none
-        }
+        let continuations = self.subscriptions.removeSubscribersWithContinuations()
+        let consumerContinuations = ConsumerContinuations(
+          continuations: continuations,
+          result: .failure(error)
+        )
+        let producerContinuations = ProducerContinuations(continuations: [], result: .success(()))
+        return .resume(consumerContinuations, producerContinuations)
       }
     }
 
@@ -682,16 +686,18 @@ struct _BroadcastSequenceStateMachine<Element: Sendable>: Sendable {
         _ continuation: ConsumerContinuation,
         forSubscription id: _BroadcastSequenceStateMachine<Element>.Subscriptions.ID
       ) -> OnSetContinuation {
-        let didSet = self.subscriptions.setContinuation(continuation, forSubscriber: id)
-        return didSet ? .none : .resume(continuation, .failure(CancellationError()))
+        if self.subscriptions.setContinuation(continuation, forSubscriber: id) {
+          return .none
+        } else {
+          return .resume(continuation, .failure(CancellationError()))
+        }
       }
 
       @inlinable
       mutating func cancel(
         _ id: _BroadcastSequenceStateMachine<Element>.Subscriptions.ID
       ) -> OnCancelSubscription {
-        let (removed, continuation) = self.subscriptions.removeSubscriber(withID: id)
-        assert(removed)
+        let (_, continuation) = self.subscriptions.removeSubscriber(withID: id)
         if let continuation = continuation {
           return .resume(continuation, .failure(CancellationError()))
         } else {
@@ -739,7 +745,7 @@ struct _BroadcastSequenceStateMachine<Element: Sendable>: Sendable {
         let producers = self.producers.map { $0.0 }
         self.producers.removeAll()
         return .resume(
-          .init(continuations: continuations ?? .many([]), result: result.map { nil }),
+          .init(continuations: continuations, result: result.map { nil }),
           .init(continuations: producers, result: .success(()))
         )
       }
@@ -751,33 +757,26 @@ struct _BroadcastSequenceStateMachine<Element: Sendable>: Sendable {
 
       @inlinable
       mutating func invalidateAllSubscriptions() -> OnInvalidateAllSubscriptions {
-        let onCancel: OnInvalidateAllSubscriptions
-
         // Remove subscriptions with continuations, they need to be failed.
-        switch self.subscriptions.removeSubscribersWithContinuations() {
-        case .some(let oneOrMany):
-          let continuations = ConsumerContinuations(
-            continuations: oneOrMany,
-            result: .failure(
-              BroadcastAsyncSequenceError.consumingTooSlow
-            )
-          )
-          onCancel = .resume(continuations)
-        case .none:
-          onCancel = .none
-        }
+        let continuations = self.subscriptions.removeSubscribersWithContinuations()
+        let consumerContinuations = ConsumerContinuations(
+          continuations: continuations,
+          result: .failure(BroadcastAsyncSequenceError.consumingTooSlow)
+        )
 
         // Remove any others to be failed when they next call 'next'.
         let ids = self.subscriptions.removeAllSubscribers()
         self.subscriptionsToDrop.append(contentsOf: ids)
-        return onCancel
+        return .resume(consumerContinuations)
       }
 
       @inlinable
       mutating func dropResources(error: BroadcastAsyncSequenceError) -> OnDropResources {
-        let consumers = self.subscriptions.removeSubscribersWithContinuations().map {
-          ConsumerContinuations(continuations: $0, result: .failure(error))
-        }
+        let continuations = self.subscriptions.removeSubscribersWithContinuations()
+        let consumerContinuations = ConsumerContinuations(
+          continuations: continuations,
+          result: .failure(error)
+        )
 
         let producers = ProducerContinuations(
           continuations: self.producers.map { $0.0 },
@@ -786,10 +785,7 @@ struct _BroadcastSequenceStateMachine<Element: Sendable>: Sendable {
 
         self.producers.removeAll()
 
-        return .resume(
-          consumers ?? ConsumerContinuations(continuations: .many([]), result: .failure(error)),
-          producers
-        )
+        return .resume(consumerContinuations, producers)
       }
 
       @inlinable
@@ -894,10 +890,49 @@ struct _BroadcastSequenceStateMachine<Element: Sendable>: Sendable {
         self.subscriptions.subscribe()
       }
 
+      @inlinable
+      mutating func invalidateAllSubscriptions() -> OnInvalidateAllSubscriptions {
+        // Remove subscriptions with continuations, they need to be failed.
+        let continuations = self.subscriptions.removeSubscribersWithContinuations()
+        let consumerContinuations = ConsumerContinuations(
+          continuations: continuations,
+          result: .failure(BroadcastAsyncSequenceError.consumingTooSlow)
+        )
+
+        // Remove any others to be failed when they next call 'next'.
+        let ids = self.subscriptions.removeAllSubscribers()
+        self.subscriptionsToDrop.append(contentsOf: ids)
+        return .resume(consumerContinuations)
+      }
+
+      @inlinable
+      mutating func dropResources(error: BroadcastAsyncSequenceError) -> OnDropResources {
+        let continuations = self.subscriptions.removeSubscribersWithContinuations()
+        let consumerContinuations = ConsumerContinuations(
+          continuations: continuations,
+          result: .failure(error)
+        )
+
+        let producers = ProducerContinuations(continuations: [], result: .failure(error))
+        return .resume(consumerContinuations, producers)
+      }
+
       @inlinable
       func nextSubscriptionIsValid() -> Bool {
         self.elements.lowestID == .initial
       }
+
+      @inlinable
+      mutating func cancel(
+        _ id: _BroadcastSequenceStateMachine<Element>.Subscriptions.ID
+      ) -> OnCancelSubscription {
+        let (_, continuation) = self.subscriptions.removeSubscriber(withID: id)
+        if let continuation = continuation {
+          return .resume(continuation, .failure(CancellationError()))
+        } else {
+          return .none
+        }
+      }
     }
   }
 
@@ -944,15 +979,19 @@ struct _BroadcastSequenceStateMachine<Element: Sendable>: Sendable {
       onCancel = .none
 
     case .subscribed(var state):
+      self._state = ._modifying
       onCancel = state.invalidateAllSubscriptions()
       self._state = .subscribed(state)
 
     case .streaming(var state):
+      self._state = ._modifying
       onCancel = state.invalidateAllSubscriptions()
       self._state = .streaming(state)
 
-    case .finished:
-      onCancel = .none
+    case .finished(var state):
+      self._state = ._modifying
+      onCancel = state.invalidateAllSubscriptions()
+      self._state = .finished(state)
 
     case ._modifying:
       fatalError("Internal inconsistency")
@@ -1124,7 +1163,10 @@ struct _BroadcastSequenceStateMachine<Element: Sendable>: Sendable {
       onSetContinuation = state.setContinuation(continuation, forSubscription: id)
       self._state = .streaming(state)
 
-    case .finished, ._modifying:
+    case .finished(let state):
+      onSetContinuation = .resume(continuation, state.result.map { _ in nil })
+
+    case ._modifying:
       // All values must have been produced, nothing to wait for.
       fatalError("Internal inconsistency")
     }
@@ -1159,7 +1201,12 @@ struct _BroadcastSequenceStateMachine<Element: Sendable>: Sendable {
       onCancel = state.cancel(id)
       self._state = .streaming(state)
 
-    case .finished, ._modifying:
+    case .finished(var state):
+      self._state = ._modifying
+      onCancel = state.cancel(id)
+      self._state = .finished(state)
+
+    case ._modifying:
       // All values must have been produced, nothing to wait for.
       fatalError("Internal inconsistency")
     }
@@ -1293,8 +1340,10 @@ struct _BroadcastSequenceStateMachine<Element: Sendable>: Sendable {
       onDrop = state.dropResources(error: error)
       self._state = .finished(State.Finished(from: state, result: .failure(error)))
 
-    case .finished:
-      onDrop = .none
+    case .finished(var state):
+      self._state = ._modifying
+      onDrop = state.dropResources(error: error)
+      self._state = .finished(state)
 
     case ._modifying:
       fatalError("Internal inconsistency")
@@ -1461,7 +1510,7 @@ extension _BroadcastSequenceStateMachine {
 
 @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
 extension _BroadcastSequenceStateMachine {
-  /// A collection of subcriptions.
+  /// A collection of subscriptions.
   @usableFromInline
   struct Subscriptions: Sendable {
     @usableFromInline
@@ -1675,14 +1724,14 @@ extension _BroadcastSequenceStateMachine {
 
     /// Removes all subscribers which have continuations and return their continuations.
     @inlinable
-    mutating func removeSubscribersWithContinuations() -> _OneOrMany<ConsumerContinuation>? {
+    mutating func removeSubscribersWithContinuations() -> _OneOrMany<ConsumerContinuation> {
       // Avoid allocs if there's only one subscriber.
       let count = self._countPendingContinuations()
-      let result: _OneOrMany<ConsumerContinuation>?
+      let result: _OneOrMany<ConsumerContinuation>
 
       switch count {
       case 0:
-        result = nil
+        result = .many([])
 
       case 1:
         let index = self._subscribers.firstIndex(where: { $0.continuation != nil })!

+ 303 - 0
Tests/GRPCCoreTests/Call/Client/Internal/ClientRPCExecutorTests+Retries.swift

@@ -0,0 +1,303 @@
+/*
+ * Copyright 2023, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import GRPCCore
+import XCTest
+
+extension ClientRPCExecutorTests {
+  fileprivate func makeHarnessForRetries(
+    rejectUntilAttempt firstSuccessfulAttempt: Int,
+    withCode code: RPCError.Code,
+    consumeInboundStream: Bool = false
+  ) -> ClientRPCExecutorTestHarness {
+    return ClientRPCExecutorTestHarness(
+      server: .attemptBased { attempt in
+        guard attempt < firstSuccessfulAttempt else {
+          return .echo
+        }
+
+        return .reject(
+          withError: RPCError(code: code, message: ""),
+          consumeInbound: consumeInboundStream
+        )
+      }
+    )
+  }
+
+  func testRetriesEventuallySucceed() async throws {
+    let harness = self.makeHarnessForRetries(rejectUntilAttempt: 3, withCode: .unavailable)
+    try await harness.bidirectional(
+      request: ClientRequest.Stream(metadata: ["foo": "bar"]) {
+        try await $0.write([0])
+        try await $0.write([1])
+        try await $0.write([2])
+      },
+      configuration: .retry(codes: [.unavailable])
+    ) { response in
+      XCTAssertEqual(
+        response.metadata,
+        [
+          "foo": "bar",
+          "grpc-previous-rpc-attempts": "2",
+        ]
+      )
+      let messages = try await response.messages.collect()
+      XCTAssertEqual(messages, [[0], [1], [2]])
+    }
+
+    // Success on the third attempt.
+    XCTAssertEqual(harness.clientStreamsOpened, 3)
+    XCTAssertEqual(harness.serverStreamsAccepted, 3)
+  }
+
+  func testRetriesRespectRetryableCodes() async throws {
+    let harness = self.makeHarnessForRetries(rejectUntilAttempt: 3, withCode: .unavailable)
+    try await harness.bidirectional(
+      request: ClientRequest.Stream(metadata: ["foo": "bar"]) {
+        try await $0.write([0, 1, 2])
+      },
+      configuration: .retry(codes: [.aborted])
+    ) { response in
+      switch response.accepted {
+      case .success:
+        XCTFail("Expected response to be rejected")
+      case .failure(let error):
+        XCTAssertEqual(error.code, .unavailable)
+      }
+    }
+
+    // Error code wasn't retryable, only one stream.
+    XCTAssertEqual(harness.clientStreamsOpened, 1)
+    XCTAssertEqual(harness.serverStreamsAccepted, 1)
+  }
+
+  func testRetriesRespectRetryLimit() async throws {
+    let harness = self.makeHarnessForRetries(rejectUntilAttempt: 5, withCode: .unavailable)
+    try await harness.bidirectional(
+      request: ClientRequest.Stream(metadata: ["foo": "bar"]) {
+        try await $0.write([0, 1, 2])
+      },
+      configuration: .retry(maximumAttempts: 2, codes: [.unavailable])
+    ) { response in
+      switch response.accepted {
+      case .success:
+        XCTFail("Expected response to be rejected")
+      case .failure(let error):
+        XCTAssertEqual(error.code, .unavailable)
+        XCTAssertEqual(Array(error.metadata[stringValues: "grpc-previous-rpc-attempts"]), ["1"])
+      }
+    }
+
+    // Only two attempts permitted.
+    XCTAssertEqual(harness.clientStreamsOpened, 2)
+    XCTAssertEqual(harness.serverStreamsAccepted, 2)
+  }
+
+  func testRetriesCantBeExecutedForTooManyRequestMessages() async throws {
+    let harness = self.makeHarnessForRetries(
+      rejectUntilAttempt: 3,
+      withCode: .unavailable,
+      consumeInboundStream: true
+    )
+
+    try await harness.bidirectional(
+      request: ClientRequest.Stream {
+        for _ in 0 ..< 1000 {
+          try await $0.write([])
+        }
+      },
+      configuration: .retry(codes: [.unavailable])
+    ) { response in
+      switch response.accepted {
+      case .success:
+        XCTFail("Expected response to be rejected")
+      case .failure(let error):
+        XCTAssertEqual(error.code, .unavailable)
+        XCTAssertFalse(error.metadata.contains { $0.key == "grpc-previous-rpc-attempts" })
+      }
+    }
+
+    // The request stream can't be buffered as it's a) large, and b) the server consumes it before
+    // responding. Even though the server responded with a retryable status code, the request buffer
+    // was dropped so only one attempt was made.
+    XCTAssertEqual(harness.clientStreamsOpened, 1)
+    XCTAssertEqual(harness.serverStreamsAccepted, 1)
+  }
+
+  func testRetriesWithImmediateTimeout() async throws {
+    let harness = ClientRPCExecutorTestHarness(
+      server: .sleepFor(duration: .milliseconds(250), then: .echo)
+    )
+
+    await XCTAssertThrowsErrorAsync {
+      try await harness.bidirectional(
+        request: ClientRequest.Stream {
+          try await $0.write([0])
+          try await $0.write([1])
+          try await $0.write([2])
+        },
+        configuration: .retry(codes: [.unavailable], timeout: .zero)
+      ) { response in
+        XCTFail("Response not expected to be handled")
+      }
+    } errorHandler: { error in
+      XCTAssert(error is CancellationError)
+    }
+  }
+
+  func testRetriesWithTimeoutDuringFirstAttempt() async throws {
+    let harness = ClientRPCExecutorTestHarness(
+      server: .sleepFor(duration: .milliseconds(250), then: .echo)
+    )
+
+    await XCTAssertThrowsErrorAsync {
+      try await harness.bidirectional(
+        request: ClientRequest.Stream {
+          try await $0.write([0])
+          try await $0.write([1])
+          try await $0.write([2])
+        },
+        configuration: .retry(codes: [.unavailable], timeout: .milliseconds(50))
+      ) { response in
+        XCTFail("Response not expected to be handled")
+      }
+    } errorHandler: { error in
+      XCTAssert(error is CancellationError)
+    }
+  }
+
+  func testRetriesWithTimeoutDuringSecondAttempt() async throws {
+    let harness = ClientRPCExecutorTestHarness(
+      server: .sleepFor(
+        duration: .milliseconds(100),
+        then: .reject(withError: RPCError(code: .unavailable, message: ""))
+      )
+    )
+
+    await XCTAssertThrowsErrorAsync {
+      try await harness.bidirectional(
+        request: ClientRequest.Stream {
+          try await $0.write([0])
+          try await $0.write([1])
+          try await $0.write([2])
+        },
+        configuration: .retry(codes: [.unavailable], timeout: .milliseconds(150))
+      ) { response in
+        XCTFail("Response not expected to be handled")
+      }
+    } errorHandler: { error in
+      XCTAssert(error is CancellationError)
+    }
+  }
+
+  func testRetriesWithServerPushback() async throws {
+    let harness = ClientRPCExecutorTestHarness(
+      server: .attemptBased { attempt in
+        if attempt == 2 {
+          return .echo
+        } else {
+          return .init { stream in
+            // Use a short pushback to override the long configured retry delay.
+            let status = Status(code: .unavailable, message: "")
+            let metadata: Metadata = ["grpc-retry-pushback-ms": "10"]
+            try await stream.outbound.write(.status(status, metadata))
+          }
+        }
+      }
+    )
+
+    let retryPolicy = RetryPolicy(
+      maximumAttempts: 5,
+      initialBackoff: .seconds(60),
+      maximumBackoff: .seconds(50),
+      backoffMultiplier: 1,
+      retryableStatusCodes: [.unavailable]
+    )
+
+    let start = ContinuousClock.now
+    try await harness.bidirectional(
+      request: ClientRequest.Stream {
+        try await $0.write([0])
+      },
+      configuration: .init(retryPolicy: retryPolicy)
+    ) { response in
+      let end = ContinuousClock.now
+      let duration = end - start
+      // Loosely check whether the RPC completed in less than 60 seconds (i.e. the configured retry
+      // delay). Allow lots of headroom to avoid false negatives; CI systems can be slow.
+      XCTAssertLessThanOrEqual(duration, .seconds(5))
+      XCTAssertEqual(Array(response.metadata[stringValues: "grpc-previous-rpc-attempts"]), ["1"])
+    }
+  }
+
+  func testRetriesWithNegativeServerPushback() async throws {
+    // Negative and values which can't be parsed should halt retries.
+    for pushback in ["-1", "not-an-int"] {
+      let harness = ClientRPCExecutorTestHarness(
+        server: .reject(
+          withError: RPCError(
+            code: .unavailable,
+            message: "",
+            metadata: ["grpc-retry-pushback-ms": "\(pushback)"]
+          )
+        )
+      )
+
+      let retryPolicy = RetryPolicy(
+        maximumAttempts: 5,
+        initialBackoff: .seconds(60),
+        maximumBackoff: .seconds(50),
+        backoffMultiplier: 1,
+        retryableStatusCodes: [.unavailable]
+      )
+
+      try await harness.bidirectional(
+        request: ClientRequest.Stream {
+          try await $0.write([0])
+        },
+        configuration: .init(retryPolicy: retryPolicy)
+      ) { response in
+        switch response.accepted {
+        case .success:
+          XCTFail("Expected RPC to fail")
+        case .failure(let error):
+          XCTAssertEqual(error.code, .unavailable)
+        }
+      }
+
+      // Only one attempt should be made.
+      XCTAssertEqual(harness.clientStreamsOpened, 1)
+      XCTAssertEqual(harness.serverStreamsAccepted, 1)
+    }
+  }
+}
+
+extension ClientRPCExecutionConfiguration {
+  fileprivate static func retry(
+    maximumAttempts: Int = 5,
+    codes: Set<Status.Code>,
+    timeout: Duration? = nil
+  ) -> Self {
+    let policy = RetryPolicy(
+      maximumAttempts: maximumAttempts,
+      initialBackoff: .milliseconds(10),
+      maximumBackoff: .milliseconds(100),
+      backoffMultiplier: 1.6,
+      retryableStatusCodes: codes
+    )
+
+    return ClientRPCExecutionConfiguration(retryPolicy: policy, timeout: timeout)
+  }
+}

+ 92 - 0
Tests/GRPCCoreTests/Call/Client/RetryDelaySequenceTests.swift

@@ -0,0 +1,92 @@
+/*
+ * Copyright 2023, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import XCTest
+
+@testable import GRPCCore
+
+final class RetryDelaySequenceTests: XCTestCase {
+  func testSequence() {
+    let policy = RetryPolicy(
+      maximumAttempts: 1,  // ignored here
+      initialBackoff: .seconds(1),
+      maximumBackoff: .seconds(8),
+      backoffMultiplier: 2.0,
+      retryableStatusCodes: [.aborted]  // ignored here
+    )
+
+    let sequence = RetryDelaySequence(policy: policy)
+    var iterator = sequence.makeIterator()
+
+    // The iterator will never return 'nil', '!' is safe.
+    XCTAssertLessThanOrEqual(iterator.next()!, .seconds(1))
+    XCTAssertLessThanOrEqual(iterator.next()!, .seconds(2))
+    XCTAssertLessThanOrEqual(iterator.next()!, .seconds(4))
+    XCTAssertLessThanOrEqual(iterator.next()!, .seconds(8))
+    XCTAssertLessThanOrEqual(iterator.next()!, .seconds(8))  // Clamped
+  }
+
+  func testSequenceSupportsMultipleIteration() {
+    let policy = RetryPolicy(
+      maximumAttempts: 1,  // ignored here
+      initialBackoff: .seconds(1),
+      maximumBackoff: .seconds(8),
+      backoffMultiplier: 2.0,
+      retryableStatusCodes: [.aborted]  // ignored here
+    )
+
+    let sequence = RetryDelaySequence(policy: policy)
+    for _ in 0 ..< 10 {
+      var iterator = sequence.makeIterator()
+      // The iterator will never return 'nil', '!' is safe.
+      XCTAssertLessThanOrEqual(iterator.next()!, .seconds(1))
+      XCTAssertLessThanOrEqual(iterator.next()!, .seconds(2))
+      XCTAssertLessThanOrEqual(iterator.next()!, .seconds(4))
+      XCTAssertLessThanOrEqual(iterator.next()!, .seconds(8))
+      XCTAssertLessThanOrEqual(iterator.next()!, .seconds(8))  // Clamped
+    }
+  }
+
+  func testDurationToDouble() {
+    let testData: [(Duration, Double)] = [
+      (.zero, 0.0),
+      (.seconds(1), 1.0),
+      (.milliseconds(1500), 1.5),
+      (.nanoseconds(1_000_000_000), 1.0),
+      (.nanoseconds(3_141_592_653), 3.141592653),
+    ]
+
+    for (duration, expected) in testData {
+      XCTAssertEqual(RetryDelaySequence.Iterator._durationToTimeInterval(duration), expected)
+    }
+  }
+
+  func testDoubleToDuration() {
+    let testData: [(Double, Duration)] = [
+      (0.0, .zero),
+      (1.0, .seconds(1)),
+      (1.5, .milliseconds(1500)),
+      (1.0, .nanoseconds(1_000_000_000)),
+      (3.141592653, .nanoseconds(3_141_592_653)),
+    ]
+
+    for (seconds, expected) in testData {
+      let actual = RetryDelaySequence.Iterator._timeIntervalToDuration(seconds)
+      XCTAssertEqual(actual.components.seconds, expected.components.seconds)
+      // We lose some precision in the conversion, that's fine.
+      XCTAssertEqual(actual.components.attoseconds / 1_000, expected.components.attoseconds / 1_000)
+    }
+  }
+}