Explorar o código

Add a server builder (#752)

Motivation:

Configuring a server with err... `Configuration` can be a bit unwieldy
and makes it easy for users to not use TLS. It's also a slight SemVer burden
since we'd want to add a new `init` each time we add a member.

Modifications:

- Add a `Server.Builder`
- Replace existing use of `Server.Configuration` with the builder
- Update docs

Result:

- Server starting API is nicer
- You have to type `insecure` to not use TLS
George Barnett %!s(int64=5) %!d(string=hai) anos
pai
achega
9a8d0239ad

+ 8 - 12
Sources/Examples/Echo/Runtime/main.swift

@@ -139,12 +139,7 @@ func main(args: [String]) {
 // MARK: - Server / Client
 
 func startEchoServer(group: EventLoopGroup, port: Int, useTLS: Bool) throws {
-  // Configure the server:
-  var configuration = Server.Configuration(
-    target: .hostAndPort("localhost", port),
-    eventLoopGroup: group,
-    serviceProviders: [EchoProvider()]
-  )
+  let builder: Server.Builder
 
   if useTLS {
     // We're using some self-signed certs here: check they aren't expired.
@@ -155,17 +150,18 @@ func startEchoServer(group: EventLoopGroup, port: Int, useTLS: Bool) throws {
       "SSL certificates are expired. Please submit an issue at https://github.com/grpc/grpc-swift."
     )
 
-    configuration.tls = .init(
-      certificateChain: [.certificate(serverCert.certificate)],
-      privateKey: .privateKey(SamplePrivateKey.server),
-      trustRoots: .certificates([caCert.certificate])
-    )
+    builder = Server.secure(group: group, certificateChain: [serverCert.certificate], privateKey: SamplePrivateKey.server)
+      .withTLS(trustRoots: .certificates([caCert.certificate]))
     print("starting secure server")
   } else {
     print("starting insecure server")
+    builder = Server.insecure(group: group)
   }
 
-  let server = try Server.start(configuration: configuration).wait()
+  let server = try builder.withServiceProviders([EchoProvider()])
+    .bind(host: "localhost", port: port)
+    .wait()
+
   print("started server: \(server.channel.localAddress!)")
 
   // This blocks to keep the main thread from finishing while the server runs,

+ 4 - 8
Sources/Examples/HelloWorld/Server/main.swift

@@ -30,15 +30,11 @@ defer {
   try! group.syncShutdownGracefully()
 }
 
-// Create some configuration for the server:
-let configuration = Server.Configuration(
-  target: .hostAndPort("localhost", 0),
-  eventLoopGroup: group,
-  serviceProviders: [GreeterProvider()]
-)
-
 // Start the server and print its address once it has started.
-let server = Server.start(configuration: configuration)
+let server = Server.insecure(group: group)
+  .withServiceProviders([GreeterProvider()])
+  .bind(host: "localhost", port: 0)
+
 server.map {
   $0.channel.localAddress
 }.whenSuccess { address in

+ 4 - 8
Sources/Examples/RouteGuide/Server/main.swift

@@ -51,15 +51,11 @@ func main(args: [String]) throws {
   // Create a provider using the features we read.
   let provider = RouteGuideProvider(features: features)
 
-  // Tie these together in some configuration:
-  let configuration = Server.Configuration(
-    target: .hostAndPort("localhost", 0),
-    eventLoopGroup: group,
-    serviceProviders: [provider]
-  )
-
   // Start the server and print its address once it has started.
-  let server = Server.start(configuration: configuration)
+  let server = Server.insecure(group: group)
+    .withServiceProviders([provider])
+    .bind(host: "localhost", port: 0)
+
   server.map {
     $0.channel.localAddress
   }.whenSuccess { address in

+ 123 - 0
Sources/GRPC/ServerBuilder.swift

@@ -0,0 +1,123 @@
+/*
+ * Copyright 2020, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import NIO
+import NIOSSL
+
+extension Server {
+  public class Builder {
+    private let group: EventLoopGroup
+    private var maybeTLS: Server.Configuration.TLS? { return nil }
+    private var providers: [CallHandlerProvider] = []
+    private var errorDelegate: ServerErrorDelegate?
+    private var messageEncoding: ServerMessageEncoding = .disabled
+
+    fileprivate init(group: EventLoopGroup) {
+      self.group = group
+    }
+
+    public class Secure: Builder {
+      private var tls: Server.Configuration.TLS
+      override var maybeTLS: Server.Configuration.TLS? {
+        return self.tls
+      }
+
+      fileprivate init(group: EventLoopGroup, certificateChain: [NIOSSLCertificate], privateKey: NIOSSLPrivateKey) {
+        self.tls = .init(
+          certificateChain: certificateChain.map { .certificate($0) },
+          privateKey: .privateKey(privateKey)
+        )
+        super.init(group: group)
+      }
+    }
+
+    public func bind(host: String, port: Int) -> EventLoopFuture<Server> {
+      let configuration = Server.Configuration(
+        target: .hostAndPort(host, port),
+        eventLoopGroup: self.group,
+        serviceProviders: self.providers,
+        errorDelegate: self.errorDelegate,
+        tls: self.maybeTLS,
+        messageEncoding: self.messageEncoding
+      )
+      return Server.start(configuration: configuration)
+    }
+  }
+}
+
+extension Server.Builder {
+  /// Sets the server error delegate.
+  @discardableResult
+  public func withErrorDelegate(_ delegate: ServerErrorDelegate?) -> Self {
+    self.errorDelegate = delegate
+    return self
+  }
+}
+
+extension Server.Builder {
+  /// Sets the service providers that this server should offer. Note that calling this multiple
+  /// times will override any previously set providers.
+  public func withServiceProviders(_ providers: [CallHandlerProvider]) -> Self {
+    self.providers = providers
+    return self
+  }
+}
+
+extension Server.Builder {
+  /// Sets the message compression configuration. Compression is disabled if this is not configured
+  /// and any RPCs using compression will not be accepted.
+  public func withMessageCompression(_ encoding: ServerMessageEncoding) -> Self {
+    self.messageEncoding = encoding
+    return self
+  }
+}
+
+extension Server.Builder.Secure {
+  /// Sets the trust roots to use to validate certificates. This only needs to be provided if you
+  /// intend to validate certificates. Defaults to the system provided trust store (`.default`) if
+  /// not set.
+  @discardableResult
+  public func withTLS(trustRoots: NIOSSLTrustRoots) -> Self {
+    self.tls.trustRoots = trustRoots
+    return self
+  }
+
+  /// Sets whether certificates should be verified. Defaults to `.fullVerification` if not set.
+  @discardableResult
+  public func withTLS(certificateVerification: CertificateVerification) -> Self {
+    self.tls.certificateVerification = certificateVerification
+    return self
+  }
+}
+
+extension Server {
+  /// Returns an insecure `Server` builder which is *not configured with TLS*.
+  public static func insecure(group: EventLoopGroup) -> Builder {
+    return Builder(group: group)
+  }
+
+  /// Returns a `Server` builder configured with TLS.
+  public static func secure(
+    group: EventLoopGroup,
+    certificateChain: [NIOSSLCertificate],
+    privateKey: NIOSSLPrivateKey
+  ) -> Builder.Secure {
+    return Builder.Secure(
+      group: group,
+      certificateChain: certificateChain,
+      privateKey: privateKey
+    )
+  }
+}

+ 9 - 12
Sources/GRPCInteroperabilityTestsImplementation/InteroperabilityTestServer.swift

@@ -37,12 +37,7 @@ public func makeInteroperabilityTestServer(
   serviceProviders: [CallHandlerProvider] = [TestServiceProvider()],
   useTLS: Bool
 ) throws -> EventLoopFuture<Server> {
-  var configuration = Server.Configuration(
-    target: .hostAndPort(host, port),
-    eventLoopGroup: eventLoopGroup,
-    serviceProviders: serviceProviders,
-    messageEncoding: .enabled(.init(decompressionLimit: .absolute(1024 * 1024)))
-  )
+  let builder: Server.Builder
 
   if useTLS {
     print("Using the gRPC interop testing CA for TLS; clients should expect the host to be '*.test.google.fr'")
@@ -51,12 +46,14 @@ public func makeInteroperabilityTestServer(
     let serverCert = InteroperabilityTestCredentials.server1Certificate
     let serverKey = InteroperabilityTestCredentials.server1Key
 
-    configuration.tls = .init(
-      certificateChain: [.certificate(serverCert)],
-      privateKey: .privateKey(serverKey),
-      trustRoots: .certificates([caCert])
-    )
+    builder = Server.secure(group: eventLoopGroup, certificateChain: [serverCert], privateKey: serverKey)
+      .withTLS(trustRoots: .certificates([caCert]))
+  } else {
+    builder = Server.insecure(group: eventLoopGroup)
   }
 
-  return Server.start(configuration: configuration)
+  return builder
+    .withMessageCompression(.enabled(.init(decompressionLimit: .absolute(1024 * 1024))))
+    .withServiceProviders(serviceProviders)
+    .bind(host: host, port: port)
 }

+ 4 - 6
Sources/GRPCPerformanceTests/Benchmarks/ServerProvidingBenchmark.swift

@@ -30,12 +30,10 @@ class ServerProvidingBenchmark: Benchmark {
 
   func setUp() throws {
     self.group = MultiThreadedEventLoopGroup(numberOfThreads: self.threadCount)
-    let configuration = Server.Configuration(
-      target: .hostAndPort("127.0.0.1", 0),
-      eventLoopGroup: self.group,
-      serviceProviders: self.providers
-    )
-    self.server = try Server.start(configuration: configuration).wait()
+    self.server = try Server.insecure(group: self.group)
+      .withServiceProviders(self.providers)
+      .bind(host: "127.0.0.1", port: 0)
+      .wait()
   }
 
   func tearDown() throws {

+ 18 - 59
Tests/GRPCTests/BasicEchoTestCase.swift

@@ -45,57 +45,6 @@ enum TransportSecurity {
   case mutualAuthentication
 }
 
-extension TransportSecurity {
-  var caCert: NIOSSLCertificate {
-    let cert = SampleCertificate.ca
-    cert.assertNotExpired()
-    return cert.certificate
-  }
-
-  var clientCert: NIOSSLCertificate {
-    let cert = SampleCertificate.client
-    cert.assertNotExpired()
-    return cert.certificate
-  }
-
-  var serverCert: NIOSSLCertificate {
-    let cert = SampleCertificate.server
-    cert.assertNotExpired()
-    return cert.certificate
-  }
-}
-
-extension TransportSecurity {
-  func makeServerTLSConfiguration() -> Server.Configuration.TLS? {
-    switch self {
-    case .none:
-      return nil
-
-    case .anonymousClient, .mutualAuthentication:
-      return .init(certificateChain: [.certificate(self.serverCert)],
-                   privateKey: .privateKey(SamplePrivateKey.server),
-                   trustRoots: .certificates ([self.caCert]))
-    }
-  }
-
-  func makeClientTLSConfiguration() -> ClientConnection.Configuration.TLS? {
-    switch self {
-    case .none:
-      return nil
-
-    case .anonymousClient:
-      return .init(trustRoots: .certificates([self.caCert]))
-
-    case .mutualAuthentication:
-      return .init(
-        certificateChain: [.certificate(self.clientCert)],
-        privateKey: .privateKey(SamplePrivateKey.client),
-        trustRoots: .certificates([self.caCert])
-      )
-    }
-  }
-}
-
 class EchoTestCaseBase: GRPCTestCase {
   var defaultTestTimeout: TimeInterval = 1.0
 
@@ -118,6 +67,7 @@ class EchoTestCaseBase: GRPCTestCase {
     switch self.transportSecurity {
     case .none:
       return ClientConnection.insecure(group: self.clientEventLoopGroup)
+
     case .anonymousClient:
       return ClientConnection.secure(group: self.clientEventLoopGroup)
         .withTLS(trustRoots: .certificates([SampleCertificate.ca.certificate]))
@@ -130,17 +80,26 @@ class EchoTestCaseBase: GRPCTestCase {
     }
   }
 
-  func makeServerConfiguration() throws -> Server.Configuration {
-    return .init(
-      target: .hostAndPort("localhost", 0),
-      eventLoopGroup: self.serverEventLoopGroup,
-      serviceProviders: [makeEchoProvider()],
-      errorDelegate: self.makeErrorDelegate(),
-      tls: self.transportSecurity.makeServerTLSConfiguration())
+  func serverBuilder() -> Server.Builder {
+    switch self.transportSecurity {
+    case .none:
+      return Server.insecure(group: self.serverEventLoopGroup)
+
+    case .anonymousClient, .mutualAuthentication:
+      return Server.secure(
+        group: self.serverEventLoopGroup,
+        certificateChain: [SampleCertificate.server.certificate],
+        privateKey: SamplePrivateKey.server
+      ).withTLS(trustRoots: .certificates([SampleCertificate.ca.certificate]))
+    }
   }
 
   func makeServer() throws -> Server {
-    return try Server.start(configuration: self.makeServerConfiguration()).wait()
+    return try self.serverBuilder()
+      .withErrorDelegate(makeErrorDelegate())
+      .withServiceProviders([makeEchoProvider()])
+      .bind(host: "localhost", port: 0)
+      .wait()
   }
 
   func makeClientConnection(port: Int) throws -> ClientConnection {

+ 3 - 6
Tests/GRPCTests/ClientConnectionBackoffTests.swift

@@ -116,12 +116,9 @@ class ClientConnectionBackoffTests: GRPCTestCase {
   }
 
   func makeServer() -> EventLoopFuture<Server> {
-    let configuration = Server.Configuration(
-      target: .hostAndPort("localhost", self.port),
-      eventLoopGroup: self.serverGroup,
-      serviceProviders: [EchoProvider()])
-
-    return Server.start(configuration: configuration)
+    return Server.insecure(group: self.serverGroup)
+      .withServiceProviders([EchoProvider()])
+      .bind(host: "localhost", port: self.port)
   }
 
   func connectionBuilder() -> ClientConnection.Builder {

+ 7 - 8
Tests/GRPCTests/ClientTLSFailureTests.swift

@@ -72,14 +72,13 @@ class ClientTLSFailureTests: GRPCTestCase {
   override func setUp() {
     self.serverEventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
 
-    let configuration = Server.Configuration(
-      target: .hostAndPort("localhost", 0),
-      eventLoopGroup: self.serverEventLoopGroup,
-      serviceProviders: [EchoProvider()],
-      errorDelegate: nil,
-      tls: self.defaultServerTLSConfiguration)
-
-    self.server = try! Server.start(configuration: configuration).wait()
+    self.server = try! Server.secure(
+      group: self.serverEventLoopGroup,
+      certificateChain: [SampleCertificate.server.certificate],
+      privateKey: SamplePrivateKey.server
+    ).withServiceProviders([EchoProvider()])
+      .bind(host: "localhost", port: 0)
+      .wait()
 
     self.port = self.server.channel.localAddress?.port
 

+ 16 - 22
Tests/GRPCTests/ClientTLSTests.swift

@@ -39,16 +39,6 @@ class ClientTLSHostnameOverrideTests: GRPCTestCase {
     XCTAssertNoThrow(try self.eventLoopGroup.syncShutdownGracefully())
   }
 
-  func makeEchoServer(tls: Server.Configuration.TLS) throws -> Server {
-    let configuration: Server.Configuration = .init(
-      target: .hostAndPort("localhost", 0),
-      eventLoopGroup: self.eventLoopGroup,
-      serviceProviders: [EchoProvider()],
-      tls: tls
-    )
-
-    return try Server.start(configuration: configuration).wait()
-  }
 
   func doTestUnary() throws {
     let client = Echo_EchoClient(channel: self.connection)
@@ -63,13 +53,15 @@ class ClientTLSHostnameOverrideTests: GRPCTestCase {
 
   func testTLSWithHostnameOverride() throws {
     // Run a server presenting a certificate for example.com on localhost.
-    let serverTLS: Server.Configuration.TLS = .init(
-      certificateChain: [.certificate(SampleCertificate.exampleServer.certificate)],
-      privateKey: .privateKey(SamplePrivateKey.exampleServer),
-      trustRoots: .certificates([SampleCertificate.ca.certificate])
-    )
+    let cert = SampleCertificate.exampleServer.certificate
+    let key = SamplePrivateKey.exampleServer
+
+    self.server = try Server.secure(group: self.eventLoopGroup, certificateChain: [cert], privateKey: key)
+      .withTLS(trustRoots: .certificates([SampleCertificate.ca.certificate]))
+      .withServiceProviders([EchoProvider()])
+      .bind(host: "localhost", port: 0)
+      .wait()
 
-    self.server = try makeEchoServer(tls: serverTLS)
     guard let port = self.server.channel.localAddress?.port else {
       XCTFail("could not get server port")
       return
@@ -85,13 +77,15 @@ class ClientTLSHostnameOverrideTests: GRPCTestCase {
 
   func testTLSWithoutHostnameOverride() throws {
     // Run a server presenting a certificate for localhost on localhost.
-    let serverTLS: Server.Configuration.TLS = .init(
-      certificateChain: [.certificate(SampleCertificate.server.certificate)],
-      privateKey: .privateKey(SamplePrivateKey.server),
-      trustRoots: .certificates([SampleCertificate.ca.certificate])
-    )
+    let cert = SampleCertificate.server.certificate
+    let key = SamplePrivateKey.server
+
+    self.server = try Server.secure(group: self.eventLoopGroup, certificateChain: [cert], privateKey: key)
+      .withTLS(trustRoots: .certificates([SampleCertificate.ca.certificate]))
+      .withServiceProviders([EchoProvider()])
+      .bind(host: "localhost", port: 0)
+      .wait()
 
-    self.server = try makeEchoServer(tls: serverTLS)
     guard let port = self.server.channel.localAddress?.port else {
       XCTFail("could not get server port")
       return

+ 5 - 8
Tests/GRPCTests/CompressionTests.swift

@@ -39,14 +39,11 @@ class MessageCompressionTests: GRPCTestCase {
   }
 
   func setupServer(encoding: ServerMessageEncoding) throws {
-    let configuration = Server.Configuration(
-      target: .hostAndPort("localhost", 0),
-      eventLoopGroup: self.group,
-      serviceProviders: [EchoProvider()],
-      messageEncoding: encoding
-    )
-
-    self.server = try Server.start(configuration: configuration).wait()
+    self.server = try Server.insecure(group: self.group)
+      .withServiceProviders([EchoProvider()])
+      .withMessageCompression(encoding)
+      .bind(host: "localhost", port: 0)
+      .wait()
   }
 
   func setupClient(encoding: ClientMessageEncoding) {

+ 4 - 7
Tests/GRPCTests/GRPCCustomPayloadTests.swift

@@ -28,13 +28,10 @@ class GRPCCustomPayloadTests: GRPCTestCase {
     super.setUp()
     self.group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
 
-    let serverConfig: Server.Configuration = .init(
-      target: .hostAndPort("localhost", 0),
-      eventLoopGroup: self.group,
-      serviceProviders: [CustomPayloadProvider()]
-    )
-
-    self.server = try! Server.start(configuration: serverConfig).wait()
+    self.server = try! Server.insecure(group: self.group)
+      .withServiceProviders([CustomPayloadProvider()])
+      .bind(host: "localhost", port: 0)
+      .wait()
 
     let channel = ClientConnection.insecure(group: self.group)
       .connect(host: "localhost", port: server.channel.localAddress!.port!)

+ 4 - 7
Tests/GRPCTests/HeaderNormalizationTests.swift

@@ -93,13 +93,10 @@ class HeaderNormalizationTests: GRPCTestCase {
 
     self.group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
 
-    let serverConfig = Server.Configuration(
-      target: .hostAndPort("localhost", 0),
-      eventLoopGroup: self.group,
-      serviceProviders: [EchoMetadataValidator()]
-    )
-
-    self.server = try! Server.start(configuration: serverConfig).wait()
+    self.server = try! Server.insecure(group: self.group)
+      .withServiceProviders([EchoMetadataValidator()])
+      .bind(host: "localhost", port: 0)
+      .wait()
 
     self.channel = ClientConnection.insecure(group: self.group)
       .connect(host: "localhost", port: self.server.channel.localAddress!.port!)

+ 11 - 14
docs/basic-tutorial.md

@@ -426,15 +426,11 @@ let features = try loadFeatures()
 // Create a provider using the features we read.
 let provider = RouteGuideProvider(features: features)
 
-// Tie these together in some configuration:
-let configuration = Server.Configuration(
-  target: .hostAndPort("localhost", 0),
-  eventLoopGroup: group,
-  serviceProviders: [provider]
-)
-
-// Start the server and print its port once it has started.
-let server = Server.start(configuration: configuration)
+// Start the server and print its address once it has started.
+let server = Server.insecure(group: group)
+  .withServiceProviders([provider])
+  .bind(host: "localhost", port: 0)
+
 server.map {
   $0.channel.localAddress
 }.whenSuccess { address in
@@ -446,15 +442,16 @@ _ = try server.flatMap {
   $0.onClose
 }.wait()
 ```
-As you can see, we build and start our server using some `Configuration`.
+As you can see, we configure and start our server using a builder.
 
 To do this, we:
 
-1. Specify the address and port we want to use to listen for client requests
-   using the `target` argument.
+1. Create an insecure server builder; it's insecure because it does not use
+   TLS.
 1. Create an instance of our service implementation class `RouteGuideProvider`
-   and pass it to the configuration's `serviceProviders` argument.
-1. Pass the configuration the `Server` class's static `start` method.
+   and configure the builder to use it with `withServiceProviders(_:)`,
+1. Call `bind(host:port:)` on the builder with the address and port we
+   want to use to listen for client requests, this starts the server.
 
 Once the server has started succesfully we print out the port the server is
 listening on. We then `wait()` on the server's `onClose` future to stop the