|
|
@@ -16,19 +16,20 @@
|
|
|
|
|
|
import GRPCCore
|
|
|
import GRPCInProcessTransport
|
|
|
+import Testing
|
|
|
import XCTest
|
|
|
|
|
|
final class GRPCServerTests: XCTestCase {
|
|
|
func withInProcessClientConnectedToServer(
|
|
|
services: [any RegistrableRPCService],
|
|
|
- interceptors: [any ServerInterceptor] = [],
|
|
|
+ interceptorPipeline: [ServerInterceptorPipelineOperation] = [],
|
|
|
_ body: (InProcessTransport.Client, GRPCServer) async throws -> Void
|
|
|
) async throws {
|
|
|
let inProcess = InProcessTransport()
|
|
|
let server = GRPCServer(
|
|
|
transport: inProcess.server,
|
|
|
services: services,
|
|
|
- interceptors: interceptors
|
|
|
+ interceptorPipeline: interceptorPipeline
|
|
|
)
|
|
|
|
|
|
try await withThrowingTaskGroup(of: Void.self) { group in
|
|
|
@@ -219,10 +220,10 @@ final class GRPCServerTests: XCTestCase {
|
|
|
|
|
|
try await self.withInProcessClientConnectedToServer(
|
|
|
services: [BinaryEcho()],
|
|
|
- interceptors: [
|
|
|
- .requestCounter(counter1),
|
|
|
- .rejectAll(with: RPCError(code: .unavailable, message: "")),
|
|
|
- .requestCounter(counter2),
|
|
|
+ interceptorPipeline: [
|
|
|
+ .apply(.requestCounter(counter1), to: .all),
|
|
|
+ .apply(.rejectAll(with: RPCError(code: .unavailable, message: "")), to: .all),
|
|
|
+ .apply(.requestCounter(counter2), to: .all),
|
|
|
]
|
|
|
) { client, _ in
|
|
|
try await client.withStream(
|
|
|
@@ -248,7 +249,7 @@ final class GRPCServerTests: XCTestCase {
|
|
|
|
|
|
try await self.withInProcessClientConnectedToServer(
|
|
|
services: [BinaryEcho()],
|
|
|
- interceptors: [.requestCounter(counter)]
|
|
|
+ interceptorPipeline: [.apply(.requestCounter(counter), to: .all)]
|
|
|
) { client, _ in
|
|
|
try await client.withStream(
|
|
|
descriptor: MethodDescriptor(service: "not", method: "implemented"),
|
|
|
@@ -374,3 +375,243 @@ final class GRPCServerTests: XCTestCase {
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+@Suite("GRPC Server Tests")
|
|
|
+struct ServerTests {
|
|
|
+ @Test("Interceptors are applied only to specified services")
|
|
|
+ func testInterceptorsAreAppliedToSpecifiedServices() async throws {
|
|
|
+ let onlyBinaryEchoCounter = AtomicCounter()
|
|
|
+ let allServicesCounter = AtomicCounter()
|
|
|
+ let onlyHelloWorldCounter = AtomicCounter()
|
|
|
+ let bothServicesCounter = AtomicCounter()
|
|
|
+
|
|
|
+ try await self.withInProcessClientConnectedToServer(
|
|
|
+ services: [BinaryEcho(), HelloWorld()],
|
|
|
+ interceptorPipeline: [
|
|
|
+ .apply(
|
|
|
+ .requestCounter(onlyBinaryEchoCounter),
|
|
|
+ to: .services([BinaryEcho.serviceDescriptor])
|
|
|
+ ),
|
|
|
+ .apply(.requestCounter(allServicesCounter), to: .all),
|
|
|
+ .apply(
|
|
|
+ .requestCounter(onlyHelloWorldCounter),
|
|
|
+ to: .services([HelloWorld.serviceDescriptor])
|
|
|
+ ),
|
|
|
+ .apply(
|
|
|
+ .requestCounter(bothServicesCounter),
|
|
|
+ to: .services([BinaryEcho.serviceDescriptor, HelloWorld.serviceDescriptor])
|
|
|
+ ),
|
|
|
+ ]
|
|
|
+ ) { client, _ in
|
|
|
+ // Make a request to the `BinaryEcho` service and assert that only
|
|
|
+ // the counters associated to interceptors that apply to it are incremented.
|
|
|
+ try await client.withStream(
|
|
|
+ descriptor: BinaryEcho.Methods.get,
|
|
|
+ options: .defaults
|
|
|
+ ) { stream in
|
|
|
+ try await stream.outbound.write(.metadata([:]))
|
|
|
+ try await stream.outbound.write(.message(Array("hello".utf8)))
|
|
|
+ await stream.outbound.finish()
|
|
|
+
|
|
|
+ var responseParts = stream.inbound.makeAsyncIterator()
|
|
|
+ let metadata = try await responseParts.next()
|
|
|
+ self.assertMetadata(metadata)
|
|
|
+
|
|
|
+ let message = try await responseParts.next()
|
|
|
+ self.assertMessage(message) {
|
|
|
+ #expect($0 == Array("hello".utf8))
|
|
|
+ }
|
|
|
+
|
|
|
+ let status = try await responseParts.next()
|
|
|
+ self.assertStatus(status) { status, _ in
|
|
|
+ #expect(status.code == .ok, Comment(rawValue: status.description))
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ #expect(onlyBinaryEchoCounter.value == 1)
|
|
|
+ #expect(allServicesCounter.value == 1)
|
|
|
+ #expect(onlyHelloWorldCounter.value == 0)
|
|
|
+ #expect(bothServicesCounter.value == 1)
|
|
|
+
|
|
|
+ // Now, make a request to the `HelloWorld` service and assert that only
|
|
|
+ // the counters associated to interceptors that apply to it are incremented.
|
|
|
+ try await client.withStream(
|
|
|
+ descriptor: HelloWorld.Methods.sayHello,
|
|
|
+ options: .defaults
|
|
|
+ ) { stream in
|
|
|
+ try await stream.outbound.write(.metadata([:]))
|
|
|
+ try await stream.outbound.write(.message(Array("Swift".utf8)))
|
|
|
+ await stream.outbound.finish()
|
|
|
+
|
|
|
+ var responseParts = stream.inbound.makeAsyncIterator()
|
|
|
+ let metadata = try await responseParts.next()
|
|
|
+ self.assertMetadata(metadata)
|
|
|
+
|
|
|
+ let message = try await responseParts.next()
|
|
|
+ self.assertMessage(message) {
|
|
|
+ #expect($0 == Array("Hello, Swift!".utf8))
|
|
|
+ }
|
|
|
+
|
|
|
+ let status = try await responseParts.next()
|
|
|
+ self.assertStatus(status) { status, _ in
|
|
|
+ #expect(status.code == .ok, Comment(rawValue: status.description))
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ #expect(onlyBinaryEchoCounter.value == 1)
|
|
|
+ #expect(allServicesCounter.value == 2)
|
|
|
+ #expect(onlyHelloWorldCounter.value == 1)
|
|
|
+ #expect(bothServicesCounter.value == 2)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test("Interceptors are applied only to specified methods")
|
|
|
+ func testInterceptorsAreAppliedToSpecifiedMethods() async throws {
|
|
|
+ let onlyBinaryEchoGetCounter = AtomicCounter()
|
|
|
+ let onlyBinaryEchoCollectCounter = AtomicCounter()
|
|
|
+ let bothBinaryEchoMethodsCounter = AtomicCounter()
|
|
|
+ let allMethodsCounter = AtomicCounter()
|
|
|
+
|
|
|
+ try await self.withInProcessClientConnectedToServer(
|
|
|
+ services: [BinaryEcho()],
|
|
|
+ interceptorPipeline: [
|
|
|
+ .apply(
|
|
|
+ .requestCounter(onlyBinaryEchoGetCounter),
|
|
|
+ to: .methods([BinaryEcho.Methods.get])
|
|
|
+ ),
|
|
|
+ .apply(.requestCounter(allMethodsCounter), to: .all),
|
|
|
+ .apply(
|
|
|
+ .requestCounter(onlyBinaryEchoCollectCounter),
|
|
|
+ to: .methods([BinaryEcho.Methods.collect])
|
|
|
+ ),
|
|
|
+ .apply(
|
|
|
+ .requestCounter(bothBinaryEchoMethodsCounter),
|
|
|
+ to: .methods([BinaryEcho.Methods.get, BinaryEcho.Methods.collect])
|
|
|
+ ),
|
|
|
+ ]
|
|
|
+ ) { client, _ in
|
|
|
+ // Make a request to the `BinaryEcho/get` method and assert that only
|
|
|
+ // the counters associated to interceptors that apply to it are incremented.
|
|
|
+ try await client.withStream(
|
|
|
+ descriptor: BinaryEcho.Methods.get,
|
|
|
+ options: .defaults
|
|
|
+ ) { stream in
|
|
|
+ try await stream.outbound.write(.metadata([:]))
|
|
|
+ try await stream.outbound.write(.message(Array("hello".utf8)))
|
|
|
+ await stream.outbound.finish()
|
|
|
+
|
|
|
+ var responseParts = stream.inbound.makeAsyncIterator()
|
|
|
+ let metadata = try await responseParts.next()
|
|
|
+ self.assertMetadata(metadata)
|
|
|
+
|
|
|
+ let message = try await responseParts.next()
|
|
|
+ self.assertMessage(message) {
|
|
|
+ #expect($0 == Array("hello".utf8))
|
|
|
+ }
|
|
|
+
|
|
|
+ let status = try await responseParts.next()
|
|
|
+ self.assertStatus(status) { status, _ in
|
|
|
+ #expect(status.code == .ok, Comment(rawValue: status.description))
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ #expect(onlyBinaryEchoGetCounter.value == 1)
|
|
|
+ #expect(allMethodsCounter.value == 1)
|
|
|
+ #expect(onlyBinaryEchoCollectCounter.value == 0)
|
|
|
+ #expect(bothBinaryEchoMethodsCounter.value == 1)
|
|
|
+
|
|
|
+ // Now, make a request to the `BinaryEcho/collect` method and assert that only
|
|
|
+ // the counters associated to interceptors that apply to it are incremented.
|
|
|
+ try await client.withStream(
|
|
|
+ descriptor: BinaryEcho.Methods.collect,
|
|
|
+ options: .defaults
|
|
|
+ ) { stream in
|
|
|
+ try await stream.outbound.write(.metadata([:]))
|
|
|
+ try await stream.outbound.write(.message(Array("hello".utf8)))
|
|
|
+ await stream.outbound.finish()
|
|
|
+
|
|
|
+ var responseParts = stream.inbound.makeAsyncIterator()
|
|
|
+ let metadata = try await responseParts.next()
|
|
|
+ self.assertMetadata(metadata)
|
|
|
+
|
|
|
+ let message = try await responseParts.next()
|
|
|
+ self.assertMessage(message) {
|
|
|
+ #expect($0 == Array("hello".utf8))
|
|
|
+ }
|
|
|
+
|
|
|
+ let status = try await responseParts.next()
|
|
|
+ self.assertStatus(status) { status, _ in
|
|
|
+ #expect(status.code == .ok, Comment(rawValue: status.description))
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ #expect(onlyBinaryEchoGetCounter.value == 1)
|
|
|
+ #expect(allMethodsCounter.value == 2)
|
|
|
+ #expect(onlyBinaryEchoCollectCounter.value == 1)
|
|
|
+ #expect(bothBinaryEchoMethodsCounter.value == 2)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ func withInProcessClientConnectedToServer(
|
|
|
+ services: [any RegistrableRPCService],
|
|
|
+ interceptorPipeline: [ServerInterceptorPipelineOperation] = [],
|
|
|
+ _ body: (InProcessTransport.Client, GRPCServer) async throws -> Void
|
|
|
+ ) async throws {
|
|
|
+ let inProcess = InProcessTransport()
|
|
|
+ let server = GRPCServer(
|
|
|
+ transport: inProcess.server,
|
|
|
+ services: services,
|
|
|
+ interceptorPipeline: interceptorPipeline
|
|
|
+ )
|
|
|
+
|
|
|
+ try await withThrowingTaskGroup(of: Void.self) { group in
|
|
|
+ group.addTask {
|
|
|
+ try await server.serve()
|
|
|
+ }
|
|
|
+
|
|
|
+ group.addTask {
|
|
|
+ try await inProcess.client.connect()
|
|
|
+ }
|
|
|
+
|
|
|
+ try await body(inProcess.client, server)
|
|
|
+ inProcess.client.beginGracefulShutdown()
|
|
|
+ server.beginGracefulShutdown()
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ func assertMetadata(
|
|
|
+ _ part: RPCResponsePart?,
|
|
|
+ metadataHandler: (Metadata) -> Void = { _ in }
|
|
|
+ ) {
|
|
|
+ switch part {
|
|
|
+ case .some(.metadata(let metadata)):
|
|
|
+ metadataHandler(metadata)
|
|
|
+ default:
|
|
|
+ Issue.record("Expected '.metadata' but found '\(String(describing: part))'")
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ func assertMessage(
|
|
|
+ _ part: RPCResponsePart?,
|
|
|
+ messageHandler: ([UInt8]) -> Void = { _ in }
|
|
|
+ ) {
|
|
|
+ switch part {
|
|
|
+ case .some(.message(let message)):
|
|
|
+ messageHandler(message)
|
|
|
+ default:
|
|
|
+ Issue.record("Expected '.message' but found '\(String(describing: part))'")
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ func assertStatus(
|
|
|
+ _ part: RPCResponsePart?,
|
|
|
+ statusHandler: (Status, Metadata) -> Void = { _, _ in }
|
|
|
+ ) {
|
|
|
+ switch part {
|
|
|
+ case .some(.status(let status, let metadata)):
|
|
|
+ statusHandler(status, metadata)
|
|
|
+ default:
|
|
|
+ Issue.record("Expected '.status' but found '\(String(describing: part))'")
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|