TracingInterceptorTests.swift 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. /*
  2. * Copyright 2024, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import GRPCCore
  17. import Tracing
  18. import XCTest
  19. @testable import GRPCInterceptors
  20. @available(macOS 15.0, iOS 18.0, watchOS 11.0, tvOS 18.0, visionOS 2.0, *)
  21. final class TracingInterceptorTests: XCTestCase {
  22. override class func setUp() {
  23. InstrumentationSystem.bootstrap(TestTracer())
  24. }
  25. func testClientInterceptor() async throws {
  26. var serviceContext = ServiceContext.topLevel
  27. let traceIDString = UUID().uuidString
  28. let interceptor = ClientTracingInterceptor(emitEventOnEachWrite: false)
  29. let (stream, continuation) = AsyncStream<String>.makeStream()
  30. serviceContext.traceID = traceIDString
  31. try await ServiceContext.withValue(serviceContext) {
  32. let methodDescriptor = MethodDescriptor(
  33. service: "TracingInterceptorTests",
  34. method: "testClientInterceptor"
  35. )
  36. let response = try await interceptor.intercept(
  37. request: .init(producer: { writer in
  38. try await writer.write(contentsOf: ["request1"])
  39. try await writer.write(contentsOf: ["request2"])
  40. }),
  41. context: .init(descriptor: methodDescriptor)
  42. ) { stream, _ in
  43. // Assert the metadata contains the injected context key-value.
  44. XCTAssertEqual(stream.metadata, ["trace-id": "\(traceIDString)"])
  45. // Write into the response stream to make sure the `producer` closure's called.
  46. let writer = RPCWriter(wrapping: TestWriter(streamContinuation: continuation))
  47. try await stream.producer(writer)
  48. continuation.finish()
  49. return .init(
  50. metadata: [],
  51. bodyParts: RPCAsyncSequence(
  52. wrapping: AsyncThrowingStream<ClientResponse.Stream.Contents.BodyPart, any Error> {
  53. $0.yield(.message(["response"]))
  54. $0.finish()
  55. }
  56. )
  57. )
  58. }
  59. var streamIterator = stream.makeAsyncIterator()
  60. var element = await streamIterator.next()
  61. XCTAssertEqual(element, "request1")
  62. element = await streamIterator.next()
  63. XCTAssertEqual(element, "request2")
  64. element = await streamIterator.next()
  65. XCTAssertNil(element)
  66. var messages = response.messages.makeAsyncIterator()
  67. var message = try await messages.next()
  68. XCTAssertEqual(message, ["response"])
  69. message = try await messages.next()
  70. XCTAssertNil(message)
  71. let tracer = InstrumentationSystem.tracer as! TestTracer
  72. XCTAssertEqual(
  73. tracer.getEventsForTestSpan(ofOperationName: methodDescriptor.fullyQualifiedMethod).map {
  74. $0.name
  75. },
  76. [
  77. "Request started",
  78. "Received response end",
  79. ]
  80. )
  81. }
  82. }
  83. func testClientInterceptorAllEventsRecorded() async throws {
  84. let methodDescriptor = MethodDescriptor(
  85. service: "TracingInterceptorTests",
  86. method: "testClientInterceptorAllEventsRecorded"
  87. )
  88. var serviceContext = ServiceContext.topLevel
  89. let traceIDString = UUID().uuidString
  90. let interceptor = ClientTracingInterceptor(emitEventOnEachWrite: true)
  91. let (stream, continuation) = AsyncStream<String>.makeStream()
  92. serviceContext.traceID = traceIDString
  93. try await ServiceContext.withValue(serviceContext) {
  94. let response = try await interceptor.intercept(
  95. request: .init(producer: { writer in
  96. try await writer.write(contentsOf: ["request1"])
  97. try await writer.write(contentsOf: ["request2"])
  98. }),
  99. context: .init(descriptor: methodDescriptor)
  100. ) { stream, _ in
  101. // Assert the metadata contains the injected context key-value.
  102. XCTAssertEqual(stream.metadata, ["trace-id": "\(traceIDString)"])
  103. // Write into the response stream to make sure the `producer` closure's called.
  104. let writer = RPCWriter(wrapping: TestWriter(streamContinuation: continuation))
  105. try await stream.producer(writer)
  106. continuation.finish()
  107. return .init(
  108. metadata: [],
  109. bodyParts: RPCAsyncSequence(
  110. wrapping: AsyncThrowingStream<ClientResponse.Stream.Contents.BodyPart, any Error> {
  111. $0.yield(.message(["response"]))
  112. $0.finish()
  113. }
  114. )
  115. )
  116. }
  117. var streamIterator = stream.makeAsyncIterator()
  118. var element = await streamIterator.next()
  119. XCTAssertEqual(element, "request1")
  120. element = await streamIterator.next()
  121. XCTAssertEqual(element, "request2")
  122. element = await streamIterator.next()
  123. XCTAssertNil(element)
  124. var messages = response.messages.makeAsyncIterator()
  125. var message = try await messages.next()
  126. XCTAssertEqual(message, ["response"])
  127. message = try await messages.next()
  128. XCTAssertNil(message)
  129. let tracer = InstrumentationSystem.tracer as! TestTracer
  130. XCTAssertEqual(
  131. tracer.getEventsForTestSpan(ofOperationName: methodDescriptor.fullyQualifiedMethod).map {
  132. $0.name
  133. },
  134. [
  135. "Request started",
  136. // Recorded when `request1` is sent
  137. "Sending request part",
  138. "Sent request part",
  139. // Recorded when `request2` is sent
  140. "Sending request part",
  141. "Sent request part",
  142. // Recorded after all request parts have been sent
  143. "Request end",
  144. // Recorded when receiving response part
  145. "Received response part",
  146. // Recorded at end of response
  147. "Received response end",
  148. ]
  149. )
  150. }
  151. }
  152. func testServerInterceptorErrorResponse() async throws {
  153. let methodDescriptor = MethodDescriptor(
  154. service: "TracingInterceptorTests",
  155. method: "testServerInterceptorErrorResponse"
  156. )
  157. let interceptor = ServerTracingInterceptor(emitEventOnEachWrite: false)
  158. let single = ServerRequest.Single(metadata: ["trace-id": "some-trace-id"], message: [UInt8]())
  159. let response = try await interceptor.intercept(
  160. request: .init(single: single),
  161. context: .init(descriptor: methodDescriptor)
  162. ) { _, _ in
  163. ServerResponse.Stream<String>(error: .init(code: .unknown, message: "Test error"))
  164. }
  165. XCTAssertThrowsError(try response.accepted.get())
  166. let tracer = InstrumentationSystem.tracer as! TestTracer
  167. XCTAssertEqual(
  168. tracer.getEventsForTestSpan(ofOperationName: methodDescriptor.fullyQualifiedMethod).map {
  169. $0.name
  170. },
  171. [
  172. "Received request start",
  173. "Received request end",
  174. "Sent error response",
  175. ]
  176. )
  177. }
  178. func testServerInterceptor() async throws {
  179. let methodDescriptor = MethodDescriptor(
  180. service: "TracingInterceptorTests",
  181. method: "testServerInterceptor"
  182. )
  183. let (stream, continuation) = AsyncStream<String>.makeStream()
  184. let interceptor = ServerTracingInterceptor(emitEventOnEachWrite: false)
  185. let single = ServerRequest.Single(metadata: ["trace-id": "some-trace-id"], message: [UInt8]())
  186. let response = try await interceptor.intercept(
  187. request: .init(single: single),
  188. context: .init(descriptor: methodDescriptor)
  189. ) { _, _ in
  190. { [serviceContext = ServiceContext.current] in
  191. return ServerResponse.Stream<String>(
  192. accepted: .success(
  193. .init(
  194. metadata: [],
  195. producer: { writer in
  196. guard let serviceContext else {
  197. XCTFail("There should be a service context present.")
  198. return ["Result": "Test failed"]
  199. }
  200. let traceID = serviceContext.traceID
  201. XCTAssertEqual("some-trace-id", traceID)
  202. try await writer.write("response1")
  203. try await writer.write("response2")
  204. return ["Result": "Trailing metadata"]
  205. }
  206. )
  207. )
  208. )
  209. }()
  210. }
  211. let responseContents = try response.accepted.get()
  212. let trailingMetadata = try await responseContents.producer(
  213. RPCWriter(wrapping: TestWriter(streamContinuation: continuation))
  214. )
  215. continuation.finish()
  216. XCTAssertEqual(trailingMetadata, ["Result": "Trailing metadata"])
  217. var streamIterator = stream.makeAsyncIterator()
  218. var element = await streamIterator.next()
  219. XCTAssertEqual(element, "response1")
  220. element = await streamIterator.next()
  221. XCTAssertEqual(element, "response2")
  222. element = await streamIterator.next()
  223. XCTAssertNil(element)
  224. let tracer = InstrumentationSystem.tracer as! TestTracer
  225. XCTAssertEqual(
  226. tracer.getEventsForTestSpan(ofOperationName: methodDescriptor.fullyQualifiedMethod).map {
  227. $0.name
  228. },
  229. [
  230. "Received request start",
  231. "Received request end",
  232. "Sent response end",
  233. ]
  234. )
  235. }
  236. func testServerInterceptorAllEventsRecorded() async throws {
  237. let methodDescriptor = MethodDescriptor(
  238. service: "TracingInterceptorTests",
  239. method: "testServerInterceptorAllEventsRecorded"
  240. )
  241. let (stream, continuation) = AsyncStream<String>.makeStream()
  242. let interceptor = ServerTracingInterceptor(emitEventOnEachWrite: true)
  243. let single = ServerRequest.Single(metadata: ["trace-id": "some-trace-id"], message: [UInt8]())
  244. let response = try await interceptor.intercept(
  245. request: .init(single: single),
  246. context: .init(descriptor: methodDescriptor)
  247. ) { _, _ in
  248. { [serviceContext = ServiceContext.current] in
  249. return ServerResponse.Stream<String>(
  250. accepted: .success(
  251. .init(
  252. metadata: [],
  253. producer: { writer in
  254. guard let serviceContext else {
  255. XCTFail("There should be a service context present.")
  256. return ["Result": "Test failed"]
  257. }
  258. let traceID = serviceContext.traceID
  259. XCTAssertEqual("some-trace-id", traceID)
  260. try await writer.write("response1")
  261. try await writer.write("response2")
  262. return ["Result": "Trailing metadata"]
  263. }
  264. )
  265. )
  266. )
  267. }()
  268. }
  269. let responseContents = try response.accepted.get()
  270. let trailingMetadata = try await responseContents.producer(
  271. RPCWriter(wrapping: TestWriter(streamContinuation: continuation))
  272. )
  273. continuation.finish()
  274. XCTAssertEqual(trailingMetadata, ["Result": "Trailing metadata"])
  275. var streamIterator = stream.makeAsyncIterator()
  276. var element = await streamIterator.next()
  277. XCTAssertEqual(element, "response1")
  278. element = await streamIterator.next()
  279. XCTAssertEqual(element, "response2")
  280. element = await streamIterator.next()
  281. XCTAssertNil(element)
  282. let tracer = InstrumentationSystem.tracer as! TestTracer
  283. XCTAssertEqual(
  284. tracer.getEventsForTestSpan(ofOperationName: methodDescriptor.fullyQualifiedMethod).map {
  285. $0.name
  286. },
  287. [
  288. "Received request start",
  289. "Received request end",
  290. // Recorded when `response1` is sent
  291. "Sending response part",
  292. "Sent response part",
  293. // Recorded when `response2` is sent
  294. "Sending response part",
  295. "Sent response part",
  296. // Recorded when we're done sending response
  297. "Sent response end",
  298. ]
  299. )
  300. }
  301. }