|
|
@@ -25,9 +25,25 @@ extension Generator {
|
|
|
self.println()
|
|
|
self.printClientProtocolExtension()
|
|
|
self.println()
|
|
|
+ self.printClassBackedServiceClientImplementation()
|
|
|
+ self.println()
|
|
|
+ self.printStructBackedServiceClientImplementation()
|
|
|
+ self.println()
|
|
|
+ self.printIfCompilerGuardForAsyncAwait()
|
|
|
+ self.printAsyncServiceClientProtocol()
|
|
|
+ self.println()
|
|
|
+ self.printAsyncClientProtocolExtension()
|
|
|
+ self.println()
|
|
|
+ self.printAsyncClientProtocolSafeWrappersExtension()
|
|
|
+ self.println()
|
|
|
+ self.printAsyncServiceClientImplementation()
|
|
|
+ self.println()
|
|
|
+ self.printEndCompilerGuardForAsyncAwait()
|
|
|
+ self.println()
|
|
|
+ // Both implementations share definitions for interceptors and metadata.
|
|
|
self.printServiceClientInterceptorFactoryProtocol()
|
|
|
self.println()
|
|
|
- self.printServiceClientImplementation()
|
|
|
+ self.printClientMetadata()
|
|
|
}
|
|
|
|
|
|
if self.options.generateTestClient {
|
|
|
@@ -41,19 +57,39 @@ extension Generator {
|
|
|
arguments: [String],
|
|
|
returnType: String?,
|
|
|
access: String? = nil,
|
|
|
+ sendable: Bool = false,
|
|
|
+ async: Bool = false,
|
|
|
+ throws: Bool = false,
|
|
|
+ genericWhereClause: String? = nil,
|
|
|
bodyBuilder: (() -> Void)?
|
|
|
) {
|
|
|
// Add a space after access, if it exists.
|
|
|
- let accessOrEmpty = access.map { $0 + " " } ?? ""
|
|
|
- let `return` = returnType.map { "-> " + $0 } ?? ""
|
|
|
+ let functionHead = (access.map { $0 + " " } ?? "") + (sendable ? "@Sendable " : "")
|
|
|
+ let `return` = returnType.map { " -> " + $0 } ?? ""
|
|
|
+ let genericWhere = genericWhereClause.map { " " + $0 } ?? ""
|
|
|
+
|
|
|
+ let asyncThrows: String
|
|
|
+ switch (async, `throws`) {
|
|
|
+ case (true, true):
|
|
|
+ asyncThrows = " async throws"
|
|
|
+ case (true, false):
|
|
|
+ asyncThrows = " async"
|
|
|
+ case (false, true):
|
|
|
+ asyncThrows = " throws"
|
|
|
+ case (false, false):
|
|
|
+ asyncThrows = ""
|
|
|
+ }
|
|
|
|
|
|
let hasBody = bodyBuilder != nil
|
|
|
|
|
|
if arguments.isEmpty {
|
|
|
// Don't bother splitting across multiple lines if there are no arguments.
|
|
|
- self.println("\(accessOrEmpty)func \(name)() \(`return`)", newline: !hasBody)
|
|
|
+ self.println(
|
|
|
+ "\(functionHead)func \(name)()\(asyncThrows)\(`return`)\(genericWhere)",
|
|
|
+ newline: !hasBody
|
|
|
+ )
|
|
|
} else {
|
|
|
- self.println("\(accessOrEmpty)func \(name)(")
|
|
|
+ self.println("\(functionHead)func \(name)(")
|
|
|
self.withIndentation {
|
|
|
// Add a comma after each argument except the last.
|
|
|
arguments.forEach(beforeLast: {
|
|
|
@@ -62,7 +98,7 @@ extension Generator {
|
|
|
self.println($0)
|
|
|
})
|
|
|
}
|
|
|
- self.println(") \(`return`)", newline: !hasBody)
|
|
|
+ self.println(")\(asyncThrows)\(`return`)\(genericWhere)", newline: !hasBody)
|
|
|
}
|
|
|
|
|
|
if let bodyBuilder = bodyBuilder {
|
|
|
@@ -123,7 +159,7 @@ extension Generator {
|
|
|
}
|
|
|
|
|
|
private func printServiceClientInterceptorFactoryProtocol() {
|
|
|
- self.println("\(self.access) protocol \(self.clientInterceptorProtocolName) {")
|
|
|
+ self.println("\(self.access) protocol \(self.clientInterceptorProtocolName): GRPCSendable {")
|
|
|
self.withIndentation {
|
|
|
// Method specific interceptors.
|
|
|
for method in service.methods {
|
|
|
@@ -152,10 +188,62 @@ extension Generator {
|
|
|
)
|
|
|
}
|
|
|
|
|
|
- private func printServiceClientImplementation() {
|
|
|
+ private func printClassBackedServiceClientImplementation() {
|
|
|
+ self.printIfCompilerGuardForAsyncAwait()
|
|
|
+ self.println("@available(*, deprecated)")
|
|
|
+ self.println("extension \(clientClassName): @unchecked Sendable {}")
|
|
|
+ self.printEndCompilerGuardForAsyncAwait()
|
|
|
+ self.println()
|
|
|
+ self.println("@available(*, deprecated, renamed: \"\(clientStructName)\")")
|
|
|
println("\(access) final class \(clientClassName): \(clientProtocolName) {")
|
|
|
self.withIndentation {
|
|
|
+ println("private let lock = Lock()")
|
|
|
+ println("private var _defaultCallOptions: CallOptions")
|
|
|
+ println("private var _interceptors: \(clientInterceptorProtocolName)?")
|
|
|
+
|
|
|
println("\(access) let channel: GRPCChannel")
|
|
|
+ println("\(access) var defaultCallOptions: CallOptions {")
|
|
|
+ self.withIndentation {
|
|
|
+ println("get { self.lock.withLock { return self._defaultCallOptions } }")
|
|
|
+ println("set { self.lock.withLockVoid { self._defaultCallOptions = newValue } }")
|
|
|
+ }
|
|
|
+ self.println("}")
|
|
|
+ println("\(access) var interceptors: \(clientInterceptorProtocolName)? {")
|
|
|
+ self.withIndentation {
|
|
|
+ println("get { self.lock.withLock { return self._interceptors } }")
|
|
|
+ println("set { self.lock.withLockVoid { self._interceptors = newValue } }")
|
|
|
+ }
|
|
|
+ println("}")
|
|
|
+ println()
|
|
|
+ println("/// Creates a client for the \(servicePath) service.")
|
|
|
+ println("///")
|
|
|
+ self.printParameters()
|
|
|
+ println("/// - channel: `GRPCChannel` to the service host.")
|
|
|
+ println(
|
|
|
+ "/// - defaultCallOptions: Options to use for each service call if the user doesn't provide them."
|
|
|
+ )
|
|
|
+ println("/// - interceptors: A factory providing interceptors for each RPC.")
|
|
|
+ println("\(access) init(")
|
|
|
+ self.withIndentation {
|
|
|
+ println("channel: GRPCChannel,")
|
|
|
+ println("defaultCallOptions: CallOptions = CallOptions(),")
|
|
|
+ println("interceptors: \(clientInterceptorProtocolName)? = nil")
|
|
|
+ }
|
|
|
+ self.println(") {")
|
|
|
+ self.withIndentation {
|
|
|
+ println("self.channel = channel")
|
|
|
+ println("self._defaultCallOptions = defaultCallOptions")
|
|
|
+ println("self._interceptors = interceptors")
|
|
|
+ }
|
|
|
+ self.println("}")
|
|
|
+ }
|
|
|
+ println("}")
|
|
|
+ }
|
|
|
+
|
|
|
+ private func printStructBackedServiceClientImplementation() {
|
|
|
+ println("\(access) struct \(clientStructName): \(clientProtocolName) {")
|
|
|
+ self.withIndentation {
|
|
|
+ println("\(access) var channel: GRPCChannel")
|
|
|
println("\(access) var defaultCallOptions: CallOptions")
|
|
|
println("\(access) var interceptors: \(clientInterceptorProtocolName)?")
|
|
|
println()
|
|
|
@@ -220,7 +308,7 @@ extension Generator {
|
|
|
) {
|
|
|
self.println("return self.makeUnaryCall(")
|
|
|
self.withIndentation {
|
|
|
- self.println("path: \(self.methodPath),")
|
|
|
+ self.println("path: \(self.methodPathUsingClientMetadata),")
|
|
|
self.println("request: request,")
|
|
|
self.println("callOptions: callOptions ?? self.defaultCallOptions,")
|
|
|
self.println(
|
|
|
@@ -247,7 +335,7 @@ extension Generator {
|
|
|
) {
|
|
|
self.println("return self.makeServerStreamingCall(")
|
|
|
self.withIndentation {
|
|
|
- self.println("path: \(self.methodPath),")
|
|
|
+ self.println("path: \(self.methodPathUsingClientMetadata),")
|
|
|
self.println("request: request,")
|
|
|
self.println("callOptions: callOptions ?? self.defaultCallOptions,")
|
|
|
self.println(
|
|
|
@@ -278,7 +366,7 @@ extension Generator {
|
|
|
) {
|
|
|
self.println("return self.makeClientStreamingCall(")
|
|
|
self.withIndentation {
|
|
|
- self.println("path: \(self.methodPath),")
|
|
|
+ self.println("path: \(self.methodPathUsingClientMetadata),")
|
|
|
self.println("callOptions: callOptions ?? self.defaultCallOptions,")
|
|
|
self.println(
|
|
|
"interceptors: self.interceptors?.\(self.methodInterceptorFactoryName)() ?? []"
|
|
|
@@ -305,7 +393,7 @@ extension Generator {
|
|
|
) {
|
|
|
self.println("return self.makeBidirectionalStreamingCall(")
|
|
|
self.withIndentation {
|
|
|
- self.println("path: \(self.methodPath),")
|
|
|
+ self.println("path: \(self.methodPathUsingClientMetadata),")
|
|
|
self.println("callOptions: callOptions ?? self.defaultCallOptions,")
|
|
|
self.println(
|
|
|
"interceptors: self.interceptors?.\(self.methodInterceptorFactoryName)() ?? [],"
|
|
|
@@ -428,7 +516,7 @@ extension Generator {
|
|
|
) {
|
|
|
self
|
|
|
.println(
|
|
|
- "return self.fakeChannel.\(factory)(path: \(self.methodPath), requestHandler: requestHandler)"
|
|
|
+ "return self.fakeChannel.\(factory)(path: \(self.methodPathUsingClientMetadata), requestHandler: requestHandler)"
|
|
|
)
|
|
|
}
|
|
|
}
|
|
|
@@ -438,21 +526,31 @@ extension Generator {
|
|
|
.println("/// Returns true if there are response streams enqueued for '\(self.method.name)'")
|
|
|
self.println("\(self.access) var has\(self.method.name)ResponsesRemaining: Bool {")
|
|
|
self.withIndentation {
|
|
|
- self.println("return self.fakeChannel.hasFakeResponseEnqueued(forPath: \(self.methodPath))")
|
|
|
+ self.println(
|
|
|
+ "return self.fakeChannel.hasFakeResponseEnqueued(forPath: \(self.methodPathUsingClientMetadata))"
|
|
|
+ )
|
|
|
}
|
|
|
self.println("}")
|
|
|
}
|
|
|
|
|
|
private func printTestClient() {
|
|
|
- self
|
|
|
- .println(
|
|
|
- "\(self.access) final class \(self.testClientClassName): \(self.clientProtocolName) {"
|
|
|
- )
|
|
|
+ self.printIfCompilerGuardForAsyncAwait()
|
|
|
+ self.println("@available(swift, deprecated: 5.6)")
|
|
|
+ self.println("extension \(self.testClientClassName): @unchecked Sendable {}")
|
|
|
+ self.printEndCompilerGuardForAsyncAwait()
|
|
|
+ self.println()
|
|
|
+ self.println(
|
|
|
+ "@available(swift, deprecated: 5.6, message: \"Test clients are not Sendable "
|
|
|
+ + "but the 'GRPCClient' API requires clients to be Sendable. Using a localhost client and "
|
|
|
+ + "server is the recommended alternative.\")"
|
|
|
+ )
|
|
|
+ self.println(
|
|
|
+ "\(self.access) final class \(self.testClientClassName): \(self.clientProtocolName) {"
|
|
|
+ )
|
|
|
self.withIndentation {
|
|
|
self.println("private let fakeChannel: FakeChannel")
|
|
|
- self.println("\(self.access) var defaultCallOptions: CallOptions")
|
|
|
- self.println("\(self.access) var interceptors: \(self.clientInterceptorProtocolName)?")
|
|
|
-
|
|
|
+ self.println("\(access) var defaultCallOptions: CallOptions")
|
|
|
+ self.println("\(access) var interceptors: \(clientInterceptorProtocolName)?")
|
|
|
self.println()
|
|
|
self.println("\(self.access) var channel: GRPCChannel {")
|
|
|
self.withIndentation {
|