ClientInterceptorPipelineTests.swift 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399
  1. /*
  2. * Copyright 2020, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. @testable import GRPC
  17. import Logging
  18. import NIO
  19. import NIOHPACK
  20. import XCTest
  21. class ClientInterceptorPipelineTests: GRPCTestCase {
  22. override func setUp() {
  23. super.setUp()
  24. self.embeddedEventLoop = EmbeddedEventLoop()
  25. }
  26. private var embeddedEventLoop: EmbeddedEventLoop!
  27. private func makePipeline<Request, Response>(
  28. requests: Request.Type = Request.self,
  29. responses: Response.Type = Response.self,
  30. details: CallDetails? = nil,
  31. interceptors: [ClientInterceptor<Request, Response>] = [],
  32. errorDelegate: ClientErrorDelegate? = nil,
  33. onCancel: @escaping (EventLoopPromise<Void>?) -> Void = { _ in },
  34. onRequestPart: @escaping (ClientRequestPart<Request>, EventLoopPromise<Void>?) -> Void,
  35. onResponsePart: @escaping (ClientResponsePart<Response>) -> Void
  36. ) -> ClientInterceptorPipeline<Request, Response> {
  37. return ClientInterceptorPipeline(
  38. eventLoop: self.embeddedEventLoop,
  39. details: details ?? self.makeCallDetails(),
  40. interceptors: interceptors,
  41. errorDelegate: errorDelegate,
  42. onCancel: onCancel,
  43. onRequestPart: onRequestPart,
  44. onResponsePart: onResponsePart
  45. )
  46. }
  47. private func makeCallDetails(timeLimit: TimeLimit = .none) -> CallDetails {
  48. return CallDetails(
  49. type: .unary,
  50. path: "ignored",
  51. authority: "ignored",
  52. scheme: "ignored",
  53. options: CallOptions(timeLimit: timeLimit, logger: self.clientLogger)
  54. )
  55. }
  56. func testEmptyPipeline() throws {
  57. var requestParts: [ClientRequestPart<String>] = []
  58. var responseParts: [ClientResponsePart<String>] = []
  59. let pipeline = self.makePipeline(
  60. requests: String.self,
  61. responses: String.self,
  62. onRequestPart: { request, promise in
  63. requestParts.append(request)
  64. XCTAssertNil(promise)
  65. },
  66. onResponsePart: { responseParts.append($0) }
  67. )
  68. // Write some request parts.
  69. pipeline.send(.metadata([:]), promise: nil)
  70. pipeline.send(.message("foo", .init(compress: false, flush: false)), promise: nil)
  71. pipeline.send(.end, promise: nil)
  72. XCTAssertEqual(requestParts.count, 3)
  73. XCTAssertEqual(requestParts[0].metadata, [:])
  74. let (message, metadata) = try assertNotNil(requestParts[1].message)
  75. XCTAssertEqual(message, "foo")
  76. XCTAssertEqual(metadata, .init(compress: false, flush: false))
  77. XCTAssertTrue(requestParts[2].isEnd)
  78. // Write some responses parts.
  79. pipeline.receive(.metadata([:]))
  80. pipeline.receive(.message("bar"))
  81. pipeline.receive(.end(.ok, [:]))
  82. XCTAssertEqual(responseParts.count, 3)
  83. XCTAssertEqual(responseParts[0].metadata, [:])
  84. XCTAssertEqual(responseParts[1].message, "bar")
  85. let (status, trailers) = try assertNotNil(responseParts[2].end)
  86. XCTAssertEqual(status, .ok)
  87. XCTAssertEqual(trailers, [:])
  88. }
  89. func testPipelineWhenClosed() throws {
  90. let pipeline = self.makePipeline(
  91. requests: String.self,
  92. responses: String.self,
  93. onRequestPart: { _, promise in
  94. XCTAssertNil(promise)
  95. },
  96. onResponsePart: {
  97. XCTAssertNotNil($0.error)
  98. }
  99. )
  100. // Fire an error; this should close the pipeline.
  101. struct DummyError: Error {}
  102. pipeline.receive(.error(DummyError()))
  103. // We're closed, writes should fail.
  104. let writePromise = pipeline.eventLoop.makePromise(of: Void.self)
  105. pipeline.send(.end, promise: writePromise)
  106. XCTAssertThrowsError(try writePromise.futureResult.wait())
  107. // As should cancellation.
  108. let cancelPromise = pipeline.eventLoop.makePromise(of: Void.self)
  109. pipeline.cancel(promise: cancelPromise)
  110. XCTAssertThrowsError(try cancelPromise.futureResult.wait())
  111. // And reads should be ignored. (We only expect errors in the response handler.)
  112. pipeline.receive(.metadata([:]))
  113. }
  114. func testPipelineWithTimeout() throws {
  115. var cancelled = false
  116. var timedOut = false
  117. class FailOnCancel<Request, Response>: ClientInterceptor<Request, Response> {
  118. override func cancel(
  119. promise: EventLoopPromise<Void>?,
  120. context: ClientInterceptorContext<Request, Response>
  121. ) {
  122. XCTFail("Unexpected cancellation")
  123. context.cancel(promise: promise)
  124. }
  125. }
  126. let deadline = NIODeadline.uptimeNanoseconds(100)
  127. let pipeline = self.makePipeline(
  128. requests: String.self,
  129. responses: String.self,
  130. details: self.makeCallDetails(timeLimit: .deadline(deadline)),
  131. interceptors: [FailOnCancel()],
  132. onCancel: { promise in
  133. assertThat(cancelled, .is(false))
  134. cancelled = true
  135. // We don't expect a promise: this cancellation is fired by the pipeline.
  136. assertThat(promise, .is(.nil()))
  137. },
  138. onRequestPart: { _, _ in
  139. XCTFail("Unexpected request part")
  140. },
  141. onResponsePart: { part in
  142. assertThat(part.error, .is(.instanceOf(GRPCError.RPCTimedOut.self)))
  143. assertThat(timedOut, .is(false))
  144. timedOut = true
  145. }
  146. )
  147. // Trigger the timeout.
  148. self.embeddedEventLoop.advanceTime(to: deadline)
  149. assertThat(timedOut, .is(true))
  150. // We'll receive a cancellation; we only get this 'onCancel' callback. We'll fail in the
  151. // interceptor if a cancellation is received.
  152. assertThat(cancelled, .is(true))
  153. // Pipeline should be torn down. Writes and cancellation should fail.
  154. let p1 = pipeline.eventLoop.makePromise(of: Void.self)
  155. pipeline.send(.end, promise: p1)
  156. assertThat(try p1.futureResult.wait(), .throws(.instanceOf(GRPCError.AlreadyComplete.self)))
  157. let p2 = pipeline.eventLoop.makePromise(of: Void.self)
  158. pipeline.cancel(promise: p2)
  159. assertThat(try p2.futureResult.wait(), .throws(.instanceOf(GRPCError.AlreadyComplete.self)))
  160. // Reads should be ignored too. (We'll fail in `onRequestPart` if this goes through.)
  161. pipeline.receive(.metadata([:]))
  162. }
  163. func testTimeoutIsCancelledOnCompletion() throws {
  164. let deadline = NIODeadline.uptimeNanoseconds(100)
  165. var cancellations = 0
  166. let pipeline = self.makePipeline(
  167. requests: String.self,
  168. responses: String.self,
  169. details: self.makeCallDetails(timeLimit: .deadline(deadline)),
  170. onCancel: { promise in
  171. assertThat(cancellations, .is(0))
  172. cancellations += 1
  173. // We don't expect a promise: this cancellation is fired by the pipeline.
  174. assertThat(promise, .is(.nil()))
  175. },
  176. onRequestPart: { _, _ in
  177. XCTFail("Unexpected request part")
  178. },
  179. onResponsePart: { part in
  180. // We only expect the end.
  181. assertThat(part.end, .is(.notNil()))
  182. }
  183. )
  184. // Read the end part.
  185. pipeline.receive(.end(.ok, [:]))
  186. // Just a single cancellation.
  187. assertThat(cancellations, .is(1))
  188. // Pass the deadline.
  189. self.embeddedEventLoop.advanceTime(to: deadline)
  190. // We should still have just the one cancellation.
  191. assertThat(cancellations, .is(1))
  192. }
  193. func testPipelineWithInterceptor() throws {
  194. // We're not testing much here, just that the interceptors are in the right order, from outbound
  195. // to inbound.
  196. let recorder = RecordingInterceptor<String, String>()
  197. let pipeline = self.makePipeline(
  198. interceptors: [StringRequestReverser(), recorder],
  199. onRequestPart: { _, _ in },
  200. onResponsePart: { _ in }
  201. )
  202. pipeline.send(.message("foo", .init(compress: false, flush: false)), promise: nil)
  203. XCTAssertEqual(recorder.requestParts.count, 1)
  204. let (message, _) = try assertNotNil(recorder.requestParts[0].message)
  205. XCTAssertEqual(message, "oof")
  206. }
  207. func testErrorDelegateIsCalled() throws {
  208. class Delegate: ClientErrorDelegate {
  209. let expectedError: GRPCError.InvalidState
  210. let file: StaticString?
  211. let line: Int?
  212. init(
  213. expected: GRPCError.InvalidState,
  214. file: StaticString?,
  215. line: Int?
  216. ) {
  217. precondition(file == nil && line == nil || file != nil && line != nil)
  218. self.expectedError = expected
  219. self.file = file
  220. self.line = line
  221. }
  222. func didCatchError(_ error: Error, logger: Logger, file: StaticString, line: Int) {
  223. XCTAssertEqual(error as? GRPCError.InvalidState, self.expectedError)
  224. // Check the file and line, if expected.
  225. if let expectedFile = self.file, let expectedLine = self.line {
  226. XCTAssertEqual("\(file)", "\(expectedFile)") // StaticString isn't Equatable
  227. XCTAssertEqual(line, expectedLine)
  228. }
  229. }
  230. }
  231. func doTest(withDelegate delegate: Delegate, error: Error) {
  232. let pipeline = self.makePipeline(
  233. requests: String.self,
  234. responses: String.self,
  235. errorDelegate: delegate,
  236. onRequestPart: { _, _ in },
  237. onResponsePart: { _ in }
  238. )
  239. pipeline.receive(.error(error))
  240. }
  241. let invalidState = GRPCError.InvalidState("invalid state")
  242. let withContext = GRPCError.WithContext(invalidState)
  243. doTest(
  244. withDelegate: .init(expected: invalidState, file: withContext.file, line: withContext.line),
  245. error: withContext
  246. )
  247. doTest(
  248. withDelegate: .init(expected: invalidState, file: nil, line: nil),
  249. error: invalidState
  250. )
  251. }
  252. }
  253. // MARK: - Test Interceptors
  254. /// A simple interceptor which records and then forwards and request and response parts it sees.
  255. class RecordingInterceptor<Request, Response>: ClientInterceptor<Request, Response> {
  256. var requestParts: [ClientRequestPart<Request>] = []
  257. var responseParts: [ClientResponsePart<Response>] = []
  258. override func send(
  259. _ part: ClientRequestPart<Request>,
  260. promise: EventLoopPromise<Void>?,
  261. context: ClientInterceptorContext<Request, Response>
  262. ) {
  263. self.requestParts.append(part)
  264. context.send(part, promise: promise)
  265. }
  266. override func receive(
  267. _ part: ClientResponsePart<Response>,
  268. context: ClientInterceptorContext<Request, Response>
  269. ) {
  270. self.responseParts.append(part)
  271. context.receive(part)
  272. }
  273. }
  274. /// An interceptor which reverses string request messages.
  275. class StringRequestReverser: ClientInterceptor<String, String> {
  276. override func send(
  277. _ part: ClientRequestPart<String>,
  278. promise: EventLoopPromise<Void>?,
  279. context: ClientInterceptorContext<String, String>
  280. ) {
  281. switch part {
  282. case let .message(value, metadata):
  283. context.send(.message(String(value.reversed()), metadata), promise: promise)
  284. default:
  285. context.send(part, promise: promise)
  286. }
  287. }
  288. }
  289. // MARK: - Request/Response part helpers
  290. extension ClientRequestPart {
  291. var metadata: HPACKHeaders? {
  292. switch self {
  293. case let .metadata(headers):
  294. return headers
  295. case .message, .end:
  296. return nil
  297. }
  298. }
  299. var message: (Request, Metadata)? {
  300. switch self {
  301. case let .message(request, metadata):
  302. return (request, metadata)
  303. case .metadata, .end:
  304. return nil
  305. }
  306. }
  307. var isEnd: Bool {
  308. switch self {
  309. case .end:
  310. return true
  311. case .metadata, .message:
  312. return false
  313. }
  314. }
  315. }
  316. extension ClientResponsePart {
  317. var metadata: HPACKHeaders? {
  318. switch self {
  319. case let .metadata(headers):
  320. return headers
  321. case .message, .end, .error:
  322. return nil
  323. }
  324. }
  325. var message: Response? {
  326. switch self {
  327. case let .message(response):
  328. return response
  329. case .metadata, .end, .error:
  330. return nil
  331. }
  332. }
  333. var end: (GRPCStatus, HPACKHeaders)? {
  334. switch self {
  335. case let .end(status, trailers):
  336. return (status, trailers)
  337. case .metadata, .message, .error:
  338. return nil
  339. }
  340. }
  341. var error: Error? {
  342. switch self {
  343. case let .error(error):
  344. return error
  345. case .metadata, .message, .end:
  346. return nil
  347. }
  348. }
  349. }