ソースを参照

Add channel arguments

SebastianThiebaud 7 年 前
コミット
94e599bbc0

+ 32 - 2
Sources/CgRPC/shim/cgrpc.h

@@ -118,6 +118,31 @@ typedef struct grpc_event {
   void *tag;
 } grpc_event;
 
+typedef enum grpc_arg_type {
+  GRPC_ARG_STRING,
+  GRPC_ARG_INTEGER,
+  GRPC_ARG_POINTER
+} grpc_arg_type;
+
+typedef struct grpc_arg_pointer_vtable {
+  void* (*copy)(void* p);
+  void (*destroy)(void* p);
+  int (*cmp)(void* p, void* q);
+} grpc_arg_pointer_vtable;
+
+typedef struct grpc_arg {
+  grpc_arg_type type;
+  char* key;
+  union grpc_arg_value {
+    char* string;
+    int integer;
+    struct grpc_arg_pointer {
+      void* p;
+      const grpc_arg_pointer_vtable* vtable;
+    } pointer;
+  } value;
+} grpc_arg;
+
 #endif
 
 // directly expose a few grpc library functions
@@ -126,6 +151,8 @@ void grpc_shutdown(void);
 const char *grpc_version_string(void);
 const char *grpc_g_stands_for(void);
 
+char* gpr_strdup(const char* src);
+
 void cgrpc_completion_queue_drain(cgrpc_completion_queue *cq);
 void grpc_completion_queue_destroy(cgrpc_completion_queue *cq);
 
@@ -133,10 +160,13 @@ void grpc_completion_queue_destroy(cgrpc_completion_queue *cq);
 void cgrpc_free_copied_string(char *string);
 
 // channel support
-cgrpc_channel *cgrpc_channel_create(const char *address);
+cgrpc_channel *cgrpc_channel_create(const char *address, 
+                                    grpc_arg *args,
+                                    int number_args);
 cgrpc_channel *cgrpc_channel_create_secure(const char *address,
                                            const char *pem_root_certs,
-                                           const char *host);
+                                           grpc_arg *args,
+                                           int number_args);
 
 void cgrpc_channel_destroy(cgrpc_channel *channel);
 cgrpc_call *cgrpc_channel_create_call(cgrpc_channel *channel,

+ 15 - 25
Sources/CgRPC/shim/channel.c

@@ -23,44 +23,34 @@
 #include <stdlib.h>
 #include <string.h>
 
-cgrpc_channel *cgrpc_channel_create(const char *address) {
+cgrpc_channel *cgrpc_channel_create(const char *address,
+                                    grpc_arg *args,
+                                    int number_args) {
   cgrpc_channel *c = (cgrpc_channel *) malloc(sizeof (cgrpc_channel));
+
+  grpc_channel_args *channel_args = gpr_malloc(sizeof(grpc_channel_args));
+  channel_args->args = args;
+  channel_args->num_args = number_args;
+
   // create the channel
-  grpc_channel_args channel_args;
-  channel_args.num_args = 0;
-  c->channel = grpc_insecure_channel_create(address, &channel_args, NULL);
+  c->channel = grpc_insecure_channel_create(address, channel_args, NULL);
   c->completion_queue = grpc_completion_queue_create_for_next(NULL);
   return c;
 }
 
 cgrpc_channel *cgrpc_channel_create_secure(const char *address,
                                            const char *pem_root_certs,
-                                           const char *host) {
+                                           grpc_arg *args,
+                                           int number_args) {
   cgrpc_channel *c = (cgrpc_channel *) malloc(sizeof (cgrpc_channel));
   // create the channel
 
-  int argMax = 2;
-  grpc_channel_args *channelArgs = gpr_malloc(sizeof(grpc_channel_args));
-  channelArgs->args = gpr_malloc(argMax * sizeof(grpc_arg));
-
-  int argCount = 1;
-  grpc_arg *arg = &channelArgs->args[0];
-  arg->type = GRPC_ARG_STRING;
-  arg->key = gpr_strdup("grpc.primary_user_agent");
-  arg->value.string = gpr_strdup("grpc-swift/0.0.1");
-
-  if (host) {
-    argCount++;
-    arg = &channelArgs->args[1];
-    arg->type = GRPC_ARG_STRING;
-    arg->key = gpr_strdup("grpc.ssl_target_name_override");
-    arg->value.string = gpr_strdup(host);
-  }
-
-  channelArgs->num_args = argCount;
+  grpc_channel_args *channel_args = gpr_malloc(sizeof(grpc_channel_args));
+  channel_args->args = args;
+  channel_args->num_args = number_args;
 
   grpc_channel_credentials *creds = grpc_ssl_credentials_create(pem_root_certs, NULL, NULL);
-  c->channel = grpc_secure_channel_create(creds, address, channelArgs, NULL);
+  c->channel = grpc_secure_channel_create(creds, address, channel_args, NULL);
   c->completion_queue = grpc_completion_queue_create_for_next(NULL);
   return c;
 }

+ 3 - 1
Sources/Examples/Echo/main.swift

@@ -38,9 +38,11 @@ func buildEchoService(_ ssl: Bool, _ address: String, _ port: String, _: String)
   if ssl {
     let certificateURL = URL(fileURLWithPath: "ssl.crt")
     let certificates = try! String(contentsOf: certificateURL)
+    let args: [Arg] = [.sslTargetNameOverride("example.com")]
+
     service = Echo_EchoServiceClient(address: address + ":" + port,
                                certificates: certificates,
-                               host: "example.com")
+                               args: args)
     service.host = "example.com"
   } else {
     service = Echo_EchoServiceClient(address: address + ":" + port, secure: false)

+ 138 - 0
Sources/SwiftGRPC/Core/Arg.swift

@@ -0,0 +1,138 @@
+/*
+ * Copyright 2016, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#if SWIFT_PACKAGE
+  import CgRPC
+#endif
+import Foundation // for String.Encoding
+
+public enum Arg {
+  /// Default authority to pass if none specified on call construction.
+  case defaultAuthority(String)
+
+  /// Primary user agent. Goes at the start of the user-agent metadata sent
+  /// on each request.
+  case primaryUserAgent(String)
+
+  /// Secondary user agent. Goes at the end of the user-agent metadata sent
+  /// on each request.
+  case secondaryUserAgent(String)
+
+  /// After a duration of this time, the client/server pings its peer to see
+  /// if the transport is still alive.
+  case keepAliveTime(TimeInterval)
+
+  /// After waiting for a duration of this time, if the keepalive ping sender does
+  /// not receive the ping ack, it will close the transport.
+  case keepAliveTimeout(TimeInterval)
+
+  /// Is it permissible to send keepalive pings without any outstanding streams?
+  case keepAlivePermitWithoutCalls(Bool)
+
+  /// The time between the first and second connection attempts.
+  case reconnectBackoffInitial(TimeInterval)
+
+  /// The minimum time between subsequent connection attempts.
+  case reconnectBackoffMin(TimeInterval)
+
+  /// The maximum time between subsequent connection attempts.
+  case reconnectBackoffMax(TimeInterval)
+
+  /// Should we allow receipt of true-binary data on http2 connections?
+  /// Defaults to on (true)
+  case http2EnableTrueBinary(Bool)
+
+  /// Minimum time between sending successive ping frames without receiving
+  /// any data frame.
+  case http2MinSentPingInterval(TimeInterval)
+
+  /// Number of pings before needing to send a data frame or header frame.
+  /// `0` indicates that an infinite number of pings can be sent without
+  /// sending a data frame or header frame.
+  case http2MaxPingsWithoutData(UInt)
+
+  /// This *should* be used for testing only.
+  /// Override the target name used for SSL host name checking using this
+  /// channel argument. If this argument is not specified, the name used
+  /// for SSL host name checking will be the target parameter (assuming that the
+  /// secure channel is an SSL channel). If this parameter is specified and the
+  /// underlying is not an SSL channel, it will just be ignored.
+  case sslTargetNameOverride(String)
+}
+
+extension Arg {
+  func toCArg() -> grpc_arg {
+    switch self {
+    case let .defaultAuthority(value):
+      return arg("grpc.default_authority", value: value)
+    case let .primaryUserAgent(value):
+      return arg("grpc.primary_user_agent", value: value)
+    case let .secondaryUserAgent(value):
+      return arg("grpc.secondary_user_agent", value: value)
+    case let .keepAliveTime(value):
+      return arg("grpc.keepalive_time_ms", value: value * 1_000)
+    case let .keepAliveTimeout(value):
+      return arg("grpc.keepalive_timeout_ms", value: value * 1_000)
+    case let .keepAlivePermitWithoutCalls(value):
+      return arg("grpc.keepalive_permit_without_calls", value: value)
+    case let .reconnectBackoffMin(value):
+      return arg("grpc.min_reconnect_backoff_ms", value: value * 1_000)
+    case let .reconnectBackoffMax(value):
+      return arg("grpc.max_reconnect_backoff_ms", value: value * 1_000)
+    case let .reconnectBackoffInitial(value):
+      return arg("grpc.initial_reconnect_backoff_ms", value: value * 1_000)
+    case let .http2EnableTrueBinary(value):
+      return arg("grpc.http2.true_binary", value: value)
+    case let .http2MinSentPingInterval(value):
+      return arg("grpc.http2.min_time_between_pings_ms", value: value * 1_000)
+    case let .http2MaxPingsWithoutData(value):
+      return arg("grpc.http2.max_pings_without_data", value: value)
+    case let .sslTargetNameOverride(value):
+      return arg("grpc.ssl_target_name_override", value: value)
+    }
+  }
+
+  private func arg(_ key: String, value: String) -> grpc_arg {
+    var arg = grpc_arg()
+    arg.key = gpr_strdup(key)
+    arg.type = GRPC_ARG_STRING
+    arg.value.string = gpr_strdup(value)
+    return arg
+  }
+
+  private func arg(_ key: String, value: Bool) -> grpc_arg {
+    return arg(key, value: Int32(value ? 1 : 0))
+  }
+
+  private func arg(_ key: String, value: Double) -> grpc_arg {
+    return arg(key, value: Int32(value))
+  }
+
+  private func arg(_ key: String, value: UInt) -> grpc_arg {
+    return arg(key, value: Int32(value))
+  }
+
+  private func arg(_ key: String, value: Int) -> grpc_arg {
+    return arg(key, value: Int32(value))
+  }
+
+  private func arg(_ key: String, value: Int32) -> grpc_arg {
+    var arg = grpc_arg()
+    arg.key = gpr_strdup(key)
+    arg.type = GRPC_ARG_INTEGER
+    arg.value.integer = value
+    return arg
+  }
+}

+ 11 - 8
Sources/SwiftGRPC/Core/Channel.swift

@@ -40,13 +40,15 @@ public class Channel {
   ///
   /// - Parameter address: the address of the server to be called
   /// - Parameter secure: if true, use TLS
-  public init(address: String, secure: Bool = true) {
-    gRPC.initialize()
+  /// - Parameter args: list of arguments
+  public init(address: String, secure: Bool = true, args: [Arg] = []) {
     host = address
+    var cargs = args.map({ $0.toCArg() })
+
     if secure {
-      underlyingChannel = cgrpc_channel_create_secure(address, roots_pem(), nil)
+      underlyingChannel = cgrpc_channel_create_secure(address, roots_pem(), &cargs, Int32(args.count))
     } else {
-      underlyingChannel = cgrpc_channel_create(address)
+      underlyingChannel = cgrpc_channel_create(address, &cargs, Int32(args.count))
     }
     completionQueue = CompletionQueue(
       underlyingCompletionQueue: cgrpc_channel_completion_queue(underlyingChannel), name: "Client")
@@ -57,11 +59,12 @@ public class Channel {
   ///
   /// - Parameter address: the address of the server to be called
   /// - Parameter certificates: a PEM representation of certificates to use
-  /// - Parameter host: an optional hostname override
-  public init(address: String, certificates: String, host: String?) {
-    gRPC.initialize()
+  /// - Parameter args: list of arguments
+  public init(address: String, certificates: String, args: [Arg] = []) {
     self.host = address
-    underlyingChannel = cgrpc_channel_create_secure(address, certificates, host)
+    var cargs = args.map({ $0.toCArg() })
+
+    underlyingChannel = cgrpc_channel_create_secure(address, certificates, &cargs, Int32(args.count))
     completionQueue = CompletionQueue(
       underlyingCompletionQueue: cgrpc_channel_completion_queue(underlyingChannel), name: "Client")
     completionQueue.run() // start a loop that watches the channel's completion queue

+ 5 - 12
Sources/SwiftGRPC/Runtime/ServiceClient.swift

@@ -49,23 +49,16 @@ open class ServiceClientBase: ServiceClient {
   }
 
   /// Create a client.
-  public init(address: String, secure: Bool = true) {
+  public init(address: String, secure: Bool = true, args: [Arg] = []) {
     gRPC.initialize()
-    channel = Channel(address: address, secure: secure)
+    channel = Channel(address: address, secure: secure, args: args)
     metadata = Metadata()
   }
 
-  /// Create a client using a pre-defined channel.
-  public init(channel: Channel) {
+  /// Create a client that makes secure connections with a custom certificate.
+  public init(address: String, certificates: String, args: [Arg] = []) {
     gRPC.initialize()
-    self.channel = channel
-    self.metadata = Metadata()
-  }
-
-  /// Create a client that makes secure connections with a custom certificate and (optional) hostname.
-  public init(address: String, certificates: String, host: String?) {
-    gRPC.initialize()
-    channel = Channel(address: address, certificates: certificates, host: host)
+    channel = Channel(address: address, certificates: certificates, args: args)
     metadata = Metadata()
   }
 }

+ 1 - 1
Tests/SwiftGRPCTests/BasicEchoTestCase.swift

@@ -53,7 +53,7 @@ class BasicEchoTestCase: XCTestCase {
                                keyString: String(data: keyForTests, encoding: .utf8)!,
                                provider: provider)
       server.start(queue: DispatchQueue.global())
-      client = Echo_EchoServiceClient(address: address, certificates: certificateString, host: "example.com")
+      client = Echo_EchoServiceClient(address: address, certificates: certificateString, args: [Arg.sslTargetNameOverride("example.com")])
       client.host = "example.com"
     } else {
       server = Echo_EchoServer(address: address, provider: provider)

+ 1 - 1
Tests/SwiftGRPCTests/GRPCTests.swift

@@ -146,7 +146,7 @@ func runClient(useSSL: Bool) throws {
   if useSSL {
     channel = Channel(address: address,
                       certificates: String(data: certificateForTests, encoding: .utf8)!,
-                      host: host)
+                      args: [Arg.sslTargetNameOverride(host)])
   } else {
     channel = Channel(address: address, secure: false)
   }