浏览代码

Some tweaks to the original PR.

Daniel Alm 7 年之前
父节点
当前提交
2e8e79d090

+ 2 - 1
Sources/CgRPC/shim/cgrpc.h

@@ -152,6 +152,7 @@ const char *grpc_version_string(void);
 const char *grpc_g_stands_for(void);
 
 char *gpr_strdup(const char *src);
+void gpr_free(void *p);
 
 void cgrpc_completion_queue_drain(cgrpc_completion_queue *cq);
 void grpc_completion_queue_destroy(cgrpc_completion_queue *cq);
@@ -160,7 +161,7 @@ void grpc_completion_queue_destroy(cgrpc_completion_queue *cq);
 void cgrpc_free_copied_string(char *string);
 
 // channel support
-cgrpc_channel *cgrpc_channel_create(const char *address, 
+cgrpc_channel *cgrpc_channel_create(const char *address,
                                     grpc_arg *args,
                                     int num_args);
 cgrpc_channel *cgrpc_channel_create_secure(const char *address,

+ 10 - 14
Sources/CgRPC/shim/channel.c

@@ -25,35 +25,31 @@
 
 cgrpc_channel *cgrpc_channel_create(const char *address,
                                     grpc_arg *args,
-                                    int number_args) {
+                                    int num_args) {
   cgrpc_channel *c = (cgrpc_channel *) malloc(sizeof (cgrpc_channel));
 
-  grpc_channel_args *channel_args = gpr_malloc(sizeof(grpc_channel_args));
-  channel_args->args = args;
-  channel_args->num_args = number_args;
+  grpc_channel_args channel_args;
+  channel_args.args = args;
+  channel_args.num_args = num_args;
 
-  // create the channel
-  c->channel = grpc_insecure_channel_create(address, channel_args, NULL);
+  c->channel = grpc_insecure_channel_create(address, &channel_args, NULL);
   c->completion_queue = grpc_completion_queue_create_for_next(NULL);
-  gpr_free(channel_args);
   return c;
 }
 
 cgrpc_channel *cgrpc_channel_create_secure(const char *address,
                                            const char *pem_root_certs,
                                            grpc_arg *args,
-                                           int number_args) {
+                                           int num_args) {
   cgrpc_channel *c = (cgrpc_channel *) malloc(sizeof (cgrpc_channel));
-  // create the channel
 
-  grpc_channel_args *channel_args = gpr_malloc(sizeof(grpc_channel_args));
-  channel_args->args = args;
-  channel_args->num_args = number_args;
+  grpc_channel_args channel_args;
+  channel_args.args = args;
+  channel_args.num_args = num_args;
 
   grpc_channel_credentials *creds = grpc_ssl_credentials_create(pem_root_certs, NULL, NULL);
-  c->channel = grpc_secure_channel_create(creds, address, channel_args, NULL);
+  c->channel = grpc_secure_channel_create(creds, address, &channel_args, NULL);
   c->completion_queue = grpc_completion_queue_create_for_next(NULL);
-  gpr_free(channel_args);
   return c;
 }
 

+ 3 - 3
Sources/Examples/Echo/main.swift

@@ -39,10 +39,10 @@ func buildEchoService(_ ssl: Bool, _ address: String, _ port: String, _: String)
     let certificateURL = URL(fileURLWithPath: "ssl.crt")
     let certificates = try! String(contentsOf: certificateURL)
     let arguments: [Channel.Argument] = [.sslTargetNameOverride("example.com")]
-
+    
     service = Echo_EchoServiceClient(address: address + ":" + port,
-                               certificates: certificates,
-                               arguments: arguments)
+                                     certificates: certificates,
+                                     arguments: arguments)
     service.host = "example.com"
   } else {
     service = Echo_EchoServiceClient(address: address + ":" + port, secure: false)

+ 9 - 5
Sources/SwiftGRPC/Core/Channel.swift

@@ -42,13 +42,15 @@ public class Channel {
   /// - Parameter secure: if true, use TLS
   /// - Parameter arguments: list of channel configuration options
   public init(address: String, secure: Bool = true, arguments: [Argument] = []) {
+    gRPC.initialize()
     host = address
-    var cargs = arguments.map { $0.toCArg() }
+    let argumentWrappers = arguments.map { $0.toCArg() }
+    var argumentValues = argumentWrappers.map { $0.wrapped }
 
     if secure {
-      underlyingChannel = cgrpc_channel_create_secure(address, roots_pem(), &cargs, Int32(arguments.count))
+      underlyingChannel = cgrpc_channel_create_secure(address, roots_pem(), &argumentValues, Int32(arguments.count))
     } else {
-      underlyingChannel = cgrpc_channel_create(address, &cargs, Int32(arguments.count))
+      underlyingChannel = cgrpc_channel_create(address, &argumentValues, Int32(arguments.count))
     }
     completionQueue = CompletionQueue(
       underlyingCompletionQueue: cgrpc_channel_completion_queue(underlyingChannel), name: "Client")
@@ -61,10 +63,12 @@ public class Channel {
   /// - Parameter certificates: a PEM representation of certificates to use
   /// - Parameter arguments: list of channel configuration options
   public init(address: String, certificates: String, arguments: [Argument] = []) {
+    gRPC.initialize()
     self.host = address
-    var cargs = arguments.map { $0.toCArg() }
+    let argumentWrappers = arguments.map { $0.toCArg() }
+    var argumentValues = argumentWrappers.map { $0.wrapped }
 
-    underlyingChannel = cgrpc_channel_create_secure(address, certificates, &cargs, Int32(arguments.count))
+    underlyingChannel = cgrpc_channel_create_secure(address, certificates, &argumentValues, Int32(arguments.count))
     completionQueue = CompletionQueue(
       underlyingCompletionQueue: cgrpc_channel_completion_queue(underlyingChannel), name: "Client")
     completionQueue.run() // start a loop that watches the channel's completion queue

+ 25 - 9
Sources/SwiftGRPC/Core/ChannelArgument.swift

@@ -75,7 +75,23 @@ public extension Channel {
 }
 
 extension Channel.Argument {
-  func toCArg() -> grpc_arg {
+  class Wrapper {
+    // Creating a `grpc_arg` allocates memory. This wrapper ensures that the memory is freed after use.
+    let wrapped: grpc_arg
+    
+    init(_ wrapped: grpc_arg) {
+      self.wrapped = wrapped
+    }
+    
+    deinit {
+      gpr_free(wrapped.key)
+      if wrapped.type == GRPC_ARG_STRING {
+        gpr_free(wrapped.value.string)
+      }
+    }
+  }
+  
+  func toCArg() -> Wrapper {
     switch self {
     case let .defaultAuthority(value):
       return makeArgument("grpc.default_authority", value: value)
@@ -107,34 +123,34 @@ extension Channel.Argument {
   }
 }
 
-private func makeArgument(_ key: String, value: String) -> grpc_arg {
+private func makeArgument(_ key: String, value: String) -> Channel.Argument.Wrapper {
   var arg = grpc_arg()
   arg.key = gpr_strdup(key)
   arg.type = GRPC_ARG_STRING
   arg.value.string = gpr_strdup(value)
-  return arg
+  return Channel.Argument.Wrapper(arg)
 }
 
-private func makeArgument(_ key: String, value: Bool) -> grpc_arg {
+private func makeArgument(_ key: String, value: Bool) -> Channel.Argument.Wrapper {
   return makeArgument(key, value: Int32(value ? 1 : 0))
 }
 
-private func makeArgument(_ key: String, value: Double) -> grpc_arg {
+private func makeArgument(_ key: String, value: Double) -> Channel.Argument.Wrapper {
   return makeArgument(key, value: Int32(value))
 }
 
-private func makeArgument(_ key: String, value: UInt) -> grpc_arg {
+private func makeArgument(_ key: String, value: UInt) -> Channel.Argument.Wrapper {
   return makeArgument(key, value: Int32(value))
 }
 
-private func makeArgument(_ key: String, value: Int) -> grpc_arg {
+private func makeArgument(_ key: String, value: Int) -> Channel.Argument.Wrapper {
   return makeArgument(key, value: Int32(value))
 }
 
-private func makeArgument(_ key: String, value: Int32) -> grpc_arg {
+private func makeArgument(_ key: String, value: Int32) -> Channel.Argument.Wrapper {
   var arg = grpc_arg()
   arg.key = gpr_strdup(key)
   arg.type = GRPC_ARG_INTEGER
   arg.value.integer = value
-  return arg
+  return Channel.Argument.Wrapper(arg)
 }

+ 14 - 0
Sources/SwiftGRPC/Core/Metadata.swift

@@ -110,3 +110,17 @@ public class Metadata: CustomStringConvertible {
     return underlyingArray
   }
 }
+
+extension Metadata {
+  public subscript(_ key: String) -> String? {
+    for i in 0..<self.count() {
+      let currentKey = self.key(i)
+      guard currentKey == key
+        else { continue }
+      
+      return self.value(i)
+    }
+    
+    return nil
+  }
+}

+ 6 - 0
Sources/SwiftGRPC/Runtime/ServiceClient.swift

@@ -55,6 +55,12 @@ open class ServiceClientBase: ServiceClient {
     metadata = Metadata()
   }
 
+  /// Create a client using a pre-defined channel.
+  public init(channel: Channel) {
+    self.channel = channel
+    metadata = Metadata()
+  }
+  
   /// Create a client that makes secure connections with a custom certificate.
   public init(address: String, certificates: String, arguments: [Channel.Argument] = []) {
     gRPC.initialize()

+ 1 - 0
Tests/LinuxMain.swift

@@ -18,6 +18,7 @@ import XCTest
 
 XCTMain([
   testCase(gRPCTests.allTests),
+  testCase(ChannelArgumentTests.allTests),
   testCase(ClientCancellingTests.allTests),
   testCase(ClientTestExample.allTests),
   testCase(ClientTimeoutTests.allTests),

+ 59 - 11
Tests/SwiftGRPCTests/ChannelArgumentTests.swift

@@ -20,29 +20,77 @@ import Foundation
 @testable import SwiftGRPC
 import XCTest
 
-class ChannelArgumentTests: XCTestCase {
+fileprivate class ChannelArgumentTestProvider: Echo_EchoProvider {
+  func get(request: Echo_EchoRequest, session: Echo_EchoGetSession) throws -> Echo_EchoResponse {
+    // We simply return the user agent we received, which can then be inspected by the test code.
+    return Echo_EchoResponse(text: (session as! ServerSessionBase).handler.requestMetadata["user-agent"]!)
+  }
+  
+  func expand(request: Echo_EchoRequest, session: Echo_EchoExpandSession) throws {
+    fatalError("not implemented")
+  }
+  
+  func collect(session: Echo_EchoCollectSession) throws {
+    fatalError("not implemented")
+  }
+  
+  func update(session: Echo_EchoUpdateSession) throws {
+    fatalError("not implemented")
+  }
+}
+
+class ChannelArgumentTests: BasicEchoTestCase {
+  static var allTests: [(String, (ChannelArgumentTests) -> () throws -> Void)] {
+    return [
+      ("testArgumentKey", testArgumentKey),
+      ("testStringArgument", testStringArgument),
+      ("testIntegerArgument", testIntegerArgument),
+      ("testBoolArgument", testBoolArgument),
+      ("testTimeIntervalArgument", testTimeIntervalArgument),
+    ]
+  }
+  
+  fileprivate func makeClient(_ arguments: [Channel.Argument]) -> Echo_EchoServiceClient {
+    let client = Echo_EchoServiceClient(address: address, secure: false, arguments: arguments)
+    client.timeout = defaultTimeout
+    return client
+  }
+  
+  override func makeProvider() -> Echo_EchoProvider { return ChannelArgumentTestProvider() }
+}
+
+extension ChannelArgumentTests {
   func testArgumentKey() {
-    let argument: Channel.Argument = .defaultAuthority("default")
-    XCTAssertEqual(String(cString: argument.toCArg().key), "grpc.default_authority")
+    let argument = Channel.Argument.defaultAuthority("default")
+    XCTAssertEqual(String(cString: argument.toCArg().wrapped.key), "grpc.default_authority")
   }
 
   func testStringArgument() {
-    let argument: Channel.Argument = .primaryUserAgent("Primary/0.1")
-    XCTAssertEqual(String(cString: argument.toCArg().value.string), "Primary/0.1")
+    let argument = Channel.Argument.primaryUserAgent("Primary/0.1")
+    XCTAssertEqual(String(cString: argument.toCArg().wrapped.value.string), "Primary/0.1")
   }
 
   func testIntegerArgument() {
-    let argument: Channel.Argument = .http2MaxPingsWithoutData(5)
-    XCTAssertEqual(argument.toCArg().value.integer, 5)
+    let argument = Channel.Argument.http2MaxPingsWithoutData(5)
+    XCTAssertEqual(argument.toCArg().wrapped.value.integer, 5)
   }
 
   func testBoolArgument() {
-    let argument: Channel.Argument = .keepAlivePermitWithoutCalls(true)
-    XCTAssertEqual(argument.toCArg().value.integer, 1)
+    let argument = Channel.Argument.keepAlivePermitWithoutCalls(true)
+    XCTAssertEqual(argument.toCArg().wrapped.value.integer, 1)
   }
 
   func testTimeIntervalArgument() {
-    let argument: Channel.Argument = .keepAliveTime(2.5)
-    XCTAssertEqual(argument.toCArg().value.integer, 2500) // in ms
+    let argument = Channel.Argument.keepAliveTime(2.5)
+    XCTAssertEqual(argument.toCArg().wrapped.value.integer, 2500) // in ms
+  }
+}
+
+extension ChannelArgumentTests {
+  func testPracticalUse() {
+    let client = makeClient([.primaryUserAgent("FOO"), .secondaryUserAgent("BAR")])
+    let responseText = try! client.get(Echo_EchoRequest(text: "")).text
+    XCTAssertTrue(responseText.hasPrefix("FOO "), "user agent \(responseText) should begin with 'FOO '")
+    XCTAssertTrue(responseText.hasSuffix(" BAR"), "user agent \(responseText) should end with ' BAR'")
   }
 }