ClientTransportTests.swift 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
  1. /*
  2. * Copyright 2020, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import NIOCore
  17. import NIOEmbedded
  18. import XCTest
  19. @testable import GRPC
  20. class ClientTransportTests: GRPCTestCase {
  21. override func setUp() {
  22. super.setUp()
  23. self.channel = EmbeddedChannel()
  24. }
  25. // MARK: - Setup Helpers
  26. private func makeDetails(type: GRPCCallType = .unary) -> CallDetails {
  27. return CallDetails(
  28. type: type,
  29. path: "/echo.Echo/Get",
  30. authority: "localhost",
  31. scheme: "https",
  32. options: .init(logger: self.logger)
  33. )
  34. }
  35. private var channel: EmbeddedChannel!
  36. private var transport: ClientTransport<String, String>!
  37. private var eventLoop: EventLoop {
  38. return self.channel.eventLoop
  39. }
  40. private func setUpTransport(
  41. details: CallDetails? = nil,
  42. interceptors: [ClientInterceptor<String, String>] = [],
  43. onError: @escaping (Error) -> Void = { _ in },
  44. onResponsePart: @escaping (GRPCClientResponsePart<String>) -> Void = { _ in }
  45. ) {
  46. self.transport = .init(
  47. details: details ?? self.makeDetails(),
  48. eventLoop: self.eventLoop,
  49. interceptors: interceptors,
  50. serializer: AnySerializer(wrapping: StringSerializer()),
  51. deserializer: AnyDeserializer(wrapping: StringDeserializer()),
  52. errorDelegate: nil,
  53. onStart: {},
  54. onError: onError,
  55. onResponsePart: onResponsePart
  56. )
  57. }
  58. private func configureTransport(additionalHandlers handlers: [ChannelHandler] = []) {
  59. self.transport.configure {
  60. var handlers = handlers
  61. handlers.append(
  62. GRPCClientReverseCodecHandler(
  63. serializer: StringSerializer(),
  64. deserializer: StringDeserializer()
  65. )
  66. )
  67. handlers.append($0)
  68. return self.channel.pipeline.addHandlers(handlers)
  69. }
  70. }
  71. private func configureTransport(_ body: @escaping (ChannelHandler) -> EventLoopFuture<Void>) {
  72. self.transport.configure(body)
  73. }
  74. private func connect(file: StaticString = #filePath, line: UInt = #line) throws {
  75. let address = try assertNoThrow(SocketAddress(unixDomainSocketPath: "/whatever"))
  76. assertThat(
  77. try self.channel.connect(to: address).wait(),
  78. .doesNotThrow(),
  79. file: file,
  80. line: line
  81. )
  82. }
  83. private func sendRequest(
  84. _ part: GRPCClientRequestPart<String>,
  85. promise: EventLoopPromise<Void>? = nil
  86. ) {
  87. self.transport.send(part, promise: promise)
  88. }
  89. private func cancel(promise: EventLoopPromise<Void>? = nil) {
  90. self.transport.cancel(promise: promise)
  91. }
  92. private func sendResponse(
  93. _ part: _GRPCClientResponsePart<String>,
  94. file: StaticString = #filePath,
  95. line: UInt = #line
  96. ) throws {
  97. assertThat(try self.channel.writeInbound(part), .doesNotThrow(), file: file, line: line)
  98. }
  99. }
  100. // MARK: - Tests
  101. extension ClientTransportTests {
  102. func testUnaryFlow() throws {
  103. let recorder = WriteRecorder<_GRPCClientRequestPart<String>>()
  104. let recorderInterceptor = RecordingInterceptor<String, String>()
  105. self.setUpTransport(interceptors: [recorderInterceptor])
  106. // Buffer up some parts.
  107. self.sendRequest(.metadata([:]))
  108. self.sendRequest(.message("0", .init(compress: false, flush: false)))
  109. // Configure the transport and connect. This will unbuffer the parts.
  110. self.configureTransport(additionalHandlers: [recorder])
  111. try self.connect()
  112. // Send the end, this shouldn't require buffering.
  113. self.sendRequest(.end)
  114. // We should have recorded 3 parts in the 'Channel' now.
  115. assertThat(recorder.writes, .hasCount(3))
  116. // Write some responses.
  117. try self.sendResponse(.initialMetadata([:]))
  118. try self.sendResponse(.message(.init("1", compressed: false)))
  119. try self.sendResponse(.trailingMetadata([:]))
  120. try self.sendResponse(.status(.ok))
  121. // The recording interceptor should now have three parts.
  122. assertThat(recorderInterceptor.responseParts, .hasCount(3))
  123. }
  124. func testCancelWhenIdle() throws {
  125. // Set up the transport, configure it and connect.
  126. self.setUpTransport(onError: { error in
  127. assertThat(error, .is(.instanceOf(GRPCError.RPCCancelledByClient.self)))
  128. })
  129. // Cancellation should succeed.
  130. let promise = self.eventLoop.makePromise(of: Void.self)
  131. self.cancel(promise: promise)
  132. assertThat(try promise.futureResult.wait(), .doesNotThrow())
  133. }
  134. func testCancelWhenAwaitingTransport() throws {
  135. // Set up the transport, configure it and connect.
  136. self.setUpTransport(onError: { error in
  137. assertThat(error, .is(.instanceOf(GRPCError.RPCCancelledByClient.self)))
  138. })
  139. // Start configuring the transport.
  140. let transportActivatedPromise = self.eventLoop.makePromise(of: Void.self)
  141. // Let's not leak this.
  142. defer {
  143. transportActivatedPromise.succeed(())
  144. }
  145. self.configureTransport { handler in
  146. self.channel.pipeline.addHandler(handler).flatMap {
  147. transportActivatedPromise.futureResult
  148. }
  149. }
  150. // Write a request.
  151. let p1 = self.eventLoop.makePromise(of: Void.self)
  152. self.sendRequest(.metadata([:]), promise: p1)
  153. let p2 = self.eventLoop.makePromise(of: Void.self)
  154. self.cancel(promise: p2)
  155. // Cancellation should succeed, and fail the write as a result.
  156. assertThat(try p2.futureResult.wait(), .doesNotThrow())
  157. assertThat(
  158. try p1.futureResult.wait(),
  159. .throws(.instanceOf(GRPCError.RPCCancelledByClient.self))
  160. )
  161. }
  162. func testCancelWhenActivating() throws {
  163. // Set up the transport, configure it and connect.
  164. // We use bidirectional streaming here so that we also flush after writing the metadata.
  165. self.setUpTransport(
  166. details: self.makeDetails(type: .bidirectionalStreaming),
  167. onError: { error in
  168. assertThat(error, .is(.instanceOf(GRPCError.RPCCancelledByClient.self)))
  169. }
  170. )
  171. // Write a request. This will buffer.
  172. let writePromise1 = self.eventLoop.makePromise(of: Void.self)
  173. self.sendRequest(.metadata([:]), promise: writePromise1)
  174. // Chain a cancel from the first write promise.
  175. let cancelPromise = self.eventLoop.makePromise(of: Void.self)
  176. writePromise1.futureResult.whenSuccess {
  177. self.cancel(promise: cancelPromise)
  178. }
  179. // Enqueue a second write.
  180. let writePromise2 = self.eventLoop.makePromise(of: Void.self)
  181. self.sendRequest(.message("foo", .init(compress: false, flush: false)), promise: writePromise2)
  182. // Now we can configure and connect to trigger the unbuffering.
  183. // We don't actually want to record writes, by the recorder will fulfill promises as we catch
  184. // them; and we need that.
  185. self.configureTransport(additionalHandlers: [WriteRecorder<_GRPCClientRequestPart<String>>()])
  186. try self.connect()
  187. // The first write should succeed.
  188. assertThat(try writePromise1.futureResult.wait(), .doesNotThrow())
  189. // As should the cancellation.
  190. assertThat(try cancelPromise.futureResult.wait(), .doesNotThrow())
  191. // The second write should fail: the cancellation happened first.
  192. assertThat(
  193. try writePromise2.futureResult.wait(),
  194. .throws(.instanceOf(GRPCError.RPCCancelledByClient.self))
  195. )
  196. }
  197. func testCancelWhenActive() throws {
  198. // Set up the transport, configure it and connect. We'll record request parts in the `Channel`.
  199. let recorder = WriteRecorder<_GRPCClientRequestPart<String>>()
  200. self.setUpTransport()
  201. self.configureTransport(additionalHandlers: [recorder])
  202. try self.connect()
  203. // We should have an active transport now.
  204. self.sendRequest(.metadata([:]))
  205. self.sendRequest(.message("0", .init(compress: false, flush: false)))
  206. // We should have picked these parts up in the recorder.
  207. assertThat(recorder.writes, .hasCount(2))
  208. // Let's cancel now.
  209. let promise = self.eventLoop.makePromise(of: Void.self)
  210. self.cancel(promise: promise)
  211. // Cancellation should succeed.
  212. assertThat(try promise.futureResult.wait(), .doesNotThrow())
  213. }
  214. func testCancelWhenClosing() throws {
  215. self.setUpTransport()
  216. // Hold the configuration until we succeed the promise.
  217. let configuredPromise = self.eventLoop.makePromise(of: Void.self)
  218. self.configureTransport { handler in
  219. self.channel.pipeline.addHandler(handler).flatMap {
  220. configuredPromise.futureResult
  221. }
  222. }
  223. }
  224. func testCancelWhenClosed() throws {
  225. // Setup and close immediately.
  226. self.setUpTransport()
  227. self.configureTransport()
  228. try self.connect()
  229. assertThat(try self.channel.close().wait(), .doesNotThrow())
  230. // Let's cancel now.
  231. let promise = self.eventLoop.makePromise(of: Void.self)
  232. self.cancel(promise: promise)
  233. // Cancellation should fail, we're already closed.
  234. assertThat(
  235. try promise.futureResult.wait(),
  236. .throws(.instanceOf(GRPCError.AlreadyComplete.self))
  237. )
  238. }
  239. func testErrorWhenActive() throws {
  240. // Setup the transport, we only expect an error back.
  241. self.setUpTransport(onError: { error in
  242. assertThat(error, .is(.instanceOf(DummyError.self)))
  243. })
  244. // Configure and activate.
  245. self.configureTransport()
  246. try self.connect()
  247. // Send a request.
  248. let p1 = self.eventLoop.makePromise(of: Void.self)
  249. self.sendRequest(.metadata([:]), promise: p1)
  250. // The transport is for a unary call, so we need to send '.end' to emit a flush and for the
  251. // promise to be completed.
  252. self.sendRequest(.end, promise: nil)
  253. assertThat(try p1.futureResult.wait(), .doesNotThrow())
  254. // Fire an error back. (We'll see an error on the response handler.)
  255. self.channel.pipeline.fireErrorCaught(DummyError())
  256. // Writes should now fail, we're closed.
  257. let p2 = self.eventLoop.makePromise(of: Void.self)
  258. self.sendRequest(.end, promise: p2)
  259. assertThat(try p2.futureResult.wait(), .throws(.instanceOf(GRPCError.AlreadyComplete.self)))
  260. }
  261. func testConfigurationFails() throws {
  262. self.setUpTransport()
  263. let p1 = self.eventLoop.makePromise(of: Void.self)
  264. self.sendRequest(.metadata([:]), promise: p1)
  265. let p2 = self.eventLoop.makePromise(of: Void.self)
  266. self.sendRequest(.message("0", .init(compress: false, flush: false)), promise: p2)
  267. // Fail to configure the transport. Our promises should fail.
  268. self.configureTransport { _ in
  269. self.eventLoop.makeFailedFuture(DummyError())
  270. }
  271. // The promises should fail.
  272. assertThat(try p1.futureResult.wait(), .throws())
  273. assertThat(try p2.futureResult.wait(), .throws())
  274. // Cancellation should also fail because we're already closed.
  275. let p3 = self.eventLoop.makePromise(of: Void.self)
  276. self.transport.cancel(promise: p3)
  277. assertThat(try p3.futureResult.wait(), .throws(.instanceOf(GRPCError.AlreadyComplete.self)))
  278. }
  279. }
  280. // MARK: - Helper Objects
  281. class WriteRecorder<Write>: ChannelOutboundHandler {
  282. typealias OutboundIn = Write
  283. var writes: [Write] = []
  284. func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
  285. self.writes.append(self.unwrapOutboundIn(data))
  286. promise?.succeed(())
  287. }
  288. }
  289. private struct DummyError: Error {}
  290. internal struct StringSerializer: MessageSerializer {
  291. typealias Input = String
  292. func serialize(_ input: String, allocator: ByteBufferAllocator) throws -> ByteBuffer {
  293. return allocator.buffer(string: input)
  294. }
  295. }
  296. internal struct StringDeserializer: MessageDeserializer {
  297. typealias Output = String
  298. func deserialize(byteBuffer: ByteBuffer) throws -> String {
  299. var buffer = byteBuffer
  300. return buffer.readString(length: buffer.readableBytes)!
  301. }
  302. }
  303. internal struct ThrowingStringSerializer: MessageSerializer {
  304. typealias Input = String
  305. func serialize(_ input: String, allocator: ByteBufferAllocator) throws -> ByteBuffer {
  306. throw DummyError()
  307. }
  308. }
  309. internal struct ThrowingStringDeserializer: MessageDeserializer {
  310. typealias Output = String
  311. func deserialize(byteBuffer: ByteBuffer) throws -> String {
  312. throw DummyError()
  313. }
  314. }