Browse Source

Generate client test stubs in a dedicated section. (#403)

This will make it simpler to conditionally turn of generating the
implementation, so that test stubs can be compiled separately. See
PR #402 for more discussion.
Martin Petrov 6 years ago
parent
commit
90a5498c0d

+ 18 - 18
Sources/Examples/Echo/Generated/echo.grpc.swift

@@ -31,10 +31,6 @@ fileprivate final class Echo_EchoGetCallBase: ClientCallUnaryBase<Echo_EchoReque
   override class var method: String { return "/echo.Echo/Get" }
 }
 
-class Echo_EchoGetCallTestStub: ClientCallUnaryTestStub, Echo_EchoGetCall {
-  override class var method: String { return "/echo.Echo/Get" }
-}
-
 internal protocol Echo_EchoExpandCall: ClientCallServerStreaming {
   /// Do not call this directly, call `receive()` in the protocol extension below instead.
   func _receive(timeout: DispatchTime) throws -> Echo_EchoResponse?
@@ -51,10 +47,6 @@ fileprivate final class Echo_EchoExpandCallBase: ClientCallServerStreamingBase<E
   override class var method: String { return "/echo.Echo/Expand" }
 }
 
-class Echo_EchoExpandCallTestStub: ClientCallServerStreamingTestStub<Echo_EchoResponse>, Echo_EchoExpandCall {
-  override class var method: String { return "/echo.Echo/Expand" }
-}
-
 internal protocol Echo_EchoCollectCall: ClientCallClientStreaming {
   /// Send a message to the stream. Nonblocking.
   func send(_ message: Echo_EchoRequest, completion: @escaping (Error?) -> Void) throws
@@ -76,12 +68,6 @@ fileprivate final class Echo_EchoCollectCallBase: ClientCallClientStreamingBase<
   override class var method: String { return "/echo.Echo/Collect" }
 }
 
-/// Simple fake implementation of Echo_EchoCollectCall
-/// stores sent values for later verification and finall returns a previously-defined result.
-class Echo_EchoCollectCallTestStub: ClientCallClientStreamingTestStub<Echo_EchoRequest, Echo_EchoResponse>, Echo_EchoCollectCall {
-  override class var method: String { return "/echo.Echo/Collect" }
-}
-
 internal protocol Echo_EchoUpdateCall: ClientCallBidirectionalStreaming {
   /// Do not call this directly, call `receive()` in the protocol extension below instead.
   func _receive(timeout: DispatchTime) throws -> Echo_EchoResponse?
@@ -113,10 +99,6 @@ fileprivate final class Echo_EchoUpdateCallBase: ClientCallBidirectionalStreamin
   override class var method: String { return "/echo.Echo/Update" }
 }
 
-class Echo_EchoUpdateCallTestStub: ClientCallBidirectionalStreamingTestStub<Echo_EchoRequest, Echo_EchoResponse>, Echo_EchoUpdateCall {
-  override class var method: String { return "/echo.Echo/Update" }
-}
-
 
 /// Instantiate Echo_EchoServiceClient, then call methods of this protocol to make API calls.
 internal protocol Echo_EchoService: ServiceClient {
@@ -210,6 +192,24 @@ internal final class Echo_EchoServiceClient: ServiceClientBase, Echo_EchoService
 
 }
 
+class Echo_EchoGetCallTestStub: ClientCallUnaryTestStub, Echo_EchoGetCall {
+  override class var method: String { return "/echo.Echo/Get" }
+}
+
+class Echo_EchoExpandCallTestStub: ClientCallServerStreamingTestStub<Echo_EchoResponse>, Echo_EchoExpandCall {
+  override class var method: String { return "/echo.Echo/Expand" }
+}
+
+/// Simple fake implementation of Echo_EchoCollectCall
+/// stores sent values for later verification and finall returns a previously-defined result.
+class Echo_EchoCollectCallTestStub: ClientCallClientStreamingTestStub<Echo_EchoRequest, Echo_EchoResponse>, Echo_EchoCollectCall {
+  override class var method: String { return "/echo.Echo/Collect" }
+}
+
+class Echo_EchoUpdateCallTestStub: ClientCallBidirectionalStreamingTestStub<Echo_EchoRequest, Echo_EchoResponse>, Echo_EchoUpdateCall {
+  override class var method: String { return "/echo.Echo/Update" }
+}
+
 class Echo_EchoServiceTestStub: ServiceClientTestStubBase, Echo_EchoService {
   var getRequests: [Echo_EchoRequest] = []
   var getResponses: [Echo_EchoResponse] = []

+ 79 - 51
Sources/protoc-gen-swiftgrpc/Generator-Client.swift

@@ -25,6 +25,10 @@ extension Generator {
     } else {
       printCGRPCClient(asynchronousCode: asynchronousCode,
                        synchronousCode: synchronousCode)
+      if options.generateTestStubs {
+        printCGRPCClientTestStubs(asynchronousCode: asynchronousCode,
+                                 synchronousCode: synchronousCode)
+      }
     }
   }
 
@@ -52,10 +56,25 @@ extension Generator {
     println()
     printServiceClientImplementation(asynchronousCode: asynchronousCode,
                                      synchronousCode: synchronousCode)
-    if options.generateTestStubs {
-      println()
-      printServiceClientTestStubs()
+  }
+
+  private func printCGRPCClientTestStubs(asynchronousCode: Bool,
+                                         synchronousCode: Bool) {
+    for method in service.methods {
+      self.method = method
+      switch streamingType(method) {
+      case .unary:
+        printServiceClientMethodCallUnaryTestStub()
+      case .serverStreaming:
+        printServiceClientMethodCallServerStreamingTestStub()
+      case .clientStreaming:
+        printServiceClientMethodCallClientStreamingTestStub()
+      case .bidirectionalStreaming:
+        printServiceClientMethodCallBidiStreamingTestStub()
+      }
     }
+    println()
+    printServiceClientTestStubs(asynchronousCode: asynchronousCode, synchronousCode: synchronousCode)
   }
 
   private func printServiceClientMethodCallUnary() {
@@ -66,17 +85,18 @@ extension Generator {
     println("override class var method: String { return \(methodPath) }")
     outdent()
     println("}")
-    if options.generateTestStubs {
-      println()
-      println("class \(callName)TestStub: ClientCallUnaryTestStub, \(callName) {")
-      indent()
-      println("override class var method: String { return \(methodPath) }")
-      outdent()
-      println("}")
-    }
     println()
   }
 
+  private func printServiceClientMethodCallUnaryTestStub() {
+    println()
+    println("class \(callName)TestStub: ClientCallUnaryTestStub, \(callName) {")
+    indent()
+    println("override class var method: String { return \(methodPath) }")
+    outdent()
+    println("}")
+  }
+
   private func printServiceClientMethodCallServerStreaming() {
     println("\(access) protocol \(callName): ClientCallServerStreaming {")
     indent()
@@ -91,17 +111,18 @@ extension Generator {
     println("override class var method: String { return \(methodPath) }")
     outdent()
     println("}")
-    if options.generateTestStubs {
-      println()
-      println("class \(callName)TestStub: ClientCallServerStreamingTestStub<\(methodOutputName)>, \(callName) {")
-      indent()
-      println("override class var method: String { return \(methodPath) }")
-      outdent()
-      println("}")
-    }
     println()
   }
 
+  private func printServiceClientMethodCallServerStreamingTestStub() {
+    println()
+    println("class \(callName)TestStub: ClientCallServerStreamingTestStub<\(methodOutputName)>, \(callName) {")
+    indent()
+    println("override class var method: String { return \(methodPath) }")
+    outdent()
+    println("}")
+  }
+
   private func printServiceClientMethodCallClientStreaming() {
     println("\(options.visibility.sourceSnippet) protocol \(callName): ClientCallClientStreaming {")
     indent()
@@ -121,19 +142,20 @@ extension Generator {
     println("override class var method: String { return \(methodPath) }")
     outdent()
     println("}")
-    if options.generateTestStubs {
-      println()
-      println("/// Simple fake implementation of \(callName)")
-      println("/// stores sent values for later verification and finall returns a previously-defined result.")
-      println("class \(callName)TestStub: ClientCallClientStreamingTestStub<\(methodInputName), \(methodOutputName)>, \(callName) {")
-      indent()
-      println("override class var method: String { return \(methodPath) }")
-      outdent()
-      println("}")
-    }
     println()
   }
 
+  private func printServiceClientMethodCallClientStreamingTestStub() {
+    println()
+    println("/// Simple fake implementation of \(callName)")
+    println("/// stores sent values for later verification and finall returns a previously-defined result.")
+    println("class \(callName)TestStub: ClientCallClientStreamingTestStub<\(methodInputName), \(methodOutputName)>, \(callName) {")
+    indent()
+    println("override class var method: String { return \(methodPath) }")
+    outdent()
+    println("}")
+  }
+
   private func printServiceClientMethodCallBidiStreaming() {
     println("\(access) protocol \(callName): ClientCallBidirectionalStreaming {")
     indent()
@@ -157,17 +179,18 @@ extension Generator {
     println("override class var method: String { return \(methodPath) }")
     outdent()
     println("}")
-    if options.generateTestStubs {
-      println()
-      println("class \(callName)TestStub: ClientCallBidirectionalStreamingTestStub<\(methodInputName), \(methodOutputName)>, \(callName) {")
-      indent()
-      println("override class var method: String { return \(methodPath) }")
-      outdent()
-      println("}")
-    }
     println()
   }
 
+  private func printServiceClientMethodCallBidiStreamingTestStub() {
+    println()
+    println("class \(callName)TestStub: ClientCallBidirectionalStreamingTestStub<\(methodInputName), \(methodOutputName)>, \(callName) {")
+    indent()
+    println("override class var method: String { return \(methodPath) }")
+    outdent()
+    println("}")
+  }
+
   private func printServiceClientProtocol(asynchronousCode: Bool,
                                           synchronousCode: Bool) {
     println("/// Instantiate \(serviceClassName)Client, then call methods of this protocol to make API calls.")
@@ -335,7 +358,8 @@ extension Generator {
     println("}")
   }
 
-  private func printServiceClientTestStubs() {
+  private func printServiceClientTestStubs(asynchronousCode: Bool,
+                                           synchronousCode: Bool) {
     println("class \(serviceClassName)TestStub: ServiceClientTestStubBase, \(serviceClassName) {")
     indent()
     for method in service.methods {
@@ -344,19 +368,23 @@ extension Generator {
       case .unary:
         println("var \(methodFunctionName)Requests: [\(methodInputName)] = []")
         println("var \(methodFunctionName)Responses: [\(methodOutputName)] = []")
-        println("func \(methodFunctionName)(_ request: \(methodInputName), metadata customMetadata: Metadata) throws -> \(methodOutputName) {")
-        indent()
-        println("\(methodFunctionName)Requests.append(request)")
-        println("defer { \(methodFunctionName)Responses.removeFirst() }")
-        println("return \(methodFunctionName)Responses.first!")
-        outdent()
-        println("}")
-        println("@discardableResult")
-        println("func \(methodFunctionName)(_ request: \(methodInputName), metadata customMetadata: Metadata, completion: @escaping (\(methodOutputName)?, CallResult) -> Void) throws -> \(callName) {")
-        indent()
-        println("fatalError(\"not implemented\")")
-        outdent()
-        println("}")
+        if synchronousCode {
+          println("func \(methodFunctionName)(_ request: \(methodInputName), metadata customMetadata: Metadata) throws -> \(methodOutputName) {")
+          indent()
+          println("\(methodFunctionName)Requests.append(request)")
+          println("defer { \(methodFunctionName)Responses.removeFirst() }")
+          println("return \(methodFunctionName)Responses.first!")
+          outdent()
+          println("}")
+        }
+        if asynchronousCode {
+          println("@discardableResult")
+          println("func \(methodFunctionName)(_ request: \(methodInputName), metadata customMetadata: Metadata, completion: @escaping (\(methodOutputName)?, CallResult) -> Void) throws -> \(callName) {")
+          indent()
+          println("fatalError(\"not implemented\")")
+          outdent()
+          println("}")
+        }
       case .serverStreaming:
         println("var \(methodFunctionName)Requests: [\(methodInputName)] = []")
         println("var \(methodFunctionName)Calls: [\(callName)] = []")