ClientTLSFailureTests.swift 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. /*
  2. * Copyright 2019, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import EchoImplementation
  17. import EchoModel
  18. @testable import GRPC
  19. import GRPCSampleData
  20. import Logging
  21. import NIO
  22. import NIOConcurrencyHelpers
  23. import NIOSSL
  24. import XCTest
  25. class ErrorRecordingDelegate: ClientErrorDelegate {
  26. private let lock: Lock
  27. private var _errors: [Error] = []
  28. internal var errors: [Error] {
  29. return self.lock.withLock {
  30. return self._errors
  31. }
  32. }
  33. var expectation: XCTestExpectation
  34. init(expectation: XCTestExpectation) {
  35. self.expectation = expectation
  36. self.lock = Lock()
  37. }
  38. func didCatchError(_ error: Error, logger: Logger, file: StaticString, line: Int) {
  39. self.lock.withLockVoid {
  40. self._errors.append(error)
  41. }
  42. self.expectation.fulfill()
  43. }
  44. }
  45. class ClientTLSFailureTests: GRPCTestCase {
  46. let defaultServerTLSConfiguration = Server.Configuration.TLS(
  47. certificateChain: [.certificate(SampleCertificate.server.certificate)],
  48. privateKey: .privateKey(SamplePrivateKey.server)
  49. )
  50. let defaultClientTLSConfiguration = ClientConnection.Configuration.TLS(
  51. certificateChain: [.certificate(SampleCertificate.client.certificate)],
  52. privateKey: .privateKey(SamplePrivateKey.client),
  53. trustRoots: .certificates([SampleCertificate.ca.certificate]),
  54. hostnameOverride: SampleCertificate.server.commonName
  55. )
  56. var defaultTestTimeout: TimeInterval = 1.0
  57. var clientEventLoopGroup: EventLoopGroup!
  58. var serverEventLoopGroup: EventLoopGroup!
  59. var server: Server!
  60. var port: Int!
  61. func makeClientConfiguration(
  62. tls: ClientConnection.Configuration.TLS
  63. ) -> ClientConnection.Configuration {
  64. return .init(
  65. target: .hostAndPort("localhost", self.port),
  66. eventLoopGroup: self.clientEventLoopGroup,
  67. tls: tls,
  68. // No need to retry connecting.
  69. connectionBackoff: nil,
  70. backgroundActivityLogger: self.clientLogger
  71. )
  72. }
  73. func makeClientConnectionExpectation() -> XCTestExpectation {
  74. return self.expectation(description: "EventLoopFuture<ClientConnection> resolved")
  75. }
  76. override func setUp() {
  77. super.setUp()
  78. self.serverEventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
  79. self.server = try! Server.secure(
  80. group: self.serverEventLoopGroup,
  81. certificateChain: [SampleCertificate.server.certificate],
  82. privateKey: SamplePrivateKey.server
  83. ).withServiceProviders([EchoProvider()])
  84. .withLogger(self.serverLogger)
  85. .bind(host: "localhost", port: 0)
  86. .wait()
  87. self.port = self.server.channel.localAddress?.port
  88. self.clientEventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
  89. // Delay the client connection creation until the test.
  90. }
  91. override func tearDown() {
  92. self.port = nil
  93. XCTAssertNoThrow(try self.clientEventLoopGroup.syncShutdownGracefully())
  94. self.clientEventLoopGroup = nil
  95. XCTAssertNoThrow(try self.server.close().wait())
  96. XCTAssertNoThrow(try self.serverEventLoopGroup.syncShutdownGracefully())
  97. self.server = nil
  98. self.serverEventLoopGroup = nil
  99. super.tearDown()
  100. }
  101. func testClientConnectionFailsWhenServerIsUnknown() throws {
  102. let errorExpectation = self.expectation(description: "error")
  103. // 2 errors: one for the failed handshake, and another for failing the ready-channel promise
  104. // (because the handshake failed).
  105. errorExpectation.expectedFulfillmentCount = 2
  106. var tls = self.defaultClientTLSConfiguration
  107. tls.trustRoots = .certificates([])
  108. var configuration = self.makeClientConfiguration(tls: tls)
  109. let errorRecorder = ErrorRecordingDelegate(expectation: errorExpectation)
  110. configuration.errorDelegate = errorRecorder
  111. let stateChangeDelegate = RecordingConnectivityDelegate()
  112. stateChangeDelegate.expectChanges(2) { changes in
  113. XCTAssertEqual(changes, [
  114. Change(from: .idle, to: .connecting),
  115. Change(from: .connecting, to: .shutdown),
  116. ])
  117. }
  118. configuration.connectivityStateDelegate = stateChangeDelegate
  119. // Start an RPC to trigger creating a channel.
  120. let echo = Echo_EchoClient(channel: ClientConnection(configuration: configuration))
  121. _ = echo.get(.with { $0.text = "foo" })
  122. self.wait(for: [errorExpectation], timeout: self.defaultTestTimeout)
  123. stateChangeDelegate.waitForExpectedChanges(timeout: .seconds(5))
  124. if let nioSSLError = errorRecorder.errors.first as? NIOSSLError,
  125. case .handshakeFailed(.sslError) = nioSSLError {
  126. // Expected case.
  127. } else {
  128. XCTFail("Expected NIOSSLError.handshakeFailed(BoringSSL.sslError)")
  129. }
  130. }
  131. func testClientConnectionFailsWhenHostnameIsNotValid() throws {
  132. let errorExpectation = self.expectation(description: "error")
  133. // 2 errors: one for the failed handshake, and another for failing the ready-channel promise
  134. // (because the handshake failed).
  135. errorExpectation.expectedFulfillmentCount = 2
  136. var tls = self.defaultClientTLSConfiguration
  137. tls.hostnameOverride = "not-the-server-hostname"
  138. var configuration = self.makeClientConfiguration(tls: tls)
  139. let errorRecorder = ErrorRecordingDelegate(expectation: errorExpectation)
  140. configuration.errorDelegate = errorRecorder
  141. let stateChangeDelegate = RecordingConnectivityDelegate()
  142. stateChangeDelegate.expectChanges(2) { changes in
  143. XCTAssertEqual(changes, [
  144. Change(from: .idle, to: .connecting),
  145. Change(from: .connecting, to: .shutdown),
  146. ])
  147. }
  148. configuration.connectivityStateDelegate = stateChangeDelegate
  149. // Start an RPC to trigger creating a channel.
  150. let echo = Echo_EchoClient(channel: ClientConnection(configuration: configuration))
  151. _ = echo.get(.with { $0.text = "foo" })
  152. self.wait(for: [errorExpectation], timeout: self.defaultTestTimeout)
  153. stateChangeDelegate.waitForExpectedChanges(timeout: .seconds(5))
  154. if let nioSSLError = errorRecorder.errors.first as? NIOSSLExtraError {
  155. XCTAssertEqual(nioSSLError, .failedToValidateHostname)
  156. // Expected case.
  157. } else {
  158. XCTFail("Expected NIOSSLExtraError.failedToValidateHostname")
  159. }
  160. }
  161. func testClientConnectionFailsWhenCertificateValidationDenied() throws {
  162. let errorExpectation = self.expectation(description: "error")
  163. // 2 errors: one for the failed handshake, and another for failing the ready-channel promise
  164. // (because the handshake failed).
  165. errorExpectation.expectedFulfillmentCount = 2
  166. let tlsConfiguration = ClientConnection.Configuration.TLS(
  167. certificateChain: [.certificate(SampleCertificate.client.certificate)],
  168. privateKey: .privateKey(SamplePrivateKey.client),
  169. trustRoots: .certificates([SampleCertificate.ca.certificate]),
  170. hostnameOverride: SampleCertificate.server.commonName,
  171. customVerificationCallback: { _, promise in
  172. // The certificate validation is forced to fail
  173. promise.fail(NIOSSLError.unableToValidateCertificate)
  174. }
  175. )
  176. var configuration = self.makeClientConfiguration(tls: tlsConfiguration)
  177. let errorRecorder = ErrorRecordingDelegate(expectation: errorExpectation)
  178. configuration.errorDelegate = errorRecorder
  179. let stateChangeDelegate = RecordingConnectivityDelegate()
  180. stateChangeDelegate.expectChanges(2) { changes in
  181. XCTAssertEqual(changes, [
  182. Change(from: .idle, to: .connecting),
  183. Change(from: .connecting, to: .shutdown),
  184. ])
  185. }
  186. configuration.connectivityStateDelegate = stateChangeDelegate
  187. // Start an RPC to trigger creating a channel.
  188. let echo = Echo_EchoClient(channel: ClientConnection(configuration: configuration))
  189. _ = echo.get(.with { $0.text = "foo" })
  190. self.wait(for: [errorExpectation], timeout: self.defaultTestTimeout)
  191. stateChangeDelegate.waitForExpectedChanges(timeout: .seconds(5))
  192. if let nioSSLError = errorRecorder.errors.first as? NIOSSLError,
  193. case .handshakeFailed(.sslError) = nioSSLError {
  194. // Expected case.
  195. } else {
  196. XCTFail("Expected NIOSSLError.handshakeFailed(BoringSSL.sslError)")
  197. }
  198. }
  199. }