瀏覽代碼

Codegen for client interceptors (#1022)

Motivation:

We have client interceptors wired all the way through but no convnenient
way for them to be used.

Modifications:

- Generate an client interceptor factory which has a
  `make<Method>Interceptors()` function for each method on the service
- The interceptor factory is an optional requirement on the generated
  client protocol, RPC invocations will pull interceptors from the
  factory, if present.

Result:

Users can provide interceptor factory implementations to their generated
clients to use interceptors.
George Barnett 5 年之前
父節點
當前提交
4d5fd4e461
共有 2 個文件被更改,包括 136 次插入31 次删除
  1. 128 31
      Sources/protoc-gen-grpc-swift/Generator-Client.swift
  2. 8 0
      Sources/protoc-gen-grpc-swift/Generator-Names.swift

+ 128 - 31
Sources/protoc-gen-grpc-swift/Generator-Client.swift

@@ -25,6 +25,10 @@ extension Generator {
       self.println()
       self.printClientProtocolExtension()
       self.println()
+      self.printServiceClientInterceptorFactoryProtocol()
+      self.println()
+      self.printServiceClientInterceptorFactoryProtocolExtension()
+      self.println()
       self.printServiceClientImplementation()
     }
 
@@ -73,13 +77,15 @@ extension Generator {
   }
 
   private func printServiceClientProtocol() {
-    self
-      .println(
-        "/// Usage: instantiate \(self.clientClassName), then call methods of this protocol to make API calls."
-      )
+    self.println(
+      "/// Usage: instantiate \(self.clientClassName), then call methods of this protocol to make API calls."
+    )
     self.println("\(self.access) protocol \(self.clientProtocolName): GRPCClient {")
     self.withIndentation {
+      self.println("var interceptors: \(self.clientInterceptorProtocolName)? { get }")
+
       for method in service.methods {
+        self.println()
         self.method = method
 
         self.printFunction(
@@ -88,8 +94,6 @@ extension Generator {
           returnType: self.methodReturnType,
           bodyBuilder: nil
         )
-
-        self.println()
       }
     }
     println("}")
@@ -98,6 +102,7 @@ extension Generator {
   private func printClientProtocolExtension() {
     self.println("extension \(self.clientProtocolName) {")
 
+    // Default method implementations.
     self.withIndentation {
       self.printMethods()
     }
@@ -105,28 +110,106 @@ extension Generator {
     self.println("}")
   }
 
+  private func printServiceClientInterceptorFactoryProtocol() {
+    self.println("\(self.access) protocol \(self.clientInterceptorProtocolName) {")
+    self.withIndentation {
+      // Generic interceptor.
+      self.println("/// Makes an array of generic interceptors. The per-method interceptor")
+      self.println("/// factories default to calling this function and it therefore provides a")
+      self.println("/// convenient way of setting interceptors for all methods on a client.")
+      self.println("/// - Returns: An array of interceptors generic over `Request` and `Response`.")
+      self.println("///   Defaults to an empty array.")
+      self.printFunction(
+        name: "makeInterceptors<Request: SwiftProtobuf.Message, Response: SwiftProtobuf.Message>",
+        arguments: [],
+        returnType: "[ClientInterceptor<Request, Response>]",
+        bodyBuilder: nil
+      )
+
+      // Method specific interceptors.
+      for method in service.methods {
+        self.println()
+        self.method = method
+        self.println(
+          "/// - Returns: Interceptors to use when invoking '\(self.methodFunctionName)'."
+        )
+        self.println("///   Defaults to calling `self.makeInterceptors()`.")
+        // Skip the access, we're defining a protocol.
+        self.printMethodInterceptorFactory(access: nil)
+      }
+    }
+    self.println("}")
+  }
+
+  private func printMethodInterceptorFactory(
+    access: String?,
+    bodyBuilder: (() -> Void)? = nil
+  ) {
+    self.printFunction(
+      name: self.methodInterceptorFactoryName,
+      arguments: [],
+      returnType: "[ClientInterceptor<\(self.methodInputName), \(self.methodOutputName)>]",
+      access: access,
+      bodyBuilder: bodyBuilder
+    )
+  }
+
+  private func printServiceClientInterceptorFactoryProtocolExtension() {
+    self.println("extension \(self.clientInterceptorProtocolName) {")
+
+    self.withIndentation {
+      // Default interceptor factory.
+      self.printFunction(
+        name: "makeInterceptors<Request: SwiftProtobuf.Message, Response: SwiftProtobuf.Message>",
+        arguments: [],
+        returnType: "[ClientInterceptor<Request, Response>]",
+        access: self.access
+      ) {
+        self.println("return []")
+      }
+
+      for method in self.service.methods {
+        self.println()
+
+        self.method = method
+        self.printMethodInterceptorFactory(access: self.access) {
+          self.println("return self.makeInterceptors()")
+        }
+      }
+    }
+
+    self.println("}")
+  }
+
   private func printServiceClientImplementation() {
     println("\(access) final class \(clientClassName): \(clientProtocolName) {")
-    indent()
-    println("\(access) let channel: GRPCChannel")
-    println("\(access) var defaultCallOptions: CallOptions")
-    println()
-    println("/// Creates a client for the \(servicePath) service.")
-    println("///")
-    self.printParameters()
-    println("///   - channel: `GRPCChannel` to the service host.")
-    println(
-      "///   - defaultCallOptions: Options to use for each service call if the user doesn't provide them."
-    )
-    println(
-      "\(access) init(channel: GRPCChannel, defaultCallOptions: CallOptions = CallOptions()) {"
-    )
-    indent()
-    println("self.channel = channel")
-    println("self.defaultCallOptions = defaultCallOptions")
-    outdent()
-    println("}")
-    outdent()
+    self.withIndentation {
+      println("\(access) let channel: GRPCChannel")
+      println("\(access) var defaultCallOptions: CallOptions")
+      println("\(access) var interceptors: \(clientInterceptorProtocolName)?")
+      println()
+      println("/// Creates a client for the \(servicePath) service.")
+      println("///")
+      self.printParameters()
+      println("///   - channel: `GRPCChannel` to the service host.")
+      println(
+        "///   - defaultCallOptions: Options to use for each service call if the user doesn't provide them."
+      )
+      println("///   - interceptors: A factory providing interceptors for each RPC.")
+      println("\(access) init(")
+      self.withIndentation {
+        println("channel: GRPCChannel,")
+        println("defaultCallOptions: CallOptions = CallOptions(),")
+        println("interceptors: \(clientInterceptorProtocolName)? = nil")
+      }
+      self.println(") {")
+      self.withIndentation {
+        println("self.channel = channel")
+        println("self.defaultCallOptions = defaultCallOptions")
+        println("self.interceptors = interceptors")
+      }
+      self.println("}")
+    }
     println("}")
   }
 
@@ -168,7 +251,10 @@ extension Generator {
       self.withIndentation {
         self.println("path: \(self.methodPath),")
         self.println("request: request,")
-        self.println("callOptions: callOptions ?? self.defaultCallOptions")
+        self.println("callOptions: callOptions ?? self.defaultCallOptions,")
+        self.println(
+          "interceptors: self.interceptors?.\(self.methodInterceptorFactoryName)() ?? []"
+        )
       }
       self.println(")")
     }
@@ -188,12 +274,14 @@ extension Generator {
       returnType: self.methodReturnType,
       access: self.access
     ) {
-      self
-        .println("return self.makeServerStreamingCall(") // path: \"/\(servicePath)/\(method.name)\",")
+      self.println("return self.makeServerStreamingCall(")
       self.withIndentation {
         self.println("path: \(self.methodPath),")
         self.println("request: request,")
         self.println("callOptions: callOptions ?? self.defaultCallOptions,")
+        self.println(
+          "interceptors: self.interceptors?.\(self.methodInterceptorFactoryName)() ?? [],"
+        )
         self.println("handler: handler")
       }
       self.println(")")
@@ -220,7 +308,10 @@ extension Generator {
       self.println("return self.makeClientStreamingCall(")
       self.withIndentation {
         self.println("path: \(self.methodPath),")
-        self.println("callOptions: callOptions ?? self.defaultCallOptions")
+        self.println("callOptions: callOptions ?? self.defaultCallOptions,")
+        self.println(
+          "interceptors: self.interceptors?.\(self.methodInterceptorFactoryName)() ?? []"
+        )
       }
       self.println(")")
     }
@@ -245,6 +336,9 @@ extension Generator {
       self.withIndentation {
         self.println("path: \(self.methodPath),")
         self.println("callOptions: callOptions ?? self.defaultCallOptions,")
+        self.println(
+          "interceptors: self.interceptors?.\(self.methodInterceptorFactoryName)() ?? [],"
+        )
         self.println("handler: handler")
       }
       self.println(")")
@@ -386,6 +480,7 @@ extension Generator {
     self.withIndentation {
       self.println("private let fakeChannel: FakeChannel")
       self.println("\(self.access) var defaultCallOptions: CallOptions")
+      self.println("\(self.access) var interceptors: \(self.clientInterceptorProtocolName)?")
 
       self.println()
       self.println("\(self.access) var channel: GRPCChannel {")
@@ -398,12 +493,14 @@ extension Generator {
       self.println("\(self.access) init(")
       self.withIndentation {
         self.println("fakeChannel: FakeChannel = FakeChannel(),")
-        self.println("defaultCallOptions callOptions: CallOptions = CallOptions()")
+        self.println("defaultCallOptions callOptions: CallOptions = CallOptions(),")
+        self.println("interceptors: \(clientInterceptorProtocolName)? = nil")
       }
       self.println(") {")
       self.withIndentation {
         self.println("self.fakeChannel = fakeChannel")
         self.println("self.defaultCallOptions = callOptions")
+        self.println("self.interceptors = interceptors")
       }
       self.println("}")
 

+ 8 - 0
Sources/protoc-gen-grpc-swift/Generator-Names.swift

@@ -57,6 +57,10 @@ extension Generator {
     return nameForPackageService(file, service) + "ClientProtocol"
   }
 
+  internal var clientInterceptorProtocolName: String {
+    return nameForPackageService(file, service) + "ClientInterceptorFactoryProtocol"
+  }
+
   internal var callName: String {
     return nameForPackageServiceMethod(file, service, method) + "Call"
   }
@@ -74,6 +78,10 @@ extension Generator {
     return protobufNamer.fullName(message: method.outputType)
   }
 
+  internal var methodInterceptorFactoryName: String {
+    return "make\(self.method.name)Interceptors"
+  }
+
   internal var servicePath: String {
     if !file.package.isEmpty {
       return file.package + "." + service.name