浏览代码

[async-await] Support for sending response headers via context (#1262)

The adopter may wish to set the response headers (aka "initial metadata") from their user handler. Until now this was not possible, even in the existing non–async-await API, and it was only possible for an adopter to set the trailers.

This introduces a new mutable property to the context passed to the user handler that allows them to set the headers that should be sent back to the client before the first response message.

* `let GRPCAsyncServerCallContext.headers` has been renamed to `requestHeaders` to disambiguate from newly introduced property.
* `var GRPCAsyncServerCallContext.trailers` has been renamed to `responseTrailers` to better align with newly introduced property.
* `var GRPCAsyncServerCallContext.responseHeaders` has been introduced.
Si Beaumont 4 年之前
父节点
当前提交
183fd1dde2

+ 4 - 0
Sources/GRPC/AsyncAwaitSupport/GRPCAsyncBidirectionalStreamingCall.swift

@@ -40,6 +40,10 @@ public struct GRPCAsyncBidirectionalStreamingCall<Request, Response> {
   // MARK: - Response Parts
 
   /// The initial metadata returned from the server.
+  ///
+  /// - Important: The initial metadata will only be available when the first response has been
+  /// received. However, it is not necessary for the response to have been consumed before reading
+  /// this property.
   public var initialMetadata: HPACKHeaders {
     // swiftformat:disable:next redundantGet
     get async throws {

+ 2 - 0
Sources/GRPC/AsyncAwaitSupport/GRPCAsyncClientStreamingCall.swift

@@ -36,6 +36,8 @@ public struct GRPCAsyncClientStreamingCall<Request, Response> {
   // MARK: - Response Parts
 
   /// The initial metadata returned from the server.
+  ///
+  /// - Important: The initial metadata will only be available when the response has been received.
   public var initialMetadata: HPACKHeaders {
     // swiftformat:disable:next redundantGet
     get async throws {

+ 25 - 9
Sources/GRPC/AsyncAwaitSupport/GRPCAsyncServerCallContext.swift

@@ -34,8 +34,8 @@ import NIOHPACK
 public final class GRPCAsyncServerCallContext {
   private let lock = Lock()
 
-  /// Request headers for this request.
-  public let headers: HPACKHeaders
+  /// Metadata for this request.
+  public let requestMetadata: HPACKHeaders
 
   /// The logger used for this call.
   public var logger: Logger {
@@ -83,18 +83,34 @@ public final class GRPCAsyncServerCallContext {
   @usableFromInline
   internal let userInfoRef: Ref<UserInfo>
 
-  /// Metadata to return at the end of the RPC. If this is required it should be updated before
-  /// the `responsePromise` or `statusPromise` is fulfilled.
-  public var trailers: HPACKHeaders {
+  /// Metadata to return at the start of the RPC.
+  ///
+  /// - Important: If this is required it should be updated _before_ the first response is sent via
+  /// the response stream writer. Any updates made after the first response will be ignored.
+  public var initialResponseMetadata: HPACKHeaders {
+    get { self.lock.withLock {
+      return self._initialResponseMetadata
+    } }
+    set { self.lock.withLock {
+      self._initialResponseMetadata = newValue
+    } }
+  }
+
+  private var _initialResponseMetadata: HPACKHeaders = [:]
+
+  /// Metadata to return at the end of the RPC.
+  ///
+  /// If this is required it should be updated before returning from the handler.
+  public var trailingResponseMetadata: HPACKHeaders {
     get { self.lock.withLock {
-      return self._trailers
+      return self._trailingResponseMetadata
     } }
     set { self.lock.withLock {
-      self._trailers = newValue
+      self._trailingResponseMetadata = newValue
     } }
   }
 
-  private var _trailers: HPACKHeaders = [:]
+  private var _trailingResponseMetadata: HPACKHeaders = [:]
 
   @inlinable
   internal init(
@@ -102,7 +118,7 @@ public final class GRPCAsyncServerCallContext {
     logger: Logger,
     userInfoRef: Ref<UserInfo>
   ) {
-    self.headers = headers
+    self.requestMetadata = headers
     self.userInfoRef = userInfoRef
     self._logger = logger
   }

+ 76 - 47
Sources/GRPC/AsyncAwaitSupport/GRPCAsyncServerHandler.swift

@@ -204,38 +204,56 @@ internal final class AsyncServerHandler<
     /// No headers have been received.
     case idle
 
-    /// Headers have been received, and an async `Task` has been created to execute the user
-    /// handler.
-    ///
-    /// The inputs to the user handler are held in the associated data of this enum value:
-    ///
-    /// - The `PassthroughMessageSource` is the source backing the request stream that is being
-    /// consumed by the user handler.
-    ///
-    /// - The `GRPCAsyncServerContext` is a reference to the context that was passed to the user
-    /// handler.
-    ///
-    /// - The `GRPCAsyncResponseStreamWriter` is the response stream writer that is being written to
-    /// by the user handler. Because this is pausable, it may contain responses after the user
-    /// handler has completed that have yet to be written. However we will remain in the `.active`
-    /// state until the response stream writer has completed.
-    ///
-    /// - The `EventLoopPromise` bridges the NIO and async-await worlds. It is the mechanism that we
-    /// use to run a callback when the user handler has completed. The promise is not passed to the
-    /// user handler directly. Instead it is fulfilled with the result of the async `Task` executing
-    /// the user handler using `completeWithTask(_:)`.
-    ///
-    /// - TODO: It shouldn't really be necessary to stash the `GRPCAsyncResponseStreamWriter` or the
-    /// `EventLoopPromise` in this enum value. Specifically they are never used anywhere when this
-    /// enum value is accessed. However, if we do not store them here then the tests periodically
-    /// segfault. This appears to be an bug in Swift and/or NIO since these should both have been
-    /// captured by `completeWithTask(_:)`.
-    case active(
-      PassthroughMessageSource<Request, Error>,
-      GRPCAsyncServerCallContext,
-      GRPCAsyncResponseStreamWriter<Response>,
-      EventLoopPromise<Void>
-    )
+    @usableFromInline
+    internal struct ActiveState {
+      /// The source backing the request stream that is being consumed by the user handler.
+      @usableFromInline
+      let requestStreamSource: PassthroughMessageSource<Request, Error>
+
+      /// The call context that was passed to the user handler.
+      @usableFromInline
+      let context: GRPCAsyncServerCallContext
+
+      /// The response stream writer that is being used by the user handler.
+      ///
+      /// Because this is pausable, it may contain responses after the user handler has completed
+      /// that have yet to be written. However we will remain in the `.active` state until the
+      /// response stream writer has completed.
+      @usableFromInline
+      let responseStreamWriter: GRPCAsyncResponseStreamWriter<Response>
+
+      /// The response headers have been sent back to the client via the interceptors.
+      @usableFromInline
+      var haveSentResponseHeaders: Bool = false
+
+      /// The promise we are using to bridge the NIO and async-await worlds.
+      ///
+      /// It is the mechanism that we use to run a callback when the user handler has completed.
+      /// The promise is not passed to the user handler directly. Instead it is fulfilled with the
+      /// result of the async `Task` executing the user handler using `completeWithTask(_:)`.
+      ///
+      /// - TODO: It shouldn't really be necessary to stash this promise here. Specifically it is
+      /// never used anywhere when the `.active` enum value is accessed. However, if we do not store
+      /// it here then the tests periodically segfault. This appears to be a reference counting bug
+      /// in Swift and/or NIO since it should have been captured by `completeWithTask(_:)`.
+      let _userHandlerPromise: EventLoopPromise<Void>
+
+      @usableFromInline
+      internal init(
+        requestStreamSource: PassthroughMessageSource<Request, Error>,
+        context: GRPCAsyncServerCallContext,
+        responseStreamWriter: GRPCAsyncResponseStreamWriter<Response>,
+        userHandlerPromise: EventLoopPromise<Void>
+      ) {
+        self.requestStreamSource = requestStreamSource
+        self.context = context
+        self.responseStreamWriter = responseStreamWriter
+        self._userHandlerPromise = userHandlerPromise
+      }
+    }
+
+    /// Headers have been received and an async `Task` has been created to execute the user handler.
+    case active(ActiveState)
 
     /// The handler has completed.
     case completed
@@ -363,15 +381,16 @@ internal final class AsyncServerHandler<
         )
 
       // Set the state to active and bundle in all the associated data.
-      self.state = .active(requestStreamSource, context, responseStreamWriter, userHandlerPromise)
+      self.state = .active(.init(
+        requestStreamSource: requestStreamSource,
+        context: context,
+        responseStreamWriter: responseStreamWriter,
+        userHandlerPromise: userHandlerPromise
+      ))
 
       // Register callback for the completion of the user handler.
       userHandlerPromise.futureResult.whenComplete(self.userHandlerCompleted(_:))
 
-      // Send response headers back via the interceptors.
-      // TODO: In future we may want to defer this until the first response is available from the user handler which will allow the user to set the response headers via the context.
-      self.interceptors.send(.metadata([:]), promise: nil)
-
       // Spin up a task to call the async user handler.
       self.userHandlerTask = userHandlerPromise.completeWithTask {
         return try await withTaskCancellationHandler {
@@ -443,8 +462,8 @@ internal final class AsyncServerHandler<
     switch self.state {
     case .idle:
       self.handleError(GRPCError.ProtocolViolation("Message received before headers"))
-    case let .active(requestStreamSource, _, _, _):
-      switch requestStreamSource.yield(request) {
+    case let .active(activeState):
+      switch activeState.requestStreamSource.yield(request) {
       case .accepted(queueDepth: _):
         // TODO: In future we will potentially issue a read request to the channel based on the value of `queueDepth`.
         break
@@ -467,8 +486,8 @@ internal final class AsyncServerHandler<
     switch self.state {
     case .idle:
       self.handleError(GRPCError.ProtocolViolation("End of stream received before headers"))
-    case let .active(requestStreamSource, _, _, _):
-      switch requestStreamSource.finish() {
+    case let .active(activeState):
+      switch activeState.requestStreamSource.finish() {
       case .accepted(queueDepth: _):
         break
       case .dropped:
@@ -495,7 +514,14 @@ internal final class AsyncServerHandler<
       // The user handler cannot send responses before it has been invoked.
       preconditionFailure()
 
-    case .active:
+    case var .active(activeState):
+      if !activeState.haveSentResponseHeaders {
+        activeState.haveSentResponseHeaders = true
+        self.state = .active(activeState)
+        // Send response headers back via the interceptors.
+        self.interceptors.send(.metadata(activeState.context.initialResponseMetadata), promise: nil)
+      }
+      // Send the response back via the interceptors.
       self.interceptors.send(.message(response, metadata), promise: nil)
 
     case .completed:
@@ -547,10 +573,13 @@ internal final class AsyncServerHandler<
     case .idle:
       preconditionFailure()
 
-    case let .active(_, context, _, _):
+    case let .active(activeState):
       // Now we have drained the response stream writer from the user handler we can send end.
       self.state = .completed
-      self.interceptors.send(.end(status, context.trailers), promise: nil)
+      self.interceptors.send(
+        .end(status, activeState.context.trailingResponseMetadata),
+        promise: nil
+      )
 
     case .completed:
       ()
@@ -580,7 +609,7 @@ internal final class AsyncServerHandler<
       )
       self.interceptors.send(.end(status, trailers), promise: nil)
 
-    case let .active(_, context, _, _):
+    case let .active(activeState):
       self.state = .completed
 
       // If we have an async task, then cancel it, which will terminate the request stream from
@@ -593,8 +622,8 @@ internal final class AsyncServerHandler<
       if isHandlerError {
         (status, trailers) = ServerErrorProcessor.processObserverError(
           error,
-          headers: context.headers,
-          trailers: context.trailers,
+          headers: activeState.context.requestMetadata,
+          trailers: activeState.context.trailingResponseMetadata,
           delegate: self.context.errorDelegate
         )
       } else {

+ 4 - 0
Sources/GRPC/AsyncAwaitSupport/GRPCAsyncServerStreamingCall.swift

@@ -39,6 +39,10 @@ public struct GRPCAsyncServerStreamingCall<Request, Response> {
   // MARK: - Response Parts
 
   /// The initial metadata returned from the server.
+  ///
+  /// - Important: The initial metadata will only be available when the first response has been
+  /// received. However, it is not necessary for the response to have been consumed before reading
+  /// this property.
   public var initialMetadata: HPACKHeaders {
     // swiftformat:disable:next redundantGet
     get async throws {

+ 2 - 0
Sources/GRPC/AsyncAwaitSupport/GRPCAsyncUnaryCall.swift

@@ -39,6 +39,8 @@ public struct GRPCAsyncUnaryCall<Request, Response> {
   // MARK: - Response Parts
 
   /// The initial metadata returned from the server.
+  ///
+  /// - Important: The initial metadata will only be available when the response has been received.
   public var initialMetadata: HPACKHeaders {
     // swiftformat:disable:next redundantGet
     get async throws {

+ 50 - 39
Tests/GRPCTests/GRPCAsyncServerHandlerTests.swift

@@ -74,8 +74,6 @@ class AsyncServerHandlerTests: ServerHandlerTestCaseBase {
     )
 
     handler.receiveMetadata([:])
-    await assertThat(self.recorder.metadata, .is([:]))
-
     handler.receiveMessage(ByteBuffer(string: "1"))
     handler.receiveMessage(ByteBuffer(string: "2"))
     handler.receiveMessage(ByteBuffer(string: "3"))
@@ -86,6 +84,7 @@ class AsyncServerHandlerTests: ServerHandlerTestCaseBase {
 
     handler.finish()
 
+    await assertThat(self.recorder.metadata, .is([:]))
     await assertThat(
       self.recorder.messages,
       .is([ByteBuffer(string: "1"), ByteBuffer(string: "2"), ByteBuffer(string: "3")])
@@ -145,14 +144,46 @@ class AsyncServerHandlerTests: ServerHandlerTestCaseBase {
     await assertThat(self.recorder.messageMetadata.map { $0.compress }, .is([false, false, false]))
   } }
 
+  func testResponseHeadersAndTrailersSentFromContext() { XCTAsyncTest {
+    let handler = self.makeHandler { _, responseStreamWriter, context in
+      context.initialResponseMetadata = ["pontiac": "bandit"]
+      try await responseStreamWriter.send("1")
+      context.trailingResponseMetadata = ["disco": "strangler"]
+    }
+    handler.receiveMetadata([:])
+    handler.receiveEnd()
+
+    // Wait for tasks to finish.
+    await handler.userHandlerTask?.value
+
+    await assertThat(self.recorder.metadata, .is(["pontiac": "bandit"]))
+    await assertThat(self.recorder.trailers, .is(["disco": "strangler"]))
+  } }
+
+  func testResponseHeadersDroppedIfSetAfterFirstResponse() { XCTAsyncTest {
+    let handler = self.makeHandler { _, responseStreamWriter, context in
+      try await responseStreamWriter.send("1")
+      context.initialResponseMetadata = ["pontiac": "bandit"]
+      context.trailingResponseMetadata = ["disco": "strangler"]
+    }
+    handler.receiveMetadata([:])
+    handler.receiveEnd()
+
+    // Wait for tasks to finish.
+    await handler.userHandlerTask?.value
+
+    await assertThat(self.recorder.metadata, .is([:]))
+    await assertThat(self.recorder.trailers, .is(["disco": "strangler"]))
+  } }
+
   func testTaskOnlyCreatedAfterHeaders() { XCTAsyncTest {
     let handler = self.makeHandler(observer: self.echo(requests:responseStreamWriter:context:))
 
-    await assertThat(handler.userHandlerTask, .is(.nil()))
+    await assertThat(handler.userHandlerTask, .nil())
 
     handler.receiveMetadata([:])
 
-    await assertThat(handler.userHandlerTask, .is(.notNil()))
+    await assertThat(handler.userHandlerTask, .notNil())
   } }
 
   func testThrowingDeserializer() { XCTAsyncTest {
@@ -165,18 +196,12 @@ class AsyncServerHandlerTests: ServerHandlerTestCaseBase {
     )
 
     handler.receiveMetadata([:])
-
-    // Wait for the async user function to have processed the metadata.
-    try self.recorder.recordedMetadataPromise.futureResult.wait()
-
-    await assertThat(self.recorder.metadata, .is([:]))
-
-    let buffer = ByteBuffer(string: "hello")
-    handler.receiveMessage(buffer)
+    handler.receiveMessage(ByteBuffer(string: "hello"))
 
     // Wait for tasks to finish.
     await handler.userHandlerTask?.value
 
+    await assertThat(self.recorder.metadata, .nil())
     await assertThat(self.recorder.messages, .isEmpty())
     await assertThat(self.recorder.status, .notNil(.hasCode(.internalError)))
   } }
@@ -191,15 +216,13 @@ class AsyncServerHandlerTests: ServerHandlerTestCaseBase {
     )
 
     handler.receiveMetadata([:])
-    await assertThat(self.recorder.metadata, .is([:]))
-
-    let buffer = ByteBuffer(string: "hello")
-    handler.receiveMessage(buffer)
+    handler.receiveMessage(ByteBuffer(string: "hello"))
     handler.receiveEnd()
 
     // Wait for tasks to finish.
     await handler.userHandlerTask?.value
 
+    await assertThat(self.recorder.metadata, .is([:]))
     await assertThat(self.recorder.messages, .isEmpty())
     await assertThat(self.recorder.status, .notNil(.hasCode(.internalError)))
   } }
@@ -213,28 +236,22 @@ class AsyncServerHandlerTests: ServerHandlerTestCaseBase {
     // Wait for tasks to finish.
     await handler.userHandlerTask?.value
 
-    await assertThat(self.recorder.metadata, .is(.nil()))
+    await assertThat(self.recorder.metadata, .nil())
     await assertThat(self.recorder.messages, .isEmpty())
     await assertThat(self.recorder.status, .notNil(.hasCode(.internalError)))
   } }
 
-  // TODO: Running this 1000 times shows up a segfault in NIO event loop group.
   func testReceiveMultipleHeaders() { XCTAsyncTest {
     let handler = self
       .makeHandler(observer: self.neverReceivesMessage(requests:responseStreamWriter:context:))
 
     handler.receiveMetadata([:])
-
-    // Wait for the async user function to have processed the metadata.
-    try self.recorder.recordedMetadataPromise.futureResult.wait()
-
-    await assertThat(self.recorder.metadata, .is([:]))
-
     handler.receiveMetadata([:])
 
     // Wait for tasks to finish.
     await handler.userHandlerTask?.value
 
+    await assertThat(self.recorder.metadata, .nil())
     await assertThat(self.recorder.messages, .isEmpty())
     await assertThat(self.recorder.status, .notNil(.hasCode(.internalError)))
   } }
@@ -244,26 +261,22 @@ class AsyncServerHandlerTests: ServerHandlerTestCaseBase {
       .makeHandler(observer: self.neverCalled(requests:responseStreamWriter:context:))
 
     handler.finish()
-    await assertThat(self.recorder.metadata, .is(.nil()))
+    await assertThat(self.recorder.metadata, .nil())
     await assertThat(self.recorder.messages, .isEmpty())
-    await assertThat(self.recorder.status, .is(.nil()))
-    await assertThat(self.recorder.trailers, .is(.nil()))
+    await assertThat(self.recorder.status, .nil())
+    await assertThat(self.recorder.trailers, .nil())
   } }
 
   func testFinishAfterHeaders() { XCTAsyncTest {
     let handler = self.makeHandler(observer: self.echo(requests:responseStreamWriter:context:))
-    handler.receiveMetadata([:])
-
-    // Wait for the async user function to have processed the metadata.
-    try self.recorder.recordedMetadataPromise.futureResult.wait()
-
-    await assertThat(self.recorder.metadata, .is([:]))
 
+    handler.receiveMetadata([:])
     handler.finish()
 
     // Wait for tasks to finish.
     await handler.userHandlerTask?.value
 
+    await assertThat(self.recorder.metadata, .nil())
     await assertThat(self.recorder.messages, .isEmpty())
     await assertThat(self.recorder.status, .notNil(.hasCode(.unavailable)))
     await assertThat(self.recorder.trailers, .is([:]))
@@ -304,8 +317,6 @@ class AsyncServerHandlerTests: ServerHandlerTestCaseBase {
     await assertThat(self.recorder.status, .notNil(.hasCode(.unknown)))
   } }
 
-  // TODO: We should be consistent about where we put the tasks... maybe even use a task group to simplify cancellation (unless they both go in the enum state which might be better).
-
   func testResponseStreamDrain() { XCTAsyncTest {
     // Set up echo handler.
     let handler = self.makeHandler(
@@ -317,14 +328,14 @@ class AsyncServerHandlerTests: ServerHandlerTestCaseBase {
 
     // Send two requests and end, pausing the writer in the middle.
     switch handler.state {
-    case let .active(_, _, responseStreamWriter, promise):
+    case let .active(activeState):
       handler.receiveMessage(ByteBuffer(string: "diaz"))
-      await responseStreamWriter.asyncWriter.toggleWritability()
+      await activeState.responseStreamWriter.asyncWriter.toggleWritability()
       handler.receiveMessage(ByteBuffer(string: "santiago"))
       handler.receiveEnd()
-      await responseStreamWriter.asyncWriter.toggleWritability()
+      await activeState.responseStreamWriter.asyncWriter.toggleWritability()
       await handler.userHandlerTask?.value
-      _ = try await promise.futureResult.get()
+      _ = try await activeState._userHandlerPromise.futureResult.get()
     default:
       XCTFail("Unexpected handler state: \(handler.state)")
     }