main.swift 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413
  1. import Foundation
  2. import GRPC
  3. import NIO
  4. import NIOSSL
  5. import Commander
  6. struct ConnectionFactory {
  7. var configuration: ClientConnection.Configuration
  8. func makeConnection() -> ClientConnection {
  9. return ClientConnection(configuration: self.configuration)
  10. }
  11. func makeEchoClient() -> Echo_EchoServiceClient {
  12. return Echo_EchoServiceClient(connection: self.makeConnection())
  13. }
  14. }
  15. protocol Benchmark: class {
  16. func setUp() throws
  17. func tearDown() throws
  18. func run() throws
  19. }
  20. /// Tests unary throughput by sending requests on a single connection.
  21. ///
  22. /// Requests are sent in batches of (up-to) 100 requests. This is due to
  23. /// https://github.com/apple/swift-nio-http2/issues/87#issuecomment-483542401.
  24. class UnaryThroughput: Benchmark {
  25. let factory: ConnectionFactory
  26. let requests: Int
  27. let requestLength: Int
  28. var client: Echo_EchoServiceClient!
  29. var request: String!
  30. init(factory: ConnectionFactory, requests: Int, requestLength: Int) {
  31. self.factory = factory
  32. self.requests = requests
  33. self.requestLength = requestLength
  34. }
  35. func setUp() throws {
  36. self.client = self.factory.makeEchoClient()
  37. self.request = String(repeating: "0", count: self.requestLength)
  38. }
  39. func run() throws {
  40. let batchSize = 100
  41. for lowerBound in stride(from: 0, to: self.requests, by: batchSize) {
  42. let upperBound = min(lowerBound + batchSize, self.requests)
  43. let requests = (lowerBound..<upperBound).map { _ in
  44. client.get(Echo_EchoRequest.with { $0.text = self.request }).response
  45. }
  46. try EventLoopFuture.andAllSucceed(requests, on: self.client.connection.eventLoop).wait()
  47. }
  48. }
  49. func tearDown() throws {
  50. try self.client.connection.close().wait()
  51. }
  52. }
  53. /// Tests bidirectional throughput by sending requests over a single stream.
  54. ///
  55. /// Requests are sent in batches of (up-to) 100 requests. This is due to
  56. /// https://github.com/apple/swift-nio-http2/issues/87#issuecomment-483542401.
  57. class BidirectionalThroughput: UnaryThroughput {
  58. override func run() throws {
  59. let update = self.client.update { _ in }
  60. for _ in 0..<self.requests {
  61. update.sendMessage(Echo_EchoRequest.with { $0.text = self.request }, promise: nil)
  62. }
  63. update.sendEnd(promise: nil)
  64. _ = try update.status.wait()
  65. }
  66. }
  67. /// Tests the number of connections that can be created.
  68. final class ConnectionCreationThroughput: Benchmark {
  69. let factory: ConnectionFactory
  70. let connections: Int
  71. var createdConnections: [ClientConnection] = []
  72. class ConnectionReadinessDelegate: ConnectivityStateDelegate {
  73. let promise: EventLoopPromise<Void>
  74. var ready: EventLoopFuture<Void> {
  75. return promise.futureResult
  76. }
  77. init(promise: EventLoopPromise<Void>) {
  78. self.promise = promise
  79. }
  80. func connectivityStateDidChange(from oldState: ConnectivityState, to newState: ConnectivityState) {
  81. switch newState {
  82. case .ready:
  83. promise.succeed(())
  84. case .shutdown:
  85. promise.fail(GRPCStatus(code: .unavailable, message: nil))
  86. default:
  87. break
  88. }
  89. }
  90. }
  91. init(factory: ConnectionFactory, connections: Int) {
  92. self.factory = factory
  93. self.connections = connections
  94. }
  95. func setUp() throws { }
  96. func run() throws {
  97. let connectionsAndDelegates: [(ClientConnection, ConnectionReadinessDelegate)] = (0..<connections).map { _ in
  98. let promise = self.factory.configuration.eventLoopGroup.next().makePromise(of: Void.self)
  99. var configuration = self.factory.configuration
  100. let delegate = ConnectionReadinessDelegate(promise: promise)
  101. configuration.connectivityStateDelegate = delegate
  102. return (ClientConnection(configuration: configuration), delegate)
  103. }
  104. self.createdConnections = connectionsAndDelegates.map { connection, _ in connection }
  105. let futures = connectionsAndDelegates.map { _, delegate in delegate.ready }
  106. try EventLoopFuture.andAllSucceed(
  107. futures,
  108. on: self.factory.configuration.eventLoopGroup.next()
  109. ).wait()
  110. }
  111. func tearDown() throws {
  112. let connectionClosures = self.createdConnections.map {
  113. $0.close()
  114. }
  115. try EventLoopFuture.andAllSucceed(
  116. connectionClosures,
  117. on: self.factory.configuration.eventLoopGroup.next()).wait()
  118. }
  119. }
  120. /// The results of a benchmark.
  121. struct BenchmarkResults {
  122. let benchmarkDescription: String
  123. let durations: [TimeInterval]
  124. /// Returns the results as a comma separated string.
  125. ///
  126. /// The format of the string is as such:
  127. /// <name>, <number of results> [, <duration>]
  128. var asCSV: String {
  129. let items = [self.benchmarkDescription, String(self.durations.count)] + self.durations.map { String($0) }
  130. return items.joined(separator: ", ")
  131. }
  132. }
  133. /// Runs the given benchmark multiple times, recording the wall time for each iteration.
  134. ///
  135. /// - Parameter description: A description of the benchmark.
  136. /// - Parameter benchmark: The benchmark to run.
  137. /// - Parameter repeats: The number of times to run the benchmark.
  138. func measure(description: String, benchmark: Benchmark, repeats: Int) -> BenchmarkResults {
  139. var durations: [TimeInterval] = []
  140. for _ in 0..<repeats {
  141. do {
  142. try benchmark.setUp()
  143. let start = Date()
  144. try benchmark.run()
  145. let end = Date()
  146. durations.append(end.timeIntervalSince(start))
  147. } catch {
  148. // If tearDown fails now then there's not a lot we can do!
  149. try? benchmark.tearDown()
  150. return BenchmarkResults(benchmarkDescription: description, durations: [])
  151. }
  152. do {
  153. try benchmark.tearDown()
  154. } catch {
  155. return BenchmarkResults(benchmarkDescription: description, durations: [])
  156. }
  157. }
  158. return BenchmarkResults(benchmarkDescription: description, durations: durations)
  159. }
  160. /// Makes an SSL context if one is required. Note that the CLI tool doesn't support optional values,
  161. /// so we use empty strings for the paths if we don't require SSL.
  162. ///
  163. /// This function will terminate the program if it is not possible to create an SSL context.
  164. ///
  165. /// - Parameter caCertificatePath: The path to the CA certificate PEM file.
  166. /// - Parameter certificatePath: The path to the certificate.
  167. /// - Parameter privateKeyPath: The path to the private key.
  168. /// - Parameter server: Whether this is for the server or not.
  169. private func makeSSLContext(caCertificatePath: String, certificatePath: String, privateKeyPath: String, server: Bool) -> NIOSSLContext? {
  170. // Commander doesn't have Optional options; we use empty strings to indicate no value.
  171. guard certificatePath.isEmpty == privateKeyPath.isEmpty &&
  172. privateKeyPath.isEmpty == caCertificatePath.isEmpty else {
  173. print("Paths for CA certificate, certificate and private key must be provided")
  174. exit(1)
  175. }
  176. // No need to check them all because of the guard statement above.
  177. if caCertificatePath.isEmpty {
  178. return nil
  179. }
  180. let configuration: TLSConfiguration
  181. if server {
  182. configuration = .forServer(
  183. certificateChain: [.file(certificatePath)],
  184. privateKey: .file(privateKeyPath),
  185. trustRoots: .file(caCertificatePath),
  186. applicationProtocols: ["h2"]
  187. )
  188. } else {
  189. configuration = .forClient(
  190. trustRoots: .file(caCertificatePath),
  191. certificateChain: [.file(certificatePath)],
  192. privateKey: .file(privateKeyPath),
  193. applicationProtocols: ["h2"]
  194. )
  195. }
  196. do {
  197. return try NIOSSLContext(configuration: configuration)
  198. } catch {
  199. print("Unable to create SSL context: \(error)")
  200. exit(1)
  201. }
  202. }
  203. enum Benchmarks: String, CaseIterable {
  204. case unaryThroughputSmallRequests = "unary_throughput_small"
  205. case unaryThroughputLargeRequests = "unary_throughput_large"
  206. case bidirectionalThroughputSmallRequests = "bidi_throughput_small"
  207. case bidirectionalThroughputLargeRequests = "bidi_throughput_large"
  208. case connectionThroughput = "connection_throughput"
  209. static let smallRequest = 8
  210. static let largeRequest = 1 << 16
  211. var description: String {
  212. switch self {
  213. case .unaryThroughputSmallRequests:
  214. return "10k unary requests of size \(Benchmarks.smallRequest)"
  215. case .unaryThroughputLargeRequests:
  216. return "10k unary requests of size \(Benchmarks.largeRequest)"
  217. case .bidirectionalThroughputSmallRequests:
  218. return "20k bidirectional messages of size \(Benchmarks.smallRequest)"
  219. case .bidirectionalThroughputLargeRequests:
  220. return "10k bidirectional messages of size \(Benchmarks.largeRequest)"
  221. case .connectionThroughput:
  222. return "100 connections created"
  223. }
  224. }
  225. func makeBenchmark(factory: ConnectionFactory) -> Benchmark {
  226. switch self {
  227. case .unaryThroughputSmallRequests:
  228. return UnaryThroughput(factory: factory, requests: 10_000, requestLength: Benchmarks.smallRequest)
  229. case .unaryThroughputLargeRequests:
  230. return UnaryThroughput(factory: factory, requests: 10_000, requestLength: Benchmarks.largeRequest)
  231. case .bidirectionalThroughputSmallRequests:
  232. return BidirectionalThroughput(factory: factory, requests: 20_000, requestLength: Benchmarks.smallRequest)
  233. case .bidirectionalThroughputLargeRequests:
  234. return BidirectionalThroughput(factory: factory, requests: 10_000, requestLength: Benchmarks.largeRequest)
  235. case .connectionThroughput:
  236. return ConnectionCreationThroughput(factory: factory, connections: 100)
  237. }
  238. }
  239. func run(using factory: ConnectionFactory, repeats: Int = 10) -> BenchmarkResults {
  240. let benchmark = self.makeBenchmark(factory: factory)
  241. return measure(description: self.description, benchmark: benchmark, repeats: repeats)
  242. }
  243. }
  244. let hostOption = Option(
  245. "host",
  246. // Use IPv4 to avoid the happy eyeballs delay, this is important when we test the
  247. // connection throughput.
  248. default: "127.0.0.1",
  249. description: "The host to connect to.")
  250. let portOption = Option(
  251. "port",
  252. default: 8080,
  253. description: "The port on the host to connect to.")
  254. let benchmarkOption = Option(
  255. "benchmarks",
  256. default: Benchmarks.allCases.map { $0.rawValue }.joined(separator: ","),
  257. description: "A comma separated list of benchmarks to run. Defaults to all benchmarks.")
  258. let caCertificateOption = Option(
  259. "ca_certificate",
  260. default: "",
  261. description: "The path to the CA certificate to use.")
  262. let certificateOption = Option(
  263. "certificate",
  264. default: "",
  265. description: "The path to the certificate to use.")
  266. let privateKeyOption = Option(
  267. "private_key",
  268. default: "",
  269. description: "The path to the private key to use.")
  270. let hostOverrideOption = Option(
  271. "hostname_override",
  272. default: "",
  273. description: "The expected name of the server to use for TLS.")
  274. Group { group in
  275. group.command(
  276. "run_benchmarks",
  277. benchmarkOption,
  278. hostOption,
  279. portOption,
  280. caCertificateOption,
  281. certificateOption,
  282. privateKeyOption,
  283. hostOverrideOption
  284. ) { benchmarkNames, host, port, caCertificatePath, certificatePath, privateKeyPath, hostOverride in
  285. var configuration = ClientConnection.Configuration(
  286. target: .hostAndPort(host, port),
  287. eventLoopGroup: MultiThreadedEventLoopGroup(numberOfThreads: 1))
  288. let sslContext = makeSSLContext(
  289. caCertificatePath: caCertificatePath,
  290. certificatePath: certificatePath,
  291. privateKeyPath: privateKeyPath,
  292. server: false)
  293. if let sslContext = sslContext {
  294. configuration.tlsConfiguration = .init(sslContext: sslContext, hostnameOverride: hostOverride)
  295. }
  296. let factory = ConnectionFactory(configuration: configuration)
  297. let names = benchmarkNames.components(separatedBy: ",")
  298. // validate the benchmarks exist before running any
  299. let benchmarks = names.map { name -> Benchmarks in
  300. guard let benchnark = Benchmarks(rawValue: name) else {
  301. print("unknown benchmark: \(name)")
  302. exit(1)
  303. }
  304. return benchnark
  305. }
  306. benchmarks.forEach { benchmark in
  307. let results = benchmark.run(using: factory)
  308. print(results.asCSV)
  309. }
  310. }
  311. group.command(
  312. "start_server",
  313. hostOption,
  314. portOption,
  315. caCertificateOption,
  316. certificateOption,
  317. privateKeyOption
  318. ) { host, port, caCertificatePath, certificatePath, privateKeyPath in
  319. let group = MultiThreadedEventLoopGroup(numberOfThreads: System.coreCount)
  320. let sslContext = makeSSLContext(
  321. caCertificatePath: caCertificatePath,
  322. certificatePath: certificatePath,
  323. privateKeyPath: privateKeyPath,
  324. server: true)
  325. let configuration = Server.Configuration(
  326. target: .hostAndPort(host, port),
  327. eventLoopGroup: group,
  328. serviceProviders: [EchoProvider()],
  329. tlsConfiguration: sslContext.map { .init(sslContext: $0) })
  330. let server: Server
  331. do {
  332. server = try Server.start(configuration: configuration).wait()
  333. } catch {
  334. print("unable to start server: \(error)")
  335. exit(1)
  336. }
  337. print("server started on port: \(server.channel.localAddress?.port ?? port)")
  338. // Stop the program from exiting.
  339. try? server.onClose.wait()
  340. }
  341. }.run()