Browse Source

Make 'finish()' 'async' (#2044)

Motivation:

Finishing writes should be `async` as the underlying writer may need to
flush and write out any buffered data.

Modifications:

- Mark `finish()` as `async`
- Refactor the in-proc client transport slightly to avoid async calls
while holding a lock

Result:

`finish` is `async`
George Barnett 1 year ago
parent
commit
d4d1a2ef1d

+ 2 - 2
Sources/GRPCCore/Call/Client/Internal/ClientStreamExecutor.swift

@@ -95,9 +95,9 @@ internal enum ClientStreamExecutor {
 
     switch result {
     case .success:
-      stream.finish()
+      await stream.finish()
     case .failure(let error):
-      stream.finish(throwing: error)
+      await stream.finish(throwing: error)
     }
   }
 

+ 2 - 2
Sources/GRPCCore/Call/Server/Internal/ServerRPCExecutor.swift

@@ -58,7 +58,7 @@ struct ServerRPCExecutor {
       // Stream can't be handled; write an error status and close.
       let status = Status(code: Status.Code(error.code), message: error.message)
       try? await stream.outbound.write(.status(status, error.metadata))
-      stream.outbound.finish()
+      await stream.outbound.finish()
     }
   }
 
@@ -231,7 +231,7 @@ struct ServerRPCExecutor {
     }
 
     try? await outbound.write(.status(status, metadata))
-    outbound.finish()
+    await outbound.finish()
   }
 
   @inlinable

+ 1 - 1
Sources/GRPCCore/Call/Server/RPCRouter.swift

@@ -155,7 +155,7 @@ extension RPCRouter {
       // If this throws then the stream must be closed which we can't do anything about, so ignore
       // any error.
       try? await stream.outbound.write(.status(.rpcNotImplemented, [:]))
-      stream.outbound.finish()
+      await stream.outbound.finish()
     }
   }
 }

+ 4 - 4
Sources/GRPCCore/Streaming/RPCWriter+Closable.swift

@@ -55,8 +55,8 @@ extension RPCWriter {
     /// All writes after ``finish()`` has been called should result in an error
     /// being thrown.
     @inlinable
-    public func finish() {
-      self.writer.finish()
+    public func finish() async {
+      await self.writer.finish()
     }
 
     /// Indicate to the writer that no more writes are to be accepted because an error occurred.
@@ -64,8 +64,8 @@ extension RPCWriter {
     /// All writes after ``finish(throwing:)`` has been called should result in an error
     /// being thrown.
     @inlinable
-    public func finish(throwing error: any Error) {
-      self.writer.finish(throwing: error)
+    public func finish(throwing error: any Error) async {
+      await self.writer.finish(throwing: error)
     }
   }
 }

+ 2 - 2
Sources/GRPCCore/Streaming/RPCWriterProtocol.swift

@@ -57,13 +57,13 @@ public protocol ClosableRPCWriterProtocol<Element>: RPCWriterProtocol {
   ///
   /// All writes after ``finish()`` has been called should result in an error
   /// being thrown.
-  func finish()
+  func finish() async
 
   /// Indicate to the writer that no more writes are to be accepted because an error occurred.
   ///
   /// All writes after ``finish(throwing:)`` has been called should result in an error
   /// being thrown.
-  func finish(throwing error: any Error)
+  func finish(throwing error: any Error) async
 }
 
 @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)

+ 50 - 42
Sources/GRPCInProcessTransport/InProcessClientTransport.swift

@@ -179,8 +179,8 @@ public final class InProcessClientTransport: ClientTransport {
     }
 
     for (clientStream, serverStream) in openStreams {
-      clientStream.outbound.finish(throwing: CancellationError())
-      serverStream.outbound.finish(throwing: CancellationError())
+      await clientStream.outbound.finish(throwing: CancellationError())
+      await serverStream.outbound.finish(throwing: CancellationError())
     }
   }
 
@@ -265,7 +265,7 @@ public final class InProcessClientTransport: ClientTransport {
       try Task.checkCancellation()
     }
 
-    let streamID = try self.state.withLock { state in
+    let acceptStream: Result<Int, RPCError> = self.state.withLock { state in
       switch state {
       case .unconnected:
         // The state cannot be unconnected because if it was, then the above
@@ -281,56 +281,64 @@ public final class InProcessClientTransport: ClientTransport {
           connectedState.openStreams[streamID] = (clientStream, serverStream)
           connectedState.nextStreamID += 1
           state = .connected(connectedState)
+          return .success(streamID)
         } catch let acceptStreamError as RPCError {
-          serverStream.outbound.finish(throwing: acceptStreamError)
-          clientStream.outbound.finish(throwing: acceptStreamError)
-          throw acceptStreamError
+          return .failure(acceptStreamError)
         } catch {
-          serverStream.outbound.finish(throwing: error)
-          clientStream.outbound.finish(throwing: error)
-          throw RPCError(code: .unknown, message: "Unknown error: \(error).")
+          return .failure(RPCError(code: .unknown, message: "Unknown error: \(error)."))
         }
-        return streamID
 
       case .closed:
-        let error = RPCError(
-          code: .failedPrecondition,
-          message: "The client transport is closed."
-        )
-        serverStream.outbound.finish(throwing: error)
-        clientStream.outbound.finish(throwing: error)
-        throw error
+        let error = RPCError(code: .failedPrecondition, message: "The client transport is closed.")
+        return .failure(error)
       }
     }
 
-    defer {
-      clientStream.outbound.finish()
-
-      let maybeEndContinuation = self.state.withLock { state in
-        switch state {
-        case .unconnected:
-          // The state cannot be unconnected at this point, because if we made
-          // it this far, it's because the transport was connected.
-          // Once connected, it's impossible to transition back to unconnected,
-          // so this is an invalid state.
-          fatalError("Invalid state")
-        case .connected(var connectedState):
-          connectedState.openStreams.removeValue(forKey: streamID)
-          state = .connected(connectedState)
-        case .closed(var closedState):
-          closedState.openStreams.removeValue(forKey: streamID)
-          state = .closed(closedState)
-          if closedState.openStreams.isEmpty {
-            // This was the last open stream: signal the closure of the client.
-            return closedState.signalEndContinuation
-          }
-        }
-        return nil
+    switch acceptStream {
+    case .success(let streamID):
+      let streamHandlingResult: Result<T, any Error>
+      do {
+        let result = try await closure(clientStream)
+        streamHandlingResult = .success(result)
+      } catch {
+        streamHandlingResult = .failure(error)
       }
-      maybeEndContinuation?.finish()
+
+      await clientStream.outbound.finish()
+      self.removeStream(id: streamID)
+
+      return try streamHandlingResult.get()
+
+    case .failure(let error):
+      await serverStream.outbound.finish(throwing: error)
+      await clientStream.outbound.finish(throwing: error)
+      throw error
     }
+  }
 
-    return try await closure(clientStream)
+  private func removeStream(id streamID: Int) {
+    let maybeEndContinuation = self.state.withLock { state in
+      switch state {
+      case .unconnected:
+        // The state cannot be unconnected at this point, because if we made
+        // it this far, it's because the transport was connected.
+        // Once connected, it's impossible to transition back to unconnected,
+        // so this is an invalid state.
+        fatalError("Invalid state")
+      case .connected(var connectedState):
+        connectedState.openStreams.removeValue(forKey: streamID)
+        state = .connected(connectedState)
+      case .closed(var closedState):
+        closedState.openStreams.removeValue(forKey: streamID)
+        state = .closed(closedState)
+        if closedState.openStreams.isEmpty {
+          // This was the last open stream: signal the closure of the client.
+          return closedState.signalEndContinuation
+        }
+      }
+      return nil
+    }
+    maybeEndContinuation?.finish()
   }
 
   /// Returns the execution configuration for a given method.

+ 3 - 3
Tests/GRPCCoreTests/Call/Client/Internal/ClientRPCExecutorTestSupport/ClientRPCExecutorTestHarness+ServerBehavior.swift

@@ -74,7 +74,7 @@ extension ClientRPCExecutorTestHarness.ServerStreamHandler {
 
       try await stream.outbound.write(contentsOf: response)
       try await stream.outbound.write(.status(Status(code: .ok, message: ""), [:]))
-      stream.outbound.finish()
+      await stream.outbound.finish()
     }
   }
 
@@ -90,7 +90,7 @@ extension ClientRPCExecutorTestHarness.ServerStreamHandler {
       // All error codes are valid status codes, '!' is safe.
       let status = Status(code: Status.Code(error.code), message: error.message)
       try await stream.outbound.write(.status(status, error.metadata))
-      stream.outbound.finish()
+      await stream.outbound.finish()
     }
   }
 
@@ -99,7 +99,7 @@ extension ClientRPCExecutorTestHarness.ServerStreamHandler {
       XCTFail("Server accepted unexpected stream")
       let status = Status(code: .unknown, message: "Unexpected stream")
       try await stream.outbound.write(.status(status, [:]))
-      stream.outbound.finish()
+      await stream.outbound.finish()
     }
   }
 

+ 16 - 16
Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTests.swift

@@ -24,7 +24,7 @@ final class ServerRPCExecutorTests: XCTestCase {
     let harness = ServerRPCExecutorTestHarness()
     try await harness.execute(handler: .echo) { inbound in
       try await inbound.write(.metadata(["foo": "bar"]))
-      inbound.finish()
+      await inbound.finish()
     } consumer: { outbound in
       let parts = try await outbound.collect()
       XCTAssertEqual(
@@ -42,7 +42,7 @@ final class ServerRPCExecutorTests: XCTestCase {
     try await harness.execute(handler: .echo) { inbound in
       try await inbound.write(.metadata(["foo": "bar"]))
       try await inbound.write(.message([0]))
-      inbound.finish()
+      await inbound.finish()
     } consumer: { outbound in
       let parts = try await outbound.collect()
       XCTAssertEqual(
@@ -63,7 +63,7 @@ final class ServerRPCExecutorTests: XCTestCase {
       try await inbound.write(.message([0]))
       try await inbound.write(.message([1]))
       try await inbound.write(.message([2]))
-      inbound.finish()
+      await inbound.finish()
     } consumer: { outbound in
       let parts = try await outbound.collect()
       XCTAssertEqual(
@@ -94,7 +94,7 @@ final class ServerRPCExecutorTests: XCTestCase {
     } producer: { inbound in
       try await inbound.write(.metadata(["foo": "bar"]))
       try await inbound.write(.message(Array("\"hello\"".utf8)))
-      inbound.finish()
+      await inbound.finish()
     } consumer: { outbound in
       let parts = try await outbound.collect()
       XCTAssertEqual(
@@ -125,7 +125,7 @@ final class ServerRPCExecutorTests: XCTestCase {
       try await inbound.write(.metadata(["foo": "bar"]))
       try await inbound.write(.message(Array("\"hello\"".utf8)))
       try await inbound.write(.message(Array("\"world\"".utf8)))
-      inbound.finish()
+      await inbound.finish()
     } consumer: { outbound in
       let parts = try await outbound.collect()
       XCTAssertEqual(
@@ -151,7 +151,7 @@ final class ServerRPCExecutorTests: XCTestCase {
       }
     } producer: { inbound in
       try await inbound.write(.metadata(["foo": "bar"]))
-      inbound.finish()
+      await inbound.finish()
     } consumer: { outbound in
       let parts = try await outbound.collect()
       XCTAssertEqual(
@@ -167,7 +167,7 @@ final class ServerRPCExecutorTests: XCTestCase {
   func testEmptyInbound() async throws {
     let harness = ServerRPCExecutorTestHarness()
     try await harness.execute(handler: .echo) { inbound in
-      inbound.finish()
+      await inbound.finish()
     } consumer: { outbound in
       let part = try await outbound.collect().first
       XCTAssertStatus(part) { status, _ in
@@ -180,7 +180,7 @@ final class ServerRPCExecutorTests: XCTestCase {
     let harness = ServerRPCExecutorTestHarness()
     try await harness.execute(handler: .echo) { inbound in
       try await inbound.write(.message([0]))
-      inbound.finish()
+      await inbound.finish()
     } consumer: { outbound in
       let part = try await outbound.collect().first
       XCTAssertStatus(part) { status, _ in
@@ -192,7 +192,7 @@ final class ServerRPCExecutorTests: XCTestCase {
   func testInboundStreamThrows() async throws {
     let harness = ServerRPCExecutorTestHarness()
     try await harness.execute(handler: .echo) { inbound in
-      inbound.finish(throwing: RPCError(code: .aborted, message: ""))
+      await inbound.finish(throwing: RPCError(code: .aborted, message: ""))
     } consumer: { outbound in
       let part = try await outbound.collect().first
       XCTAssertStatus(part) { status, _ in
@@ -206,7 +206,7 @@ final class ServerRPCExecutorTests: XCTestCase {
     let harness = ServerRPCExecutorTestHarness()
     try await harness.execute(handler: .throwing(SomeError())) { inbound in
       try await inbound.write(.metadata([:]))
-      inbound.finish()
+      await inbound.finish()
     } consumer: { outbound in
       let part = try await outbound.collect().first
       XCTAssertStatus(part) { status, _ in
@@ -220,7 +220,7 @@ final class ServerRPCExecutorTests: XCTestCase {
     let harness = ServerRPCExecutorTestHarness()
     try await harness.execute(handler: .throwing(error)) { inbound in
       try await inbound.write(.metadata([:]))
-      inbound.finish()
+      await inbound.finish()
     } consumer: { outbound in
       let part = try await outbound.collect().first
       XCTAssertStatus(part) { status, metadata in
@@ -247,7 +247,7 @@ final class ServerRPCExecutorTests: XCTestCase {
       return ServerResponse.Stream(error: RPCError(code: .failedPrecondition, message: ""))
     } producer: { inbound in
       try await inbound.write(.metadata(["grpc-timeout": "1000n"]))
-      inbound.finish()
+      await inbound.finish()
     } consumer: { outbound in
       let part = try await outbound.collect().first
       XCTAssertStatus(part) { status, _ in
@@ -277,7 +277,7 @@ final class ServerRPCExecutorTests: XCTestCase {
       )
     } producer: { inbound in
       try await inbound.write(.metadata([:]))
-      inbound.finish()
+      await inbound.finish()
     } consumer: { outbound in
       let part = try await outbound.collect().first
       XCTAssertStatus(part) { status, metadata in
@@ -302,7 +302,7 @@ final class ServerRPCExecutorTests: XCTestCase {
 
     try await harness.execute(handler: .echo) { inbound in
       try await inbound.write(.metadata([:]))
-      inbound.finish()
+      await inbound.finish()
     } consumer: { outbound in
       let parts = try await outbound.collect()
       XCTAssertEqual(parts, [.metadata([:]), .status(.ok, [:])])
@@ -327,7 +327,7 @@ final class ServerRPCExecutorTests: XCTestCase {
 
     try await harness.execute(handler: .echo) { inbound in
       try await inbound.write(.metadata([:]))
-      inbound.finish()
+      await inbound.finish()
     } consumer: { outbound in
       let parts = try await outbound.collect()
       XCTAssertEqual(parts, [.status(Status(code: .unavailable, message: ""), [:])])
@@ -345,7 +345,7 @@ final class ServerRPCExecutorTests: XCTestCase {
 
     try await harness.execute(handler: .echo) { inbound in
       try await inbound.write(.metadata([:]))
-      inbound.finish()
+      await inbound.finish()
     } consumer: { outbound in
       let parts = try await outbound.collect()
       XCTAssertEqual(parts, [.status(Status(code: .unavailable, message: "Unavailable"), [:])])

+ 10 - 10
Tests/GRPCCoreTests/GRPCServerTests.swift

@@ -55,7 +55,7 @@ final class GRPCServerTests: XCTestCase {
       ) { stream in
         try await stream.outbound.write(.metadata([:]))
         try await stream.outbound.write(.message([3, 1, 4, 1, 5]))
-        stream.outbound.finish()
+        await stream.outbound.finish()
 
         var responseParts = stream.inbound.makeAsyncIterator()
         let metadata = try await responseParts.next()
@@ -86,7 +86,7 @@ final class GRPCServerTests: XCTestCase {
         try await stream.outbound.write(.message([4]))
         try await stream.outbound.write(.message([1]))
         try await stream.outbound.write(.message([5]))
-        stream.outbound.finish()
+        await stream.outbound.finish()
 
         var responseParts = stream.inbound.makeAsyncIterator()
         let metadata = try await responseParts.next()
@@ -113,7 +113,7 @@ final class GRPCServerTests: XCTestCase {
       ) { stream in
         try await stream.outbound.write(.metadata([:]))
         try await stream.outbound.write(.message([3, 1, 4, 1, 5]))
-        stream.outbound.finish()
+        await stream.outbound.finish()
 
         var responseParts = stream.inbound.makeAsyncIterator()
         let metadata = try await responseParts.next()
@@ -144,7 +144,7 @@ final class GRPCServerTests: XCTestCase {
         for byte in [3, 1, 4, 1, 5] as [UInt8] {
           try await stream.outbound.write(.message([byte]))
         }
-        stream.outbound.finish()
+        await stream.outbound.finish()
 
         var responseParts = stream.inbound.makeAsyncIterator()
         let metadata = try await responseParts.next()
@@ -172,7 +172,7 @@ final class GRPCServerTests: XCTestCase {
         options: .defaults
       ) { stream in
         try await stream.outbound.write(.metadata([:]))
-        stream.outbound.finish()
+        await stream.outbound.finish()
 
         var responseParts = stream.inbound.makeAsyncIterator()
         let status = try await responseParts.next()
@@ -194,7 +194,7 @@ final class GRPCServerTests: XCTestCase {
             ) { stream in
               try await stream.outbound.write(.metadata([:]))
               try await stream.outbound.write(.message([i]))
-              stream.outbound.finish()
+              await stream.outbound.finish()
 
               var responseParts = stream.inbound.makeAsyncIterator()
               let metadata = try await responseParts.next()
@@ -231,7 +231,7 @@ final class GRPCServerTests: XCTestCase {
         options: .defaults
       ) { stream in
         try await stream.outbound.write(.metadata([:]))
-        stream.outbound.finish()
+        await stream.outbound.finish()
 
         let parts = try await stream.inbound.collect()
         XCTAssertStatus(parts.first) { status, _ in
@@ -256,7 +256,7 @@ final class GRPCServerTests: XCTestCase {
         options: .defaults
       ) { stream in
         try await stream.outbound.write(.metadata([:]))
-        stream.outbound.finish()
+        await stream.outbound.finish()
 
         let parts = try await stream.inbound.collect()
         XCTAssertStatus(parts.first) { status, _ in
@@ -306,7 +306,7 @@ final class GRPCServerTests: XCTestCase {
         server.beginGracefulShutdown()
 
         try await stream.outbound.write(.message([0]))
-        stream.outbound.finish()
+        await stream.outbound.finish()
 
         let message = try await iterator.next()
         XCTAssertMessage(message) { XCTAssertEqual($0, [0]) }
@@ -368,7 +368,7 @@ final class GRPCServerTests: XCTestCase {
     ) { stream in
       try await stream.outbound.write(.metadata([:]))
       try await stream.outbound.write(.message([0]))
-      stream.outbound.finish()
+      await stream.outbound.finish()
       // Don't need to validate the response, just that the server is running.
       let parts = try await stream.inbound.collect()
       XCTAssertEqual(parts.count, 3)

+ 3 - 3
Tests/GRPCHTTP2CoreTests/Client/Connection/GRPCChannelTests.swift

@@ -371,7 +371,7 @@ final class GRPCChannelTests: XCTestCase {
           switch state {
           case .shutdown:
             // Happens when shutting-down has been initiated, so finish the RPC.
-            stream.outbound.finish()
+            await stream.outbound.finish()
 
             let part2 = try await iterator.next()
             switch part2 {
@@ -444,7 +444,7 @@ final class GRPCChannelTests: XCTestCase {
         group.addTask {
           try await channel.withStream(descriptor: .echoGet, options: .defaults) { stream in
             try await stream.outbound.write(.metadata([:]))
-            stream.outbound.finish()
+            await stream.outbound.finish()
 
             for try await part in stream.inbound {
               switch part {
@@ -824,7 +824,7 @@ extension GRPCChannel {
       options: .defaults
     ) { stream in
       try await stream.outbound.write(.metadata([:]))
-      stream.outbound.finish()
+      await stream.outbound.finish()
 
       for try await part in stream.inbound {
         switch part {

+ 2 - 2
Tests/GRPCInProcessTransportTests/InProcessClientTransportTests.swift

@@ -163,7 +163,7 @@ final class InProcessClientTransportTests: XCTestCase {
           options: .defaults
         ) { stream in
           try await stream.outbound.write(.message([1]))
-          stream.outbound.finish()
+          await stream.outbound.finish()
           let receivedMessages = try await stream.inbound.reduce(into: []) { $0.append($1) }
 
           XCTAssertEqual(receivedMessages, [.message([42])])
@@ -174,7 +174,7 @@ final class InProcessClientTransportTests: XCTestCase {
         try await server.listen { stream in
           let receivedMessages = try? await stream.inbound.reduce(into: []) { $0.append($1) }
           try? await stream.outbound.write(RPCResponsePart.message([42]))
-          stream.outbound.finish()
+          await stream.outbound.finish()
 
           XCTAssertEqual(receivedMessages, [.message([1])])
         }