ClientInterceptorPipelineTests.swift 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. /*
  2. * Copyright 2020, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import Logging
  17. import NIOCore
  18. import NIOEmbedded
  19. import NIOHPACK
  20. import XCTest
  21. @testable import GRPC
  22. class ClientInterceptorPipelineTests: GRPCTestCase {
  23. override func setUp() {
  24. super.setUp()
  25. self.embeddedEventLoop = EmbeddedEventLoop()
  26. }
  27. private var embeddedEventLoop: EmbeddedEventLoop!
  28. private func makePipeline<Request, Response>(
  29. requests: Request.Type = Request.self,
  30. responses: Response.Type = Response.self,
  31. details: CallDetails? = nil,
  32. interceptors: [ClientInterceptor<Request, Response>] = [],
  33. errorDelegate: ClientErrorDelegate? = nil,
  34. onError: @escaping (Error) -> Void = { _ in },
  35. onCancel: @escaping (EventLoopPromise<Void>?) -> Void = { _ in },
  36. onRequestPart: @escaping (GRPCClientRequestPart<Request>, EventLoopPromise<Void>?) -> Void,
  37. onResponsePart: @escaping (GRPCClientResponsePart<Response>) -> Void
  38. ) -> ClientInterceptorPipeline<Request, Response> {
  39. let callDetails = details ?? self.makeCallDetails()
  40. return ClientInterceptorPipeline(
  41. eventLoop: self.embeddedEventLoop,
  42. details: callDetails,
  43. logger: callDetails.options.logger,
  44. interceptors: interceptors,
  45. errorDelegate: errorDelegate,
  46. onError: onError,
  47. onCancel: onCancel,
  48. onRequestPart: onRequestPart,
  49. onResponsePart: onResponsePart
  50. )
  51. }
  52. private func makeCallDetails(timeLimit: TimeLimit = .none) -> CallDetails {
  53. return CallDetails(
  54. type: .unary,
  55. path: "ignored",
  56. authority: "ignored",
  57. scheme: "ignored",
  58. options: CallOptions(timeLimit: timeLimit, logger: self.clientLogger)
  59. )
  60. }
  61. func testEmptyPipeline() throws {
  62. var requestParts: [GRPCClientRequestPart<String>] = []
  63. var responseParts: [GRPCClientResponsePart<String>] = []
  64. let pipeline = self.makePipeline(
  65. requests: String.self,
  66. responses: String.self,
  67. onRequestPart: { request, promise in
  68. requestParts.append(request)
  69. XCTAssertNil(promise)
  70. },
  71. onResponsePart: { responseParts.append($0) }
  72. )
  73. // Write some request parts.
  74. pipeline.send(.metadata([:]), promise: nil)
  75. pipeline.send(.message("foo", .init(compress: false, flush: false)), promise: nil)
  76. pipeline.send(.end, promise: nil)
  77. XCTAssertEqual(requestParts.count, 3)
  78. XCTAssertEqual(requestParts[0].metadata, [:])
  79. let (message, metadata) = try assertNotNil(requestParts[1].message)
  80. XCTAssertEqual(message, "foo")
  81. XCTAssertEqual(metadata, .init(compress: false, flush: false))
  82. XCTAssertTrue(requestParts[2].isEnd)
  83. // Write some responses parts.
  84. pipeline.receive(.metadata([:]))
  85. pipeline.receive(.message("bar"))
  86. pipeline.receive(.end(.ok, [:]))
  87. XCTAssertEqual(responseParts.count, 3)
  88. XCTAssertEqual(responseParts[0].metadata, [:])
  89. XCTAssertEqual(responseParts[1].message, "bar")
  90. let (status, trailers) = try assertNotNil(responseParts[2].end)
  91. XCTAssertEqual(status, .ok)
  92. XCTAssertEqual(trailers, [:])
  93. }
  94. func testPipelineWhenClosed() throws {
  95. let pipeline = self.makePipeline(
  96. requests: String.self,
  97. responses: String.self,
  98. onRequestPart: { _, promise in
  99. XCTAssertNil(promise)
  100. },
  101. onResponsePart: { _ in }
  102. )
  103. // Fire an error; this should close the pipeline.
  104. struct DummyError: Error {}
  105. pipeline.errorCaught(DummyError())
  106. // We're closed, writes should fail.
  107. let writePromise = pipeline.eventLoop.makePromise(of: Void.self)
  108. pipeline.send(.end, promise: writePromise)
  109. XCTAssertThrowsError(try writePromise.futureResult.wait())
  110. // As should cancellation.
  111. let cancelPromise = pipeline.eventLoop.makePromise(of: Void.self)
  112. pipeline.cancel(promise: cancelPromise)
  113. XCTAssertThrowsError(try cancelPromise.futureResult.wait())
  114. // And reads should be ignored. (We only expect errors in the response handler.)
  115. pipeline.receive(.metadata([:]))
  116. }
  117. func testPipelineWithTimeout() throws {
  118. var cancelled = false
  119. var timedOut = false
  120. class FailOnCancel<Request, Response>: ClientInterceptor<Request, Response>,
  121. @unchecked Sendable
  122. {
  123. override func cancel(
  124. promise: EventLoopPromise<Void>?,
  125. context: ClientInterceptorContext<Request, Response>
  126. ) {
  127. XCTFail("Unexpected cancellation")
  128. context.cancel(promise: promise)
  129. }
  130. }
  131. let deadline = NIODeadline.uptimeNanoseconds(100)
  132. let pipeline = self.makePipeline(
  133. requests: String.self,
  134. responses: String.self,
  135. details: self.makeCallDetails(timeLimit: .deadline(deadline)),
  136. interceptors: [FailOnCancel()],
  137. onError: { error in
  138. assertThat(error, .is(.instanceOf(GRPCError.RPCTimedOut.self)))
  139. assertThat(timedOut, .is(false))
  140. timedOut = true
  141. },
  142. onCancel: { promise in
  143. assertThat(cancelled, .is(false))
  144. cancelled = true
  145. // We don't expect a promise: this cancellation is fired by the pipeline.
  146. assertThat(promise, .is(.none()))
  147. },
  148. onRequestPart: { _, _ in
  149. XCTFail("Unexpected request part")
  150. },
  151. onResponsePart: { _ in
  152. XCTFail("Unexpected response part")
  153. }
  154. )
  155. // Trigger the timeout.
  156. self.embeddedEventLoop.advanceTime(to: deadline)
  157. assertThat(timedOut, .is(true))
  158. // We'll receive a cancellation; we only get this 'onCancel' callback. We'll fail in the
  159. // interceptor if a cancellation is received.
  160. assertThat(cancelled, .is(true))
  161. // Pipeline should be torn down. Writes and cancellation should fail.
  162. let p1 = pipeline.eventLoop.makePromise(of: Void.self)
  163. pipeline.send(.end, promise: p1)
  164. assertThat(try p1.futureResult.wait(), .throws(.instanceOf(GRPCError.AlreadyComplete.self)))
  165. let p2 = pipeline.eventLoop.makePromise(of: Void.self)
  166. pipeline.cancel(promise: p2)
  167. assertThat(try p2.futureResult.wait(), .throws(.instanceOf(GRPCError.AlreadyComplete.self)))
  168. // Reads should be ignored too. (We'll fail in `onRequestPart` if this goes through.)
  169. pipeline.receive(.metadata([:]))
  170. }
  171. func testTimeoutIsCancelledOnCompletion() throws {
  172. let deadline = NIODeadline.uptimeNanoseconds(100)
  173. var cancellations = 0
  174. let pipeline = self.makePipeline(
  175. requests: String.self,
  176. responses: String.self,
  177. details: self.makeCallDetails(timeLimit: .deadline(deadline)),
  178. onCancel: { promise in
  179. assertThat(cancellations, .is(0))
  180. cancellations += 1
  181. // We don't expect a promise: this cancellation is fired by the pipeline.
  182. assertThat(promise, .is(.none()))
  183. },
  184. onRequestPart: { _, _ in
  185. XCTFail("Unexpected request part")
  186. },
  187. onResponsePart: { part in
  188. // We only expect the end.
  189. assertThat(part.end, .is(.some()))
  190. }
  191. )
  192. // Read the end part.
  193. pipeline.receive(.end(.ok, [:]))
  194. // Just a single cancellation.
  195. assertThat(cancellations, .is(1))
  196. // Pass the deadline.
  197. self.embeddedEventLoop.advanceTime(to: deadline)
  198. // We should still have just the one cancellation.
  199. assertThat(cancellations, .is(1))
  200. }
  201. func testPipelineWithInterceptor() throws {
  202. // We're not testing much here, just that the interceptors are in the right order, from outbound
  203. // to inbound.
  204. let recorder = RecordingInterceptor<String, String>()
  205. let pipeline = self.makePipeline(
  206. interceptors: [StringRequestReverser(), recorder],
  207. onRequestPart: { _, _ in },
  208. onResponsePart: { _ in }
  209. )
  210. pipeline.send(.message("foo", .init(compress: false, flush: false)), promise: nil)
  211. XCTAssertEqual(recorder.requestParts.count, 1)
  212. let (message, _) = try assertNotNil(recorder.requestParts[0].message)
  213. XCTAssertEqual(message, "oof")
  214. }
  215. func testErrorDelegateIsCalled() throws {
  216. final class Delegate: ClientErrorDelegate {
  217. let expectedError: GRPCError.InvalidState
  218. let file: StaticString?
  219. let line: Int?
  220. init(
  221. expected: GRPCError.InvalidState,
  222. file: StaticString?,
  223. line: Int?
  224. ) {
  225. precondition(file == nil && line == nil || file != nil && line != nil)
  226. self.expectedError = expected
  227. self.file = file
  228. self.line = line
  229. }
  230. func didCatchError(_ error: Error, logger: Logger, file: StaticString, line: Int) {
  231. XCTAssertEqual(error as? GRPCError.InvalidState, self.expectedError)
  232. // Check the file and line, if expected.
  233. if let expectedFile = self.file, let expectedLine = self.line {
  234. XCTAssertEqual("\(file)", "\(expectedFile)") // StaticString isn't Equatable
  235. XCTAssertEqual(line, expectedLine)
  236. }
  237. }
  238. }
  239. func doTest(withDelegate delegate: Delegate, error: Error) {
  240. let pipeline = self.makePipeline(
  241. requests: String.self,
  242. responses: String.self,
  243. errorDelegate: delegate,
  244. onRequestPart: { _, _ in },
  245. onResponsePart: { _ in }
  246. )
  247. pipeline.errorCaught(error)
  248. }
  249. let invalidState = GRPCError.InvalidState("invalid state")
  250. let withContext = GRPCError.WithContext(invalidState)
  251. doTest(
  252. withDelegate: .init(expected: invalidState, file: withContext.file, line: withContext.line),
  253. error: withContext
  254. )
  255. doTest(
  256. withDelegate: .init(expected: invalidState, file: nil, line: nil),
  257. error: invalidState
  258. )
  259. }
  260. }
  261. // MARK: - Test Interceptors
  262. /// A simple interceptor which records and then forwards and request and response parts it sees.
  263. class RecordingInterceptor<Request, Response>: ClientInterceptor<Request, Response>, @unchecked
  264. Sendable
  265. {
  266. var requestParts: [GRPCClientRequestPart<Request>] = []
  267. var responseParts: [GRPCClientResponsePart<Response>] = []
  268. override func send(
  269. _ part: GRPCClientRequestPart<Request>,
  270. promise: EventLoopPromise<Void>?,
  271. context: ClientInterceptorContext<Request, Response>
  272. ) {
  273. self.requestParts.append(part)
  274. context.send(part, promise: promise)
  275. }
  276. override func receive(
  277. _ part: GRPCClientResponsePart<Response>,
  278. context: ClientInterceptorContext<Request, Response>
  279. ) {
  280. self.responseParts.append(part)
  281. context.receive(part)
  282. }
  283. }
  284. /// An interceptor which reverses string request messages.
  285. class StringRequestReverser: ClientInterceptor<String, String>, @unchecked Sendable {
  286. override func send(
  287. _ part: GRPCClientRequestPart<String>,
  288. promise: EventLoopPromise<Void>?,
  289. context: ClientInterceptorContext<String, String>
  290. ) {
  291. switch part {
  292. case let .message(value, metadata):
  293. context.send(.message(String(value.reversed()), metadata), promise: promise)
  294. default:
  295. context.send(part, promise: promise)
  296. }
  297. }
  298. }
  299. // MARK: - Request/Response part helpers
  300. extension GRPCClientRequestPart {
  301. var metadata: HPACKHeaders? {
  302. switch self {
  303. case let .metadata(headers):
  304. return headers
  305. case .message, .end:
  306. return nil
  307. }
  308. }
  309. var message: (Request, MessageMetadata)? {
  310. switch self {
  311. case let .message(request, metadata):
  312. return (request, metadata)
  313. case .metadata, .end:
  314. return nil
  315. }
  316. }
  317. var isEnd: Bool {
  318. switch self {
  319. case .end:
  320. return true
  321. case .metadata, .message:
  322. return false
  323. }
  324. }
  325. }
  326. extension GRPCClientResponsePart {
  327. var metadata: HPACKHeaders? {
  328. switch self {
  329. case let .metadata(headers):
  330. return headers
  331. case .message, .end:
  332. return nil
  333. }
  334. }
  335. var message: Response? {
  336. switch self {
  337. case let .message(response):
  338. return response
  339. case .metadata, .end:
  340. return nil
  341. }
  342. }
  343. var end: (GRPCStatus, HPACKHeaders)? {
  344. switch self {
  345. case let .end(status, trailers):
  346. return (status, trailers)
  347. case .metadata, .message:
  348. return nil
  349. }
  350. }
  351. }