ClientTLSFailureTests.swift 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. /*
  2. * Copyright 2019, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import EchoImplementation
  17. import EchoModel
  18. @testable import GRPC
  19. import GRPCSampleData
  20. import Logging
  21. import NIOConcurrencyHelpers
  22. import NIOCore
  23. import NIOPosix
  24. import NIOSSL
  25. import XCTest
  26. class ErrorRecordingDelegate: ClientErrorDelegate {
  27. private let lock: Lock
  28. private var _errors: [Error] = []
  29. internal var errors: [Error] {
  30. return self.lock.withLock {
  31. return self._errors
  32. }
  33. }
  34. var expectation: XCTestExpectation
  35. init(expectation: XCTestExpectation) {
  36. self.expectation = expectation
  37. self.lock = Lock()
  38. }
  39. func didCatchError(_ error: Error, logger: Logger, file: StaticString, line: Int) {
  40. self.lock.withLockVoid {
  41. self._errors.append(error)
  42. }
  43. self.expectation.fulfill()
  44. }
  45. }
  46. class ClientTLSFailureTests: GRPCTestCase {
  47. let defaultServerTLSConfiguration = GRPCTLSConfiguration.makeServerConfigurationBackedByNIOSSL(
  48. certificateChain: [.certificate(SampleCertificate.server.certificate)],
  49. privateKey: .privateKey(SamplePrivateKey.server)
  50. )
  51. let defaultClientTLSConfiguration = GRPCTLSConfiguration.makeClientConfigurationBackedByNIOSSL(
  52. certificateChain: [.certificate(SampleCertificate.client.certificate)],
  53. privateKey: .privateKey(SamplePrivateKey.client),
  54. trustRoots: .certificates([SampleCertificate.ca.certificate]),
  55. hostnameOverride: SampleCertificate.server.commonName
  56. )
  57. var defaultTestTimeout: TimeInterval = 1.0
  58. var clientEventLoopGroup: EventLoopGroup!
  59. var serverEventLoopGroup: EventLoopGroup!
  60. var server: Server!
  61. var port: Int!
  62. func makeClientConfiguration(
  63. tls: GRPCTLSConfiguration
  64. ) -> ClientConnection.Configuration {
  65. var configuration = ClientConnection.Configuration.default(
  66. target: .hostAndPort("localhost", self.port),
  67. eventLoopGroup: self.clientEventLoopGroup
  68. )
  69. configuration.tlsConfiguration = tls
  70. // No need to retry connecting.
  71. configuration.connectionBackoff = nil
  72. configuration.backgroundActivityLogger = self.clientLogger
  73. return configuration
  74. }
  75. func makeClientConnectionExpectation() -> XCTestExpectation {
  76. return self.expectation(description: "EventLoopFuture<ClientConnection> resolved")
  77. }
  78. override func setUp() {
  79. super.setUp()
  80. self.serverEventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
  81. self.server = try! Server.usingTLSBackedByNIOSSL(
  82. on: self.serverEventLoopGroup,
  83. certificateChain: [SampleCertificate.server.certificate],
  84. privateKey: SamplePrivateKey.server
  85. ).withServiceProviders([EchoProvider()])
  86. .withLogger(self.serverLogger)
  87. .bind(host: "localhost", port: 0)
  88. .wait()
  89. self.port = self.server.channel.localAddress?.port
  90. self.clientEventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
  91. // Delay the client connection creation until the test.
  92. }
  93. override func tearDown() {
  94. self.port = nil
  95. XCTAssertNoThrow(try self.clientEventLoopGroup.syncShutdownGracefully())
  96. self.clientEventLoopGroup = nil
  97. XCTAssertNoThrow(try self.server.close().wait())
  98. XCTAssertNoThrow(try self.serverEventLoopGroup.syncShutdownGracefully())
  99. self.server = nil
  100. self.serverEventLoopGroup = nil
  101. super.tearDown()
  102. }
  103. func testClientConnectionFailsWhenServerIsUnknown() throws {
  104. let errorExpectation = self.expectation(description: "error")
  105. // 2 errors: one for the failed handshake, and another for failing the ready-channel promise
  106. // (because the handshake failed).
  107. errorExpectation.expectedFulfillmentCount = 2
  108. var tls = self.defaultClientTLSConfiguration
  109. tls.updateNIOTrustRoots(to: .certificates([]))
  110. var configuration = self.makeClientConfiguration(tls: tls)
  111. let errorRecorder = ErrorRecordingDelegate(expectation: errorExpectation)
  112. configuration.errorDelegate = errorRecorder
  113. let stateChangeDelegate = RecordingConnectivityDelegate()
  114. stateChangeDelegate.expectChanges(2) { changes in
  115. XCTAssertEqual(changes, [
  116. Change(from: .idle, to: .connecting),
  117. Change(from: .connecting, to: .shutdown),
  118. ])
  119. }
  120. configuration.connectivityStateDelegate = stateChangeDelegate
  121. // Start an RPC to trigger creating a channel.
  122. let echo = Echo_EchoClient(channel: ClientConnection(configuration: configuration))
  123. _ = echo.get(.with { $0.text = "foo" })
  124. self.wait(for: [errorExpectation], timeout: self.defaultTestTimeout)
  125. stateChangeDelegate.waitForExpectedChanges(timeout: .seconds(5))
  126. if let nioSSLError = errorRecorder.errors.first as? NIOSSLError,
  127. case .handshakeFailed(.sslError) = nioSSLError {
  128. // Expected case.
  129. } else {
  130. XCTFail("Expected NIOSSLError.handshakeFailed(BoringSSL.sslError)")
  131. }
  132. }
  133. func testClientConnectionFailsWhenHostnameIsNotValid() throws {
  134. let errorExpectation = self.expectation(description: "error")
  135. // 2 errors: one for the failed handshake, and another for failing the ready-channel promise
  136. // (because the handshake failed).
  137. errorExpectation.expectedFulfillmentCount = 2
  138. var tls = self.defaultClientTLSConfiguration
  139. tls.hostnameOverride = "not-the-server-hostname"
  140. var configuration = self.makeClientConfiguration(tls: tls)
  141. let errorRecorder = ErrorRecordingDelegate(expectation: errorExpectation)
  142. configuration.errorDelegate = errorRecorder
  143. let stateChangeDelegate = RecordingConnectivityDelegate()
  144. stateChangeDelegate.expectChanges(2) { changes in
  145. XCTAssertEqual(changes, [
  146. Change(from: .idle, to: .connecting),
  147. Change(from: .connecting, to: .shutdown),
  148. ])
  149. }
  150. configuration.connectivityStateDelegate = stateChangeDelegate
  151. // Start an RPC to trigger creating a channel.
  152. let echo = Echo_EchoClient(channel: ClientConnection(configuration: configuration))
  153. _ = echo.get(.with { $0.text = "foo" })
  154. self.wait(for: [errorExpectation], timeout: self.defaultTestTimeout)
  155. stateChangeDelegate.waitForExpectedChanges(timeout: .seconds(5))
  156. if let nioSSLError = errorRecorder.errors.first as? NIOSSLExtraError {
  157. XCTAssertEqual(nioSSLError, .failedToValidateHostname)
  158. // Expected case.
  159. } else {
  160. XCTFail("Expected NIOSSLExtraError.failedToValidateHostname")
  161. }
  162. }
  163. func testClientConnectionFailsWhenCertificateValidationDenied() throws {
  164. let errorExpectation = self.expectation(description: "error")
  165. // 2 errors: one for the failed handshake, and another for failing the ready-channel promise
  166. // (because the handshake failed).
  167. errorExpectation.expectedFulfillmentCount = 2
  168. let tlsConfiguration = GRPCTLSConfiguration.makeClientConfigurationBackedByNIOSSL(
  169. certificateChain: [.certificate(SampleCertificate.client.certificate)],
  170. privateKey: .privateKey(SamplePrivateKey.client),
  171. trustRoots: .certificates([SampleCertificate.ca.certificate]),
  172. hostnameOverride: SampleCertificate.server.commonName,
  173. customVerificationCallback: { _, promise in
  174. // The certificate validation is forced to fail
  175. promise.fail(NIOSSLError.unableToValidateCertificate)
  176. }
  177. )
  178. var configuration = self.makeClientConfiguration(tls: tlsConfiguration)
  179. let errorRecorder = ErrorRecordingDelegate(expectation: errorExpectation)
  180. configuration.errorDelegate = errorRecorder
  181. let stateChangeDelegate = RecordingConnectivityDelegate()
  182. stateChangeDelegate.expectChanges(2) { changes in
  183. XCTAssertEqual(changes, [
  184. Change(from: .idle, to: .connecting),
  185. Change(from: .connecting, to: .shutdown),
  186. ])
  187. }
  188. configuration.connectivityStateDelegate = stateChangeDelegate
  189. // Start an RPC to trigger creating a channel.
  190. let echo = Echo_EchoClient(channel: ClientConnection(configuration: configuration))
  191. _ = echo.get(.with { $0.text = "foo" })
  192. self.wait(for: [errorExpectation], timeout: self.defaultTestTimeout)
  193. stateChangeDelegate.waitForExpectedChanges(timeout: .seconds(5))
  194. if let nioSSLError = errorRecorder.errors.first as? NIOSSLError,
  195. case .handshakeFailed(.sslError) = nioSSLError {
  196. // Expected case.
  197. } else {
  198. XCTFail("Expected NIOSSLError.handshakeFailed(BoringSSL.sslError)")
  199. }
  200. }
  201. }