ClientInterceptorPipelineTests.swift 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. /*
  2. * Copyright 2020, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. @testable import GRPC
  17. import Logging
  18. import NIO
  19. import NIOHPACK
  20. import XCTest
  21. class ClientInterceptorPipelineTests: GRPCTestCase {
  22. override func setUp() {
  23. super.setUp()
  24. self.embeddedEventLoop = EmbeddedEventLoop()
  25. }
  26. private var embeddedEventLoop: EmbeddedEventLoop!
  27. private func makePipeline<Request, Response>(
  28. requests: Request.Type = Request.self,
  29. responses: Response.Type = Response.self,
  30. details: CallDetails? = nil,
  31. interceptors: [ClientInterceptor<Request, Response>] = [],
  32. errorDelegate: ClientErrorDelegate? = nil,
  33. onError: @escaping (Error) -> Void = { _ in },
  34. onCancel: @escaping (EventLoopPromise<Void>?) -> Void = { _ in },
  35. onRequestPart: @escaping (GRPCClientRequestPart<Request>, EventLoopPromise<Void>?) -> Void,
  36. onResponsePart: @escaping (GRPCClientResponsePart<Response>) -> Void
  37. ) -> ClientInterceptorPipeline<Request, Response> {
  38. let callDetails = details ?? self.makeCallDetails()
  39. return ClientInterceptorPipeline(
  40. eventLoop: self.embeddedEventLoop,
  41. details: callDetails,
  42. logger: callDetails.options.logger.wrapped,
  43. interceptors: interceptors,
  44. errorDelegate: errorDelegate,
  45. onError: onError,
  46. onCancel: onCancel,
  47. onRequestPart: onRequestPart,
  48. onResponsePart: onResponsePart
  49. )
  50. }
  51. private func makeCallDetails(timeLimit: TimeLimit = .none) -> CallDetails {
  52. return CallDetails(
  53. type: .unary,
  54. path: "ignored",
  55. authority: "ignored",
  56. scheme: "ignored",
  57. options: CallOptions(timeLimit: timeLimit, logger: self.clientLogger)
  58. )
  59. }
  60. func testEmptyPipeline() throws {
  61. var requestParts: [GRPCClientRequestPart<String>] = []
  62. var responseParts: [GRPCClientResponsePart<String>] = []
  63. let pipeline = self.makePipeline(
  64. requests: String.self,
  65. responses: String.self,
  66. onRequestPart: { request, promise in
  67. requestParts.append(request)
  68. XCTAssertNil(promise)
  69. },
  70. onResponsePart: { responseParts.append($0) }
  71. )
  72. // Write some request parts.
  73. pipeline.send(.metadata([:]), promise: nil)
  74. pipeline.send(.message("foo", .init(compress: false, flush: false)), promise: nil)
  75. pipeline.send(.end, promise: nil)
  76. XCTAssertEqual(requestParts.count, 3)
  77. XCTAssertEqual(requestParts[0].metadata, [:])
  78. let (message, metadata) = try assertNotNil(requestParts[1].message)
  79. XCTAssertEqual(message, "foo")
  80. XCTAssertEqual(metadata, .init(compress: false, flush: false))
  81. XCTAssertTrue(requestParts[2].isEnd)
  82. // Write some responses parts.
  83. pipeline.receive(.metadata([:]))
  84. pipeline.receive(.message("bar"))
  85. pipeline.receive(.end(.ok, [:]))
  86. XCTAssertEqual(responseParts.count, 3)
  87. XCTAssertEqual(responseParts[0].metadata, [:])
  88. XCTAssertEqual(responseParts[1].message, "bar")
  89. let (status, trailers) = try assertNotNil(responseParts[2].end)
  90. XCTAssertEqual(status, .ok)
  91. XCTAssertEqual(trailers, [:])
  92. }
  93. func testPipelineWhenClosed() throws {
  94. let pipeline = self.makePipeline(
  95. requests: String.self,
  96. responses: String.self,
  97. onRequestPart: { _, promise in
  98. XCTAssertNil(promise)
  99. },
  100. onResponsePart: { _ in }
  101. )
  102. // Fire an error; this should close the pipeline.
  103. struct DummyError: Error {}
  104. pipeline.errorCaught(DummyError())
  105. // We're closed, writes should fail.
  106. let writePromise = pipeline.eventLoop.makePromise(of: Void.self)
  107. pipeline.send(.end, promise: writePromise)
  108. XCTAssertThrowsError(try writePromise.futureResult.wait())
  109. // As should cancellation.
  110. let cancelPromise = pipeline.eventLoop.makePromise(of: Void.self)
  111. pipeline.cancel(promise: cancelPromise)
  112. XCTAssertThrowsError(try cancelPromise.futureResult.wait())
  113. // And reads should be ignored. (We only expect errors in the response handler.)
  114. pipeline.receive(.metadata([:]))
  115. }
  116. func testPipelineWithTimeout() throws {
  117. var cancelled = false
  118. var timedOut = false
  119. class FailOnCancel<Request, Response>: ClientInterceptor<Request, Response> {
  120. override func cancel(
  121. promise: EventLoopPromise<Void>?,
  122. context: ClientInterceptorContext<Request, Response>
  123. ) {
  124. XCTFail("Unexpected cancellation")
  125. context.cancel(promise: promise)
  126. }
  127. }
  128. let deadline = NIODeadline.uptimeNanoseconds(100)
  129. let pipeline = self.makePipeline(
  130. requests: String.self,
  131. responses: String.self,
  132. details: self.makeCallDetails(timeLimit: .deadline(deadline)),
  133. interceptors: [FailOnCancel()],
  134. onError: { error in
  135. assertThat(error, .is(.instanceOf(GRPCError.RPCTimedOut.self)))
  136. assertThat(timedOut, .is(false))
  137. timedOut = true
  138. },
  139. onCancel: { promise in
  140. assertThat(cancelled, .is(false))
  141. cancelled = true
  142. // We don't expect a promise: this cancellation is fired by the pipeline.
  143. assertThat(promise, .is(.nil()))
  144. },
  145. onRequestPart: { _, _ in
  146. XCTFail("Unexpected request part")
  147. },
  148. onResponsePart: { _ in
  149. XCTFail("Unexpected response part")
  150. }
  151. )
  152. // Trigger the timeout.
  153. self.embeddedEventLoop.advanceTime(to: deadline)
  154. assertThat(timedOut, .is(true))
  155. // We'll receive a cancellation; we only get this 'onCancel' callback. We'll fail in the
  156. // interceptor if a cancellation is received.
  157. assertThat(cancelled, .is(true))
  158. // Pipeline should be torn down. Writes and cancellation should fail.
  159. let p1 = pipeline.eventLoop.makePromise(of: Void.self)
  160. pipeline.send(.end, promise: p1)
  161. assertThat(try p1.futureResult.wait(), .throws(.instanceOf(GRPCError.AlreadyComplete.self)))
  162. let p2 = pipeline.eventLoop.makePromise(of: Void.self)
  163. pipeline.cancel(promise: p2)
  164. assertThat(try p2.futureResult.wait(), .throws(.instanceOf(GRPCError.AlreadyComplete.self)))
  165. // Reads should be ignored too. (We'll fail in `onRequestPart` if this goes through.)
  166. pipeline.receive(.metadata([:]))
  167. }
  168. func testTimeoutIsCancelledOnCompletion() throws {
  169. let deadline = NIODeadline.uptimeNanoseconds(100)
  170. var cancellations = 0
  171. let pipeline = self.makePipeline(
  172. requests: String.self,
  173. responses: String.self,
  174. details: self.makeCallDetails(timeLimit: .deadline(deadline)),
  175. onCancel: { promise in
  176. assertThat(cancellations, .is(0))
  177. cancellations += 1
  178. // We don't expect a promise: this cancellation is fired by the pipeline.
  179. assertThat(promise, .is(.nil()))
  180. },
  181. onRequestPart: { _, _ in
  182. XCTFail("Unexpected request part")
  183. },
  184. onResponsePart: { part in
  185. // We only expect the end.
  186. assertThat(part.end, .is(.notNil()))
  187. }
  188. )
  189. // Read the end part.
  190. pipeline.receive(.end(.ok, [:]))
  191. // Just a single cancellation.
  192. assertThat(cancellations, .is(1))
  193. // Pass the deadline.
  194. self.embeddedEventLoop.advanceTime(to: deadline)
  195. // We should still have just the one cancellation.
  196. assertThat(cancellations, .is(1))
  197. }
  198. func testPipelineWithInterceptor() throws {
  199. // We're not testing much here, just that the interceptors are in the right order, from outbound
  200. // to inbound.
  201. let recorder = RecordingInterceptor<String, String>()
  202. let pipeline = self.makePipeline(
  203. interceptors: [StringRequestReverser(), recorder],
  204. onRequestPart: { _, _ in },
  205. onResponsePart: { _ in }
  206. )
  207. pipeline.send(.message("foo", .init(compress: false, flush: false)), promise: nil)
  208. XCTAssertEqual(recorder.requestParts.count, 1)
  209. let (message, _) = try assertNotNil(recorder.requestParts[0].message)
  210. XCTAssertEqual(message, "oof")
  211. }
  212. func testErrorDelegateIsCalled() throws {
  213. class Delegate: ClientErrorDelegate {
  214. let expectedError: GRPCError.InvalidState
  215. let file: StaticString?
  216. let line: Int?
  217. init(
  218. expected: GRPCError.InvalidState,
  219. file: StaticString?,
  220. line: Int?
  221. ) {
  222. precondition(file == nil && line == nil || file != nil && line != nil)
  223. self.expectedError = expected
  224. self.file = file
  225. self.line = line
  226. }
  227. func didCatchError(_ error: Error, logger: Logger, file: StaticString, line: Int) {
  228. XCTAssertEqual(error as? GRPCError.InvalidState, self.expectedError)
  229. // Check the file and line, if expected.
  230. if let expectedFile = self.file, let expectedLine = self.line {
  231. XCTAssertEqual("\(file)", "\(expectedFile)") // StaticString isn't Equatable
  232. XCTAssertEqual(line, expectedLine)
  233. }
  234. }
  235. }
  236. func doTest(withDelegate delegate: Delegate, error: Error) {
  237. let pipeline = self.makePipeline(
  238. requests: String.self,
  239. responses: String.self,
  240. errorDelegate: delegate,
  241. onRequestPart: { _, _ in },
  242. onResponsePart: { _ in }
  243. )
  244. pipeline.errorCaught(error)
  245. }
  246. let invalidState = GRPCError.InvalidState("invalid state")
  247. let withContext = GRPCError.WithContext(invalidState)
  248. doTest(
  249. withDelegate: .init(expected: invalidState, file: withContext.file, line: withContext.line),
  250. error: withContext
  251. )
  252. doTest(
  253. withDelegate: .init(expected: invalidState, file: nil, line: nil),
  254. error: invalidState
  255. )
  256. }
  257. }
  258. // MARK: - Test Interceptors
  259. /// A simple interceptor which records and then forwards and request and response parts it sees.
  260. class RecordingInterceptor<Request, Response>: ClientInterceptor<Request, Response> {
  261. var requestParts: [GRPCClientRequestPart<Request>] = []
  262. var responseParts: [GRPCClientResponsePart<Response>] = []
  263. override func send(
  264. _ part: GRPCClientRequestPart<Request>,
  265. promise: EventLoopPromise<Void>?,
  266. context: ClientInterceptorContext<Request, Response>
  267. ) {
  268. self.requestParts.append(part)
  269. context.send(part, promise: promise)
  270. }
  271. override func receive(
  272. _ part: GRPCClientResponsePart<Response>,
  273. context: ClientInterceptorContext<Request, Response>
  274. ) {
  275. self.responseParts.append(part)
  276. context.receive(part)
  277. }
  278. }
  279. /// An interceptor which reverses string request messages.
  280. class StringRequestReverser: ClientInterceptor<String, String> {
  281. override func send(
  282. _ part: GRPCClientRequestPart<String>,
  283. promise: EventLoopPromise<Void>?,
  284. context: ClientInterceptorContext<String, String>
  285. ) {
  286. switch part {
  287. case let .message(value, metadata):
  288. context.send(.message(String(value.reversed()), metadata), promise: promise)
  289. default:
  290. context.send(part, promise: promise)
  291. }
  292. }
  293. }
  294. // MARK: - Request/Response part helpers
  295. extension GRPCClientRequestPart {
  296. var metadata: HPACKHeaders? {
  297. switch self {
  298. case let .metadata(headers):
  299. return headers
  300. case .message, .end:
  301. return nil
  302. }
  303. }
  304. var message: (Request, MessageMetadata)? {
  305. switch self {
  306. case let .message(request, metadata):
  307. return (request, metadata)
  308. case .metadata, .end:
  309. return nil
  310. }
  311. }
  312. var isEnd: Bool {
  313. switch self {
  314. case .end:
  315. return true
  316. case .metadata, .message:
  317. return false
  318. }
  319. }
  320. }
  321. extension GRPCClientResponsePart {
  322. var metadata: HPACKHeaders? {
  323. switch self {
  324. case let .metadata(headers):
  325. return headers
  326. case .message, .end:
  327. return nil
  328. }
  329. }
  330. var message: Response? {
  331. switch self {
  332. case let .message(response):
  333. return response
  334. case .metadata, .end:
  335. return nil
  336. }
  337. }
  338. var end: (GRPCStatus, HPACKHeaders)? {
  339. switch self {
  340. case let .end(status, trailers):
  341. return (status, trailers)
  342. case .metadata, .message:
  343. return nil
  344. }
  345. }
  346. }