Browse Source

Replace Commander usage with our own CLI parsing (#570)

Motivation:

We're relying on third-party code for parsing CLI arguments, however,
this is only being used for interop tests and the echo example and isn't
required to use the GRPC module. However when another package depends on
GRPC, SwiftPM will also bring in the third-party code (and all of its
transitive dependencies) even if no targets used by the other package
rely on that code.

Modifications:

Remove usages of "Commander" and replace with some basic CLI parsing.

Result:

Our CLIs are less colorful, but we drop a dependency.
George Barnett 6 years ago
parent
commit
589611a11a

+ 0 - 18
Package.resolved

@@ -1,24 +1,6 @@
 {
   "object": {
     "pins": [
-      {
-        "package": "Commander",
-        "repositoryURL": "https://github.com/kylef/Commander.git",
-        "state": {
-          "branch": null,
-          "revision": "dc97e80a1cf5df6c5768528f039c65ad8c564712",
-          "version": "0.9.0"
-        }
-      },
-      {
-        "package": "Spectre",
-        "repositoryURL": "https://github.com/kylef/Spectre.git",
-        "state": {
-          "branch": null,
-          "revision": "f14ff47f45642aa5703900980b014c2e9394b6e5",
-          "version": "0.9.0"
-        }
-      },
       {
         "package": "swift-log",
         "repositoryURL": "https://github.com/apple/swift-log",

+ 0 - 6
Package.swift

@@ -42,9 +42,6 @@ let package = Package(
 
     // Logging API.
     .package(url: "https://github.com/apple/swift-log", from: "1.0.0"),
-
-    // Command line argument parser for our auxiliary command line tools.
-    .package(url: "https://github.com/kylef/Commander.git", from: "0.8.0"),
   ],
   targets: [
     // The main GRPC module.
@@ -107,7 +104,6 @@ let package = Package(
       name: "GRPCInteroperabilityTests",
       dependencies: [
         "GRPCInteroperabilityTestsImplementation",
-        "Commander",
         "Logging",
       ]
     ),
@@ -131,7 +127,6 @@ let package = Package(
         "EchoImplementation",
         "NIO",
         "NIOSSL",
-        "Commander",
       ]
     ),
 
@@ -150,7 +145,6 @@ let package = Package(
         "GRPC",
         "GRPCSampleData",
         "SwiftProtobuf",
-        "Commander"
       ],
       path: "Sources/Examples/Echo/Runtime"
     ),

+ 246 - 176
Sources/Examples/Echo/Runtime/main.swift

@@ -13,8 +13,6 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-import Commander
-import Dispatch
 import Foundation
 import NIO
 import NIOSSL
@@ -22,210 +20,282 @@ import GRPC
 import GRPCSampleData
 import EchoImplementation
 import EchoModel
+import Logging
 
-// Common flags and options
-let sslFlag = Flag("ssl", description: "if true, use SSL for connections")
-func addressOption(_ address: String) -> Option<String> {
-  return Option("address", default: address, description: "address of server")
-}
-
-let portOption = Option("port", default: 8080)
-let messageOption = Option("message",
-                           default: "Testing 1 2 3",
-                           description: "message to send")
-
-func makeClientTLSConfiguration() -> ClientConnection.Configuration.TLS {
-  let caCert = SampleCertificate.ca
-  let clientCert = SampleCertificate.client
-  precondition(!caCert.isExpired && !clientCert.isExpired,
-               "SSL certificates are expired. Please submit an issue at https://github.com/grpc/grpc-swift.")
-
-  return .init(certificateChain: [.certificate(clientCert.certificate)],
-               privateKey: .privateKey(SamplePrivateKey.client),
-               trustRoots: .certificates([caCert.certificate]),
-               certificateVerification: .noHostnameVerification)
-}
-
-func makeServerTLSConfiguration() -> Server.Configuration.TLS {
-  let caCert = SampleCertificate.ca
-  let serverCert = SampleCertificate.server
-  precondition(!caCert.isExpired && !serverCert.isExpired,
-               "SSL certificates are expired. Please submit an issue at https://github.com/grpc/grpc-swift.")
+// MARK: - Argument parsing
 
-  return .init(certificateChain: [.certificate(serverCert.certificate)],
-               privateKey: .privateKey(SamplePrivateKey.server),
-               trustRoots: .certificates([caCert.certificate]))
+enum RPC: String {
+  case get
+  case collect
+  case expand
+  case update
 }
 
-/// Create en `EchoClient` and wait for it to initialize. Returns nil if initialisation fails.
-func makeEchoClient(address: String, port: Int, ssl: Bool) -> Echo_EchoServiceClient? {
-  let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
+enum Command {
+  case server(port: Int, useTLS: Bool)
+  case client(host: String, port: Int, useTLS: Bool, rpc: RPC, message: String)
 
-  let configuration = ClientConnection.Configuration(
-    target: .hostAndPort(address, port),
-    eventLoopGroup: eventLoopGroup,
-    tls: ssl ? makeClientTLSConfiguration() : nil)
+  init?(from args: [String]) {
+    guard !args.isEmpty else {
+      return nil
+    }
 
-  return Echo_EchoServiceClient(connection: ClientConnection(configuration: configuration))
-}
+    var args = args
+    switch args.removeFirst() {
+    case "server":
+      guard (args.count == 1 || args.count == 2),
+        let port = args.popLast().flatMap(Int.init),
+        let useTLS = Command.parseTLSArg(args.popLast())
+        else {
+          return nil
+      }
+      self = .server(port: port, useTLS: useTLS)
+
+    case "client":
+      guard (args.count == 4 || args.count == 5),
+        let message = args.popLast(),
+        let rpc = args.popLast().flatMap(RPC.init),
+        let port = args.popLast().flatMap(Int.init),
+        let host = args.popLast(),
+        let useTLS = Command.parseTLSArg(args.popLast())
+        else {
+          return nil
+      }
+      self = .client(host: host, port: port, useTLS: useTLS, rpc: rpc, message: message)
+
+    default:
+      return nil
+    }
+  }
 
-Group {
-  $0.command("serve",
-             sslFlag,
-             addressOption("localhost"),
-             portOption,
-             description: "Run an echo server.") { ssl, address, port in
-    let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
-
-    var configuration = Server.Configuration(
-      target: .hostAndPort(address, port),
-      eventLoopGroup: eventLoopGroup,
-      serviceProviders: [EchoProvider()])
-
-    if ssl {
-      print("starting secure server")
-      configuration.tls = makeServerTLSConfiguration()
-    } else {
-      print("starting insecure server")
+  private static func parseTLSArg(_ arg: String?) -> Bool? {
+    switch arg {
+    case .some("--tls"):
+      return true
+    case .none, .some("--notls"):
+      return false
+    default:
+      return nil
     }
+  }
+}
 
-    let server = try! Server.start(configuration: configuration)
-      .wait()
+func printUsageAndExit(program: String) -> Never {
+  print("""
+    Usage: \(program) COMMAND [OPTIONS...]
+
+    Commands:
+      server [--tls|--notls] PORT                     Starts the echo server on the given port.
+
+      client [--tls|--notls] HOST PORT RPC MESSAGE    Connects to the echo server on the given host
+                                                      host and port and calls the RPC with the
+                                                      provided message. See below for a list of
+                                                      possible RPCs.
+
+    RPCs:
+      * get      (unary)
+      * collect  (client streaming)
+      * expand   (server streaming)
+      * update   (bidirectional streaming)
+    """)
+  exit(1)
+}
 
-    // This blocks to keep the main thread from finishing while the server runs,
-    // but the server never exits. Kill the process to stop it.
-    try server.onClose.wait()
+func main(args: [String]) {
+  var args = args
+  let program = args.removeFirst()
+  guard let command = Command(from: args) else {
+    printUsageAndExit(program: program)
   }
 
-  $0.command(
-    "get",
-    sslFlag,
-    addressOption("localhost"),
-    portOption,
-    messageOption,
-    description: "Perform a unary get()."
-  ) { ssl, address, port, message in
-    print("calling get")
-    guard let echo = makeEchoClient(address: address, port: port, ssl: ssl) else { return }
-
-    var requestMessage = Echo_EchoRequest()
-    requestMessage.text = message
-
-    print("get sending: \(requestMessage.text)")
-    let get = echo.get(requestMessage)
-    get.response.whenSuccess { response in
-      print("get received: \(response.text)")
-    }
+  // Reduce the logging verbosity.
+  LoggingSystem.bootstrap {
+    var handler = StreamLogHandler.standardOutput(label: $0)
+    handler.logLevel = .warning
+    return handler
+  }
 
-    get.response.whenFailure { error in
-      print("get response failed with error: \(error)")
-    }
+  // Okay, we're nearly ready to start, create an `EventLoopGroup` most suitable for our platform.
+  let group = PlatformSupport.makeEventLoopGroup(loopCount: 1)
+  defer {
+    try! group.syncShutdownGracefully()
+  }
 
-    // wait() on the status to stop the program from exiting.
+  // Now run the server/client.
+  switch command {
+  case let .server(port: port, useTLS: useTLS):
     do {
-      let status = try get.status.wait()
-      print("get completed with status: \(status)")
+      try startEchoServer(group: group, port: port, useTLS: useTLS)
     } catch {
-      print("get status failed with error: \(error)")
+      print("Error running server: \(error)")
     }
+
+  case let .client(host: host, port: port, useTLS: useTLS, rpc: rpc, message: message):
+    let client = makeClient(group: group, host: host, port: port, useTLS: useTLS)
+    callRPC(rpc, using: client, message: message)
   }
+}
 
-  $0.command(
-    "expand",
-    sslFlag,
-    addressOption("localhost"),
-    portOption,
-    messageOption,
-    description: "Perform a server-streaming expand()."
-  ) { ssl, address, port, message in
-    print("calling expand")
-    guard let echo = makeEchoClient(address: address, port: port, ssl: ssl) else { return }
-
-    let requestMessage = Echo_EchoRequest.with { $0.text = message }
-
-    print("expand sending: \(requestMessage.text)")
-    let expand = echo.expand(requestMessage) { response in
-      print("expand received: \(response.text)")
-    }
+// MARK: - Server / Client
+
+func startEchoServer(group: EventLoopGroup, port: Int, useTLS: Bool) throws {
+  // Configure the server:
+  var configuration = Server.Configuration(
+    target: .hostAndPort("localhost", port),
+    eventLoopGroup: group,
+    serviceProviders: [EchoProvider()]
+  )
+
+  if useTLS {
+    // We're using some self-signed certs here: check they aren't expired.
+    let caCert = SampleCertificate.ca
+    let serverCert = SampleCertificate.server
+    precondition(
+      !caCert.isExpired && !serverCert.isExpired,
+      "SSL certificates are expired. Please submit an issue at https://github.com/grpc/grpc-swift."
+    )
+
+    configuration.tls = .init(
+      certificateChain: [.certificate(serverCert.certificate)],
+      privateKey: .privateKey(SamplePrivateKey.server),
+      trustRoots: .certificates([caCert.certificate])
+    )
+    print("starting secure server")
+  } else {
+    print("starting insecure server")
+  }
 
-    // wait() on the status to stop the program from exiting.
-    do {
-      let status = try expand.status.wait()
-      print("expand completed with status: \(status)")
-    } catch {
-      print("expand status failed with error: \(error)")
-    }
+  let server = try Server.start(configuration: configuration).wait()
+  print("started server: \(server.channel.localAddress!)")
+
+  // This blocks to keep the main thread from finishing while the server runs,
+  // but the server never exits. Kill the process to stop it.
+  try server.onClose.wait()
+}
+
+func makeClient(group: EventLoopGroup, host: String, port: Int, useTLS: Bool) -> Echo_EchoServiceClient {
+  // Configure the connection:
+  var configuration = ClientConnection.Configuration(
+    target: .hostAndPort(host, port),
+    eventLoopGroup: group
+  )
+
+  if useTLS {
+    // We're using some self-signed certs here: check they aren't expired.
+    let caCert = SampleCertificate.ca
+    let clientCert = SampleCertificate.client
+    precondition(
+      !caCert.isExpired && !clientCert.isExpired,
+      "SSL certificates are expired. Please submit an issue at https://github.com/grpc/grpc-swift."
+    )
+
+    configuration.tls = .init(
+      certificateChain: [.certificate(clientCert.certificate)],
+      privateKey: .privateKey(SamplePrivateKey.client),
+      trustRoots: .certificates([caCert.certificate])
+    )
   }
 
-  $0.command(
-    "collect",
-    sslFlag,
-    addressOption("localhost"),
-    portOption,
-    messageOption,
-    description: "Perform a client-streaming collect()."
-  ) { ssl, address, port, message in
-    print("calling collect")
-    guard let echo = makeEchoClient(address: address, port: port, ssl: ssl) else { return }
-
-    let collect = echo.collect()
-
-    var queue = collect.newMessageQueue()
-    for part in message.components(separatedBy: " ") {
-      var requestMessage = Echo_EchoRequest()
-      requestMessage.text = part
-      print("collect sending: \(requestMessage.text)")
-      queue = queue.flatMap { collect.sendMessage(requestMessage) }
-    }
-    queue.whenSuccess { collect.sendEnd(promise: nil) }
+  // Start the connection and create the client:
+  let connection = ClientConnection(configuration: configuration)
+  return Echo_EchoServiceClient(connection: connection)
+}
 
-    collect.response.whenSuccess { respone in
-      print("collect received: \(respone.text)")
+func callRPC(_ rpc: RPC, using client: Echo_EchoServiceClient, message: String) {
+  do {
+    switch rpc {
+    case .get:
+      try echoGet(client: client, message: message)
+    case .collect:
+      try echoCollect(client: client, message: message)
+    case .expand:
+      try echoExpand(client: client, message: message)
+    case .update:
+      try echoUpdate(client: client, message: message)
     }
+  } catch {
+    print("\(rpc) RPC failed: \(error)")
+  }
+}
 
-    collect.response.whenFailure { error in
-      print("collect response failed with error: \(error)")
+func echoGet(client: Echo_EchoServiceClient, message: String) throws {
+  // Get is a unary call.
+  let get = client.get(.with { $0.text = message })
+
+  // Register a callback for the response:
+  get.response.whenComplete { result in
+    switch result {
+    case .success(let response):
+      print("get receieved: \(response.text)")
+    case .failure(let error):
+      print("get failed with error: \(error)")
     }
+  }
 
-    // wait() on the status to stop the program from exiting.
-    do {
-      let status = try collect.status.wait()
-      print("collect completed with status: \(status)")
-    } catch {
-      print("collect status failed with error: \(error)")
-    }
+  // wait() for the call to terminate
+  let status = try get.status.wait()
+  print("get completed with status: \(status.code)")
+}
+
+func echoCollect(client: Echo_EchoServiceClient, message: String) throws {
+  // Collect is a client streaming call
+  let collect = client.collect()
+
+  // Split the messages and map them into requests
+  let messages = message.components(separatedBy: " ").map { part in
+    Echo_EchoRequest.with { $0.text = part }
   }
 
-  $0.command(
-    "update",
-    sslFlag,
-    addressOption("localhost"),
-    portOption,
-    messageOption,
-    description: "Perform a bidirectional-streaming update()."
-  ) { ssl, address, port, message in
-    print("calling update")
-    guard let echo = makeEchoClient(address: address, port: port, ssl: ssl) else { return }
-
-    let update = echo.update { response in
-      print("update received: \(response.text)")
+  // Stream the to the service (this can also be done on individual requests using `sendMessage`).
+  collect.sendMessages(messages, promise: nil)
+  // Close the request stream.
+  collect.sendEnd(promise: nil)
+
+  // Register a callback for the response:
+  collect.response.whenComplete { result in
+    switch result {
+    case .success(let response):
+      print("collect receieved: \(response.text)")
+    case .failure(let error):
+      print("collect failed with error: \(error)")
     }
+  }
 
-    var queue = update.newMessageQueue()
-    for part in message.components(separatedBy: " ") {
-      var requestMessage = Echo_EchoRequest()
-      requestMessage.text = part
-      print("update sending: \(requestMessage.text)")
-      queue = queue.flatMap { update.sendMessage(requestMessage) }
-    }
-    queue.whenSuccess { update.sendEnd(promise: nil) }
+  // wait() for the call to terminate
+  let status = try collect.status.wait()
+  print("collect completed with status: \(status.code)")
+}
 
-    // wait() on the status to stop the program from exiting.
-    do {
-      let status = try update.status.wait()
-      print("update completed with status: \(status)")
-    } catch {
-      print("update status failed with error: \(error)")
-    }
+func echoExpand(client: Echo_EchoServiceClient, message: String) throws {
+  // Expand is a server streaming call; provide a response handler.
+  let expand = client.expand(.with { $0.text = message}) { response in
+    print("expand received: \(response.text)")
+  }
+
+  // wait() for the call to terminate
+  let status = try expand.status.wait()
+  print("expand completed with status: \(status.code)")
+}
+
+func echoUpdate(client: Echo_EchoServiceClient, message: String) throws {
+  // Update is a bidirectional streaming call; provide a response handler.
+  let update = client.update { response in
+    print("update received: \(response.text)")
   }
-}.run()
+
+  // Split the messages and map them into requests
+  let messages = message.components(separatedBy: " ").map { part in
+    Echo_EchoRequest.with { $0.text = part }
+  }
+
+  // Stream the to the service (this can also be done on individual requests using `sendMessage`).
+  update.sendMessages(messages, promise: nil)
+  // Close the request stream.
+  update.sendEnd(promise: nil)
+
+  // wait() for the call to terminate
+  let status = try update.status.wait()
+  print("update completed with status: \(status.code)")
+}
+
+main(args: CommandLine.arguments)

+ 106 - 93
Sources/GRPCInteroperabilityTests/main.swift

@@ -18,7 +18,6 @@ import GRPC
 import NIO
 import NIOSSL
 import GRPCInteroperabilityTestsImplementation
-import Commander
 import Logging
 
 // Reduce stdout noise.
@@ -70,116 +69,130 @@ func makeRunnableTest(name: String) throws -> InteroperabilityTest {
   return testCase.makeTest()
 }
 
-/// Runs the given block and exits with code 1 if the block throws an error.
-///
-/// The "Commander" CLI elides thrown errors in favour of its own. This function is intended purely
-/// to work around this limitation by printing any errors before exiting.
-func exitOnThrow<T>(block: () throws -> T) -> T {
-  do {
-    return try block()
-  } catch {
-    print(error)
-    exit(1)
-  }
-}
-
 // MARK: - Command line options and "main".
 
-let serverHostOption = Option(
-  "server_host",
-  default: "localhost",
-  description: "The server host to connect to.")
-
-let serverPortOption = Option(
-  "server_port",
-  default: 8080,
-  description: "The server port to connect to.")
-
-let testCaseOption = Option(
-  "test_case",
-  default: InteroperabilityTestCase.emptyUnary.name,
-  description: "The name of the test case to execute.")
-
-/// The spec requires a string (as opposed to having a flag) to indicate whether TLS is enabled or
-/// disabled.
-let useTLSOption = Option(
-  "use_tls",
-  default: "false",
-  description: "Whether to use an encrypted or plaintext connection (true|false).") { value in
-  let lowercased = value.lowercased()
-  switch lowercased {
-  case "true", "false":
-    return lowercased
-  default:
-    throw ArgumentError.invalidType(value: value, type: "boolean", argument: "use_tls")
-  }
+func printUsageAndExit(program: String) -> Never {
+  print("""
+    Usage: \(program) COMMAND [OPTIONS...]
+
+    Commands:
+      start_server [--tls|--notls] PORT         Starts the interoperability test server.
+
+      run_test [--tls|--notls] HOST PORT NAME   Run an interoperability test.
+
+      list_tests                                List all interoperability test names.
+    """)
+  exit(1)
 }
 
-let portOption = Option(
-  "port",
-  default: 8080,
-  description: "The port to listen on.")
-
-let group = Group { group in
-  group.command(
-    "run_test",
-    serverHostOption,
-    serverPortOption,
-    useTLSOption,
-    testCaseOption,
-    description: "Run a single test. See 'list_tests' for available test names."
-  ) { host, port, useTLS, testCaseName in
-    let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
-    defer {
-      try? eventLoopGroup.syncShutdownGracefully()
+enum Command {
+  case startServer(port: Int, useTLS: Bool)
+  case runTest(name: String, host: String, port: Int, useTLS: Bool)
+  case listTests
+
+  init?(from args: [String]) {
+    guard !args.isEmpty else {
+      return nil
     }
 
-    exitOnThrow {
-      let instance = try makeRunnableTest(name: testCaseName)
-      let connection = try makeInteroperabilityTestClientConnection(
-        host: host,
-        port: port,
-        eventLoopGroup: eventLoopGroup,
-        useTLS: useTLS == "true")
-      try runTest(instance, name: testCaseName, connection: connection)
+    var args = args
+    let command = args.removeFirst()
+    switch command {
+    case "start_server":
+      guard (args.count == 2 || args.count == 3),
+        let port = args.popLast().flatMap(Int.init),
+        let useTLS = Command.parseTLSArg(args.popLast())
+        else {
+          return nil
+      }
+      self = .startServer(port: port, useTLS: useTLS)
+
+    case "run_test":
+      guard (args.count == 3 || args.count == 4),
+        let name = args.popLast(),
+        let port = args.popLast().flatMap(Int.init),
+        let host = args.popLast(),
+        let useTLS = Command.parseTLSArg(args.popLast())
+        else {
+          return nil
+      }
+      self = .runTest(name: name, host: host, port: port, useTLS: useTLS)
+
+    case "list_tests":
+      self = .listTests
+
+    default:
+      return nil
     }
   }
 
-  group.command(
-    "start_server",
-    portOption,
-    useTLSOption,
-    description: "Starts the test server."
-  ) { port, useTls in
-    let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
+  private static func parseTLSArg(_ arg: String?) -> Bool? {
+    switch arg {
+    case .some("--tls"):
+      return true
+    case .none, .some("--notls"):
+      return false
+    default:
+      return nil
+    }
+  }
+}
+
+func main(args: [String]) {
+  let program = args.first ?? "GRPC Interoperability Tests"
+  guard let command = Command(from: .init(args.dropFirst())) else {
+    printUsageAndExit(program: program)
+  }
+
+  switch command {
+  case .listTests:
+    InteroperabilityTestCase.allCases.forEach {
+      print($0.name)
+    }
+
+  case let .startServer(port: port, useTLS: useTLS):
+    let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
     defer {
-      try? eventLoopGroup.syncShutdownGracefully()
+      try! group.syncShutdownGracefully()
     }
 
-    let server = exitOnThrow {
-      return try makeInteroperabilityTestServer(
-        host: "localhost",
-        port: port,
-        eventLoopGroup: eventLoopGroup,
-        useTLS: useTls == "true")
+    do {
+      let server = try makeInteroperabilityTestServer(port: port, eventLoopGroup: group, useTLS: useTLS).wait()
+      print("server started: \(server.channel.localAddress!)")
+
+      // We never call close; run until we get killed.
+      try server.onClose.wait()
+    } catch {
+      print("unable to start interoperability test server")
     }
 
-    server.map { $0.channel.localAddress?.port }.whenSuccess {
-      print("Server started on port \($0!)")
+  case let .runTest(name: name, host: host, port: port, useTLS: useTLS):
+    let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
+    defer {
+      try! group.syncShutdownGracefully()
     }
 
-    // We never call close; run until we get killed.
-    try server.flatMap { $0.onClose }.wait()
-  }
+    let test: InteroperabilityTest
+    do {
+      test = try makeRunnableTest(name: name)
+    } catch {
+      print("\(error)")
+      exit(1)
+    }
 
-  group.command(
-    "list_tests",
-    description: "List available test case names."
-  ) {
-    InteroperabilityTestCase.allCases.forEach {
-      print($0.name)
+    do {
+      let connection = try makeInteroperabilityTestClientConnection(
+        host: host,
+        port: port,
+        eventLoopGroup: group,
+        useTLS: useTLS
+      )
+      try runTest(test, name: name, connection: connection)
+    } catch {
+      print("Error running test: \(error)")
+      exit(1)
     }
   }
 }
 
-group.run()
+main(args: CommandLine.arguments)

+ 172 - 97
Sources/GRPCPerformanceTests/main.swift

@@ -17,9 +17,9 @@ import Foundation
 import GRPC
 import NIO
 import NIOSSL
-import Commander
 import EchoImplementation
 import EchoModel
+import Logging
 
 struct ConnectionFactory {
   var configuration: ClientConnection.Configuration
@@ -317,116 +317,191 @@ enum Benchmarks: String, CaseIterable {
   }
 }
 
-let hostOption = Option(
-  "host",
-  // Use IPv4 to avoid the happy eyeballs delay, this is important when we test the
-  // connection throughput.
-  default: "127.0.0.1",
-  description: "The host to connect to.")
-
-let portOption = Option(
-  "port",
-  default: 8080,
-  description: "The port on the host to connect to.")
-
-let benchmarkOption = Option(
-  "benchmarks",
-  default: Benchmarks.allCases.map { $0.rawValue }.joined(separator: ","),
-  description: "A comma separated list of benchmarks to run. Defaults to all benchmarks.")
-
-let caCertificateOption = Option(
-  "ca_certificate",
-  default: "",
-  description: "The path to the CA certificate to use.")
-
-let certificateOption = Option(
-  "certificate",
-  default: "",
-  description: "The path to the certificate to use.")
-
-let privateKeyOption = Option(
-  "private_key",
-  default: "",
-  description: "The path to the private key to use.")
-
-let hostOverrideOption = Option(
-  "hostname_override",
-  default: "",
-  description: "The expected name of the server to use for TLS.")
-
-Group { group in
-  group.command(
-    "run_benchmarks",
-    benchmarkOption,
-    hostOption,
-    portOption,
-    caCertificateOption,
-    certificateOption,
-    privateKeyOption,
-    hostOverrideOption
-  ) { benchmarkNames, host, port, caCertificatePath, certificatePath, privateKeyPath, hostOverride in
-    let tlsConfiguration = try makeClientTLSConfiguration(
-      caCertificatePath: caCertificatePath,
-      certificatePath: certificatePath,
-      privateKeyPath: privateKeyPath)
-
-    let configuration = ClientConnection.Configuration(
-      target: .hostAndPort(host, port),
-      eventLoopGroup: MultiThreadedEventLoopGroup(numberOfThreads: 1),
-      tls: tlsConfiguration)
-
-    let factory = ConnectionFactory(configuration: configuration)
-
-    let names = benchmarkNames.components(separatedBy: ",")
-
-    // validate the benchmarks exist before running any
-    let benchmarks = names.map { name -> Benchmarks in
-      guard let benchnark = Benchmarks(rawValue: name) else {
-        print("unknown benchmark: \(name)")
-        exit(1)
-      }
-      return benchnark
+enum Command {
+  case listBenchmarks
+  case benchmark(name: String, host: String, port: Int, tls: (ca: String, cert: String)?)
+  case server(port: Int, tls: (ca: String, cert: String, key: String)?)
+
+  init?(from args: [String]) {
+    guard !args.isEmpty else {
+      return nil
     }
 
-    benchmarks.forEach { benchmark in
-      let results = benchmark.run(using: factory)
-      print(results.asCSV)
+    var args = args
+    let command = args.removeFirst()
+    switch command {
+    case "server":
+      guard let port = args.popLast().flatMap(Int.init) else {
+        return nil
+      }
+
+      let caPath = args.suffixOfFirst(prefixedWith: "--caPath=")
+      let certPath = args.suffixOfFirst(prefixedWith: "--certPath=")
+      let keyPath = args.suffixOfFirst(prefixedWith: "--keyPath=")
+
+      // We need all or nothing here:
+      switch (caPath, certPath, keyPath) {
+      case let (.some(ca), .some(cert), .some(key)):
+        self = .server(port: port, tls: (ca: ca, cert: cert, key: key))
+      case (.none, .none, .none):
+        self = .server(port: port, tls: nil)
+      default:
+        return nil
+      }
+
+    case "benchmark":
+      guard let name = args.popLast(),
+        let port = args.popLast().flatMap(Int.init),
+        let host = args.popLast()
+        else {
+          return nil
+      }
+
+      let caPath = args.suffixOfFirst(prefixedWith: "--caPath=")
+      let certPath = args.suffixOfFirst(prefixedWith: "--certPath=")
+      // We need all or nothing here:
+      switch (caPath, certPath) {
+      case let (.some(ca), .some(cert)):
+        self = .benchmark(name: name, host: host, port: port, tls: (ca: ca, cert: cert))
+      case (.none, .none):
+        self = .benchmark(name: name, host: host, port: port, tls: nil)
+      default:
+        return nil
+      }
+
+    case "list_benchmarks":
+      self = .listBenchmarks
+
+    default:
+      return nil
     }
   }
+}
 
-  group.command(
-    "start_server",
-    hostOption,
-    portOption,
-    caCertificateOption,
-    certificateOption,
-    privateKeyOption
-  ) { host, port, caCertificatePath, certificatePath, privateKeyPath in
-    let group = MultiThreadedEventLoopGroup(numberOfThreads: System.coreCount)
+func printUsageAndExit(program: String) -> Never {
+  print("""
+  Usage: \(program) COMMAND [OPTIONS...]
+
+  benchmark:
+    Run the given benchmark (see 'list_benchmarks' for possible options) against a server on the
+    specified host and port. TLS may be used by spefifying the path to the PEM formatted
+    certificate and CA certificate.
 
-    let tlsConfiguration = try makeServerTLSConfiguration(
-      caCertificatePath: caCertificatePath,
-      certificatePath: certificatePath,
-      privateKeyPath: privateKeyPath)
+      benchmark [--ca=CA --cert=CERT] HOST PORT BENCHMARK_NAME
 
-    let configuration = Server.Configuration(
-      target: .hostAndPort(host, port),
-      eventLoopGroup: group,
-      serviceProviders: [EchoProvider()],
-      tls: tlsConfiguration)
+    Note: eiether all or none of CA and CERT must be provided.
 
-    let server: Server
+  list_benchmarks:
+    List the available benchmarks to run.
+
+  server:
+    Start the server on the given PORT. TLS may be used by specifying the paths to the PEM formatted
+    certificate, private key and CA certificate.
+
+      server [--ca=CA --cert=CERT --key=KEY] PORT
+
+    Note: eiether all or none of CA, CERT and KEY must be provided.
+  """)
+  exit(1)
+}
+
+fileprivate extension Array where Element == String {
+  func suffixOfFirst(prefixedWith prefix: String) -> String? {
+    return self.first {
+      $0.hasPrefix(prefix)
+    }.map {
+      String($0.dropFirst(prefix.count))
+    }
+  }
+}
+
+func main(args: [String]) {
+  var args = args
+  let program = args.removeFirst()
+  guard let command = Command(from: args) else {
+    printUsageAndExit(program: program)
+  }
+
+  switch command {
+  case let .server(port: port, tls: tls):
+    let group = MultiThreadedEventLoopGroup(numberOfThreads: System.coreCount)
+    defer {
+      try! group.syncShutdownGracefully()
+    }
+
+    // Quieten the logs.
+    LoggingSystem.bootstrap {
+      var handler = StreamLogHandler.standardOutput(label: $0)
+      handler.logLevel = .warning
+      return handler
+    }
 
     do {
-      server = try Server.start(configuration: configuration).wait()
+      let configuration = try Server.Configuration(
+        target: .hostAndPort("localhost", port),
+        eventLoopGroup: group,
+        serviceProviders: [EchoProvider()],
+        tls: tls.map { tlsArgs in
+          return .init(
+            certificateChain: try NIOSSLCertificate.fromPEMFile(tlsArgs.cert).map { .certificate($0) },
+            privateKey: .file(tlsArgs.key),
+            trustRoots: .file(tlsArgs.ca)
+          )
+        }
+      )
+
+      let server = try Server.start(configuration: configuration).wait()
+      print("server started on port: \(server.channel.localAddress?.port ?? port)")
+
+      // Stop the program from exiting.
+      try? server.onClose.wait()
     } catch {
       print("unable to start server: \(error)")
       exit(1)
     }
 
-    print("server started on port: \(server.channel.localAddress?.port ?? port)")
+  case let .benchmark(name: name, host: host, port: port, tls: tls):
+    guard let benchmark = Benchmarks(rawValue: name) else {
+      printUsageAndExit(program: program)
+    }
+
+    // Quieten the logs.
+    LoggingSystem.bootstrap {
+      var handler = StreamLogHandler.standardOutput(label: $0)
+      handler.logLevel = .critical
+      return handler
+    }
+
+    let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
+    defer {
+      try! group.syncShutdownGracefully()
+    }
+
+    do {
+      let configuration = try ClientConnection.Configuration(
+        target: .hostAndPort(host, port),
+        eventLoopGroup: group,
+        tls: tls.map { tlsArgs in
+          return .init(
+            certificateChain: try NIOSSLCertificate.fromPEMFile(tlsArgs.cert).map { .certificate($0) },
+            trustRoots: .file(tlsArgs.ca)
+          )
+        }
+      )
+
+      let factory = ConnectionFactory(configuration: configuration)
+      let results = benchmark.run(using: factory)
+      print(results.asCSV)
+    } catch {
+      print("unable to run benchmark: \(error)")
+      exit(1)
+    }
 
-    // Stop the program from exiting.
-    try? server.onClose.wait()
+  case .listBenchmarks:
+    Benchmarks.allCases.forEach {
+      print($0.rawValue)
+    }
   }
-}.run()
+}
+
+main(args: CommandLine.arguments)