浏览代码

Allow client to specify metadata per call (#356)

To address https://github.com/grpc/grpc-swift/pull/355#issuecomment-453744353. It also closes https://github.com/grpc/grpc-swift/issues/190.
Taeho Kim 6 年之前
父节点
当前提交
3cd505acdc
共有 2 个文件被更改,包括 122 次插入40 次删除
  1. 47 20
      Sources/Examples/Echo/Generated/echo.grpc.swift
  2. 75 20
      Sources/protoc-gen-swiftgrpc/Generator-Client.swift

+ 47 - 20
Sources/Examples/Echo/Generated/echo.grpc.swift

@@ -117,61 +117,88 @@ class Echo_EchoUpdateCallTestStub: ClientCallBidirectionalStreamingTestStub<Echo
 /// Instantiate Echo_EchoServiceClient, then call methods of this protocol to make API calls.
 internal protocol Echo_EchoService: ServiceClient {
   /// Synchronous. Unary.
-  func get(_ request: Echo_EchoRequest) throws -> Echo_EchoResponse
+  func get(_ request: Echo_EchoRequest, metadata customMetadata: Metadata) throws -> Echo_EchoResponse
   /// Asynchronous. Unary.
-  func get(_ request: Echo_EchoRequest, completion: @escaping (Echo_EchoResponse?, CallResult) -> Void) throws -> Echo_EchoGetCall
+  func get(_ request: Echo_EchoRequest, metadata customMetadata: Metadata, completion: @escaping (Echo_EchoResponse?, CallResult) -> Void) throws -> Echo_EchoGetCall
 
   /// Asynchronous. Server-streaming.
   /// Send the initial message.
   /// Use methods on the returned object to get streamed responses.
-  func expand(_ request: Echo_EchoRequest, completion: ((CallResult) -> Void)?) throws -> Echo_EchoExpandCall
+  func expand(_ request: Echo_EchoRequest, metadata customMetadata: Metadata, completion: ((CallResult) -> Void)?) throws -> Echo_EchoExpandCall
 
   /// Asynchronous. Client-streaming.
   /// Use methods on the returned object to stream messages and
   /// to close the connection and wait for a final response.
-  func collect(completion: ((CallResult) -> Void)?) throws -> Echo_EchoCollectCall
+  func collect(metadata customMetadata: Metadata, completion: ((CallResult) -> Void)?) throws -> Echo_EchoCollectCall
 
   /// Asynchronous. Bidirectional-streaming.
   /// Use methods on the returned object to stream messages,
   /// to wait for replies, and to close the connection.
-  func update(completion: ((CallResult) -> Void)?) throws -> Echo_EchoUpdateCall
+  func update(metadata customMetadata: Metadata, completion: ((CallResult) -> Void)?) throws -> Echo_EchoUpdateCall
+
+}
+
+internal extension Echo_EchoService {
+  /// Synchronous. Unary.
+  func get(_ request: Echo_EchoRequest) throws -> Echo_EchoResponse {
+    return try self.get(request, metadata: self.metadata)
+  }
+  /// Asynchronous. Unary.
+  func get(_ request: Echo_EchoRequest, completion: @escaping (Echo_EchoResponse?, CallResult) -> Void) throws -> Echo_EchoGetCall {
+    return try self.get(request, metadata: self.metadata, completion: completion)
+  }
+
+  /// Asynchronous. Server-streaming.
+  func expand(_ request: Echo_EchoRequest, completion: ((CallResult) -> Void)?) throws -> Echo_EchoExpandCall {
+    return try self.expand(request, metadata: self.metadata, completion: completion)
+  }
+
+  /// Asynchronous. Client-streaming.
+  func collect(completion: ((CallResult) -> Void)?) throws -> Echo_EchoCollectCall {
+    return try self.collect(metadata: self.metadata, completion: completion)
+  }
+
+  /// Asynchronous. Bidirectional-streaming.
+  func update(completion: ((CallResult) -> Void)?) throws -> Echo_EchoUpdateCall {
+    return try self.update(metadata: self.metadata, completion: completion)
+  }
 
 }
 
 internal final class Echo_EchoServiceClient: ServiceClientBase, Echo_EchoService {
   /// Synchronous. Unary.
-  internal func get(_ request: Echo_EchoRequest) throws -> Echo_EchoResponse {
+  internal func get(_ request: Echo_EchoRequest, metadata customMetadata: Metadata) throws -> Echo_EchoResponse {
     return try Echo_EchoGetCallBase(channel)
-      .run(request: request, metadata: metadata)
+      .run(request: request, metadata: customMetadata)
   }
   /// Asynchronous. Unary.
-  internal func get(_ request: Echo_EchoRequest, completion: @escaping (Echo_EchoResponse?, CallResult) -> Void) throws -> Echo_EchoGetCall {
+  internal func get(_ request: Echo_EchoRequest, metadata customMetadata: Metadata, completion: @escaping (Echo_EchoResponse?, CallResult) -> Void) throws -> Echo_EchoGetCall {
     return try Echo_EchoGetCallBase(channel)
-      .start(request: request, metadata: metadata, completion: completion)
+      .start(request: request, metadata: customMetadata, completion: completion)
   }
 
   /// Asynchronous. Server-streaming.
   /// Send the initial message.
   /// Use methods on the returned object to get streamed responses.
-  internal func expand(_ request: Echo_EchoRequest, completion: ((CallResult) -> Void)?) throws -> Echo_EchoExpandCall {
+  internal func expand(_ request: Echo_EchoRequest, metadata customMetadata: Metadata, completion: ((CallResult) -> Void)?) throws -> Echo_EchoExpandCall {
     return try Echo_EchoExpandCallBase(channel)
-      .start(request: request, metadata: metadata, completion: completion)
+      .start(request: request, metadata: customMetadata, completion: completion)
   }
 
   /// Asynchronous. Client-streaming.
   /// Use methods on the returned object to stream messages and
   /// to close the connection and wait for a final response.
-  internal func collect(completion: ((CallResult) -> Void)?) throws -> Echo_EchoCollectCall {
+  internal func collect(metadata customMetadata: Metadata, completion: ((CallResult) -> Void)?) throws -> Echo_EchoCollectCall {
     return try Echo_EchoCollectCallBase(channel)
-      .start(metadata: metadata, completion: completion)
+      .start(metadata: customMetadata, completion: completion)
   }
 
   /// Asynchronous. Bidirectional-streaming.
   /// Use methods on the returned object to stream messages,
   /// to wait for replies, and to close the connection.
-  internal func update(completion: ((CallResult) -> Void)?) throws -> Echo_EchoUpdateCall {
+  internal func update(metadata customMetadata: Metadata, completion: ((CallResult) -> Void)?) throws -> Echo_EchoUpdateCall {
     return try Echo_EchoUpdateCallBase(channel)
-      .start(metadata: metadata, completion: completion)
+      .start(metadata: customMetadata, completion: completion)
   }
 
 }
@@ -179,31 +206,31 @@ internal final class Echo_EchoServiceClient: ServiceClientBase, Echo_EchoService
 class Echo_EchoServiceTestStub: ServiceClientTestStubBase, Echo_EchoService {
   var getRequests: [Echo_EchoRequest] = []
   var getResponses: [Echo_EchoResponse] = []
-  func get(_ request: Echo_EchoRequest) throws -> Echo_EchoResponse {
+  func get(_ request: Echo_EchoRequest, metadata customMetadata: Metadata) throws -> Echo_EchoResponse {
     getRequests.append(request)
     defer { getResponses.removeFirst() }
     return getResponses.first!
   }
-  func get(_ request: Echo_EchoRequest, completion: @escaping (Echo_EchoResponse?, CallResult) -> Void) throws -> Echo_EchoGetCall {
+  func get(_ request: Echo_EchoRequest, metadata customMetadata: Metadata, completion: @escaping (Echo_EchoResponse?, CallResult) -> Void) throws -> Echo_EchoGetCall {
     fatalError("not implemented")
   }
 
   var expandRequests: [Echo_EchoRequest] = []
   var expandCalls: [Echo_EchoExpandCall] = []
-  func expand(_ request: Echo_EchoRequest, completion: ((CallResult) -> Void)?) throws -> Echo_EchoExpandCall {
+  func expand(_ request: Echo_EchoRequest, metadata customMetadata: Metadata, completion: ((CallResult) -> Void)?) throws -> Echo_EchoExpandCall {
     expandRequests.append(request)
     defer { expandCalls.removeFirst() }
     return expandCalls.first!
   }
 
   var collectCalls: [Echo_EchoCollectCall] = []
-  func collect(completion: ((CallResult) -> Void)?) throws -> Echo_EchoCollectCall {
+  func collect(metadata customMetadata: Metadata, completion: ((CallResult) -> Void)?) throws -> Echo_EchoCollectCall {
     defer { collectCalls.removeFirst() }
     return collectCalls.first!
   }
 
   var updateCalls: [Echo_EchoUpdateCall] = []
-  func update(completion: ((CallResult) -> Void)?) throws -> Echo_EchoUpdateCall {
+  func update(metadata customMetadata: Metadata, completion: ((CallResult) -> Void)?) throws -> Echo_EchoUpdateCall {
     defer { updateCalls.removeFirst() }
     return updateCalls.first!
   }

+ 75 - 20
Sources/protoc-gen-swiftgrpc/Generator-Client.swift

@@ -37,6 +37,9 @@ extension Generator {
     printServiceClientProtocol(asynchronousCode: asynchronousCode,
                                synchronousCode: synchronousCode)
     println()
+    printServiceClientProtocolExtension(asynchronousCode: asynchronousCode,
+                                        synchronousCode: synchronousCode)
+    println()
     printServiceClientImplementation(asynchronousCode: asynchronousCode,
                                      synchronousCode: synchronousCode)
     if options.generateTestStubs {
@@ -158,27 +161,79 @@ extension Generator {
       case .unary:
         if synchronousCode {
           println("/// Synchronous. Unary.")
-          println("func \(methodFunctionName)(_ request: \(methodInputName)) throws -> \(methodOutputName)")
+          println("func \(methodFunctionName)(_ request: \(methodInputName), metadata customMetadata: Metadata) throws -> \(methodOutputName)")
         }
         if asynchronousCode {
           println("/// Asynchronous. Unary.")
-          println("func \(methodFunctionName)(_ request: \(methodInputName), completion: @escaping (\(methodOutputName)?, CallResult) -> Void) throws -> \(callName)")
+          println("func \(methodFunctionName)(_ request: \(methodInputName), metadata customMetadata: Metadata, completion: @escaping (\(methodOutputName)?, CallResult) -> Void) throws -> \(callName)")
         }
       case .serverStreaming:
         println("/// Asynchronous. Server-streaming.")
         println("/// Send the initial message.")
         println("/// Use methods on the returned object to get streamed responses.")
-        println("func \(methodFunctionName)(_ request: \(methodInputName), completion: ((CallResult) -> Void)?) throws -> \(callName)")
+        println("func \(methodFunctionName)(_ request: \(methodInputName), metadata customMetadata: Metadata, completion: ((CallResult) -> Void)?) throws -> \(callName)")
       case .clientStreaming:
         println("/// Asynchronous. Client-streaming.")
         println("/// Use methods on the returned object to stream messages and")
         println("/// to close the connection and wait for a final response.")
-        println("func \(methodFunctionName)(completion: ((CallResult) -> Void)?) throws -> \(callName)")
+        println("func \(methodFunctionName)(metadata customMetadata: Metadata, completion: ((CallResult) -> Void)?) throws -> \(callName)")
       case .bidirectionalStreaming:
         println("/// Asynchronous. Bidirectional-streaming.")
         println("/// Use methods on the returned object to stream messages,")
         println("/// to wait for replies, and to close the connection.")
-        println("func \(methodFunctionName)(completion: ((CallResult) -> Void)?) throws -> \(callName)")
+        println("func \(methodFunctionName)(metadata customMetadata: Metadata, completion: ((CallResult) -> Void)?) throws -> \(callName)")
+      }
+      println()
+    }
+    outdent()
+    println("}")
+  }
+
+  private func printServiceClientProtocolExtension(asynchronousCode: Bool,
+                                                   synchronousCode: Bool) {
+    println("\(options.visibility.sourceSnippet) extension \(serviceClassName) {")
+    indent()
+    for method in service.methods {
+      self.method = method
+      switch streamingType(method) {
+      case .unary:
+        if synchronousCode {
+          println("/// Synchronous. Unary.")
+          println("func \(methodFunctionName)(_ request: \(methodInputName)) throws -> \(methodOutputName) {")
+          indent()
+          println("return try self.\(methodFunctionName)(request, metadata: self.metadata)")
+          outdent()
+          println("}")
+        }
+        if asynchronousCode {
+          println("/// Asynchronous. Unary.")
+          println("func \(methodFunctionName)(_ request: \(methodInputName), completion: @escaping (\(methodOutputName)?, CallResult) -> Void) throws -> \(callName) {")
+          indent()
+          println("return try self.\(methodFunctionName)(request, metadata: self.metadata, completion: completion)")
+          outdent()
+          println("}")
+        }
+      case .serverStreaming:
+        println("/// Asynchronous. Server-streaming.")
+        println("func \(methodFunctionName)(_ request: \(methodInputName), completion: ((CallResult) -> Void)?) throws -> \(callName) {")
+        indent()
+        println("return try self.\(methodFunctionName)(request, metadata: self.metadata, completion: completion)")
+        outdent()
+        println("}")
+      case .clientStreaming:
+        println("/// Asynchronous. Client-streaming.")
+        println("func \(methodFunctionName)(completion: ((CallResult) -> Void)?) throws -> \(callName) {")
+        indent()
+        println("return try self.\(methodFunctionName)(metadata: self.metadata, completion: completion)")
+        outdent()
+        println("}")
+      case .bidirectionalStreaming:
+        println("/// Asynchronous. Bidirectional-streaming.")
+        println("func \(methodFunctionName)(completion: ((CallResult) -> Void)?) throws -> \(callName) {")
+        indent()
+        println("return try self.\(methodFunctionName)(metadata: self.metadata, completion: completion)")
+        outdent()
+        println("}")
       }
       println()
     }
@@ -196,22 +251,22 @@ extension Generator {
       case .unary:
         if synchronousCode {
           println("/// Synchronous. Unary.")
-          println("\(access) func \(methodFunctionName)(_ request: \(methodInputName)) throws -> \(methodOutputName) {")
+          println("\(access) func \(methodFunctionName)(_ request: \(methodInputName), metadata customMetadata: Metadata) throws -> \(methodOutputName) {")
           indent()
           println("return try \(callName)Base(channel)")
           indent()
-          println(".run(request: request, metadata: metadata)")
+          println(".run(request: request, metadata: customMetadata)")
           outdent()
           outdent()
           println("}")
         }
         if asynchronousCode {
           println("/// Asynchronous. Unary.")
-          println("\(access) func \(methodFunctionName)(_ request: \(methodInputName), completion: @escaping (\(methodOutputName)?, CallResult) -> Void) throws -> \(callName) {")
+          println("\(access) func \(methodFunctionName)(_ request: \(methodInputName), metadata customMetadata: Metadata, completion: @escaping (\(methodOutputName)?, CallResult) -> Void) throws -> \(callName) {")
           indent()
           println("return try \(callName)Base(channel)")
           indent()
-          println(".start(request: request, metadata: metadata, completion: completion)")
+          println(".start(request: request, metadata: customMetadata, completion: completion)")
           outdent()
           outdent()
           println("}")
@@ -220,11 +275,11 @@ extension Generator {
         println("/// Asynchronous. Server-streaming.")
         println("/// Send the initial message.")
         println("/// Use methods on the returned object to get streamed responses.")
-        println("\(access) func \(methodFunctionName)(_ request: \(methodInputName), completion: ((CallResult) -> Void)?) throws -> \(callName) {")
+        println("\(access) func \(methodFunctionName)(_ request: \(methodInputName), metadata customMetadata: Metadata, completion: ((CallResult) -> Void)?) throws -> \(callName) {")
         indent()
         println("return try \(callName)Base(channel)")
         indent()
-        println(".start(request: request, metadata: metadata, completion: completion)")
+        println(".start(request: request, metadata: customMetadata, completion: completion)")
         outdent()
         outdent()
         println("}")
@@ -232,11 +287,11 @@ extension Generator {
         println("/// Asynchronous. Client-streaming.")
         println("/// Use methods on the returned object to stream messages and")
         println("/// to close the connection and wait for a final response.")
-        println("\(access) func \(methodFunctionName)(completion: ((CallResult) -> Void)?) throws -> \(callName) {")
+        println("\(access) func \(methodFunctionName)(metadata customMetadata: Metadata, completion: ((CallResult) -> Void)?) throws -> \(callName) {")
         indent()
         println("return try \(callName)Base(channel)")
         indent()
-        println(".start(metadata: metadata, completion: completion)")
+        println(".start(metadata: customMetadata, completion: completion)")
         outdent()
         outdent()
         println("}")
@@ -244,11 +299,11 @@ extension Generator {
         println("/// Asynchronous. Bidirectional-streaming.")
         println("/// Use methods on the returned object to stream messages,")
         println("/// to wait for replies, and to close the connection.")
-        println("\(access) func \(methodFunctionName)(completion: ((CallResult) -> Void)?) throws -> \(callName) {")
+        println("\(access) func \(methodFunctionName)(metadata customMetadata: Metadata, completion: ((CallResult) -> Void)?) throws -> \(callName) {")
         indent()
         println("return try \(callName)Base(channel)")
         indent()
-        println(".start(metadata: metadata, completion: completion)")
+        println(".start(metadata: customMetadata, completion: completion)")
         outdent()
         outdent()
         println("}")
@@ -268,14 +323,14 @@ extension Generator {
       case .unary:
         println("var \(methodFunctionName)Requests: [\(methodInputName)] = []")
         println("var \(methodFunctionName)Responses: [\(methodOutputName)] = []")
-        println("func \(methodFunctionName)(_ request: \(methodInputName)) throws -> \(methodOutputName) {")
+        println("func \(methodFunctionName)(_ request: \(methodInputName), metadata customMetadata: Metadata) throws -> \(methodOutputName) {")
         indent()
         println("\(methodFunctionName)Requests.append(request)")
         println("defer { \(methodFunctionName)Responses.removeFirst() }")
         println("return \(methodFunctionName)Responses.first!")
         outdent()
         println("}")
-        println("func \(methodFunctionName)(_ request: \(methodInputName), completion: @escaping (\(methodOutputName)?, CallResult) -> Void) throws -> \(callName) {")
+        println("func \(methodFunctionName)(_ request: \(methodInputName), metadata customMetadata: Metadata, completion: @escaping (\(methodOutputName)?, CallResult) -> Void) throws -> \(callName) {")
         indent()
         println("fatalError(\"not implemented\")")
         outdent()
@@ -283,7 +338,7 @@ extension Generator {
       case .serverStreaming:
         println("var \(methodFunctionName)Requests: [\(methodInputName)] = []")
         println("var \(methodFunctionName)Calls: [\(callName)] = []")
-        println("func \(methodFunctionName)(_ request: \(methodInputName), completion: ((CallResult) -> Void)?) throws -> \(callName) {")
+        println("func \(methodFunctionName)(_ request: \(methodInputName), metadata customMetadata: Metadata, completion: ((CallResult) -> Void)?) throws -> \(callName) {")
         indent()
         println("\(methodFunctionName)Requests.append(request)")
         println("defer { \(methodFunctionName)Calls.removeFirst() }")
@@ -292,7 +347,7 @@ extension Generator {
         println("}")
       case .clientStreaming:
         println("var \(methodFunctionName)Calls: [\(callName)] = []")
-        println("func \(methodFunctionName)(completion: ((CallResult) -> Void)?) throws -> \(callName) {")
+        println("func \(methodFunctionName)(metadata customMetadata: Metadata, completion: ((CallResult) -> Void)?) throws -> \(callName) {")
         indent()
         println("defer { \(methodFunctionName)Calls.removeFirst() }")
         println("return \(methodFunctionName)Calls.first!")
@@ -300,7 +355,7 @@ extension Generator {
         println("}")
       case .bidirectionalStreaming:
         println("var \(methodFunctionName)Calls: [\(callName)] = []")
-        println("func \(methodFunctionName)(completion: ((CallResult) -> Void)?) throws -> \(callName) {")
+        println("func \(methodFunctionName)(metadata customMetadata: Metadata, completion: ((CallResult) -> Void)?) throws -> \(callName) {")
         indent()
         println("defer { \(methodFunctionName)Calls.removeFirst() }")
         println("return \(methodFunctionName)Calls.first!")