2
0

main.swift 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. import Foundation
  2. import GRPC
  3. import NIO
  4. import NIOSSL
  5. import Commander
  6. struct ConnectionFactory {
  7. var configuration: GRPCClientConnection.Configuration
  8. func makeConnection() throws -> EventLoopFuture<GRPCClientConnection> {
  9. return GRPCClientConnection.start(configuration)
  10. }
  11. func makeEchoClient() throws -> EventLoopFuture<Echo_EchoServiceClient> {
  12. return try self.makeConnection().map {
  13. Echo_EchoServiceClient(connection: $0)
  14. }
  15. }
  16. }
  17. protocol Benchmark: class {
  18. func setUp() throws
  19. func tearDown() throws
  20. func run() throws
  21. }
  22. /// Tests unary throughput by sending requests on a single connection.
  23. ///
  24. /// Requests are sent in batches of (up-to) 100 requests. This is due to
  25. /// https://github.com/apple/swift-nio-http2/issues/87#issuecomment-483542401.
  26. class UnaryThroughput: Benchmark {
  27. let factory: ConnectionFactory
  28. let requests: Int
  29. let requestLength: Int
  30. var client: Echo_EchoServiceClient!
  31. var request: String!
  32. init(factory: ConnectionFactory, requests: Int, requestLength: Int) {
  33. self.factory = factory
  34. self.requests = requests
  35. self.requestLength = requestLength
  36. }
  37. func setUp() throws {
  38. self.client = try self.factory.makeEchoClient().wait()
  39. self.request = String(repeating: "0", count: self.requestLength)
  40. }
  41. func run() throws {
  42. let batchSize = 100
  43. for lowerBound in stride(from: 0, to: self.requests, by: batchSize) {
  44. let upperBound = min(lowerBound + batchSize, self.requests)
  45. let requests = (lowerBound..<upperBound).map { _ in
  46. client.get(Echo_EchoRequest.with { $0.text = self.request }).response
  47. }
  48. try EventLoopFuture.andAllSucceed(requests, on: self.client.connection.channel.eventLoop).wait()
  49. }
  50. }
  51. func tearDown() throws {
  52. try self.client.connection.close().wait()
  53. }
  54. }
  55. /// Tests bidirectional throughput by sending requests over a single stream.
  56. ///
  57. /// Requests are sent in batches of (up-to) 100 requests. This is due to
  58. /// https://github.com/apple/swift-nio-http2/issues/87#issuecomment-483542401.
  59. class BidirectionalThroughput: UnaryThroughput {
  60. override func run() throws {
  61. let update = self.client.update { _ in }
  62. for _ in 0..<self.requests {
  63. update.sendMessage(Echo_EchoRequest.with { $0.text = self.request }, promise: nil)
  64. }
  65. update.sendEnd(promise: nil)
  66. _ = try update.status.wait()
  67. }
  68. }
  69. /// Tests the number of connections that can be created.
  70. final class ConnectionCreationThroughput: Benchmark {
  71. let factory: ConnectionFactory
  72. let connections: Int
  73. var createdConnections: [EventLoopFuture<GRPCClientConnection>] = []
  74. init(factory: ConnectionFactory, connections: Int) {
  75. self.factory = factory
  76. self.connections = connections
  77. }
  78. func setUp() throws { }
  79. func run() throws {
  80. self.createdConnections = try (0..<connections).map { _ in
  81. try self.factory.makeConnection()
  82. }
  83. try EventLoopFuture.andAllSucceed(
  84. self.createdConnections,
  85. on: self.factory.configuration.eventLoopGroup.next()).wait()
  86. }
  87. func tearDown() throws {
  88. let connectionClosures = self.createdConnections.map {
  89. $0.flatMap {
  90. $0.close()
  91. }
  92. }
  93. try EventLoopFuture.andAllSucceed(
  94. connectionClosures,
  95. on: self.factory.configuration.eventLoopGroup.next()).wait()
  96. }
  97. }
  98. /// The results of a benchmark.
  99. struct BenchmarkResults {
  100. let benchmarkDescription: String
  101. let durations: [TimeInterval]
  102. /// Returns the results as a comma separated string.
  103. ///
  104. /// The format of the string is as such:
  105. /// <name>, <number of results> [, <duration>]
  106. var asCSV: String {
  107. let items = [self.benchmarkDescription, String(self.durations.count)] + self.durations.map { String($0) }
  108. return items.joined(separator: ", ")
  109. }
  110. }
  111. /// Runs the given benchmark multiple times, recording the wall time for each iteration.
  112. ///
  113. /// - Parameter description: A description of the benchmark.
  114. /// - Parameter benchmark: The benchmark to run.
  115. /// - Parameter repeats: The number of times to run the benchmark.
  116. func measure(description: String, benchmark: Benchmark, repeats: Int) -> BenchmarkResults {
  117. var durations: [TimeInterval] = []
  118. for _ in 0..<repeats {
  119. do {
  120. try benchmark.setUp()
  121. let start = Date()
  122. try benchmark.run()
  123. let end = Date()
  124. durations.append(end.timeIntervalSince(start))
  125. } catch {
  126. // If tearDown fails now then there's not a lot we can do!
  127. try? benchmark.tearDown()
  128. return BenchmarkResults(benchmarkDescription: description, durations: [])
  129. }
  130. do {
  131. try benchmark.tearDown()
  132. } catch {
  133. return BenchmarkResults(benchmarkDescription: description, durations: [])
  134. }
  135. }
  136. return BenchmarkResults(benchmarkDescription: description, durations: durations)
  137. }
  138. /// Makes an SSL context if one is required. Note that the CLI tool doesn't support optional values,
  139. /// so we use empty strings for the paths if we don't require SSL.
  140. ///
  141. /// This function will terminate the program if it is not possible to create an SSL context.
  142. ///
  143. /// - Parameter caCertificatePath: The path to the CA certificate PEM file.
  144. /// - Parameter certificatePath: The path to the certificate.
  145. /// - Parameter privateKeyPath: The path to the private key.
  146. /// - Parameter server: Whether this is for the server or not.
  147. private func makeSSLContext(caCertificatePath: String, certificatePath: String, privateKeyPath: String, server: Bool) -> NIOSSLContext? {
  148. // Commander doesn't have Optional options; we use empty strings to indicate no value.
  149. guard certificatePath.isEmpty == privateKeyPath.isEmpty &&
  150. privateKeyPath.isEmpty == caCertificatePath.isEmpty else {
  151. print("Paths for CA certificate, certificate and private key must be provided")
  152. exit(1)
  153. }
  154. // No need to check them all because of the guard statement above.
  155. if caCertificatePath.isEmpty {
  156. return nil
  157. }
  158. let configuration: TLSConfiguration
  159. if server {
  160. configuration = .forServer(
  161. certificateChain: [.file(certificatePath)],
  162. privateKey: .file(privateKeyPath),
  163. trustRoots: .file(caCertificatePath),
  164. applicationProtocols: ["h2"]
  165. )
  166. } else {
  167. configuration = .forClient(
  168. trustRoots: .file(caCertificatePath),
  169. certificateChain: [.file(certificatePath)],
  170. privateKey: .file(privateKeyPath),
  171. applicationProtocols: ["h2"]
  172. )
  173. }
  174. do {
  175. return try NIOSSLContext(configuration: configuration)
  176. } catch {
  177. print("Unable to create SSL context: \(error)")
  178. exit(1)
  179. }
  180. }
  181. enum Benchmarks: String, CaseIterable {
  182. case unaryThroughputSmallRequests = "unary_throughput_small"
  183. case unaryThroughputLargeRequests = "unary_throughput_large"
  184. case bidirectionalThroughputSmallRequests = "bidi_throughput_small"
  185. case bidirectionalThroughputLargeRequests = "bidi_throughput_large"
  186. case connectionThroughput = "connection_throughput"
  187. static let smallRequest = 8
  188. static let largeRequest = 1 << 16
  189. var description: String {
  190. switch self {
  191. case .unaryThroughputSmallRequests:
  192. return "10k unary requests of size \(Benchmarks.smallRequest)"
  193. case .unaryThroughputLargeRequests:
  194. return "10k unary requests of size \(Benchmarks.largeRequest)"
  195. case .bidirectionalThroughputSmallRequests:
  196. return "20k bidirectional messages of size \(Benchmarks.smallRequest)"
  197. case .bidirectionalThroughputLargeRequests:
  198. return "10k bidirectional messages of size \(Benchmarks.largeRequest)"
  199. case .connectionThroughput:
  200. return "100 connections created"
  201. }
  202. }
  203. func makeBenchmark(factory: ConnectionFactory) -> Benchmark {
  204. switch self {
  205. case .unaryThroughputSmallRequests:
  206. return UnaryThroughput(factory: factory, requests: 10_000, requestLength: Benchmarks.smallRequest)
  207. case .unaryThroughputLargeRequests:
  208. return UnaryThroughput(factory: factory, requests: 10_000, requestLength: Benchmarks.largeRequest)
  209. case .bidirectionalThroughputSmallRequests:
  210. return BidirectionalThroughput(factory: factory, requests: 20_000, requestLength: Benchmarks.smallRequest)
  211. case .bidirectionalThroughputLargeRequests:
  212. return BidirectionalThroughput(factory: factory, requests: 10_000, requestLength: Benchmarks.largeRequest)
  213. case .connectionThroughput:
  214. return ConnectionCreationThroughput(factory: factory, connections: 100)
  215. }
  216. }
  217. func run(using factory: ConnectionFactory, repeats: Int = 10) -> BenchmarkResults {
  218. let benchmark = self.makeBenchmark(factory: factory)
  219. return measure(description: self.description, benchmark: benchmark, repeats: repeats)
  220. }
  221. }
  222. let hostOption = Option(
  223. "host",
  224. // Use IPv4 to avoid the happy eyeballs delay, this is important when we test the
  225. // connection throughput.
  226. default: "127.0.0.1",
  227. description: "The host to connect to.")
  228. let portOption = Option(
  229. "port",
  230. default: 8080,
  231. description: "The port on the host to connect to.")
  232. let benchmarkOption = Option(
  233. "benchmarks",
  234. default: Benchmarks.allCases.map { $0.rawValue }.joined(separator: ","),
  235. description: "A comma separated list of benchmarks to run. Defaults to all benchmarks.")
  236. let caCertificateOption = Option(
  237. "ca_certificate",
  238. default: "",
  239. description: "The path to the CA certificate to use.")
  240. let certificateOption = Option(
  241. "certificate",
  242. default: "",
  243. description: "The path to the certificate to use.")
  244. let privateKeyOption = Option(
  245. "private_key",
  246. default: "",
  247. description: "The path to the private key to use.")
  248. let hostOverrideOption = Option(
  249. "hostname_override",
  250. default: "",
  251. description: "The expected name of the server to use for TLS.")
  252. Group { group in
  253. group.command(
  254. "run_benchmarks",
  255. benchmarkOption,
  256. hostOption,
  257. portOption,
  258. caCertificateOption,
  259. certificateOption,
  260. privateKeyOption,
  261. hostOverrideOption
  262. ) { benchmarkNames, host, port, caCertificatePath, certificatePath, privateKeyPath, hostOverride in
  263. var configuration = GRPCClientConnection.Configuration(
  264. target: .hostAndPort(host, port),
  265. eventLoopGroup: MultiThreadedEventLoopGroup(numberOfThreads: 1))
  266. let sslContext = makeSSLContext(
  267. caCertificatePath: caCertificatePath,
  268. certificatePath: certificatePath,
  269. privateKeyPath: privateKeyPath,
  270. server: false)
  271. if let sslContext = sslContext {
  272. configuration.tlsConfiguration = .init(sslContext: sslContext, hostnameOverride: hostOverride)
  273. }
  274. let factory = ConnectionFactory(configuration: configuration)
  275. let names = benchmarkNames.components(separatedBy: ",")
  276. // validate the benchmarks exist before running any
  277. let benchmarks = names.map { name -> Benchmarks in
  278. guard let benchnark = Benchmarks(rawValue: name) else {
  279. print("unknown benchmark: \(name)")
  280. exit(1)
  281. }
  282. return benchnark
  283. }
  284. benchmarks.forEach { benchmark in
  285. let results = benchmark.run(using: factory)
  286. print(results.asCSV)
  287. }
  288. }
  289. group.command(
  290. "start_server",
  291. hostOption,
  292. portOption,
  293. caCertificateOption,
  294. certificateOption,
  295. privateKeyOption
  296. ) { host, port, caCertificatePath, certificatePath, privateKeyPath in
  297. let group = MultiThreadedEventLoopGroup(numberOfThreads: System.coreCount)
  298. let sslContext = makeSSLContext(
  299. caCertificatePath: caCertificatePath,
  300. certificatePath: certificatePath,
  301. privateKeyPath: privateKeyPath,
  302. server: true)
  303. let server: GRPCServer
  304. do {
  305. server = try GRPCServer.start(
  306. hostname: host,
  307. port: port,
  308. eventLoopGroup: group,
  309. serviceProviders: [EchoProvider()],
  310. tls: sslContext.map { .custom($0) } ?? .none).wait()
  311. } catch {
  312. print("unable to start server: \(error)")
  313. exit(1)
  314. }
  315. print("server started on port: \(server.channel.localAddress?.port ?? port)")
  316. // Stop the program from exiting.
  317. try? server.onClose.wait()
  318. }
  319. }.run()