Browse Source

Add the server RPC executor. (#1715)

Motivation:

The server needs to handle accepted RPC streams by turning them into
requests and letting a user provided handler handle the request, turning
it into the response. The server executor is then responsible for
handling the response and writing it back to the client.

Modifications:

- Add the server RPC executor
- Add testing Utilities and tests for the server executor

Result:

Server RPCs can be handled
George Barnett 2 years ago
parent
commit
d5a05a2591

+ 1 - 0
Package.swift

@@ -164,6 +164,7 @@ extension Target {
     name: "GRPCCore",
     dependencies: [
       .dequeModule,
+      .atomics
     ],
     path: "Sources/GRPCCore"
   )

+ 299 - 0
Sources/GRPCCore/Call/Server/Internal/ServerRPCExecutor.swift

@@ -0,0 +1,299 @@
+/*
+ * Copyright 2023, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+@available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
+@usableFromInline
+struct ServerRPCExecutor {
+  /// Executes an RPC using the provided handler.
+  ///
+  /// - Parameters:
+  ///   - stream: The accepted stream to execute the RPC on.
+  ///   - deserializer: A deserializer for messages received from the client.
+  ///   - serializer: A serializer for messages to send to the client.
+  ///   - interceptors: Server interceptors to apply to this RPC.
+  ///   - handler: A handler which turns the request into a response.
+  @inlinable
+  static func execute<Input, Output>(
+    stream: RPCStream<RPCAsyncSequence<RPCRequestPart>, RPCWriter<RPCResponsePart>.Closable>,
+    deserializer: some MessageDeserializer<Input>,
+    serializer: some MessageSerializer<Output>,
+    interceptors: [any ServerInterceptor],
+    handler: @Sendable @escaping (
+      _ request: ServerRequest.Stream<Input>
+    ) async throws -> ServerResponse.Stream<Output>
+  ) async {
+    // Wait for the first request part from the transport.
+    let firstPart = await Self._waitForFirstRequestPart(inbound: stream.inbound)
+
+    switch firstPart {
+    case .process(let metadata, let inbound):
+      await Self._execute(
+        method: stream.descriptor,
+        metadata: metadata,
+        inbound: inbound,
+        outbound: stream.outbound,
+        deserializer: deserializer,
+        serializer: serializer,
+        interceptors: interceptors,
+        handler: handler
+      )
+
+    case .reject(let error):
+      // Stream can't be handled; write an error status and close.
+      let status = Status(code: Status.Code(error.code), message: error.message)
+      try? await stream.outbound.write(.status(status, error.metadata))
+      stream.outbound.finish()
+    }
+  }
+
+  @inlinable
+  static func _execute<Input, Output>(
+    method: MethodDescriptor,
+    metadata: Metadata,
+    inbound: UnsafeTransfer<RPCAsyncSequence<RPCRequestPart>.AsyncIterator>,
+    outbound: RPCWriter<RPCResponsePart>.Closable,
+    deserializer: some MessageDeserializer<Input>,
+    serializer: some MessageSerializer<Output>,
+    interceptors: [any ServerInterceptor],
+    handler: @escaping @Sendable (
+      _ request: ServerRequest.Stream<Input>
+    ) async throws -> ServerResponse.Stream<Output>
+  ) async {
+    await withTaskGroup(of: ServerExecutorTask.self) { group in
+      if let timeout = metadata.timeout {
+        group.addTask {
+          let result = await Result {
+            try await Task.sleep(until: .now.advanced(by: timeout), clock: .continuous)
+          }
+          return .timedOut(result)
+        }
+      }
+
+      group.addTask {
+        await Self._processRPC(
+          method: method,
+          metadata: metadata,
+          inbound: inbound,
+          outbound: outbound,
+          deserializer: deserializer,
+          serializer: serializer,
+          interceptors: interceptors,
+          handler: handler
+        )
+        return .executed
+      }
+
+      while let next = await group.next() {
+        switch next {
+        case .timedOut(.success):
+          // Timeout expired; cancel the work.
+          group.cancelAll()
+
+        case .timedOut(.failure):
+          // Timeout failed (because it was cancelled). Wait for more tasks to finish.
+          ()
+
+        case .executed:
+          // The work finished. Cancel any remaining tasks.
+          group.cancelAll()
+        }
+      }
+    }
+  }
+
+  @inlinable
+  static func _processRPC<Input, Output>(
+    method: MethodDescriptor,
+    metadata: Metadata,
+    inbound: UnsafeTransfer<RPCAsyncSequence<RPCRequestPart>.AsyncIterator>,
+    outbound: RPCWriter<RPCResponsePart>.Closable,
+    deserializer: some MessageDeserializer<Input>,
+    serializer: some MessageSerializer<Output>,
+    interceptors: [any ServerInterceptor],
+    handler: @escaping @Sendable (
+      ServerRequest.Stream<Input>
+    ) async throws -> ServerResponse.Stream<Output>
+  ) async {
+    let messages = AsyncIteratorSequence(inbound.wrappedValue).map { part throws -> Input in
+      switch part {
+      case .message(let bytes):
+        return try deserializer.deserialize(bytes)
+      case .metadata:
+        throw RPCError(
+          code: .internalError,
+          message: """
+            Server received an extra set of metadata. Only one set of metadata may be received \
+            at the start of the RPC. This is likely to be caused by a misbehaving client.
+            """
+        )
+      }
+    }
+
+    let response = await Result {
+      // Run the request through the interceptors, finally passing it to the handler.
+      return try await Self._intercept(
+        request: ServerRequest.Stream(
+          metadata: metadata,
+          messages: RPCAsyncSequence(wrapping: messages)
+        ),
+        context: ServerInterceptorContext(descriptor: method),
+        interceptors: interceptors
+      ) { request, _ in
+        try await handler(request)
+      }
+    }.castError(to: RPCError.self) { error in
+      RPCError(code: .unknown, message: "Service method threw an unknown error.", cause: error)
+    }.flatMap { response in
+      response.accepted
+    }
+
+    let status: Status
+    let metadata: Metadata
+
+    switch response {
+    case .success(let contents):
+      let result = await Result {
+        // Write the metadata and run the producer.
+        try await outbound.write(.metadata(contents.metadata))
+        return try await contents.producer(
+          .serializingToRPCResponsePart(into: outbound, with: serializer)
+        )
+      }.castError(to: RPCError.self) { error in
+        RPCError(code: .unknown, message: "", cause: error)
+      }
+
+      switch result {
+      case .success(let trailingMetadata):
+        status = .ok
+        metadata = trailingMetadata
+      case .failure(let error):
+        status = Status(code: Status.Code(error.code), message: error.message)
+        metadata = error.metadata
+      }
+
+    case .failure(let error):
+      status = Status(code: Status.Code(error.code), message: error.message)
+      metadata = error.metadata
+    }
+
+    try? await outbound.write(.status(status, metadata))
+    outbound.finish()
+  }
+
+  @inlinable
+  static func _waitForFirstRequestPart(
+    inbound: RPCAsyncSequence<RPCRequestPart>
+  ) async -> OnFirstRequestPart {
+    var iterator = inbound.makeAsyncIterator()
+    let part = await Result { try await iterator.next() }
+    let onFirstRequestPart: OnFirstRequestPart
+
+    switch part {
+    case .success(.metadata(let metadata)):
+      // The only valid first part.
+      onFirstRequestPart = .process(metadata, UnsafeTransfer(iterator))
+
+    case .success(.none):
+      // Empty stream; reject.
+      let error = RPCError(code: .internalError, message: "Empty inbound server stream.")
+      onFirstRequestPart = .reject(error)
+
+    case .success(.message):
+      let error = RPCError(
+        code: .internalError,
+        message: """
+          Invalid inbound server stream; received message bytes at start of stream. This is \
+          likely to be a transport specific bug.
+          """
+      )
+      onFirstRequestPart = .reject(error)
+
+    case .failure(let error):
+      let error = RPCError(
+        code: .unknown,
+        message: "Inbound server stream threw error when reading metadata.",
+        cause: error
+      )
+      onFirstRequestPart = .reject(error)
+    }
+
+    return onFirstRequestPart
+  }
+
+  @usableFromInline
+  enum OnFirstRequestPart {
+    case process(Metadata, UnsafeTransfer<RPCAsyncSequence<RPCRequestPart>.AsyncIterator>)
+    case reject(RPCError)
+  }
+
+  @usableFromInline
+  enum ServerExecutorTask {
+    case timedOut(Result<Void, Error>)
+    case executed
+  }
+}
+
+@available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
+extension ServerRPCExecutor {
+  @inlinable
+  static func _intercept<Input, Output>(
+    request: ServerRequest.Stream<Input>,
+    context: ServerInterceptorContext,
+    interceptors: [any ServerInterceptor],
+    finally: @escaping @Sendable (
+      _ request: ServerRequest.Stream<Input>,
+      _ context: ServerInterceptorContext
+    ) async throws -> ServerResponse.Stream<Output>
+  ) async throws -> ServerResponse.Stream<Output> {
+    return try await self._intercept(
+      request: request,
+      context: context,
+      iterator: interceptors.makeIterator(),
+      finally: finally
+    )
+  }
+
+  @inlinable
+  static func _intercept<Input, Output>(
+    request: ServerRequest.Stream<Input>,
+    context: ServerInterceptorContext,
+    iterator: Array<any ServerInterceptor>.Iterator,
+    finally: @escaping @Sendable (
+      _ request: ServerRequest.Stream<Input>,
+      _ context: ServerInterceptorContext
+    ) async throws -> ServerResponse.Stream<Output>
+  ) async throws -> ServerResponse.Stream<Output> {
+    var iterator = iterator
+
+    switch iterator.next() {
+    case .some(let interceptor):
+      let iter = iterator
+      do {
+        return try await interceptor.intercept(request: request, context: context) {
+          try await self._intercept(request: $0, context: $1, iterator: iter, finally: finally)
+        }
+      } catch let error as RPCError {
+        return ServerResponse.Stream(error: error)
+      } catch let other {
+        let error = RPCError(code: .unknown, message: "", cause: other)
+        return ServerResponse.Stream(error: error)
+      }
+
+    case .none:
+      return try await finally(request, context)
+    }
+  }
+}

+ 13 - 0
Sources/GRPCCore/Internal/Metadata+GRPC.swift

@@ -36,11 +36,24 @@ extension Metadata {
       RetryPushback(milliseconds: $0)
     }
   }
+
+  @inlinable
+  var timeout: Duration? {
+    // Temporary hack to support tests; only supports nanoseconds.
+    guard let value = self.firstString(forKey: .timeout) else { return nil }
+    guard value.utf8.last == UTF8.CodeUnit(ascii: "n") else { return nil }
+    var index = value.utf8.endIndex
+    value.utf8.formIndex(before: &index)
+    guard let digits = String(value.utf8[..<index]) else { return nil }
+    guard let nanoseconds = Int64(digits) else { return nil }
+    return .nanoseconds(nanoseconds)
+  }
 }
 
 extension Metadata {
   @usableFromInline
   enum GRPCKey: String, Sendable, Hashable {
+    case timeout = "grpc-timeout"
     case retryPushbackMs = "grpc-retry-pushback-ms"
     case previousRPCAttempts = "grpc-previous-rpc-attempts"
   }

+ 6 - 0
Sources/GRPCCore/Status.swift

@@ -65,6 +65,12 @@ public struct Status: @unchecked Sendable, Hashable {
       self.storage = Storage(code: code, message: message)
     }
   }
+
+  /// A status with code ``Code-swift.struct/ok`` and an empty message.
+  @inlinable
+  internal static var ok: Self {
+    Status(code: .ok, message: "")
+  }
 }
 
 extension Status: CustomStringConvertible {

+ 69 - 0
Sources/GRPCCore/Streaming/Internal/AsyncIteratorSequence.swift

@@ -0,0 +1,69 @@
+/*
+ * Copyright 2023, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import Atomics
+
+@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
+@usableFromInline
+/// An `AsyncSequence` which wraps an existing async iterator.
+struct AsyncIteratorSequence<Base: AsyncIteratorProtocol>: AsyncSequence {
+  @usableFromInline
+  typealias Element = Base.Element
+
+  /// The base iterator.
+  @usableFromInline
+  private(set) var base: Base
+
+  /// Set to `true` when an iterator has been made.
+  @usableFromInline
+  let _hasMadeIterator = ManagedAtomic(false)
+
+  @inlinable
+  init(_ base: Base) {
+    self.base = base
+  }
+
+  @usableFromInline
+  struct AsyncIterator: AsyncIteratorProtocol {
+    @usableFromInline
+    private(set) var base: Base
+
+    @inlinable
+    init(base: Base) {
+      self.base = base
+    }
+
+    @inlinable
+    mutating func next() async throws -> Element? {
+      try await self.base.next()
+    }
+  }
+
+  @inlinable
+  func makeAsyncIterator() -> AsyncIterator {
+    let (exchanged, original) = self._hasMadeIterator.compareExchange(
+      expected: false,
+      desired: true,
+      ordering: .relaxed
+    )
+
+    guard exchanged else {
+      fatalError("Only one iterator can be made")
+    }
+
+    assert(!original)
+    return AsyncIterator(base: self.base)
+  }
+}

+ 53 - 0
Sources/GRPCCore/Streaming/Internal/RPCWriter+MessageToRPCResponsePart.swift

@@ -0,0 +1,53 @@
+/*
+ * Copyright 2023, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
+@usableFromInline
+struct MessageToRPCResponsePartWriter<Serializer: MessageSerializer>: RPCWriterProtocol {
+  @usableFromInline
+  typealias Element = Serializer.Message
+
+  @usableFromInline
+  let base: RPCWriter<RPCResponsePart>
+  @usableFromInline
+  let serializer: Serializer
+
+  @inlinable
+  init(serializer: Serializer, base: some RPCWriterProtocol<RPCResponsePart>) {
+    self.serializer = serializer
+    self.base = RPCWriter(wrapping: base)
+  }
+
+  @inlinable
+  func write(contentsOf elements: some Sequence<Serializer.Message>) async throws {
+    let requestParts = try elements.map { message -> RPCResponsePart in
+      .message(try self.serializer.serialize(message))
+    }
+
+    try await self.base.write(contentsOf: requestParts)
+  }
+}
+
+@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
+extension RPCWriter {
+  @inlinable
+  static func serializingToRPCResponsePart(
+    into writer: some RPCWriterProtocol<RPCResponsePart>,
+    with serializer: some MessageSerializer<Element>
+  ) -> Self {
+    return RPCWriter(wrapping: MessageToRPCResponsePartWriter(serializer: serializer, base: writer))
+  }
+}

+ 137 - 0
Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTestSupport/ServerRPCExecutorTestHarness.swift

@@ -0,0 +1,137 @@
+/*
+ * Copyright 2023, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import XCTest
+
+@testable import GRPCCore
+
+@available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
+struct ServerRPCExecutorTestHarness {
+  struct ServerHandler<Input, Output>: Sendable {
+    let fn: @Sendable (ServerRequest.Stream<Input>) async throws -> ServerResponse.Stream<Output>
+
+    init(
+      _ fn: @escaping @Sendable (
+        ServerRequest.Stream<Input>
+      ) async throws -> ServerResponse.Stream<Output>
+    ) {
+      self.fn = fn
+    }
+
+    func handle(
+      _ request: ServerRequest.Stream<Input>
+    ) async throws -> ServerResponse.Stream<Output> {
+      try await self.fn(request)
+    }
+
+    static func throwing(_ error: any Error) -> Self {
+      return Self { _ in throw error }
+    }
+  }
+
+  let interceptors: [any ServerInterceptor]
+
+  init(interceptors: [any ServerInterceptor] = []) {
+    self.interceptors = interceptors
+  }
+
+  func execute<Input, Output>(
+    deserializer: some MessageDeserializer<Input>,
+    serializer: some MessageSerializer<Output>,
+    handler: @escaping @Sendable (
+      ServerRequest.Stream<Input>
+    ) async throws -> ServerResponse.Stream<Output>,
+    producer: @escaping (RPCWriter<RPCRequestPart>.Closable) async throws -> Void,
+    consumer: @escaping (RPCAsyncSequence<RPCResponsePart>) async throws -> Void
+  ) async throws {
+    try await self.execute(
+      deserializer: deserializer,
+      serializer: serializer,
+      handler: .init(handler),
+      producer: producer,
+      consumer: consumer
+    )
+  }
+
+  func execute<Input, Output>(
+    deserializer: some MessageDeserializer<Input>,
+    serializer: some MessageSerializer<Output>,
+    handler: ServerHandler<Input, Output>,
+    producer: @escaping (RPCWriter<RPCRequestPart>.Closable) async throws -> Void,
+    consumer: @escaping (RPCAsyncSequence<RPCResponsePart>) async throws -> Void
+  ) async throws {
+    let input = RPCAsyncSequence.makeBackpressuredStream(
+      of: RPCRequestPart.self,
+      watermarks: (16, 32)
+    )
+
+    let output = RPCAsyncSequence.makeBackpressuredStream(
+      of: RPCResponsePart.self,
+      watermarks: (16, 32)
+    )
+
+    try await withThrowingTaskGroup(of: Void.self) { group in
+      group.addTask {
+        try await producer(input.writer)
+      }
+
+      group.addTask {
+        try await consumer(output.stream)
+      }
+
+      group.addTask {
+        await ServerRPCExecutor.execute(
+          stream: RPCStream(
+            descriptor: MethodDescriptor(service: "foo", method: "bar"),
+            inbound: input.stream,
+            outbound: output.writer
+          ),
+          deserializer: deserializer,
+          serializer: serializer,
+          interceptors: self.interceptors,
+          handler: { try await handler.handle($0) }
+        )
+      }
+
+      try await group.waitForAll()
+    }
+  }
+
+  func execute(
+    handler: ServerHandler<[UInt8], [UInt8]> = .echo,
+    producer: @escaping (RPCWriter<RPCRequestPart>.Closable) async throws -> Void,
+    consumer: @escaping (RPCAsyncSequence<RPCResponsePart>) async throws -> Void
+  ) async throws {
+    try await self.execute(
+      deserializer: IdentityDeserializer(),
+      serializer: IdentitySerializer(),
+      handler: handler,
+      producer: producer,
+      consumer: consumer
+    )
+  }
+}
+
+@available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
+extension ServerRPCExecutorTestHarness.ServerHandler where Input == Output {
+  static var echo: Self {
+    return Self { request in
+      return ServerResponse.Stream(metadata: request.metadata) { writer in
+        try await writer.write(contentsOf: request.messages)
+        return [:]
+      }
+    }
+  }
+}

+ 355 - 0
Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTests.swift

@@ -0,0 +1,355 @@
+/*
+ * Copyright 2023, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import Atomics
+import XCTest
+
+@testable import GRPCCore
+
+@available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
+final class ServerRPCExecutorTests: XCTestCase {
+  func testEchoNoMessages() async throws {
+    let harness = ServerRPCExecutorTestHarness()
+    try await harness.execute(handler: .echo) { inbound in
+      try await inbound.write(.metadata(["foo": "bar"]))
+      inbound.finish()
+    } consumer: { outbound in
+      let parts = try await outbound.collect()
+      XCTAssertEqual(
+        parts,
+        [
+          .metadata(["foo": "bar"]),
+          .status(.ok, [:]),
+        ]
+      )
+    }
+  }
+
+  func testEchoSingleMessage() async throws {
+    let harness = ServerRPCExecutorTestHarness()
+    try await harness.execute(handler: .echo) { inbound in
+      try await inbound.write(.metadata(["foo": "bar"]))
+      try await inbound.write(.message([0]))
+      inbound.finish()
+    } consumer: { outbound in
+      let parts = try await outbound.collect()
+      XCTAssertEqual(
+        parts,
+        [
+          .metadata(["foo": "bar"]),
+          .message([0]),
+          .status(.ok, [:]),
+        ]
+      )
+    }
+  }
+
+  func testEchoMultipleMessages() async throws {
+    let harness = ServerRPCExecutorTestHarness()
+    try await harness.execute(handler: .echo) { inbound in
+      try await inbound.write(.metadata(["foo": "bar"]))
+      try await inbound.write(.message([0]))
+      try await inbound.write(.message([1]))
+      try await inbound.write(.message([2]))
+      inbound.finish()
+    } consumer: { outbound in
+      let parts = try await outbound.collect()
+      XCTAssertEqual(
+        parts,
+        [
+          .metadata(["foo": "bar"]),
+          .message([0]),
+          .message([1]),
+          .message([2]),
+          .status(.ok, [:]),
+        ]
+      )
+    }
+  }
+
+  func testEchoSingleJSONMessage() async throws {
+    let harness = ServerRPCExecutorTestHarness()
+    try await harness.execute(
+      deserializer: JSONDeserializer<String>(),
+      serializer: JSONSerializer<String>()
+    ) { request in
+      let messages = try await request.messages.collect()
+      XCTAssertEqual(messages, ["hello"])
+      return ServerResponse.Stream(metadata: request.metadata) { writer in
+        try await writer.write("hello")
+        return [:]
+      }
+    } producer: { inbound in
+      try await inbound.write(.metadata(["foo": "bar"]))
+      try await inbound.write(.message(Array("\"hello\"".utf8)))
+      inbound.finish()
+    } consumer: { outbound in
+      let parts = try await outbound.collect()
+      XCTAssertEqual(
+        parts,
+        [
+          .metadata(["foo": "bar"]),
+          .message(Array("\"hello\"".utf8)),
+          .status(.ok, [:]),
+        ]
+      )
+    }
+  }
+
+  func testEchoMultipleJSONMessages() async throws {
+    let harness = ServerRPCExecutorTestHarness()
+    try await harness.execute(
+      deserializer: JSONDeserializer<String>(),
+      serializer: JSONSerializer<String>()
+    ) { request in
+      let messages = try await request.messages.collect()
+      XCTAssertEqual(messages, ["hello", "world"])
+      return ServerResponse.Stream(metadata: request.metadata) { writer in
+        try await writer.write("hello")
+        try await writer.write("world")
+        return [:]
+      }
+    } producer: { inbound in
+      try await inbound.write(.metadata(["foo": "bar"]))
+      try await inbound.write(.message(Array("\"hello\"".utf8)))
+      try await inbound.write(.message(Array("\"world\"".utf8)))
+      inbound.finish()
+    } consumer: { outbound in
+      let parts = try await outbound.collect()
+      XCTAssertEqual(
+        parts,
+        [
+          .metadata(["foo": "bar"]),
+          .message(Array("\"hello\"".utf8)),
+          .message(Array("\"world\"".utf8)),
+          .status(.ok, [:]),
+        ]
+      )
+    }
+  }
+
+  func testReturnTrailingMetadata() async throws {
+    let harness = ServerRPCExecutorTestHarness()
+    try await harness.execute(
+      deserializer: IdentityDeserializer(),
+      serializer: IdentitySerializer()
+    ) { request in
+      return ServerResponse.Stream(metadata: request.metadata) { _ in
+        return ["bar": "baz"]
+      }
+    } producer: { inbound in
+      try await inbound.write(.metadata(["foo": "bar"]))
+      inbound.finish()
+    } consumer: { outbound in
+      let parts = try await outbound.collect()
+      XCTAssertEqual(
+        parts,
+        [
+          .metadata(["foo": "bar"]),
+          .status(.ok, ["bar": "baz"]),
+        ]
+      )
+    }
+  }
+
+  func testEmptyInbound() async throws {
+    let harness = ServerRPCExecutorTestHarness()
+    try await harness.execute(handler: .echo) { inbound in
+      inbound.finish()
+    } consumer: { outbound in
+      let part = try await outbound.collect().first
+      XCTAssertStatus(part) { status, _ in
+        XCTAssertEqual(status.code, .internalError)
+      }
+    }
+  }
+
+  func testInboundStreamMissingMetadata() async throws {
+    let harness = ServerRPCExecutorTestHarness()
+    try await harness.execute(handler: .echo) { inbound in
+      try await inbound.write(.message([0]))
+      inbound.finish()
+    } consumer: { outbound in
+      let part = try await outbound.collect().first
+      XCTAssertStatus(part) { status, _ in
+        XCTAssertEqual(status.code, .internalError)
+      }
+    }
+  }
+
+  func testInboundStreamThrows() async throws {
+    let harness = ServerRPCExecutorTestHarness()
+    try await harness.execute(handler: .echo) { inbound in
+      inbound.finish(throwing: RPCError(code: .aborted, message: ""))
+    } consumer: { outbound in
+      let part = try await outbound.collect().first
+      XCTAssertStatus(part) { status, _ in
+        XCTAssertEqual(status.code, .unknown)
+      }
+    }
+  }
+
+  func testHandlerThrowsAnyError() async throws {
+    struct SomeError: Error {}
+    let harness = ServerRPCExecutorTestHarness()
+    try await harness.execute(handler: .throwing(SomeError())) { inbound in
+      try await inbound.write(.metadata([:]))
+      inbound.finish()
+    } consumer: { outbound in
+      let part = try await outbound.collect().first
+      XCTAssertStatus(part) { status, _ in
+        XCTAssertEqual(status.code, .unknown)
+      }
+    }
+  }
+
+  func testHandlerThrowsRPCError() async throws {
+    let error = RPCError(code: .aborted, message: "RPC aborted", metadata: ["foo": "bar"])
+    let harness = ServerRPCExecutorTestHarness()
+    try await harness.execute(handler: .throwing(error)) { inbound in
+      try await inbound.write(.metadata([:]))
+      inbound.finish()
+    } consumer: { outbound in
+      let part = try await outbound.collect().first
+      XCTAssertStatus(part) { status, metadata in
+        XCTAssertEqual(status.code, .aborted)
+        XCTAssertEqual(status.message, "RPC aborted")
+        XCTAssertEqual(metadata, ["foo": "bar"])
+      }
+    }
+  }
+
+  func testHandlerRespectsTimeout() async throws {
+    let harness = ServerRPCExecutorTestHarness()
+    try await harness.execute(
+      deserializer: IdentityDeserializer(),
+      serializer: IdentitySerializer()
+    ) { request in
+      do {
+        try await Task.sleep(until: .now.advanced(by: .seconds(180)), clock: .continuous)
+      } catch is CancellationError {
+        throw RPCError(code: .cancelled, message: "Sleep was cancelled")
+      }
+
+      XCTFail("Server handler should've been cancelled by timeout.")
+      return ServerResponse.Stream(error: RPCError(code: .failedPrecondition, message: ""))
+    } producer: { inbound in
+      try await inbound.write(.metadata(["grpc-timeout": "1000n"]))
+      inbound.finish()
+    } consumer: { outbound in
+      let part = try await outbound.collect().first
+      XCTAssertStatus(part) { status, _ in
+        XCTAssertEqual(status.code, .cancelled)
+        XCTAssertEqual(status.message, "Sleep was cancelled")
+      }
+    }
+  }
+
+  func testShortCircuitInterceptor() async throws {
+    let error = RPCError(
+      code: .unauthenticated,
+      message: "Unauthenticated",
+      metadata: ["foo": "bar"]
+    )
+
+    // The interceptor skips the handler altogether.
+    let harness = ServerRPCExecutorTestHarness(interceptors: [.rejectAll(with: error)])
+    try await harness.execute(
+      deserializer: IdentityDeserializer(),
+      serializer: IdentitySerializer()
+    ) { request in
+      XCTFail("Unexpected request")
+      return ServerResponse.Stream(
+        of: [UInt8].self,
+        error: RPCError(code: .failedPrecondition, message: "")
+      )
+    } producer: { inbound in
+      try await inbound.write(.metadata([:]))
+      inbound.finish()
+    } consumer: { outbound in
+      let part = try await outbound.collect().first
+      XCTAssertStatus(part) { status, metadata in
+        XCTAssertEqual(status.code, .unauthenticated)
+        XCTAssertEqual(status.message, "Unauthenticated")
+        XCTAssertEqual(metadata, ["foo": "bar"])
+      }
+    }
+  }
+
+  func testMultipleInterceptorsAreCalled() async throws {
+    let counter1 = ManagedAtomic(0)
+    let counter2 = ManagedAtomic(0)
+
+    // The interceptor skips the handler altogether.
+    let harness = ServerRPCExecutorTestHarness(
+      interceptors: [
+        .requestCounter(counter1),
+        .requestCounter(counter2),
+      ]
+    )
+
+    try await harness.execute(handler: .echo) { inbound in
+      try await inbound.write(.metadata([:]))
+      inbound.finish()
+    } consumer: { outbound in
+      let parts = try await outbound.collect()
+      XCTAssertEqual(parts, [.metadata([:]), .status(.ok, [:])])
+    }
+
+    XCTAssertEqual(counter1.load(ordering: .sequentiallyConsistent), 1)
+    XCTAssertEqual(counter2.load(ordering: .sequentiallyConsistent), 1)
+  }
+
+  func testInterceptorsAreCalledInOrder() async throws {
+    let counter1 = ManagedAtomic(0)
+    let counter2 = ManagedAtomic(0)
+
+    // The interceptor skips the handler altogether.
+    let harness = ServerRPCExecutorTestHarness(
+      interceptors: [
+        .requestCounter(counter1),
+        .rejectAll(with: RPCError(code: .unavailable, message: "")),
+        .requestCounter(counter2),
+      ]
+    )
+
+    try await harness.execute(handler: .echo) { inbound in
+      try await inbound.write(.metadata([:]))
+      inbound.finish()
+    } consumer: { outbound in
+      let parts = try await outbound.collect()
+      XCTAssertEqual(parts, [.status(Status(code: .unavailable, message: ""), [:])])
+    }
+
+    XCTAssertEqual(counter1.load(ordering: .sequentiallyConsistent), 1)
+    // Zero because the RPC should've been rejected by the second interceptor.
+    XCTAssertEqual(counter2.load(ordering: .sequentiallyConsistent), 0)
+  }
+
+  func testThrowingInterceptor() async throws {
+    let harness = ServerRPCExecutorTestHarness(
+      interceptors: [.throwError(RPCError(code: .unavailable, message: "Unavailable"))]
+    )
+
+    try await harness.execute(handler: .echo) { inbound in
+      try await inbound.write(.metadata([:]))
+      inbound.finish()
+    } consumer: { outbound in
+      let parts = try await outbound.collect()
+      XCTAssertEqual(parts, [.status(Status(code: .unavailable, message: "Unavailable"), [:])])
+    }
+  }
+}

+ 84 - 0
Tests/GRPCCoreTests/Test Utilities/Call/Server/ServerInterceptors.swift

@@ -0,0 +1,84 @@
+/*
+ * Copyright 2023, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import Atomics
+import GRPCCore
+
+extension ServerInterceptor where Self == RejectAllServerInterceptor {
+  static func rejectAll(with error: RPCError) -> Self {
+    return RejectAllServerInterceptor(error: error, throw: false)
+  }
+
+  static func throwError(_ error: RPCError) -> Self {
+    return RejectAllServerInterceptor(error: error, throw: true)
+  }
+
+}
+
+extension ServerInterceptor where Self == RequestCountingServerInterceptor {
+  static func requestCounter(_ counter: ManagedAtomic<Int>) -> Self {
+    return RequestCountingServerInterceptor(counter: counter)
+  }
+}
+
+/// Rejects all RPCs with the provided error.
+struct RejectAllServerInterceptor: ServerInterceptor {
+  /// The error to reject all RPCs with.
+  let error: RPCError
+  /// Whether the error should be thrown. If `false` then the request is rejected with the error
+  /// instead.
+  let `throw`: Bool
+
+  init(error: RPCError, throw: Bool = false) {
+    self.error = error
+    self.`throw` = `throw`
+  }
+
+  func intercept<Input: Sendable, Output: Sendable>(
+    request: ServerRequest.Stream<Input>,
+    context: ServerInterceptorContext,
+    next: @Sendable (
+      ServerRequest.Stream<Input>,
+      ServerInterceptorContext
+    ) async throws -> ServerResponse.Stream<Output>
+  ) async throws -> ServerResponse.Stream<Output> {
+    if self.throw {
+      throw self.error
+    } else {
+      return ServerResponse.Stream(error: self.error)
+    }
+  }
+}
+
+struct RequestCountingServerInterceptor: ServerInterceptor {
+  /// The error to reject all RPCs with.
+  let counter: ManagedAtomic<Int>
+
+  init(counter: ManagedAtomic<Int>) {
+    self.counter = counter
+  }
+
+  func intercept<Input: Sendable, Output: Sendable>(
+    request: ServerRequest.Stream<Input>,
+    context: ServerInterceptorContext,
+    next: @Sendable (
+      ServerRequest.Stream<Input>,
+      ServerInterceptorContext
+    ) async throws -> ServerResponse.Stream<Output>
+  ) async throws -> ServerResponse.Stream<Output> {
+    self.counter.wrappingIncrement(ordering: .sequentiallyConsistent)
+    return try await next(request, context)
+  }
+}

+ 13 - 0
Tests/GRPCCoreTests/Test Utilities/XCTest+Utilities.swift

@@ -63,3 +63,16 @@ func XCTAssertThrowsRPCErrorAsync<T>(
     XCTFail("Error had unexpected type '\(type(of: error))'")
   }
 }
+
+func XCTAssertStatus(
+  _ part: RPCResponsePart?,
+  statusHandler: (Status, Metadata) -> Void = { _, _ in }
+) {
+  switch part {
+  case .some(.status(let status, let metadata)):
+    statusHandler(status, metadata)
+  default:
+    XCTFail("Expected '.status' but found '\(String(describing: part))'")
+  }
+
+}