Browse Source

Verify TLS/ALPN were successful before returning a connection (#443)

* Verify TLS/ALPN were successful before returning a connection

* Add TODO, remove unused import, clarify error message

* Clarify the use of the TLS verification future
George Barnett 6 years ago
parent
commit
69b48e12ad

+ 4 - 2
Sources/Examples/EchoNIO/main.swift

@@ -55,7 +55,8 @@ func makeClientTLSConfiguration() throws -> TLSConfiguration {
   return .forClient(certificateVerification: .noHostnameVerification,
                     trustRoots: .certificates([caCert.certificate]),
                     certificateChain: [.certificate(clientCert.certificate)],
-                    privateKey: .privateKey(SamplePrivateKey.client))
+                    privateKey: .privateKey(SamplePrivateKey.client),
+                    applicationProtocols: GRPCApplicationProtocolIdentifier.allCases.map { $0.rawValue })
 }
 
 func makeServerTLSConfiguration() throws -> TLSConfiguration {
@@ -66,7 +67,8 @@ func makeServerTLSConfiguration() throws -> TLSConfiguration {
 
   return .forServer(certificateChain: [.certificate(serverCert.certificate)],
                     privateKey: .privateKey(SamplePrivateKey.server),
-                    trustRoots: .certificates([caCert.certificate]))
+                    trustRoots: .certificates([caCert.certificate]),
+                    applicationProtocols: GRPCApplicationProtocolIdentifier.allCases.map { $0.rawValue })
 }
 
 /// Create en `EchoClient` and wait for it to initialize. Returns nil if initialisation fails.

+ 30 - 3
Sources/SwiftGRPCNIO/GRPCClientConnection.swift

@@ -17,6 +17,7 @@ import Foundation
 import NIO
 import NIOHTTP2
 import NIOSSL
+import NIOTLS
 
 /// Underlying channel and HTTP/2 stream multiplexer.
 ///
@@ -49,8 +50,30 @@ open class GRPCClientConnection {
       }
 
     return bootstrap.connect(host: host, port: port).flatMap { channel in
-      channel.pipeline.context(handlerType: HTTP2StreamMultiplexer.self).map { context in
-        context.handler as! HTTP2StreamMultiplexer
+      // Check the handshake succeeded and a valid protocol was negotiated via ALPN.
+      let tlsVerified: EventLoopFuture<Void>
+
+      if case .none = tlsMode {
+        tlsVerified = channel.eventLoop.makeSucceededFuture(())
+      } else {
+        // TODO: Use `handler(type:)` introduced in https://github.com/apple/swift-nio/pull/974
+        // once it has been released.
+        tlsVerified = channel.pipeline.context(handlerType: GRPCTLSVerificationHandler.self).map {
+          $0.handler as! GRPCTLSVerificationHandler
+        }.flatMap {
+          // Use the result of the verification future to determine whether we should return a
+          // connection to the caller. Note that even though it contains a `Void` it may also
+          // contain an `Error`, which is what we are interested in here.
+          $0.verification
+        }
+      }
+
+      return tlsVerified.flatMap {
+        // TODO: Use `handler(type:)` introduced in https://github.com/apple/swift-nio/pull/974
+        // once it has been released.
+        channel.pipeline.context(handlerType: HTTP2StreamMultiplexer.self)
+      }.map {
+        $0.handler as! HTTP2StreamMultiplexer
       }.map { multiplexer in
         GRPCClientConnection(channel: channel, multiplexer: multiplexer, host: host, httpProtocol: tlsMode.httpProtocol)
       }
@@ -72,7 +95,11 @@ open class GRPCClientConnection {
         handlerAddedPromise.succeed(())
         return handlerAddedPromise.futureResult
       }
-      channel.pipeline.addHandler(try NIOSSLClientHandler(context: sslContext, serverHostname: host)).cascade(to: handlerAddedPromise)
+
+      let sslHandler = try NIOSSLClientHandler(context: sslContext, serverHostname: host)
+      let verificationHandler = GRPCTLSVerificationHandler()
+
+      channel.pipeline.addHandlers(sslHandler, verificationHandler).cascade(to: handlerAddedPromise)
     } catch {
       handlerAddedPromise.fail(error)
     }

+ 6 - 0
Sources/SwiftGRPCNIO/GRPCError.swift

@@ -116,6 +116,9 @@ public enum GRPCClientError: Error, Equatable {
 
   /// The call deadline was exceeded.
   case deadlineExceeded(GRPCTimeout)
+
+  /// The protocol negotiated via ALPN was not valid.
+  case applicationLevelProtocolNegotiationFailed
 }
 
 /// An error which should be thrown by either the client or server.
@@ -178,6 +181,9 @@ extension GRPCClientError: GRPCStatusTransformable {
 
     case .deadlineExceeded(let timeout):
       return GRPCStatus(code: .deadlineExceeded, message: "call exceeded timeout of \(timeout)")
+
+    case .applicationLevelProtocolNegotiationFailed:
+      return GRPCStatus(code: .invalidArgument, message: "failed to negotiate application level protocol")
     }
   }
 }

+ 71 - 0
Sources/SwiftGRPCNIO/GRPCTLSVerificationHandler.swift

@@ -0,0 +1,71 @@
+import Foundation
+import NIO
+import NIOSSL
+import NIOTLS
+
+/// Application protocol identifiers for ALPN.
+public enum GRPCApplicationProtocolIdentifier: String, CaseIterable {
+  // This is not in the IANA ALPN protocol ID registry, but may be used by servers to indicate that
+  // they serve only gRPC traffic. It is part of the gRPC core implementation.
+  case gRPC = "grpc-ext"
+  case h2 = "h2"
+}
+
+/// A helper `ChannelInboundHandler` to verify that a TLS handshake was completed successfully
+/// and that the negotiated application protocol is valid.
+///
+/// The handler holds a promise which is succeeded on successful verification of the negotiated
+/// application protocol and failed if any error is received by this handler or an invalid
+/// application protocol was negotiated.
+///
+/// Users of this handler should rely on the `verification` future held by this instance.
+///
+/// On fulfillment of the promise this handler is removed from the channel pipeline.
+public class GRPCTLSVerificationHandler: ChannelInboundHandler, RemovableChannelHandler {
+  public typealias InboundIn = Any
+
+  private var verificationPromise: EventLoopPromise<Void>!
+
+  /// A future which is fulfilled when the state of the TLS handshake is known. If the handshake
+  /// was successful and the negotiated application protocol is valid then the future is succeeded.
+  /// If an error occured or the application protocol is not valid then the future will have been
+  /// failed.
+  ///
+  /// - Important: The promise associated with this future is created in `handlerAdded(context:)`,
+  ///   and as such must _not_ be accessed before the handler has be added to a pipeline.
+  public var verification: EventLoopFuture<Void>! {
+    return verificationPromise.futureResult
+  }
+
+  public init() { }
+
+  public func handlerAdded(context: ChannelHandlerContext) {
+    self.verificationPromise = context.eventLoop.makePromise()
+    // Remove ourselves from the pipeline when the promise gets fulfilled.
+    self.verificationPromise.futureResult.whenComplete { _ in
+      context.pipeline.removeHandler(self, promise: nil)
+    }
+  }
+
+  public func errorCaught(context: ChannelHandlerContext, error: Error) {
+    precondition(self.verificationPromise != nil, "handler has not been added to the pipeline")
+
+    verificationPromise.fail(error)
+  }
+
+  public func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
+    precondition(self.verificationPromise != nil, "handler has not been added to the pipeline")
+
+    guard let tlsEvent = event as? TLSUserEvent,
+      case .handshakeCompleted(negotiatedProtocol: let negotiatedProtocol) = tlsEvent else {
+        context.fireUserInboundEventTriggered(event)
+        return
+    }
+
+    if let proto = negotiatedProtocol, GRPCApplicationProtocolIdentifier(rawValue: proto) != nil {
+      self.verificationPromise.succeed(())
+    } else {
+      self.verificationPromise.fail(GRPCError.client(.applicationLevelProtocolNegotiationFailed))
+    }
+  }
+}

+ 3 - 1
Sources/SwiftGRPCNIO/HTTPProtocolSwitcher.swift

@@ -134,7 +134,9 @@ extension HTTPProtocolSwitcher: ChannelInboundHandler, RemovableChannelHandler {
   private func protocolVersion(_ preamble: String) -> HTTPProtocolVersion? {
     let range = NSRange(location: 0, length: preamble.utf16.count)
     let regex = try! NSRegularExpression(pattern: "^.*HTTP/(\\d)\\.\\d$")
-    let result = regex.firstMatch(in: preamble, options: [], range: range)!
+    guard let result = regex.firstMatch(in: preamble, options: [], range: range) else {
+      return nil
+    }
 
     let versionRange = result.range(at: 1)
 

+ 26 - 0
Tests/SwiftGRPCNIOTests/EventLoopFuture+Assertions.swift

@@ -40,3 +40,29 @@ extension EventLoopFuture where Value: Equatable {
     }
   }
 }
+
+extension EventLoopFuture {
+  /// Registers a callback which asserts that this future is fulfilled with an error. Causes a test
+  /// failure if the future is not fulfilled with an error.
+  ///
+  /// Callers can additionally verify the error by providing an error handler.
+  ///
+  /// - Parameters:
+  ///   - expectation: A test expectation to fulfill once the future has completed.
+  ///   - handler: A block to run additional verification on the error. Defaults to no-op.
+  func assertError(fulfill expectation: XCTestExpectation, file: StaticString = #file, line: UInt = #line, handler: @escaping (Error) -> Void = { _ in }) {
+    self.whenComplete { result in
+      defer {
+        expectation.fulfill()
+      }
+
+      switch result {
+      case .success:
+        XCTFail("Unexpectedly received \(Value.self), expected an error", file: file, line: line)
+
+      case .failure(let error):
+        handler(error)
+      }
+    }
+  }
+}

+ 6 - 3
Tests/SwiftGRPCNIOTests/NIOBasicEchoTestCase.swift

@@ -72,7 +72,8 @@ extension TransportSecurity {
     case .anonymousClient, .mutualAuthentication:
       return .forServer(certificateChain: [.certificate(self.serverCert)],
                         privateKey: .privateKey(SamplePrivateKey.server), 
-                        trustRoots: .certificates ([self.caCert]))
+                        trustRoots: .certificates ([self.caCert]),
+                        applicationProtocols: GRPCApplicationProtocolIdentifier.allCases.map { $0.rawValue })
     }
   }
 
@@ -87,13 +88,15 @@ extension TransportSecurity {
 
     case .anonymousClient:
       return .forClient(certificateVerification: .noHostnameVerification,
-                        trustRoots: .certificates([self.caCert]))
+                        trustRoots: .certificates([self.caCert]),
+                        applicationProtocols: GRPCApplicationProtocolIdentifier.allCases.map { $0.rawValue })
 
     case .mutualAuthentication:
       return .forClient(certificateVerification: .noHostnameVerification,
                         trustRoots: .certificates([self.caCert]),
                         certificateChain: [.certificate(self.clientCert)],
-                        privateKey: .privateKey(SamplePrivateKey.client))
+                        privateKey: .privateKey(SamplePrivateKey.client),
+                        applicationProtocols: GRPCApplicationProtocolIdentifier.allCases.map { $0.rawValue })
     }
   }
 }

+ 135 - 0
Tests/SwiftGRPCNIOTests/NIOClientTLSFailureTests.swift

@@ -0,0 +1,135 @@
+/*
+ * Copyright 2019, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import Foundation
+import SwiftGRPCNIO
+import SwiftGRPCNIOSampleData
+import NIO
+import NIOSSL
+import XCTest
+
+class NIOClientTLSFailureTests: XCTestCase {
+  let defaultServerTLSConfiguration = TLSConfiguration.forServer(
+    certificateChain: [.certificate(SampleCertificate.server.certificate)],
+    privateKey: .privateKey(SamplePrivateKey.server),
+    applicationProtocols: GRPCApplicationProtocolIdentifier.allCases.map { $0.rawValue })
+
+  let defaultClientTLSConfiguration = TLSConfiguration.forClient(
+    trustRoots: .certificates([SampleCertificate.ca.certificate]),
+    certificateChain: [.certificate(SampleCertificate.client.certificate)],
+    privateKey: .privateKey(SamplePrivateKey.client),
+    applicationProtocols: GRPCApplicationProtocolIdentifier.allCases.map { $0.rawValue })
+
+  var defaultTestTimeout: TimeInterval = 1.0
+
+  var clientEventLoopGroup: EventLoopGroup!
+  var serverEventLoopGroup: EventLoopGroup!
+  var server: GRPCServer!
+  var port: Int!
+
+  func makeClientConnection(
+    configuration: TLSConfiguration,
+    hostOverride: String? = SampleCertificate.server.commonName
+  ) throws -> EventLoopFuture<GRPCClientConnection> {
+    return try GRPCClientConnection.start(
+      host: "localhost",
+      port: self.port,
+      eventLoopGroup: self.clientEventLoopGroup,
+      tls: .custom(try NIOSSLContext(configuration: configuration)),
+      hostOverride: hostOverride)
+  }
+
+  func makeClientConnectionExpectation() -> XCTestExpectation {
+    return self.expectation(description: "EventLoopFuture<GRPCClientConnection> resolved")
+  }
+
+  override func setUp() {
+    self.serverEventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
+    self.server = try! GRPCServer.start(
+      hostname: "localhost",
+      port: 0,
+      eventLoopGroup: self.serverEventLoopGroup,
+      serviceProviders: [EchoProviderNIO()],
+      errorDelegate: nil,
+      tls: .custom(try NIOSSLContext(configuration: defaultServerTLSConfiguration))
+    ).wait()
+
+    self.port = self.server.channel.localAddress?.port
+
+    self.clientEventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
+    // Delay the client connection creation until the test.
+  }
+
+  override func tearDown() {
+    self.port = nil
+
+    XCTAssertNoThrow(try self.clientEventLoopGroup.syncShutdownGracefully())
+    self.clientEventLoopGroup = nil
+
+    XCTAssertNoThrow(try self.server.close().wait())
+    XCTAssertNoThrow(try self.serverEventLoopGroup.syncShutdownGracefully())
+    self.server = nil
+    self.serverEventLoopGroup = nil
+  }
+
+  func testClientConnectionFailsWhenProtocolCanNotBeNegotiated() throws {
+    var configuration = defaultClientTLSConfiguration
+    configuration.applicationProtocols = ["not-h2", "not-grpc-ext"]
+
+    let connection = try self.makeClientConnection(configuration: configuration)
+    let connectionExpectation = self.makeClientConnectionExpectation()
+
+    connection.assertError(fulfill: connectionExpectation) { error in
+      let clientError = (error as? GRPCError)?.error as? GRPCClientError
+      XCTAssertEqual(clientError, .applicationLevelProtocolNegotiationFailed)
+    }
+
+    self.wait(for: [connectionExpectation], timeout: self.defaultTestTimeout)
+  }
+
+  func testClientConnectionFailsWhenServerIsUnknown() throws {
+    var configuration = defaultClientTLSConfiguration
+    configuration.trustRoots = .certificates([])
+
+    let connection = try self.makeClientConnection(configuration: configuration)
+    let connectionExpectation = self.makeClientConnectionExpectation()
+
+    connection.assertError(fulfill: connectionExpectation) { error in
+      guard case .some(.handshakeFailed(.sslError)) = error as? NIOSSLError else {
+        XCTFail("Expected NIOSSLError.handshakeFailed(BoringSSL.sslError) but got \(error)")
+        return
+      }
+    }
+
+    self.wait(for: [connectionExpectation], timeout: self.defaultTestTimeout)
+  }
+
+  func testClientConnectionFailsWhenHostnameIsNotValid() throws {
+    let connection = try self.makeClientConnection(
+      configuration: self.defaultClientTLSConfiguration,
+      hostOverride: "not-the-server-hostname")
+
+    let connectionExpectation = self.makeClientConnectionExpectation()
+
+    connection.assertError(fulfill: connectionExpectation) { error in
+      guard case .some(.unableToValidateCertificate) = error as? NIOSSLError else {
+        XCTFail("Expected NIOSSLError.unableToValidateCertificate but got \(error)")
+        return
+      }
+    }
+
+    self.wait(for: [connectionExpectation], timeout: self.defaultTestTimeout)
+  }
+}

+ 12 - 0
Tests/SwiftGRPCNIOTests/XCTestManifests.swift

@@ -116,6 +116,17 @@ extension NIOClientCancellingTests {
     ]
 }
 
+extension NIOClientTLSFailureTests {
+    // DO NOT MODIFY: This is autogenerated, use:
+    //   `swift test --generate-linuxmain`
+    // to regenerate.
+    static let __allTests__NIOClientTLSFailureTests = [
+        ("testClientConnectionFailsWhenHostnameIsNotValid", testClientConnectionFailsWhenHostnameIsNotValid),
+        ("testClientConnectionFailsWhenProtocolCanNotBeNegotiated", testClientConnectionFailsWhenProtocolCanNotBeNegotiated),
+        ("testClientConnectionFailsWhenServerIsUnknown", testClientConnectionFailsWhenServerIsUnknown),
+    ]
+}
+
 extension NIOClientTimeoutTests {
     // DO NOT MODIFY: This is autogenerated, use:
     //   `swift test --generate-linuxmain`
@@ -247,6 +258,7 @@ public func __allTests() -> [XCTestCaseEntry] {
         testCase(HTTP1ToRawGRPCServerCodecTests.__allTests__HTTP1ToRawGRPCServerCodecTests),
         testCase(LengthPrefixedMessageReaderTests.__allTests__LengthPrefixedMessageReaderTests),
         testCase(NIOClientCancellingTests.__allTests__NIOClientCancellingTests),
+        testCase(NIOClientTLSFailureTests.__allTests__NIOClientTLSFailureTests),
         testCase(NIOClientTimeoutTests.__allTests__NIOClientTimeoutTests),
         testCase(NIOFunctionalTestsAnonymousClient.__allTests__NIOFunctionalTestsAnonymousClient),
         testCase(NIOFunctionalTestsInsecureTransport.__allTests__NIOFunctionalTestsInsecureTransport),