فهرست منبع

Added support for NIOSSLCustomVerificationCallback for client connection (#1107)

This allows client apps to perform SSL Public Key Pinning, or override the certificate verification logic
Franck CLEMENT 4 سال پیش
والد
کامیت
e70c2cff9c

+ 16 - 6
Sources/GRPC/ClientConnection.swift

@@ -411,7 +411,8 @@ extension Channel {
     connectionIdleTimeout: TimeAmount,
     errorDelegate: ClientErrorDelegate?,
     requiresZeroLengthWriteWorkaround: Bool,
-    logger: Logger
+    logger: Logger,
+    customVerificationCallback: NIOSSLCustomVerificationCallback?
   ) -> EventLoopFuture<Void> {
     // We add at most 8 handlers to the pipeline.
     var handlers: [ChannelHandler] = []
@@ -427,11 +428,20 @@ extension Channel {
 
     if let tlsConfiguration = tlsConfiguration {
       do {
-        let sslClientHandler = try NIOSSLClientHandler(
-          context: try NIOSSLContext(configuration: tlsConfiguration),
-          serverHostname: tlsServerHostname
-        )
-        handlers.append(sslClientHandler)
+        if let customVerificationCallback = customVerificationCallback {
+          let sslClientHandler = try NIOSSLClientHandler(
+            context: try NIOSSLContext(configuration: tlsConfiguration),
+            serverHostname: tlsServerHostname,
+            customVerificationCallback: customVerificationCallback
+          )
+          handlers.append(sslClientHandler)
+        } else {
+          let sslClientHandler = try NIOSSLClientHandler(
+            context: try NIOSSLContext(configuration: tlsConfiguration),
+            serverHostname: tlsServerHostname
+          )
+          handlers.append(sslClientHandler)
+        }
         handlers.append(TLSVerificationHandler(logger: logger))
       } catch {
         return self.eventLoop.makeFailedFuture(error)

+ 2 - 1
Sources/GRPC/ConnectionManager.swift

@@ -853,7 +853,8 @@ extension ConnectionManager {
             group: self.eventLoop,
             hasTLS: self.configuration.tls != nil
           ),
-          logger: self.logger
+          logger: self.logger,
+          customVerificationCallback: self.configuration.tls?.customVerificationCallback
         )
 
         // Run the debug initializer, if there is one.

+ 9 - 0
Sources/GRPC/GRPCChannel/GRPCChannelBuilder.swift

@@ -235,6 +235,15 @@ extension ClientConnection.Builder.Secure {
     self.tls.certificateVerification = certificateVerification
     return self
   }
+
+  /// A custom verification callback that allows completely overriding the certificate verification logic.
+  @discardableResult
+  public func withTLSCustomVerificationCallback(
+    _ callback: @escaping NIOSSLCustomVerificationCallback
+  ) -> Self {
+    self.tls.customVerificationCallback = callback
+    return self
+  }
 }
 
 extension ClientConnection.Builder {

+ 8 - 1
Sources/GRPC/TLSConfiguration.swift

@@ -68,6 +68,9 @@ extension ClientConnection.Configuration {
       }
     }
 
+    /// A custom verification callback that allows completely overriding the certificate verification logic for this connection.
+    public var customVerificationCallback: NIOSSLCustomVerificationCallback?
+
     /// TLS Configuration with suitable defaults for clients.
     ///
     /// This is a wrapper around `NIOSSL.TLSConfiguration` to restrict input to values which comply
@@ -83,12 +86,15 @@ extension ClientConnection.Configuration {
     ///     `.fullVerification`.
     /// - Parameter hostnameOverride: Value to use for TLS SNI extension; this must not be an IP
     ///     address, defaults to `nil`.
+    /// - Parameter customVerificationCallback: A callback to provide to override the certificate verification logic,
+    ///     defaults to `nil`.
     public init(
       certificateChain: [NIOSSLCertificateSource] = [],
       privateKey: NIOSSLPrivateKeySource? = nil,
       trustRoots: NIOSSLTrustRoots = .default,
       certificateVerification: CertificateVerification = .fullVerification,
-      hostnameOverride: String? = nil
+      hostnameOverride: String? = nil,
+      customVerificationCallback: NIOSSLCustomVerificationCallback? = nil
     ) {
       self.configuration = .forClient(
         minimumTLSVersion: .tlsv12,
@@ -99,6 +105,7 @@ extension ClientConnection.Configuration {
         applicationProtocols: GRPCApplicationProtocolIdentifier.client
       )
       self.hostnameOverride = hostnameOverride
+      self.customVerificationCallback = customVerificationCallback
     }
 
     /// Creates a TLS Configuration using the given `NIOSSL.TLSConfiguration`.

+ 45 - 0
Tests/GRPCTests/ClientTLSFailureTests.swift

@@ -180,4 +180,49 @@ class ClientTLSFailureTests: GRPCTestCase {
       XCTFail("Expected NIOSSLExtraError.failedToValidateHostname")
     }
   }
+
+  func testClientConnectionFailsWhenCertificateValidationDenied() throws {
+    let errorExpectation = self.expectation(description: "error")
+    // 2 errors: one for the failed handshake, and another for failing the ready-channel promise
+    // (because the handshake failed).
+    errorExpectation.expectedFulfillmentCount = 2
+
+    let tlsConfiguration = ClientConnection.Configuration.TLS(
+      certificateChain: [.certificate(SampleCertificate.client.certificate)],
+      privateKey: .privateKey(SamplePrivateKey.client),
+      trustRoots: .certificates([SampleCertificate.ca.certificate]),
+      hostnameOverride: SampleCertificate.server.commonName,
+      customVerificationCallback: { _, promise in
+        // The certificate validation is forced to fail
+        promise.fail(NIOSSLError.unableToValidateCertificate)
+      }
+    )
+
+    var configuration = self.makeClientConfiguration(tls: tlsConfiguration)
+    let errorRecorder = ErrorRecordingDelegate(expectation: errorExpectation)
+    configuration.errorDelegate = errorRecorder
+
+    let stateChangeDelegate = RecordingConnectivityDelegate()
+    stateChangeDelegate.expectChanges(2) { changes in
+      XCTAssertEqual(changes, [
+        Change(from: .idle, to: .connecting),
+        Change(from: .connecting, to: .shutdown),
+      ])
+    }
+    configuration.connectivityStateDelegate = stateChangeDelegate
+
+    // Start an RPC to trigger creating a channel.
+    let echo = Echo_EchoClient(channel: ClientConnection(configuration: configuration))
+    _ = echo.get(.with { $0.text = "foo" })
+
+    self.wait(for: [errorExpectation], timeout: self.defaultTestTimeout)
+    stateChangeDelegate.waitForExpectedChanges(timeout: .seconds(5))
+
+    if let nioSSLError = errorRecorder.errors.first as? NIOSSLError,
+      case .handshakeFailed(.sslError) = nioSSLError {
+      // Expected case.
+    } else {
+      XCTFail("Expected NIOSSLError.handshakeFailed(BoringSSL.sslError)")
+    }
+  }
 }