ソースを参照

Add client and server tracing interceptors (#1756)

Gustavo Cairo 2 年 前
コミット
378b6e8993

+ 27 - 1
Package.swift

@@ -72,6 +72,10 @@ let packageDependencies: [Package.Dependency] = [
     url: "https://github.com/apple/swift-docc-plugin",
     from: "1.0.0"
   ),
+  .package(
+    url: "https://github.com/apple/swift-distributed-tracing.git",
+    from: "1.0.0"
+  ),
 ].appending(
   .package(
     url: "https://github.com/apple/swift-nio-ssl.git",
@@ -131,9 +135,11 @@ extension Target.Dependency {
   )
   static let dequeModule: Self = .product(name: "DequeModule", package: "swift-collections")
   static let atomics: Self = .product(name: "Atomics", package: "swift-atomics")
+  static let tracing: Self = .product(name: "Tracing", package: "swift-distributed-tracing")
 
   static let grpcCore: Self = .target(name: "GRPCCore")
   static let grpcInProcessTransport: Self = .target(name: "GRPCInProcessTransport")
+  static let grpcInterceptors: Self = .target(name: "GRPCInterceptors")
   static let grpcHTTP2Core: Self = .target(name: "GRPCHTTP2Core")
   static let grpcHTTP2TransportNIOPosix: Self = .target(name: "GRPCHTTP2TransportNIOPosix")
   static let grpcHTTP2TransportNIOTransportServices: Self = .target(name: "GRPCHTTP2TransportNIOTransportServices")
@@ -181,6 +187,14 @@ extension Target {
     ]
   )
   
+  static let grpcInterceptors: Target = .target(
+    name: "GRPCInterceptors",
+    dependencies: [
+      .grpcCore,
+      .tracing
+    ]
+  )
+
   static let grpcHTTP2Core: Target = .target(
     name: "GRPCHTTP2Core",
     dependencies: [
@@ -274,10 +288,20 @@ extension Target {
     name: "GRPCInProcessTransportTests",
     dependencies: [
       .grpcCore,
-      .grpcInProcessTransport,
+      .grpcInProcessTransport
     ]
   )
   
+  static let grpcInterceptorsTests: Target = .testTarget(
+    name: "GRPCInterceptorsTests",
+    dependencies: [
+      .grpcCore,
+      .tracing,
+      .nioCore,
+      .grpcInterceptors
+    ]
+  )
+
   static let grpcHTTP2CoreTests: Target = .testTarget(
     name: "GRPCHTTP2CoreTests",
     dependencies: [
@@ -638,6 +662,7 @@ let package = Package(
     .grpcCore,
     .grpcInProcessTransport,
     .grpcCodeGen,
+    .grpcInterceptors,
     .grpcHTTP2Core,
     .grpcHTTP2TransportNIOPosix,
     .grpcHTTP2TransportNIOTransportServices,
@@ -646,6 +671,7 @@ let package = Package(
     .grpcCoreTests,
     .grpcInProcessTransportTests,
     .grpcCodeGenTests,
+    .grpcInterceptorsTests,
     .grpcHTTP2CoreTests,
     .grpcHTTP2TransportNIOPosixTests,
     .grpcHTTP2TransportNIOTransportServicesTests

+ 140 - 0
Sources/GRPCInterceptors/ClientTracingInterceptor.swift

@@ -0,0 +1,140 @@
+/*
+ * Copyright 2024, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import GRPCCore
+import Tracing
+
+/// A client interceptor that injects tracing information into the request.
+///
+/// The tracing information is taken from the current `ServiceContext`, and injected into the request's
+/// metadata. It will then be picked up by the server-side ``ServerTracingInterceptor``.
+///
+/// For more information, refer to the documentation for `swift-distributed-tracing`.
+@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
+public struct ClientTracingInterceptor: ClientInterceptor {
+  private let injector: ClientRequestInjector
+  private let emitEventOnEachWrite: Bool
+
+  /// Create a new instance of a ``ClientTracingInterceptor``.
+  ///
+  /// - Parameter emitEventOnEachWrite: If `true`, each request part sent and response part
+  /// received will be recorded as a separate event in a tracing span. Otherwise, only the request/response
+  /// start and end will be recorded as events.
+  public init(emitEventOnEachWrite: Bool = false) {
+    self.injector = ClientRequestInjector()
+    self.emitEventOnEachWrite = emitEventOnEachWrite
+  }
+
+  /// This interceptor will inject as the request's metadata whatever `ServiceContext` key-value pairs
+  /// have been made available by the tracing implementation bootstrapped in your application.
+  ///
+  /// Which key-value pairs are injected will depend on the specific tracing implementation
+  /// that has been configured when bootstrapping `swift-distributed-tracing` in your application.
+  public func intercept<Input, Output>(
+    request: ClientRequest.Stream<Input>,
+    context: ClientInterceptorContext,
+    next: @Sendable (ClientRequest.Stream<Input>, ClientInterceptorContext) async throws ->
+      ClientResponse.Stream<Output>
+  ) async throws -> ClientResponse.Stream<Output> where Input: Sendable, Output: Sendable {
+    var request = request
+    let tracer = InstrumentationSystem.tracer
+    let serviceContext = ServiceContext.current ?? .topLevel
+
+    tracer.inject(
+      serviceContext,
+      into: &request.metadata,
+      using: self.injector
+    )
+
+    return try await tracer.withSpan(
+      context.descriptor.fullyQualifiedMethod,
+      context: serviceContext,
+      ofKind: .client
+    ) { span in
+      span.addEvent("Request started")
+
+      if self.emitEventOnEachWrite {
+        let wrappedProducer = request.producer
+        request.producer = { writer in
+          let eventEmittingWriter = HookedWriter(
+            wrapping: writer,
+            beforeEachWrite: {
+              span.addEvent("Sending request part")
+            },
+            afterEachWrite: {
+              span.addEvent("Sent request part")
+            }
+          )
+
+          do {
+            try await wrappedProducer(RPCWriter(wrapping: eventEmittingWriter))
+          } catch {
+            span.addEvent("Error encountered")
+            throw error
+          }
+
+          span.addEvent("Request end")
+        }
+      }
+
+      var response: ClientResponse.Stream<Output>
+      do {
+        response = try await next(request, context)
+      } catch {
+        span.addEvent("Error encountered")
+        throw error
+      }
+
+      switch response.accepted {
+      case .success(var success):
+        if self.emitEventOnEachWrite {
+          let onEachPartRecordingSequence = success.bodyParts.map { element in
+            span.addEvent("Received response part")
+            return element
+          }
+          let onFinishRecordingSequence = OnFinishAsyncSequence(
+            wrapping: onEachPartRecordingSequence
+          ) {
+            span.addEvent("Received response end")
+          }
+          success.bodyParts = RPCAsyncSequence(wrapping: onFinishRecordingSequence)
+          response.accepted = .success(success)
+        } else {
+          let onFinishRecordingSequence = OnFinishAsyncSequence(wrapping: success.bodyParts) {
+            span.addEvent("Received response end")
+          }
+          success.bodyParts = RPCAsyncSequence(wrapping: onFinishRecordingSequence)
+          response.accepted = .success(success)
+        }
+      case .failure:
+        span.addEvent("Received error response")
+      }
+
+      return response
+    }
+  }
+}
+
+/// An injector responsible for injecting the required instrumentation keys from the `ServiceContext` into
+/// the request metadata.
+@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
+struct ClientRequestInjector: Instrumentation.Injector {
+  typealias Carrier = Metadata
+
+  func inject(_ value: String, forKey key: String, into carrier: inout Carrier) {
+    carrier.addString(value, forKey: key)
+  }
+}

+ 40 - 0
Sources/GRPCInterceptors/HookedWriter.swift

@@ -0,0 +1,40 @@
+/*
+ * Copyright 2024, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import GRPCCore
+import Tracing
+
+@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
+struct HookedWriter<Element>: RPCWriterProtocol {
+  private let writer: any RPCWriterProtocol<Element>
+  private let beforeEachWrite: @Sendable () -> Void
+  private let afterEachWrite: @Sendable () -> Void
+
+  init(
+    wrapping other: some RPCWriterProtocol<Element>,
+    beforeEachWrite: @Sendable @escaping () -> Void,
+    afterEachWrite: @Sendable @escaping () -> Void
+  ) {
+    self.writer = other
+    self.beforeEachWrite = beforeEachWrite
+    self.afterEachWrite = afterEachWrite
+  }
+
+  func write(contentsOf elements: some Sequence<Element>) async throws {
+    self.beforeEachWrite()
+    try await self.writer.write(contentsOf: elements)
+    self.afterEachWrite()
+  }
+}

+ 57 - 0
Sources/GRPCInterceptors/OnFinishAsyncSequence.swift

@@ -0,0 +1,57 @@
+/*
+ * Copyright 2024, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
+struct OnFinishAsyncSequence<Element: Sendable>: AsyncSequence, Sendable {
+  private let _makeAsyncIterator: @Sendable () -> AsyncIterator
+
+  init<S: AsyncSequence>(
+    wrapping other: S,
+    onFinish: @escaping () -> Void
+  ) where S.Element == Element {
+    self._makeAsyncIterator = {
+      AsyncIterator(wrapping: other.makeAsyncIterator(), onFinish: onFinish)
+    }
+  }
+
+  func makeAsyncIterator() -> AsyncIterator {
+    self._makeAsyncIterator()
+  }
+
+  struct AsyncIterator: AsyncIteratorProtocol {
+    private var iterator: any AsyncIteratorProtocol
+    private var onFinish: (() -> Void)?
+
+    fileprivate init<Iterator>(
+      wrapping other: Iterator,
+      onFinish: @escaping () -> Void
+    ) where Iterator: AsyncIteratorProtocol, Iterator.Element == Element {
+      self.iterator = other
+      self.onFinish = onFinish
+    }
+
+    mutating func next() async throws -> Element? {
+      let elem = try await self.iterator.next()
+
+      if elem == nil {
+        self.onFinish?()
+        self.onFinish = nil
+      }
+
+      return elem as? Element
+    }
+  }
+}

+ 148 - 0
Sources/GRPCInterceptors/ServerTracingInterceptor.swift

@@ -0,0 +1,148 @@
+/*
+ * Copyright 2024, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import GRPCCore
+import Tracing
+
+/// A server interceptor that extracts tracing information from the request.
+///
+/// The extracted tracing information is made available to user code via the current `ServiceContext`.
+/// For more information, refer to the documentation for `swift-distributed-tracing`.
+@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
+public struct ServerTracingInterceptor: ServerInterceptor {
+  private let extractor: ServerRequestExtractor
+  private let emitEventOnEachWrite: Bool
+
+  /// Create a new instance of a ``ServerTracingInterceptor``.
+  ///
+  /// - Parameter emitEventOnEachWrite: If `true`, each response part sent and request part
+  /// received will be recorded as a separate event in a tracing span. Otherwise, only the request/response
+  /// start and end will be recorded as events.
+  public init(emitEventOnEachWrite: Bool = false) {
+    self.extractor = ServerRequestExtractor()
+    self.emitEventOnEachWrite = emitEventOnEachWrite
+  }
+
+  /// This interceptor will extract whatever `ServiceContext` key-value pairs have been inserted into the
+  /// request's metadata, and will make them available to user code via the `ServiceContext/current`
+  /// context.
+  ///
+  /// Which key-value pairs are extracted and made available will depend on the specific tracing implementation
+  /// that has been configured when bootstrapping `swift-distributed-tracing` in your application.
+  public func intercept<Input, Output>(
+    request: ServerRequest.Stream<Input>,
+    context: ServerInterceptorContext,
+    next: @Sendable (ServerRequest.Stream<Input>, ServerInterceptorContext) async throws ->
+      ServerResponse.Stream<Output>
+  ) async throws -> ServerResponse.Stream<Output> where Input: Sendable, Output: Sendable {
+    var serviceContext = ServiceContext.topLevel
+    let tracer = InstrumentationSystem.tracer
+
+    tracer.extract(
+      request.metadata,
+      into: &serviceContext,
+      using: self.extractor
+    )
+
+    return try await ServiceContext.withValue(serviceContext) {
+      try await tracer.withSpan(
+        context.descriptor.fullyQualifiedMethod,
+        context: serviceContext,
+        ofKind: .server
+      ) { span in
+        span.addEvent("Received request start")
+
+        var request = request
+
+        if self.emitEventOnEachWrite {
+          request.messages = RPCAsyncSequence(
+            wrapping: request.messages.map { element in
+              span.addEvent("Received request part")
+              return element
+            }
+          )
+        }
+
+        var response = try await next(request, context)
+
+        span.addEvent("Received request end")
+
+        switch response.accepted {
+        case .success(var success):
+          let wrappedProducer = success.producer
+
+          if self.emitEventOnEachWrite {
+            success.producer = { writer in
+              let eventEmittingWriter = HookedWriter(
+                wrapping: writer,
+                beforeEachWrite: {
+                  span.addEvent("Sending response part")
+                },
+                afterEachWrite: {
+                  span.addEvent("Sent response part")
+                }
+              )
+
+              let wrappedResult: Metadata
+              do {
+                wrappedResult = try await wrappedProducer(
+                  RPCWriter(wrapping: eventEmittingWriter)
+                )
+              } catch {
+                span.addEvent("Error encountered")
+                throw error
+              }
+
+              span.addEvent("Sent response end")
+              return wrappedResult
+            }
+          } else {
+            success.producer = { writer in
+              let wrappedResult: Metadata
+              do {
+                wrappedResult = try await wrappedProducer(writer)
+              } catch {
+                span.addEvent("Error encountered")
+                throw error
+              }
+
+              span.addEvent("Sent response end")
+              return wrappedResult
+            }
+          }
+
+          response = .init(accepted: .success(success))
+        case .failure:
+          span.addEvent("Sent error response")
+        }
+
+        return response
+      }
+    }
+  }
+}
+
+/// An extractor responsible for extracting the required instrumentation keys from request metadata.
+@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
+struct ServerRequestExtractor: Instrumentation.Extractor {
+  typealias Carrier = Metadata
+
+  func extract(key: String, from carrier: Carrier) -> String? {
+    var values = carrier[stringValues: key].makeIterator()
+    // There should only be one value for each key. If more, pick just one.
+    return values.next()
+  }
+}

+ 333 - 0
Tests/GRPCInterceptorsTests/TracingInterceptorTests.swift

@@ -0,0 +1,333 @@
+/*
+ * Copyright 2024, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import GRPCCore
+import Tracing
+import XCTest
+
+@testable import GRPCInterceptors
+
+@available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
+final class TracingInterceptorTests: XCTestCase {
+  override class func setUp() {
+    InstrumentationSystem.bootstrap(TestTracer())
+  }
+
+  #if swift(>=5.8)  // Compiling these tests fails in 5.7
+  func testClientInterceptor() async throws {
+    var serviceContext = ServiceContext.topLevel
+    let traceIDString = UUID().uuidString
+    let interceptor = ClientTracingInterceptor(emitEventOnEachWrite: false)
+    let (stream, continuation) = AsyncStream<String>.makeStream()
+    serviceContext.traceID = traceIDString
+
+    try await ServiceContext.withValue(serviceContext) {
+      let methodDescriptor = MethodDescriptor(
+        service: "TracingInterceptorTests",
+        method: "testClientInterceptor"
+      )
+      let response = try await interceptor.intercept(
+        request: .init(producer: { writer in
+          try await writer.write(contentsOf: ["request1"])
+          try await writer.write(contentsOf: ["request2"])
+        }),
+        context: .init(descriptor: methodDescriptor)
+      ) { stream, _ in
+        // Assert the metadata contains the injected context key-value.
+        XCTAssertEqual(stream.metadata, ["trace-id": "\(traceIDString)"])
+
+        // Write into the response stream to make sure the `producer` closure's called.
+        let writer = RPCWriter(wrapping: TestWriter(streamContinuation: continuation))
+        try await stream.producer(writer)
+        continuation.finish()
+
+        return .init(
+          metadata: [],
+          bodyParts: .init(
+            wrapping: AsyncStream<ClientResponse.Stream.Contents.BodyPart> { cont in
+              cont.yield(.message(["response"]))
+              cont.finish()
+            }
+          )
+        )
+      }
+
+      var streamIterator = stream.makeAsyncIterator()
+      var element = await streamIterator.next()
+      XCTAssertEqual(element, "request1")
+      element = await streamIterator.next()
+      XCTAssertEqual(element, "request2")
+      element = await streamIterator.next()
+      XCTAssertNil(element)
+
+      var messages = response.messages.makeAsyncIterator()
+      var message = try await messages.next()
+      XCTAssertEqual(message, ["response"])
+      message = try await messages.next()
+      XCTAssertNil(message)
+
+      let tracer = InstrumentationSystem.tracer as! TestTracer
+      XCTAssertEqual(
+        tracer.getEventsForTestSpan(ofOperationName: methodDescriptor.fullyQualifiedMethod).map {
+          $0.name
+        },
+        [
+          "Request started",
+          "Received response end",
+        ]
+      )
+    }
+  }
+
+  func testClientInterceptorAllEventsRecorded() async throws {
+    let methodDescriptor = MethodDescriptor(
+      service: "TracingInterceptorTests",
+      method: "testClientInterceptorAllEventsRecorded"
+    )
+    var serviceContext = ServiceContext.topLevel
+    let traceIDString = UUID().uuidString
+    let interceptor = ClientTracingInterceptor(emitEventOnEachWrite: true)
+    let (stream, continuation) = AsyncStream<String>.makeStream()
+    serviceContext.traceID = traceIDString
+
+    try await ServiceContext.withValue(serviceContext) {
+      let response = try await interceptor.intercept(
+        request: .init(producer: { writer in
+          try await writer.write(contentsOf: ["request1"])
+          try await writer.write(contentsOf: ["request2"])
+        }),
+        context: .init(descriptor: methodDescriptor)
+      ) { stream, _ in
+        // Assert the metadata contains the injected context key-value.
+        XCTAssertEqual(stream.metadata, ["trace-id": "\(traceIDString)"])
+
+        // Write into the response stream to make sure the `producer` closure's called.
+        let writer = RPCWriter(wrapping: TestWriter(streamContinuation: continuation))
+        try await stream.producer(writer)
+        continuation.finish()
+
+        return .init(
+          metadata: [],
+          bodyParts: .init(
+            wrapping: AsyncStream<ClientResponse.Stream.Contents.BodyPart> { cont in
+              cont.yield(.message(["response"]))
+              cont.finish()
+            }
+          )
+        )
+      }
+
+      var streamIterator = stream.makeAsyncIterator()
+      var element = await streamIterator.next()
+      XCTAssertEqual(element, "request1")
+      element = await streamIterator.next()
+      XCTAssertEqual(element, "request2")
+      element = await streamIterator.next()
+      XCTAssertNil(element)
+
+      var messages = response.messages.makeAsyncIterator()
+      var message = try await messages.next()
+      XCTAssertEqual(message, ["response"])
+      message = try await messages.next()
+      XCTAssertNil(message)
+
+      let tracer = InstrumentationSystem.tracer as! TestTracer
+      XCTAssertEqual(
+        tracer.getEventsForTestSpan(ofOperationName: methodDescriptor.fullyQualifiedMethod).map {
+          $0.name
+        },
+        [
+          "Request started",
+          // Recorded when `request1` is sent
+          "Sending request part",
+          "Sent request part",
+          // Recorded when `request2` is sent
+          "Sending request part",
+          "Sent request part",
+          // Recorded after all request parts have been sent
+          "Request end",
+          // Recorded when receiving response part
+          "Received response part",
+          // Recorded at end of response
+          "Received response end",
+        ]
+      )
+    }
+  }
+  #endif  // swift >= 5.7
+
+  func testServerInterceptorErrorResponse() async throws {
+    let methodDescriptor = MethodDescriptor(
+      service: "TracingInterceptorTests",
+      method: "testServerInterceptorErrorResponse"
+    )
+    let interceptor = ServerTracingInterceptor(emitEventOnEachWrite: false)
+    let response = try await interceptor.intercept(
+      request: .init(single: .init(metadata: ["trace-id": "some-trace-id"], message: [])),
+      context: .init(descriptor: methodDescriptor)
+    ) { _, _ in
+      ServerResponse.Stream<String>(error: .init(code: .unknown, message: "Test error"))
+    }
+    XCTAssertThrowsError(try response.accepted.get())
+
+    let tracer = InstrumentationSystem.tracer as! TestTracer
+    XCTAssertEqual(
+      tracer.getEventsForTestSpan(ofOperationName: methodDescriptor.fullyQualifiedMethod).map {
+        $0.name
+      },
+      [
+        "Received request start",
+        "Received request end",
+        "Sent error response",
+      ]
+    )
+  }
+
+  func testServerInterceptor() async throws {
+    let methodDescriptor = MethodDescriptor(
+      service: "TracingInterceptorTests",
+      method: "testServerInterceptor"
+    )
+    let (stream, continuation) = AsyncStream<String>.makeStream()
+    let interceptor = ServerTracingInterceptor(emitEventOnEachWrite: false)
+    let response = try await interceptor.intercept(
+      request: .init(single: .init(metadata: ["trace-id": "some-trace-id"], message: [])),
+      context: .init(descriptor: methodDescriptor)
+    ) { _, _ in
+      { [serviceContext = ServiceContext.current] in
+        return ServerResponse.Stream<String>(
+          accepted: .success(
+            .init(
+              metadata: [],
+              producer: { writer in
+                guard let serviceContext else {
+                  XCTFail("There should be a service context present.")
+                  return ["Result": "Test failed"]
+                }
+
+                let traceID = serviceContext.traceID
+                XCTAssertEqual("some-trace-id", traceID)
+
+                try await writer.write("response1")
+                try await writer.write("response2")
+
+                return ["Result": "Trailing metadata"]
+              }
+            )
+          )
+        )
+      }()
+    }
+
+    let responseContents = try response.accepted.get()
+    let trailingMetadata = try await responseContents.producer(
+      RPCWriter(wrapping: TestWriter(streamContinuation: continuation))
+    )
+    continuation.finish()
+    XCTAssertEqual(trailingMetadata, ["Result": "Trailing metadata"])
+
+    var streamIterator = stream.makeAsyncIterator()
+    var element = await streamIterator.next()
+    XCTAssertEqual(element, "response1")
+    element = await streamIterator.next()
+    XCTAssertEqual(element, "response2")
+    element = await streamIterator.next()
+    XCTAssertNil(element)
+
+    let tracer = InstrumentationSystem.tracer as! TestTracer
+    XCTAssertEqual(
+      tracer.getEventsForTestSpan(ofOperationName: methodDescriptor.fullyQualifiedMethod).map {
+        $0.name
+      },
+      [
+        "Received request start",
+        "Received request end",
+        "Sent response end",
+      ]
+    )
+  }
+
+  func testServerInterceptorAllEventsRecorded() async throws {
+    let methodDescriptor = MethodDescriptor(
+      service: "TracingInterceptorTests",
+      method: "testServerInterceptorAllEventsRecorded"
+    )
+    let (stream, continuation) = AsyncStream<String>.makeStream()
+    let interceptor = ServerTracingInterceptor(emitEventOnEachWrite: true)
+    let response = try await interceptor.intercept(
+      request: .init(single: .init(metadata: ["trace-id": "some-trace-id"], message: [])),
+      context: .init(descriptor: methodDescriptor)
+    ) { _, _ in
+      { [serviceContext = ServiceContext.current] in
+        return ServerResponse.Stream<String>(
+          accepted: .success(
+            .init(
+              metadata: [],
+              producer: { writer in
+                guard let serviceContext else {
+                  XCTFail("There should be a service context present.")
+                  return ["Result": "Test failed"]
+                }
+
+                let traceID = serviceContext.traceID
+                XCTAssertEqual("some-trace-id", traceID)
+
+                try await writer.write("response1")
+                try await writer.write("response2")
+
+                return ["Result": "Trailing metadata"]
+              }
+            )
+          )
+        )
+      }()
+    }
+
+    let responseContents = try response.accepted.get()
+    let trailingMetadata = try await responseContents.producer(
+      RPCWriter(wrapping: TestWriter(streamContinuation: continuation))
+    )
+    continuation.finish()
+    XCTAssertEqual(trailingMetadata, ["Result": "Trailing metadata"])
+
+    var streamIterator = stream.makeAsyncIterator()
+    var element = await streamIterator.next()
+    XCTAssertEqual(element, "response1")
+    element = await streamIterator.next()
+    XCTAssertEqual(element, "response2")
+    element = await streamIterator.next()
+    XCTAssertNil(element)
+
+    let tracer = InstrumentationSystem.tracer as! TestTracer
+    XCTAssertEqual(
+      tracer.getEventsForTestSpan(ofOperationName: methodDescriptor.fullyQualifiedMethod).map {
+        $0.name
+      },
+      [
+        "Received request start",
+        "Received request end",
+        // Recorded when `response1` is sent
+        "Sending response part",
+        "Sent response part",
+        // Recorded when `response2` is sent
+        "Sending response part",
+        "Sent response part",
+        // Recorded when we're done sending response
+        "Sent response end",
+      ]
+    )
+  }
+}

+ 182 - 0
Tests/GRPCInterceptorsTests/TracingTestsUtilities.swift

@@ -0,0 +1,182 @@
+/*
+ * Copyright 2024, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import GRPCCore
+import NIOConcurrencyHelpers
+import Tracing
+
+final class TestTracer: Tracer {
+  typealias Span = TestSpan
+
+  private var testSpans: NIOLockedValueBox<[String: TestSpan]> = .init([:])
+
+  func getEventsForTestSpan(ofOperationName operationName: String) -> [SpanEvent] {
+    self.testSpans.withLockedValue({ $0[operationName] })?.events ?? []
+  }
+
+  func extract<Carrier, Extract>(
+    _ carrier: Carrier,
+    into context: inout ServiceContextModule.ServiceContext,
+    using extractor: Extract
+  ) where Carrier == Extract.Carrier, Extract: Instrumentation.Extractor {
+    let traceID = extractor.extract(key: TraceID.keyName, from: carrier)
+    context[TraceID.self] = traceID
+  }
+
+  func inject<Carrier, Inject>(
+    _ context: ServiceContextModule.ServiceContext,
+    into carrier: inout Carrier,
+    using injector: Inject
+  ) where Carrier == Inject.Carrier, Inject: Instrumentation.Injector {
+    if let traceID = context.traceID {
+      injector.inject(traceID, forKey: TraceID.keyName, into: &carrier)
+    }
+  }
+
+  func forceFlush() {
+    // no-op
+  }
+
+  func startSpan<Instant>(
+    _ operationName: String,
+    context: @autoclosure () -> ServiceContext,
+    ofKind kind: SpanKind,
+    at instant: @autoclosure () -> Instant,
+    function: String,
+    file fileID: String,
+    line: UInt
+  ) -> TestSpan where Instant: TracerInstant {
+    return self.testSpans.withLockedValue { testSpans in
+      let span = TestSpan(context: context(), operationName: operationName)
+      testSpans[operationName] = span
+      return span
+    }
+  }
+}
+
+class TestSpan: Span {
+  var context: ServiceContextModule.ServiceContext
+  var operationName: String
+  var attributes: Tracing.SpanAttributes
+  var isRecording: Bool
+  private(set) var status: Tracing.SpanStatus?
+  private(set) var events: [Tracing.SpanEvent] = []
+
+  init(
+    context: ServiceContextModule.ServiceContext,
+    operationName: String,
+    attributes: Tracing.SpanAttributes = [:],
+    isRecording: Bool = true
+  ) {
+    self.context = context
+    self.operationName = operationName
+    self.attributes = attributes
+    self.isRecording = isRecording
+  }
+
+  func setStatus(_ status: Tracing.SpanStatus) {
+    self.status = status
+  }
+
+  func addEvent(_ event: Tracing.SpanEvent) {
+    self.events.append(event)
+  }
+
+  func recordError<Instant>(
+    _ error: any Error,
+    attributes: Tracing.SpanAttributes,
+    at instant: @autoclosure () -> Instant
+  ) where Instant: Tracing.TracerInstant {
+    self.setStatus(
+      .init(
+        code: .error,
+        message: "Error: \(error), attributes: \(attributes), at instant: \(instant())"
+      )
+    )
+  }
+
+  func addLink(_ link: Tracing.SpanLink) {
+    self.context.spanLinks?.append(link)
+  }
+
+  func end<Instant>(at instant: @autoclosure () -> Instant) where Instant: Tracing.TracerInstant {
+    self.setStatus(.init(code: .ok, message: "Ended at instant: \(instant())"))
+  }
+}
+
+enum TraceID: ServiceContextModule.ServiceContextKey {
+  typealias Value = String
+
+  static let keyName = "trace-id"
+}
+
+enum ServiceContextSpanLinksKey: ServiceContextModule.ServiceContextKey {
+  typealias Value = [SpanLink]
+
+  static let keyName = "span-links"
+}
+
+extension ServiceContext {
+  var traceID: String? {
+    get {
+      self[TraceID.self]
+    }
+    set {
+      self[TraceID.self] = newValue
+    }
+  }
+
+  var spanLinks: [SpanLink]? {
+    get {
+      self[ServiceContextSpanLinksKey.self]
+    }
+    set {
+      self[ServiceContextSpanLinksKey.self] = newValue
+    }
+  }
+}
+
+struct TestWriter<WriterElement>: RPCWriterProtocol {
+  typealias Element = WriterElement
+
+  private let streamContinuation: AsyncStream<Element>.Continuation
+
+  init(streamContinuation: AsyncStream<Element>.Continuation) {
+    self.streamContinuation = streamContinuation
+  }
+
+  func write(contentsOf elements: some Sequence<Self.Element>) async throws {
+    elements.forEach { element in
+      self.streamContinuation.yield(element)
+    }
+  }
+}
+
+#if swift(<5.9)
+@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
+extension AsyncStream {
+  static func makeStream(
+    of elementType: Element.Type = Element.self,
+    bufferingPolicy limit: AsyncStream<Element>.Continuation.BufferingPolicy = .unbounded
+  ) -> (stream: AsyncStream<Element>, continuation: AsyncStream<Element>.Continuation) {
+    var continuation: AsyncStream<Element>.Continuation!
+    let stream = AsyncStream(Element.self, bufferingPolicy: limit) {
+      continuation = $0
+    }
+    return (stream, continuation)
+  }
+}
+#endif