| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692 |
- /*
- * Copyright 2018, gRPC Authors All rights reserved.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- import Foundation
- import SwiftProtobuf
- import SwiftProtobufPluginLibrary
- extension Generator {
- internal func printClient() {
- if self.options.generateClient {
- self.println()
- self.printServiceClientProtocol()
- self.println()
- self.printClientProtocolExtension()
- self.println()
- self.printClassBackedServiceClientImplementation()
- self.println()
- self.printStructBackedServiceClientImplementation()
- self.println()
- self.printIfCompilerGuardForAsyncAwait()
- self.printAsyncServiceClientProtocol()
- self.println()
- self.printAsyncClientProtocolExtension()
- self.println()
- self.printAsyncClientProtocolSafeWrappersExtension()
- self.println()
- self.printAsyncServiceClientImplementation()
- self.println()
- self.printEndCompilerGuardForAsyncAwait()
- self.println()
- // Both implementations share definitions for interceptors and metadata.
- self.printServiceClientInterceptorFactoryProtocol()
- self.println()
- self.printClientMetadata()
- }
- if self.options.generateTestClient {
- self.println()
- self.printTestClient()
- }
- }
- internal func printFunction(
- name: String,
- arguments: [String],
- returnType: String?,
- access: String? = nil,
- sendable: Bool = false,
- async: Bool = false,
- throws: Bool = false,
- genericWhereClause: String? = nil,
- bodyBuilder: (() -> Void)?
- ) {
- // Add a space after access, if it exists.
- let functionHead = (access.map { $0 + " " } ?? "") + (sendable ? "@Sendable " : "")
- let `return` = returnType.map { " -> " + $0 } ?? ""
- let genericWhere = genericWhereClause.map { " " + $0 } ?? ""
- let asyncThrows: String
- switch (async, `throws`) {
- case (true, true):
- asyncThrows = " async throws"
- case (true, false):
- asyncThrows = " async"
- case (false, true):
- asyncThrows = " throws"
- case (false, false):
- asyncThrows = ""
- }
- let hasBody = bodyBuilder != nil
- if arguments.isEmpty {
- // Don't bother splitting across multiple lines if there are no arguments.
- self.println(
- "\(functionHead)func \(name)()\(asyncThrows)\(`return`)\(genericWhere)",
- newline: !hasBody
- )
- } else {
- self.println("\(functionHead)func \(name)(")
- self.withIndentation {
- // Add a comma after each argument except the last.
- arguments.forEach(beforeLast: {
- self.println($0 + ",")
- }, onLast: {
- self.println($0)
- })
- }
- self.println(")\(asyncThrows)\(`return`)\(genericWhere)", newline: !hasBody)
- }
- if let bodyBuilder = bodyBuilder {
- self.println(" {")
- self.withIndentation {
- bodyBuilder()
- }
- self.println("}")
- }
- }
- private func printServiceClientProtocol() {
- let comments = self.service.protoSourceComments()
- if !comments.isEmpty {
- // Source comments already have the leading '///'
- self.println(comments, newline: false)
- self.println("///")
- }
- self.println(
- "/// Usage: instantiate `\(self.clientClassName)`, then call methods of this protocol to make API calls."
- )
- self.println("\(self.access) protocol \(self.clientProtocolName): GRPCClient {")
- self.withIndentation {
- self.println("var serviceName: String { get }")
- self.println("var interceptors: \(self.clientInterceptorProtocolName)? { get }")
- for method in service.methods {
- self.println()
- self.method = method
- self.printFunction(
- name: self.methodFunctionName,
- arguments: self.methodArgumentsWithoutDefaults,
- returnType: self.methodReturnType,
- bodyBuilder: nil
- )
- }
- }
- println("}")
- }
- private func printClientProtocolExtension() {
- self.println("extension \(self.clientProtocolName) {")
- self.withIndentation {
- // Service name.
- self.println("\(self.access) var serviceName: String {")
- self.withIndentation {
- self.println("return \"\(self.servicePath)\"")
- }
- self.println("}")
- // Default method implementations.
- self.printMethods()
- }
- self.println("}")
- }
- private func printServiceClientInterceptorFactoryProtocol() {
- self.println("\(self.access) protocol \(self.clientInterceptorProtocolName): GRPCSendable {")
- self.withIndentation {
- // Method specific interceptors.
- for method in service.methods {
- self.println()
- self.method = method
- self.println(
- "/// - Returns: Interceptors to use when invoking '\(self.methodFunctionName)'."
- )
- // Skip the access, we're defining a protocol.
- self.printMethodInterceptorFactory(access: nil)
- }
- }
- self.println("}")
- }
- private func printMethodInterceptorFactory(
- access: String?,
- bodyBuilder: (() -> Void)? = nil
- ) {
- self.printFunction(
- name: self.methodInterceptorFactoryName,
- arguments: [],
- returnType: "[ClientInterceptor<\(self.methodInputName), \(self.methodOutputName)>]",
- access: access,
- bodyBuilder: bodyBuilder
- )
- }
- private func printClassBackedServiceClientImplementation() {
- self.printIfCompilerGuardForAsyncAwait()
- self.println("@available(*, deprecated)")
- self.println("extension \(clientClassName): @unchecked Sendable {}")
- self.printEndCompilerGuardForAsyncAwait()
- self.println()
- self.println("@available(*, deprecated, renamed: \"\(clientStructName)\")")
- println("\(access) final class \(clientClassName): \(clientProtocolName) {")
- self.withIndentation {
- println("private let lock = Lock()")
- println("private var _defaultCallOptions: CallOptions")
- println("private var _interceptors: \(clientInterceptorProtocolName)?")
- println("\(access) let channel: GRPCChannel")
- println("\(access) var defaultCallOptions: CallOptions {")
- self.withIndentation {
- println("get { self.lock.withLock { return self._defaultCallOptions } }")
- println("set { self.lock.withLockVoid { self._defaultCallOptions = newValue } }")
- }
- self.println("}")
- println("\(access) var interceptors: \(clientInterceptorProtocolName)? {")
- self.withIndentation {
- println("get { self.lock.withLock { return self._interceptors } }")
- println("set { self.lock.withLockVoid { self._interceptors = newValue } }")
- }
- println("}")
- println()
- println("/// Creates a client for the \(servicePath) service.")
- println("///")
- self.printParameters()
- println("/// - channel: `GRPCChannel` to the service host.")
- println(
- "/// - defaultCallOptions: Options to use for each service call if the user doesn't provide them."
- )
- println("/// - interceptors: A factory providing interceptors for each RPC.")
- println("\(access) init(")
- self.withIndentation {
- println("channel: GRPCChannel,")
- println("defaultCallOptions: CallOptions = CallOptions(),")
- println("interceptors: \(clientInterceptorProtocolName)? = nil")
- }
- self.println(") {")
- self.withIndentation {
- println("self.channel = channel")
- println("self._defaultCallOptions = defaultCallOptions")
- println("self._interceptors = interceptors")
- }
- self.println("}")
- }
- println("}")
- }
- private func printStructBackedServiceClientImplementation() {
- println("\(access) struct \(clientStructName): \(clientProtocolName) {")
- self.withIndentation {
- println("\(access) var channel: GRPCChannel")
- println("\(access) var defaultCallOptions: CallOptions")
- println("\(access) var interceptors: \(clientInterceptorProtocolName)?")
- println()
- println("/// Creates a client for the \(servicePath) service.")
- println("///")
- self.printParameters()
- println("/// - channel: `GRPCChannel` to the service host.")
- println(
- "/// - defaultCallOptions: Options to use for each service call if the user doesn't provide them."
- )
- println("/// - interceptors: A factory providing interceptors for each RPC.")
- println("\(access) init(")
- self.withIndentation {
- println("channel: GRPCChannel,")
- println("defaultCallOptions: CallOptions = CallOptions(),")
- println("interceptors: \(clientInterceptorProtocolName)? = nil")
- }
- self.println(") {")
- self.withIndentation {
- println("self.channel = channel")
- println("self.defaultCallOptions = defaultCallOptions")
- println("self.interceptors = interceptors")
- }
- self.println("}")
- }
- println("}")
- }
- private func printMethods() {
- for method in self.service.methods {
- self.println()
- self.method = method
- switch self.streamType {
- case .unary:
- self.printUnaryCall()
- case .serverStreaming:
- self.printServerStreamingCall()
- case .clientStreaming:
- self.printClientStreamingCall()
- case .bidirectionalStreaming:
- self.printBidirectionalStreamingCall()
- }
- }
- }
- private func printUnaryCall() {
- self.println(self.method.documentation(streamingType: self.streamType), newline: false)
- self.println("///")
- self.printParameters()
- self.printRequestParameter()
- self.printCallOptionsParameter()
- self.println("/// - Returns: A `UnaryCall` with futures for the metadata, status and response.")
- self.printFunction(
- name: self.methodFunctionName,
- arguments: self.methodArguments,
- returnType: self.methodReturnType,
- access: self.access
- ) {
- self.println("return self.makeUnaryCall(")
- self.withIndentation {
- self.println("path: \(self.methodPathUsingClientMetadata),")
- self.println("request: request,")
- self.println("callOptions: callOptions ?? self.defaultCallOptions,")
- self.println(
- "interceptors: self.interceptors?.\(self.methodInterceptorFactoryName)() ?? []"
- )
- }
- self.println(")")
- }
- }
- private func printServerStreamingCall() {
- self.println(self.method.documentation(streamingType: self.streamType), newline: false)
- self.println("///")
- self.printParameters()
- self.printRequestParameter()
- self.printCallOptionsParameter()
- self.printHandlerParameter()
- self.println("/// - Returns: A `ServerStreamingCall` with futures for the metadata and status.")
- self.printFunction(
- name: self.methodFunctionName,
- arguments: self.methodArguments,
- returnType: self.methodReturnType,
- access: self.access
- ) {
- self.println("return self.makeServerStreamingCall(")
- self.withIndentation {
- self.println("path: \(self.methodPathUsingClientMetadata),")
- self.println("request: request,")
- self.println("callOptions: callOptions ?? self.defaultCallOptions,")
- self.println(
- "interceptors: self.interceptors?.\(self.methodInterceptorFactoryName)() ?? [],"
- )
- self.println("handler: handler")
- }
- self.println(")")
- }
- }
- private func printClientStreamingCall() {
- self.println(self.method.documentation(streamingType: self.streamType), newline: false)
- self.println("///")
- self.printClientStreamingDetails()
- self.println("///")
- self.printParameters()
- self.printCallOptionsParameter()
- self
- .println(
- "/// - Returns: A `ClientStreamingCall` with futures for the metadata, status and response."
- )
- self.printFunction(
- name: self.methodFunctionName,
- arguments: self.methodArguments,
- returnType: self.methodReturnType,
- access: self.access
- ) {
- self.println("return self.makeClientStreamingCall(")
- self.withIndentation {
- self.println("path: \(self.methodPathUsingClientMetadata),")
- self.println("callOptions: callOptions ?? self.defaultCallOptions,")
- self.println(
- "interceptors: self.interceptors?.\(self.methodInterceptorFactoryName)() ?? []"
- )
- }
- self.println(")")
- }
- }
- private func printBidirectionalStreamingCall() {
- self.println(self.method.documentation(streamingType: self.streamType), newline: false)
- self.println("///")
- self.printClientStreamingDetails()
- self.println("///")
- self.printParameters()
- self.printCallOptionsParameter()
- self.printHandlerParameter()
- self.println("/// - Returns: A `ClientStreamingCall` with futures for the metadata and status.")
- self.printFunction(
- name: self.methodFunctionName,
- arguments: self.methodArguments,
- returnType: self.methodReturnType,
- access: self.access
- ) {
- self.println("return self.makeBidirectionalStreamingCall(")
- self.withIndentation {
- self.println("path: \(self.methodPathUsingClientMetadata),")
- self.println("callOptions: callOptions ?? self.defaultCallOptions,")
- self.println(
- "interceptors: self.interceptors?.\(self.methodInterceptorFactoryName)() ?? [],"
- )
- self.println("handler: handler")
- }
- self.println(")")
- }
- }
- private func printClientStreamingDetails() {
- println("/// Callers should use the `send` method on the returned object to send messages")
- println(
- "/// to the server. The caller should send an `.end` after the final message has been sent."
- )
- }
- private func printParameters() {
- println("/// - Parameters:")
- }
- private func printRequestParameter() {
- println("/// - request: Request to send to \(method.name).")
- }
- private func printCallOptionsParameter() {
- println("/// - callOptions: Call options.")
- }
- private func printHandlerParameter() {
- println("/// - handler: A closure called when each response is received from the server.")
- }
- }
- extension Generator {
- private func printFakeResponseStreams() {
- for method in self.service.methods {
- self.println()
- self.method = method
- switch self.streamType {
- case .unary, .clientStreaming:
- self.printUnaryResponse()
- case .serverStreaming, .bidirectionalStreaming:
- self.printStreamingResponse()
- }
- }
- }
- private func printUnaryResponse() {
- self.printResponseStream(isUnary: true)
- self.println()
- self.printEnqueueUnaryResponse(isUnary: true)
- self.println()
- self.printHasResponseStreamEnqueued()
- }
- private func printStreamingResponse() {
- self.printResponseStream(isUnary: false)
- self.println()
- self.printEnqueueUnaryResponse(isUnary: false)
- self.println()
- self.printHasResponseStreamEnqueued()
- }
- private func printEnqueueUnaryResponse(isUnary: Bool) {
- let name: String
- let responseArg: String
- let responseArgAndType: String
- if isUnary {
- name = "enqueue\(self.method.name)Response"
- responseArg = "response"
- responseArgAndType = "_ \(responseArg): \(self.methodOutputName)"
- } else {
- name = "enqueue\(self.method.name)Responses"
- responseArg = "responses"
- responseArgAndType = "_ \(responseArg): [\(self.methodOutputName)]"
- }
- self.printFunction(
- name: name,
- arguments: [
- responseArgAndType,
- "_ requestHandler: @escaping (FakeRequestPart<\(self.methodInputName)>) -> () = { _ in }",
- ],
- returnType: nil,
- access: self.access
- ) {
- self.println("let stream = self.make\(self.method.name)ResponseStream(requestHandler)")
- if isUnary {
- self.println("// This is the only operation on the stream; try! is fine.")
- self.println("try! stream.sendMessage(\(responseArg))")
- } else {
- self.println("// These are the only operation on the stream; try! is fine.")
- self.println("\(responseArg).forEach { try! stream.sendMessage($0) }")
- self.println("try! stream.sendEnd()")
- }
- }
- }
- private func printResponseStream(isUnary: Bool) {
- let type = isUnary ? "FakeUnaryResponse" : "FakeStreamingResponse"
- let factory = isUnary ? "makeFakeUnaryResponse" : "makeFakeStreamingResponse"
- self
- .println(
- "/// Make a \(isUnary ? "unary" : "streaming") response for the \(self.method.name) RPC. This must be called"
- )
- self.println("/// before calling '\(self.methodFunctionName)'. See also '\(type)'.")
- self.println("///")
- self.println("/// - Parameter requestHandler: a handler for request parts sent by the RPC.")
- self.printFunction(
- name: "make\(self.method.name)ResponseStream",
- arguments: [
- "_ requestHandler: @escaping (FakeRequestPart<\(self.methodInputName)>) -> () = { _ in }",
- ],
- returnType: "\(type)<\(self.methodInputName), \(self.methodOutputName)>",
- access: self.access
- ) {
- self
- .println(
- "return self.fakeChannel.\(factory)(path: \(self.methodPathUsingClientMetadata), requestHandler: requestHandler)"
- )
- }
- }
- private func printHasResponseStreamEnqueued() {
- self
- .println("/// Returns true if there are response streams enqueued for '\(self.method.name)'")
- self.println("\(self.access) var has\(self.method.name)ResponsesRemaining: Bool {")
- self.withIndentation {
- self.println(
- "return self.fakeChannel.hasFakeResponseEnqueued(forPath: \(self.methodPathUsingClientMetadata))"
- )
- }
- self.println("}")
- }
- private func printTestClient() {
- self.printIfCompilerGuardForAsyncAwait()
- self.println("@available(swift, deprecated: 5.6)")
- self.println("extension \(self.testClientClassName): @unchecked Sendable {}")
- self.printEndCompilerGuardForAsyncAwait()
- self.println()
- self.println(
- "@available(swift, deprecated: 5.6, message: \"Test clients are not Sendable "
- + "but the 'GRPCClient' API requires clients to be Sendable. Using a localhost client and "
- + "server is the recommended alternative.\")"
- )
- self.println(
- "\(self.access) final class \(self.testClientClassName): \(self.clientProtocolName) {"
- )
- self.withIndentation {
- self.println("private let fakeChannel: FakeChannel")
- self.println("\(access) var defaultCallOptions: CallOptions")
- self.println("\(access) var interceptors: \(clientInterceptorProtocolName)?")
- self.println()
- self.println("\(self.access) var channel: GRPCChannel {")
- self.withIndentation {
- self.println("return self.fakeChannel")
- }
- self.println("}")
- self.println()
- self.println("\(self.access) init(")
- self.withIndentation {
- self.println("fakeChannel: FakeChannel = FakeChannel(),")
- self.println("defaultCallOptions callOptions: CallOptions = CallOptions(),")
- self.println("interceptors: \(clientInterceptorProtocolName)? = nil")
- }
- self.println(") {")
- self.withIndentation {
- self.println("self.fakeChannel = fakeChannel")
- self.println("self.defaultCallOptions = callOptions")
- self.println("self.interceptors = interceptors")
- }
- self.println("}")
- self.printFakeResponseStreams()
- }
- self.println("}") // end class
- }
- }
- extension Generator {
- private var streamType: StreamingType {
- return streamingType(self.method)
- }
- }
- extension Generator {
- private var methodArguments: [String] {
- switch self.streamType {
- case .unary:
- return [
- "_ request: \(self.methodInputName)",
- "callOptions: CallOptions? = nil",
- ]
- case .serverStreaming:
- return [
- "_ request: \(self.methodInputName)",
- "callOptions: CallOptions? = nil",
- "handler: @escaping (\(methodOutputName)) -> Void",
- ]
- case .clientStreaming:
- return ["callOptions: CallOptions? = nil"]
- case .bidirectionalStreaming:
- return [
- "callOptions: CallOptions? = nil",
- "handler: @escaping (\(methodOutputName)) -> Void",
- ]
- }
- }
- private var methodArgumentsWithoutDefaults: [String] {
- return self.methodArguments.map { arg in
- // Remove default arg from call options.
- if arg == "callOptions: CallOptions? = nil" {
- return "callOptions: CallOptions?"
- } else {
- return arg
- }
- }
- }
- private var methodArgumentsWithoutCallOptions: [String] {
- return self.methodArguments.filter {
- !$0.hasPrefix("callOptions: ")
- }
- }
- private var methodReturnType: String {
- switch self.streamType {
- case .unary:
- return "UnaryCall<\(self.methodInputName), \(self.methodOutputName)>"
- case .serverStreaming:
- return "ServerStreamingCall<\(self.methodInputName), \(self.methodOutputName)>"
- case .clientStreaming:
- return "ClientStreamingCall<\(self.methodInputName), \(self.methodOutputName)>"
- case .bidirectionalStreaming:
- return "BidirectionalStreamingCall<\(self.methodInputName), \(self.methodOutputName)>"
- }
- }
- }
- extension StreamingType {
- fileprivate var name: String {
- switch self {
- case .unary:
- return "Unary"
- case .clientStreaming:
- return "Client streaming"
- case .serverStreaming:
- return "Server streaming"
- case .bidirectionalStreaming:
- return "Bidirectional streaming"
- }
- }
- }
- extension MethodDescriptor {
- var documentation: String? {
- let comments = self.protoSourceComments(commentPrefix: "")
- return comments.isEmpty ? nil : comments
- }
- fileprivate func documentation(streamingType: StreamingType) -> String {
- let sourceComments = self.protoSourceComments()
- if sourceComments.isEmpty {
- return "/// \(streamingType.name) call to \(self.name)\n" // comments end with "\n" already.
- } else {
- return sourceComments // already prefixed with "///"
- }
- }
- }
- extension Array {
- /// Like `forEach` except that the `body` closure operates on all elements except for the last,
- /// and the `last` closure only operates on the last element.
- fileprivate func forEach(beforeLast body: (Element) -> Void, onLast last: (Element) -> Void) {
- for element in self.dropLast() {
- body(element)
- }
- if let lastElement = self.last {
- last(lastElement)
- }
- }
- }
|