GRPCServerTests.swift 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616
  1. /*
  2. * Copyright 2023, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import GRPCCore
  17. import GRPCInProcessTransport
  18. import Testing
  19. import XCTest
  20. final class GRPCServerTests: XCTestCase {
  21. func withInProcessClientConnectedToServer(
  22. services: [any RegistrableRPCService],
  23. interceptorPipeline: [ConditionalInterceptor<any ServerInterceptor>] = [],
  24. _ body: (InProcessTransport.Client, GRPCServer<InProcessTransport.Server>) async throws -> Void
  25. ) async throws {
  26. let inProcess = InProcessTransport()
  27. try await withGRPCServer(
  28. transport: inProcess.server,
  29. services: services,
  30. interceptorPipeline: interceptorPipeline
  31. ) { server in
  32. try await withThrowingTaskGroup(of: Void.self) { group in
  33. group.addTask {
  34. try await inProcess.client.connect()
  35. }
  36. try await body(inProcess.client, server)
  37. inProcess.client.beginGracefulShutdown()
  38. }
  39. }
  40. }
  41. func testServerHandlesUnary() async throws {
  42. try await self.withInProcessClientConnectedToServer(services: [BinaryEcho()]) { client, _ in
  43. try await client.withStream(
  44. descriptor: BinaryEcho.Methods.get,
  45. options: .defaults
  46. ) { stream, _ in
  47. try await stream.outbound.write(.metadata([:]))
  48. try await stream.outbound.write(.message([3, 1, 4, 1, 5]))
  49. await stream.outbound.finish()
  50. var responseParts = stream.inbound.makeAsyncIterator()
  51. let metadata = try await responseParts.next()
  52. XCTAssertMetadata(metadata)
  53. let message = try await responseParts.next()
  54. XCTAssertMessage(message) {
  55. XCTAssertEqual($0, [3, 1, 4, 1, 5])
  56. }
  57. let status = try await responseParts.next()
  58. XCTAssertStatus(status) { status, _ in
  59. XCTAssertEqual(status.code, .ok)
  60. }
  61. }
  62. }
  63. }
  64. func testServerHandlesClientStreaming() async throws {
  65. try await self.withInProcessClientConnectedToServer(services: [BinaryEcho()]) { client, _ in
  66. try await client.withStream(
  67. descriptor: BinaryEcho.Methods.collect,
  68. options: .defaults
  69. ) { stream, _ in
  70. try await stream.outbound.write(.metadata([:]))
  71. try await stream.outbound.write(.message([3]))
  72. try await stream.outbound.write(.message([1]))
  73. try await stream.outbound.write(.message([4]))
  74. try await stream.outbound.write(.message([1]))
  75. try await stream.outbound.write(.message([5]))
  76. await stream.outbound.finish()
  77. var responseParts = stream.inbound.makeAsyncIterator()
  78. let metadata = try await responseParts.next()
  79. XCTAssertMetadata(metadata)
  80. let message = try await responseParts.next()
  81. XCTAssertMessage(message) {
  82. XCTAssertEqual($0, [3, 1, 4, 1, 5])
  83. }
  84. let status = try await responseParts.next()
  85. XCTAssertStatus(status) { status, _ in
  86. XCTAssertEqual(status.code, .ok)
  87. }
  88. }
  89. }
  90. }
  91. func testServerHandlesServerStreaming() async throws {
  92. try await self.withInProcessClientConnectedToServer(services: [BinaryEcho()]) { client, _ in
  93. try await client.withStream(
  94. descriptor: BinaryEcho.Methods.expand,
  95. options: .defaults
  96. ) { stream, _ in
  97. try await stream.outbound.write(.metadata([:]))
  98. try await stream.outbound.write(.message([3, 1, 4, 1, 5]))
  99. await stream.outbound.finish()
  100. var responseParts = stream.inbound.makeAsyncIterator()
  101. let metadata = try await responseParts.next()
  102. XCTAssertMetadata(metadata)
  103. for byte in [3, 1, 4, 1, 5] as [UInt8] {
  104. let message = try await responseParts.next()
  105. XCTAssertMessage(message) {
  106. XCTAssertEqual($0, [byte])
  107. }
  108. }
  109. let status = try await responseParts.next()
  110. XCTAssertStatus(status) { status, _ in
  111. XCTAssertEqual(status.code, .ok)
  112. }
  113. }
  114. }
  115. }
  116. func testServerHandlesBidirectionalStreaming() async throws {
  117. try await self.withInProcessClientConnectedToServer(services: [BinaryEcho()]) { client, _ in
  118. try await client.withStream(
  119. descriptor: BinaryEcho.Methods.update,
  120. options: .defaults
  121. ) { stream, _ in
  122. try await stream.outbound.write(.metadata([:]))
  123. for byte in [3, 1, 4, 1, 5] as [UInt8] {
  124. try await stream.outbound.write(.message([byte]))
  125. }
  126. await stream.outbound.finish()
  127. var responseParts = stream.inbound.makeAsyncIterator()
  128. let metadata = try await responseParts.next()
  129. XCTAssertMetadata(metadata)
  130. for byte in [3, 1, 4, 1, 5] as [UInt8] {
  131. let message = try await responseParts.next()
  132. XCTAssertMessage(message) {
  133. XCTAssertEqual($0, [byte])
  134. }
  135. }
  136. let status = try await responseParts.next()
  137. XCTAssertStatus(status) { status, _ in
  138. XCTAssertEqual(status.code, .ok)
  139. }
  140. }
  141. }
  142. }
  143. func testUnimplementedMethod() async throws {
  144. try await self.withInProcessClientConnectedToServer(services: [BinaryEcho()]) { client, _ in
  145. try await client.withStream(
  146. descriptor: MethodDescriptor(fullyQualifiedService: "not", method: "implemented"),
  147. options: .defaults
  148. ) { stream, _ in
  149. try await stream.outbound.write(.metadata([:]))
  150. await stream.outbound.finish()
  151. var responseParts = stream.inbound.makeAsyncIterator()
  152. let status = try await responseParts.next()
  153. XCTAssertStatus(status) { status, _ in
  154. XCTAssertEqual(status.code, .unimplemented)
  155. }
  156. }
  157. }
  158. }
  159. func testMultipleConcurrentRequests() async throws {
  160. try await self.withInProcessClientConnectedToServer(services: [BinaryEcho()]) { client, _ in
  161. await withThrowingTaskGroup(of: Void.self) { group in
  162. for i in UInt8.min ..< UInt8.max {
  163. group.addTask {
  164. try await client.withStream(
  165. descriptor: BinaryEcho.Methods.get,
  166. options: .defaults
  167. ) { stream, _ in
  168. try await stream.outbound.write(.metadata([:]))
  169. try await stream.outbound.write(.message([i]))
  170. await stream.outbound.finish()
  171. var responseParts = stream.inbound.makeAsyncIterator()
  172. let metadata = try await responseParts.next()
  173. XCTAssertMetadata(metadata)
  174. let message = try await responseParts.next()
  175. XCTAssertMessage(message) { XCTAssertEqual($0, [i]) }
  176. let status = try await responseParts.next()
  177. XCTAssertStatus(status) { status, _ in
  178. XCTAssertEqual(status.code, .ok)
  179. }
  180. }
  181. }
  182. }
  183. }
  184. }
  185. }
  186. func testInterceptorsAreAppliedInOrder() async throws {
  187. let counter1 = AtomicCounter()
  188. let counter2 = AtomicCounter()
  189. try await self.withInProcessClientConnectedToServer(
  190. services: [BinaryEcho()],
  191. interceptorPipeline: [
  192. .apply(.requestCounter(counter1), to: .all),
  193. .apply(.rejectAll(with: RPCError(code: .unavailable, message: "")), to: .all),
  194. .apply(.requestCounter(counter2), to: .all),
  195. ]
  196. ) { client, _ in
  197. try await client.withStream(
  198. descriptor: BinaryEcho.Methods.get,
  199. options: .defaults
  200. ) { stream, _ in
  201. try await stream.outbound.write(.metadata([:]))
  202. await stream.outbound.finish()
  203. let parts = try await stream.inbound.collect()
  204. XCTAssertStatus(parts.first) { status, _ in
  205. XCTAssertEqual(status.code, .unavailable)
  206. }
  207. }
  208. }
  209. XCTAssertEqual(counter1.value, 1)
  210. XCTAssertEqual(counter2.value, 0)
  211. }
  212. func testInterceptorsAreNotAppliedToUnimplementedMethods() async throws {
  213. let counter = AtomicCounter()
  214. try await self.withInProcessClientConnectedToServer(
  215. services: [BinaryEcho()],
  216. interceptorPipeline: [.apply(.requestCounter(counter), to: .all)]
  217. ) { client, _ in
  218. try await client.withStream(
  219. descriptor: MethodDescriptor(fullyQualifiedService: "not", method: "implemented"),
  220. options: .defaults
  221. ) { stream, _ in
  222. try await stream.outbound.write(.metadata([:]))
  223. await stream.outbound.finish()
  224. let parts = try await stream.inbound.collect()
  225. XCTAssertStatus(parts.first) { status, _ in
  226. XCTAssertEqual(status.code, .unimplemented)
  227. }
  228. }
  229. }
  230. XCTAssertEqual(counter.value, 0)
  231. }
  232. func testNoNewRPCsAfterServerStopListening() async throws {
  233. try await withInProcessClientConnectedToServer(services: [BinaryEcho()]) { client, server in
  234. // Run an RPC so we know the server is up.
  235. try await self.doEchoGet(using: client)
  236. // New streams should fail immediately after this.
  237. server.beginGracefulShutdown()
  238. // RPC should fail now.
  239. await XCTAssertThrowsRPCErrorAsync {
  240. try await client.withStream(
  241. descriptor: BinaryEcho.Methods.get,
  242. options: .defaults
  243. ) { stream, _ in
  244. XCTFail("Stream shouldn't be opened")
  245. }
  246. } errorHandler: { error in
  247. XCTAssertEqual(error.code, .failedPrecondition)
  248. }
  249. }
  250. }
  251. func testInFlightRPCsCanContinueAfterServerStopListening() async throws {
  252. try await withInProcessClientConnectedToServer(services: [BinaryEcho()]) { client, server in
  253. try await client.withStream(
  254. descriptor: BinaryEcho.Methods.update,
  255. options: .defaults
  256. ) { stream, _ in
  257. try await stream.outbound.write(.metadata([:]))
  258. var iterator = stream.inbound.makeAsyncIterator()
  259. // Don't need to validate the response, just that the server is running.
  260. let metadata = try await iterator.next()
  261. XCTAssertMetadata(metadata)
  262. // New streams should fail immediately after this.
  263. server.beginGracefulShutdown()
  264. try await stream.outbound.write(.message([0]))
  265. await stream.outbound.finish()
  266. let message = try await iterator.next()
  267. XCTAssertMessage(message) { XCTAssertEqual($0, [0]) }
  268. let status = try await iterator.next()
  269. XCTAssertStatus(status)
  270. }
  271. }
  272. }
  273. func testCancelRunningServer() async throws {
  274. let inProcess = InProcessTransport()
  275. let task = Task {
  276. let server = GRPCServer(transport: inProcess.server, services: [BinaryEcho()])
  277. try await server.serve()
  278. }
  279. try await withThrowingTaskGroup(of: Void.self) { group in
  280. group.addTask {
  281. try? await inProcess.client.connect()
  282. }
  283. try await self.doEchoGet(using: inProcess.client)
  284. // The server must be running at this point as an RPC has completed.
  285. task.cancel()
  286. try await task.value
  287. group.cancelAll()
  288. }
  289. }
  290. func testTestRunStoppedServer() async throws {
  291. let server = GRPCServer(
  292. transport: InProcessTransport.Server(peer: "in-process:1234"),
  293. services: []
  294. )
  295. // Run the server.
  296. let task = Task { try await server.serve() }
  297. task.cancel()
  298. try await task.value
  299. // Server is stopped, should throw an error.
  300. await XCTAssertThrowsErrorAsync(ofType: RuntimeError.self) {
  301. try await server.serve()
  302. } errorHandler: { error in
  303. XCTAssertEqual(error.code, .serverIsStopped)
  304. }
  305. }
  306. func testRunServerWhenTransportThrows() async throws {
  307. let server = GRPCServer(transport: ThrowOnRunServerTransport(), services: [])
  308. await XCTAssertThrowsErrorAsync(ofType: RuntimeError.self) {
  309. try await server.serve()
  310. } errorHandler: { error in
  311. XCTAssertEqual(error.code, .transportError)
  312. }
  313. }
  314. private func doEchoGet(using transport: some ClientTransport<[UInt8]>) async throws {
  315. try await transport.withStream(
  316. descriptor: BinaryEcho.Methods.get,
  317. options: .defaults
  318. ) { stream, _ in
  319. try await stream.outbound.write(.metadata([:]))
  320. try await stream.outbound.write(.message([0]))
  321. await stream.outbound.finish()
  322. // Don't need to validate the response, just that the server is running.
  323. let parts = try await stream.inbound.collect()
  324. XCTAssertEqual(parts.count, 3)
  325. }
  326. }
  327. }
  328. @Suite("GRPC Server Tests")
  329. struct ServerTests {
  330. @Test("Interceptors are applied only to specified services")
  331. func testInterceptorsAreAppliedToSpecifiedServices() async throws {
  332. let onlyBinaryEchoCounter = AtomicCounter()
  333. let allServicesCounter = AtomicCounter()
  334. let onlyHelloWorldCounter = AtomicCounter()
  335. let bothServicesCounter = AtomicCounter()
  336. try await self.withInProcessClientConnectedToServer(
  337. services: [BinaryEcho(), HelloWorld()],
  338. interceptorPipeline: [
  339. .apply(
  340. .requestCounter(onlyBinaryEchoCounter),
  341. to: .services([BinaryEcho.serviceDescriptor])
  342. ),
  343. .apply(.requestCounter(allServicesCounter), to: .all),
  344. .apply(
  345. .requestCounter(onlyHelloWorldCounter),
  346. to: .services([HelloWorld.serviceDescriptor])
  347. ),
  348. .apply(
  349. .requestCounter(bothServicesCounter),
  350. to: .services([BinaryEcho.serviceDescriptor, HelloWorld.serviceDescriptor])
  351. ),
  352. ]
  353. ) { client, _ in
  354. // Make a request to the `BinaryEcho` service and assert that only
  355. // the counters associated to interceptors that apply to it are incremented.
  356. try await client.withStream(
  357. descriptor: BinaryEcho.Methods.get,
  358. options: .defaults
  359. ) { stream, _ in
  360. try await stream.outbound.write(.metadata([:]))
  361. try await stream.outbound.write(.message(Array("hello".utf8)))
  362. await stream.outbound.finish()
  363. var responseParts = stream.inbound.makeAsyncIterator()
  364. let metadata = try await responseParts.next()
  365. self.assertMetadata(metadata)
  366. let message = try await responseParts.next()
  367. self.assertMessage(message) {
  368. #expect($0 == Array("hello".utf8))
  369. }
  370. let status = try await responseParts.next()
  371. self.assertStatus(status) { status, _ in
  372. #expect(status.code == .ok, Comment(rawValue: status.description))
  373. }
  374. }
  375. #expect(onlyBinaryEchoCounter.value == 1)
  376. #expect(allServicesCounter.value == 1)
  377. #expect(onlyHelloWorldCounter.value == 0)
  378. #expect(bothServicesCounter.value == 1)
  379. // Now, make a request to the `HelloWorld` service and assert that only
  380. // the counters associated to interceptors that apply to it are incremented.
  381. try await client.withStream(
  382. descriptor: HelloWorld.Methods.sayHello,
  383. options: .defaults
  384. ) { stream, _ in
  385. try await stream.outbound.write(.metadata([:]))
  386. try await stream.outbound.write(.message(Array("Swift".utf8)))
  387. await stream.outbound.finish()
  388. var responseParts = stream.inbound.makeAsyncIterator()
  389. let metadata = try await responseParts.next()
  390. self.assertMetadata(metadata)
  391. let message = try await responseParts.next()
  392. self.assertMessage(message) {
  393. #expect($0 == Array("Hello, Swift!".utf8))
  394. }
  395. let status = try await responseParts.next()
  396. self.assertStatus(status) { status, _ in
  397. #expect(status.code == .ok, Comment(rawValue: status.description))
  398. }
  399. }
  400. #expect(onlyBinaryEchoCounter.value == 1)
  401. #expect(allServicesCounter.value == 2)
  402. #expect(onlyHelloWorldCounter.value == 1)
  403. #expect(bothServicesCounter.value == 2)
  404. }
  405. }
  406. @Test("Interceptors are applied only to specified methods")
  407. func testInterceptorsAreAppliedToSpecifiedMethods() async throws {
  408. let onlyBinaryEchoGetCounter = AtomicCounter()
  409. let onlyBinaryEchoCollectCounter = AtomicCounter()
  410. let bothBinaryEchoMethodsCounter = AtomicCounter()
  411. let allMethodsCounter = AtomicCounter()
  412. try await self.withInProcessClientConnectedToServer(
  413. services: [BinaryEcho()],
  414. interceptorPipeline: [
  415. .apply(
  416. .requestCounter(onlyBinaryEchoGetCounter),
  417. to: .methods([BinaryEcho.Methods.get])
  418. ),
  419. .apply(.requestCounter(allMethodsCounter), to: .all),
  420. .apply(
  421. .requestCounter(onlyBinaryEchoCollectCounter),
  422. to: .methods([BinaryEcho.Methods.collect])
  423. ),
  424. .apply(
  425. .requestCounter(bothBinaryEchoMethodsCounter),
  426. to: .methods([BinaryEcho.Methods.get, BinaryEcho.Methods.collect])
  427. ),
  428. ]
  429. ) { client, _ in
  430. // Make a request to the `BinaryEcho/get` method and assert that only
  431. // the counters associated to interceptors that apply to it are incremented.
  432. try await client.withStream(
  433. descriptor: BinaryEcho.Methods.get,
  434. options: .defaults
  435. ) { stream, _ in
  436. try await stream.outbound.write(.metadata([:]))
  437. try await stream.outbound.write(.message(Array("hello".utf8)))
  438. await stream.outbound.finish()
  439. var responseParts = stream.inbound.makeAsyncIterator()
  440. let metadata = try await responseParts.next()
  441. self.assertMetadata(metadata)
  442. let message = try await responseParts.next()
  443. self.assertMessage(message) {
  444. #expect($0 == Array("hello".utf8))
  445. }
  446. let status = try await responseParts.next()
  447. self.assertStatus(status) { status, _ in
  448. #expect(status.code == .ok, Comment(rawValue: status.description))
  449. }
  450. }
  451. #expect(onlyBinaryEchoGetCounter.value == 1)
  452. #expect(allMethodsCounter.value == 1)
  453. #expect(onlyBinaryEchoCollectCounter.value == 0)
  454. #expect(bothBinaryEchoMethodsCounter.value == 1)
  455. // Now, make a request to the `BinaryEcho/collect` method and assert that only
  456. // the counters associated to interceptors that apply to it are incremented.
  457. try await client.withStream(
  458. descriptor: BinaryEcho.Methods.collect,
  459. options: .defaults
  460. ) { stream, _ in
  461. try await stream.outbound.write(.metadata([:]))
  462. try await stream.outbound.write(.message(Array("hello".utf8)))
  463. await stream.outbound.finish()
  464. var responseParts = stream.inbound.makeAsyncIterator()
  465. let metadata = try await responseParts.next()
  466. self.assertMetadata(metadata)
  467. let message = try await responseParts.next()
  468. self.assertMessage(message) {
  469. #expect($0 == Array("hello".utf8))
  470. }
  471. let status = try await responseParts.next()
  472. self.assertStatus(status) { status, _ in
  473. #expect(status.code == .ok, Comment(rawValue: status.description))
  474. }
  475. }
  476. #expect(onlyBinaryEchoGetCounter.value == 1)
  477. #expect(allMethodsCounter.value == 2)
  478. #expect(onlyBinaryEchoCollectCounter.value == 1)
  479. #expect(bothBinaryEchoMethodsCounter.value == 2)
  480. }
  481. }
  482. func withInProcessClientConnectedToServer(
  483. services: [any RegistrableRPCService],
  484. interceptorPipeline: [ConditionalInterceptor<any ServerInterceptor>] = [],
  485. _ body: (InProcessTransport.Client, GRPCServer<InProcessTransport.Server>) async throws -> Void
  486. ) async throws {
  487. let inProcess = InProcessTransport()
  488. let server = GRPCServer(
  489. transport: inProcess.server,
  490. services: services,
  491. interceptorPipeline: interceptorPipeline
  492. )
  493. try await withThrowingTaskGroup(of: Void.self) { group in
  494. group.addTask {
  495. try await server.serve()
  496. }
  497. group.addTask {
  498. try await inProcess.client.connect()
  499. }
  500. try await body(inProcess.client, server)
  501. inProcess.client.beginGracefulShutdown()
  502. server.beginGracefulShutdown()
  503. }
  504. }
  505. func assertMetadata<Bytes: GRPCContiguousBytes>(
  506. _ part: RPCResponsePart<Bytes>?,
  507. metadataHandler: (Metadata) -> Void = { _ in }
  508. ) {
  509. switch part {
  510. case .some(.metadata(let metadata)):
  511. metadataHandler(metadata)
  512. default:
  513. Issue.record("Expected '.metadata' but found '\(String(describing: part))'")
  514. }
  515. }
  516. func assertMessage<Bytes: GRPCContiguousBytes>(
  517. _ part: RPCResponsePart<Bytes>?,
  518. messageHandler: (Bytes) -> Void = { _ in }
  519. ) {
  520. switch part {
  521. case .some(.message(let message)):
  522. messageHandler(message)
  523. default:
  524. Issue.record("Expected '.message' but found '\(String(describing: part))'")
  525. }
  526. }
  527. func assertStatus<Bytes: GRPCContiguousBytes>(
  528. _ part: RPCResponsePart<Bytes>?,
  529. statusHandler: (Status, Metadata) -> Void = { _, _ in }
  530. ) {
  531. switch part {
  532. case .some(.status(let status, let metadata)):
  533. statusHandler(status, metadata)
  534. default:
  535. Issue.record("Expected '.status' but found '\(String(describing: part))'")
  536. }
  537. }
  538. }