TracingInterceptorTests.swift 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. /*
  2. * Copyright 2024, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import GRPCCore
  17. import Tracing
  18. import XCTest
  19. @testable import GRPCInterceptors
  20. @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
  21. final class TracingInterceptorTests: XCTestCase {
  22. override class func setUp() {
  23. InstrumentationSystem.bootstrap(TestTracer())
  24. }
  25. #if swift(>=5.8) // Compiling these tests fails in 5.7
  26. func testClientInterceptor() async throws {
  27. var serviceContext = ServiceContext.topLevel
  28. let traceIDString = UUID().uuidString
  29. let interceptor = ClientTracingInterceptor(emitEventOnEachWrite: false)
  30. let (stream, continuation) = AsyncStream<String>.makeStream()
  31. serviceContext.traceID = traceIDString
  32. try await ServiceContext.withValue(serviceContext) {
  33. let methodDescriptor = MethodDescriptor(
  34. service: "TracingInterceptorTests",
  35. method: "testClientInterceptor"
  36. )
  37. let response = try await interceptor.intercept(
  38. request: .init(producer: { writer in
  39. try await writer.write(contentsOf: ["request1"])
  40. try await writer.write(contentsOf: ["request2"])
  41. }),
  42. context: .init(descriptor: methodDescriptor)
  43. ) { stream, _ in
  44. // Assert the metadata contains the injected context key-value.
  45. XCTAssertEqual(stream.metadata, ["trace-id": "\(traceIDString)"])
  46. // Write into the response stream to make sure the `producer` closure's called.
  47. let writer = RPCWriter(wrapping: TestWriter(streamContinuation: continuation))
  48. try await stream.producer(writer)
  49. continuation.finish()
  50. return .init(
  51. metadata: [],
  52. bodyParts: .init(
  53. wrapping: AsyncStream<ClientResponse.Stream.Contents.BodyPart> { cont in
  54. cont.yield(.message(["response"]))
  55. cont.finish()
  56. }
  57. )
  58. )
  59. }
  60. var streamIterator = stream.makeAsyncIterator()
  61. var element = await streamIterator.next()
  62. XCTAssertEqual(element, "request1")
  63. element = await streamIterator.next()
  64. XCTAssertEqual(element, "request2")
  65. element = await streamIterator.next()
  66. XCTAssertNil(element)
  67. var messages = response.messages.makeAsyncIterator()
  68. var message = try await messages.next()
  69. XCTAssertEqual(message, ["response"])
  70. message = try await messages.next()
  71. XCTAssertNil(message)
  72. let tracer = InstrumentationSystem.tracer as! TestTracer
  73. XCTAssertEqual(
  74. tracer.getEventsForTestSpan(ofOperationName: methodDescriptor.fullyQualifiedMethod).map {
  75. $0.name
  76. },
  77. [
  78. "Request started",
  79. "Received response end",
  80. ]
  81. )
  82. }
  83. }
  84. func testClientInterceptorAllEventsRecorded() async throws {
  85. let methodDescriptor = MethodDescriptor(
  86. service: "TracingInterceptorTests",
  87. method: "testClientInterceptorAllEventsRecorded"
  88. )
  89. var serviceContext = ServiceContext.topLevel
  90. let traceIDString = UUID().uuidString
  91. let interceptor = ClientTracingInterceptor(emitEventOnEachWrite: true)
  92. let (stream, continuation) = AsyncStream<String>.makeStream()
  93. serviceContext.traceID = traceIDString
  94. try await ServiceContext.withValue(serviceContext) {
  95. let response = try await interceptor.intercept(
  96. request: .init(producer: { writer in
  97. try await writer.write(contentsOf: ["request1"])
  98. try await writer.write(contentsOf: ["request2"])
  99. }),
  100. context: .init(descriptor: methodDescriptor)
  101. ) { stream, _ in
  102. // Assert the metadata contains the injected context key-value.
  103. XCTAssertEqual(stream.metadata, ["trace-id": "\(traceIDString)"])
  104. // Write into the response stream to make sure the `producer` closure's called.
  105. let writer = RPCWriter(wrapping: TestWriter(streamContinuation: continuation))
  106. try await stream.producer(writer)
  107. continuation.finish()
  108. return .init(
  109. metadata: [],
  110. bodyParts: .init(
  111. wrapping: AsyncStream<ClientResponse.Stream.Contents.BodyPart> { cont in
  112. cont.yield(.message(["response"]))
  113. cont.finish()
  114. }
  115. )
  116. )
  117. }
  118. var streamIterator = stream.makeAsyncIterator()
  119. var element = await streamIterator.next()
  120. XCTAssertEqual(element, "request1")
  121. element = await streamIterator.next()
  122. XCTAssertEqual(element, "request2")
  123. element = await streamIterator.next()
  124. XCTAssertNil(element)
  125. var messages = response.messages.makeAsyncIterator()
  126. var message = try await messages.next()
  127. XCTAssertEqual(message, ["response"])
  128. message = try await messages.next()
  129. XCTAssertNil(message)
  130. let tracer = InstrumentationSystem.tracer as! TestTracer
  131. XCTAssertEqual(
  132. tracer.getEventsForTestSpan(ofOperationName: methodDescriptor.fullyQualifiedMethod).map {
  133. $0.name
  134. },
  135. [
  136. "Request started",
  137. // Recorded when `request1` is sent
  138. "Sending request part",
  139. "Sent request part",
  140. // Recorded when `request2` is sent
  141. "Sending request part",
  142. "Sent request part",
  143. // Recorded after all request parts have been sent
  144. "Request end",
  145. // Recorded when receiving response part
  146. "Received response part",
  147. // Recorded at end of response
  148. "Received response end",
  149. ]
  150. )
  151. }
  152. }
  153. #endif // swift >= 5.7
  154. func testServerInterceptorErrorResponse() async throws {
  155. let methodDescriptor = MethodDescriptor(
  156. service: "TracingInterceptorTests",
  157. method: "testServerInterceptorErrorResponse"
  158. )
  159. let interceptor = ServerTracingInterceptor(emitEventOnEachWrite: false)
  160. let response = try await interceptor.intercept(
  161. request: .init(single: .init(metadata: ["trace-id": "some-trace-id"], message: [])),
  162. context: .init(descriptor: methodDescriptor)
  163. ) { _, _ in
  164. ServerResponse.Stream<String>(error: .init(code: .unknown, message: "Test error"))
  165. }
  166. XCTAssertThrowsError(try response.accepted.get())
  167. let tracer = InstrumentationSystem.tracer as! TestTracer
  168. XCTAssertEqual(
  169. tracer.getEventsForTestSpan(ofOperationName: methodDescriptor.fullyQualifiedMethod).map {
  170. $0.name
  171. },
  172. [
  173. "Received request start",
  174. "Received request end",
  175. "Sent error response",
  176. ]
  177. )
  178. }
  179. func testServerInterceptor() async throws {
  180. let methodDescriptor = MethodDescriptor(
  181. service: "TracingInterceptorTests",
  182. method: "testServerInterceptor"
  183. )
  184. let (stream, continuation) = AsyncStream<String>.makeStream()
  185. let interceptor = ServerTracingInterceptor(emitEventOnEachWrite: false)
  186. let response = try await interceptor.intercept(
  187. request: .init(single: .init(metadata: ["trace-id": "some-trace-id"], message: [])),
  188. context: .init(descriptor: methodDescriptor)
  189. ) { _, _ in
  190. { [serviceContext = ServiceContext.current] in
  191. return ServerResponse.Stream<String>(
  192. accepted: .success(
  193. .init(
  194. metadata: [],
  195. producer: { writer in
  196. guard let serviceContext else {
  197. XCTFail("There should be a service context present.")
  198. return ["Result": "Test failed"]
  199. }
  200. let traceID = serviceContext.traceID
  201. XCTAssertEqual("some-trace-id", traceID)
  202. try await writer.write("response1")
  203. try await writer.write("response2")
  204. return ["Result": "Trailing metadata"]
  205. }
  206. )
  207. )
  208. )
  209. }()
  210. }
  211. let responseContents = try response.accepted.get()
  212. let trailingMetadata = try await responseContents.producer(
  213. RPCWriter(wrapping: TestWriter(streamContinuation: continuation))
  214. )
  215. continuation.finish()
  216. XCTAssertEqual(trailingMetadata, ["Result": "Trailing metadata"])
  217. var streamIterator = stream.makeAsyncIterator()
  218. var element = await streamIterator.next()
  219. XCTAssertEqual(element, "response1")
  220. element = await streamIterator.next()
  221. XCTAssertEqual(element, "response2")
  222. element = await streamIterator.next()
  223. XCTAssertNil(element)
  224. let tracer = InstrumentationSystem.tracer as! TestTracer
  225. XCTAssertEqual(
  226. tracer.getEventsForTestSpan(ofOperationName: methodDescriptor.fullyQualifiedMethod).map {
  227. $0.name
  228. },
  229. [
  230. "Received request start",
  231. "Received request end",
  232. "Sent response end",
  233. ]
  234. )
  235. }
  236. func testServerInterceptorAllEventsRecorded() async throws {
  237. let methodDescriptor = MethodDescriptor(
  238. service: "TracingInterceptorTests",
  239. method: "testServerInterceptorAllEventsRecorded"
  240. )
  241. let (stream, continuation) = AsyncStream<String>.makeStream()
  242. let interceptor = ServerTracingInterceptor(emitEventOnEachWrite: true)
  243. let response = try await interceptor.intercept(
  244. request: .init(single: .init(metadata: ["trace-id": "some-trace-id"], message: [])),
  245. context: .init(descriptor: methodDescriptor)
  246. ) { _, _ in
  247. { [serviceContext = ServiceContext.current] in
  248. return ServerResponse.Stream<String>(
  249. accepted: .success(
  250. .init(
  251. metadata: [],
  252. producer: { writer in
  253. guard let serviceContext else {
  254. XCTFail("There should be a service context present.")
  255. return ["Result": "Test failed"]
  256. }
  257. let traceID = serviceContext.traceID
  258. XCTAssertEqual("some-trace-id", traceID)
  259. try await writer.write("response1")
  260. try await writer.write("response2")
  261. return ["Result": "Trailing metadata"]
  262. }
  263. )
  264. )
  265. )
  266. }()
  267. }
  268. let responseContents = try response.accepted.get()
  269. let trailingMetadata = try await responseContents.producer(
  270. RPCWriter(wrapping: TestWriter(streamContinuation: continuation))
  271. )
  272. continuation.finish()
  273. XCTAssertEqual(trailingMetadata, ["Result": "Trailing metadata"])
  274. var streamIterator = stream.makeAsyncIterator()
  275. var element = await streamIterator.next()
  276. XCTAssertEqual(element, "response1")
  277. element = await streamIterator.next()
  278. XCTAssertEqual(element, "response2")
  279. element = await streamIterator.next()
  280. XCTAssertNil(element)
  281. let tracer = InstrumentationSystem.tracer as! TestTracer
  282. XCTAssertEqual(
  283. tracer.getEventsForTestSpan(ofOperationName: methodDescriptor.fullyQualifiedMethod).map {
  284. $0.name
  285. },
  286. [
  287. "Received request start",
  288. "Received request end",
  289. // Recorded when `response1` is sent
  290. "Sending response part",
  291. "Sent response part",
  292. // Recorded when `response2` is sent
  293. "Sending response part",
  294. "Sent response part",
  295. // Recorded when we're done sending response
  296. "Sent response end",
  297. ]
  298. )
  299. }
  300. }