浏览代码

Add Server.Configuration (#490)

Motivation:

We added Configuration to the ClientConnection to make it simpler to
create new connections. We should have the equivalent for Server.

Modification:

Added a Configuration for Server similar to
ClientConnection.Configuration. This also made it possible to add
support for NIOTS for the server.

Result:

Configuring a server is easier and they can now use NIOTS.
George Barnett 6 年之前
父节点
当前提交
515e53347b

+ 16 - 13
Sources/Examples/Echo/main.swift

@@ -36,11 +36,8 @@ func makeClientSSLContext() throws -> NIOSSLContext {
   return try NIOSSLContext(configuration: makeClientTLSConfiguration())
 }
 
-func makeServerTLS(enabled: Bool) throws -> Server.TLSMode {
-  guard enabled else {
-    return .none
-  }
-  return .custom(try NIOSSLContext(configuration: makeServerTLSConfiguration()))
+func makeServerSSLContext() throws -> NIOSSLContext {
+  return try NIOSSLContext(configuration: makeServerTLSConfiguration())
 }
 
 func makeClientTLSConfiguration() -> TLSConfiguration {
@@ -100,20 +97,26 @@ Group {
              addressOption("localhost"),
              portOption,
              description: "Run an echo server.") { ssl, address, port in
-    let sem = DispatchSemaphore(value: 0)
     let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
 
-    print(ssl ? "starting secure server" : "starting insecure server")
-    _ = try! Server.start(hostname: address,
-                              port: port,
-                              eventLoopGroup: eventLoopGroup,
-                              serviceProviders: [EchoProvider()],
-                              tls: makeServerTLS(enabled: ssl))
+    var configuration = Server.Configuration(
+      target: .hostAndPort(address, port),
+      eventLoopGroup: eventLoopGroup,
+      serviceProviders: [EchoProvider()])
+
+    if ssl {
+      print("starting secure server")
+      configuration.tlsConfiguration = .init(sslContext: try makeServerSSLContext())
+    } else {
+      print("starting insecure server")
+    }
+
+    let server = try! Server.start(configuration: configuration)
       .wait()
 
     // This blocks to keep the main thread from finishing while the server runs,
     // but the server never exits. Kill the process to stop it.
-    _ = sem.wait()
+    try server.onClose.wait()
   }
 
   $0.command(

+ 3 - 1
Sources/GRPC/LoggingServerErrorDelegate.swift

@@ -16,7 +16,9 @@
 import Foundation
 
 public class LoggingServerErrorDelegate: ServerErrorDelegate {
-  public init() {}
+  public static let shared = LoggingServerErrorDelegate()
+
+  private init() {}
 
   public func observeLibraryError(_ error: Error) {
     print("[grpc-server][\(Date())] library: \(error)")

+ 132 - 52
Sources/GRPC/Server.swift

@@ -1,3 +1,18 @@
+/*
+ * Copyright 2019, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
 import Foundation
 import NIO
 import NIOHTTP1
@@ -76,63 +91,54 @@ import NIOSSL
 ///                             │                       ▼
 ///
 public final class Server {
-  /// Starts up a server that serves the given providers.
-  ///
-  /// - Returns: A future that is completed when the server has successfully started up.
-  public static func start(
-    hostname: String,
-    port: Int,
-    eventLoopGroup: EventLoopGroup,
-    serviceProviders: [CallHandlerProvider],
-    errorDelegate: ServerErrorDelegate? = LoggingServerErrorDelegate(),
-    tls tlsMode: TLSMode = .none
-  ) throws -> EventLoopFuture<Server> {
-    let servicesByName = Dictionary(uniqueKeysWithValues: serviceProviders.map { ($0.serviceName, $0) })
-    let bootstrap = ServerBootstrap(group: eventLoopGroup)
+  /// Makes and configures a `ServerBootstrap` using the provided configuration.
+  public class func makeBootstrap(configuration: Configuration) -> ServerBootstrapProtocol {
+    let bootstrap = GRPCNIO.makeServerBootstrap(group: configuration.eventLoopGroup)
+
+    // Backlog is only available on `ServerBootstrap`.
+    if bootstrap is ServerBootstrap {
       // Specify a backlog to avoid overloading the server.
-      .serverChannelOption(ChannelOptions.backlog, value: 256)
+      _ = bootstrap.serverChannelOption(ChannelOptions.backlog, value: 256)
+    }
+
+    return bootstrap
       // Enable `SO_REUSEADDR` to avoid "address already in use" error.
       .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
       // Set the handlers that are applied to the accepted Channels
       .childChannelInitializer { channel in
-        let protocolSwitcherHandler = HTTPProtocolSwitcher(errorDelegate: errorDelegate) { channel -> EventLoopFuture<Void> in
-          channel.pipeline.addHandlers(HTTP1ToRawGRPCServerCodec(),
-                                       GRPCChannelHandler(servicesByName: servicesByName, errorDelegate: errorDelegate))
+        let protocolSwitcher = HTTPProtocolSwitcher(errorDelegate: configuration.errorDelegate) { channel -> EventLoopFuture<Void> in
+          let handlers: [ChannelHandler] = [
+            HTTP1ToRawGRPCServerCodec(),
+            GRPCChannelHandler(
+              servicesByName: configuration.serviceProvidersByName,
+              errorDelegate: configuration.errorDelegate
+            )
+          ]
+          return channel.pipeline.addHandlers(handlers)
         }
 
-        return configureTLS(mode: tlsMode, channel: channel).flatMap {
-          channel.pipeline.addHandler(protocolSwitcherHandler)
+        if let tlsConfiguration = configuration.tlsConfiguration {
+          return channel.configureTLS(configuration: tlsConfiguration).flatMap {
+            channel.pipeline.addHandler(protocolSwitcher)
+          }
+        } else {
+          return channel.pipeline.addHandler(protocolSwitcher)
         }
       }
 
       // Enable TCP_NODELAY and SO_REUSEADDR for the accepted Channels
       .childChannelOption(ChannelOptions.socket(IPPROTO_TCP, TCP_NODELAY), value: 1)
       .childChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
-
-    return bootstrap.bind(host: hostname, port: port)
-      .map { Server(channel: $0, errorDelegate: errorDelegate) }
   }
 
-  /// Configure an SSL handler on the channel, if one is provided.
-  ///
-  /// - Parameters:
-  ///   - mode: TLS mode to run the server in.
-  ///   - channel: The channel on which to add the SSL handler.
-  /// - Returns: A future which will be succeeded when the pipeline has been configured.
-  private static func configureTLS(mode: TLSMode, channel: Channel) -> EventLoopFuture<Void> {
-    guard let sslContext = mode.sslContext else {
-      return channel.eventLoop.makeSucceededFuture(())
-    }
-
-    let handlerAddedPromise: EventLoopPromise<Void> = channel.eventLoop.makePromise()
-
-    do {
-      channel.pipeline.addHandler(try NIOSSLServerHandler(context: sslContext)).cascade(to: handlerAddedPromise)
-    } catch {
-      handlerAddedPromise.fail(error)
-    }
-
-    return handlerAddedPromise.futureResult
+  /// Starts a server with the given configuration. See `Server.Configuration` for the options
+  /// available to configure the server.
+  public static func start(configuration: Configuration) -> EventLoopFuture<Server> {
+    return makeBootstrap(configuration: configuration)
+      .bind(to: configuration.target)
+      .map { channel in
+        Server(channel: channel, errorDelegate: configuration.errorDelegate)
+      }
   }
 
   public let channel: Channel
@@ -161,19 +167,93 @@ public final class Server {
   }
 }
 
+public typealias BindTarget = ConnectionTarget
+
 extension Server {
-  public enum TLSMode {
-    case none
-    case custom(NIOSSLContext)
+  /// The configuration for a server.
+  public struct Configuration {
+    /// The target to bind to.
+    public var target: BindTarget
 
-    var sslContext: NIOSSLContext? {
-      switch self {
-      case .none:
-        return nil
+    /// The event loop group to run the connection on.
+    public var eventLoopGroup: EventLoopGroup
 
-      case .custom(let context):
-        return context
-      }
+    /// Providers the server should use to handle gRPC requests.
+    public var serviceProviders: [CallHandlerProvider]
+
+    /// An error delegate which is called when errors are caught. Provided delegates **must not
+    /// maintain a strong reference to this `Server`**. Doing so will cause a retain cycle.
+    public var errorDelegate: ServerErrorDelegate?
+
+    /// TLS configuration for this connection. `nil` if TLS is not desired.
+    public var tlsConfiguration: TLSConfiguration?
+
+    /// Create a `Configuration` with some pre-defined defaults.
+    ///
+    /// - Parameter target: The target to bind to.
+    /// - Parameter eventLoopGroup: The event loop group to run the server on.
+    /// - Parameter serviceProviders: An array of `CallHandlerProvider`s which the server should use
+    ///     to handle requests.
+    /// - Parameter errorDelegate: The error delegate, defaulting to a logging delegate.
+    /// - Parameter tlsConfiguration: TLS configuration, defaulting to `nil`.
+    public init(
+      target: BindTarget,
+      eventLoopGroup: EventLoopGroup,
+      serviceProviders: [CallHandlerProvider],
+      errorDelegate: ServerErrorDelegate? = LoggingServerErrorDelegate.shared,
+      tlsConfiguration: TLSConfiguration? = nil
+    ) {
+      self.target = target
+      self.eventLoopGroup = eventLoopGroup
+      self.serviceProviders = serviceProviders
+      self.errorDelegate = errorDelegate
+      self.tlsConfiguration = tlsConfiguration
+    }
+  }
+
+  /// The TLS configuration for a connection.
+  public struct TLSConfiguration {
+    /// The SSL context to use.
+    public var sslContext: NIOSSLContext
+
+    public init(sslContext: NIOSSLContext) {
+      self.sslContext = sslContext
+    }
+  }
+}
+
+fileprivate extension Server.Configuration {
+  var serviceProvidersByName: [String: CallHandlerProvider] {
+    return Dictionary(uniqueKeysWithValues: self.serviceProviders.map { ($0.serviceName, $0) })
+  }
+}
+
+fileprivate extension Channel {
+  /// Configure an SSL handler on the channel.
+  ///
+  /// - Parameters:
+  ///   - configuration: The configuration to use when creating the handler.
+  /// - Returns: A future which will be succeeded when the pipeline has been configured.
+  func configureTLS(configuration: Server.TLSConfiguration) -> EventLoopFuture<Void> {
+    do {
+      return self.pipeline.addHandler(try NIOSSLServerHandler(context: configuration.sslContext))
+    } catch {
+      return self.pipeline.eventLoop.makeFailedFuture(error)
+    }
+  }
+}
+
+fileprivate extension ServerBootstrapProtocol {
+  func bind(to target: BindTarget) -> EventLoopFuture<Channel> {
+    switch target {
+    case .hostAndPort(let host, let port):
+      return self.bind(host: host, port: port)
+
+    case .unixDomainSocket(let path):
+      return self.bind(unixDomainSocketPath: path)
+
+    case .socketAddress(let address):
+      return self.bind(to: address)
     }
   }
 }

+ 7 - 11
Sources/GRPCInteroperabilityTests/InteroperabilityTestServer.swift

@@ -37,7 +37,10 @@ public func makeInteroperabilityTestServer(
   serviceProviders: [CallHandlerProvider] = [TestServiceProvider()],
   useTLS: Bool
 ) throws -> EventLoopFuture<Server> {
-  let tlsMode: Server.TLSMode
+  var configuration = Server.Configuration(
+    target: .hostAndPort(host, port),
+    eventLoopGroup: eventLoopGroup,
+    serviceProviders: serviceProviders)
 
   if useTLS {
     print("Using the gRPC interop testing CA for TLS; clients should expect the host to be '*.test.google.fr'")
@@ -53,16 +56,9 @@ public func makeInteroperabilityTestServer(
       applicationProtocols: ["h2"]
     )
 
-    tlsMode = .custom(try NIOSSLContext(configuration: tlsConfiguration))
-  } else {
-    tlsMode = .none
+    let sslContext = try NIOSSLContext(configuration: tlsConfiguration)
+    configuration.tlsConfiguration = .init(sslContext: sslContext)
   }
 
-  return try Server.start(
-    hostname: host,
-    port: port,
-    eventLoopGroup: eventLoopGroup,
-    serviceProviders: serviceProviders,
-    tls: tlsMode
-  )
+  return Server.start(configuration: configuration)
 }

+ 7 - 6
Sources/GRPCPerformanceTests/main.swift

@@ -363,15 +363,16 @@ Group { group in
       privateKeyPath: privateKeyPath,
       server: true)
 
+    let configuration = Server.Configuration(
+      target: .hostAndPort(host, port),
+      eventLoopGroup: group,
+      serviceProviders: [EchoProvider()],
+      tlsConfiguration: sslContext.map { .init(sslContext: $0) })
+
     let server: Server
 
     do {
-      server = try Server.start(
-        hostname: host,
-        port: port,
-        eventLoopGroup: group,
-        serviceProviders: [EchoProvider()],
-        tls: sslContext.map { .custom($0) } ?? .none).wait()
+      server = try Server.start(configuration: configuration).wait()
     } catch {
       print("unable to start server: \(error)")
       exit(1)

+ 35 - 23
Tests/GRPCTests/BasicEchoTestCase.swift

@@ -17,7 +17,7 @@ import Dispatch
 import Foundation
 import NIO
 import NIOSSL
-@testable import GRPC
+import GRPC
 import GRPCSampleData
 import XCTest
 
@@ -60,8 +60,13 @@ extension TransportSecurity {
 }
 
 extension TransportSecurity {
-  func makeServerTLS() throws -> Server.TLSMode {
-    return try makeServerTLSConfiguration().map { .custom(try NIOSSLContext(configuration: $0)) } ?? .none
+  func makeServerConfiguration() throws -> Server.TLSConfiguration? {
+    guard let config = try self.makeServerTLSConfiguration() else {
+      return nil
+    }
+
+    let context = try NIOSSLContext(configuration: config)
+    return .init(sslContext: context)
   }
 
   func makeServerTLSConfiguration() throws -> TLSConfiguration? {
@@ -77,7 +82,7 @@ extension TransportSecurity {
     }
   }
 
-  func makeConfiguration() throws -> ClientConnection.TLSConfiguration? {
+  func makeClientConfiguration() throws -> ClientConnection.TLSConfiguration? {
     guard let config = try self.makeClientTLSConfiguration() else {
       return nil
     }
@@ -116,6 +121,7 @@ class EchoTestCaseBase: XCTestCase {
 
   var server: Server!
   var client: Echo_EchoServiceClient!
+  var port: Int!
 
   // Prefer POSIX: subclasses can override this and add availability checks to ensure NIOTS
   // variants run where possible.
@@ -123,52 +129,57 @@ class EchoTestCaseBase: XCTestCase {
     return .userDefined(.posix)
   }
 
-  func makeClientConfiguration() throws -> ClientConnection.Configuration {
+  func makeClientConfiguration(port: Int) throws -> ClientConnection.Configuration {
     return .init(
-      target: .hostAndPort("localhost", 5050),
+      target: .hostAndPort("localhost", port),
       eventLoopGroup: self.clientEventLoopGroup,
-      tlsConfiguration: try self.transportSecurity.makeConfiguration())
+      tlsConfiguration: try self.transportSecurity.makeClientConfiguration())
   }
 
-  func makeServer() throws -> Server {
-    return try Server.start(
-      hostname: "localhost",
-      port: 5050,
+  func makeServerConfiguration() throws -> Server.Configuration {
+    return .init(
+      target: .hostAndPort("localhost", 0),
       eventLoopGroup: self.serverEventLoopGroup,
       serviceProviders: [makeEchoProvider()],
-      errorDelegate: makeErrorDelegate(),
-      tls: try self.transportSecurity.makeServerTLS()
-    ).wait()
+      errorDelegate: self.makeErrorDelegate(),
+      tlsConfiguration: try self.transportSecurity.makeServerConfiguration())
+  }
+
+  func makeServer() throws -> Server {
+    return try Server.start(configuration: self.makeServerConfiguration()).wait()
   }
 
-  func makeClientConnection() throws -> ClientConnection {
-    return try ClientConnection.start(self.makeClientConfiguration()).wait()
+  func makeClientConnection(port: Int) throws -> ClientConnection {
+    return try ClientConnection.start(self.makeClientConfiguration(port: port)).wait()
   }
 
   func makeEchoProvider() -> Echo_EchoProvider { return EchoProvider() }
 
   func makeErrorDelegate() -> ServerErrorDelegate? { return nil }
 
-  func makeEchoClient() throws -> Echo_EchoServiceClient {
-    return Echo_EchoServiceClient(connection: try self.makeClientConnection())
+  func makeEchoClient(port: Int) throws -> Echo_EchoServiceClient {
+    return Echo_EchoServiceClient(connection: try self.makeClientConnection(port: port))
   }
 
   override func setUp() {
     super.setUp()
-    self.serverEventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
+    self.serverEventLoopGroup = GRPCNIO.makeEventLoopGroup(
+      loopCount: 1,
+      networkPreference: self.networkPreference)
     self.server = try! self.makeServer()
 
+    self.port = self.server.channel.localAddress!.port!
+
     self.clientEventLoopGroup = GRPCNIO.makeEventLoopGroup(
       loopCount: 1,
       networkPreference: self.networkPreference)
-    self.client = try! self.makeEchoClient()
+    self.client = try! self.makeEchoClient(port: self.port)
   }
 
   override func tearDown() {
     // Some tests close the channel, so would throw here if called twice.
-    if self.client.connection.channel.isActive {
-      XCTAssertNoThrow(try self.client.connection.close().wait())
-    }
+    try? self.client.connection.close().wait()
+
     XCTAssertNoThrow(try self.clientEventLoopGroup.syncShutdownGracefully())
     self.client = nil
     self.clientEventLoopGroup = nil
@@ -177,6 +188,7 @@ class EchoTestCaseBase: XCTestCase {
     XCTAssertNoThrow(try self.serverEventLoopGroup.syncShutdownGracefully())
     self.server = nil
     self.serverEventLoopGroup = nil
+    self.port = nil
 
     super.tearDown()
   }

+ 7 - 6
Tests/GRPCTests/ClientConnectionBackoffTests.swift

@@ -32,19 +32,20 @@ class ClientConnectionBackoffTests: XCTestCase {
     }
 
     // We don't always expect a client (since we deliberately timeout the connection in some cases).
-    if let client = try? self.client.wait() {
+    if let client = try? self.client.wait(), client.channel.isActive {
       XCTAssertNoThrow(try client.channel.close().wait())
     }
 
     XCTAssertNoThrow(try self.group.syncShutdownGracefully())
   }
 
-  func makeServer() throws -> EventLoopFuture<Server> {
-    return try Server.start(
-      hostname: "localhost",
-      port: self.port,
+  func makeServer() -> EventLoopFuture<Server> {
+    let configuration = Server.Configuration(
+      target: .hostAndPort("localhost", self.port),
       eventLoopGroup: self.group,
       serviceProviders: [])
+
+    return Server.start(configuration: configuration)
   }
 
   func makeClientConfiguration() -> ClientConnection.Configuration {
@@ -81,7 +82,7 @@ class ClientConnectionBackoffTests: XCTestCase {
     // Sleep for a little bit to make sure we hit the backoff.
     Thread.sleep(forTimeInterval: 0.2)
 
-    self.server = try self.makeServer()
+    self.server = self.makeServer()
     self.server.assertSuccess(fulfill: serverStarted)
 
     self.wait(for: [serverStarted, clientConnected], timeout: 2.0, enforceOrder: true)

+ 7 - 5
Tests/GRPCTests/ClientTLSFailureTests.swift

@@ -60,14 +60,16 @@ class ClientTLSFailureTests: XCTestCase {
 
   override func setUp() {
     self.serverEventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
-    self.server = try! Server.start(
-      hostname: "localhost",
-      port: 0,
+    let sslContext = try! NIOSSLContext(configuration: self.defaultServerTLSConfiguration)
+
+    let configuration = Server.Configuration(
+      target: .hostAndPort("localhost", 0),
       eventLoopGroup: self.serverEventLoopGroup,
       serviceProviders: [EchoProvider()],
       errorDelegate: nil,
-      tls: .custom(try NIOSSLContext(configuration: defaultServerTLSConfiguration))
-    ).wait()
+      tlsConfiguration: .init(sslContext: sslContext))
+
+    self.server = try! Server.start(configuration: configuration).wait()
 
     self.port = self.server.channel.localAddress?.port
 

+ 1 - 1
Tests/GRPCTests/ServerWebTests.swift

@@ -45,7 +45,7 @@ class ServerWebTests: EchoTestCaseBase {
   }
 
   private func sendOverHTTP1(rpcMethod: String, message: String?, handler: @escaping (Data?, Error?) -> Void) {
-    let serverURL = URL(string: "http://localhost:5050/echo.Echo/\(rpcMethod)")!
+    let serverURL = URL(string: "http://localhost:\(self.port!)/echo.Echo/\(rpcMethod)")!
     var request = URLRequest(url: serverURL)
     request.httpMethod = "POST"
     request.setValue("application/grpc-web-text", forHTTPHeaderField: "content-type")