Pārlūkot izejas kodu

Move to completed state before cancelling task during finish() (#1302)

As reported in #1299 and #1301, there are some scenarios where the closing the channel or using the response stream writer after the channel has been closed will lead to a crash.

Specifically, when `finish()` is called the state was not progressed to `.completed` before cancelling the task. This was to maintain parity with the ELG-based API where the status and the trailers were still sent after `finish()` is called. We now believe this to be misguided and we shouldn't expect to be able to send anything on the channel at this point because we are tearing the handler and the channel down.

This changes `finish()` to move to the `.completed` state before cancelling the `userHandlerTask`. As a result, when the completion handler for the user function fires, it will call `handleError(_:)` with `CancellationError` (as before) but now the error handler will not attempt to send the status or trailers back via the interceptors because the state will be in `.completed`.

Tests for receiving an error after headers and after a message have been added.
Si Beaumont 4 gadi atpakaļ
vecāks
revīzija
fb54300850

+ 6 - 2
Sources/GRPC/AsyncAwaitSupport/GRPCAsyncServerHandler.swift

@@ -327,6 +327,8 @@ internal final class AsyncServerHandler<
       self.state = .completed
 
     case .active:
+      self.state = .completed
+      self.interceptors = nil
       self.userHandlerTask?.cancel()
 
     case .completed:
@@ -524,8 +526,10 @@ internal final class AsyncServerHandler<
       self.interceptors.send(.message(response, metadata), promise: nil)
 
     case .completed:
-      /// If we are in the completed state then the async writer delegate must have terminated.
-      preconditionFailure()
+      /// If we are in the completed state then the async writer delegate will have been cancelled,
+      /// however the cancellation is asynchronous so there's a chance that we receive this callback
+      /// after that has happened. We can drop the response.
+      ()
     }
   }
 

+ 14 - 3
Tests/GRPCTests/AsyncAwaitSupport/AsyncIntegrationTests.swift

@@ -49,9 +49,9 @@ final class AsyncIntegrationTests: GRPCTestCase {
   }
 
   override func tearDown() {
-    XCTAssertNoThrow(try self.client.close().wait())
-    XCTAssertNoThrow(try self.server.close().wait())
-    XCTAssertNoThrow(try self.group.syncShutdownGracefully())
+    XCTAssertNoThrow(try self.client?.close().wait())
+    XCTAssertNoThrow(try self.server?.close().wait())
+    XCTAssertNoThrow(try self.group?.syncShutdownGracefully())
     super.tearDown()
   }
 
@@ -195,6 +195,17 @@ final class AsyncIntegrationTests: GRPCTestCase {
       ])
     }
   }
+
+  func testServerCloseAfterMessage() {
+    XCTAsyncTest {
+      let update = self.echo.makeUpdateCall()
+      try await update.requestStream.send(.with { $0.text = "hello" })
+      _ = try await update.responses.first(where: { _ in true })
+      XCTAssertNoThrow(try self.server.close().wait())
+      self.server = nil // So that tearDown() does not call close() again.
+      try await update.requestStream.finish()
+    }
+  }
 }
 
 extension HPACKHeaders {

+ 34 - 2
Tests/GRPCTests/GRPCAsyncServerHandlerTests.swift

@@ -278,8 +278,8 @@ class AsyncServerHandlerTests: ServerHandlerTestCaseBase {
 
     await assertThat(self.recorder.metadata, .nil())
     await assertThat(self.recorder.messages, .isEmpty())
-    await assertThat(self.recorder.status, .notNil(.hasCode(.unavailable)))
-    await assertThat(self.recorder.trailers, .is([:]))
+    await assertThat(self.recorder.status, .nil())
+    await assertThat(self.recorder.trailers, .nil())
   } }
 
   func testFinishAfterMessage() { XCTAsyncTest {
@@ -296,6 +296,38 @@ class AsyncServerHandlerTests: ServerHandlerTestCaseBase {
     // Wait for tasks to finish.
     await handler.userHandlerTask?.value
 
+    await assertThat(self.recorder.messages.first, .is(ByteBuffer(string: "hello")))
+    await assertThat(self.recorder.status, .nil())
+    await assertThat(self.recorder.trailers, .nil())
+  } }
+
+  func testErrorAfterHeaders() { XCTAsyncTest {
+    let handler = self.makeHandler(observer: self.echo(requests:responseStreamWriter:context:))
+
+    handler.receiveMetadata([:])
+    handler.receiveError(CancellationError())
+
+    // Wait for tasks to finish.
+    await handler.userHandlerTask?.value
+
+    await assertThat(self.recorder.status, .notNil(.hasCode(.unavailable)))
+    await assertThat(self.recorder.trailers, .is([:]))
+  } }
+
+  func testErrorAfterMessage() { XCTAsyncTest {
+    let handler = self.makeHandler(observer: self.echo(requests:responseStreamWriter:context:))
+
+    handler.receiveMetadata([:])
+    handler.receiveMessage(ByteBuffer(string: "hello"))
+
+    // Wait for the async user function to have processed the message.
+    try self.recorder.recordedMessagePromise.futureResult.wait()
+
+    handler.receiveError(CancellationError())
+
+    // Wait for tasks to finish.
+    await handler.userHandlerTask?.value
+
     await assertThat(self.recorder.messages.first, .is(ByteBuffer(string: "hello")))
     await assertThat(self.recorder.status, .notNil(.hasCode(.unavailable)))
     await assertThat(self.recorder.trailers, .is([:]))