Browse Source

Require generated client protocols to accept `CallOptions` (#865)

Motivation:

The generated protocol for clients allows the caller to optionally pass
in some `CallOptions`. The generated client conforming to this protocol
defaults this argument to `nil` and falls back to default options set on
that client in that case.

It is therefore possible to call, for example,
`theClient.someUnaryRPC(someRequest)` but not
`theProtocol.someUnaryRPC(someRequest)`.

Without explicitly passing the call options, the protocol offers little
value when coding to the protocol rather than to a concrete
implementation.

Modifications:

- Generated client protocols now extend `GRPCClient` (previously the
  generated client implementations conformed to `GRPCClient` and the
  protocol)
- Generated client protocol requires `CallOptions`
- Generated client protocols now have a generated extension for each RPC
  which does not require call options whose implementation forwards the
  default call options courtesy of `GRPCClient`.

Result:

- It is possible to reasonably implement code against the generated
  protocol rather than a concrete implementation
- We lose the ability to call an RPC with `callOptions: nil`
George Barnett 5 years ago
parent
commit
fe3d286f19
1 changed files with 181 additions and 54 deletions
  1. 181 54
      Sources/protoc-gen-grpc-swift/Generator-Client.swift

+ 181 - 54
Sources/protoc-gen-grpc-swift/Generator-Client.swift

@@ -22,35 +22,106 @@ extension Generator {
     println()
     printServiceClientProtocol()
     println()
+    printClientProtocolExtension()
+    println()
     printServiceClientImplementation()
   }
 
-  private func printServiceClientProtocol() {
-    println("/// Usage: instantiate \(clientClassName), then call methods of this protocol to make API calls.")
-    println("\(options.visibility.sourceSnippet) protocol \(clientProtocolName) {")
-    indent()
-    for method in service.methods {
-      self.method = method
-      switch streamingType(method) {
-      case .unary:
-        println("func \(methodFunctionName)(_ request: \(methodInputName), callOptions: CallOptions?) -> UnaryCall<\(methodInputName), \(methodOutputName)>")
+  private func printFunction(name: String, arguments: [String], returnType: String, access: String? = nil, bodyBuilder: (() -> ())?) {
+    // Add a space after access, if it exists.
+    let accessOrEmpty = access.map { $0 + " " } ?? ""
 
-      case .serverStreaming:
-        println("func \(methodFunctionName)(_ request: \(methodInputName), callOptions: CallOptions?, handler: @escaping (\(methodOutputName)) -> Void) -> ServerStreamingCall<\(methodInputName), \(methodOutputName)>")
+    let hasBody = bodyBuilder != nil
 
-      case .clientStreaming:
-        println("func \(methodFunctionName)(callOptions: CallOptions?) -> ClientStreamingCall<\(methodInputName), \(methodOutputName)>")
+    if arguments.isEmpty {
+      // Don't bother splitting across multiple lines if there are no arguments.
+      self.println("\(accessOrEmpty)func \(name)() -> \(returnType)", newline: !hasBody)
+    } else {
+      self.println("\(accessOrEmpty)func \(name)(")
+      self.withIndentation {
+        // Add a comma after each argument except the last.
+        arguments.forEach(beforeLast: {
+          self.println($0 + ",")
+        }, onLast: {
+          self.println($0)
+        })
+      }
+      self.println(") -> \(returnType)", newline: !hasBody)
+    }
 
-      case .bidirectionalStreaming:
-        println("func \(methodFunctionName)(callOptions: CallOptions?, handler: @escaping (\(methodOutputName)) -> Void) -> BidirectionalStreamingCall<\(methodInputName), \(methodOutputName)>")
+    if let bodyBuilder = bodyBuilder {
+      self.println(" {")
+      self.withIndentation {
+        bodyBuilder()
+      }
+      self.println("}")
+    }
+  }
+
+  private func printServiceClientProtocol() {
+    self.println("/// Usage: instantiate \(self.clientClassName), then call methods of this protocol to make API calls.")
+    self.println("\(self.access) protocol \(self.clientProtocolName): GRPCClient {")
+    self.withIndentation {
+      for method in service.methods {
+        self.method = method
+
+        self.printFunction(
+          name: self.methodFunctionName,
+          arguments: self.methodArguments,
+          returnType: self.methodReturnType,
+          bodyBuilder: nil
+        )
+
+        self.println()
       }
     }
-    outdent()
     println("}")
   }
 
+  private func printClientProtocolExtension() {
+    self.println("extension \(self.clientProtocolName) {")
+    self.withIndentation {
+      for method in service.methods {
+        self.method = method
+        let body: () -> ()
+
+        switch streamingType(method) {
+        case .unary:
+          body = {
+            self.println("return self.\(self.methodFunctionName)(request, callOptions: self.defaultCallOptions)")
+          }
+
+        case .serverStreaming:
+          body = {
+            self.println("return self.\(self.methodFunctionName)(request, callOptions: self.defaultCallOptions, handler: handler)")
+          }
+
+        case .clientStreaming:
+          body = {
+            self.println("return self.\(self.methodFunctionName)(callOptions: self.defaultCallOptions)")
+          }
+
+        case .bidirectionalStreaming:
+          body = {
+            self.println("return self.\(self.methodFunctionName)(callOptions: self.defaultCallOptions, handler: handler)")
+          }
+        }
+
+        self.printFunction(
+          name: self.methodFunctionName,
+          arguments: self.methodArgumentsWithoutCallOptions,
+          returnType: self.methodReturnType,
+          access: self.access,
+          bodyBuilder: body
+        )
+        self.println()
+      }
+    }
+    self.println("}")
+  }
+
   private func printServiceClientImplementation() {
-    println("\(access) final class \(clientClassName): GRPCClient, \(clientProtocolName) {")
+    println("\(access) final class \(clientClassName): \(clientProtocolName) {")
     indent()
     println("\(access) let channel: GRPCChannel")
     println("\(access) var defaultCallOptions: CallOptions")
@@ -101,22 +172,20 @@ extension Generator {
     self.printRequestParameter()
     self.printCallOptionsParameter()
     self.println("/// - Returns: A `UnaryCall` with futures for the metadata, status and response.")
-    self.println("\(self.access) func \(self.methodFunctionName)(")
-    self.withIndentation {
-      self.println("_ request: \(self.methodInputName),")
-      self.println("callOptions: CallOptions? = nil")
-    }
-    self.println(") -> UnaryCall<\(self.methodInputName), \(self.methodOutputName)> {")
-    self.withIndentation {
+    self.printFunction(
+      name: self.methodFunctionName,
+      arguments: self.methodArguments,
+      returnType: self.methodReturnType,
+      access: self.access
+    ) {
       self.println("return \(callFactory).makeUnaryCall(")
       self.withIndentation {
         self.println("path: \(self.methodPath),")
         self.println("request: request,")
-        self.println("callOptions: callOptions ?? self.defaultCallOptions")
+        self.println("callOptions: callOptions")
       }
       self.println(")")
     }
-    self.println("}")
   }
 
   private func printServerStreamingCall(callFactory: String) {
@@ -127,24 +196,21 @@ extension Generator {
     self.printCallOptionsParameter()
     self.printHandlerParameter()
     self.println("/// - Returns: A `ServerStreamingCall` with futures for the metadata and status.")
-    self.println("\(self.access) func \(self.methodFunctionName)(")
-    self.withIndentation {
-      self.println("_ request: \(self.methodInputName),")
-      self.println("callOptions: CallOptions? = nil,")
-      self.println("handler: @escaping (\(methodOutputName)) -> Void")
-    }
-    self.println(") -> ServerStreamingCall<\(methodInputName), \(methodOutputName)> {")
-    self.withIndentation {
+    self.printFunction(
+      name: self.methodFunctionName,
+      arguments: self.methodArguments,
+      returnType: self.methodReturnType,
+      access: self.access
+    ) {
       self.println("return \(callFactory).makeServerStreamingCall(") // path: \"/\(servicePath)/\(method.name)\",")
       self.withIndentation {
         self.println("path: \(self.methodPath),")
         self.println("request: request,")
-        self.println("callOptions: callOptions ?? self.defaultCallOptions,")
+        self.println("callOptions: callOptions,")
         self.println("handler: handler")
       }
       self.println(")")
     }
-    self.println("}")
   }
 
   private func printClientStreamingCall(callFactory: String) {
@@ -155,20 +221,19 @@ extension Generator {
     self.printParameters()
     self.printCallOptionsParameter()
     self.println("/// - Returns: A `ClientStreamingCall` with futures for the metadata, status and response.")
-    self.println("\(self.access) func \(self.methodFunctionName)(")
-    self.withIndentation {
-      self.println("callOptions: CallOptions? = nil")
-    }
-    self.println(") -> ClientStreamingCall<\(self.methodInputName), \(self.methodOutputName)> {")
-    self.withIndentation {
+    self.printFunction(
+      name: self.methodFunctionName,
+      arguments: self.methodArguments,
+      returnType: self.methodReturnType,
+      access: self.access
+    ) {
       self.println("return \(callFactory).makeClientStreamingCall(")
       self.withIndentation {
         self.println("path: \(self.methodPath),")
-        self.println("callOptions: callOptions ?? self.defaultCallOptions")
+        self.println("callOptions: callOptions")
       }
       self.println(")")
     }
-    self.println("}")
   }
 
   private func printBidirectionalStreamingCall(callFactory: String) {
@@ -180,22 +245,20 @@ extension Generator {
     self.printCallOptionsParameter()
     self.printHandlerParameter()
     self.println("/// - Returns: A `ClientStreamingCall` with futures for the metadata and status.")
-    self.println("\(self.access) func \(self.methodFunctionName)(")
-    self.withIndentation {
-      self.println("callOptions: CallOptions? = nil,")
-      self.println("handler: @escaping (\(self.methodOutputName)) -> Void")
-    }
-    self.println(") -> BidirectionalStreamingCall<\(self.methodInputName), \(self.methodOutputName)> {")
-    self.withIndentation {
+    self.printFunction(
+      name: self.methodFunctionName,
+      arguments: self.methodArguments,
+      returnType: self.methodReturnType,
+      access: self.access
+    ) {
       self.println("return \(callFactory).makeBidirectionalStreamingCall(")
       self.withIndentation {
         self.println("path: \(self.methodPath),")
-        self.println("callOptions: callOptions ?? self.defaultCallOptions,")
+        self.println("callOptions: callOptions,")
         self.println("handler: handler")
       }
       self.println(")")
     }
-    self.println("}")
   }
 
   private func printClientStreamingDetails() {
@@ -212,7 +275,7 @@ extension Generator {
   }
 
   private func printCallOptionsParameter() {
-    println("///   - callOptions: Call options; `self.defaultCallOptions` is used if `nil`.")
+    println("///   - callOptions: Call options.")
   }
 
   private func printHandlerParameter() {
@@ -226,6 +289,57 @@ fileprivate extension Generator {
   }
 }
 
+extension Generator {
+  fileprivate var methodArguments: [String] {
+    switch self.streamType {
+    case .unary:
+      return [
+        "_ request: \(self.methodInputName)",
+        "callOptions: CallOptions"
+      ]
+    case .serverStreaming:
+      return [
+        "_ request: \(self.methodInputName)",
+        "callOptions: CallOptions",
+        "handler: @escaping (\(methodOutputName)) -> Void"
+      ]
+
+    case .clientStreaming:
+      return ["callOptions: CallOptions"]
+
+    case .bidirectionalStreaming:
+      return [
+        "callOptions: CallOptions",
+        "handler: @escaping (\(methodOutputName)) -> Void"
+      ]
+    }
+  }
+
+
+  fileprivate var methodArgumentsWithoutCallOptions: [String] {
+    return self.methodArguments.filter {
+      !$0.hasPrefix("callOptions: ")
+    }
+  }
+
+  fileprivate var methodReturnType: String {
+    switch self.streamType {
+    case .unary:
+      return "UnaryCall<\(self.methodInputName), \(self.methodOutputName)>"
+
+    case .serverStreaming:
+      return "ServerStreamingCall<\(self.methodInputName), \(self.methodOutputName)>"
+
+    case .clientStreaming:
+      return "ClientStreamingCall<\(self.methodInputName), \(self.methodOutputName)>"
+
+    case .bidirectionalStreaming:
+      return "BidirectionalStreamingCall<\(self.methodInputName), \(self.methodOutputName)>"
+    }
+
+  }
+}
+
 fileprivate extension StreamingType {
   var name: String {
     switch self {
@@ -257,3 +371,16 @@ extension MethodDescriptor {
     }
   }
 }
+
+extension Array {
+  /// Like `forEach` except that the `body` closure operates on all elements except for the last,
+  /// and the `last` closure only operates on the last element.
+  fileprivate func forEach(beforeLast body: (Element) -> (), onLast last: (Element) -> ()) {
+    for element in self.dropLast() {
+      body(element)
+    }
+    if let lastElement = self.last {
+      last(lastElement)
+    }
+  }
+}