ClientTLSTests.swift 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import Foundation
  2. import GRPC
  3. import GRPCSampleData
  4. import NIO
  5. import NIOSSL
  6. import XCTest
  7. class ClientTLSHostnameOverrideTests: GRPCTestCase {
  8. var eventLoopGroup: EventLoopGroup!
  9. var server: Server!
  10. var connection: ClientConnection!
  11. override func setUp() {
  12. super.setUp()
  13. self.eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
  14. }
  15. override func tearDown() {
  16. super.tearDown()
  17. XCTAssertNoThrow(try self.server.close().wait())
  18. XCTAssertNoThrow(try connection.close().wait())
  19. XCTAssertNoThrow(try self.eventLoopGroup.syncShutdownGracefully())
  20. }
  21. func makeEchoServer(tls: Server.Configuration.TLS) throws -> Server {
  22. let configuration: Server.Configuration = .init(
  23. target: .hostAndPort("localhost", 0),
  24. eventLoopGroup: self.eventLoopGroup,
  25. serviceProviders: [EchoProvider()],
  26. tls: tls
  27. )
  28. return try Server.start(configuration: configuration).wait()
  29. }
  30. func makeConnection(port: Int, tls: ClientConnection.Configuration.TLS) -> ClientConnection {
  31. let configuration: ClientConnection.Configuration = .init(
  32. target: .hostAndPort("localhost", port),
  33. eventLoopGroup: self.eventLoopGroup,
  34. tls: tls
  35. )
  36. return ClientConnection(configuration: configuration)
  37. }
  38. func doTestUnary() throws {
  39. let client = Echo_EchoServiceClient(connection: self.connection)
  40. let get = client.get(.with { $0.text = "foo" })
  41. let response = try get.response.wait()
  42. XCTAssertEqual(response.text, "Swift echo get: foo")
  43. let status = try get.status.wait()
  44. XCTAssertEqual(status.code, .ok)
  45. }
  46. func testTLSWithHostnameOverride() throws {
  47. // Run a server presenting a certificate for example.com on localhost.
  48. let serverTLS: Server.Configuration.TLS = .init(
  49. certificateChain: [.certificate(SampleCertificate.exampleServer.certificate)],
  50. privateKey: .privateKey(SamplePrivateKey.exampleServer),
  51. trustRoots: .certificates([SampleCertificate.ca.certificate])
  52. )
  53. self.server = try makeEchoServer(tls: serverTLS)
  54. guard let port = self.server.channel.localAddress?.port else {
  55. XCTFail("could not get server port")
  56. return
  57. }
  58. let clientTLS: ClientConnection.Configuration.TLS = .init(
  59. trustRoots: .certificates([SampleCertificate.ca.certificate]),
  60. hostnameOverride: "example.com"
  61. )
  62. self.connection = self.makeConnection(port: port, tls: clientTLS)
  63. try self.doTestUnary()
  64. }
  65. func testTLSWithoutHostnameOverride() throws {
  66. // Run a server presenting a certificate for localhost on localhost.
  67. let serverTLS: Server.Configuration.TLS = .init(
  68. certificateChain: [.certificate(SampleCertificate.server.certificate)],
  69. privateKey: .privateKey(SamplePrivateKey.server),
  70. trustRoots: .certificates([SampleCertificate.ca.certificate])
  71. )
  72. self.server = try makeEchoServer(tls: serverTLS)
  73. guard let port = self.server.channel.localAddress?.port else {
  74. XCTFail("could not get server port")
  75. return
  76. }
  77. let clientTLS: ClientConnection.Configuration.TLS = .init(
  78. trustRoots: .certificates([SampleCertificate.ca.certificate])
  79. )
  80. self.connection = self.makeConnection(port: port, tls: clientTLS)
  81. try self.doTestUnary()
  82. }
  83. }