Browse Source

Refactor the async server call context (#1407)

Motivation:

The async call context provided to async server methods provides an API
to access request headers and the logger as well as modify the response
headers, trailers and user info.

This is currently implemented as a class with its mutable state
protected by a lock. This works fine but is racy: the response headers
can be updated after writing the first message and still be sent before
that message (as writing the message requires executing onto the event
loop).

Modifications:

- Turn the context into a `struct` and break it into request and
  response components.
- Request components are immutable (or, more correctly, mutating them
  has no impact on the underlying RPC).
- Response components are mutable via `async` calls which execute onto
  the underlying event loop.
- This introduces allocations as the work is submitted onto the
  event loop, however setting headers is not a frequent operation and is
  worth the trade off for a less surprising API.
- Trap when headers/trailers are set after they have been sent
  (this is an assertion failre).

Result:

- Async server context has a less surprising API
- Setting headers/trailers is not racy
George Barnett 3 years ago
parent
commit
981e6cdd21

+ 33 - 2
Sources/GRPC/AsyncAwaitSupport/AsyncServerHandler/ServerHandlerStateMachine/ServerHandlerStateMachine+Actions.swift

@@ -19,9 +19,9 @@ import NIOHPACK
 @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
 extension ServerHandlerStateMachine {
   @usableFromInline
-  enum HandleMetadataAction {
+  enum HandleMetadataAction: Hashable {
     /// Invoke the user handler.
-    case invokeHandler(Ref<UserInfo>, CallHandlerContext)
+    case invokeHandler
     /// Cancel the RPC, the metadata was not expected.
     case cancel
   }
@@ -62,5 +62,36 @@ extension ServerHandlerStateMachine {
     /// Don't do anything.
     case none
   }
+
+  /// Tracks whether response metadata has been written.
+  @usableFromInline
+  internal enum ResponseMetadata {
+    case notWritten(HPACKHeaders)
+    case written
+
+    /// Update the metadata. It must not have been written yet.
+    @inlinable
+    mutating func update(_ metadata: HPACKHeaders) {
+      switch self {
+      case .notWritten:
+        self = .notWritten(metadata)
+      case .written:
+        assertionFailure("Metadata must not be set after it has been sent")
+      }
+    }
+
+    /// Returns the metadata if it has not been written and moves the state to
+    /// `written`. Returns `nil` if it has already been written.
+    @inlinable
+    mutating func getIfNotWritten() -> HPACKHeaders? {
+      switch self {
+      case let .notWritten(metadata):
+        self = .written
+        return metadata
+      case .written:
+        return nil
+      }
+    }
+  }
 }
 #endif // compiler(>=5.6)

+ 33 - 16
Sources/GRPC/AsyncAwaitSupport/AsyncServerHandler/ServerHandlerStateMachine/ServerHandlerStateMachine+Draining.swift

@@ -31,16 +31,37 @@ extension ServerHandlerStateMachine {
         Output
       >
 
-    /// Whether the response headers have been written yet.
+    /// The response headers.
     @usableFromInline
-    internal private(set) var headersWritten: Bool
+    internal private(set) var responseHeaders: ResponseMetadata
+    /// The response trailers.
     @usableFromInline
-    internal let context: GRPCAsyncServerCallContext
+    internal private(set) var responseTrailers: ResponseMetadata
+    /// The request headers.
+    @usableFromInline
+    internal let requestHeaders: HPACKHeaders
 
     @inlinable
     init(from state: ServerHandlerStateMachine.Handling) {
-      self.headersWritten = state.headersWritten
-      self.context = state.context
+      self.responseHeaders = state.responseHeaders
+      self.responseTrailers = state.responseTrailers
+      self.requestHeaders = state.requestHeaders
+    }
+
+    @inlinable
+    mutating func setResponseHeaders(
+      _ metadata: HPACKHeaders
+    ) -> Self.NextStateAndOutput<Void> {
+      self.responseHeaders.update(metadata)
+      return .init(nextState: .draining(self))
+    }
+
+    @inlinable
+    mutating func setResponseTrailers(
+      _ metadata: HPACKHeaders
+    ) -> Self.NextStateAndOutput<Void> {
+      self.responseTrailers.update(metadata)
+      return .init(nextState: .draining(self))
     }
 
     @inlinable
@@ -63,24 +84,20 @@ extension ServerHandlerStateMachine {
 
     @inlinable
     mutating func sendMessage() -> Self.NextStateAndOutput<SendMessageAction> {
-      let headers: HPACKHeaders?
-
-      if self.headersWritten {
-        headers = nil
-      } else {
-        self.headersWritten = true
-        headers = self.context.initialResponseMetadata
-      }
-
+      let headers = self.responseHeaders.getIfNotWritten()
       return .init(nextState: .draining(self), output: .intercept(headers: headers))
     }
 
     @inlinable
     mutating func sendStatus() -> Self.NextStateAndOutput<SendStatusAction> {
-      let trailers = self.context.trailingResponseMetadata
       return .init(
         nextState: .finished(from: self),
-        output: .intercept(requestHeaders: self.context.requestMetadata, trailers: trailers)
+        output: .intercept(
+          requestHeaders: self.requestHeaders,
+          // If trailers had been written we'd already be in the finished state so
+          // the force unwrap is okay here.
+          trailers: self.responseTrailers.getIfNotWritten()!
+        )
       )
     }
 

+ 15 - 0
Sources/GRPC/AsyncAwaitSupport/AsyncServerHandler/ServerHandlerStateMachine/ServerHandlerStateMachine+Finished.swift

@@ -14,6 +14,7 @@
  * limitations under the License.
  */
 #if compiler(>=5.6)
+import NIOHPACK
 
 @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
 extension ServerHandlerStateMachine {
@@ -32,6 +33,20 @@ extension ServerHandlerStateMachine {
     @inlinable
     internal init(from state: ServerHandlerStateMachine.Draining) {}
 
+    @inlinable
+    mutating func setResponseHeaders(
+      _ headers: HPACKHeaders
+    ) -> Self.NextStateAndOutput<Void> {
+      return .init(nextState: .finished(self))
+    }
+
+    @inlinable
+    mutating func setResponseTrailers(
+      _ metadata: HPACKHeaders
+    ) -> Self.NextStateAndOutput<Void> {
+      return .init(nextState: .finished(self))
+    }
+
     @inlinable
     mutating func handleMetadata() -> Self.NextStateAndOutput<HandleMetadataAction> {
       return .init(nextState: .finished(self), output: .cancel)

+ 34 - 23
Sources/GRPC/AsyncAwaitSupport/AsyncServerHandler/ServerHandlerStateMachine/ServerHandlerStateMachine+Handling.swift

@@ -30,20 +30,38 @@ extension ServerHandlerStateMachine {
       Output
     >
 
-    /// Whether response headers have been written (they are written lazily rather than on receipt
-    /// of the request headers).
+    /// The response headers.
     @usableFromInline
-    internal private(set) var headersWritten: Bool
-
-    /// A context held by user handler which may be used to alter the response headers or trailers.
+    internal private(set) var responseHeaders: ResponseMetadata
+    /// The response trailers.
+    @usableFromInline
+    internal private(set) var responseTrailers: ResponseMetadata
+    /// The request headers.
     @usableFromInline
-    internal let context: GRPCAsyncServerCallContext
+    internal let requestHeaders: HPACKHeaders
 
     /// Transition from the 'Idle' state.
     @inlinable
-    init(from state: ServerHandlerStateMachine.Idle, context: GRPCAsyncServerCallContext) {
-      self.headersWritten = false
-      self.context = context
+    init(from state: ServerHandlerStateMachine.Idle, requestHeaders: HPACKHeaders) {
+      self.responseHeaders = .notWritten([:])
+      self.responseTrailers = .notWritten([:])
+      self.requestHeaders = requestHeaders
+    }
+
+    @inlinable
+    mutating func setResponseHeaders(
+      _ metadata: HPACKHeaders
+    ) -> Self.NextStateAndOutput<Void> {
+      self.responseHeaders.update(metadata)
+      return .init(nextState: .handling(self))
+    }
+
+    @inlinable
+    mutating func setResponseTrailers(
+      _ metadata: HPACKHeaders
+    ) -> Self.NextStateAndOutput<Void> {
+      self.responseTrailers.update(metadata)
+      return .init(nextState: .handling(self))
     }
 
     @inlinable
@@ -69,27 +87,20 @@ extension ServerHandlerStateMachine {
 
     @inlinable
     mutating func sendMessage() -> Self.NextStateAndOutput<SendMessageAction> {
-      let headers: HPACKHeaders?
-
-      // We send headers once, lazily, when the first message is sent back.
-      if self.headersWritten {
-        headers = nil
-      } else {
-        self.headersWritten = true
-        headers = self.context.initialResponseMetadata
-      }
-
+      let headers = self.responseHeaders.getIfNotWritten()
       return .init(nextState: .handling(self), output: .intercept(headers: headers))
     }
 
     @inlinable
     mutating func sendStatus() -> Self.NextStateAndOutput<SendStatusAction> {
-      // Sending the status is the final action taken by the user handler. We can always send
-      // them from this state and doing so means the user handler has completed.
-      let trailers = self.context.trailingResponseMetadata
       return .init(
         nextState: .finished(from: self),
-        output: .intercept(requestHeaders: self.context.requestMetadata, trailers: trailers)
+        output: .intercept(
+          requestHeaders: self.requestHeaders,
+          // If trailers had been written we'd already be in the finished state so
+          // the force unwrap is okay here.
+          trailers: self.responseTrailers.getIfNotWritten()!
+        )
       )
     }
 

+ 5 - 15
Sources/GRPC/AsyncAwaitSupport/AsyncServerHandler/ServerHandlerStateMachine/ServerHandlerStateMachine+Idle.swift

@@ -14,6 +14,7 @@
  * limitations under the License.
  */
 #if compiler(>=5.6)
+import NIOHPACK
 
 @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
 extension ServerHandlerStateMachine {
@@ -27,21 +28,12 @@ extension ServerHandlerStateMachine {
       Output
     >
 
-    /// A ref to the `UserInfo`. We hold on to this until we're ready to invoke the handler.
-    @usableFromInline
-    let userInfoRef: Ref<UserInfo>
-    /// A bag of bits required to construct a context passed to the user handler when it is invoked.
-    @usableFromInline
-    let callHandlerContext: CallHandlerContext
-
     /// The state of the inbound stream, i.e. the request stream.
     @usableFromInline
     internal private(set) var inboundState: ServerInterceptorStateMachine.InboundStreamState
 
     @inlinable
-    init(userInfoRef: Ref<UserInfo>, context: CallHandlerContext) {
-      self.userInfoRef = userInfoRef
-      self.callHandlerContext = context
+    init() {
       self.inboundState = .idle
     }
 
@@ -53,7 +45,7 @@ extension ServerHandlerStateMachine {
       case .accept:
         // We tell the caller to invoke the handler immediately: they should then call
         // 'handlerInvoked' on the state machine which will cause a transition to the next state.
-        action = .invokeHandler(self.userInfoRef, self.callHandlerContext)
+        action = .invokeHandler
       case .reject:
         action = .cancel
       }
@@ -74,11 +66,9 @@ extension ServerHandlerStateMachine {
     }
 
     @inlinable
-    mutating func handlerInvoked(
-      context: GRPCAsyncServerCallContext
-    ) -> Self.NextStateAndOutput<Void> {
+    mutating func handlerInvoked(requestHeaders: HPACKHeaders) -> Self.NextStateAndOutput<Void> {
       // The handler was invoked as a result of receiving metadata. Move to the next state.
-      return .init(nextState: .handling(from: self, context: context))
+      return .init(nextState: .handling(from: self, requestHeaders: requestHeaders))
     }
 
     @inlinable

+ 47 - 6
Sources/GRPC/AsyncAwaitSupport/AsyncServerHandler/ServerHandlerStateMachine/ServerHandlerStateMachine.swift

@@ -14,6 +14,7 @@
  * limitations under the License.
  */
 #if compiler(>=5.6)
+import NIOHPACK
 
 @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
 @usableFromInline
@@ -22,8 +23,48 @@ internal struct ServerHandlerStateMachine {
   internal private(set) var state: Self.State
 
   @inlinable
-  init(userInfoRef: Ref<UserInfo>, context: CallHandlerContext) {
-    self.state = .idle(.init(userInfoRef: userInfoRef, context: context))
+  init() {
+    self.state = .idle(.init())
+  }
+
+  @inlinable
+  mutating func setResponseHeaders(_ headers: HPACKHeaders) {
+    switch self.state {
+    case var .handling(handling):
+      let nextStateAndOutput = handling.setResponseHeaders(headers)
+      self.state = nextStateAndOutput.nextState.state
+      return nextStateAndOutput.output
+    case var .draining(draining):
+      let nextStateAndOutput = draining.setResponseHeaders(headers)
+      self.state = nextStateAndOutput.nextState.state
+      return nextStateAndOutput.output
+    case var .finished(finished):
+      let nextStateAndOutput = finished.setResponseHeaders(headers)
+      self.state = nextStateAndOutput.nextState.state
+      return nextStateAndOutput.output
+    case .idle:
+      preconditionFailure()
+    }
+  }
+
+  @inlinable
+  mutating func setResponseTrailers(_ trailers: HPACKHeaders) {
+    switch self.state {
+    case var .handling(handling):
+      let nextStateAndOutput = handling.setResponseTrailers(trailers)
+      self.state = nextStateAndOutput.nextState.state
+      return nextStateAndOutput.output
+    case var .draining(draining):
+      let nextStateAndOutput = draining.setResponseTrailers(trailers)
+      self.state = nextStateAndOutput.nextState.state
+      return nextStateAndOutput.output
+    case var .finished(finished):
+      let nextStateAndOutput = finished.setResponseTrailers(trailers)
+      self.state = nextStateAndOutput.nextState.state
+      return nextStateAndOutput.output
+    case .idle:
+      preconditionFailure()
+    }
   }
 
   @inlinable
@@ -155,10 +196,10 @@ internal struct ServerHandlerStateMachine {
   }
 
   @inlinable
-  mutating func handlerInvoked(context: GRPCAsyncServerCallContext) {
+  mutating func handlerInvoked(requestHeaders: HPACKHeaders) {
     switch self.state {
     case var .idle(idle):
-      let nextStateAndOutput = idle.handlerInvoked(context: context)
+      let nextStateAndOutput = idle.handlerInvoked(requestHeaders: requestHeaders)
       self.state = nextStateAndOutput.nextState.state
       return nextStateAndOutput.output
     case .handling:
@@ -232,9 +273,9 @@ extension ServerHandlerStateMachine.Idle {
     @inlinable
     internal static func handling(
       from: ServerHandlerStateMachine.Idle,
-      context: GRPCAsyncServerCallContext
+      requestHeaders: HPACKHeaders
     ) -> Self {
-      return Self(_state: .handling(.init(from: from, context: context)))
+      return Self(_state: .handling(.init(from: from, requestHeaders: requestHeaders)))
     }
 
     @inlinable

+ 3 - 22
Sources/GRPC/AsyncAwaitSupport/GRPCAsyncResponseStreamWriter.swift

@@ -52,48 +52,29 @@ internal final class AsyncResponseStreamWriterDelegate<Response: Sendable>: Asyn
   internal typealias End = GRPCStatus
 
   @usableFromInline
-  internal let _context: GRPCAsyncServerCallContext
-
-  @usableFromInline
-  internal let _send: @Sendable (Response, MessageMetadata) -> Void
+  internal let _send: @Sendable (Response, Compression) -> Void
 
   @usableFromInline
   internal let _finish: @Sendable (GRPCStatus) -> Void
 
-  @usableFromInline
-  internal let _compressionEnabledOnServer: Bool
-
   // Create a new AsyncResponseStreamWriterDelegate.
   //
   // - Important: the `send` and `finish` closures must be thread-safe.
   @inlinable
   internal init(
-    context: GRPCAsyncServerCallContext,
-    compressionIsEnabled: Bool,
-    send: @escaping @Sendable (Response, MessageMetadata) -> Void,
+    send: @escaping @Sendable (Response, Compression) -> Void,
     finish: @escaping @Sendable (GRPCStatus) -> Void
   ) {
-    self._context = context
-    self._compressionEnabledOnServer = compressionIsEnabled
     self._send = send
     self._finish = finish
   }
 
-  @inlinable
-  internal func _shouldCompress(_ compression: Compression) -> Bool {
-    guard self._compressionEnabledOnServer else {
-      return false
-    }
-    return compression.isEnabled(callDefault: self._context.compressionEnabled)
-  }
-
   @inlinable
   internal func _send(
     _ response: Response,
     compression: Compression = .deferToCallDefault
   ) {
-    let compress = self._shouldCompress(compression)
-    self._send(response, .init(compress: compress, flush: true))
+    self._send(response, compression)
   }
 
   // MARK: - AsyncWriterDelegate conformance.

+ 80 - 90
Sources/GRPC/AsyncAwaitSupport/GRPCAsyncServerCallContext.swift

@@ -15,114 +15,104 @@
  */
 #if compiler(>=5.6)
 
-import Logging
+@preconcurrency import Logging
 import NIOConcurrencyHelpers
-import NIOHPACK
-
-// We use a `class` here because we do not want copy-on-write semantics. The instance that the async
-// handler holds must not diverge from the instance the implementor of the RPC holds. They hold these
-// instances on different threads (EventLoop vs Task).
-//
-// We considered wrapping this in a `struct` and pass it `inout` to the RPC. This would communicate
-// explicitly that it stores mutable state. However, without copy-on-write semantics, this could
-// make for a surprising API.
-//
-// We also considered an `actor` but that felt clunky at the point of use since adopters would need
-// to `await` the retrieval of a logger or the updating of the trailers and each would require a
-// promise to glue the NIO and async-await paradigms in the handler.
-//
-// Note: this is `@unchecked Sendable`; all mutable state is protected by a lock.
-@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
-public final class GRPCAsyncServerCallContext: @unchecked Sendable {
-  private let lock = Lock()
-
-  /// Metadata for this request.
-  public let requestMetadata: HPACKHeaders
-
-  /// The logger used for this call.
-  public var logger: Logger {
-    get { self.lock.withLock {
-      self._logger
-    } }
-    set { self.lock.withLock {
-      self._logger = newValue
-    } }
-  }
+@preconcurrency import NIOHPACK
 
+@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
+public struct GRPCAsyncServerCallContext {
   @usableFromInline
-  internal var _logger: Logger
-
-  /// Whether compression should be enabled for responses, defaulting to `true`. Note that for
-  /// this value to take effect compression must have been enabled on the server and a compression
-  /// algorithm must have been negotiated with the client.
-  public var compressionEnabled: Bool {
-    get { self.lock.withLock {
-      self._compressionEnabled
-    } }
-    set { self.lock.withLock {
-      self._compressionEnabled = newValue
-    } }
-  }
+  let contextProvider: AsyncServerCallContextProvider
 
-  private var _compressionEnabled: Bool = true
+  /// Details of the request, including request headers and a logger.
+  public var request: Request
 
-  /// A `UserInfo` dictionary which is shared with the interceptor contexts for this RPC.
-  ///
-  /// - Important: While `UserInfo` has value-semantics, this property retrieves from, and sets a
-  ///   reference wrapped `UserInfo`. The contexts passed to interceptors provide the same
-  ///   reference. As such this may be used as a mechanism to pass information between interceptors
-  ///   and service providers.
-  public var userInfo: UserInfo {
-    get { self.lock.withLock {
-      self.userInfoRef.value
-    } }
-    set { self.lock.withLock {
-      self.userInfoRef.value = newValue
-    } }
+  /// A response context which may be used to set response headers and trailers.
+  public var response: Response {
+    Response(contextProvider: self.contextProvider)
   }
 
-  /// A reference to an underlying `UserInfo`. We share this with the interceptors.
-  @usableFromInline
-  internal let userInfoRef: Ref<UserInfo>
-
-  /// Metadata to return at the start of the RPC.
+  /// Access the `UserInfo` dictionary which is shared with the interceptor contexts for this RPC.
   ///
-  /// - Important: If this is required it should be updated _before_ the first response is sent via
-  /// the response stream writer. Any updates made after the first response will be ignored.
-  public var initialResponseMetadata: HPACKHeaders {
-    get { self.lock.withLock {
-      return self._initialResponseMetadata
-    } }
-    set { self.lock.withLock {
-      self._initialResponseMetadata = newValue
-    } }
+  /// - Important: While `UserInfo` has value-semantics, this function accesses a reference
+  ///   wrapped `UserInfo`. The contexts passed to interceptors provide the same reference. As such
+  ///   this may be used as a mechanism to pass information between interceptors and service
+  ///   providers.
+  public func withUserInfo<Result: Sendable>(
+    _ body: @Sendable @escaping (UserInfo) throws -> Result
+  ) async throws -> Result {
+    return try await self.contextProvider.withUserInfo(body)
   }
 
-  private var _initialResponseMetadata: HPACKHeaders = [:]
-
-  /// Metadata to return at the end of the RPC.
+  /// Modify the `UserInfo` dictionary which is shared with the interceptor contexts for this RPC.
   ///
-  /// If this is required it should be updated before returning from the handler.
-  public var trailingResponseMetadata: HPACKHeaders {
-    get { self.lock.withLock {
-      return self._trailingResponseMetadata
-    } }
-    set { self.lock.withLock {
-      self._trailingResponseMetadata = newValue
-    } }
+  /// - Important: While `UserInfo` has value-semantics, this function accesses a reference
+  ///   wrapped `UserInfo`. The contexts passed to interceptors provide the same reference. As such
+  ///   this may be used as a mechanism to pass information between interceptors and service
+  ///   providers.
+  public func withMutableUserInfo<Result: Sendable>(
+    _ modify: @Sendable @escaping (inout UserInfo) -> Result
+  ) async throws -> Result {
+    return try await self.contextProvider.withMutableUserInfo(modify)
   }
 
-  private var _trailingResponseMetadata: HPACKHeaders = [:]
-
   @inlinable
   internal init(
     headers: HPACKHeaders,
     logger: Logger,
-    userInfoRef: Ref<UserInfo>
+    contextProvider: AsyncServerCallContextProvider
   ) {
-    self.requestMetadata = headers
-    self.userInfoRef = userInfoRef
-    self._logger = logger
+    self.request = Request(headers: headers, logger: logger)
+    self.contextProvider = contextProvider
+  }
+}
+
+@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
+extension GRPCAsyncServerCallContext {
+  public struct Request: Sendable {
+    /// The request headers received from the client at the start of the RPC.
+    public var headers: HPACKHeaders
+
+    /// A logger.
+    public var logger: Logger
+
+    @usableFromInline
+    init(headers: HPACKHeaders, logger: Logger) {
+      self.headers = headers
+      self.logger = logger
+    }
+  }
+
+  public struct Response {
+    private let contextProvider: AsyncServerCallContextProvider
+
+    /// Set the metadata to return at the start of the RPC.
+    ///
+    /// - Important: If this is required it should be updated _before_ the first response is sent
+    ///   via the response stream writer. Updates must not be made after the first response has
+    ///   been sent.
+    public func setHeaders(_ headers: HPACKHeaders) async throws {
+      try await self.contextProvider.setResponseHeaders(headers)
+    }
+
+    /// Set the metadata to return at the end of the RPC.
+    ///
+    /// If this is required it must be updated before returning from the handler.
+    public func setTrailers(_ trailers: HPACKHeaders) async throws {
+      try await self.contextProvider.setResponseTrailers(trailers)
+    }
+
+    /// Whether compression should be enabled for responses, defaulting to `true`. Note that for
+    /// this value to take effect compression must have been enabled on the server and a compression
+    /// algorithm must have been negotiated with the client.
+    public func compressResponses(_ compress: Bool) async throws {
+      try await self.contextProvider.setResponseCompression(compress)
+    }
+
+    @usableFromInline
+    internal init(contextProvider: AsyncServerCallContextProvider) {
+      self.contextProvider = contextProvider
+    }
   }
 }
 

+ 112 - 14
Sources/GRPC/AsyncAwaitSupport/GRPCAsyncServerHandler.swift

@@ -14,6 +14,7 @@
  * limitations under the License.
  */
 #if compiler(>=5.6)
+import Logging
 import NIOCore
 import NIOHPACK
 
@@ -182,9 +183,30 @@ internal final class AsyncServerHandler<
   @usableFromInline
   internal let allocator: ByteBufferAllocator
 
+  /// A user-provided error delegate which, if provided, is used to transform errors and potentially
+  /// pack errors into trailers.
   @usableFromInline
   internal let errorDelegate: ServerErrorDelegate?
 
+  /// A logger.
+  @usableFromInline
+  internal let logger: Logger
+
+  /// A reference to the user info. This is shared with the interceptor pipeline and may be accessed
+  /// from the async call context. `UserInfo` is _not_ `Sendable` and must always be accessed from
+  /// an appropriate event loop.
+  @usableFromInline
+  internal let userInfoRef: Ref<UserInfo>
+
+  /// Whether compression is enabled on the server and an algorithm has been negotiated with
+  /// the client
+  @usableFromInline
+  internal let compressionEnabledOnRPC: Bool
+
+  /// Whether the RPC method would like to compress responses (if possible). Defaults to true.
+  @usableFromInline
+  internal var compressResponsesIfPossible: Bool
+
   /// A state machine for the interceptor pipeline.
   @usableFromInline
   internal private(set) var interceptorStateMachine: ServerInterceptorStateMachine
@@ -232,9 +254,12 @@ internal final class AsyncServerHandler<
     self.allocator = context.allocator
     self.responseWriter = context.responseWriter
     self.errorDelegate = context.errorDelegate
+    self.compressionEnabledOnRPC = context.encoding.isEnabled
+    self.compressResponsesIfPossible = true
+    self.logger = context.logger
 
-    let userInfoRef = Ref(UserInfo())
-    self.handlerStateMachine = .init(userInfoRef: userInfoRef, context: context)
+    self.userInfoRef = Ref(UserInfo())
+    self.handlerStateMachine = .init()
     self.handlerComponents = nil
 
     self.userHandler = userHandler
@@ -247,7 +272,7 @@ internal final class AsyncServerHandler<
       path: context.path,
       callType: callType,
       remoteAddress: context.remoteAddress,
-      userInfoRef: userInfoRef,
+      userInfoRef: self.userInfoRef,
       interceptors: interceptors,
       onRequestPart: self.receiveInterceptedPart(_:),
       onResponsePart: self.sendInterceptedPart(_:promise:)
@@ -377,7 +402,7 @@ internal final class AsyncServerHandler<
     }
 
     switch self.handlerStateMachine.handleMetadata() {
-    case let .invokeHandler(userInfoRef, callHandlerContext):
+    case .invokeHandler:
       // We're going to invoke the handler. We need to create a handful of things in order to do
       // that:
       //
@@ -393,16 +418,14 @@ internal final class AsyncServerHandler<
       // as a result of an error or when `self.finish()` is called).
       let handlerContext = GRPCAsyncServerCallContext(
         headers: headers,
-        logger: callHandlerContext.logger,
-        userInfoRef: userInfoRef
+        logger: self.logger,
+        contextProvider: self
       )
 
       let requestSource = PassthroughMessageSource<Request, Error>()
 
       let writerDelegate = AsyncResponseStreamWriterDelegate(
-        context: handlerContext,
-        compressionIsEnabled: callHandlerContext.encoding.isEnabled,
-        send: self.interceptResponseMessage(_:metadata:),
+        send: self.interceptResponseMessage(_:compression:),
         finish: self.interceptResponseStatus(_:)
       )
       let writer = AsyncWriter(delegate: writerDelegate)
@@ -423,7 +446,7 @@ internal final class AsyncServerHandler<
       }
 
       // Update our state before invoke the handler.
-      self.handlerStateMachine.handlerInvoked(context: handlerContext)
+      self.handlerStateMachine.handlerInvoked(requestHeaders: headers)
       self.handlerComponents = ServerHandlerComponents(
         requestSource: requestSource,
         responseWriter: writer,
@@ -540,7 +563,7 @@ internal final class AsyncServerHandler<
   // MARK: - User Function To Interceptors
 
   @inlinable
-  internal func _interceptResponseMessage(_ response: Response, metadata: MessageMetadata) {
+  internal func _interceptResponseMessage(_ response: Response, compression: Compression) {
     self.eventLoop.assertInEventLoop()
 
     switch self.handlerStateMachine.sendMessage() {
@@ -559,6 +582,13 @@ internal final class AsyncServerHandler<
     case .intercept(.none):
       switch self.interceptorStateMachine.interceptResponseMessage() {
       case .intercept:
+        let senderWantsCompression = compression.isEnabled(
+          callDefault: self.compressResponsesIfPossible
+        )
+
+        let compress = self.compressionEnabledOnRPC && senderWantsCompression
+
+        let metadata = MessageMetadata(compress: compress, flush: true)
         self.interceptors?.send(.message(response, metadata), promise: nil)
       case .cancel:
         return self.cancel(error: nil)
@@ -573,12 +603,12 @@ internal final class AsyncServerHandler<
 
   @Sendable
   @inlinable
-  internal func interceptResponseMessage(_ response: Response, metadata: MessageMetadata) {
+  internal func interceptResponseMessage(_ response: Response, compression: Compression) {
     if self.eventLoop.inEventLoop {
-      self._interceptResponseMessage(response, metadata: metadata)
+      self._interceptResponseMessage(response, compression: compression)
     } else {
       self.eventLoop.execute {
-        self._interceptResponseMessage(response, metadata: metadata)
+        self._interceptResponseMessage(response, compression: compression)
       }
     }
   }
@@ -698,6 +728,74 @@ internal final class AsyncServerHandler<
   }
 }
 
+@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
+extension AsyncServerHandler: AsyncServerCallContextProvider {
+  @usableFromInline
+  internal func setResponseHeaders(_ headers: HPACKHeaders) async throws {
+    let completed = self.eventLoop.submit {
+      self.handlerStateMachine.setResponseHeaders(headers)
+    }
+    try await completed.get()
+  }
+
+  @usableFromInline
+  internal func setResponseTrailers(_ headers: HPACKHeaders) async throws {
+    let completed = self.eventLoop.submit {
+      self.handlerStateMachine.setResponseTrailers(headers)
+    }
+    try await completed.get()
+  }
+
+  @usableFromInline
+  internal func setResponseCompression(_ enabled: Bool) async throws {
+    let completed = self.eventLoop.submit {
+      self.compressResponsesIfPossible = enabled
+    }
+    try await completed.get()
+  }
+
+  @usableFromInline
+  func withUserInfo<Result: Sendable>(
+    _ modify: @Sendable @escaping (UserInfo) throws -> Result
+  ) async throws -> Result {
+    let result = self.eventLoop.submit {
+      try modify(self.userInfoRef.value)
+    }
+    return try await result.get()
+  }
+
+  @usableFromInline
+  func withMutableUserInfo<Result: Sendable>(
+    _ modify: @Sendable @escaping (inout UserInfo) throws -> Result
+  ) async throws -> Result {
+    let result = self.eventLoop.submit {
+      try modify(&self.userInfoRef.value)
+    }
+    return try await result.get()
+  }
+}
+
+/// This protocol exists so that the generic server handler can be erased from the
+/// `GRPCAsyncServerCallContext`.
+///
+/// It provides methods which update context on the async handler by first executing onto the
+/// correct event loop.
+@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
+@usableFromInline
+protocol AsyncServerCallContextProvider {
+  func setResponseHeaders(_ headers: HPACKHeaders) async throws
+  func setResponseTrailers(_ trailers: HPACKHeaders) async throws
+  func setResponseCompression(_ enabled: Bool) async throws
+
+  func withUserInfo<Result: Sendable>(
+    _ modify: @Sendable @escaping (UserInfo) throws -> Result
+  ) async throws -> Result
+
+  func withMutableUserInfo<Result: Sendable>(
+    _ modify: @Sendable @escaping (inout UserInfo) throws -> Result
+  ) async throws -> Result
+}
+
 @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
 @usableFromInline
 internal struct ServerHandlerComponents<Request, Delegate: AsyncWriterDelegate> {

+ 5 - 5
Sources/GRPCInteroperabilityTestsImplementation/TestServiceAsyncProvider.swift

@@ -74,7 +74,7 @@ public class TestServiceAsyncProvider: Grpc_Testing_TestServiceAsyncProvider {
     // We can't validate messages at the wire-encoding layer (i.e. where the compression byte is
     // set), so we have to check via the encoding header. Note that it is possible for the header
     // to be set and for the message to not be compressed.
-    if request.expectCompressed.value, !context.requestMetadata.contains(name: "grpc-encoding") {
+    if request.expectCompressed.value, !context.request.headers.contains(name: "grpc-encoding") {
       throw GRPCStatus(
         code: .invalidArgument,
         message: "Expected compressed request, but 'grpc-encoding' was missing"
@@ -83,14 +83,14 @@ public class TestServiceAsyncProvider: Grpc_Testing_TestServiceAsyncProvider {
 
     // Should we enable compression? The C++ interoperability client only expects compression if
     // explicitly requested; we'll do the same.
-    context.compressionEnabled = request.responseCompressed.value
+    try await context.response.compressResponses(request.responseCompressed.value)
 
     if request.shouldEchoStatus {
       let code = GRPCStatus.Code(rawValue: numericCast(request.responseStatus.code)) ?? .unknown
       throw GRPCStatus(code: code, message: request.responseStatus.message)
     }
 
-    if context.requestMetadata.shouldEchoMetadata {
+    if context.request.headers.shouldEchoMetadata {
       throw Self.echoMetadataNotImplemented
     }
 
@@ -158,7 +158,7 @@ public class TestServiceAsyncProvider: Grpc_Testing_TestServiceAsyncProvider {
 
     for try await request in requestStream {
       if request.expectCompressed.value {
-        guard context.requestMetadata.contains(name: "grpc-encoding") else {
+        guard context.request.headers.contains(name: "grpc-encoding") else {
           throw GRPCStatus(
             code: .invalidArgument,
             message: "Expected compressed request, but 'grpc-encoding' was missing"
@@ -183,7 +183,7 @@ public class TestServiceAsyncProvider: Grpc_Testing_TestServiceAsyncProvider {
     context: GRPCAsyncServerCallContext
   ) async throws {
     // We don't have support for this yet so just fail the call.
-    if context.requestMetadata.shouldEchoMetadata {
+    if context.request.headers.shouldEchoMetadata {
       throw Self.echoMetadataNotImplemented
     }
 

+ 73 - 31
Tests/GRPCTests/AsyncAwaitSupport/AsyncServerHandler/ServerHandlerStateMachine/ServerHandlerStateMachineTests.swift

@@ -30,21 +30,18 @@ internal final class ServerHandlerStateMachineTests: GRPCTestCase {
   }
 
   private func makeStateMachine(inState state: InitialState = .idle) -> ServerHandlerStateMachine {
-    var stateMachine = ServerHandlerStateMachine(
-      userInfoRef: Ref(UserInfo()),
-      context: self.makeCallHandlerContext()
-    )
+    var stateMachine = ServerHandlerStateMachine()
 
     switch state {
     case .idle:
       return stateMachine
     case .handling:
       stateMachine.handleMetadata().assertInvokeHandler()
-      stateMachine.handlerInvoked(context: self.makeAsyncServerContext())
+      stateMachine.handlerInvoked(requestHeaders: [:])
       return stateMachine
     case .draining:
       stateMachine.handleMetadata().assertInvokeHandler()
-      stateMachine.handlerInvoked(context: self.makeAsyncServerContext())
+      stateMachine.handlerInvoked(requestHeaders: [:])
       stateMachine.handleEnd().assertForward()
       return stateMachine
     case .finished:
@@ -69,14 +66,6 @@ internal final class ServerHandlerStateMachineTests: GRPCTestCase {
     )
   }
 
-  private func makeAsyncServerContext() -> GRPCAsyncServerCallContext {
-    return GRPCAsyncServerCallContext(
-      headers: [:],
-      logger: self.logger,
-      userInfoRef: Ref(UserInfo())
-    )
-  }
-
   // MARK: - Test Cases
 
   func testHandleMetadataWhenIdle() {
@@ -84,7 +73,7 @@ internal final class ServerHandlerStateMachineTests: GRPCTestCase {
     // Receiving metadata is the signal to invoke the user handler.
     stateMachine.handleMetadata().assertInvokeHandler()
     // On invoking the handler we move to the next state. No output.
-    stateMachine.handlerInvoked(context: self.makeAsyncServerContext())
+    stateMachine.handlerInvoked(requestHeaders: [:])
   }
 
   func testHandleMetadataWhenHandling() {
@@ -219,27 +208,70 @@ internal final class ServerHandlerStateMachineTests: GRPCTestCase {
     var stateMachine = self.makeStateMachine(inState: .finished)
     stateMachine.cancel().assertDoCancel()
   }
+
+  func testSetResponseHeadersWhenHandling() {
+    var stateMachine = self.makeStateMachine(inState: .handling)
+    stateMachine.setResponseHeaders(["foo": "bar"])
+    stateMachine.sendMessage().assertInterceptHeadersThenMessage { headers in
+      XCTAssertEqual(headers, ["foo": "bar"])
+    }
+  }
+
+  func testSetResponseHeadersWhenHandlingAreMovedToDraining() {
+    var stateMachine = self.makeStateMachine(inState: .handling)
+    stateMachine.setResponseHeaders(["foo": "bar"])
+    stateMachine.handleEnd().assertForward()
+    stateMachine.sendMessage().assertInterceptHeadersThenMessage { headers in
+      XCTAssertEqual(headers, ["foo": "bar"])
+    }
+  }
+
+  func testSetResponseHeadersWhenDraining() {
+    var stateMachine = self.makeStateMachine(inState: .draining)
+    stateMachine.setResponseHeaders(["foo": "bar"])
+    stateMachine.sendMessage().assertInterceptHeadersThenMessage { headers in
+      XCTAssertEqual(headers, ["foo": "bar"])
+    }
+  }
+
+  func testSetResponseHeadersWhenFinished() {
+    var stateMachine = self.makeStateMachine(inState: .finished)
+    stateMachine.setResponseHeaders(["foo": "bar"])
+    // Nothing we can assert on, only that we don't crash.
+  }
+
+  func testSetResponseTrailersWhenHandling() {
+    var stateMachine = self.makeStateMachine(inState: .handling)
+    stateMachine.setResponseTrailers(["foo": "bar"])
+    stateMachine.sendStatus().assertIntercept { trailers in
+      XCTAssertEqual(trailers, ["foo": "bar"])
+    }
+  }
+
+  func testSetResponseTrailersWhenDraining() {
+    var stateMachine = self.makeStateMachine(inState: .draining)
+    stateMachine.setResponseTrailers(["foo": "bar"])
+    stateMachine.sendStatus().assertIntercept { trailers in
+      XCTAssertEqual(trailers, ["foo": "bar"])
+    }
+  }
+
+  func testSetResponseTrailersWhenFinished() {
+    var stateMachine = self.makeStateMachine(inState: .finished)
+    stateMachine.setResponseTrailers(["foo": "bar"])
+    // Nothing we can assert on, only that we don't crash.
+  }
 }
 
 // MARK: - Action Assertions
 
 extension ServerHandlerStateMachine.HandleMetadataAction {
   func assertInvokeHandler() {
-    switch self {
-    case .invokeHandler:
-      ()
-    case .cancel:
-      XCTFail("Expected 'invokeHandler' but got \(self)")
-    }
+    XCTAssertEqual(self, .invokeHandler)
   }
 
   func assertInvokeCancel() {
-    switch self {
-    case .cancel:
-      ()
-    case .invokeHandler:
-      XCTFail("Expected 'cancel' but got \(self)")
-    }
+    XCTAssertEqual(self, .cancel)
   }
 }
 
@@ -254,8 +286,13 @@ extension ServerHandlerStateMachine.HandleMessageAction {
 }
 
 extension ServerHandlerStateMachine.SendMessageAction {
-  func assertInterceptHeadersThenMessage() {
-    XCTAssertEqual(self, .intercept(headers: [:]))
+  func assertInterceptHeadersThenMessage(_ verify: (HPACKHeaders) -> Void = { _ in }) {
+    switch self {
+    case let .intercept(headers: .some(headers)):
+      verify(headers)
+    default:
+      XCTFail("Expected .intercept(.some) but got \(self)")
+    }
   }
 
   func assertInterceptMessage() {
@@ -268,8 +305,13 @@ extension ServerHandlerStateMachine.SendMessageAction {
 }
 
 extension ServerHandlerStateMachine.SendStatusAction {
-  func assertIntercept() {
-    XCTAssertEqual(self, .intercept(requestHeaders: [:], trailers: [:]))
+  func assertIntercept(_ verify: (HPACKHeaders) -> Void = { _ in }) {
+    switch self {
+    case let .intercept(_, trailers: trailers):
+      verify(trailers)
+    case .drop:
+      XCTFail("Expected .intercept but got .drop")
+    }
   }
 
   func assertDrop() {

+ 3 - 31
Tests/GRPCTests/GRPCAsyncServerHandlerTests.swift

@@ -173,7 +173,7 @@ class AsyncServerHandlerTests: GRPCTestCase {
     let handler = self.makeHandler(
       encoding: .enabled(.init(decompressionLimit: .absolute(.max)))
     ) { requests, responseStreamWriter, context in
-      context.compressionEnabled = false
+      try await context.response.compressResponses(false)
       return try await Self.echo(
         requests: requests,
         responseStreamWriter: responseStreamWriter,
@@ -206,9 +206,9 @@ class AsyncServerHandlerTests: GRPCTestCase {
 
   func testResponseHeadersAndTrailersSentFromContext() async throws {
     let handler = self.makeHandler { _, responseStreamWriter, context in
-      context.initialResponseMetadata = ["pontiac": "bandit"]
+      try await context.response.setHeaders(["pontiac": "bandit"])
       try await responseStreamWriter.send("1")
-      context.trailingResponseMetadata = ["disco": "strangler"]
+      try await context.response.setTrailers(["disco": "strangler"])
     }
     defer {
       XCTAssertNoThrow(try self.loop.submit { handler.finish() }.wait())
@@ -230,34 +230,6 @@ class AsyncServerHandlerTests: GRPCTestCase {
     try await responseStream.next().assertNil()
   }
 
-  func testResponseHeadersDroppedIfSetAfterFirstResponse() async throws {
-    throw XCTSkip("Setting metadata is racy. This test will not reliably pass until that is fixed.")
-    // let handler = self.makeHandler { _, responseStreamWriter, context in
-    //   // try await context.sendHeaders(...)
-    //   try await responseStreamWriter.send("1")
-    //   context.initialResponseMetadata = ["pontiac": "bandit"]
-    //   context.trailingResponseMetadata = ["disco": "strangler"]
-    // }
-    // defer {
-    //   XCTAssertNoThrow(try self.loop.submit { handler.finish() }.wait())
-    // }
-    //
-    // self.loop.execute {
-    //   handler.receiveMetadata([:])
-    //   handler.receiveEnd()
-    // }
-    //
-    // let responseStream = self.recorder.responseSequence.makeAsyncIterator()
-    // try await responseStream.next().assertMetadata { headers in
-    //   XCTAssertEqual(headers, [:])
-    // }
-    // try await responseStream.next().assertMessage()
-    // try await responseStream.next().assertStatus { _, trailers in
-    //   XCTAssertEqual(trailers, ["disco": "strangler"])
-    // }
-    // try await responseStream.next().assertNil()
-  }
-
   func testThrowingDeserializer() async throws {
     let handler = AsyncServerHandler(
       context: self.makeCallHandlerContext(),