TracingInterceptorTests.swift 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. /*
  2. * Copyright 2024, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import GRPCCore
  17. import Tracing
  18. import XCTest
  19. @testable import GRPCInterceptors
  20. @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
  21. final class TracingInterceptorTests: XCTestCase {
  22. override class func setUp() {
  23. InstrumentationSystem.bootstrap(TestTracer())
  24. }
  25. #if swift(>=5.8) // Compiling these tests fails in 5.7
  26. func testClientInterceptor() async throws {
  27. var serviceContext = ServiceContext.topLevel
  28. let traceIDString = UUID().uuidString
  29. let interceptor = ClientTracingInterceptor(emitEventOnEachWrite: false)
  30. let (stream, continuation) = AsyncStream<String>.makeStream()
  31. serviceContext.traceID = traceIDString
  32. try await ServiceContext.withValue(serviceContext) {
  33. let methodDescriptor = MethodDescriptor(
  34. service: "TracingInterceptorTests",
  35. method: "testClientInterceptor"
  36. )
  37. let response = try await interceptor.intercept(
  38. request: .init(producer: { writer in
  39. try await writer.write(contentsOf: ["request1"])
  40. try await writer.write(contentsOf: ["request2"])
  41. }),
  42. context: .init(descriptor: methodDescriptor)
  43. ) { stream, _ in
  44. // Assert the metadata contains the injected context key-value.
  45. XCTAssertEqual(stream.metadata, ["trace-id": "\(traceIDString)"])
  46. // Write into the response stream to make sure the `producer` closure's called.
  47. let writer = RPCWriter(wrapping: TestWriter(streamContinuation: continuation))
  48. try await stream.producer(writer)
  49. continuation.finish()
  50. return .init(
  51. metadata: [],
  52. bodyParts: .init(
  53. wrapping: AsyncStream<ClientResponse.Stream.Contents.BodyPart> { cont in
  54. cont.yield(.message(["response"]))
  55. cont.finish()
  56. }
  57. )
  58. )
  59. }
  60. var streamIterator = stream.makeAsyncIterator()
  61. var element = await streamIterator.next()
  62. XCTAssertEqual(element, "request1")
  63. element = await streamIterator.next()
  64. XCTAssertEqual(element, "request2")
  65. element = await streamIterator.next()
  66. XCTAssertNil(element)
  67. var messages = response.messages.makeAsyncIterator()
  68. var message = try await messages.next()
  69. XCTAssertEqual(message, ["response"])
  70. message = try await messages.next()
  71. XCTAssertNil(message)
  72. let tracer = InstrumentationSystem.tracer as! TestTracer
  73. XCTAssertEqual(
  74. tracer.getEventsForTestSpan(ofOperationName: methodDescriptor.fullyQualifiedMethod).map {
  75. $0.name
  76. },
  77. [
  78. "Request started",
  79. "Received response end",
  80. ]
  81. )
  82. }
  83. }
  84. func testClientInterceptorAllEventsRecorded() async throws {
  85. let methodDescriptor = MethodDescriptor(
  86. service: "TracingInterceptorTests",
  87. method: "testClientInterceptorAllEventsRecorded"
  88. )
  89. var serviceContext = ServiceContext.topLevel
  90. let traceIDString = UUID().uuidString
  91. let interceptor = ClientTracingInterceptor(emitEventOnEachWrite: true)
  92. let (stream, continuation) = AsyncStream<String>.makeStream()
  93. serviceContext.traceID = traceIDString
  94. try await ServiceContext.withValue(serviceContext) {
  95. let response = try await interceptor.intercept(
  96. request: .init(producer: { writer in
  97. try await writer.write(contentsOf: ["request1"])
  98. try await writer.write(contentsOf: ["request2"])
  99. }),
  100. context: .init(descriptor: methodDescriptor)
  101. ) { stream, _ in
  102. // Assert the metadata contains the injected context key-value.
  103. XCTAssertEqual(stream.metadata, ["trace-id": "\(traceIDString)"])
  104. // Write into the response stream to make sure the `producer` closure's called.
  105. let writer = RPCWriter(wrapping: TestWriter(streamContinuation: continuation))
  106. try await stream.producer(writer)
  107. continuation.finish()
  108. return .init(
  109. metadata: [],
  110. bodyParts: .init(
  111. wrapping: AsyncStream<ClientResponse.Stream.Contents.BodyPart> { cont in
  112. cont.yield(.message(["response"]))
  113. cont.finish()
  114. }
  115. )
  116. )
  117. }
  118. var streamIterator = stream.makeAsyncIterator()
  119. var element = await streamIterator.next()
  120. XCTAssertEqual(element, "request1")
  121. element = await streamIterator.next()
  122. XCTAssertEqual(element, "request2")
  123. element = await streamIterator.next()
  124. XCTAssertNil(element)
  125. var messages = response.messages.makeAsyncIterator()
  126. var message = try await messages.next()
  127. XCTAssertEqual(message, ["response"])
  128. message = try await messages.next()
  129. XCTAssertNil(message)
  130. let tracer = InstrumentationSystem.tracer as! TestTracer
  131. XCTAssertEqual(
  132. tracer.getEventsForTestSpan(ofOperationName: methodDescriptor.fullyQualifiedMethod).map {
  133. $0.name
  134. },
  135. [
  136. "Request started",
  137. // Recorded when `request1` is sent
  138. "Sending request part",
  139. "Sent request part",
  140. // Recorded when `request2` is sent
  141. "Sending request part",
  142. "Sent request part",
  143. // Recorded after all request parts have been sent
  144. "Request end",
  145. // Recorded when receiving response part
  146. "Received response part",
  147. // Recorded at end of response
  148. "Received response end",
  149. ]
  150. )
  151. }
  152. }
  153. #endif // swift >= 5.7
  154. func testServerInterceptorErrorResponse() async throws {
  155. let methodDescriptor = MethodDescriptor(
  156. service: "TracingInterceptorTests",
  157. method: "testServerInterceptorErrorResponse"
  158. )
  159. let interceptor = ServerTracingInterceptor(emitEventOnEachWrite: false)
  160. let single = ServerRequest.Single(metadata: ["trace-id": "some-trace-id"], message: [UInt8]())
  161. let response = try await interceptor.intercept(
  162. request: .init(single: single),
  163. context: .init(descriptor: methodDescriptor)
  164. ) { _, _ in
  165. ServerResponse.Stream<String>(error: .init(code: .unknown, message: "Test error"))
  166. }
  167. XCTAssertThrowsError(try response.accepted.get())
  168. let tracer = InstrumentationSystem.tracer as! TestTracer
  169. XCTAssertEqual(
  170. tracer.getEventsForTestSpan(ofOperationName: methodDescriptor.fullyQualifiedMethod).map {
  171. $0.name
  172. },
  173. [
  174. "Received request start",
  175. "Received request end",
  176. "Sent error response",
  177. ]
  178. )
  179. }
  180. func testServerInterceptor() async throws {
  181. let methodDescriptor = MethodDescriptor(
  182. service: "TracingInterceptorTests",
  183. method: "testServerInterceptor"
  184. )
  185. let (stream, continuation) = AsyncStream<String>.makeStream()
  186. let interceptor = ServerTracingInterceptor(emitEventOnEachWrite: false)
  187. let single = ServerRequest.Single(metadata: ["trace-id": "some-trace-id"], message: [UInt8]())
  188. let response = try await interceptor.intercept(
  189. request: .init(single: single),
  190. context: .init(descriptor: methodDescriptor)
  191. ) { _, _ in
  192. { [serviceContext = ServiceContext.current] in
  193. return ServerResponse.Stream<String>(
  194. accepted: .success(
  195. .init(
  196. metadata: [],
  197. producer: { writer in
  198. guard let serviceContext else {
  199. XCTFail("There should be a service context present.")
  200. return ["Result": "Test failed"]
  201. }
  202. let traceID = serviceContext.traceID
  203. XCTAssertEqual("some-trace-id", traceID)
  204. try await writer.write("response1")
  205. try await writer.write("response2")
  206. return ["Result": "Trailing metadata"]
  207. }
  208. )
  209. )
  210. )
  211. }()
  212. }
  213. let responseContents = try response.accepted.get()
  214. let trailingMetadata = try await responseContents.producer(
  215. RPCWriter(wrapping: TestWriter(streamContinuation: continuation))
  216. )
  217. continuation.finish()
  218. XCTAssertEqual(trailingMetadata, ["Result": "Trailing metadata"])
  219. var streamIterator = stream.makeAsyncIterator()
  220. var element = await streamIterator.next()
  221. XCTAssertEqual(element, "response1")
  222. element = await streamIterator.next()
  223. XCTAssertEqual(element, "response2")
  224. element = await streamIterator.next()
  225. XCTAssertNil(element)
  226. let tracer = InstrumentationSystem.tracer as! TestTracer
  227. XCTAssertEqual(
  228. tracer.getEventsForTestSpan(ofOperationName: methodDescriptor.fullyQualifiedMethod).map {
  229. $0.name
  230. },
  231. [
  232. "Received request start",
  233. "Received request end",
  234. "Sent response end",
  235. ]
  236. )
  237. }
  238. func testServerInterceptorAllEventsRecorded() async throws {
  239. let methodDescriptor = MethodDescriptor(
  240. service: "TracingInterceptorTests",
  241. method: "testServerInterceptorAllEventsRecorded"
  242. )
  243. let (stream, continuation) = AsyncStream<String>.makeStream()
  244. let interceptor = ServerTracingInterceptor(emitEventOnEachWrite: true)
  245. let single = ServerRequest.Single(metadata: ["trace-id": "some-trace-id"], message: [UInt8]())
  246. let response = try await interceptor.intercept(
  247. request: .init(single: single),
  248. context: .init(descriptor: methodDescriptor)
  249. ) { _, _ in
  250. { [serviceContext = ServiceContext.current] in
  251. return ServerResponse.Stream<String>(
  252. accepted: .success(
  253. .init(
  254. metadata: [],
  255. producer: { writer in
  256. guard let serviceContext else {
  257. XCTFail("There should be a service context present.")
  258. return ["Result": "Test failed"]
  259. }
  260. let traceID = serviceContext.traceID
  261. XCTAssertEqual("some-trace-id", traceID)
  262. try await writer.write("response1")
  263. try await writer.write("response2")
  264. return ["Result": "Trailing metadata"]
  265. }
  266. )
  267. )
  268. )
  269. }()
  270. }
  271. let responseContents = try response.accepted.get()
  272. let trailingMetadata = try await responseContents.producer(
  273. RPCWriter(wrapping: TestWriter(streamContinuation: continuation))
  274. )
  275. continuation.finish()
  276. XCTAssertEqual(trailingMetadata, ["Result": "Trailing metadata"])
  277. var streamIterator = stream.makeAsyncIterator()
  278. var element = await streamIterator.next()
  279. XCTAssertEqual(element, "response1")
  280. element = await streamIterator.next()
  281. XCTAssertEqual(element, "response2")
  282. element = await streamIterator.next()
  283. XCTAssertNil(element)
  284. let tracer = InstrumentationSystem.tracer as! TestTracer
  285. XCTAssertEqual(
  286. tracer.getEventsForTestSpan(ofOperationName: methodDescriptor.fullyQualifiedMethod).map {
  287. $0.name
  288. },
  289. [
  290. "Received request start",
  291. "Received request end",
  292. // Recorded when `response1` is sent
  293. "Sending response part",
  294. "Sent response part",
  295. // Recorded when `response2` is sent
  296. "Sending response part",
  297. "Sent response part",
  298. // Recorded when we're done sending response
  299. "Sent response end",
  300. ]
  301. )
  302. }
  303. }