ClientInterceptorPipelineTests.swift 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. /*
  2. * Copyright 2020, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. @testable import GRPC
  17. import Logging
  18. import NIOCore
  19. import NIOEmbedded
  20. import NIOHPACK
  21. import XCTest
  22. class ClientInterceptorPipelineTests: GRPCTestCase {
  23. override func setUp() {
  24. super.setUp()
  25. self.embeddedEventLoop = EmbeddedEventLoop()
  26. }
  27. private var embeddedEventLoop: EmbeddedEventLoop!
  28. private func makePipeline<Request, Response>(
  29. requests: Request.Type = Request.self,
  30. responses: Response.Type = Response.self,
  31. details: CallDetails? = nil,
  32. interceptors: [ClientInterceptor<Request, Response>] = [],
  33. errorDelegate: ClientErrorDelegate? = nil,
  34. onError: @escaping (Error) -> Void = { _ in },
  35. onCancel: @escaping (EventLoopPromise<Void>?) -> Void = { _ in },
  36. onRequestPart: @escaping (GRPCClientRequestPart<Request>, EventLoopPromise<Void>?) -> Void,
  37. onResponsePart: @escaping (GRPCClientResponsePart<Response>) -> Void
  38. ) -> ClientInterceptorPipeline<Request, Response> {
  39. let callDetails = details ?? self.makeCallDetails()
  40. return ClientInterceptorPipeline(
  41. eventLoop: self.embeddedEventLoop,
  42. details: callDetails,
  43. logger: callDetails.options.logger.wrapped,
  44. interceptors: interceptors,
  45. errorDelegate: errorDelegate,
  46. onError: onError,
  47. onCancel: onCancel,
  48. onRequestPart: onRequestPart,
  49. onResponsePart: onResponsePart
  50. )
  51. }
  52. private func makeCallDetails(timeLimit: TimeLimit = .none) -> CallDetails {
  53. return CallDetails(
  54. type: .unary,
  55. path: "ignored",
  56. authority: "ignored",
  57. scheme: "ignored",
  58. options: CallOptions(timeLimit: timeLimit, logger: self.clientLogger)
  59. )
  60. }
  61. func testEmptyPipeline() throws {
  62. var requestParts: [GRPCClientRequestPart<String>] = []
  63. var responseParts: [GRPCClientResponsePart<String>] = []
  64. let pipeline = self.makePipeline(
  65. requests: String.self,
  66. responses: String.self,
  67. onRequestPart: { request, promise in
  68. requestParts.append(request)
  69. XCTAssertNil(promise)
  70. },
  71. onResponsePart: { responseParts.append($0) }
  72. )
  73. // Write some request parts.
  74. pipeline.send(.metadata([:]), promise: nil)
  75. pipeline.send(.message("foo", .init(compress: false, flush: false)), promise: nil)
  76. pipeline.send(.end, promise: nil)
  77. XCTAssertEqual(requestParts.count, 3)
  78. XCTAssertEqual(requestParts[0].metadata, [:])
  79. let (message, metadata) = try assertNotNil(requestParts[1].message)
  80. XCTAssertEqual(message, "foo")
  81. XCTAssertEqual(metadata, .init(compress: false, flush: false))
  82. XCTAssertTrue(requestParts[2].isEnd)
  83. // Write some responses parts.
  84. pipeline.receive(.metadata([:]))
  85. pipeline.receive(.message("bar"))
  86. pipeline.receive(.end(.ok, [:]))
  87. XCTAssertEqual(responseParts.count, 3)
  88. XCTAssertEqual(responseParts[0].metadata, [:])
  89. XCTAssertEqual(responseParts[1].message, "bar")
  90. let (status, trailers) = try assertNotNil(responseParts[2].end)
  91. XCTAssertEqual(status, .ok)
  92. XCTAssertEqual(trailers, [:])
  93. }
  94. func testPipelineWhenClosed() throws {
  95. let pipeline = self.makePipeline(
  96. requests: String.self,
  97. responses: String.self,
  98. onRequestPart: { _, promise in
  99. XCTAssertNil(promise)
  100. },
  101. onResponsePart: { _ in }
  102. )
  103. // Fire an error; this should close the pipeline.
  104. struct DummyError: Error {}
  105. pipeline.errorCaught(DummyError())
  106. // We're closed, writes should fail.
  107. let writePromise = pipeline.eventLoop.makePromise(of: Void.self)
  108. pipeline.send(.end, promise: writePromise)
  109. XCTAssertThrowsError(try writePromise.futureResult.wait())
  110. // As should cancellation.
  111. let cancelPromise = pipeline.eventLoop.makePromise(of: Void.self)
  112. pipeline.cancel(promise: cancelPromise)
  113. XCTAssertThrowsError(try cancelPromise.futureResult.wait())
  114. // And reads should be ignored. (We only expect errors in the response handler.)
  115. pipeline.receive(.metadata([:]))
  116. }
  117. func testPipelineWithTimeout() throws {
  118. var cancelled = false
  119. var timedOut = false
  120. class FailOnCancel<Request, Response>: ClientInterceptor<Request, Response> {
  121. override func cancel(
  122. promise: EventLoopPromise<Void>?,
  123. context: ClientInterceptorContext<Request, Response>
  124. ) {
  125. XCTFail("Unexpected cancellation")
  126. context.cancel(promise: promise)
  127. }
  128. }
  129. let deadline = NIODeadline.uptimeNanoseconds(100)
  130. let pipeline = self.makePipeline(
  131. requests: String.self,
  132. responses: String.self,
  133. details: self.makeCallDetails(timeLimit: .deadline(deadline)),
  134. interceptors: [FailOnCancel()],
  135. onError: { error in
  136. assertThat(error, .is(.instanceOf(GRPCError.RPCTimedOut.self)))
  137. assertThat(timedOut, .is(false))
  138. timedOut = true
  139. },
  140. onCancel: { promise in
  141. assertThat(cancelled, .is(false))
  142. cancelled = true
  143. // We don't expect a promise: this cancellation is fired by the pipeline.
  144. assertThat(promise, .is(.nil()))
  145. },
  146. onRequestPart: { _, _ in
  147. XCTFail("Unexpected request part")
  148. },
  149. onResponsePart: { _ in
  150. XCTFail("Unexpected response part")
  151. }
  152. )
  153. // Trigger the timeout.
  154. self.embeddedEventLoop.advanceTime(to: deadline)
  155. assertThat(timedOut, .is(true))
  156. // We'll receive a cancellation; we only get this 'onCancel' callback. We'll fail in the
  157. // interceptor if a cancellation is received.
  158. assertThat(cancelled, .is(true))
  159. // Pipeline should be torn down. Writes and cancellation should fail.
  160. let p1 = pipeline.eventLoop.makePromise(of: Void.self)
  161. pipeline.send(.end, promise: p1)
  162. assertThat(try p1.futureResult.wait(), .throws(.instanceOf(GRPCError.AlreadyComplete.self)))
  163. let p2 = pipeline.eventLoop.makePromise(of: Void.self)
  164. pipeline.cancel(promise: p2)
  165. assertThat(try p2.futureResult.wait(), .throws(.instanceOf(GRPCError.AlreadyComplete.self)))
  166. // Reads should be ignored too. (We'll fail in `onRequestPart` if this goes through.)
  167. pipeline.receive(.metadata([:]))
  168. }
  169. func testTimeoutIsCancelledOnCompletion() throws {
  170. let deadline = NIODeadline.uptimeNanoseconds(100)
  171. var cancellations = 0
  172. let pipeline = self.makePipeline(
  173. requests: String.self,
  174. responses: String.self,
  175. details: self.makeCallDetails(timeLimit: .deadline(deadline)),
  176. onCancel: { promise in
  177. assertThat(cancellations, .is(0))
  178. cancellations += 1
  179. // We don't expect a promise: this cancellation is fired by the pipeline.
  180. assertThat(promise, .is(.nil()))
  181. },
  182. onRequestPart: { _, _ in
  183. XCTFail("Unexpected request part")
  184. },
  185. onResponsePart: { part in
  186. // We only expect the end.
  187. assertThat(part.end, .is(.notNil()))
  188. }
  189. )
  190. // Read the end part.
  191. pipeline.receive(.end(.ok, [:]))
  192. // Just a single cancellation.
  193. assertThat(cancellations, .is(1))
  194. // Pass the deadline.
  195. self.embeddedEventLoop.advanceTime(to: deadline)
  196. // We should still have just the one cancellation.
  197. assertThat(cancellations, .is(1))
  198. }
  199. func testPipelineWithInterceptor() throws {
  200. // We're not testing much here, just that the interceptors are in the right order, from outbound
  201. // to inbound.
  202. let recorder = RecordingInterceptor<String, String>()
  203. let pipeline = self.makePipeline(
  204. interceptors: [StringRequestReverser(), recorder],
  205. onRequestPart: { _, _ in },
  206. onResponsePart: { _ in }
  207. )
  208. pipeline.send(.message("foo", .init(compress: false, flush: false)), promise: nil)
  209. XCTAssertEqual(recorder.requestParts.count, 1)
  210. let (message, _) = try assertNotNil(recorder.requestParts[0].message)
  211. XCTAssertEqual(message, "oof")
  212. }
  213. func testErrorDelegateIsCalled() throws {
  214. final class Delegate: ClientErrorDelegate {
  215. let expectedError: GRPCError.InvalidState
  216. let file: StaticString?
  217. let line: Int?
  218. init(
  219. expected: GRPCError.InvalidState,
  220. file: StaticString?,
  221. line: Int?
  222. ) {
  223. precondition(file == nil && line == nil || file != nil && line != nil)
  224. self.expectedError = expected
  225. self.file = file
  226. self.line = line
  227. }
  228. func didCatchError(_ error: Error, logger: Logger, file: StaticString, line: Int) {
  229. XCTAssertEqual(error as? GRPCError.InvalidState, self.expectedError)
  230. // Check the file and line, if expected.
  231. if let expectedFile = self.file, let expectedLine = self.line {
  232. XCTAssertEqual("\(file)", "\(expectedFile)") // StaticString isn't Equatable
  233. XCTAssertEqual(line, expectedLine)
  234. }
  235. }
  236. }
  237. func doTest(withDelegate delegate: Delegate, error: Error) {
  238. let pipeline = self.makePipeline(
  239. requests: String.self,
  240. responses: String.self,
  241. errorDelegate: delegate,
  242. onRequestPart: { _, _ in },
  243. onResponsePart: { _ in }
  244. )
  245. pipeline.errorCaught(error)
  246. }
  247. let invalidState = GRPCError.InvalidState("invalid state")
  248. let withContext = GRPCError.WithContext(invalidState)
  249. doTest(
  250. withDelegate: .init(expected: invalidState, file: withContext.file, line: withContext.line),
  251. error: withContext
  252. )
  253. doTest(
  254. withDelegate: .init(expected: invalidState, file: nil, line: nil),
  255. error: invalidState
  256. )
  257. }
  258. }
  259. // MARK: - Test Interceptors
  260. /// A simple interceptor which records and then forwards and request and response parts it sees.
  261. class RecordingInterceptor<Request, Response>: ClientInterceptor<Request, Response> {
  262. var requestParts: [GRPCClientRequestPart<Request>] = []
  263. var responseParts: [GRPCClientResponsePart<Response>] = []
  264. override func send(
  265. _ part: GRPCClientRequestPart<Request>,
  266. promise: EventLoopPromise<Void>?,
  267. context: ClientInterceptorContext<Request, Response>
  268. ) {
  269. self.requestParts.append(part)
  270. context.send(part, promise: promise)
  271. }
  272. override func receive(
  273. _ part: GRPCClientResponsePart<Response>,
  274. context: ClientInterceptorContext<Request, Response>
  275. ) {
  276. self.responseParts.append(part)
  277. context.receive(part)
  278. }
  279. }
  280. /// An interceptor which reverses string request messages.
  281. class StringRequestReverser: ClientInterceptor<String, String> {
  282. override func send(
  283. _ part: GRPCClientRequestPart<String>,
  284. promise: EventLoopPromise<Void>?,
  285. context: ClientInterceptorContext<String, String>
  286. ) {
  287. switch part {
  288. case let .message(value, metadata):
  289. context.send(.message(String(value.reversed()), metadata), promise: promise)
  290. default:
  291. context.send(part, promise: promise)
  292. }
  293. }
  294. }
  295. // MARK: - Request/Response part helpers
  296. extension GRPCClientRequestPart {
  297. var metadata: HPACKHeaders? {
  298. switch self {
  299. case let .metadata(headers):
  300. return headers
  301. case .message, .end:
  302. return nil
  303. }
  304. }
  305. var message: (Request, MessageMetadata)? {
  306. switch self {
  307. case let .message(request, metadata):
  308. return (request, metadata)
  309. case .metadata, .end:
  310. return nil
  311. }
  312. }
  313. var isEnd: Bool {
  314. switch self {
  315. case .end:
  316. return true
  317. case .metadata, .message:
  318. return false
  319. }
  320. }
  321. }
  322. extension GRPCClientResponsePart {
  323. var metadata: HPACKHeaders? {
  324. switch self {
  325. case let .metadata(headers):
  326. return headers
  327. case .message, .end:
  328. return nil
  329. }
  330. }
  331. var message: Response? {
  332. switch self {
  333. case let .message(response):
  334. return response
  335. case .metadata, .end:
  336. return nil
  337. }
  338. }
  339. var end: (GRPCStatus, HPACKHeaders)? {
  340. switch self {
  341. case let .end(status, trailers):
  342. return (status, trailers)
  343. case .metadata, .message:
  344. return nil
  345. }
  346. }
  347. }