ClientTLSTests.swift 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. /*
  2. * Copyright 2019, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import EchoImplementation
  17. import EchoModel
  18. import Foundation
  19. import GRPC
  20. import GRPCSampleData
  21. import NIOCore
  22. import NIOPosix
  23. import NIOSSL
  24. import XCTest
  25. class ClientTLSHostnameOverrideTests: GRPCTestCase {
  26. var eventLoopGroup: EventLoopGroup!
  27. var server: Server!
  28. var connection: ClientConnection!
  29. override func setUp() {
  30. super.setUp()
  31. self.eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
  32. }
  33. override func tearDown() {
  34. XCTAssertNoThrow(try self.server.close().wait())
  35. XCTAssertNoThrow(try self.connection.close().wait())
  36. XCTAssertNoThrow(try self.eventLoopGroup.syncShutdownGracefully())
  37. super.tearDown()
  38. }
  39. func doTestUnary() throws {
  40. let client = Echo_EchoClient(
  41. channel: self.connection,
  42. defaultCallOptions: self.callOptionsWithLogger
  43. )
  44. let get = client.get(.with { $0.text = "foo" })
  45. let response = try get.response.wait()
  46. XCTAssertEqual(response.text, "Swift echo get: foo")
  47. let status = try get.status.wait()
  48. XCTAssertEqual(status.code, .ok)
  49. }
  50. func testTLSWithHostnameOverride() throws {
  51. // Run a server presenting a certificate for example.com on localhost.
  52. let cert = SampleCertificate.exampleServer.certificate
  53. let key = SamplePrivateKey.exampleServer
  54. self.server = try Server.usingTLSBackedByNIOSSL(
  55. on: self.eventLoopGroup,
  56. certificateChain: [cert],
  57. privateKey: key
  58. )
  59. .withTLS(trustRoots: .certificates([SampleCertificate.ca.certificate]))
  60. .withServiceProviders([EchoProvider()])
  61. .withLogger(self.serverLogger)
  62. .bind(host: "localhost", port: 0)
  63. .wait()
  64. guard let port = self.server.channel.localAddress?.port else {
  65. XCTFail("could not get server port")
  66. return
  67. }
  68. self.connection = ClientConnection.usingTLSBackedByNIOSSL(on: self.eventLoopGroup)
  69. .withTLS(trustRoots: .certificates([SampleCertificate.ca.certificate]))
  70. .withTLS(serverHostnameOverride: "example.com")
  71. .withBackgroundActivityLogger(self.clientLogger)
  72. .connect(host: "localhost", port: port)
  73. try self.doTestUnary()
  74. }
  75. func testTLSWithoutHostnameOverride() throws {
  76. // Run a server presenting a certificate for localhost on localhost.
  77. let cert = SampleCertificate.server.certificate
  78. let key = SamplePrivateKey.server
  79. self.server = try Server.usingTLSBackedByNIOSSL(
  80. on: self.eventLoopGroup,
  81. certificateChain: [cert],
  82. privateKey: key
  83. )
  84. .withTLS(trustRoots: .certificates([SampleCertificate.ca.certificate]))
  85. .withServiceProviders([EchoProvider()])
  86. .withLogger(self.serverLogger)
  87. .bind(host: "localhost", port: 0)
  88. .wait()
  89. guard let port = self.server.channel.localAddress?.port else {
  90. XCTFail("could not get server port")
  91. return
  92. }
  93. self.connection = ClientConnection.usingTLSBackedByNIOSSL(on: self.eventLoopGroup)
  94. .withTLS(trustRoots: .certificates([SampleCertificate.ca.certificate]))
  95. .withBackgroundActivityLogger(self.clientLogger)
  96. .connect(host: "localhost", port: port)
  97. try self.doTestUnary()
  98. }
  99. func testTLSWithNoCertificateVerification() throws {
  100. self.server = try Server.usingTLSBackedByNIOSSL(
  101. on: self.eventLoopGroup,
  102. certificateChain: [SampleCertificate.server.certificate],
  103. privateKey: SamplePrivateKey.server
  104. )
  105. .withServiceProviders([EchoProvider()])
  106. .withLogger(self.serverLogger)
  107. .bind(host: "localhost", port: 0)
  108. .wait()
  109. guard let port = self.server.channel.localAddress?.port else {
  110. XCTFail("could not get server port")
  111. return
  112. }
  113. self.connection = ClientConnection.usingTLSBackedByNIOSSL(on: self.eventLoopGroup)
  114. .withTLS(trustRoots: .certificates([]))
  115. .withTLS(certificateVerification: .none)
  116. .withBackgroundActivityLogger(self.clientLogger)
  117. .connect(host: "localhost", port: port)
  118. try self.doTestUnary()
  119. }
  120. func testAuthorityUsesTLSHostnameOverride() throws {
  121. // This test validates that when suppled with a server hostname override, the client uses it
  122. // as the ":authority" pseudo-header.
  123. self.server = try Server.usingTLSBackedByNIOSSL(
  124. on: self.eventLoopGroup,
  125. certificateChain: [SampleCertificate.exampleServer.certificate],
  126. privateKey: SamplePrivateKey.exampleServer
  127. )
  128. .withTLS(trustRoots: .certificates([SampleCertificate.ca.certificate]))
  129. .withServiceProviders([AuthorityCheckingEcho()])
  130. .withLogger(self.serverLogger)
  131. .bind(host: "localhost", port: 0)
  132. .wait()
  133. guard let port = self.server.channel.localAddress?.port else {
  134. XCTFail("could not get server port")
  135. return
  136. }
  137. self.connection = ClientConnection.usingTLSBackedByNIOSSL(on: self.eventLoopGroup)
  138. .withTLS(trustRoots: .certificates([SampleCertificate.ca.certificate]))
  139. .withTLS(serverHostnameOverride: "example.com")
  140. .withBackgroundActivityLogger(self.clientLogger)
  141. .connect(host: "localhost", port: port)
  142. try self.doTestUnary()
  143. }
  144. }
  145. private class AuthorityCheckingEcho: Echo_EchoProvider {
  146. var interceptors: Echo_EchoServerInterceptorFactoryProtocol?
  147. func get(
  148. request: Echo_EchoRequest,
  149. context: StatusOnlyCallContext
  150. ) -> EventLoopFuture<Echo_EchoResponse> {
  151. guard let authority = context.headers.first(name: ":authority") else {
  152. let status = GRPCStatus(
  153. code: .failedPrecondition,
  154. message: "Missing ':authority' pseudo header"
  155. )
  156. return context.eventLoop.makeFailedFuture(status)
  157. }
  158. XCTAssertEqual(authority, SampleCertificate.exampleServer.commonName)
  159. XCTAssertNotEqual(authority, "localhost")
  160. return context.eventLoop.makeSucceededFuture(.with {
  161. $0.text = "Swift echo get: \(request.text)"
  162. })
  163. }
  164. func expand(
  165. request: Echo_EchoRequest,
  166. context: StreamingResponseCallContext<Echo_EchoResponse>
  167. ) -> EventLoopFuture<GRPCStatus> {
  168. preconditionFailure("Not implemented")
  169. }
  170. func collect(
  171. context: UnaryResponseCallContext<Echo_EchoResponse>
  172. ) -> EventLoopFuture<(StreamEvent<Echo_EchoRequest>) -> Void> {
  173. preconditionFailure("Not implemented")
  174. }
  175. func update(
  176. context: StreamingResponseCallContext<Echo_EchoResponse>
  177. ) -> EventLoopFuture<(StreamEvent<Echo_EchoRequest>) -> Void> {
  178. preconditionFailure("Not implemented")
  179. }
  180. }