Browse Source

Add tests for mutual auth

Nate Armstrong 7 years ago
parent
commit
18af90f2e0

+ 3 - 1
Sources/CgRPC/shim/cgrpc.h

@@ -190,7 +190,9 @@ void cgrpc_channel_watch_connectivity_state(cgrpc_channel *channel,
 cgrpc_server *cgrpc_server_create(const char *address);
 cgrpc_server *cgrpc_server_create_secure(const char *address,
                                          const char *private_key,
-                                         const char *cert_chain);
+                                         const char *cert_chain,
+                                         const char *root_certs,
+                                         int force_client_auth);
 void cgrpc_server_stop(cgrpc_server *server);
 void cgrpc_server_destroy(cgrpc_server *s);
 void cgrpc_server_start(cgrpc_server *s);

+ 5 - 3
Sources/CgRPC/shim/server.c

@@ -33,7 +33,9 @@ cgrpc_server *cgrpc_server_create(const char *address) {
 
 cgrpc_server *cgrpc_server_create_secure(const char *address,
                                          const char *private_key,
-                                         const char *cert_chain) {
+                                         const char *cert_chain,
+                                         const char *root_certs,
+                                         int force_client_auth) {
   cgrpc_server *server = (cgrpc_server *) malloc(sizeof (cgrpc_server));
   server->server = grpc_server_create(NULL, NULL);
   server->completion_queue = grpc_completion_queue_create_for_next(NULL);
@@ -44,10 +46,10 @@ cgrpc_server *cgrpc_server_create_secure(const char *address,
   server_credentials.cert_chain = cert_chain;
 
   grpc_server_credentials *credentials = grpc_ssl_server_credentials_create
-  (NULL,
+  (root_certs,
    &server_credentials,
    1,
-   0,
+   force_client_auth,
    NULL);
   
   // prepare the server to listen

+ 4 - 4
Sources/Examples/Echo/Generated/echo.grpc.swift

@@ -313,14 +313,14 @@ internal final class Echo_EchoServer: ServiceServer {
     super.init(address: address)
   }
 
-  internal init?(address: String, certificateURL: URL, keyURL: URL, provider: Echo_EchoProvider) {
+  internal init?(address: String, certificateURL: URL, keyURL: URL, rootCertsURL: URL? = nil, provider: Echo_EchoProvider) {
     self.provider = provider
-    super.init(address: address, certificateURL: certificateURL, keyURL: keyURL)
+    super.init(address: address, certificateURL: certificateURL, keyURL: keyURL, rootCertsURL: rootCertsURL)
   }
 
-  internal init?(address: String, certificateString: String, keyString: String, provider: Echo_EchoProvider) {
+  internal init?(address: String, certificateString: String, keyString: String, rootCerts: String? = nil, provider: Echo_EchoProvider) {
     self.provider = provider
-    super.init(address: address, certificateString: certificateString, keyString: keyString)
+    super.init(address: address, certificateString: certificateString, keyString: keyString, rootCerts: rootCerts)
   }
 
   /// Determines and calls the appropriate request handler, depending on the request's method.

+ 2 - 2
Sources/SwiftGRPC/Core/Server.swift

@@ -50,8 +50,8 @@ public class Server {
   /// - Parameter address: the address where the server will listen
   /// - Parameter key: the private key for the server's certificates
   /// - Parameter certs: the server's certificates
-  public init(address: String, key: String, certs: String) {
-    underlyingServer = cgrpc_server_create_secure(address, key, certs)
+  public init(address: String, key: String, certs: String, rootCerts: String? = nil) {
+    underlyingServer = cgrpc_server_create_secure(address, key, certs, rootCerts, rootCerts == nil ? 0 : 1)
     completionQueue = CompletionQueue(
       underlyingCompletionQueue: cgrpc_server_get_completion_queue(underlyingServer), name: "Server " + address)
   }

+ 11 - 4
Sources/SwiftGRPC/Runtime/ServiceServer.swift

@@ -32,20 +32,27 @@ open class ServiceServer {
   }
 
   /// Create a server that accepts secure connections.
-  public init(address: String, certificateString: String, keyString: String) {
+  public init(address: String, certificateString: String, keyString: String, rootCerts: String? = nil) {
     gRPC.initialize()
     self.address = address
-    server = Server(address: address, key: keyString, certs: certificateString)
+    server = Server(address: address, key: keyString, certs: certificateString, rootCerts: rootCerts)
   }
 
   /// Create a server that accepts secure connections.
-  public init?(address: String, certificateURL: URL, keyURL: URL) {
+  public init?(address: String, certificateURL: URL, keyURL: URL, rootCertsURL: URL?) {
     guard let certificate = try? String(contentsOf: certificateURL, encoding: .utf8),
       let key = try? String(contentsOf: keyURL, encoding: .utf8)
       else { return nil }
+    var rootCerts: String?
+    if let rootCertsURL = rootCertsURL {
+      guard let rootCertsString = try? String(contentsOf: rootCertsURL, encoding: .utf8) else {
+        return nil
+      }
+      rootCerts = rootCertsString
+    }
     gRPC.initialize()
     self.address = address
-    server = Server(address: address, key: key, certs: certificate)
+    server = Server(address: address, key: key, certs: certificate, rootCerts: rootCerts)
   }
 
   public enum HandleMethodError: Error {

+ 4 - 4
Sources/protoc-gen-swiftgrpc/Generator-Server.swift

@@ -76,17 +76,17 @@ extension Generator {
     outdent()
     println("}")
     println()
-    println("\(access) init?(address: String, certificateURL: URL, keyURL: URL, provider: \(providerName)) {")
+    println("\(access) init?(address: String, certificateURL: URL, keyURL: URL, rootCertsURL: URL? = nil, provider: \(providerName)) {")
     indent()
     println("self.provider = provider")
-    println("super.init(address: address, certificateURL: certificateURL, keyURL: keyURL)")
+    println("super.init(address: address, certificateURL: certificateURL, keyURL: keyURL, rootCertsURL: rootCertsURL)")
     outdent()
     println("}")
     println()
-    println("\(access) init?(address: String, certificateString: String, keyString: String, provider: \(providerName)) {")
+    println("\(access) init?(address: String, certificateString: String, keyString: String, rootCerts: String? = nil, provider: \(providerName)) {")
     indent()
     println("self.provider = provider")
-    println("super.init(address: address, certificateString: certificateString, keyString: keyString)")
+    println("super.init(address: address, certificateString: certificateString, keyString: keyString, rootCerts: rootCerts)")
     outdent()
     println("}")
     println()

+ 25 - 8
Tests/SwiftGRPCTests/BasicEchoTestCase.swift

@@ -31,6 +31,12 @@ extension Echo_EchoResponse {
 }
 
 class BasicEchoTestCase: XCTestCase {
+  enum Security {
+    case none
+    case ssl
+    case tlsMutualAuth
+  }
+
   func makeProvider() -> Echo_EchoProvider { return EchoProvider() }
 
   var provider: Echo_EchoProvider!
@@ -38,29 +44,40 @@ class BasicEchoTestCase: XCTestCase {
   var client: Echo_EchoServiceClient!
   
   var defaultTimeout: TimeInterval { return 1.0 }
-  var secure: Bool { return false }
+  var security: Security { return .none }
   var address: String { return "localhost:5050" }
 
   override func setUp() {
     super.setUp()
     
     provider = makeProvider()
-    
-    if secure {
-      let certificateString = String(data: certificateForTests, encoding: .utf8)!
+
+    let certificateString = String(data: certificateForTests, encoding: .utf8)!
+    let keyString = String(data: keyForTests, encoding: .utf8)!
+    let rootCerts = String(data: trustCollectionCertificateForTests, encoding: .utf8)!
+    let clientCertificateString = String(data: clientCertificateForTests, encoding: .utf8)!
+    let clientKeyString = String(data: clientKeyForTests, encoding: .utf8)!
+
+    switch security {
+    case .ssl:
       server = Echo_EchoServer(address: address,
                                certificateString: certificateString,
-                               keyString: String(data: keyForTests, encoding: .utf8)!,
+                               keyString: keyString,
                                provider: provider)
       server.start()
-      client = Echo_EchoServiceClient(address: address, certificates: certificateString, arguments: [.sslTargetNameOverride("example.com")])
+      client = Echo_EchoServiceClient(address: address, certificates: rootCerts, arguments: [.sslTargetNameOverride("example.com")])
       client.host = "example.com"
-    } else {
+    case .tlsMutualAuth:
+      server = Echo_EchoServer(address: address, certificateString: certificateString, keyString: keyString, rootCerts: rootCerts, provider: provider)
+      server.start()
+      client = Echo_EchoServiceClient(address: address, certificates: rootCerts, clientCertificates: clientCertificateString, clientKey: clientKeyString, arguments: [.sslTargetNameOverride("example.com")])
+      client.host = "example.com"
+    case .none:
       server = Echo_EchoServer(address: address, provider: provider)
       server.start()
       client = Echo_EchoServiceClient(address: address, secure: false)
     }
-    
+
     client.timeout = defaultTimeout
   }
   

+ 5 - 1
Tests/SwiftGRPCTests/EchoTests.swift

@@ -38,7 +38,11 @@ class EchoTests: BasicEchoTestCase {
 }
 
 class EchoTestsSecure: EchoTests {
-  override var secure: Bool { return true }
+  override var security: Security { return .ssl }
+}
+
+class EchoTestsMutualAuth: EchoTests {
+  override var security: Security { return .tlsMutualAuth }
 }
 
 extension EchoTests {

+ 1 - 1
Tests/SwiftGRPCTests/GRPCTests.swift

@@ -145,7 +145,7 @@ func runClient(useSSL: Bool) throws {
 
   if useSSL {
     channel = Channel(address: address,
-                      certificates: String(data: certificateForTests, encoding: .utf8)!,
+                      certificates: String(data: trustCollectionCertificateForTests, encoding: .utf8)!,
                       arguments: [.sslTargetNameOverride(host)])
   } else {
     channel = Channel(address: address, secure: false)

File diff suppressed because it is too large
+ 0 - 1
Tests/SwiftGRPCTests/TestKeys.swift


+ 33 - 7
scripts/makecert

@@ -1,8 +1,34 @@
-#!/bin/sh
-# 
-# Use this script to create a self-signed certificate (ssl.crt) and key file (ssl.key)
+#!/bin/bash
+#
+# Creates a trust collection certificate (ca.crt)
+# and self-signed server certificate (server.crt) and private key (server.pem)
+# and client certificate (client.crt) and key file (client.pem) for mutual TLS.
 # Replace "example.com" with the host name you'd like for your certificate.
-# 
-openssl genrsa -out ssl.key 2048
-openssl req -new -key ssl.key -out ssl.csr -subj "/CN=example.com"
-openssl x509 -req -days 365 -in ssl.csr -signkey ssl.key -out ssl.crt
+#
+# https://github.com/grpc/grpc-java/tree/master/examples
+#
+SIZE=2048
+
+CN_CA=foo
+CN_SERVER=example.com
+CN_CLIENT=localhost
+
+# CA
+openssl genrsa -out ca.key $SIZE
+openssl req -new -x509 -days 365 -key ca.key -out ca.crt -subj "/CN=${CN_CA}"
+
+# Server
+openssl genrsa -out server.key $SIZE
+openssl req -new -key server.key -out server.csr -subj "/CN=${CN_SERVER}"
+openssl x509 -req -days 365 -in server.csr -CA ca.crt -CAkey ca.key -set_serial 01 -out server.crt
+
+# Client
+openssl genrsa -out client.key $SIZE
+openssl req -new -key client.key -out client.csr -subj "/CN=${CN_CLIENT}"
+openssl x509 -req -days 365 -in client.csr -CA ca.crt -CAkey ca.key -set_serial 01 -out client.crt
+
+# netty only supports PKCS8 keys. openssl is used to convert from PKCS1 to PKCS8
+# http://netty.io/wiki/sslcontextbuilder-and-private-key.html
+# Generates client.pem which is the clientPrivateKeyFile for the Client (needed for mutual TLS only)
+openssl pkcs8 -topk8 -nocrypt -in client.key -out client.pem
+openssl pkcs8 -topk8 -nocrypt -in server.key -out server.pem

Some files were not shown because too many files changed in this diff