Forráskód Böngészése

Codegen for server interceptors (#1030)

Motivation:

We need codegen to use server interceptors. This PR adds it.

Modifications:

- Add codegen for server interceptors.
- Also removes the default implementation of generated client
  interceptor factories, this makes the requirements clearer to the
  user.
- Adds a handful of tests using interceptors as well.

Result:

Codegen for server interceptors.
George Barnett 5 éve
szülő
commit
185b732f23

+ 1 - 44
Sources/protoc-gen-grpc-swift/Generator-Client.swift

@@ -27,8 +27,6 @@ extension Generator {
       self.println()
       self.printServiceClientInterceptorFactoryProtocol()
       self.println()
-      self.printServiceClientInterceptorFactoryProtocolExtension()
-      self.println()
       self.printServiceClientImplementation()
     }
 
@@ -38,7 +36,7 @@ extension Generator {
     }
   }
 
-  private func printFunction(
+  internal func printFunction(
     name: String,
     arguments: [String],
     returnType: String?,
@@ -113,19 +111,6 @@ extension Generator {
   private func printServiceClientInterceptorFactoryProtocol() {
     self.println("\(self.access) protocol \(self.clientInterceptorProtocolName) {")
     self.withIndentation {
-      // Generic interceptor.
-      self.println("/// Makes an array of generic interceptors. The per-method interceptor")
-      self.println("/// factories default to calling this function and it therefore provides a")
-      self.println("/// convenient way of setting interceptors for all methods on a client.")
-      self.println("/// - Returns: An array of interceptors generic over `Request` and `Response`.")
-      self.println("///   Defaults to an empty array.")
-      self.printFunction(
-        name: "makeInterceptors<Request: SwiftProtobuf.Message, Response: SwiftProtobuf.Message>",
-        arguments: [],
-        returnType: "[ClientInterceptor<Request, Response>]",
-        bodyBuilder: nil
-      )
-
       // Method specific interceptors.
       for method in service.methods {
         self.println()
@@ -133,7 +118,6 @@ extension Generator {
         self.println(
           "/// - Returns: Interceptors to use when invoking '\(self.methodFunctionName)'."
         )
-        self.println("///   Defaults to calling `self.makeInterceptors()`.")
         // Skip the access, we're defining a protocol.
         self.printMethodInterceptorFactory(access: nil)
       }
@@ -154,33 +138,6 @@ extension Generator {
     )
   }
 
-  private func printServiceClientInterceptorFactoryProtocolExtension() {
-    self.println("extension \(self.clientInterceptorProtocolName) {")
-
-    self.withIndentation {
-      // Default interceptor factory.
-      self.printFunction(
-        name: "makeInterceptors<Request: SwiftProtobuf.Message, Response: SwiftProtobuf.Message>",
-        arguments: [],
-        returnType: "[ClientInterceptor<Request, Response>]",
-        access: self.access
-      ) {
-        self.println("return []")
-      }
-
-      for method in self.service.methods {
-        self.println()
-
-        self.method = method
-        self.printMethodInterceptorFactory(access: self.access) {
-          self.println("return self.makeInterceptors()")
-        }
-      }
-    }
-
-    self.println("}")
-  }
-
   private func printServiceClientImplementation() {
     println("\(access) final class \(clientClassName): \(clientProtocolName) {")
     self.withIndentation {

+ 4 - 0
Sources/protoc-gen-grpc-swift/Generator-Names.swift

@@ -61,6 +61,10 @@ extension Generator {
     return nameForPackageService(file, service) + "ClientInterceptorFactoryProtocol"
   }
 
+  internal var serverInterceptorProtocolName: String {
+    return nameForPackageService(file, service) + "ServerInterceptorFactoryProtocol"
+  }
+
   internal var callName: String {
     return nameForPackageServiceMethod(file, service, method) + "Call"
   }

+ 157 - 72
Sources/protoc-gen-grpc-swift/Generator-Server.swift

@@ -20,90 +20,175 @@ import SwiftProtobufPluginLibrary
 extension Generator {
   internal func printServer() {
     self.printServerProtocol()
+    self.println()
+    self.printServerProtocolExtension()
+    self.println()
+    self.printServerInterceptorFactoryProtocol()
   }
 
   private func printServerProtocol() {
     println("/// To build a server, implement a class that conforms to this protocol.")
     println("\(access) protocol \(providerName): CallHandlerProvider {")
-    indent()
-    for method in service.methods {
-      self.method = method
+    self.withIndentation {
+      println("var interceptors: \(self.serverInterceptorProtocolName)? { get }")
+      for method in service.methods {
+        self.method = method
+        self.println()
 
-      switch streamingType(method) {
-      case .unary:
-        println(self.method.protoSourceComments(), newline: false)
-        println(
-          "func \(methodFunctionName)(request: \(methodInputName), context: StatusOnlyCallContext) -> EventLoopFuture<\(methodOutputName)>"
-        )
-      case .serverStreaming:
-        println(self.method.protoSourceComments(), newline: false)
-        println(
-          "func \(methodFunctionName)(request: \(methodInputName), context: StreamingResponseCallContext<\(methodOutputName)>) -> EventLoopFuture<GRPCStatus>"
-        )
-      case .clientStreaming:
-        println(self.method.protoSourceComments(), newline: false)
-        println(
-          "func \(methodFunctionName)(context: UnaryResponseCallContext<\(methodOutputName)>) -> EventLoopFuture<(StreamEvent<\(methodInputName)>) -> Void>"
-        )
-      case .bidirectionalStreaming:
-        println(self.method.protoSourceComments(), newline: false)
-        println(
-          "func \(methodFunctionName)(context: StreamingResponseCallContext<\(methodOutputName)>) -> EventLoopFuture<(StreamEvent<\(methodInputName)>) -> Void>"
-        )
+        switch streamingType(method) {
+        case .unary:
+          println(self.method.protoSourceComments(), newline: false)
+          println(
+            "func \(methodFunctionName)(request: \(methodInputName), context: StatusOnlyCallContext) -> EventLoopFuture<\(methodOutputName)>"
+          )
+        case .serverStreaming:
+          println(self.method.protoSourceComments(), newline: false)
+          println(
+            "func \(methodFunctionName)(request: \(methodInputName), context: StreamingResponseCallContext<\(methodOutputName)>) -> EventLoopFuture<GRPCStatus>"
+          )
+        case .clientStreaming:
+          println(self.method.protoSourceComments(), newline: false)
+          println(
+            "func \(methodFunctionName)(context: UnaryResponseCallContext<\(methodOutputName)>) -> EventLoopFuture<(StreamEvent<\(methodInputName)>) -> Void>"
+          )
+        case .bidirectionalStreaming:
+          println(self.method.protoSourceComments(), newline: false)
+          println(
+            "func \(methodFunctionName)(context: StreamingResponseCallContext<\(methodOutputName)>) -> EventLoopFuture<(StreamEvent<\(methodInputName)>) -> Void>"
+          )
+        }
       }
     }
-    outdent()
     println("}")
-    println()
-    println("extension \(providerName) {")
-    indent()
-    println("\(access) var serviceName: Substring { return \"\(servicePath)\" }")
-    println()
-    println(
-      "/// Determines, calls and returns the appropriate request handler, depending on the request's method."
-    )
-    println("/// Returns nil for methods not handled by this service.")
-    println(
-      "\(access) func handleMethod(_ methodName: Substring, callHandlerContext: CallHandlerContext) -> GRPCCallHandler? {"
-    )
-    indent()
-    println("switch methodName {")
-    for method in service.methods {
-      self.method = method
-      println("case \"\(method.name)\":")
-      indent()
-      let callHandlerType: String
-      switch streamingType(method) {
-      case .unary: callHandlerType = "CallHandlerFactory.makeUnary"
-      case .serverStreaming: callHandlerType = "CallHandlerFactory.makeServerStreaming"
-      case .clientStreaming: callHandlerType = "CallHandlerFactory.makeClientStreaming"
-      case .bidirectionalStreaming: callHandlerType =
-        "CallHandlerFactory.makeBidirectionalStreaming"
+  }
+
+  private func printServerProtocolExtension() {
+    self.println("extension \(self.providerName) {")
+    self.withIndentation {
+      self.println("\(self.access) var serviceName: Substring { return \"\(self.servicePath)\" }")
+      self.println()
+      self.println(
+        "/// Determines, calls and returns the appropriate request handler, depending on the request's method."
+      )
+      self.println("/// Returns nil for methods not handled by this service.")
+      self.printFunction(
+        name: "handleMethod",
+        arguments: [
+          "_ methodName: Substring",
+          "callHandlerContext: CallHandlerContext",
+        ],
+        returnType: "GRPCCallHandler?",
+        access: self.access
+      ) {
+        self.println("switch methodName {")
+        for method in self.service.methods {
+          self.method = method
+          self.println("case \"\(method.name)\":")
+          self.withIndentation {
+            // Get the factory name.
+            let callHandlerType: String
+            switch streamingType(method) {
+            case .unary:
+              callHandlerType = "CallHandlerFactory.makeUnary"
+            case .serverStreaming:
+              callHandlerType = "CallHandlerFactory.makeServerStreaming"
+            case .clientStreaming:
+              callHandlerType = "CallHandlerFactory.makeClientStreaming"
+            case .bidirectionalStreaming:
+              callHandlerType = "CallHandlerFactory.makeBidirectionalStreaming"
+            }
+
+            self.println("return \(callHandlerType)(")
+            self.withIndentation {
+              self.println("callHandlerContext: callHandlerContext,")
+              self.println(
+                "interceptors: self.interceptors?.\(self.methodInterceptorFactoryName)() ?? []"
+              )
+            }
+            self.println(") { context in")
+            self.withIndentation {
+              switch streamingType(self.method) {
+              case .unary, .serverStreaming:
+                self.println("return { request in")
+                self.withIndentation {
+                  self.println(
+                    "self.\(self.methodFunctionName)(request: request, context: context)"
+                  )
+                }
+                self.println("}")
+              case .clientStreaming, .bidirectionalStreaming:
+                self.println("self.\(self.methodFunctionName)(context: context)")
+              }
+            }
+            self.println("}")
+          }
+          self.println()
+        }
+
+        // Default case.
+        self.println("default:")
+        self.withIndentation {
+          self.println("return nil")
+        }
+        self.println("}")
       }
-      println("return \(callHandlerType)(callHandlerContext: callHandlerContext) { context in")
-      indent()
-      switch streamingType(method) {
-      case .unary, .serverStreaming:
-        println("return { request in")
-        indent()
-        println("self.\(methodFunctionName)(request: request, context: context)")
-        outdent()
-        println("}")
-      case .clientStreaming, .bidirectionalStreaming:
-        println("return self.\(methodFunctionName)(context: context)")
+    }
+    self.println("}")
+  }
+
+  private func printServerInterceptorFactoryProtocol() {
+    self.println("\(self.access) protocol \(self.serverInterceptorProtocolName) {")
+    self.withIndentation {
+      // Method specific interceptors.
+      for method in service.methods {
+        self.println()
+        self.method = method
+        self.println(
+          "/// - Returns: Interceptors to use when handling '\(self.methodFunctionName)'."
+        )
+        self.println("///   Defaults to calling `self.makeInterceptors()`.")
+        // Skip the access, we're defining a protocol.
+        self.printMethodInterceptorFactory(access: nil)
       }
-      outdent()
-      println("}")
-      outdent()
-      println()
     }
-    println("default: return nil")
-    println("}")
-    outdent()
-    println("}")
+    self.println("}")
+  }
 
-    outdent()
-    println("}")
-    println()
+  private func printMethodInterceptorFactory(
+    access: String?,
+    bodyBuilder: (() -> Void)? = nil
+  ) {
+    self.printFunction(
+      name: self.methodInterceptorFactoryName,
+      arguments: [],
+      returnType: "[ServerInterceptor<\(self.methodInputName), \(self.methodOutputName)>]",
+      access: access,
+      bodyBuilder: bodyBuilder
+    )
+  }
+
+  func printServerInterceptorFactoryProtocolExtension() {
+    self.println("extension \(self.serverInterceptorProtocolName) {")
+    self.withIndentation {
+      // Default interceptor factory.
+      self.printFunction(
+        name: "makeInterceptors<Request: SwiftProtobuf.Message, Response: SwiftProtobuf.Message>",
+        arguments: [],
+        returnType: "[ServerInterceptor<Request, Response>]",
+        access: self.access
+      ) {
+        self.println("return []")
+      }
+
+      for method in self.service.methods {
+        self.println()
+
+        self.method = method
+        self.printMethodInterceptorFactory(access: self.access) {
+          self.println("return self.makeInterceptors()")
+        }
+      }
+    }
+    self.println("}")
   }
 }