Browse Source

Avoid protocol check after TLS handshake (#685)

Motivation:

When tls handshake is over there is no need to raise a failure if no
protocol has been negotiated.

Modifications:

The ```TLSVerificationHandler.userInboundEventTriggered``` method has
been updated so that it always succeed. However it will notice when a
protocol could not be negotiated

Result:

This change allows a successful connection even when a protocol is not
negotiated. By doing so we are more consistent with other grpc
implementations.
William Izzo 6 years ago
parent
commit
2d9158bd52

+ 8 - 12
Sources/GRPC/TLSVerificationHandler.swift

@@ -31,8 +31,7 @@ internal enum GRPCApplicationProtocolIdentifier: String, CaseIterable {
 /// and that the negotiated application protocol is valid.
 ///
 /// The handler holds a promise which is succeeded on successful verification of the negotiated
-/// application protocol and failed if any error is received by this handler or an invalid
-/// application protocol was negotiated.
+/// application protocol and failed if any error is received by this handler.
 ///
 /// Users of this handler should rely on the `verification` future held by this instance.
 ///
@@ -44,9 +43,8 @@ internal class TLSVerificationHandler: ChannelInboundHandler, RemovableChannelHa
   private var verificationPromise: EventLoopPromise<Void>!
 
   /// A future which is fulfilled when the state of the TLS handshake is known. If the handshake
-  /// was successful and the negotiated application protocol is valid then the future is succeeded.
-  /// If an error occurred or the application protocol is not valid then the future will have been
-  /// failed.
+  /// was successful then the future is succeeded.
+  /// If an error occurred the future will have been failed.
   ///
   /// - Important: The promise associated with this future is created in `handlerAdded(context:)`,
   ///   and as such must _not_ be accessed before the handler has be added to a pipeline.
@@ -85,14 +83,12 @@ internal class TLSVerificationHandler: ChannelInboundHandler, RemovableChannelHa
         return
     }
 
-    self.logger.debug("TLS handshake completed, negotiated protocol: \(String(describing: negotiatedProtocol))")
-    if let proto = negotiatedProtocol, GRPCApplicationProtocolIdentifier(rawValue: proto) != nil {
-      self.logger.debug("negotiated application protocol is valid")
-      self.verificationPromise.succeed(())
+    if let proto = negotiatedProtocol {
+      self.logger.debug("TLS handshake completed, negotiated protocol: \(proto)")
     } else {
-      self.logger.error("negotiated application protocol is invalid: \(String(describing: negotiatedProtocol))")
-      let error = GRPCError.client(.applicationLevelProtocolNegotiationFailed)
-      self.verificationPromise.fail(error)
+      self.logger.debug("TLS handshake completed, no protocol negotiated")
     }
+    
+    self.verificationPromise.succeed(())
   }
 }

+ 0 - 28
Tests/GRPCTests/ClientTLSFailureTests.swift

@@ -100,34 +100,6 @@ class ClientTLSFailureTests: GRPCTestCase {
     self.serverEventLoopGroup = nil
   }
 
-  func testClientConnectionFailsWhenProtocolCanNotBeNegotiated() throws {
-    let shutdownExpectation = self.expectation(description: "client shutdown")
-    let errorExpectation = self.expectation(description: "error")
-
-    // We use the underlying configuration because `applicationProtocols` is not user-configurable
-    // via `Configuration.TLS`.
-    var tlsConfiguration = self.defaultClientTLSConfiguration.configuration
-    tlsConfiguration.applicationProtocols = ["not-h2", "not-grpc-exp"]
-
-    let tls = ClientConnection.Configuration.TLS(
-      configuration: tlsConfiguration,
-      hostnameOverride: self.defaultClientTLSConfiguration.hostnameOverride
-    )
-
-    var configuration = self.makeClientConfiguration(tls: tls)
-    let errorRecorder = ErrorRecordingDelegate(expectation: errorExpectation)
-    configuration.errorDelegate = errorRecorder
-
-    let stateChangeDelegate = ConnectivityStateCollectionDelegate(shutdown: shutdownExpectation)
-    configuration.connectivityStateDelegate = stateChangeDelegate
-
-    _ = ClientConnection(configuration: configuration)
-
-    self.wait(for: [shutdownExpectation, errorExpectation], timeout: self.defaultTestTimeout)
-
-    let clientErrors = errorRecorder.errors.compactMap { $0 as? GRPCClientError }
-    XCTAssertEqual(clientErrors, [.applicationLevelProtocolNegotiationFailed])
-  }
 
   func testClientConnectionFailsWhenServerIsUnknown() throws {
     let shutdownExpectation = self.expectation(description: "client shutdown")

+ 59 - 0
Tests/GRPCTests/TLSVerificationHandlerTests.swift

@@ -0,0 +1,59 @@
+/*
+ * Copyright 2020, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import Foundation
+@testable import GRPC
+import NIO
+import NIOSSL
+import XCTest
+import NIOTLS
+
+class TLSVerificationHandlerTests: GRPCTestCase {
+  func testTLSValidationSucceededWithUnspecifiedProtocol() {
+    let expectation = self.expectation(description: "tls handshake success")
+    let tlsVerificationHandler = TLSVerificationHandler()
+    let handshakeEvent = TLSUserEvent.handshakeCompleted(negotiatedProtocol: nil)
+    let channel = EmbeddedChannel(handler: tlsVerificationHandler)
+    channel.pipeline.fireUserInboundEventTriggered(handshakeEvent)
+    tlsVerificationHandler.verification.assertSuccess(fulfill: expectation)
+    self.wait(for: [expectation], timeout: 1.0)
+  }
+  
+  func testTLSValidationSucceededWithGRPCApplicationProtocols() {
+    var expectations = [XCTestExpectation]()
+    
+    GRPCApplicationProtocolIdentifier.allCases.forEach {
+      let exp = self.expectation(description: "tls \(String(describing:$0)) protocol success")
+      expectations.append(exp)
+      let tlsVerificationHandler = TLSVerificationHandler()
+      let channel = EmbeddedChannel(handler: tlsVerificationHandler)
+      let handshakeEvent = TLSUserEvent.handshakeCompleted(negotiatedProtocol: $0.rawValue)
+      channel.pipeline.fireUserInboundEventTriggered(handshakeEvent)
+      tlsVerificationHandler.verification.assertSuccess(fulfill: exp)
+    }
+    
+    self.wait(for: expectations, timeout: 1.0)
+  }
+  
+  func testTLSValidationSucceededWithCustomProtocol() {
+    let expectation = self.expectation(description: "tls custom protocol success")
+    let tlsVerificationHandler = TLSVerificationHandler()
+    let handshakeEvent = TLSUserEvent.handshakeCompleted(negotiatedProtocol: "some-protocol")
+    let channel = EmbeddedChannel(handler: tlsVerificationHandler)
+    channel.pipeline.fireUserInboundEventTriggered(handshakeEvent)
+    tlsVerificationHandler.verification.assertSuccess(fulfill: expectation)
+    self.wait(for: [expectation], timeout: 1.0)
+  }
+}

+ 12 - 1
Tests/GRPCTests/XCTestManifests.swift

@@ -57,7 +57,6 @@ extension ClientTLSFailureTests {
     // to regenerate.
     static let __allTests__ClientTLSFailureTests = [
         ("testClientConnectionFailsWhenHostnameIsNotValid", testClientConnectionFailsWhenHostnameIsNotValid),
-        ("testClientConnectionFailsWhenProtocolCanNotBeNegotiated", testClientConnectionFailsWhenProtocolCanNotBeNegotiated),
         ("testClientConnectionFailsWhenServerIsUnknown", testClientConnectionFailsWhenServerIsUnknown),
     ]
 }
@@ -582,6 +581,17 @@ extension StreamingRequestClientCallTests {
     ]
 }
 
+extension TLSVerificationHandlerTests {
+    // DO NOT MODIFY: This is autogenerated, use:
+    //   `swift test --generate-linuxmain`
+    // to regenerate.
+    static let __allTests__TLSVerificationHandlerTests = [
+        ("testTLSValidationSucceededWithCustomProtocol", testTLSValidationSucceededWithCustomProtocol),
+        ("testTLSValidationSucceededWithGRPCApplicationProtocols", testTLSValidationSucceededWithGRPCApplicationProtocols),
+        ("testTLSValidationSucceededWithUnspecifiedProtocol", testTLSValidationSucceededWithUnspecifiedProtocol),
+    ]
+}
+
 public func __allTests() -> [XCTestCaseEntry] {
     return [
         testCase(AnyServiceClientTests.__allTests__AnyServiceClientTests),
@@ -621,6 +631,7 @@ public func __allTests() -> [XCTestCaseEntry] {
         testCase(ServerWebTests.__allTests__ServerWebTests),
         testCase(StopwatchTests.__allTests__StopwatchTests),
         testCase(StreamingRequestClientCallTests.__allTests__StreamingRequestClientCallTests),
+        testCase(TLSVerificationHandlerTests.__allTests__TLSVerificationHandlerTests),
     ]
 }
 #endif