Browse Source

Merge pull request #225 from MrMage/args

Finalize adding channel arguments
Tim Burks 7 years ago
parent
commit
64ba6ef8fd

+ 33 - 2
Sources/CgRPC/shim/cgrpc.h

@@ -118,6 +118,31 @@ typedef struct grpc_event {
   void *tag;
 } grpc_event;
 
+typedef enum grpc_arg_type {
+  GRPC_ARG_STRING,
+  GRPC_ARG_INTEGER,
+  GRPC_ARG_POINTER
+} grpc_arg_type;
+
+typedef struct grpc_arg_pointer_vtable {
+  void *(*copy)(void *p);
+  void (*destroy)(void *p);
+  int (*cmp)(void *p, void *q);
+} grpc_arg_pointer_vtable;
+
+typedef struct grpc_arg {
+  grpc_arg_type type;
+  char *key;
+  union grpc_arg_value {
+    char *string;
+    int integer;
+    struct grpc_arg_pointer {
+      void *p;
+      const grpc_arg_pointer_vtable *vtable;
+    } pointer;
+  } value;
+} grpc_arg;
+
 #endif
 
 // directly expose a few grpc library functions
@@ -126,6 +151,9 @@ void grpc_shutdown(void);
 const char *grpc_version_string(void);
 const char *grpc_g_stands_for(void);
 
+char *gpr_strdup(const char *src);
+void gpr_free(void *p);
+
 void cgrpc_completion_queue_drain(cgrpc_completion_queue *cq);
 void grpc_completion_queue_destroy(cgrpc_completion_queue *cq);
 
@@ -133,10 +161,13 @@ void grpc_completion_queue_destroy(cgrpc_completion_queue *cq);
 void cgrpc_free_copied_string(char *string);
 
 // channel support
-cgrpc_channel *cgrpc_channel_create(const char *address);
+cgrpc_channel *cgrpc_channel_create(const char *address,
+                                    grpc_arg *args,
+                                    int num_args);
 cgrpc_channel *cgrpc_channel_create_secure(const char *address,
                                            const char *pem_root_certs,
-                                           const char *host);
+                                           grpc_arg *args,
+                                           int num_args);
 
 void cgrpc_channel_destroy(cgrpc_channel *channel);
 cgrpc_call *cgrpc_channel_create_call(cgrpc_channel *channel,

+ 13 - 25
Sources/CgRPC/shim/channel.c

@@ -23,11 +23,15 @@
 #include <stdlib.h>
 #include <string.h>
 
-cgrpc_channel *cgrpc_channel_create(const char *address) {
+cgrpc_channel *cgrpc_channel_create(const char *address,
+                                    grpc_arg *args,
+                                    int num_args) {
   cgrpc_channel *c = (cgrpc_channel *) malloc(sizeof (cgrpc_channel));
-  // create the channel
+
   grpc_channel_args channel_args;
-  channel_args.num_args = 0;
+  channel_args.args = args;
+  channel_args.num_args = num_args;
+
   c->channel = grpc_insecure_channel_create(address, &channel_args, NULL);
   c->completion_queue = grpc_completion_queue_create_for_next(NULL);
   return c;
@@ -35,32 +39,16 @@ cgrpc_channel *cgrpc_channel_create(const char *address) {
 
 cgrpc_channel *cgrpc_channel_create_secure(const char *address,
                                            const char *pem_root_certs,
-                                           const char *host) {
+                                           grpc_arg *args,
+                                           int num_args) {
   cgrpc_channel *c = (cgrpc_channel *) malloc(sizeof (cgrpc_channel));
-  // create the channel
-
-  int argMax = 2;
-  grpc_channel_args *channelArgs = gpr_malloc(sizeof(grpc_channel_args));
-  channelArgs->args = gpr_malloc(argMax * sizeof(grpc_arg));
 
-  int argCount = 1;
-  grpc_arg *arg = &channelArgs->args[0];
-  arg->type = GRPC_ARG_STRING;
-  arg->key = gpr_strdup("grpc.primary_user_agent");
-  arg->value.string = gpr_strdup("grpc-swift/0.0.1");
-
-  if (host) {
-    argCount++;
-    arg = &channelArgs->args[1];
-    arg->type = GRPC_ARG_STRING;
-    arg->key = gpr_strdup("grpc.ssl_target_name_override");
-    arg->value.string = gpr_strdup(host);
-  }
-
-  channelArgs->num_args = argCount;
+  grpc_channel_args channel_args;
+  channel_args.args = args;
+  channel_args.num_args = num_args;
 
   grpc_channel_credentials *creds = grpc_ssl_credentials_create(pem_root_certs, NULL, NULL);
-  c->channel = grpc_secure_channel_create(creds, address, channelArgs, NULL);
+  c->channel = grpc_secure_channel_create(creds, address, &channel_args, NULL);
   c->completion_queue = grpc_completion_queue_create_for_next(NULL);
   return c;
 }

+ 4 - 2
Sources/Examples/Echo/main.swift

@@ -38,9 +38,11 @@ func buildEchoService(_ ssl: Bool, _ address: String, _ port: String, _: String)
   if ssl {
     let certificateURL = URL(fileURLWithPath: "ssl.crt")
     let certificates = try! String(contentsOf: certificateURL)
+    let arguments: [Channel.Argument] = [.sslTargetNameOverride("example.com")]
+    
     service = Echo_EchoServiceClient(address: address + ":" + port,
-                               certificates: certificates,
-                               host: "example.com")
+                                     certificates: certificates,
+                                     arguments: arguments)
     service.host = "example.com"
   } else {
     service = Echo_EchoServiceClient(address: address + ":" + port, secure: false)

+ 13 - 6
Sources/SwiftGRPC/Core/Channel.swift

@@ -40,13 +40,17 @@ public class Channel {
   ///
   /// - Parameter address: the address of the server to be called
   /// - Parameter secure: if true, use TLS
-  public init(address: String, secure: Bool = true) {
+  /// - Parameter arguments: list of channel configuration options
+  public init(address: String, secure: Bool = true, arguments: [Argument] = []) {
     gRPC.initialize()
     host = address
+    let argumentWrappers = arguments.map { $0.toCArg() }
+    var argumentValues = argumentWrappers.map { $0.wrapped }
+
     if secure {
-      underlyingChannel = cgrpc_channel_create_secure(address, roots_pem(), nil)
+      underlyingChannel = cgrpc_channel_create_secure(address, roots_pem(), &argumentValues, Int32(arguments.count))
     } else {
-      underlyingChannel = cgrpc_channel_create(address)
+      underlyingChannel = cgrpc_channel_create(address, &argumentValues, Int32(arguments.count))
     }
     completionQueue = CompletionQueue(
       underlyingCompletionQueue: cgrpc_channel_completion_queue(underlyingChannel), name: "Client")
@@ -57,11 +61,14 @@ public class Channel {
   ///
   /// - Parameter address: the address of the server to be called
   /// - Parameter certificates: a PEM representation of certificates to use
-  /// - Parameter host: an optional hostname override
-  public init(address: String, certificates: String, host: String?) {
+  /// - Parameter arguments: list of channel configuration options
+  public init(address: String, certificates: String, arguments: [Argument] = []) {
     gRPC.initialize()
     self.host = address
-    underlyingChannel = cgrpc_channel_create_secure(address, certificates, host)
+    let argumentWrappers = arguments.map { $0.toCArg() }
+    var argumentValues = argumentWrappers.map { $0.wrapped }
+
+    underlyingChannel = cgrpc_channel_create_secure(address, certificates, &argumentValues, Int32(arguments.count))
     completionQueue = CompletionQueue(
       underlyingCompletionQueue: cgrpc_channel_completion_queue(underlyingChannel), name: "Client")
     completionQueue.run() // start a loop that watches the channel's completion queue

+ 156 - 0
Sources/SwiftGRPC/Core/ChannelArgument.swift

@@ -0,0 +1,156 @@
+/*
+ * Copyright 2018, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#if SWIFT_PACKAGE
+  import CgRPC
+#endif
+import Foundation // for String.Encoding
+
+public extension Channel {
+  enum Argument {
+    /// Default authority to pass if none specified on call construction.
+    case defaultAuthority(String)
+
+    /// Primary user agent. Goes at the start of the user-agent metadata sent
+    /// on each request.
+    case primaryUserAgent(String)
+
+    /// Secondary user agent. Goes at the end of the user-agent metadata sent
+    /// on each request.
+    case secondaryUserAgent(String)
+
+    /// After a duration of this time, the client/server pings its peer to see
+    /// if the transport is still alive.
+    case keepAliveTime(TimeInterval)
+
+    /// After waiting for a duration of this time, if the keepalive ping sender does
+    /// not receive the ping ack, it will close the transport.
+    case keepAliveTimeout(TimeInterval)
+
+    /// Is it permissible to send keepalive pings without any outstanding streams?
+    case keepAlivePermitWithoutCalls(Bool)
+
+    /// The time between the first and second connection attempts.
+    case reconnectBackoffInitial(TimeInterval)
+
+    /// The minimum time between subsequent connection attempts.
+    case reconnectBackoffMin(TimeInterval)
+
+    /// The maximum time between subsequent connection attempts.
+    case reconnectBackoffMax(TimeInterval)
+
+    /// Should we allow receipt of true-binary data on http2 connections?
+    /// Defaults to on (true)
+    case http2EnableTrueBinary(Bool)
+
+    /// Minimum time between sending successive ping frames without receiving
+    /// any data frame.
+    case http2MinSentPingInterval(TimeInterval)
+
+    /// Number of pings before needing to send a data frame or header frame.
+    /// `0` indicates that an infinite number of pings can be sent without
+    /// sending a data frame or header frame.
+    case http2MaxPingsWithoutData(UInt)
+
+    /// This *should* be used for testing only.
+    /// Override the target name used for SSL host name checking using this
+    /// channel argument. If this argument is not specified, the name used
+    /// for SSL host name checking will be the target parameter (assuming that the
+    /// secure channel is an SSL channel). If this parameter is specified and the
+    /// underlying is not an SSL channel, it will just be ignored.
+    case sslTargetNameOverride(String)
+  }
+}
+
+extension Channel.Argument {
+  class Wrapper {
+    // Creating a `grpc_arg` allocates memory. This wrapper ensures that the memory is freed after use.
+    let wrapped: grpc_arg
+    
+    init(_ wrapped: grpc_arg) {
+      self.wrapped = wrapped
+    }
+    
+    deinit {
+      gpr_free(wrapped.key)
+      if wrapped.type == GRPC_ARG_STRING {
+        gpr_free(wrapped.value.string)
+      }
+    }
+  }
+  
+  func toCArg() -> Wrapper {
+    switch self {
+    case let .defaultAuthority(value):
+      return makeArgument("grpc.default_authority", value: value)
+    case let .primaryUserAgent(value):
+      return makeArgument("grpc.primary_user_agent", value: value)
+    case let .secondaryUserAgent(value):
+      return makeArgument("grpc.secondary_user_agent", value: value)
+    case let .keepAliveTime(value):
+      return makeArgument("grpc.keepalive_time_ms", value: value * 1_000)
+    case let .keepAliveTimeout(value):
+      return makeArgument("grpc.keepalive_timeout_ms", value: value * 1_000)
+    case let .keepAlivePermitWithoutCalls(value):
+      return makeArgument("grpc.keepalive_permit_without_calls", value: value)
+    case let .reconnectBackoffMin(value):
+      return makeArgument("grpc.min_reconnect_backoff_ms", value: value * 1_000)
+    case let .reconnectBackoffMax(value):
+      return makeArgument("grpc.max_reconnect_backoff_ms", value: value * 1_000)
+    case let .reconnectBackoffInitial(value):
+      return makeArgument("grpc.initial_reconnect_backoff_ms", value: value * 1_000)
+    case let .http2EnableTrueBinary(value):
+      return makeArgument("grpc.http2.true_binary", value: value)
+    case let .http2MinSentPingInterval(value):
+      return makeArgument("grpc.http2.min_time_between_pings_ms", value: value * 1_000)
+    case let .http2MaxPingsWithoutData(value):
+      return makeArgument("grpc.http2.max_pings_without_data", value: value)
+    case let .sslTargetNameOverride(value):
+      return makeArgument("grpc.ssl_target_name_override", value: value)
+    }
+  }
+}
+
+private func makeArgument(_ key: String, value: String) -> Channel.Argument.Wrapper {
+  var arg = grpc_arg()
+  arg.key = gpr_strdup(key)
+  arg.type = GRPC_ARG_STRING
+  arg.value.string = gpr_strdup(value)
+  return Channel.Argument.Wrapper(arg)
+}
+
+private func makeArgument(_ key: String, value: Bool) -> Channel.Argument.Wrapper {
+  return makeArgument(key, value: Int32(value ? 1 : 0))
+}
+
+private func makeArgument(_ key: String, value: Double) -> Channel.Argument.Wrapper {
+  return makeArgument(key, value: Int32(value))
+}
+
+private func makeArgument(_ key: String, value: UInt) -> Channel.Argument.Wrapper {
+  return makeArgument(key, value: Int32(value))
+}
+
+private func makeArgument(_ key: String, value: Int) -> Channel.Argument.Wrapper {
+  return makeArgument(key, value: Int32(value))
+}
+
+private func makeArgument(_ key: String, value: Int32) -> Channel.Argument.Wrapper {
+  var arg = grpc_arg()
+  arg.key = gpr_strdup(key)
+  arg.type = GRPC_ARG_INTEGER
+  arg.value.integer = value
+  return Channel.Argument.Wrapper(arg)
+}

+ 14 - 0
Sources/SwiftGRPC/Core/Metadata.swift

@@ -110,3 +110,17 @@ public class Metadata: CustomStringConvertible {
     return underlyingArray
   }
 }
+
+extension Metadata {
+  public subscript(_ key: String) -> String? {
+    for i in 0..<self.count() {
+      let currentKey = self.key(i)
+      guard currentKey == key
+        else { continue }
+      
+      return self.value(i)
+    }
+    
+    return nil
+  }
+}

+ 7 - 8
Sources/SwiftGRPC/Runtime/ServiceClient.swift

@@ -49,23 +49,22 @@ open class ServiceClientBase: ServiceClient {
   }
 
   /// Create a client.
-  public init(address: String, secure: Bool = true) {
+  public init(address: String, secure: Bool = true, arguments: [Channel.Argument] = []) {
     gRPC.initialize()
-    channel = Channel(address: address, secure: secure)
+    channel = Channel(address: address, secure: secure, arguments: arguments)
     metadata = Metadata()
   }
 
   /// Create a client using a pre-defined channel.
   public init(channel: Channel) {
-    gRPC.initialize()
     self.channel = channel
-    self.metadata = Metadata()
+    metadata = Metadata()
   }
-
-  /// Create a client that makes secure connections with a custom certificate and (optional) hostname.
-  public init(address: String, certificates: String, host: String?) {
+  
+  /// Create a client that makes secure connections with a custom certificate.
+  public init(address: String, certificates: String, arguments: [Channel.Argument] = []) {
     gRPC.initialize()
-    channel = Channel(address: address, certificates: certificates, host: host)
+    channel = Channel(address: address, certificates: certificates, arguments: arguments)
     metadata = Metadata()
   }
 }

+ 1 - 0
Tests/LinuxMain.swift

@@ -18,6 +18,7 @@ import XCTest
 
 XCTMain([
   testCase(gRPCTests.allTests),
+  testCase(ChannelArgumentTests.allTests),
   testCase(ClientCancellingTests.allTests),
   testCase(ClientTestExample.allTests),
   testCase(ClientTimeoutTests.allTests),

+ 1 - 1
Tests/SwiftGRPCTests/BasicEchoTestCase.swift

@@ -53,7 +53,7 @@ class BasicEchoTestCase: XCTestCase {
                                keyString: String(data: keyForTests, encoding: .utf8)!,
                                provider: provider)
       server.start(queue: DispatchQueue.global())
-      client = Echo_EchoServiceClient(address: address, certificates: certificateString, host: "example.com")
+      client = Echo_EchoServiceClient(address: address, certificates: certificateString, arguments: [.sslTargetNameOverride("example.com")])
       client.host = "example.com"
     } else {
       server = Echo_EchoServer(address: address, provider: provider)

+ 96 - 0
Tests/SwiftGRPCTests/ChannelArgumentTests.swift

@@ -0,0 +1,96 @@
+/*
+ * Copyright 2018, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#if SWIFT_PACKAGE
+import CgRPC
+#endif
+import Foundation
+@testable import SwiftGRPC
+import XCTest
+
+fileprivate class ChannelArgumentTestProvider: Echo_EchoProvider {
+  func get(request: Echo_EchoRequest, session: Echo_EchoGetSession) throws -> Echo_EchoResponse {
+    // We simply return the user agent we received, which can then be inspected by the test code.
+    return Echo_EchoResponse(text: (session as! ServerSessionBase).handler.requestMetadata["user-agent"]!)
+  }
+  
+  func expand(request: Echo_EchoRequest, session: Echo_EchoExpandSession) throws {
+    fatalError("not implemented")
+  }
+  
+  func collect(session: Echo_EchoCollectSession) throws {
+    fatalError("not implemented")
+  }
+  
+  func update(session: Echo_EchoUpdateSession) throws {
+    fatalError("not implemented")
+  }
+}
+
+class ChannelArgumentTests: BasicEchoTestCase {
+  static var allTests: [(String, (ChannelArgumentTests) -> () throws -> Void)] {
+    return [
+      ("testArgumentKey", testArgumentKey),
+      ("testStringArgument", testStringArgument),
+      ("testIntegerArgument", testIntegerArgument),
+      ("testBoolArgument", testBoolArgument),
+      ("testTimeIntervalArgument", testTimeIntervalArgument),
+    ]
+  }
+  
+  fileprivate func makeClient(_ arguments: [Channel.Argument]) -> Echo_EchoServiceClient {
+    let client = Echo_EchoServiceClient(address: address, secure: false, arguments: arguments)
+    client.timeout = defaultTimeout
+    return client
+  }
+  
+  override func makeProvider() -> Echo_EchoProvider { return ChannelArgumentTestProvider() }
+}
+
+extension ChannelArgumentTests {
+  func testArgumentKey() {
+    let argument = Channel.Argument.defaultAuthority("default")
+    XCTAssertEqual(String(cString: argument.toCArg().wrapped.key), "grpc.default_authority")
+  }
+
+  func testStringArgument() {
+    let argument = Channel.Argument.primaryUserAgent("Primary/0.1")
+    XCTAssertEqual(String(cString: argument.toCArg().wrapped.value.string), "Primary/0.1")
+  }
+
+  func testIntegerArgument() {
+    let argument = Channel.Argument.http2MaxPingsWithoutData(5)
+    XCTAssertEqual(argument.toCArg().wrapped.value.integer, 5)
+  }
+
+  func testBoolArgument() {
+    let argument = Channel.Argument.keepAlivePermitWithoutCalls(true)
+    XCTAssertEqual(argument.toCArg().wrapped.value.integer, 1)
+  }
+
+  func testTimeIntervalArgument() {
+    let argument = Channel.Argument.keepAliveTime(2.5)
+    XCTAssertEqual(argument.toCArg().wrapped.value.integer, 2500) // in ms
+  }
+}
+
+extension ChannelArgumentTests {
+  func testPracticalUse() {
+    let client = makeClient([.primaryUserAgent("FOO"), .secondaryUserAgent("BAR")])
+    let responseText = try! client.get(Echo_EchoRequest(text: "")).text
+    XCTAssertTrue(responseText.hasPrefix("FOO "), "user agent \(responseText) should begin with 'FOO '")
+    XCTAssertTrue(responseText.hasSuffix(" BAR"), "user agent \(responseText) should end with ' BAR'")
+  }
+}

+ 1 - 1
Tests/SwiftGRPCTests/GRPCTests.swift

@@ -146,7 +146,7 @@ func runClient(useSSL: Bool) throws {
   if useSSL {
     channel = Channel(address: address,
                       certificates: String(data: certificateForTests, encoding: .utf8)!,
-                      host: host)
+                      arguments: [.sslTargetNameOverride(host)])
   } else {
     channel = Channel(address: address, secure: false)
   }