Browse Source

Allow server handlers to send response headers directly (#1599)

Motivation:

The async server call context allows users to set headers which are sent
when the first message is sent. In many cases this is fine, however,
some use cases require the headers to be sent immediately.

Modifications:

- Add `sendHeaders(_:)` to the `GRPCAsyncServerCallContext` which sends
  headers to the client and throws if headers have already been written
  or it's too late to send them.

Result:

Headers can be sent directly from a server call handler.
George Barnett 2 years ago
parent
commit
6b55ce088b

+ 3 - 2
Sources/GRPC/AsyncAwaitSupport/AsyncServerHandler/ServerHandlerStateMachine/ServerHandlerStateMachine+Actions.swift

@@ -70,12 +70,13 @@ extension ServerHandlerStateMachine {
 
 
     /// Update the metadata. It must not have been written yet.
     /// Update the metadata. It must not have been written yet.
     @inlinable
     @inlinable
-    mutating func update(_ metadata: HPACKHeaders) {
+    mutating func update(_ metadata: HPACKHeaders) -> Bool {
       switch self {
       switch self {
       case .notWritten:
       case .notWritten:
         self = .notWritten(metadata)
         self = .notWritten(metadata)
+        return true
       case .written:
       case .written:
-        assertionFailure("Metadata must not be set after it has been sent")
+        return false
       }
       }
     }
     }
 
 

+ 4 - 4
Sources/GRPC/AsyncAwaitSupport/AsyncServerHandler/ServerHandlerStateMachine/ServerHandlerStateMachine+Draining.swift

@@ -50,16 +50,16 @@ extension ServerHandlerStateMachine {
     @inlinable
     @inlinable
     mutating func setResponseHeaders(
     mutating func setResponseHeaders(
       _ metadata: HPACKHeaders
       _ metadata: HPACKHeaders
-    ) -> Self.NextStateAndOutput<Void> {
-      self.responseHeaders.update(metadata)
-      return .init(nextState: .draining(self))
+    ) -> Self.NextStateAndOutput<Bool> {
+      let output = self.responseHeaders.update(metadata)
+      return .init(nextState: .draining(self), output: output)
     }
     }
 
 
     @inlinable
     @inlinable
     mutating func setResponseTrailers(
     mutating func setResponseTrailers(
       _ metadata: HPACKHeaders
       _ metadata: HPACKHeaders
     ) -> Self.NextStateAndOutput<Void> {
     ) -> Self.NextStateAndOutput<Void> {
-      self.responseTrailers.update(metadata)
+      _ = self.responseTrailers.update(metadata)
       return .init(nextState: .draining(self))
       return .init(nextState: .draining(self))
     }
     }
 
 

+ 2 - 2
Sources/GRPC/AsyncAwaitSupport/AsyncServerHandler/ServerHandlerStateMachine/ServerHandlerStateMachine+Finished.swift

@@ -35,8 +35,8 @@ extension ServerHandlerStateMachine {
     @inlinable
     @inlinable
     mutating func setResponseHeaders(
     mutating func setResponseHeaders(
       _ headers: HPACKHeaders
       _ headers: HPACKHeaders
-    ) -> Self.NextStateAndOutput<Void> {
-      return .init(nextState: .finished(self))
+    ) -> Self.NextStateAndOutput<Bool> {
+      return .init(nextState: .finished(self), output: false)
     }
     }
 
 
     @inlinable
     @inlinable

+ 4 - 4
Sources/GRPC/AsyncAwaitSupport/AsyncServerHandler/ServerHandlerStateMachine/ServerHandlerStateMachine+Handling.swift

@@ -50,16 +50,16 @@ extension ServerHandlerStateMachine {
     @inlinable
     @inlinable
     mutating func setResponseHeaders(
     mutating func setResponseHeaders(
       _ metadata: HPACKHeaders
       _ metadata: HPACKHeaders
-    ) -> Self.NextStateAndOutput<Void> {
-      self.responseHeaders.update(metadata)
-      return .init(nextState: .handling(self))
+    ) -> Self.NextStateAndOutput<Bool> {
+      let output = self.responseHeaders.update(metadata)
+      return .init(nextState: .handling(self), output: output)
     }
     }
 
 
     @inlinable
     @inlinable
     mutating func setResponseTrailers(
     mutating func setResponseTrailers(
       _ metadata: HPACKHeaders
       _ metadata: HPACKHeaders
     ) -> Self.NextStateAndOutput<Void> {
     ) -> Self.NextStateAndOutput<Void> {
-      self.responseTrailers.update(metadata)
+      _ = self.responseTrailers.update(metadata)
       return .init(nextState: .handling(self))
       return .init(nextState: .handling(self))
     }
     }
 
 

+ 1 - 1
Sources/GRPC/AsyncAwaitSupport/AsyncServerHandler/ServerHandlerStateMachine/ServerHandlerStateMachine.swift

@@ -27,7 +27,7 @@ internal struct ServerHandlerStateMachine {
   }
   }
 
 
   @inlinable
   @inlinable
-  mutating func setResponseHeaders(_ headers: HPACKHeaders) {
+  mutating func setResponseHeaders(_ headers: HPACKHeaders) -> Bool {
     switch self.state {
     switch self.state {
     case var .handling(handling):
     case var .handling(handling):
       let nextStateAndOutput = handling.setResponseHeaders(headers)
       let nextStateAndOutput = handling.setResponseHeaders(headers)

+ 17 - 2
Sources/GRPC/AsyncAwaitSupport/GRPCAsyncServerCallContext.swift

@@ -30,6 +30,21 @@ public struct GRPCAsyncServerCallContext: Sendable {
     Response(contextProvider: self.contextProvider)
     Response(contextProvider: self.contextProvider)
   }
   }
 
 
+  /// Notifies the client that the RPC has been accepted for processing by the server.
+  ///
+  /// On accepting the RPC the server will send the given headers (which may be empty) along with
+  /// any transport specific headers (such the ":status" pseudo header) to the client.
+  ///
+  /// It is not necessary to call this function: the RPC is implicitly accepted when the first
+  /// response message is sent, however this may be useful when clients require an early indication
+  /// that the RPC has been accepted.
+  ///
+  /// If the RPC has already been accepted (either implicitly or explicitly) then this function is
+  /// a no-op.
+  public func acceptRPC(headers: HPACKHeaders) async {
+    await self.contextProvider.acceptRPC(headers)
+  }
+
   /// Access the ``UserInfo`` dictionary which is shared with the interceptor contexts for this RPC.
   /// Access the ``UserInfo`` dictionary which is shared with the interceptor contexts for this RPC.
   ///
   ///
   /// - Important: While ``UserInfo`` has value-semantics, this function accesses a reference
   /// - Important: While ``UserInfo`` has value-semantics, this function accesses a reference
@@ -87,8 +102,8 @@ extension GRPCAsyncServerCallContext {
     /// Set the metadata to return at the start of the RPC.
     /// Set the metadata to return at the start of the RPC.
     ///
     ///
     /// - Important: If this is required it should be updated _before_ the first response is sent
     /// - Important: If this is required it should be updated _before_ the first response is sent
-    ///   via the response stream writer. Updates must not be made after the first response has
-    ///   been sent.
+    ///   via the response stream writer. Updates must not be made after the RPC has been accepted
+    ///   or the first response has been sent otherwise this method will throw an error.
     public func setHeaders(_ headers: HPACKHeaders) async throws {
     public func setHeaders(_ headers: HPACKHeaders) async throws {
       try await self.contextProvider.setResponseHeaders(headers)
       try await self.contextProvider.setResponseHeaders(headers)
     }
     }

+ 49 - 2
Sources/GRPC/AsyncAwaitSupport/GRPCAsyncServerHandler.swift

@@ -207,6 +207,16 @@ internal final class AsyncServerHandler<
   @usableFromInline
   @usableFromInline
   internal var compressResponsesIfPossible: Bool
   internal var compressResponsesIfPossible: Bool
 
 
+  /// The interceptor pipeline does not track flushing as a separate event. The flush decision is
+  /// included with metadata alongside each message. For the status and trailers the flush is
+  /// implicit. For headers we track whether to flush here.
+  ///
+  /// In most cases the flush will be delayed until the first message is flushed and this will
+  /// remain unset. However, this may be set when the server handler
+  /// uses ``GRPCAsyncServerCallContext/sendHeaders(_:)``.
+  @usableFromInline
+  internal var flushNextHeaders: Bool
+
   /// A state machine for the interceptor pipeline.
   /// A state machine for the interceptor pipeline.
   @usableFromInline
   @usableFromInline
   internal private(set) var interceptorStateMachine: ServerInterceptorStateMachine
   internal private(set) var interceptorStateMachine: ServerInterceptorStateMachine
@@ -265,6 +275,7 @@ internal final class AsyncServerHandler<
     self.errorDelegate = context.errorDelegate
     self.errorDelegate = context.errorDelegate
     self.compressionEnabledOnRPC = context.encoding.isEnabled
     self.compressionEnabledOnRPC = context.encoding.isEnabled
     self.compressResponsesIfPossible = true
     self.compressResponsesIfPossible = true
+    self.flushNextHeaders = false
     self.logger = context.logger
     self.logger = context.logger
 
 
     self.userInfoRef = Ref(UserInfo())
     self.userInfoRef = Ref(UserInfo())
@@ -685,7 +696,9 @@ internal final class AsyncServerHandler<
     switch self.interceptorStateMachine.interceptedResponseMetadata() {
     switch self.interceptorStateMachine.interceptedResponseMetadata() {
     case .forward:
     case .forward:
       if let responseWriter = self.responseWriter {
       if let responseWriter = self.responseWriter {
-        responseWriter.sendMetadata(metadata, flush: false, promise: promise)
+        let flush = self.flushNextHeaders
+        self.flushNextHeaders = false
+        responseWriter.sendMetadata(metadata, flush: flush, promise: promise)
       } else if let promise = promise {
       } else if let promise = promise {
         promise.fail(GRPCStatus.processingError)
         promise.fail(GRPCStatus.processingError)
       }
       }
@@ -747,11 +760,44 @@ extension AsyncServerHandler: AsyncServerCallContextProvider {
   @usableFromInline
   @usableFromInline
   internal func setResponseHeaders(_ headers: HPACKHeaders) async throws {
   internal func setResponseHeaders(_ headers: HPACKHeaders) async throws {
     let completed = self.eventLoop.submit {
     let completed = self.eventLoop.submit {
-      self.handlerStateMachine.setResponseHeaders(headers)
+      if !self.handlerStateMachine.setResponseHeaders(headers) {
+        throw GRPCStatus(
+          code: .failedPrecondition,
+          message: "Tried to send response headers in an invalid state"
+        )
+      }
     }
     }
     try await completed.get()
     try await completed.get()
   }
   }
 
 
+  @usableFromInline
+  internal func acceptRPC(_ headers: HPACKHeaders) async {
+    let completed = self.eventLoop.submit {
+      guard self.handlerStateMachine.setResponseHeaders(headers) else { return }
+
+      // Shh,it's a lie! We don't really have a message to send but the state machine doesn't know
+      // (or care) about that. It will, however, tell us if we can send the headers or not.
+      switch self.handlerStateMachine.sendMessage() {
+      case let .intercept(.some(headers)):
+        switch self.interceptorStateMachine.interceptResponseMetadata() {
+        case .intercept:
+          self.flushNextHeaders = true
+          self.interceptors?.send(.metadata(headers), promise: nil)
+        case .cancel:
+          return self.cancel(error: nil)
+        case .drop:
+          ()
+        }
+
+      case .intercept(.none), .drop:
+        // intercept(.none) means headers have already been sent; we should never hit this because
+        // we guard on setting the response headers above.
+        ()
+      }
+    }
+    try? await completed.get()
+  }
+
   @usableFromInline
   @usableFromInline
   internal func setResponseTrailers(_ headers: HPACKHeaders) async throws {
   internal func setResponseTrailers(_ headers: HPACKHeaders) async throws {
     let completed = self.eventLoop.submit {
     let completed = self.eventLoop.submit {
@@ -798,6 +844,7 @@ extension AsyncServerHandler: AsyncServerCallContextProvider {
 @usableFromInline
 @usableFromInline
 protocol AsyncServerCallContextProvider: Sendable {
 protocol AsyncServerCallContextProvider: Sendable {
   func setResponseHeaders(_ headers: HPACKHeaders) async throws
   func setResponseHeaders(_ headers: HPACKHeaders) async throws
+  func acceptRPC(_ headers: HPACKHeaders) async
   func setResponseTrailers(_ trailers: HPACKHeaders) async throws
   func setResponseTrailers(_ trailers: HPACKHeaders) async throws
   func setResponseCompression(_ enabled: Bool) async throws
   func setResponseCompression(_ enabled: Bool) async throws
 
 

+ 4 - 5
Tests/GRPCTests/AsyncAwaitSupport/AsyncServerHandler/ServerHandlerStateMachine/ServerHandlerStateMachineTests.swift

@@ -210,7 +210,7 @@ internal final class ServerHandlerStateMachineTests: GRPCTestCase {
 
 
   func testSetResponseHeadersWhenHandling() {
   func testSetResponseHeadersWhenHandling() {
     var stateMachine = self.makeStateMachine(inState: .handling)
     var stateMachine = self.makeStateMachine(inState: .handling)
-    stateMachine.setResponseHeaders(["foo": "bar"])
+    XCTAssertTrue(stateMachine.setResponseHeaders(["foo": "bar"]))
     stateMachine.sendMessage().assertInterceptHeadersThenMessage { headers in
     stateMachine.sendMessage().assertInterceptHeadersThenMessage { headers in
       XCTAssertEqual(headers, ["foo": "bar"])
       XCTAssertEqual(headers, ["foo": "bar"])
     }
     }
@@ -218,7 +218,7 @@ internal final class ServerHandlerStateMachineTests: GRPCTestCase {
 
 
   func testSetResponseHeadersWhenHandlingAreMovedToDraining() {
   func testSetResponseHeadersWhenHandlingAreMovedToDraining() {
     var stateMachine = self.makeStateMachine(inState: .handling)
     var stateMachine = self.makeStateMachine(inState: .handling)
-    stateMachine.setResponseHeaders(["foo": "bar"])
+    XCTAssertTrue(stateMachine.setResponseHeaders(["foo": "bar"]))
     stateMachine.handleEnd().assertForward()
     stateMachine.handleEnd().assertForward()
     stateMachine.sendMessage().assertInterceptHeadersThenMessage { headers in
     stateMachine.sendMessage().assertInterceptHeadersThenMessage { headers in
       XCTAssertEqual(headers, ["foo": "bar"])
       XCTAssertEqual(headers, ["foo": "bar"])
@@ -227,7 +227,7 @@ internal final class ServerHandlerStateMachineTests: GRPCTestCase {
 
 
   func testSetResponseHeadersWhenDraining() {
   func testSetResponseHeadersWhenDraining() {
     var stateMachine = self.makeStateMachine(inState: .draining)
     var stateMachine = self.makeStateMachine(inState: .draining)
-    stateMachine.setResponseHeaders(["foo": "bar"])
+    XCTAssertTrue(stateMachine.setResponseHeaders(["foo": "bar"]))
     stateMachine.sendMessage().assertInterceptHeadersThenMessage { headers in
     stateMachine.sendMessage().assertInterceptHeadersThenMessage { headers in
       XCTAssertEqual(headers, ["foo": "bar"])
       XCTAssertEqual(headers, ["foo": "bar"])
     }
     }
@@ -235,8 +235,7 @@ internal final class ServerHandlerStateMachineTests: GRPCTestCase {
 
 
   func testSetResponseHeadersWhenFinished() {
   func testSetResponseHeadersWhenFinished() {
     var stateMachine = self.makeStateMachine(inState: .finished)
     var stateMachine = self.makeStateMachine(inState: .finished)
-    stateMachine.setResponseHeaders(["foo": "bar"])
-    // Nothing we can assert on, only that we don't crash.
+    XCTAssertFalse(stateMachine.setResponseHeaders(["foo": "bar"]))
   }
   }
 
 
   func testSetResponseTrailersWhenHandling() {
   func testSetResponseTrailersWhenHandling() {

+ 168 - 2
Tests/GRPCTests/GRPCAsyncClientCallTests.swift

@@ -35,12 +35,14 @@ class GRPCAsyncClientCallTests: GRPCTestCase {
     ("grpc-status", "0"),
     ("grpc-status", "0"),
   ])
   ])
 
 
-  private func setUpServerAndChannel() throws -> ClientConnection {
+  private func setUpServerAndChannel(
+    service: CallHandlerProvider = EchoProvider()
+  ) throws -> ClientConnection {
     let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
     let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
     self.group = group
     self.group = group
 
 
     let server = try Server.insecure(group: group)
     let server = try Server.insecure(group: group)
-      .withServiceProviders([EchoProvider()])
+      .withServiceProviders([service])
       .withLogger(self.serverLogger)
       .withLogger(self.serverLogger)
       .bind(host: "127.0.0.1", port: 0)
       .bind(host: "127.0.0.1", port: 0)
       .wait()
       .wait()
@@ -204,6 +206,110 @@ class GRPCAsyncClientCallTests: GRPCTestCase {
     await assertThat(try await update.trailingMetadata, .is(.equalTo(Self.OKTrailingMetadata)))
     await assertThat(try await update.trailingMetadata, .is(.equalTo(Self.OKTrailingMetadata)))
     await assertThat(await update.status, .hasCode(.ok))
     await assertThat(await update.status, .hasCode(.ok))
   }
   }
+
+  func testExplicitAcceptUnary(twice: Bool, function: String = #function) async throws {
+    let headers: HPACKHeaders = ["fn": function]
+    let channel = try self.setUpServerAndChannel(
+      service: AsyncEchoProvider(headers: headers, sendTwice: twice)
+    )
+    let echo = Echo_EchoAsyncClient(channel: channel)
+    let call = echo.makeGetCall(.with { $0.text = "" })
+    let responseHeaders = try await call.initialMetadata
+    XCTAssertEqual(responseHeaders.first(name: "fn"), function)
+    let status = await call.status
+    XCTAssertEqual(status.code, .ok)
+  }
+
+  func testExplicitAcceptUnary() async throws {
+    try await self.testExplicitAcceptUnary(twice: false)
+  }
+
+  func testExplicitAcceptTwiceUnary() async throws {
+    try await self.testExplicitAcceptUnary(twice: true)
+  }
+
+  func testExplicitAcceptClientStreaming(twice: Bool, function: String = #function) async throws {
+    let headers: HPACKHeaders = ["fn": function]
+    let channel = try self.setUpServerAndChannel(
+      service: AsyncEchoProvider(headers: headers, sendTwice: twice)
+    )
+    let echo = Echo_EchoAsyncClient(channel: channel)
+    let call = echo.makeCollectCall()
+    let responseHeaders = try await call.initialMetadata
+    XCTAssertEqual(responseHeaders.first(name: "fn"), function)
+
+    // Close request stream; the response should be empty.
+    call.requestStream.finish()
+    let response = try await call.response
+    XCTAssertEqual(response.text, "")
+
+    let status = await call.status
+    XCTAssertEqual(status.code, .ok)
+  }
+
+  func testExplicitAcceptClientStreaming() async throws {
+    try await self.testExplicitAcceptClientStreaming(twice: false)
+  }
+
+  func testExplicitAcceptTwiceClientStreaming() async throws {
+    try await self.testExplicitAcceptClientStreaming(twice: true)
+  }
+
+  func testExplicitAcceptServerStreaming(twice: Bool, function: String = #function) async throws {
+    let headers: HPACKHeaders = ["fn": #function]
+    let channel = try self.setUpServerAndChannel(
+      service: AsyncEchoProvider(headers: headers, sendTwice: twice)
+    )
+    let echo = Echo_EchoAsyncClient(channel: channel)
+    let call = echo.makeExpandCall(.with { $0.text = "foo bar baz" })
+    let responseHeaders = try await call.initialMetadata
+    XCTAssertEqual(responseHeaders.first(name: "fn"), #function)
+
+    // Close request stream; the response should be empty.
+    let responses = try await call.responseStream.collect()
+    XCTAssertEqual(responses.count, 3)
+
+    let status = await call.status
+    XCTAssertEqual(status.code, .ok)
+  }
+
+  func testExplicitAcceptServerStreaming() async throws {
+    try await self.testExplicitAcceptServerStreaming(twice: false)
+  }
+
+  func testExplicitAcceptTwiceServerStreaming() async throws {
+    try await self.testExplicitAcceptServerStreaming(twice: true)
+  }
+
+  func testExplicitAcceptBidirectionalStreaming(
+    twice: Bool,
+    function: String = #function
+  ) async throws {
+    let headers: HPACKHeaders = ["fn": function]
+    let channel = try self.setUpServerAndChannel(
+      service: AsyncEchoProvider(headers: headers, sendTwice: twice)
+    )
+    let echo = Echo_EchoAsyncClient(channel: channel)
+    let call = echo.makeUpdateCall()
+    let responseHeaders = try await call.initialMetadata
+    XCTAssertEqual(responseHeaders.first(name: "fn"), function)
+
+    // Close request stream; there should be no responses.
+    call.requestStream.finish()
+    let responses = try await call.responseStream.collect()
+    XCTAssertEqual(responses.count, 0)
+
+    let status = await call.status
+    XCTAssertEqual(status.code, .ok)
+  }
+
+  func testExplicitAcceptBidirectionalStreaming() async throws {
+    try await self.testExplicitAcceptBidirectionalStreaming(twice: false)
+  }
+
+  func testExplicitAcceptTwiceBidirectionalStreaming() async throws {
+    try await self.testExplicitAcceptBidirectionalStreaming(twice: true)
+  }
 }
 }
 
 
 // Workaround https://bugs.swift.org/browse/SR-15070 (compiler crashes when defining a class/actor
 // Workaround https://bugs.swift.org/browse/SR-15070 (compiler crashes when defining a class/actor
@@ -221,3 +327,63 @@ private actor RequestResponseCounter {
     self.numRequests += 1
     self.numRequests += 1
   }
   }
 }
 }
+
+private final class AsyncEchoProvider: Echo_EchoAsyncProvider {
+  let headers: HPACKHeaders
+  let sendTwice: Bool
+
+  init(headers: HPACKHeaders, sendTwice: Bool = false) {
+    self.headers = headers
+    self.sendTwice = sendTwice
+  }
+
+  private func accept(context: GRPCAsyncServerCallContext) async {
+    await context.acceptRPC(headers: self.headers)
+    if self.sendTwice {
+      await context.acceptRPC(headers: self.headers) // Should be a no-op.
+    }
+  }
+
+  func get(
+    request: Echo_EchoRequest,
+    context: GRPCAsyncServerCallContext
+  ) async throws -> Echo_EchoResponse {
+    await self.accept(context: context)
+    return Echo_EchoResponse.with { $0.text = request.text }
+  }
+
+  func expand(
+    request: Echo_EchoRequest,
+    responseStream: GRPCAsyncResponseStreamWriter<Echo_EchoResponse>,
+    context: GRPCAsyncServerCallContext
+  ) async throws {
+    await self.accept(context: context)
+    for part in request.text.components(separatedBy: " ") {
+      let response = Echo_EchoResponse.with {
+        $0.text = part
+      }
+      try await responseStream.send(response)
+    }
+  }
+
+  func collect(
+    requestStream: GRPCAsyncRequestStream<Echo_EchoRequest>,
+    context: GRPCAsyncServerCallContext
+  ) async throws -> Echo_EchoResponse {
+    await self.accept(context: context)
+    let collected = try await requestStream.map { $0.text }.collect().joined(separator: " ")
+    return Echo_EchoResponse.with { $0.text = collected }
+  }
+
+  func update(
+    requestStream: GRPCAsyncRequestStream<Echo_EchoRequest>,
+    responseStream: GRPCAsyncResponseStreamWriter<Echo_EchoResponse>,
+    context: GRPCAsyncServerCallContext
+  ) async throws {
+    await self.accept(context: context)
+    for try await request in requestStream {
+      let response = Echo_EchoResponse.with { $0.text = request.text }
+      try await responseStream.send(response)
+    }
+  }
+}