GRPCAsyncServerHandlerTests.swift 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617
  1. /*
  2. * Copyright 2021, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #if compiler(>=5.6)
  17. @testable import GRPC
  18. import NIOCore
  19. import NIOEmbedded
  20. import NIOHPACK
  21. import NIOPosix
  22. import XCTest
  23. // MARK: - Tests
  24. @available(macOS 12, iOS 15, tvOS 15, watchOS 8, *)
  25. class AsyncServerHandlerTests: GRPCTestCase {
  26. private let recorder = AsyncResponseStream()
  27. private var group: EventLoopGroup!
  28. private var loop: EventLoop!
  29. override func setUp() {
  30. super.setUp()
  31. self.group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
  32. self.loop = self.group.next()
  33. }
  34. override func tearDown() {
  35. XCTAssertNoThrow(try self.group.syncShutdownGracefully())
  36. super.tearDown()
  37. }
  38. func makeCallHandlerContext(
  39. encoding: ServerMessageEncoding = .disabled
  40. ) -> CallHandlerContext {
  41. let closeFuture = self.loop.makeSucceededVoidFuture()
  42. return CallHandlerContext(
  43. errorDelegate: nil,
  44. logger: self.logger,
  45. encoding: encoding,
  46. eventLoop: self.loop,
  47. path: "/ignored",
  48. remoteAddress: nil,
  49. responseWriter: self.recorder,
  50. allocator: ByteBufferAllocator(),
  51. closeFuture: closeFuture
  52. )
  53. }
  54. private func makeHandler(
  55. encoding: ServerMessageEncoding = .disabled,
  56. callType: GRPCCallType = .bidirectionalStreaming,
  57. observer: @escaping @Sendable (
  58. GRPCAsyncRequestStream<String>,
  59. GRPCAsyncResponseStreamWriter<String>,
  60. GRPCAsyncServerCallContext
  61. ) async throws -> Void
  62. ) -> AsyncServerHandler<StringSerializer, StringDeserializer, String, String> {
  63. return AsyncServerHandler(
  64. context: self.makeCallHandlerContext(encoding: encoding),
  65. requestDeserializer: StringDeserializer(),
  66. responseSerializer: StringSerializer(),
  67. callType: callType,
  68. interceptors: [],
  69. userHandler: observer
  70. )
  71. }
  72. @Sendable
  73. private static func echo(
  74. requests: GRPCAsyncRequestStream<String>,
  75. responseStreamWriter: GRPCAsyncResponseStreamWriter<String>,
  76. context: GRPCAsyncServerCallContext
  77. ) async throws {
  78. for try await message in requests {
  79. try await responseStreamWriter.send(message)
  80. }
  81. }
  82. @Sendable
  83. private static func neverReceivesMessage(
  84. requests: GRPCAsyncRequestStream<String>,
  85. responseStreamWriter: GRPCAsyncResponseStreamWriter<String>,
  86. context: GRPCAsyncServerCallContext
  87. ) async throws {
  88. for try await message in requests {
  89. XCTFail("Unexpected message: '\(message)'")
  90. }
  91. }
  92. @Sendable
  93. private static func neverCalled(
  94. requests: GRPCAsyncRequestStream<String>,
  95. responseStreamWriter: GRPCAsyncResponseStreamWriter<String>,
  96. context: GRPCAsyncServerCallContext
  97. ) async throws {
  98. XCTFail("This observer should never be called")
  99. }
  100. func testHappyPath() async throws {
  101. let handler = self.makeHandler(
  102. observer: Self.echo(requests:responseStreamWriter:context:)
  103. )
  104. defer {
  105. XCTAssertNoThrow(try self.loop.submit { handler.finish() }.wait())
  106. }
  107. self.loop.execute {
  108. handler.receiveMetadata([:])
  109. handler.receiveMessage(ByteBuffer(string: "1"))
  110. handler.receiveMessage(ByteBuffer(string: "2"))
  111. handler.receiveMessage(ByteBuffer(string: "3"))
  112. handler.receiveEnd()
  113. }
  114. let responseStream = self.recorder.responseSequence.makeAsyncIterator()
  115. await responseStream.next().assertMetadata()
  116. for expected in ["1", "2", "3"] {
  117. await responseStream.next().assertMessage { buffer, metadata in
  118. XCTAssertEqual(buffer, .init(string: expected))
  119. XCTAssertFalse(metadata.compress)
  120. }
  121. }
  122. await responseStream.next().assertStatus { status, _ in
  123. XCTAssertEqual(status.code, .ok)
  124. }
  125. await responseStream.next().assertNil()
  126. }
  127. func testHappyPathWithCompressionEnabled() async throws {
  128. let handler = self.makeHandler(
  129. encoding: .enabled(.init(decompressionLimit: .absolute(.max))),
  130. observer: Self.echo(requests:responseStreamWriter:context:)
  131. )
  132. defer {
  133. XCTAssertNoThrow(try self.loop.submit { handler.finish() }.wait())
  134. }
  135. self.loop.execute {
  136. handler.receiveMetadata([:])
  137. handler.receiveMessage(ByteBuffer(string: "1"))
  138. handler.receiveMessage(ByteBuffer(string: "2"))
  139. handler.receiveMessage(ByteBuffer(string: "3"))
  140. handler.receiveEnd()
  141. }
  142. let responseStream = self.recorder.responseSequence.makeAsyncIterator()
  143. await responseStream.next().assertMetadata()
  144. for expected in ["1", "2", "3"] {
  145. await responseStream.next().assertMessage { buffer, metadata in
  146. XCTAssertEqual(buffer, .init(string: expected))
  147. XCTAssertTrue(metadata.compress)
  148. }
  149. }
  150. await responseStream.next().assertStatus()
  151. await responseStream.next().assertNil()
  152. }
  153. func testHappyPathWithCompressionEnabledButDisabledByCaller() async throws {
  154. let handler = self.makeHandler(
  155. encoding: .enabled(.init(decompressionLimit: .absolute(.max)))
  156. ) { requests, responseStreamWriter, context in
  157. try await context.response.compressResponses(false)
  158. return try await Self.echo(
  159. requests: requests,
  160. responseStreamWriter: responseStreamWriter,
  161. context: context
  162. )
  163. }
  164. defer {
  165. XCTAssertNoThrow(try self.loop.submit { handler.finish() }.wait())
  166. }
  167. self.loop.execute {
  168. handler.receiveMetadata([:])
  169. handler.receiveMessage(ByteBuffer(string: "1"))
  170. handler.receiveMessage(ByteBuffer(string: "2"))
  171. handler.receiveMessage(ByteBuffer(string: "3"))
  172. handler.receiveEnd()
  173. }
  174. let responseStream = self.recorder.responseSequence.makeAsyncIterator()
  175. await responseStream.next().assertMetadata()
  176. for expected in ["1", "2", "3"] {
  177. await responseStream.next().assertMessage { buffer, metadata in
  178. XCTAssertEqual(buffer, .init(string: expected))
  179. XCTAssertFalse(metadata.compress)
  180. }
  181. }
  182. await responseStream.next().assertStatus()
  183. await responseStream.next().assertNil()
  184. }
  185. func testResponseHeadersAndTrailersSentFromContext() async throws {
  186. let handler = self.makeHandler { _, responseStreamWriter, context in
  187. try await context.response.setHeaders(["pontiac": "bandit"])
  188. try await responseStreamWriter.send("1")
  189. try await context.response.setTrailers(["disco": "strangler"])
  190. }
  191. defer {
  192. XCTAssertNoThrow(try self.loop.submit { handler.finish() }.wait())
  193. }
  194. self.loop.execute {
  195. handler.receiveMetadata([:])
  196. handler.receiveEnd()
  197. }
  198. let responseStream = self.recorder.responseSequence.makeAsyncIterator()
  199. await responseStream.next().assertMetadata { headers in
  200. XCTAssertEqual(headers, ["pontiac": "bandit"])
  201. }
  202. await responseStream.next().assertMessage()
  203. await responseStream.next().assertStatus { _, trailers in
  204. XCTAssertEqual(trailers, ["disco": "strangler"])
  205. }
  206. await responseStream.next().assertNil()
  207. }
  208. func testResponseSequence() async throws {
  209. let handler = self.makeHandler { _, responseStreamWriter, _ in
  210. try await responseStreamWriter.send(contentsOf: ["1", "2", "3"])
  211. }
  212. defer {
  213. XCTAssertNoThrow(try self.loop.submit { handler.finish() }.wait())
  214. }
  215. self.loop.execute {
  216. handler.receiveMetadata([:])
  217. handler.receiveEnd()
  218. }
  219. let responseStream = self.recorder.responseSequence.makeAsyncIterator()
  220. await responseStream.next().assertMetadata { _ in }
  221. await responseStream.next().assertMessage()
  222. await responseStream.next().assertMessage()
  223. await responseStream.next().assertMessage()
  224. await responseStream.next().assertStatus { _, _ in }
  225. await responseStream.next().assertNil()
  226. }
  227. func testThrowingDeserializer() async throws {
  228. let handler = AsyncServerHandler(
  229. context: self.makeCallHandlerContext(),
  230. requestDeserializer: ThrowingStringDeserializer(),
  231. responseSerializer: StringSerializer(),
  232. callType: .bidirectionalStreaming,
  233. interceptors: [],
  234. userHandler: Self.neverReceivesMessage(requests:responseStreamWriter:context:)
  235. )
  236. defer {
  237. XCTAssertNoThrow(try self.loop.submit { handler.finish() }.wait())
  238. }
  239. self.loop.execute {
  240. handler.receiveMetadata([:])
  241. handler.receiveMessage(ByteBuffer(string: "hello"))
  242. }
  243. let responseStream = self.recorder.responseSequence.makeAsyncIterator()
  244. await responseStream.next().assertStatus { status, _ in
  245. XCTAssertEqual(status.code, .internalError)
  246. }
  247. await responseStream.next().assertNil()
  248. }
  249. func testThrowingSerializer() async throws {
  250. let handler = AsyncServerHandler(
  251. context: self.makeCallHandlerContext(),
  252. requestDeserializer: StringDeserializer(),
  253. responseSerializer: ThrowingStringSerializer(),
  254. callType: .bidirectionalStreaming,
  255. interceptors: [],
  256. userHandler: Self.echo(requests:responseStreamWriter:context:)
  257. )
  258. defer {
  259. XCTAssertNoThrow(try self.loop.submit { handler.finish() }.wait())
  260. }
  261. self.loop.execute {
  262. handler.receiveMetadata([:])
  263. handler.receiveMessage(ByteBuffer(string: "hello"))
  264. }
  265. let responseStream = self.recorder.responseSequence.makeAsyncIterator()
  266. await responseStream.next().assertMetadata()
  267. await responseStream.next().assertStatus { status, _ in
  268. XCTAssertEqual(status.code, .internalError)
  269. }
  270. await responseStream.next().assertNil()
  271. }
  272. func testReceiveMessageBeforeHeaders() async throws {
  273. let handler = self.makeHandler(
  274. observer: Self.neverCalled(requests:responseStreamWriter:context:)
  275. )
  276. defer {
  277. XCTAssertNoThrow(try self.loop.submit { handler.finish() }.wait())
  278. }
  279. self.loop.execute {
  280. handler.receiveMessage(ByteBuffer(string: "foo"))
  281. }
  282. let responseStream = self.recorder.responseSequence.makeAsyncIterator()
  283. await responseStream.next().assertStatus { status, _ in
  284. XCTAssertEqual(status.code, .internalError)
  285. }
  286. await responseStream.next().assertNil()
  287. }
  288. func testReceiveMultipleHeaders() async throws {
  289. let handler = self.makeHandler(
  290. observer: Self.neverReceivesMessage(requests:responseStreamWriter:context:)
  291. )
  292. defer {
  293. XCTAssertNoThrow(try self.loop.submit { handler.finish() }.wait())
  294. }
  295. self.loop.execute {
  296. handler.receiveMetadata([:])
  297. handler.receiveMetadata([:])
  298. }
  299. let responseStream = self.recorder.responseSequence.makeAsyncIterator()
  300. await responseStream.next().assertStatus { status, _ in
  301. XCTAssertEqual(status.code, .internalError)
  302. }
  303. await responseStream.next().assertNil()
  304. }
  305. func testFinishBeforeStarting() async throws {
  306. let handler = self.makeHandler(
  307. observer: Self.neverCalled(requests:responseStreamWriter:context:)
  308. )
  309. self.loop.execute {
  310. handler.finish()
  311. }
  312. let responseStream = self.recorder.responseSequence.makeAsyncIterator()
  313. await responseStream.next().assertStatus()
  314. await responseStream.next().assertNil()
  315. }
  316. func testFinishAfterHeaders() async throws {
  317. let handler = self.makeHandler(
  318. observer: Self.neverReceivesMessage(requests:responseStreamWriter:context:)
  319. )
  320. self.loop.execute {
  321. handler.receiveMetadata([:])
  322. handler.finish()
  323. }
  324. let responseStream = self.recorder.responseSequence.makeAsyncIterator()
  325. await responseStream.next().assertStatus()
  326. await responseStream.next().assertNil()
  327. }
  328. func testFinishAfterMessage() async throws {
  329. let handler = self.makeHandler(observer: Self.echo(requests:responseStreamWriter:context:))
  330. self.loop.execute {
  331. handler.receiveMetadata([:])
  332. handler.receiveMessage(ByteBuffer(string: "hello"))
  333. }
  334. // Await the metadata and message so we know the user function is running.
  335. let responseStream = self.recorder.responseSequence.makeAsyncIterator()
  336. await responseStream.next().assertMetadata()
  337. await responseStream.next().assertMessage()
  338. // Finish, i.e. terminate early.
  339. self.loop.execute {
  340. handler.finish()
  341. }
  342. await responseStream.next().assertStatus { status, _ in
  343. XCTAssertEqual(status.code, .internalError)
  344. }
  345. await responseStream.next().assertNil()
  346. }
  347. func testErrorAfterHeaders() async throws {
  348. let handler = self.makeHandler(observer: Self.echo(requests:responseStreamWriter:context:))
  349. self.loop.execute {
  350. handler.receiveMetadata([:])
  351. handler.receiveError(CancellationError())
  352. }
  353. // We don't send a message so we don't expect any responses. As metadata is sent lazily on the
  354. // first message we don't expect to get metadata back either.
  355. let responseStream = self.recorder.responseSequence.makeAsyncIterator()
  356. await responseStream.next().assertStatus { status, _ in
  357. XCTAssertEqual(status.code, .unavailable)
  358. }
  359. await responseStream.next().assertNil()
  360. }
  361. func testErrorAfterMessage() async throws {
  362. let handler = self.makeHandler(observer: Self.echo(requests:responseStreamWriter:context:))
  363. self.loop.execute {
  364. handler.receiveMetadata([:])
  365. handler.receiveMessage(ByteBuffer(string: "hello"))
  366. }
  367. // Wait the metadata and message; i.e. for function to have been invoked.
  368. let responseStream = self.recorder.responseSequence.makeAsyncIterator()
  369. await responseStream.next().assertMetadata()
  370. await responseStream.next().assertMessage()
  371. // Throw in an error.
  372. self.loop.execute {
  373. handler.receiveError(CancellationError())
  374. }
  375. // The RPC should end.
  376. await responseStream.next().assertStatus { status, _ in
  377. XCTAssertEqual(status.code, .unavailable)
  378. }
  379. await responseStream.next().assertNil()
  380. }
  381. func testHandlerThrowsGRPCStatusOKResultsInUnknownStatus() async throws {
  382. // Create a user function that immediately throws GRPCStatus.ok.
  383. let handler = self.makeHandler { _, _, _ in
  384. throw GRPCStatus.ok
  385. }
  386. // Send some metadata to trigger the creation of the async task with the user function.
  387. self.loop.execute {
  388. handler.receiveMetadata([:])
  389. }
  390. let responseStream = self.recorder.responseSequence.makeAsyncIterator()
  391. await responseStream.next().assertStatus { status, _ in
  392. XCTAssertEqual(status.code, .unknown)
  393. }
  394. await responseStream.next().assertNil()
  395. }
  396. func testUnaryHandlerReceivingMultipleMessages() async throws {
  397. @Sendable
  398. func neverCalled(_: String, _: GRPCAsyncServerCallContext) async throws -> String {
  399. XCTFail("Should not be called")
  400. return ""
  401. }
  402. let handler = GRPCAsyncServerHandler(
  403. context: self.makeCallHandlerContext(),
  404. requestDeserializer: StringDeserializer(),
  405. responseSerializer: StringSerializer(),
  406. interceptors: [],
  407. wrapping: neverCalled(_:_:)
  408. )
  409. defer {
  410. XCTAssertNoThrow(try self.loop.submit { handler.finish() }.wait())
  411. }
  412. self.loop.execute {
  413. handler.receiveMetadata([:])
  414. handler.receiveMessage(ByteBuffer(string: "1"))
  415. handler.receiveMessage(ByteBuffer(string: "2"))
  416. }
  417. let responseStream = self.recorder.responseSequence.makeAsyncIterator()
  418. await responseStream.next().assertStatus { status, _ in
  419. XCTAssertEqual(status.code, .internalError)
  420. }
  421. }
  422. func testServerStreamingHandlerReceivingMultipleMessages() async throws {
  423. @Sendable
  424. func neverCalled(
  425. _: String,
  426. _: GRPCAsyncResponseStreamWriter<String>,
  427. _: GRPCAsyncServerCallContext
  428. ) async throws {
  429. XCTFail("Should not be called")
  430. }
  431. let handler = GRPCAsyncServerHandler(
  432. context: self.makeCallHandlerContext(),
  433. requestDeserializer: StringDeserializer(),
  434. responseSerializer: StringSerializer(),
  435. interceptors: [],
  436. wrapping: neverCalled(_:_:_:)
  437. )
  438. defer {
  439. XCTAssertNoThrow(try self.loop.submit { handler.finish() }.wait())
  440. }
  441. self.loop.execute {
  442. handler.receiveMetadata([:])
  443. handler.receiveMessage(ByteBuffer(string: "1"))
  444. handler.receiveMessage(ByteBuffer(string: "2"))
  445. }
  446. let responseStream = self.recorder.responseSequence.makeAsyncIterator()
  447. await responseStream.next().assertStatus { status, _ in
  448. XCTAssertEqual(status.code, .internalError)
  449. }
  450. }
  451. }
  452. internal final class AsyncResponseStream: GRPCServerResponseWriter {
  453. private let source: NIOAsyncSequenceProducer<
  454. GRPCServerResponsePart<ByteBuffer>,
  455. NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark,
  456. GRPCAsyncSequenceProducerDelegate
  457. >.Source
  458. internal var responseSequence: NIOAsyncSequenceProducer<
  459. GRPCServerResponsePart<ByteBuffer>,
  460. NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark,
  461. GRPCAsyncSequenceProducerDelegate
  462. >
  463. init() {
  464. let backpressureStrategy = NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark(
  465. lowWatermark: 10,
  466. highWatermark: 50
  467. )
  468. let sequence = NIOAsyncSequenceProducer.makeSequence(
  469. elementType: GRPCServerResponsePart<ByteBuffer>.self,
  470. backPressureStrategy: backpressureStrategy,
  471. delegate: GRPCAsyncSequenceProducerDelegate()
  472. )
  473. self.source = sequence.source
  474. self.responseSequence = sequence.sequence
  475. }
  476. func sendMetadata(
  477. _ metadata: HPACKHeaders,
  478. flush: Bool,
  479. promise: EventLoopPromise<Void>?
  480. ) {
  481. _ = self.source.yield(.metadata(metadata))
  482. promise?.succeed(())
  483. }
  484. func sendMessage(
  485. _ bytes: ByteBuffer,
  486. metadata: MessageMetadata,
  487. promise: EventLoopPromise<Void>?
  488. ) {
  489. _ = self.source.yield(.message(bytes, metadata))
  490. promise?.succeed(())
  491. }
  492. func sendEnd(
  493. status: GRPCStatus,
  494. trailers: HPACKHeaders,
  495. promise: EventLoopPromise<Void>?
  496. ) {
  497. _ = self.source.yield(.end(status, trailers))
  498. self.source.finish()
  499. promise?.succeed(())
  500. }
  501. func stopRecording() {
  502. self.source.finish()
  503. }
  504. }
  505. extension Optional where Wrapped == GRPCServerResponsePart<ByteBuffer> {
  506. func assertNil() {
  507. XCTAssertNil(self)
  508. }
  509. func assertMetadata(_ verify: (HPACKHeaders) -> Void = { _ in }) {
  510. switch self {
  511. case let .some(.metadata(headers)):
  512. verify(headers)
  513. default:
  514. XCTFail("Expected metadata but value was \(String(describing: self))")
  515. }
  516. }
  517. func assertMessage(_ verify: (ByteBuffer, MessageMetadata) -> Void = { _, _ in }) {
  518. switch self {
  519. case let .some(.message(buffer, metadata)):
  520. verify(buffer, metadata)
  521. default:
  522. XCTFail("Expected message but value was \(String(describing: self))")
  523. }
  524. }
  525. func assertStatus(_ verify: (GRPCStatus, HPACKHeaders) -> Void = { _, _ in }) {
  526. switch self {
  527. case let .some(.end(status, trailers)):
  528. verify(status, trailers)
  529. default:
  530. XCTFail("Expected status but value was \(String(describing: self))")
  531. }
  532. }
  533. }
  534. #endif