Browse Source

Suspend request stream writes before the RPC is ready (#1411)

Motivation:

Writes on the request stream should suspend before the RPC is ready.
That's not the case right now.

Modifications:

- Allow the writability of a request stream writer to be specified when
  initialized and have clients default to not writable.
- Plumb through an `onStart` callback which is triggered on the
  `channelActive` of the HTTP/2 stream which toggles the writability.

Result:

Attempting to write on a request stream before the underlying http/2
stream is ready will suspend.
George Barnett 3 years ago
parent
commit
8c5a8af968

+ 2 - 1
.github/workflows/ci.yaml

@@ -25,7 +25,8 @@ jobs:
       matrix:
         include:
           - image: swift:5.6-focal
-            swift-test-flags: "--sanitize=thread"
+            # No TSAN because of: https://github.com/apple/swift/issues/59068
+            # swift-test-flags: "--sanitize=thread"
           - image: swift:5.5-focal
             swift-test-flags: "--sanitize=thread"
           - image: swift:5.4-focal

+ 3 - 1
Sources/GRPC/AsyncAwaitSupport/AsyncWriter.swift

@@ -109,7 +109,7 @@ internal final actor AsyncWriter<Delegate: AsyncWriterDelegate>: Sendable {
 
   /// Whether the writer is paused.
   @usableFromInline
-  internal var _isPaused: Bool = false
+  internal var _isPaused: Bool
 
   /// The delegate to process elements. By convention we call the delegate before resuming any
   /// continuation.
@@ -120,12 +120,14 @@ internal final actor AsyncWriter<Delegate: AsyncWriterDelegate>: Sendable {
   internal init(
     maxPendingElements: Int = 16,
     maxWritesBeforeYield: Int = 5,
+    isWritable: Bool = true,
     delegate: Delegate
   ) {
     self._maxPendingElements = maxPendingElements
     self._maxWritesBeforeYield = maxWritesBeforeYield
     self._pendingElements = CircularBuffer(initialCapacity: maxPendingElements)
     self._completionState = .incomplete
+    self._isPaused = !isWritable
     self._delegate = delegate
   }
 

+ 2 - 1
Sources/GRPC/AsyncAwaitSupport/Call+AsyncRequestStreamWriter.swift

@@ -26,7 +26,8 @@ extension Call {
       self.send(.end, promise: nil)
     }
 
-    return GRPCAsyncRequestStreamWriter(asyncWriter: .init(delegate: delegate))
+    // Start as not-writable; writability will be toggled when the stream comes up.
+    return GRPCAsyncRequestStreamWriter(asyncWriter: .init(isWritable: false, delegate: delegate))
   }
 }
 

+ 3 - 0
Sources/GRPC/AsyncAwaitSupport/GRPCAsyncBidirectionalStreamingCall.swift

@@ -86,6 +86,9 @@ public struct GRPCAsyncBidirectionalStreamingCall<Request: Sendable, Response: S
     let asyncCall = Self(call: call)
 
     asyncCall.call.invokeStreamingRequests(
+      onStart: {
+        asyncCall.requestStream.asyncWriter.toggleWritabilityAsynchronously()
+      },
       onError: { error in
         asyncCall.responseParts.handleError(error)
         asyncCall.responseSource.finish(throwing: error)

+ 3 - 0
Sources/GRPC/AsyncAwaitSupport/GRPCAsyncClientStreamingCall.swift

@@ -85,6 +85,9 @@ public struct GRPCAsyncClientStreamingCall<Request: Sendable, Response: Sendable
     let asyncCall = Self(call: call)
 
     asyncCall.call.invokeStreamingRequests(
+      onStart: {
+        asyncCall.requestStream.asyncWriter.toggleWritabilityAsynchronously()
+      },
       onError: { error in
         asyncCall.responseParts.handleError(error)
         asyncCall.requestStream.asyncWriter.cancelAsynchronously()

+ 1 - 0
Sources/GRPC/AsyncAwaitSupport/GRPCAsyncServerStreamingCall.swift

@@ -88,6 +88,7 @@ public struct GRPCAsyncServerStreamingCall<Request: Sendable, Response: Sendable
 
     asyncCall.call.invokeUnaryRequest(
       request,
+      onStart: {},
       onError: { error in
         asyncCall.responseParts.handleError(error)
         asyncCall.responseSource.finish(throwing: error)

+ 1 - 0
Sources/GRPC/AsyncAwaitSupport/GRPCAsyncUnaryCall.swift

@@ -84,6 +84,7 @@ public struct GRPCAsyncUnaryCall<Request: Sendable, Response: Sendable>: Sendabl
     self.responseParts = UnaryResponseParts(on: call.eventLoop)
     self.call.invokeUnaryRequest(
       request,
+      onStart: {},
       onError: self.responseParts.handleError(_:),
       onResponsePart: self.responseParts.handle(_:)
     )

+ 1 - 0
Sources/GRPC/ClientCalls/BidirectionalStreamingCall.swift

@@ -84,6 +84,7 @@ public struct BidirectionalStreamingCall<
 
   internal func invoke() {
     self.call.invokeStreamingRequests(
+      onStart: {},
       onError: self.responseParts.handleError(_:),
       onResponsePart: self.responseParts.handle(_:)
     )

+ 32 - 8
Sources/GRPC/ClientCalls/Call.swift

@@ -123,10 +123,10 @@ public final class Call<Request, Response> {
     self.options.logger.debug("starting rpc", metadata: ["path": "\(self.path)"], source: "GRPC")
 
     if self.eventLoop.inEventLoop {
-      self._invoke(onError: onError, onResponsePart: onResponsePart)
+      self._invoke(onStart: {}, onError: onError, onResponsePart: onResponsePart)
     } else {
       self.eventLoop.execute {
-        self._invoke(onError: onError, onResponsePart: onResponsePart)
+        self._invoke(onStart: {}, onError: onError, onResponsePart: onResponsePart)
       }
     }
   }
@@ -262,6 +262,7 @@ extension Call {
   /// - Important: This *must* to be called from the `eventLoop`.
   @usableFromInline
   internal func _invoke(
+    onStart: @escaping () -> Void,
     onError: @escaping (Error) -> Void,
     onResponsePart: @escaping (GRPCClientResponsePart<Response>) -> Void
   ) {
@@ -275,6 +276,7 @@ extension Call {
         withOptions: self.options,
         onEventLoop: self.eventLoop,
         interceptedBy: self._interceptors,
+        onStart: onStart,
         onError: onError,
         onResponsePart: onResponsePart
       )
@@ -354,14 +356,25 @@ extension Call {
   @inlinable
   internal func invokeUnaryRequest(
     _ request: Request,
+    onStart: @escaping () -> Void,
     onError: @escaping (Error) -> Void,
     onResponsePart: @escaping (GRPCClientResponsePart<Response>) -> Void
   ) {
     if self.eventLoop.inEventLoop {
-      self._invokeUnaryRequest(request: request, onError: onError, onResponsePart: onResponsePart)
+      self._invokeUnaryRequest(
+        request: request,
+        onStart: onStart,
+        onError: onError,
+        onResponsePart: onResponsePart
+      )
     } else {
       self.eventLoop.execute {
-        self._invokeUnaryRequest(request: request, onError: onError, onResponsePart: onResponsePart)
+        self._invokeUnaryRequest(
+          request: request,
+          onStart: onStart,
+          onError: onError,
+          onResponsePart: onResponsePart
+        )
       }
     }
   }
@@ -373,14 +386,23 @@ extension Call {
   ///   - onResponsePart: A callback invoked for each response part received.
   @inlinable
   internal func invokeStreamingRequests(
+    onStart: @escaping () -> Void,
     onError: @escaping (Error) -> Void,
     onResponsePart: @escaping (GRPCClientResponsePart<Response>) -> Void
   ) {
     if self.eventLoop.inEventLoop {
-      self._invokeStreamingRequests(onError: onError, onResponsePart: onResponsePart)
+      self._invokeStreamingRequests(
+        onStart: onStart,
+        onError: onError,
+        onResponsePart: onResponsePart
+      )
     } else {
       self.eventLoop.execute {
-        self._invokeStreamingRequests(onError: onError, onResponsePart: onResponsePart)
+        self._invokeStreamingRequests(
+          onStart: onStart,
+          onError: onError,
+          onResponsePart: onResponsePart
+        )
       }
     }
   }
@@ -389,13 +411,14 @@ extension Call {
   @usableFromInline
   internal func _invokeUnaryRequest(
     request: Request,
+    onStart: @escaping () -> Void,
     onError: @escaping (Error) -> Void,
     onResponsePart: @escaping (GRPCClientResponsePart<Response>) -> Void
   ) {
     self.eventLoop.assertInEventLoop()
     assert(self.type == .unary || self.type == .serverStreaming)
 
-    self._invoke(onError: onError, onResponsePart: onResponsePart)
+    self._invoke(onStart: onStart, onError: onError, onResponsePart: onResponsePart)
     self._send(.metadata(self.options.customMetadata), promise: nil)
     self._send(
       .message(request, .init(compress: self.isCompressionEnabled, flush: false)),
@@ -407,13 +430,14 @@ extension Call {
   /// On-`EventLoop` implementation of `invokeStreamingRequests(_:)`.
   @usableFromInline
   internal func _invokeStreamingRequests(
+    onStart: @escaping () -> Void,
     onError: @escaping (Error) -> Void,
     onResponsePart: @escaping (GRPCClientResponsePart<Response>) -> Void
   ) {
     self.eventLoop.assertInEventLoop()
     assert(self.type == .clientStreaming || self.type == .bidirectionalStreaming)
 
-    self._invoke(onError: onError, onResponsePart: onResponsePart)
+    self._invoke(onStart: onStart, onError: onError, onResponsePart: onResponsePart)
     self._send(.metadata(self.options.customMetadata), promise: nil)
   }
 }

+ 1 - 0
Sources/GRPC/ClientCalls/ClientStreamingCall.swift

@@ -84,6 +84,7 @@ public struct ClientStreamingCall<RequestPayload, ResponsePayload>: StreamingReq
 
   internal func invoke() {
     self.call.invokeStreamingRequests(
+      onStart: {},
       onError: self.responseParts.handleError(_:),
       onResponsePart: self.responseParts.handle(_:)
     )

+ 1 - 0
Sources/GRPC/ClientCalls/ServerStreamingCall.swift

@@ -80,6 +80,7 @@ public struct ServerStreamingCall<RequestPayload, ResponsePayload>: ClientCall {
   internal func invoke(_ request: RequestPayload) {
     self.call.invokeUnaryRequest(
       request,
+      onStart: {},
       onError: self.responseParts.handleError(_:),
       onResponsePart: self.responseParts.handle(_:)
     )

+ 1 - 0
Sources/GRPC/ClientCalls/UnaryCall.swift

@@ -84,6 +84,7 @@ public struct UnaryCall<RequestPayload, ResponsePayload>: UnaryResponseClientCal
   internal func invoke(_ request: RequestPayload) {
     self.call.invokeUnaryRequest(
       request,
+      onStart: {},
       onError: self.responseParts.handleError(_:),
       onResponsePart: self.responseParts.handle(_:)
     )

+ 6 - 0
Sources/GRPC/Interceptor/ClientTransport.swift

@@ -93,6 +93,9 @@ internal final class ClientTransport<Request, Response> {
   /// The `NIO.Channel` used by the transport, if it is available.
   private var channel: Channel?
 
+  /// A callback which is invoked once when the stream channel becomes active.
+  private let onStart: () -> Void
+
   /// Our current state as logging metadata.
   private var stateForLogging: Logger.MetadataValue {
     if self.state.mayBuffer {
@@ -109,11 +112,13 @@ internal final class ClientTransport<Request, Response> {
     serializer: AnySerializer<Request>,
     deserializer: AnyDeserializer<Response>,
     errorDelegate: ClientErrorDelegate?,
+    onStart: @escaping () -> Void,
     onError: @escaping (Error) -> Void,
     onResponsePart: @escaping (GRPCClientResponsePart<Response>) -> Void
   ) {
     self.callEventLoop = eventLoop
     self.callDetails = details
+    self.onStart = onStart
     let logger = GRPCLogger(wrapping: details.options.logger)
     self.logger = logger
     self.serializer = serializer
@@ -332,6 +337,7 @@ extension ClientTransport {
       self._pipeline?.logger = self.logger
       self.logger.debug("activated stream channel")
       self.channel = channel
+      self.onStart()
       self.unbuffer()
 
     case .close:

+ 5 - 0
Sources/GRPC/Interceptor/ClientTransportFactory.swift

@@ -140,6 +140,7 @@ internal struct ClientTransportFactory<Request, Response> {
     withOptions options: CallOptions,
     onEventLoop eventLoop: EventLoop,
     interceptedBy interceptors: [ClientInterceptor<Request, Response>],
+    onStart: @escaping () -> Void,
     onError: @escaping (Error) -> Void,
     onResponsePart: @escaping (GRPCClientResponsePart<Response>) -> Void
   ) -> ClientTransport<Request, Response> {
@@ -151,6 +152,7 @@ internal struct ClientTransportFactory<Request, Response> {
         withOptions: options,
         onEventLoop: eventLoop,
         interceptedBy: interceptors,
+        onStart: onStart,
         onError: onError,
         onResponsePart: onResponsePart
       )
@@ -220,6 +222,7 @@ internal struct HTTP2ClientTransportFactory<Request, Response> {
     withOptions options: CallOptions,
     onEventLoop eventLoop: EventLoop,
     interceptedBy interceptors: [ClientInterceptor<Request, Response>],
+    onStart: @escaping () -> Void,
     onError: @escaping (Error) -> Void,
     onResponsePart: @escaping (GRPCClientResponsePart<Response>) -> Void
   ) -> ClientTransport<Request, Response> {
@@ -230,6 +233,7 @@ internal struct HTTP2ClientTransportFactory<Request, Response> {
       serializer: self.serializer,
       deserializer: self.deserializer,
       errorDelegate: self.errorDelegate,
+      onStart: onStart,
       onError: onError,
       onResponsePart: onResponsePart
     )
@@ -333,6 +337,7 @@ internal struct FakeClientTransportFactory<Request, Response> {
       serializer: self.requestSerializer,
       deserializer: self.responseDeserializer,
       errorDelegate: nil,
+      onStart: {},
       onError: onError,
       onResponsePart: onResponsePart
     )

+ 119 - 15
Tests/GRPCTests/AsyncAwaitSupport/AsyncClientTests.swift

@@ -27,17 +27,21 @@ final class AsyncClientCancellationTests: GRPCTestCase {
   private var group: EventLoopGroup!
   private var pool: GRPCChannel!
 
-  override func setUpWithError() throws {
-    try super.setUpWithError()
+  override func setUp() {
+    super.setUp()
     self.group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
   }
 
   override func tearDown() async throws {
-    try self.pool.close().wait()
-    self.pool = nil
+    if self.pool != nil {
+      try self.pool.close().wait()
+      self.pool = nil
+    }
 
-    try self.server.close().wait()
-    self.server = nil
+    if self.server != nil {
+      try self.server.close().wait()
+      self.server = nil
+    }
 
     try self.group.syncShutdownGracefully()
     self.group = nil
@@ -45,18 +49,26 @@ final class AsyncClientCancellationTests: GRPCTestCase {
     try await super.tearDown()
   }
 
-  private func startServer(service: CallHandlerProvider) throws -> Echo_EchoAsyncClient {
+  private func startServer(service: CallHandlerProvider) throws {
     precondition(self.server == nil)
-    precondition(self.pool == nil)
 
     self.server = try Server.insecure(group: self.group)
       .withServiceProviders([service])
       .withLogger(self.serverLogger)
       .bind(host: "127.0.0.1", port: 0)
       .wait()
+  }
+
+  private func startServerAndClient(service: CallHandlerProvider) throws -> Echo_EchoAsyncClient {
+    try self.startServer(service: service)
+    return try self.makeClient(port: self.server.channel.localAddress!.port!)
+  }
+
+  private func makeClient(port: Int) throws -> Echo_EchoAsyncClient {
+    precondition(self.pool == nil)
 
     self.pool = try GRPCChannelPool.with(
-      target: .host("127.0.0.1", port: self.server.channel.localAddress!.port!),
+      target: .host("127.0.0.1", port: port),
       transportSecurity: .plaintext,
       eventLoopGroup: self.group
     ) {
@@ -68,7 +80,7 @@ final class AsyncClientCancellationTests: GRPCTestCase {
 
   func testCancelUnaryFailsResponse() async throws {
     // We don't want the RPC to complete before we cancel it so use the never resolving service.
-    let echo = try self.startServer(service: NeverResolvingEchoProvider())
+    let echo = try self.startServerAndClient(service: NeverResolvingEchoProvider())
 
     let get = echo.makeGetCall(.with { $0.text = "foo bar baz" })
     try await get.cancel()
@@ -82,7 +94,7 @@ final class AsyncClientCancellationTests: GRPCTestCase {
 
   func testCancelServerStreamingClosesResponseStream() async throws {
     // We don't want the RPC to complete before we cancel it so use the never resolving service.
-    let echo = try self.startServer(service: NeverResolvingEchoProvider())
+    let echo = try self.startServerAndClient(service: NeverResolvingEchoProvider())
 
     let expand = echo.makeExpandCall(.with { $0.text = "foo bar baz" })
     try await expand.cancel()
@@ -96,7 +108,7 @@ final class AsyncClientCancellationTests: GRPCTestCase {
   }
 
   func testCancelClientStreamingClosesRequestStreamAndFailsResponse() async throws {
-    let echo = try self.startServer(service: EchoProvider())
+    let echo = try self.startServerAndClient(service: EchoProvider())
 
     let collect = echo.makeCollectCall()
     // Make sure the stream is up before we cancel it.
@@ -114,7 +126,7 @@ final class AsyncClientCancellationTests: GRPCTestCase {
   }
 
   func testClientStreamingClosesRequestStreamOnEnd() async throws {
-    let echo = try self.startServer(service: EchoProvider())
+    let echo = try self.startServerAndClient(service: EchoProvider())
 
     let collect = echo.makeCollectCall()
     // Send and close.
@@ -133,7 +145,7 @@ final class AsyncClientCancellationTests: GRPCTestCase {
   }
 
   func testCancelBidiStreamingClosesRequestStreamAndResponseStream() async throws {
-    let echo = try self.startServer(service: EchoProvider())
+    let echo = try self.startServerAndClient(service: EchoProvider())
 
     let update = echo.makeUpdateCall()
     // Make sure the stream is up before we cancel it.
@@ -153,7 +165,7 @@ final class AsyncClientCancellationTests: GRPCTestCase {
   }
 
   func testBidiStreamingClosesRequestStreamOnEnd() async throws {
-    let echo = try self.startServer(service: EchoProvider())
+    let echo = try self.startServerAndClient(service: EchoProvider())
 
     let update = echo.makeUpdateCall()
     // Send and close.
@@ -172,6 +184,98 @@ final class AsyncClientCancellationTests: GRPCTestCase {
       try await update.requestStream.send(.with { $0.text = "should throw" })
     )
   }
+
+  private enum RequestStreamingRPC {
+    typealias Request = Echo_EchoRequest
+    typealias Response = Echo_EchoResponse
+
+    case clientStreaming(GRPCAsyncClientStreamingCall<Request, Response>)
+    case bidirectionalStreaming(GRPCAsyncBidirectionalStreamingCall<Request, Response>)
+
+    func sendRequest(_ text: String) async throws {
+      switch self {
+      case let .clientStreaming(call):
+        try await call.requestStream.send(.with { $0.text = text })
+      case let .bidirectionalStreaming(call):
+        try await call.requestStream.send(.with { $0.text = text })
+      }
+    }
+
+    func cancel() {
+      switch self {
+      case let .clientStreaming(call):
+        // TODO: this should be async
+        Task { try await call.cancel() }
+      case let .bidirectionalStreaming(call):
+        // TODO: this should be async
+        Task { try await call.cancel() }
+      }
+    }
+  }
+
+  private func testSendingRequestsSuspendsWhileStreamIsNotReady(
+    makeRPC: @escaping () -> RequestStreamingRPC
+  ) async throws {
+    // The strategy for this test is to race two different tasks. The first will attempt to send a
+    // message on a request stream on a connection which will never establish. The second will sleep
+    // for a little while. Each task returns a `SendOrTimedOut` event. If the message is sent then
+    // the test definitely failed; it should not be possible to send a message on a stream which is
+    // not open. If the time out happens first then it probably did not fail.
+    enum SentOrTimedOut: Equatable, Sendable {
+      case messageSent
+      case timedOut
+    }
+
+    await withThrowingTaskGroup(of: SentOrTimedOut.self) { group in
+      group.addTask {
+        let rpc = makeRPC()
+
+        return try await withTaskCancellationHandler {
+          // This should suspend until we cancel it: we're never going to start a server so it
+          // should never succeed.
+          try await rpc.sendRequest("I should suspend")
+          return .messageSent
+        } onCancel: {
+          rpc.cancel()
+        }
+      }
+
+      group.addTask {
+        // Wait for 100ms.
+        try await Task.sleep(nanoseconds: 100_000_000)
+        return .timedOut
+      }
+
+      do {
+        let event = try await group.next()
+        // If this isn't timed out then the message was sent before the stream was ready.
+        XCTAssertEqual(event, .timedOut)
+      } catch {
+        XCTFail("Unexpected error \(error)")
+      }
+
+      // Cancel the other task.
+      group.cancelAll()
+    }
+  }
+
+  func testClientStreamingSuspendsWritesUntilStreamIsUp() async throws {
+    // Make a client for a server which isn't up yet. It will continually fail to establish a
+    // connection.
+    let echo = try self.makeClient(port: 0)
+    try await self.testSendingRequestsSuspendsWhileStreamIsNotReady {
+      return .clientStreaming(echo.makeCollectCall())
+    }
+  }
+
+  func testBidirectionalStreamingSuspendsWritesUntilStreamIsUp() async throws {
+    // Make a client for a server which isn't up yet. It will continually fail to establish a
+    // connection.
+    let echo = try self.makeClient(port: 0)
+    try await self.testSendingRequestsSuspendsWhileStreamIsNotReady {
+      return .bidirectionalStreaming(echo.makeUpdateCall())
+    }
+  }
 }
 
 #endif // compiler(>=5.6)

+ 4 - 0
Tests/GRPCTests/ClientCallTests.swift

@@ -122,6 +122,7 @@ class ClientCallTests: GRPCTestCase {
     let promise = self.makeStatusPromise()
     get.invokeUnaryRequest(
       .with { $0.text = "get" },
+      onStart: {},
       onError: promise.fail(_:),
       onResponsePart: self.makeResponsePartHandler(completing: promise)
     )
@@ -134,6 +135,7 @@ class ClientCallTests: GRPCTestCase {
 
     let promise = self.makeStatusPromise()
     collect.invokeStreamingRequests(
+      onStart: {},
       onError: promise.fail(_:),
       onResponsePart: self.makeResponsePartHandler(completing: promise)
     )
@@ -152,6 +154,7 @@ class ClientCallTests: GRPCTestCase {
     let promise = self.makeStatusPromise()
     expand.invokeUnaryRequest(
       .with { $0.text = "expand" },
+      onStart: {},
       onError: promise.fail(_:),
       onResponsePart: self.makeResponsePartHandler(completing: promise)
     )
@@ -164,6 +167,7 @@ class ClientCallTests: GRPCTestCase {
 
     let promise = self.makeStatusPromise()
     update.invokeStreamingRequests(
+      onStart: {},
       onError: promise.fail(_:),
       onResponsePart: self.makeResponsePartHandler(completing: promise)
     )

+ 1 - 0
Tests/GRPCTests/ClientTransportTests.swift

@@ -56,6 +56,7 @@ class ClientTransportTests: GRPCTestCase {
       serializer: AnySerializer(wrapping: StringSerializer()),
       deserializer: AnyDeserializer(wrapping: StringDeserializer()),
       errorDelegate: nil,
+      onStart: {},
       onError: onError,
       onResponsePart: onResponsePart
     )