ClientInterceptorPipelineTests.swift 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. /*
  2. * Copyright 2020, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. @testable import GRPC
  17. import Logging
  18. import NIO
  19. import NIOHPACK
  20. import XCTest
  21. class ClientInterceptorPipelineTests: GRPCTestCase {
  22. override func setUp() {
  23. super.setUp()
  24. self.embeddedEventLoop = EmbeddedEventLoop()
  25. }
  26. private var embeddedEventLoop: EmbeddedEventLoop!
  27. private func makePipeline<Request, Response>(
  28. requests: Request.Type = Request.self,
  29. responses: Response.Type = Response.self,
  30. details: CallDetails? = nil,
  31. interceptors: [ClientInterceptor<Request, Response>] = [],
  32. errorDelegate: ClientErrorDelegate? = nil,
  33. onError: @escaping (Error) -> Void = { _ in },
  34. onCancel: @escaping (EventLoopPromise<Void>?) -> Void = { _ in },
  35. onRequestPart: @escaping (GRPCClientRequestPart<Request>, EventLoopPromise<Void>?) -> Void,
  36. onResponsePart: @escaping (GRPCClientResponsePart<Response>) -> Void
  37. ) -> ClientInterceptorPipeline<Request, Response> {
  38. return ClientInterceptorPipeline(
  39. eventLoop: self.embeddedEventLoop,
  40. details: details ?? self.makeCallDetails(),
  41. interceptors: interceptors,
  42. errorDelegate: errorDelegate,
  43. onError: onError,
  44. onCancel: onCancel,
  45. onRequestPart: onRequestPart,
  46. onResponsePart: onResponsePart
  47. )
  48. }
  49. private func makeCallDetails(timeLimit: TimeLimit = .none) -> CallDetails {
  50. return CallDetails(
  51. type: .unary,
  52. path: "ignored",
  53. authority: "ignored",
  54. scheme: "ignored",
  55. options: CallOptions(timeLimit: timeLimit, logger: self.clientLogger)
  56. )
  57. }
  58. func testEmptyPipeline() throws {
  59. var requestParts: [GRPCClientRequestPart<String>] = []
  60. var responseParts: [GRPCClientResponsePart<String>] = []
  61. let pipeline = self.makePipeline(
  62. requests: String.self,
  63. responses: String.self,
  64. onRequestPart: { request, promise in
  65. requestParts.append(request)
  66. XCTAssertNil(promise)
  67. },
  68. onResponsePart: { responseParts.append($0) }
  69. )
  70. // Write some request parts.
  71. pipeline.send(.metadata([:]), promise: nil)
  72. pipeline.send(.message("foo", .init(compress: false, flush: false)), promise: nil)
  73. pipeline.send(.end, promise: nil)
  74. XCTAssertEqual(requestParts.count, 3)
  75. XCTAssertEqual(requestParts[0].metadata, [:])
  76. let (message, metadata) = try assertNotNil(requestParts[1].message)
  77. XCTAssertEqual(message, "foo")
  78. XCTAssertEqual(metadata, .init(compress: false, flush: false))
  79. XCTAssertTrue(requestParts[2].isEnd)
  80. // Write some responses parts.
  81. pipeline.receive(.metadata([:]))
  82. pipeline.receive(.message("bar"))
  83. pipeline.receive(.end(.ok, [:]))
  84. XCTAssertEqual(responseParts.count, 3)
  85. XCTAssertEqual(responseParts[0].metadata, [:])
  86. XCTAssertEqual(responseParts[1].message, "bar")
  87. let (status, trailers) = try assertNotNil(responseParts[2].end)
  88. XCTAssertEqual(status, .ok)
  89. XCTAssertEqual(trailers, [:])
  90. }
  91. func testPipelineWhenClosed() throws {
  92. let pipeline = self.makePipeline(
  93. requests: String.self,
  94. responses: String.self,
  95. onRequestPart: { _, promise in
  96. XCTAssertNil(promise)
  97. },
  98. onResponsePart: { _ in }
  99. )
  100. // Fire an error; this should close the pipeline.
  101. struct DummyError: Error {}
  102. pipeline.errorCaught(DummyError())
  103. // We're closed, writes should fail.
  104. let writePromise = pipeline.eventLoop.makePromise(of: Void.self)
  105. pipeline.send(.end, promise: writePromise)
  106. XCTAssertThrowsError(try writePromise.futureResult.wait())
  107. // As should cancellation.
  108. let cancelPromise = pipeline.eventLoop.makePromise(of: Void.self)
  109. pipeline.cancel(promise: cancelPromise)
  110. XCTAssertThrowsError(try cancelPromise.futureResult.wait())
  111. // And reads should be ignored. (We only expect errors in the response handler.)
  112. pipeline.receive(.metadata([:]))
  113. }
  114. func testPipelineWithTimeout() throws {
  115. var cancelled = false
  116. var timedOut = false
  117. class FailOnCancel<Request, Response>: ClientInterceptor<Request, Response> {
  118. override func cancel(
  119. promise: EventLoopPromise<Void>?,
  120. context: ClientInterceptorContext<Request, Response>
  121. ) {
  122. XCTFail("Unexpected cancellation")
  123. context.cancel(promise: promise)
  124. }
  125. }
  126. let deadline = NIODeadline.uptimeNanoseconds(100)
  127. let pipeline = self.makePipeline(
  128. requests: String.self,
  129. responses: String.self,
  130. details: self.makeCallDetails(timeLimit: .deadline(deadline)),
  131. interceptors: [FailOnCancel()],
  132. onError: { error in
  133. assertThat(error, .is(.instanceOf(GRPCError.RPCTimedOut.self)))
  134. assertThat(timedOut, .is(false))
  135. timedOut = true
  136. },
  137. onCancel: { promise in
  138. assertThat(cancelled, .is(false))
  139. cancelled = true
  140. // We don't expect a promise: this cancellation is fired by the pipeline.
  141. assertThat(promise, .is(.nil()))
  142. },
  143. onRequestPart: { _, _ in
  144. XCTFail("Unexpected request part")
  145. },
  146. onResponsePart: { _ in
  147. XCTFail("Unexpected response part")
  148. }
  149. )
  150. // Trigger the timeout.
  151. self.embeddedEventLoop.advanceTime(to: deadline)
  152. assertThat(timedOut, .is(true))
  153. // We'll receive a cancellation; we only get this 'onCancel' callback. We'll fail in the
  154. // interceptor if a cancellation is received.
  155. assertThat(cancelled, .is(true))
  156. // Pipeline should be torn down. Writes and cancellation should fail.
  157. let p1 = pipeline.eventLoop.makePromise(of: Void.self)
  158. pipeline.send(.end, promise: p1)
  159. assertThat(try p1.futureResult.wait(), .throws(.instanceOf(GRPCError.AlreadyComplete.self)))
  160. let p2 = pipeline.eventLoop.makePromise(of: Void.self)
  161. pipeline.cancel(promise: p2)
  162. assertThat(try p2.futureResult.wait(), .throws(.instanceOf(GRPCError.AlreadyComplete.self)))
  163. // Reads should be ignored too. (We'll fail in `onRequestPart` if this goes through.)
  164. pipeline.receive(.metadata([:]))
  165. }
  166. func testTimeoutIsCancelledOnCompletion() throws {
  167. let deadline = NIODeadline.uptimeNanoseconds(100)
  168. var cancellations = 0
  169. let pipeline = self.makePipeline(
  170. requests: String.self,
  171. responses: String.self,
  172. details: self.makeCallDetails(timeLimit: .deadline(deadline)),
  173. onCancel: { promise in
  174. assertThat(cancellations, .is(0))
  175. cancellations += 1
  176. // We don't expect a promise: this cancellation is fired by the pipeline.
  177. assertThat(promise, .is(.nil()))
  178. },
  179. onRequestPart: { _, _ in
  180. XCTFail("Unexpected request part")
  181. },
  182. onResponsePart: { part in
  183. // We only expect the end.
  184. assertThat(part.end, .is(.notNil()))
  185. }
  186. )
  187. // Read the end part.
  188. pipeline.receive(.end(.ok, [:]))
  189. // Just a single cancellation.
  190. assertThat(cancellations, .is(1))
  191. // Pass the deadline.
  192. self.embeddedEventLoop.advanceTime(to: deadline)
  193. // We should still have just the one cancellation.
  194. assertThat(cancellations, .is(1))
  195. }
  196. func testPipelineWithInterceptor() throws {
  197. // We're not testing much here, just that the interceptors are in the right order, from outbound
  198. // to inbound.
  199. let recorder = RecordingInterceptor<String, String>()
  200. let pipeline = self.makePipeline(
  201. interceptors: [StringRequestReverser(), recorder],
  202. onRequestPart: { _, _ in },
  203. onResponsePart: { _ in }
  204. )
  205. pipeline.send(.message("foo", .init(compress: false, flush: false)), promise: nil)
  206. XCTAssertEqual(recorder.requestParts.count, 1)
  207. let (message, _) = try assertNotNil(recorder.requestParts[0].message)
  208. XCTAssertEqual(message, "oof")
  209. }
  210. func testErrorDelegateIsCalled() throws {
  211. class Delegate: ClientErrorDelegate {
  212. let expectedError: GRPCError.InvalidState
  213. let file: StaticString?
  214. let line: Int?
  215. init(
  216. expected: GRPCError.InvalidState,
  217. file: StaticString?,
  218. line: Int?
  219. ) {
  220. precondition(file == nil && line == nil || file != nil && line != nil)
  221. self.expectedError = expected
  222. self.file = file
  223. self.line = line
  224. }
  225. func didCatchError(_ error: Error, logger: Logger, file: StaticString, line: Int) {
  226. XCTAssertEqual(error as? GRPCError.InvalidState, self.expectedError)
  227. // Check the file and line, if expected.
  228. if let expectedFile = self.file, let expectedLine = self.line {
  229. XCTAssertEqual("\(file)", "\(expectedFile)") // StaticString isn't Equatable
  230. XCTAssertEqual(line, expectedLine)
  231. }
  232. }
  233. }
  234. func doTest(withDelegate delegate: Delegate, error: Error) {
  235. let pipeline = self.makePipeline(
  236. requests: String.self,
  237. responses: String.self,
  238. errorDelegate: delegate,
  239. onRequestPart: { _, _ in },
  240. onResponsePart: { _ in }
  241. )
  242. pipeline.errorCaught(error)
  243. }
  244. let invalidState = GRPCError.InvalidState("invalid state")
  245. let withContext = GRPCError.WithContext(invalidState)
  246. doTest(
  247. withDelegate: .init(expected: invalidState, file: withContext.file, line: withContext.line),
  248. error: withContext
  249. )
  250. doTest(
  251. withDelegate: .init(expected: invalidState, file: nil, line: nil),
  252. error: invalidState
  253. )
  254. }
  255. }
  256. // MARK: - Test Interceptors
  257. /// A simple interceptor which records and then forwards and request and response parts it sees.
  258. class RecordingInterceptor<Request, Response>: ClientInterceptor<Request, Response> {
  259. var requestParts: [GRPCClientRequestPart<Request>] = []
  260. var responseParts: [GRPCClientResponsePart<Response>] = []
  261. override func send(
  262. _ part: GRPCClientRequestPart<Request>,
  263. promise: EventLoopPromise<Void>?,
  264. context: ClientInterceptorContext<Request, Response>
  265. ) {
  266. self.requestParts.append(part)
  267. context.send(part, promise: promise)
  268. }
  269. override func receive(
  270. _ part: GRPCClientResponsePart<Response>,
  271. context: ClientInterceptorContext<Request, Response>
  272. ) {
  273. self.responseParts.append(part)
  274. context.receive(part)
  275. }
  276. }
  277. /// An interceptor which reverses string request messages.
  278. class StringRequestReverser: ClientInterceptor<String, String> {
  279. override func send(
  280. _ part: GRPCClientRequestPart<String>,
  281. promise: EventLoopPromise<Void>?,
  282. context: ClientInterceptorContext<String, String>
  283. ) {
  284. switch part {
  285. case let .message(value, metadata):
  286. context.send(.message(String(value.reversed()), metadata), promise: promise)
  287. default:
  288. context.send(part, promise: promise)
  289. }
  290. }
  291. }
  292. // MARK: - Request/Response part helpers
  293. extension GRPCClientRequestPart {
  294. var metadata: HPACKHeaders? {
  295. switch self {
  296. case let .metadata(headers):
  297. return headers
  298. case .message, .end:
  299. return nil
  300. }
  301. }
  302. var message: (Request, MessageMetadata)? {
  303. switch self {
  304. case let .message(request, metadata):
  305. return (request, metadata)
  306. case .metadata, .end:
  307. return nil
  308. }
  309. }
  310. var isEnd: Bool {
  311. switch self {
  312. case .end:
  313. return true
  314. case .metadata, .message:
  315. return false
  316. }
  317. }
  318. }
  319. extension GRPCClientResponsePart {
  320. var metadata: HPACKHeaders? {
  321. switch self {
  322. case let .metadata(headers):
  323. return headers
  324. case .message, .end:
  325. return nil
  326. }
  327. }
  328. var message: Response? {
  329. switch self {
  330. case let .message(response):
  331. return response
  332. case .metadata, .end:
  333. return nil
  334. }
  335. }
  336. var end: (GRPCStatus, HPACKHeaders)? {
  337. switch self {
  338. case let .end(status, trailers):
  339. return (status, trailers)
  340. case .metadata, .message:
  341. return nil
  342. }
  343. }
  344. }