main.swift 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387
  1. import Foundation
  2. import GRPC
  3. import NIO
  4. import NIOSSL
  5. import Commander
  6. struct ConnectionFactory {
  7. var host: String
  8. var port: Int
  9. var group: EventLoopGroup
  10. var context: NIOSSLContext?
  11. var serverHostOverride: String?
  12. func makeConnection() throws -> EventLoopFuture<GRPCClientConnection> {
  13. return try GRPCClientConnection.start(
  14. host: self.host,
  15. port: self.port,
  16. eventLoopGroup: self.group,
  17. tls: self.context.map { .custom($0) } ?? .none,
  18. hostOverride: self.serverHostOverride)
  19. }
  20. func makeEchoClient() throws -> EventLoopFuture<Echo_EchoServiceClient> {
  21. return try self.makeConnection().map {
  22. Echo_EchoServiceClient(connection: $0)
  23. }
  24. }
  25. }
  26. protocol Benchmark: class {
  27. func setUp() throws
  28. func tearDown() throws
  29. func run() throws
  30. }
  31. /// Tests unary throughput by sending requests on a single connection.
  32. ///
  33. /// Requests are sent in batches of (up-to) 100 requests. This is due to
  34. /// https://github.com/apple/swift-nio-http2/issues/87#issuecomment-483542401.
  35. class UnaryThroughput: Benchmark {
  36. let factory: ConnectionFactory
  37. let requests: Int
  38. let requestLength: Int
  39. var client: Echo_EchoServiceClient!
  40. var request: String!
  41. init(factory: ConnectionFactory, requests: Int, requestLength: Int) {
  42. self.factory = factory
  43. self.requests = requests
  44. self.requestLength = requestLength
  45. }
  46. func setUp() throws {
  47. self.client = try self.factory.makeEchoClient().wait()
  48. self.request = String(repeating: "0", count: self.requestLength)
  49. }
  50. func run() throws {
  51. let batchSize = 100
  52. for lowerBound in stride(from: 0, to: self.requests, by: batchSize) {
  53. let upperBound = min(lowerBound + batchSize, self.requests)
  54. let requests = (lowerBound..<upperBound).map { _ in
  55. client.get(Echo_EchoRequest.with { $0.text = self.request }).response
  56. }
  57. try EventLoopFuture.andAllSucceed(requests, on: self.client.connection.channel.eventLoop).wait()
  58. }
  59. }
  60. func tearDown() throws {
  61. try self.client.connection.close().wait()
  62. }
  63. }
  64. /// Tests bidirectional throughput by sending requests over a single stream.
  65. ///
  66. /// Requests are sent in batches of (up-to) 100 requests. This is due to
  67. /// https://github.com/apple/swift-nio-http2/issues/87#issuecomment-483542401.
  68. class BidirectionalThroughput: UnaryThroughput {
  69. override func run() throws {
  70. let update = self.client.update { _ in }
  71. for _ in 0..<self.requests {
  72. update.sendMessage(Echo_EchoRequest.with { $0.text = self.request }, promise: nil)
  73. }
  74. update.sendEnd(promise: nil)
  75. _ = try update.status.wait()
  76. }
  77. }
  78. /// Tests the number of connections that can be created.
  79. final class ConnectionCreationThroughput: Benchmark {
  80. let factory: ConnectionFactory
  81. let connections: Int
  82. var createdConnections: [EventLoopFuture<GRPCClientConnection>] = []
  83. init(factory: ConnectionFactory, connections: Int) {
  84. self.factory = factory
  85. self.connections = connections
  86. }
  87. func setUp() throws { }
  88. func run() throws {
  89. self.createdConnections = try (0..<connections).map { _ in
  90. try self.factory.makeConnection()
  91. }
  92. try EventLoopFuture.andAllSucceed(self.createdConnections, on: self.factory.group.next()).wait()
  93. }
  94. func tearDown() throws {
  95. let connectionClosures = self.createdConnections.map {
  96. $0.flatMap {
  97. $0.close()
  98. }
  99. }
  100. try EventLoopFuture.andAllSucceed(connectionClosures, on: self.factory.group.next()).wait()
  101. }
  102. }
  103. /// The results of a benchmark.
  104. struct BenchmarkResults {
  105. let benchmarkDescription: String
  106. let durations: [TimeInterval]
  107. /// Returns the results as a comma separated string.
  108. ///
  109. /// The format of the string is as such:
  110. /// <name>, <number of results> [, <duration>]
  111. var asCSV: String {
  112. let items = [self.benchmarkDescription, String(self.durations.count)] + self.durations.map { String($0) }
  113. return items.joined(separator: ", ")
  114. }
  115. }
  116. /// Runs the given benchmark multiple times, recording the wall time for each iteration.
  117. ///
  118. /// - Parameter description: A description of the benchmark.
  119. /// - Parameter benchmark: The benchmark to run.
  120. /// - Parameter repeats: The number of times to run the benchmark.
  121. func measure(description: String, benchmark: Benchmark, repeats: Int) -> BenchmarkResults {
  122. var durations: [TimeInterval] = []
  123. for _ in 0..<repeats {
  124. do {
  125. try benchmark.setUp()
  126. let start = Date()
  127. try benchmark.run()
  128. let end = Date()
  129. durations.append(end.timeIntervalSince(start))
  130. } catch {
  131. // If tearDown fails now then there's not a lot we can do!
  132. try? benchmark.tearDown()
  133. return BenchmarkResults(benchmarkDescription: description, durations: [])
  134. }
  135. do {
  136. try benchmark.tearDown()
  137. } catch {
  138. return BenchmarkResults(benchmarkDescription: description, durations: [])
  139. }
  140. }
  141. return BenchmarkResults(benchmarkDescription: description, durations: durations)
  142. }
  143. /// Makes an SSL context if one is required. Note that the CLI tool doesn't support optional values,
  144. /// so we use empty strings for the paths if we don't require SSL.
  145. ///
  146. /// This function will terminate the program if it is not possible to create an SSL context.
  147. ///
  148. /// - Parameter caCertificatePath: The path to the CA certificate PEM file.
  149. /// - Parameter certificatePath: The path to the certificate.
  150. /// - Parameter privateKeyPath: The path to the private key.
  151. /// - Parameter server: Whether this is for the server or not.
  152. private func makeSSLContext(caCertificatePath: String, certificatePath: String, privateKeyPath: String, server: Bool) -> NIOSSLContext? {
  153. // Commander doesn't have Optional options; we use empty strings to indicate no value.
  154. guard certificatePath.isEmpty == privateKeyPath.isEmpty &&
  155. privateKeyPath.isEmpty == caCertificatePath.isEmpty else {
  156. print("Paths for CA certificate, certificate and private key must be provided")
  157. exit(1)
  158. }
  159. // No need to check them all because of the guard statement above.
  160. if caCertificatePath.isEmpty {
  161. return nil
  162. }
  163. let configuration: TLSConfiguration
  164. if server {
  165. configuration = .forServer(
  166. certificateChain: [.file(certificatePath)],
  167. privateKey: .file(privateKeyPath),
  168. trustRoots: .file(caCertificatePath),
  169. applicationProtocols: ["h2"]
  170. )
  171. } else {
  172. configuration = .forClient(
  173. trustRoots: .file(caCertificatePath),
  174. certificateChain: [.file(certificatePath)],
  175. privateKey: .file(privateKeyPath),
  176. applicationProtocols: ["h2"]
  177. )
  178. }
  179. do {
  180. return try NIOSSLContext(configuration: configuration)
  181. } catch {
  182. print("Unable to create SSL context: \(error)")
  183. exit(1)
  184. }
  185. }
  186. enum Benchmarks: String, CaseIterable {
  187. case unaryThroughputSmallRequests = "unary_throughput_small"
  188. case unaryThroughputLargeRequests = "unary_throughput_large"
  189. case bidirectionalThroughputSmallRequests = "bidi_throughput_small"
  190. case bidirectionalThroughputLargeRequests = "bidi_throughput_large"
  191. case connectionThroughput = "connection_throughput"
  192. static let smallRequest = 8
  193. static let largeRequest = 1 << 16
  194. var description: String {
  195. switch self {
  196. case .unaryThroughputSmallRequests:
  197. return "10k unary requests of size \(Benchmarks.smallRequest)"
  198. case .unaryThroughputLargeRequests:
  199. return "10k unary requests of size \(Benchmarks.largeRequest)"
  200. case .bidirectionalThroughputSmallRequests:
  201. return "20k bidirectional messages of size \(Benchmarks.smallRequest)"
  202. case .bidirectionalThroughputLargeRequests:
  203. return "10k bidirectional messages of size \(Benchmarks.largeRequest)"
  204. case .connectionThroughput:
  205. return "100 connections created"
  206. }
  207. }
  208. func makeBenchmark(factory: ConnectionFactory) -> Benchmark {
  209. switch self {
  210. case .unaryThroughputSmallRequests:
  211. return UnaryThroughput(factory: factory, requests: 10_000, requestLength: Benchmarks.smallRequest)
  212. case .unaryThroughputLargeRequests:
  213. return UnaryThroughput(factory: factory, requests: 10_000, requestLength: Benchmarks.largeRequest)
  214. case .bidirectionalThroughputSmallRequests:
  215. return BidirectionalThroughput(factory: factory, requests: 20_000, requestLength: Benchmarks.smallRequest)
  216. case .bidirectionalThroughputLargeRequests:
  217. return BidirectionalThroughput(factory: factory, requests: 10_000, requestLength: Benchmarks.largeRequest)
  218. case .connectionThroughput:
  219. return ConnectionCreationThroughput(factory: factory, connections: 100)
  220. }
  221. }
  222. func run(using factory: ConnectionFactory, repeats: Int = 10) -> BenchmarkResults {
  223. let benchmark = self.makeBenchmark(factory: factory)
  224. return measure(description: self.description, benchmark: benchmark, repeats: repeats)
  225. }
  226. }
  227. let hostOption = Option(
  228. "host",
  229. // Use IPv4 to avoid the happy eyeballs delay, this is important when we test the
  230. // connection throughput.
  231. default: "127.0.0.1",
  232. description: "The host to connect to.")
  233. let portOption = Option(
  234. "port",
  235. default: 8080,
  236. description: "The port on the host to connect to.")
  237. let benchmarkOption = Option(
  238. "benchmarks",
  239. default: Benchmarks.allCases.map { $0.rawValue }.joined(separator: ","),
  240. description: "A comma separated list of benchmarks to run. Defaults to all benchmarks.")
  241. let caCertificateOption = Option(
  242. "ca_certificate",
  243. default: "",
  244. description: "The path to the CA certificate to use.")
  245. let certificateOption = Option(
  246. "certificate",
  247. default: "",
  248. description: "The path to the certificate to use.")
  249. let privateKeyOption = Option(
  250. "private_key",
  251. default: "",
  252. description: "The path to the private key to use.")
  253. let hostOverrideOption = Option(
  254. "hostname_override",
  255. default: "",
  256. description: "The expected name of the server to use for TLS.")
  257. Group { group in
  258. group.command(
  259. "run_benchmarks",
  260. benchmarkOption,
  261. hostOption,
  262. portOption,
  263. caCertificateOption,
  264. certificateOption,
  265. privateKeyOption,
  266. hostOverrideOption
  267. ) { benchmarkNames, host, port, caCertificatePath, certificatePath, privateKeyPath, hostOverride in
  268. let sslContext = makeSSLContext(
  269. caCertificatePath: caCertificatePath,
  270. certificatePath: certificatePath,
  271. privateKeyPath: privateKeyPath,
  272. server: false)
  273. let factory = ConnectionFactory(
  274. host: host,
  275. port: port,
  276. group: MultiThreadedEventLoopGroup(numberOfThreads: 1),
  277. context: sslContext,
  278. serverHostOverride: hostOverride.isEmpty ? nil : hostOverride)
  279. let names = benchmarkNames.components(separatedBy: ",")
  280. // validate the benchmarks exist before running any
  281. let benchmarks = names.map { name -> Benchmarks in
  282. guard let benchnark = Benchmarks(rawValue: name) else {
  283. print("unknown benchmark: \(name)")
  284. exit(1)
  285. }
  286. return benchnark
  287. }
  288. benchmarks.forEach { benchmark in
  289. let results = benchmark.run(using: factory)
  290. print(results.asCSV)
  291. }
  292. }
  293. group.command(
  294. "start_server",
  295. hostOption,
  296. portOption,
  297. caCertificateOption,
  298. certificateOption,
  299. privateKeyOption
  300. ) { host, port, caCertificatePath, certificatePath, privateKeyPath in
  301. let group = MultiThreadedEventLoopGroup(numberOfThreads: System.coreCount)
  302. let sslContext = makeSSLContext(
  303. caCertificatePath: caCertificatePath,
  304. certificatePath: certificatePath,
  305. privateKeyPath: privateKeyPath,
  306. server: true)
  307. let server: GRPCServer
  308. do {
  309. server = try GRPCServer.start(
  310. hostname: host,
  311. port: port,
  312. eventLoopGroup: group,
  313. serviceProviders: [EchoProvider()],
  314. tls: sslContext.map { .custom($0) } ?? .none).wait()
  315. } catch {
  316. print("unable to start server: \(error)")
  317. exit(1)
  318. }
  319. print("server started on port: \(server.channel.localAddress?.port ?? port)")
  320. // Stop the program from exiting.
  321. try? server.onClose.wait()
  322. }
  323. }.run()