|
|
@@ -25,17 +25,23 @@ extension Generator {
|
|
|
printClientProtocolExtension()
|
|
|
println()
|
|
|
printServiceClientImplementation()
|
|
|
+
|
|
|
+ if self.options.generateTestClient {
|
|
|
+ self.println()
|
|
|
+ self.printTestClient()
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
- private func printFunction(name: String, arguments: [String], returnType: String, access: String? = nil, bodyBuilder: (() -> ())?) {
|
|
|
+ private func printFunction(name: String, arguments: [String], returnType: String?, access: String? = nil, bodyBuilder: (() -> ())?) {
|
|
|
// Add a space after access, if it exists.
|
|
|
let accessOrEmpty = access.map { $0 + " " } ?? ""
|
|
|
+ let `return` = returnType.map { "-> " + $0 } ?? ""
|
|
|
|
|
|
let hasBody = bodyBuilder != nil
|
|
|
|
|
|
if arguments.isEmpty {
|
|
|
// Don't bother splitting across multiple lines if there are no arguments.
|
|
|
- self.println("\(accessOrEmpty)func \(name)() -> \(returnType)", newline: !hasBody)
|
|
|
+ self.println("\(accessOrEmpty)func \(name)() \(`return`)", newline: !hasBody)
|
|
|
} else {
|
|
|
self.println("\(accessOrEmpty)func \(name)(")
|
|
|
self.withIndentation {
|
|
|
@@ -46,7 +52,7 @@ extension Generator {
|
|
|
self.println($0)
|
|
|
})
|
|
|
}
|
|
|
- self.println(") -> \(returnType)", newline: !hasBody)
|
|
|
+ self.println(") \(`return`)", newline: !hasBody)
|
|
|
}
|
|
|
|
|
|
if let bodyBuilder = bodyBuilder {
|
|
|
@@ -67,7 +73,7 @@ extension Generator {
|
|
|
|
|
|
self.printFunction(
|
|
|
name: self.methodFunctionName,
|
|
|
- arguments: self.methodArguments,
|
|
|
+ arguments: self.methodArgumentsWithoutDefaults,
|
|
|
returnType: self.methodReturnType,
|
|
|
bodyBuilder: nil
|
|
|
)
|
|
|
@@ -80,43 +86,11 @@ extension Generator {
|
|
|
|
|
|
private func printClientProtocolExtension() {
|
|
|
self.println("extension \(self.clientProtocolName) {")
|
|
|
- self.withIndentation {
|
|
|
- for method in service.methods {
|
|
|
- self.method = method
|
|
|
- let body: () -> ()
|
|
|
-
|
|
|
- switch streamingType(method) {
|
|
|
- case .unary:
|
|
|
- body = {
|
|
|
- self.println("return self.\(self.methodFunctionName)(request, callOptions: self.defaultCallOptions)")
|
|
|
- }
|
|
|
-
|
|
|
- case .serverStreaming:
|
|
|
- body = {
|
|
|
- self.println("return self.\(self.methodFunctionName)(request, callOptions: self.defaultCallOptions, handler: handler)")
|
|
|
- }
|
|
|
-
|
|
|
- case .clientStreaming:
|
|
|
- body = {
|
|
|
- self.println("return self.\(self.methodFunctionName)(callOptions: self.defaultCallOptions)")
|
|
|
- }
|
|
|
-
|
|
|
- case .bidirectionalStreaming:
|
|
|
- body = {
|
|
|
- self.println("return self.\(self.methodFunctionName)(callOptions: self.defaultCallOptions, handler: handler)")
|
|
|
- }
|
|
|
- }
|
|
|
|
|
|
- self.printFunction(
|
|
|
- name: self.methodFunctionName,
|
|
|
- arguments: self.methodArgumentsWithoutCallOptions,
|
|
|
- returnType: self.methodReturnType,
|
|
|
- access: self.access,
|
|
|
- bodyBuilder: body
|
|
|
- )
|
|
|
- self.println()
|
|
|
- }
|
|
|
+ self.withIndentation {
|
|
|
+ self.printMethods()
|
|
|
}
|
|
|
+
|
|
|
self.println("}")
|
|
|
}
|
|
|
|
|
|
@@ -137,35 +111,32 @@ extension Generator {
|
|
|
println("self.defaultCallOptions = defaultCallOptions")
|
|
|
outdent()
|
|
|
println("}")
|
|
|
-
|
|
|
- self.printMethods()
|
|
|
-
|
|
|
outdent()
|
|
|
println("}")
|
|
|
}
|
|
|
|
|
|
- private func printMethods(callFactory: String = "self") {
|
|
|
+ private func printMethods() {
|
|
|
for method in self.service.methods {
|
|
|
self.println()
|
|
|
|
|
|
self.method = method
|
|
|
switch self.streamType {
|
|
|
case .unary:
|
|
|
- self.printUnaryCall(callFactory: callFactory)
|
|
|
+ self.printUnaryCall()
|
|
|
|
|
|
case .serverStreaming:
|
|
|
- self.printServerStreamingCall(callFactory: callFactory)
|
|
|
+ self.printServerStreamingCall()
|
|
|
|
|
|
case .clientStreaming:
|
|
|
- self.printClientStreamingCall(callFactory: callFactory)
|
|
|
+ self.printClientStreamingCall()
|
|
|
|
|
|
case .bidirectionalStreaming:
|
|
|
- self.printBidirectionalStreamingCall(callFactory: callFactory)
|
|
|
+ self.printBidirectionalStreamingCall()
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- private func printUnaryCall(callFactory: String) {
|
|
|
+ private func printUnaryCall() {
|
|
|
self.println(self.method.documentation(streamingType: self.streamType), newline: false)
|
|
|
self.println("///")
|
|
|
self.printParameters()
|
|
|
@@ -178,17 +149,17 @@ extension Generator {
|
|
|
returnType: self.methodReturnType,
|
|
|
access: self.access
|
|
|
) {
|
|
|
- self.println("return \(callFactory).makeUnaryCall(")
|
|
|
+ self.println("return self.makeUnaryCall(")
|
|
|
self.withIndentation {
|
|
|
self.println("path: \(self.methodPath),")
|
|
|
self.println("request: request,")
|
|
|
- self.println("callOptions: callOptions")
|
|
|
+ self.println("callOptions: callOptions ?? self.defaultCallOptions")
|
|
|
}
|
|
|
self.println(")")
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- private func printServerStreamingCall(callFactory: String) {
|
|
|
+ private func printServerStreamingCall() {
|
|
|
self.println(self.method.documentation(streamingType: self.streamType), newline: false)
|
|
|
self.println("///")
|
|
|
self.printParameters()
|
|
|
@@ -202,18 +173,18 @@ extension Generator {
|
|
|
returnType: self.methodReturnType,
|
|
|
access: self.access
|
|
|
) {
|
|
|
- self.println("return \(callFactory).makeServerStreamingCall(") // path: \"/\(servicePath)/\(method.name)\",")
|
|
|
+ self.println("return self.makeServerStreamingCall(") // path: \"/\(servicePath)/\(method.name)\",")
|
|
|
self.withIndentation {
|
|
|
self.println("path: \(self.methodPath),")
|
|
|
self.println("request: request,")
|
|
|
- self.println("callOptions: callOptions,")
|
|
|
+ self.println("callOptions: callOptions ?? self.defaultCallOptions,")
|
|
|
self.println("handler: handler")
|
|
|
}
|
|
|
self.println(")")
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- private func printClientStreamingCall(callFactory: String) {
|
|
|
+ private func printClientStreamingCall() {
|
|
|
self.println(self.method.documentation(streamingType: self.streamType), newline: false)
|
|
|
self.println("///")
|
|
|
self.printClientStreamingDetails()
|
|
|
@@ -227,16 +198,16 @@ extension Generator {
|
|
|
returnType: self.methodReturnType,
|
|
|
access: self.access
|
|
|
) {
|
|
|
- self.println("return \(callFactory).makeClientStreamingCall(")
|
|
|
+ self.println("return self.makeClientStreamingCall(")
|
|
|
self.withIndentation {
|
|
|
self.println("path: \(self.methodPath),")
|
|
|
- self.println("callOptions: callOptions")
|
|
|
+ self.println("callOptions: callOptions ?? self.defaultCallOptions")
|
|
|
}
|
|
|
self.println(")")
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- private func printBidirectionalStreamingCall(callFactory: String) {
|
|
|
+ private func printBidirectionalStreamingCall() {
|
|
|
self.println(self.method.documentation(streamingType: self.streamType), newline: false)
|
|
|
self.println("///")
|
|
|
self.printClientStreamingDetails()
|
|
|
@@ -251,10 +222,10 @@ extension Generator {
|
|
|
returnType: self.methodReturnType,
|
|
|
access: self.access
|
|
|
) {
|
|
|
- self.println("return \(callFactory).makeBidirectionalStreamingCall(")
|
|
|
+ self.println("return self.makeBidirectionalStreamingCall(")
|
|
|
self.withIndentation {
|
|
|
self.println("path: \(self.methodPath),")
|
|
|
- self.println("callOptions: callOptions,")
|
|
|
+ self.println("callOptions: callOptions ?? self.defaultCallOptions,")
|
|
|
self.println("handler: handler")
|
|
|
}
|
|
|
self.println(")")
|
|
|
@@ -283,6 +254,133 @@ extension Generator {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+extension Generator {
|
|
|
+ fileprivate func printFakeResponseStreams() {
|
|
|
+ for method in self.service.methods {
|
|
|
+ self.println()
|
|
|
+
|
|
|
+ self.method = method
|
|
|
+ switch self.streamType {
|
|
|
+ case .unary, .clientStreaming:
|
|
|
+ self.printUnaryResponse()
|
|
|
+
|
|
|
+ case .serverStreaming, .bidirectionalStreaming:
|
|
|
+ self.printStreamingResponse()
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ fileprivate func printUnaryResponse() {
|
|
|
+ self.printResponseStream(isUnary: true)
|
|
|
+ self.println()
|
|
|
+ self.printEnqueueUnaryResponse(isUnary: true)
|
|
|
+ self.println()
|
|
|
+ self.printHasResponseStreamEnqueued()
|
|
|
+ }
|
|
|
+
|
|
|
+ fileprivate func printStreamingResponse() {
|
|
|
+ self.printResponseStream(isUnary: false)
|
|
|
+ self.println()
|
|
|
+ self.printEnqueueUnaryResponse(isUnary: false)
|
|
|
+ self.println()
|
|
|
+ self.printHasResponseStreamEnqueued()
|
|
|
+ }
|
|
|
+
|
|
|
+ private func printEnqueueUnaryResponse(isUnary: Bool) {
|
|
|
+ let name: String
|
|
|
+ let responseArg: String
|
|
|
+ let responseArgAndType: String
|
|
|
+ if isUnary {
|
|
|
+ name = "enqueue\(self.method.name)Response"
|
|
|
+ responseArg = "response"
|
|
|
+ responseArgAndType = "_ \(responseArg): \(self.methodOutputName)"
|
|
|
+ } else {
|
|
|
+ name = "enqueue\(self.method.name)Responses"
|
|
|
+ responseArg = "responses"
|
|
|
+ responseArgAndType = "_ \(responseArg): [\(self.methodOutputName)]"
|
|
|
+ }
|
|
|
+
|
|
|
+ self.printFunction(
|
|
|
+ name: name,
|
|
|
+ arguments: [
|
|
|
+ responseArgAndType,
|
|
|
+ "_ requestHandler: @escaping (FakeRequestPart<\(self.methodInputName)>) -> () = { _ in }"
|
|
|
+ ],
|
|
|
+ returnType: nil,
|
|
|
+ access: self.access
|
|
|
+ ) {
|
|
|
+ self.println("let stream = self.make\(self.method.name)ResponseStream(requestHandler)")
|
|
|
+ if isUnary {
|
|
|
+ self.println("// This is the only operation on the stream; try! is fine.")
|
|
|
+ self.println("try! stream.sendMessage(\(responseArg))")
|
|
|
+ } else {
|
|
|
+ self.println("// These are the only operation on the stream; try! is fine.")
|
|
|
+ self.println("\(responseArg).forEach { try! stream.sendMessage($0) }")
|
|
|
+ self.println("try! stream.sendEnd()")
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ private func printResponseStream(isUnary: Bool) {
|
|
|
+ let type = isUnary ? "FakeUnaryResponse" : "FakeStreamingResponse"
|
|
|
+ let factory = isUnary ? "makeFakeUnaryResponse" : "makeFakeStreamingResponse"
|
|
|
+
|
|
|
+ self.println("/// Make a \(isUnary ? "unary" : "streaming") response for the \(self.method.name) RPC. This must be called")
|
|
|
+ self.println("/// before calling '\(self.methodFunctionName)'. See also '\(type)'.")
|
|
|
+ self.println("///")
|
|
|
+ self.println("/// - Parameter requestHandler: a handler for request parts sent by the RPC.")
|
|
|
+ self.printFunction(
|
|
|
+ name: "make\(self.method.name)ResponseStream",
|
|
|
+ arguments: ["_ requestHandler: @escaping (FakeRequestPart<\(self.methodInputName)>) -> () = { _ in }"],
|
|
|
+ returnType: "\(type)<\(self.methodInputName), \(self.methodOutputName)>",
|
|
|
+ access: self.access
|
|
|
+ ) {
|
|
|
+ self.println("return self.fakeChannel.\(factory)(path: \(self.methodPath), requestHandler: requestHandler)")
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ private func printHasResponseStreamEnqueued() {
|
|
|
+ self.println("/// Returns true if there are response streams enqueued for '\(self.method.name)'")
|
|
|
+ self.println("\(self.access) var has\(self.method.name)ResponsesRemaining: Bool {")
|
|
|
+ self.withIndentation {
|
|
|
+ self.println("return self.fakeChannel.hasFakeResponseEnqueued(forPath: \(self.methodPath))")
|
|
|
+ }
|
|
|
+ self.println("}")
|
|
|
+ }
|
|
|
+
|
|
|
+ fileprivate func printTestClient() {
|
|
|
+ self.println("\(self.access) final class \(self.testClientClassName): \(self.clientProtocolName) {")
|
|
|
+ self.withIndentation {
|
|
|
+ self.println("private let fakeChannel: FakeChannel")
|
|
|
+ self.println("\(self.access) var defaultCallOptions: CallOptions")
|
|
|
+
|
|
|
+ self.println()
|
|
|
+ self.println("\(self.access) var channel: GRPCChannel {")
|
|
|
+ self.withIndentation {
|
|
|
+ self.println("return self.fakeChannel")
|
|
|
+ }
|
|
|
+ self.println("}")
|
|
|
+ self.println()
|
|
|
+
|
|
|
+ self.println("\(self.access) init(")
|
|
|
+ self.withIndentation {
|
|
|
+ self.println("fakeChannel: FakeChannel = FakeChannel(),")
|
|
|
+ self.println("defaultCallOptions callOptions: CallOptions = CallOptions()")
|
|
|
+ }
|
|
|
+ self.println(") {")
|
|
|
+ self.withIndentation {
|
|
|
+ self.println("self.fakeChannel = fakeChannel")
|
|
|
+ self.println("self.defaultCallOptions = callOptions")
|
|
|
+ }
|
|
|
+ self.println("}")
|
|
|
+
|
|
|
+ self.printFakeResponseStreams()
|
|
|
+ }
|
|
|
+
|
|
|
+ self.println("}") // end class
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
fileprivate extension Generator {
|
|
|
var streamType: StreamingType {
|
|
|
return streamingType(self.method)
|
|
|
@@ -295,26 +393,36 @@ extension Generator {
|
|
|
case .unary:
|
|
|
return [
|
|
|
"_ request: \(self.methodInputName)",
|
|
|
- "callOptions: CallOptions"
|
|
|
+ "callOptions: CallOptions? = nil"
|
|
|
]
|
|
|
case .serverStreaming:
|
|
|
return [
|
|
|
"_ request: \(self.methodInputName)",
|
|
|
- "callOptions: CallOptions",
|
|
|
+ "callOptions: CallOptions? = nil",
|
|
|
"handler: @escaping (\(methodOutputName)) -> Void"
|
|
|
]
|
|
|
|
|
|
case .clientStreaming:
|
|
|
- return ["callOptions: CallOptions"]
|
|
|
+ return ["callOptions: CallOptions? = nil"]
|
|
|
|
|
|
case .bidirectionalStreaming:
|
|
|
return [
|
|
|
- "callOptions: CallOptions",
|
|
|
+ "callOptions: CallOptions? = nil",
|
|
|
"handler: @escaping (\(methodOutputName)) -> Void"
|
|
|
]
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ fileprivate var methodArgumentsWithoutDefaults: [String] {
|
|
|
+ return self.methodArguments.map { arg in
|
|
|
+ // Remove default arg from call options.
|
|
|
+ if arg == "callOptions: CallOptions? = nil" {
|
|
|
+ return "callOptions: CallOptions?"
|
|
|
+ } else {
|
|
|
+ return arg
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
fileprivate var methodArgumentsWithoutCallOptions: [String] {
|
|
|
return self.methodArguments.filter {
|