Browse Source

Provide the codegen with an option to generate test clients (#870)

* Provide the codegen with an option to generate test clients

Motivation:

When consuming gRPC it is often helpful to be able to write tests that
ensure the client is integrated correctly. At the moment this is only
possible by running a local gRPC server with a custom service handler to
return the responses you would like to test.

Modifications:

This builds on work in #855, #864, and #865.

This pull request introduces code generation for the test clients. It
provides type-safe wrappers and convenience methods on top of
`FakeChannel` and a code-gen option to enable 'TestClient' generation.

It also removes an `init` requirement on the `GRPCClient` protocol.

Result:

Users can generate test clients.

* Regenerate

* fix type, add assertion

* Add "Remaining"
George Barnett 5 years ago
parent
commit
56c455dce8

+ 171 - 63
Sources/protoc-gen-grpc-swift/Generator-Client.swift

@@ -25,17 +25,23 @@ extension Generator {
     printClientProtocolExtension()
     println()
     printServiceClientImplementation()
+
+    if self.options.generateTestClient {
+      self.println()
+      self.printTestClient()
+    }
   }
 
-  private func printFunction(name: String, arguments: [String], returnType: String, access: String? = nil, bodyBuilder: (() -> ())?) {
+  private func printFunction(name: String, arguments: [String], returnType: String?, access: String? = nil, bodyBuilder: (() -> ())?) {
     // Add a space after access, if it exists.
     let accessOrEmpty = access.map { $0 + " " } ?? ""
+    let `return` = returnType.map { "-> " + $0 } ?? ""
 
     let hasBody = bodyBuilder != nil
 
     if arguments.isEmpty {
       // Don't bother splitting across multiple lines if there are no arguments.
-      self.println("\(accessOrEmpty)func \(name)() -> \(returnType)", newline: !hasBody)
+      self.println("\(accessOrEmpty)func \(name)() \(`return`)", newline: !hasBody)
     } else {
       self.println("\(accessOrEmpty)func \(name)(")
       self.withIndentation {
@@ -46,7 +52,7 @@ extension Generator {
           self.println($0)
         })
       }
-      self.println(") -> \(returnType)", newline: !hasBody)
+      self.println(") \(`return`)", newline: !hasBody)
     }
 
     if let bodyBuilder = bodyBuilder {
@@ -67,7 +73,7 @@ extension Generator {
 
         self.printFunction(
           name: self.methodFunctionName,
-          arguments: self.methodArguments,
+          arguments: self.methodArgumentsWithoutDefaults,
           returnType: self.methodReturnType,
           bodyBuilder: nil
         )
@@ -80,43 +86,11 @@ extension Generator {
 
   private func printClientProtocolExtension() {
     self.println("extension \(self.clientProtocolName) {")
-    self.withIndentation {
-      for method in service.methods {
-        self.method = method
-        let body: () -> ()
-
-        switch streamingType(method) {
-        case .unary:
-          body = {
-            self.println("return self.\(self.methodFunctionName)(request, callOptions: self.defaultCallOptions)")
-          }
-
-        case .serverStreaming:
-          body = {
-            self.println("return self.\(self.methodFunctionName)(request, callOptions: self.defaultCallOptions, handler: handler)")
-          }
-
-        case .clientStreaming:
-          body = {
-            self.println("return self.\(self.methodFunctionName)(callOptions: self.defaultCallOptions)")
-          }
-
-        case .bidirectionalStreaming:
-          body = {
-            self.println("return self.\(self.methodFunctionName)(callOptions: self.defaultCallOptions, handler: handler)")
-          }
-        }
 
-        self.printFunction(
-          name: self.methodFunctionName,
-          arguments: self.methodArgumentsWithoutCallOptions,
-          returnType: self.methodReturnType,
-          access: self.access,
-          bodyBuilder: body
-        )
-        self.println()
-      }
+    self.withIndentation {
+      self.printMethods()
     }
+
     self.println("}")
   }
 
@@ -137,35 +111,32 @@ extension Generator {
     println("self.defaultCallOptions = defaultCallOptions")
     outdent()
     println("}")
-
-    self.printMethods()
-
     outdent()
     println("}")
   }
 
-  private func printMethods(callFactory: String = "self") {
+  private func printMethods() {
     for method in self.service.methods {
       self.println()
 
       self.method = method
       switch self.streamType {
       case .unary:
-        self.printUnaryCall(callFactory: callFactory)
+        self.printUnaryCall()
 
       case .serverStreaming:
-        self.printServerStreamingCall(callFactory: callFactory)
+        self.printServerStreamingCall()
 
       case .clientStreaming:
-        self.printClientStreamingCall(callFactory: callFactory)
+        self.printClientStreamingCall()
 
       case .bidirectionalStreaming:
-        self.printBidirectionalStreamingCall(callFactory: callFactory)
+        self.printBidirectionalStreamingCall()
       }
     }
   }
 
-  private func printUnaryCall(callFactory: String) {
+  private func printUnaryCall() {
     self.println(self.method.documentation(streamingType: self.streamType), newline: false)
     self.println("///")
     self.printParameters()
@@ -178,17 +149,17 @@ extension Generator {
       returnType: self.methodReturnType,
       access: self.access
     ) {
-      self.println("return \(callFactory).makeUnaryCall(")
+      self.println("return self.makeUnaryCall(")
       self.withIndentation {
         self.println("path: \(self.methodPath),")
         self.println("request: request,")
-        self.println("callOptions: callOptions")
+        self.println("callOptions: callOptions ?? self.defaultCallOptions")
       }
       self.println(")")
     }
   }
 
-  private func printServerStreamingCall(callFactory: String) {
+  private func printServerStreamingCall() {
     self.println(self.method.documentation(streamingType: self.streamType), newline: false)
     self.println("///")
     self.printParameters()
@@ -202,18 +173,18 @@ extension Generator {
       returnType: self.methodReturnType,
       access: self.access
     ) {
-      self.println("return \(callFactory).makeServerStreamingCall(") // path: \"/\(servicePath)/\(method.name)\",")
+      self.println("return self.makeServerStreamingCall(") // path: \"/\(servicePath)/\(method.name)\",")
       self.withIndentation {
         self.println("path: \(self.methodPath),")
         self.println("request: request,")
-        self.println("callOptions: callOptions,")
+        self.println("callOptions: callOptions ?? self.defaultCallOptions,")
         self.println("handler: handler")
       }
       self.println(")")
     }
   }
 
-  private func printClientStreamingCall(callFactory: String) {
+  private func printClientStreamingCall() {
     self.println(self.method.documentation(streamingType: self.streamType), newline: false)
     self.println("///")
     self.printClientStreamingDetails()
@@ -227,16 +198,16 @@ extension Generator {
       returnType: self.methodReturnType,
       access: self.access
     ) {
-      self.println("return \(callFactory).makeClientStreamingCall(")
+      self.println("return self.makeClientStreamingCall(")
       self.withIndentation {
         self.println("path: \(self.methodPath),")
-        self.println("callOptions: callOptions")
+        self.println("callOptions: callOptions ?? self.defaultCallOptions")
       }
       self.println(")")
     }
   }
 
-  private func printBidirectionalStreamingCall(callFactory: String) {
+  private func printBidirectionalStreamingCall() {
     self.println(self.method.documentation(streamingType: self.streamType), newline: false)
     self.println("///")
     self.printClientStreamingDetails()
@@ -251,10 +222,10 @@ extension Generator {
       returnType: self.methodReturnType,
       access: self.access
     ) {
-      self.println("return \(callFactory).makeBidirectionalStreamingCall(")
+      self.println("return self.makeBidirectionalStreamingCall(")
       self.withIndentation {
         self.println("path: \(self.methodPath),")
-        self.println("callOptions: callOptions,")
+        self.println("callOptions: callOptions ?? self.defaultCallOptions,")
         self.println("handler: handler")
       }
       self.println(")")
@@ -283,6 +254,133 @@ extension Generator {
   }
 }
 
+extension Generator {
+  fileprivate func printFakeResponseStreams() {
+    for method in self.service.methods {
+      self.println()
+
+      self.method = method
+      switch self.streamType {
+      case .unary, .clientStreaming:
+        self.printUnaryResponse()
+
+      case .serverStreaming, .bidirectionalStreaming:
+        self.printStreamingResponse()
+      }
+    }
+  }
+
+  fileprivate func printUnaryResponse() {
+    self.printResponseStream(isUnary: true)
+    self.println()
+    self.printEnqueueUnaryResponse(isUnary: true)
+    self.println()
+    self.printHasResponseStreamEnqueued()
+  }
+
+  fileprivate func printStreamingResponse() {
+    self.printResponseStream(isUnary: false)
+    self.println()
+    self.printEnqueueUnaryResponse(isUnary: false)
+    self.println()
+    self.printHasResponseStreamEnqueued()
+  }
+
+  private func printEnqueueUnaryResponse(isUnary: Bool) {
+    let name: String
+    let responseArg: String
+    let responseArgAndType: String
+    if isUnary {
+      name = "enqueue\(self.method.name)Response"
+      responseArg = "response"
+      responseArgAndType = "_ \(responseArg): \(self.methodOutputName)"
+    } else {
+      name = "enqueue\(self.method.name)Responses"
+      responseArg = "responses"
+      responseArgAndType = "_ \(responseArg): [\(self.methodOutputName)]"
+    }
+
+    self.printFunction(
+      name: name,
+      arguments: [
+        responseArgAndType,
+        "_ requestHandler: @escaping (FakeRequestPart<\(self.methodInputName)>) -> () = { _ in }"
+      ],
+      returnType: nil,
+      access: self.access
+    ) {
+      self.println("let stream = self.make\(self.method.name)ResponseStream(requestHandler)")
+      if isUnary {
+        self.println("// This is the only operation on the stream; try! is fine.")
+        self.println("try! stream.sendMessage(\(responseArg))")
+      } else {
+        self.println("// These are the only operation on the stream; try! is fine.")
+        self.println("\(responseArg).forEach { try! stream.sendMessage($0) }")
+        self.println("try! stream.sendEnd()")
+      }
+    }
+  }
+
+  private func printResponseStream(isUnary: Bool) {
+    let type = isUnary ? "FakeUnaryResponse" : "FakeStreamingResponse"
+    let factory = isUnary ? "makeFakeUnaryResponse" : "makeFakeStreamingResponse"
+
+    self.println("/// Make a \(isUnary ? "unary" : "streaming") response for the \(self.method.name) RPC. This must be called")
+    self.println("/// before calling '\(self.methodFunctionName)'. See also '\(type)'.")
+    self.println("///")
+    self.println("/// - Parameter requestHandler: a handler for request parts sent by the RPC.")
+    self.printFunction(
+      name: "make\(self.method.name)ResponseStream",
+      arguments: ["_ requestHandler: @escaping (FakeRequestPart<\(self.methodInputName)>) -> () = { _ in }"],
+      returnType: "\(type)<\(self.methodInputName), \(self.methodOutputName)>",
+      access: self.access
+    ) {
+      self.println("return self.fakeChannel.\(factory)(path: \(self.methodPath), requestHandler: requestHandler)")
+    }
+  }
+
+  private func printHasResponseStreamEnqueued() {
+    self.println("/// Returns true if there are response streams enqueued for '\(self.method.name)'")
+    self.println("\(self.access) var has\(self.method.name)ResponsesRemaining: Bool {")
+    self.withIndentation {
+      self.println("return self.fakeChannel.hasFakeResponseEnqueued(forPath: \(self.methodPath))")
+    }
+    self.println("}")
+  }
+
+  fileprivate func printTestClient() {
+    self.println("\(self.access) final class \(self.testClientClassName): \(self.clientProtocolName) {")
+    self.withIndentation {
+      self.println("private let fakeChannel: FakeChannel")
+      self.println("\(self.access) var defaultCallOptions: CallOptions")
+
+      self.println()
+      self.println("\(self.access) var channel: GRPCChannel {")
+      self.withIndentation {
+        self.println("return self.fakeChannel")
+      }
+      self.println("}")
+      self.println()
+
+      self.println("\(self.access) init(")
+      self.withIndentation {
+        self.println("fakeChannel: FakeChannel = FakeChannel(),")
+        self.println("defaultCallOptions callOptions: CallOptions = CallOptions()")
+      }
+      self.println(") {")
+      self.withIndentation {
+        self.println("self.fakeChannel = fakeChannel")
+        self.println("self.defaultCallOptions = callOptions")
+      }
+      self.println("}")
+
+      self.printFakeResponseStreams()
+    }
+
+    self.println("}")  // end class
+  }
+}
+
 fileprivate extension Generator {
   var streamType: StreamingType {
     return streamingType(self.method)
@@ -295,26 +393,36 @@ extension Generator {
     case .unary:
       return [
         "_ request: \(self.methodInputName)",
-        "callOptions: CallOptions"
+        "callOptions: CallOptions? = nil"
       ]
     case .serverStreaming:
       return [
         "_ request: \(self.methodInputName)",
-        "callOptions: CallOptions",
+        "callOptions: CallOptions? = nil",
         "handler: @escaping (\(methodOutputName)) -> Void"
       ]
 
     case .clientStreaming:
-      return ["callOptions: CallOptions"]
+      return ["callOptions: CallOptions? = nil"]
 
     case .bidirectionalStreaming:
       return [
-        "callOptions: CallOptions",
+        "callOptions: CallOptions? = nil",
         "handler: @escaping (\(methodOutputName)) -> Void"
       ]
     }
   }
 
+  fileprivate var methodArgumentsWithoutDefaults: [String] {
+    return self.methodArguments.map { arg in
+      // Remove default arg from call options.
+      if arg == "callOptions: CallOptions? = nil" {
+        return "callOptions: CallOptions?"
+      } else {
+        return arg
+      }
+    }
+  }
 
   fileprivate var methodArgumentsWithoutCallOptions: [String] {
     return self.methodArguments.filter {

+ 4 - 0
Sources/protoc-gen-grpc-swift/Generator-Names.swift

@@ -50,6 +50,10 @@ extension Generator {
     return nameForPackageService(file, service) + "Client"
   }
 
+  internal var testClientClassName: String {
+    return nameForPackageService(self.file, self.service) + "TestClient"
+  }
+
   internal var clientProtocolName: String {
     return nameForPackageService(file, service) + "ClientProtocol"
   }

+ 8 - 0
Sources/protoc-gen-grpc-swift/options.swift

@@ -54,6 +54,7 @@ final class GeneratorOptions {
   private(set) var visibility = Visibility.internal
   private(set) var generateServer = true
   private(set) var generateClient = true
+  private(set) var generateTestClient = false
   private(set) var protoToModuleMappings = ProtoFileToModuleMappings()
   private(set) var fileNaming = FileNaming.FullPath
   private(set) var extraModuleImports: [String] = []
@@ -82,6 +83,13 @@ final class GeneratorOptions {
           throw GenerationError.invalidParameterValue(name: pair.key, value: pair.value)
         }
 
+      case "TestClient":
+        if let value = Bool(pair.value) {
+          self.generateTestClient = value
+        } else {
+          throw GenerationError.invalidParameterValue(name: pair.key, value: pair.value)
+        }
+
       case "ProtoPathModuleMappings":
         if !pair.value.isEmpty {
           do {