GRPCServerTests.swift 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613
  1. /*
  2. * Copyright 2023, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import GRPCCore
  17. import GRPCInProcessTransport
  18. import Testing
  19. import XCTest
  20. final class GRPCServerTests: XCTestCase {
  21. func withInProcessClientConnectedToServer(
  22. services: [any RegistrableRPCService],
  23. interceptorPipeline: [ServerInterceptorPipelineOperation] = [],
  24. _ body: (InProcessTransport.Client, GRPCServer) async throws -> Void
  25. ) async throws {
  26. let inProcess = InProcessTransport()
  27. try await withGRPCServer(
  28. transport: inProcess.server,
  29. services: services,
  30. interceptorPipeline: interceptorPipeline
  31. ) { server in
  32. try await withThrowingTaskGroup(of: Void.self) { group in
  33. group.addTask {
  34. try await inProcess.client.connect()
  35. }
  36. try await body(inProcess.client, server)
  37. inProcess.client.beginGracefulShutdown()
  38. }
  39. }
  40. }
  41. func testServerHandlesUnary() async throws {
  42. try await self.withInProcessClientConnectedToServer(services: [BinaryEcho()]) { client, _ in
  43. try await client.withStream(
  44. descriptor: BinaryEcho.Methods.get,
  45. options: .defaults
  46. ) { stream in
  47. try await stream.outbound.write(.metadata([:]))
  48. try await stream.outbound.write(.message([3, 1, 4, 1, 5]))
  49. await stream.outbound.finish()
  50. var responseParts = stream.inbound.makeAsyncIterator()
  51. let metadata = try await responseParts.next()
  52. XCTAssertMetadata(metadata)
  53. let message = try await responseParts.next()
  54. XCTAssertMessage(message) {
  55. XCTAssertEqual($0, [3, 1, 4, 1, 5])
  56. }
  57. let status = try await responseParts.next()
  58. XCTAssertStatus(status) { status, _ in
  59. XCTAssertEqual(status.code, .ok)
  60. }
  61. }
  62. }
  63. }
  64. func testServerHandlesClientStreaming() async throws {
  65. try await self.withInProcessClientConnectedToServer(services: [BinaryEcho()]) { client, _ in
  66. try await client.withStream(
  67. descriptor: BinaryEcho.Methods.collect,
  68. options: .defaults
  69. ) { stream in
  70. try await stream.outbound.write(.metadata([:]))
  71. try await stream.outbound.write(.message([3]))
  72. try await stream.outbound.write(.message([1]))
  73. try await stream.outbound.write(.message([4]))
  74. try await stream.outbound.write(.message([1]))
  75. try await stream.outbound.write(.message([5]))
  76. await stream.outbound.finish()
  77. var responseParts = stream.inbound.makeAsyncIterator()
  78. let metadata = try await responseParts.next()
  79. XCTAssertMetadata(metadata)
  80. let message = try await responseParts.next()
  81. XCTAssertMessage(message) {
  82. XCTAssertEqual($0, [3, 1, 4, 1, 5])
  83. }
  84. let status = try await responseParts.next()
  85. XCTAssertStatus(status) { status, _ in
  86. XCTAssertEqual(status.code, .ok)
  87. }
  88. }
  89. }
  90. }
  91. func testServerHandlesServerStreaming() async throws {
  92. try await self.withInProcessClientConnectedToServer(services: [BinaryEcho()]) { client, _ in
  93. try await client.withStream(
  94. descriptor: BinaryEcho.Methods.expand,
  95. options: .defaults
  96. ) { stream in
  97. try await stream.outbound.write(.metadata([:]))
  98. try await stream.outbound.write(.message([3, 1, 4, 1, 5]))
  99. await stream.outbound.finish()
  100. var responseParts = stream.inbound.makeAsyncIterator()
  101. let metadata = try await responseParts.next()
  102. XCTAssertMetadata(metadata)
  103. for byte in [3, 1, 4, 1, 5] as [UInt8] {
  104. let message = try await responseParts.next()
  105. XCTAssertMessage(message) {
  106. XCTAssertEqual($0, [byte])
  107. }
  108. }
  109. let status = try await responseParts.next()
  110. XCTAssertStatus(status) { status, _ in
  111. XCTAssertEqual(status.code, .ok)
  112. }
  113. }
  114. }
  115. }
  116. func testServerHandlesBidirectionalStreaming() async throws {
  117. try await self.withInProcessClientConnectedToServer(services: [BinaryEcho()]) { client, _ in
  118. try await client.withStream(
  119. descriptor: BinaryEcho.Methods.update,
  120. options: .defaults
  121. ) { stream in
  122. try await stream.outbound.write(.metadata([:]))
  123. for byte in [3, 1, 4, 1, 5] as [UInt8] {
  124. try await stream.outbound.write(.message([byte]))
  125. }
  126. await stream.outbound.finish()
  127. var responseParts = stream.inbound.makeAsyncIterator()
  128. let metadata = try await responseParts.next()
  129. XCTAssertMetadata(metadata)
  130. for byte in [3, 1, 4, 1, 5] as [UInt8] {
  131. let message = try await responseParts.next()
  132. XCTAssertMessage(message) {
  133. XCTAssertEqual($0, [byte])
  134. }
  135. }
  136. let status = try await responseParts.next()
  137. XCTAssertStatus(status) { status, _ in
  138. XCTAssertEqual(status.code, .ok)
  139. }
  140. }
  141. }
  142. }
  143. func testUnimplementedMethod() async throws {
  144. try await self.withInProcessClientConnectedToServer(services: [BinaryEcho()]) { client, _ in
  145. try await client.withStream(
  146. descriptor: MethodDescriptor(service: "not", method: "implemented"),
  147. options: .defaults
  148. ) { stream in
  149. try await stream.outbound.write(.metadata([:]))
  150. await stream.outbound.finish()
  151. var responseParts = stream.inbound.makeAsyncIterator()
  152. let status = try await responseParts.next()
  153. XCTAssertStatus(status) { status, _ in
  154. XCTAssertEqual(status.code, .unimplemented)
  155. }
  156. }
  157. }
  158. }
  159. func testMultipleConcurrentRequests() async throws {
  160. try await self.withInProcessClientConnectedToServer(services: [BinaryEcho()]) { client, _ in
  161. await withThrowingTaskGroup(of: Void.self) { group in
  162. for i in UInt8.min ..< UInt8.max {
  163. group.addTask {
  164. try await client.withStream(
  165. descriptor: BinaryEcho.Methods.get,
  166. options: .defaults
  167. ) { stream in
  168. try await stream.outbound.write(.metadata([:]))
  169. try await stream.outbound.write(.message([i]))
  170. await stream.outbound.finish()
  171. var responseParts = stream.inbound.makeAsyncIterator()
  172. let metadata = try await responseParts.next()
  173. XCTAssertMetadata(metadata)
  174. let message = try await responseParts.next()
  175. XCTAssertMessage(message) { XCTAssertEqual($0, [i]) }
  176. let status = try await responseParts.next()
  177. XCTAssertStatus(status) { status, _ in
  178. XCTAssertEqual(status.code, .ok)
  179. }
  180. }
  181. }
  182. }
  183. }
  184. }
  185. }
  186. func testInterceptorsAreAppliedInOrder() async throws {
  187. let counter1 = AtomicCounter()
  188. let counter2 = AtomicCounter()
  189. try await self.withInProcessClientConnectedToServer(
  190. services: [BinaryEcho()],
  191. interceptorPipeline: [
  192. .apply(.requestCounter(counter1), to: .all),
  193. .apply(.rejectAll(with: RPCError(code: .unavailable, message: "")), to: .all),
  194. .apply(.requestCounter(counter2), to: .all),
  195. ]
  196. ) { client, _ in
  197. try await client.withStream(
  198. descriptor: BinaryEcho.Methods.get,
  199. options: .defaults
  200. ) { stream in
  201. try await stream.outbound.write(.metadata([:]))
  202. await stream.outbound.finish()
  203. let parts = try await stream.inbound.collect()
  204. XCTAssertStatus(parts.first) { status, _ in
  205. XCTAssertEqual(status.code, .unavailable)
  206. }
  207. }
  208. }
  209. XCTAssertEqual(counter1.value, 1)
  210. XCTAssertEqual(counter2.value, 0)
  211. }
  212. func testInterceptorsAreNotAppliedToUnimplementedMethods() async throws {
  213. let counter = AtomicCounter()
  214. try await self.withInProcessClientConnectedToServer(
  215. services: [BinaryEcho()],
  216. interceptorPipeline: [.apply(.requestCounter(counter), to: .all)]
  217. ) { client, _ in
  218. try await client.withStream(
  219. descriptor: MethodDescriptor(service: "not", method: "implemented"),
  220. options: .defaults
  221. ) { stream in
  222. try await stream.outbound.write(.metadata([:]))
  223. await stream.outbound.finish()
  224. let parts = try await stream.inbound.collect()
  225. XCTAssertStatus(parts.first) { status, _ in
  226. XCTAssertEqual(status.code, .unimplemented)
  227. }
  228. }
  229. }
  230. XCTAssertEqual(counter.value, 0)
  231. }
  232. func testNoNewRPCsAfterServerStopListening() async throws {
  233. try await withInProcessClientConnectedToServer(services: [BinaryEcho()]) { client, server in
  234. // Run an RPC so we know the server is up.
  235. try await self.doEchoGet(using: client)
  236. // New streams should fail immediately after this.
  237. server.beginGracefulShutdown()
  238. // RPC should fail now.
  239. await XCTAssertThrowsRPCErrorAsync {
  240. try await client.withStream(
  241. descriptor: BinaryEcho.Methods.get,
  242. options: .defaults
  243. ) { stream in
  244. XCTFail("Stream shouldn't be opened")
  245. }
  246. } errorHandler: { error in
  247. XCTAssertEqual(error.code, .failedPrecondition)
  248. }
  249. }
  250. }
  251. func testInFlightRPCsCanContinueAfterServerStopListening() async throws {
  252. try await withInProcessClientConnectedToServer(services: [BinaryEcho()]) { client, server in
  253. try await client.withStream(
  254. descriptor: BinaryEcho.Methods.update,
  255. options: .defaults
  256. ) { stream in
  257. try await stream.outbound.write(.metadata([:]))
  258. var iterator = stream.inbound.makeAsyncIterator()
  259. // Don't need to validate the response, just that the server is running.
  260. let metadata = try await iterator.next()
  261. XCTAssertMetadata(metadata)
  262. // New streams should fail immediately after this.
  263. server.beginGracefulShutdown()
  264. try await stream.outbound.write(.message([0]))
  265. await stream.outbound.finish()
  266. let message = try await iterator.next()
  267. XCTAssertMessage(message) { XCTAssertEqual($0, [0]) }
  268. let status = try await iterator.next()
  269. XCTAssertStatus(status)
  270. }
  271. }
  272. }
  273. func testCancelRunningServer() async throws {
  274. let inProcess = InProcessTransport()
  275. let task = Task {
  276. let server = GRPCServer(transport: inProcess.server, services: [BinaryEcho()])
  277. try await server.serve()
  278. }
  279. try await withThrowingTaskGroup(of: Void.self) { group in
  280. group.addTask {
  281. try? await inProcess.client.connect()
  282. }
  283. try await self.doEchoGet(using: inProcess.client)
  284. // The server must be running at this point as an RPC has completed.
  285. task.cancel()
  286. try await task.value
  287. group.cancelAll()
  288. }
  289. }
  290. func testTestRunStoppedServer() async throws {
  291. let server = GRPCServer(transport: InProcessTransport.Server(), services: [])
  292. // Run the server.
  293. let task = Task { try await server.serve() }
  294. task.cancel()
  295. try await task.value
  296. // Server is stopped, should throw an error.
  297. await XCTAssertThrowsErrorAsync(ofType: RuntimeError.self) {
  298. try await server.serve()
  299. } errorHandler: { error in
  300. XCTAssertEqual(error.code, .serverIsStopped)
  301. }
  302. }
  303. func testRunServerWhenTransportThrows() async throws {
  304. let server = GRPCServer(transport: ThrowOnRunServerTransport(), services: [])
  305. await XCTAssertThrowsErrorAsync(ofType: RuntimeError.self) {
  306. try await server.serve()
  307. } errorHandler: { error in
  308. XCTAssertEqual(error.code, .transportError)
  309. }
  310. }
  311. private func doEchoGet(using transport: some ClientTransport) async throws {
  312. try await transport.withStream(
  313. descriptor: BinaryEcho.Methods.get,
  314. options: .defaults
  315. ) { stream in
  316. try await stream.outbound.write(.metadata([:]))
  317. try await stream.outbound.write(.message([0]))
  318. await stream.outbound.finish()
  319. // Don't need to validate the response, just that the server is running.
  320. let parts = try await stream.inbound.collect()
  321. XCTAssertEqual(parts.count, 3)
  322. }
  323. }
  324. }
  325. @Suite("GRPC Server Tests")
  326. struct ServerTests {
  327. @Test("Interceptors are applied only to specified services")
  328. func testInterceptorsAreAppliedToSpecifiedServices() async throws {
  329. let onlyBinaryEchoCounter = AtomicCounter()
  330. let allServicesCounter = AtomicCounter()
  331. let onlyHelloWorldCounter = AtomicCounter()
  332. let bothServicesCounter = AtomicCounter()
  333. try await self.withInProcessClientConnectedToServer(
  334. services: [BinaryEcho(), HelloWorld()],
  335. interceptorPipeline: [
  336. .apply(
  337. .requestCounter(onlyBinaryEchoCounter),
  338. to: .services([BinaryEcho.serviceDescriptor])
  339. ),
  340. .apply(.requestCounter(allServicesCounter), to: .all),
  341. .apply(
  342. .requestCounter(onlyHelloWorldCounter),
  343. to: .services([HelloWorld.serviceDescriptor])
  344. ),
  345. .apply(
  346. .requestCounter(bothServicesCounter),
  347. to: .services([BinaryEcho.serviceDescriptor, HelloWorld.serviceDescriptor])
  348. ),
  349. ]
  350. ) { client, _ in
  351. // Make a request to the `BinaryEcho` service and assert that only
  352. // the counters associated to interceptors that apply to it are incremented.
  353. try await client.withStream(
  354. descriptor: BinaryEcho.Methods.get,
  355. options: .defaults
  356. ) { stream in
  357. try await stream.outbound.write(.metadata([:]))
  358. try await stream.outbound.write(.message(Array("hello".utf8)))
  359. await stream.outbound.finish()
  360. var responseParts = stream.inbound.makeAsyncIterator()
  361. let metadata = try await responseParts.next()
  362. self.assertMetadata(metadata)
  363. let message = try await responseParts.next()
  364. self.assertMessage(message) {
  365. #expect($0 == Array("hello".utf8))
  366. }
  367. let status = try await responseParts.next()
  368. self.assertStatus(status) { status, _ in
  369. #expect(status.code == .ok, Comment(rawValue: status.description))
  370. }
  371. }
  372. #expect(onlyBinaryEchoCounter.value == 1)
  373. #expect(allServicesCounter.value == 1)
  374. #expect(onlyHelloWorldCounter.value == 0)
  375. #expect(bothServicesCounter.value == 1)
  376. // Now, make a request to the `HelloWorld` service and assert that only
  377. // the counters associated to interceptors that apply to it are incremented.
  378. try await client.withStream(
  379. descriptor: HelloWorld.Methods.sayHello,
  380. options: .defaults
  381. ) { stream in
  382. try await stream.outbound.write(.metadata([:]))
  383. try await stream.outbound.write(.message(Array("Swift".utf8)))
  384. await stream.outbound.finish()
  385. var responseParts = stream.inbound.makeAsyncIterator()
  386. let metadata = try await responseParts.next()
  387. self.assertMetadata(metadata)
  388. let message = try await responseParts.next()
  389. self.assertMessage(message) {
  390. #expect($0 == Array("Hello, Swift!".utf8))
  391. }
  392. let status = try await responseParts.next()
  393. self.assertStatus(status) { status, _ in
  394. #expect(status.code == .ok, Comment(rawValue: status.description))
  395. }
  396. }
  397. #expect(onlyBinaryEchoCounter.value == 1)
  398. #expect(allServicesCounter.value == 2)
  399. #expect(onlyHelloWorldCounter.value == 1)
  400. #expect(bothServicesCounter.value == 2)
  401. }
  402. }
  403. @Test("Interceptors are applied only to specified methods")
  404. func testInterceptorsAreAppliedToSpecifiedMethods() async throws {
  405. let onlyBinaryEchoGetCounter = AtomicCounter()
  406. let onlyBinaryEchoCollectCounter = AtomicCounter()
  407. let bothBinaryEchoMethodsCounter = AtomicCounter()
  408. let allMethodsCounter = AtomicCounter()
  409. try await self.withInProcessClientConnectedToServer(
  410. services: [BinaryEcho()],
  411. interceptorPipeline: [
  412. .apply(
  413. .requestCounter(onlyBinaryEchoGetCounter),
  414. to: .methods([BinaryEcho.Methods.get])
  415. ),
  416. .apply(.requestCounter(allMethodsCounter), to: .all),
  417. .apply(
  418. .requestCounter(onlyBinaryEchoCollectCounter),
  419. to: .methods([BinaryEcho.Methods.collect])
  420. ),
  421. .apply(
  422. .requestCounter(bothBinaryEchoMethodsCounter),
  423. to: .methods([BinaryEcho.Methods.get, BinaryEcho.Methods.collect])
  424. ),
  425. ]
  426. ) { client, _ in
  427. // Make a request to the `BinaryEcho/get` method and assert that only
  428. // the counters associated to interceptors that apply to it are incremented.
  429. try await client.withStream(
  430. descriptor: BinaryEcho.Methods.get,
  431. options: .defaults
  432. ) { stream in
  433. try await stream.outbound.write(.metadata([:]))
  434. try await stream.outbound.write(.message(Array("hello".utf8)))
  435. await stream.outbound.finish()
  436. var responseParts = stream.inbound.makeAsyncIterator()
  437. let metadata = try await responseParts.next()
  438. self.assertMetadata(metadata)
  439. let message = try await responseParts.next()
  440. self.assertMessage(message) {
  441. #expect($0 == Array("hello".utf8))
  442. }
  443. let status = try await responseParts.next()
  444. self.assertStatus(status) { status, _ in
  445. #expect(status.code == .ok, Comment(rawValue: status.description))
  446. }
  447. }
  448. #expect(onlyBinaryEchoGetCounter.value == 1)
  449. #expect(allMethodsCounter.value == 1)
  450. #expect(onlyBinaryEchoCollectCounter.value == 0)
  451. #expect(bothBinaryEchoMethodsCounter.value == 1)
  452. // Now, make a request to the `BinaryEcho/collect` method and assert that only
  453. // the counters associated to interceptors that apply to it are incremented.
  454. try await client.withStream(
  455. descriptor: BinaryEcho.Methods.collect,
  456. options: .defaults
  457. ) { stream in
  458. try await stream.outbound.write(.metadata([:]))
  459. try await stream.outbound.write(.message(Array("hello".utf8)))
  460. await stream.outbound.finish()
  461. var responseParts = stream.inbound.makeAsyncIterator()
  462. let metadata = try await responseParts.next()
  463. self.assertMetadata(metadata)
  464. let message = try await responseParts.next()
  465. self.assertMessage(message) {
  466. #expect($0 == Array("hello".utf8))
  467. }
  468. let status = try await responseParts.next()
  469. self.assertStatus(status) { status, _ in
  470. #expect(status.code == .ok, Comment(rawValue: status.description))
  471. }
  472. }
  473. #expect(onlyBinaryEchoGetCounter.value == 1)
  474. #expect(allMethodsCounter.value == 2)
  475. #expect(onlyBinaryEchoCollectCounter.value == 1)
  476. #expect(bothBinaryEchoMethodsCounter.value == 2)
  477. }
  478. }
  479. func withInProcessClientConnectedToServer(
  480. services: [any RegistrableRPCService],
  481. interceptorPipeline: [ServerInterceptorPipelineOperation] = [],
  482. _ body: (InProcessTransport.Client, GRPCServer) async throws -> Void
  483. ) async throws {
  484. let inProcess = InProcessTransport()
  485. let server = GRPCServer(
  486. transport: inProcess.server,
  487. services: services,
  488. interceptorPipeline: interceptorPipeline
  489. )
  490. try await withThrowingTaskGroup(of: Void.self) { group in
  491. group.addTask {
  492. try await server.serve()
  493. }
  494. group.addTask {
  495. try await inProcess.client.connect()
  496. }
  497. try await body(inProcess.client, server)
  498. inProcess.client.beginGracefulShutdown()
  499. server.beginGracefulShutdown()
  500. }
  501. }
  502. func assertMetadata(
  503. _ part: RPCResponsePart?,
  504. metadataHandler: (Metadata) -> Void = { _ in }
  505. ) {
  506. switch part {
  507. case .some(.metadata(let metadata)):
  508. metadataHandler(metadata)
  509. default:
  510. Issue.record("Expected '.metadata' but found '\(String(describing: part))'")
  511. }
  512. }
  513. func assertMessage(
  514. _ part: RPCResponsePart?,
  515. messageHandler: ([UInt8]) -> Void = { _ in }
  516. ) {
  517. switch part {
  518. case .some(.message(let message)):
  519. messageHandler(message)
  520. default:
  521. Issue.record("Expected '.message' but found '\(String(describing: part))'")
  522. }
  523. }
  524. func assertStatus(
  525. _ part: RPCResponsePart?,
  526. statusHandler: (Status, Metadata) -> Void = { _, _ in }
  527. ) {
  528. switch part {
  529. case .some(.status(let status, let metadata)):
  530. statusHandler(status, metadata)
  531. default:
  532. Issue.record("Expected '.status' but found '\(String(describing: part))'")
  533. }
  534. }
  535. }