Browse Source

allow client/server to be initialised with a connected socket (#1385)

Motivation: https://github.com/grpc/grpc-swift/issues/1353

Modifications: 

- A new entry on ConnectionTarget - New API on the ClientConnection builder
- Some validation that we're not trying to use this API with a
  NIOTSEventLoopGroup or NIOTSEventLoop (as far as I can tell, using a file
  descriptor directly is not possible with Network.framework)
- Tests

Result:

This allows greater flexibility in spawning the client/server; in particular, 
it allows unix domain sockets for sandboxed apps on macOS.
Vimal 3 years ago
parent
commit
e1a0025a60

+ 9 - 1
Sources/GRPC/ClientConnection.swift

@@ -266,6 +266,7 @@ public struct ConnectionTarget {
     case hostAndPort(String, Int)
     case unixDomainSocket(String)
     case socketAddress(SocketAddress)
+    case connectedSocket(NIOBSDSocket.Handle)
   }
 
   internal var wrapped: Wrapped
@@ -293,6 +294,11 @@ public struct ConnectionTarget {
     return ConnectionTarget(.socketAddress(address))
   }
 
+  /// A connected NIO socket.
+  public static func connectedSocket(_ socket: NIOBSDSocket.Handle) -> ConnectionTarget {
+    return ConnectionTarget(.connectedSocket(socket))
+  }
+
   @usableFromInline
   var host: String {
     switch self.wrapped {
@@ -302,7 +308,7 @@ public struct ConnectionTarget {
       return address.host
     case let .socketAddress(.v6(address)):
       return address.host
-    case .unixDomainSocket, .socketAddress(.unixDomainSocket):
+    case .unixDomainSocket, .socketAddress(.unixDomainSocket), .connectedSocket:
       return "localhost"
     }
   }
@@ -540,6 +546,8 @@ extension ClientBootstrapProtocol {
 
     case let .socketAddress(address):
       return self.connect(to: address)
+    case let .connectedSocket(socket):
+      return self.withConnectedSocket(socket)
     }
   }
 }

+ 12 - 0
Sources/GRPC/GRPCChannel/GRPCChannelBuilder.swift

@@ -85,6 +85,18 @@ extension ClientConnection {
       self.configuration.tlsConfiguration = self.maybeTLS
       return ClientConnection(configuration: self.configuration)
     }
+
+    public func withConnectedSocket(_ socket: NIOBSDSocket.Handle) -> ClientConnection {
+      precondition(
+        !PlatformSupport.isTransportServicesEventLoopGroup(self.configuration.eventLoopGroup),
+        "'\(#function)' requires 'group' to not be a 'NIOTransportServices.NIOTSEventLoopGroup' or 'NIOTransportServices.QoSEventLoop' (but was '\(type(of: self.configuration.eventLoopGroup))'"
+      )
+      self.configuration.target = .connectedSocket(socket)
+      self.configuration.connectionBackoff =
+        self.connectionBackoffIsEnabled ? self.connectionBackoff : nil
+      self.configuration.tlsConfiguration = self.maybeTLS
+      return ClientConnection(configuration: self.configuration)
+    }
   }
 }
 

+ 24 - 2
Sources/GRPC/PlatformSupport.swift

@@ -117,17 +117,28 @@ public protocol ClientBootstrapProtocol {
   func connect(to: SocketAddress) -> EventLoopFuture<Channel>
   func connect(host: String, port: Int) -> EventLoopFuture<Channel>
   func connect(unixDomainSocketPath: String) -> EventLoopFuture<Channel>
+  func withConnectedSocket(_ socket: NIOBSDSocket.Handle) -> EventLoopFuture<Channel>
 
   func connectTimeout(_ timeout: TimeAmount) -> Self
   func channelOption<T>(_ option: T, value: T.Value) -> Self where T: ChannelOption
   func channelInitializer(_ handler: @escaping (Channel) -> EventLoopFuture<Void>) -> Self
 }
 
+extension ClientBootstrapProtocol {
+  public func withConnectedSocket(_ socket: NIOBSDSocket.Handle) -> EventLoopFuture<Channel> {
+    preconditionFailure("withConnectedSocket(_:) is not implemented")
+  }
+}
+
 extension ClientBootstrap: ClientBootstrapProtocol {}
 
 #if canImport(Network)
 @available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *)
-extension NIOTSConnectionBootstrap: ClientBootstrapProtocol {}
+extension NIOTSConnectionBootstrap: ClientBootstrapProtocol {
+  public func withConnectedSocket(_ socket: NIOBSDSocket.Handle) -> EventLoopFuture<Channel> {
+    preconditionFailure("NIOTSConnectionBootstrap does not support withConnectedSocket(_:)")
+  }
+}
 #endif
 
 /// This protocol is intended as a layer of abstraction over `ServerBootstrap` and
@@ -136,6 +147,7 @@ public protocol ServerBootstrapProtocol {
   func bind(to: SocketAddress) -> EventLoopFuture<Channel>
   func bind(host: String, port: Int) -> EventLoopFuture<Channel>
   func bind(unixDomainSocketPath: String) -> EventLoopFuture<Channel>
+  func withBoundSocket(_ connectedSocket: NIOBSDSocket.Handle) -> EventLoopFuture<Channel>
 
   func serverChannelInitializer(_ initializer: @escaping (Channel) -> EventLoopFuture<Void>) -> Self
   func serverChannelOption<T>(_ option: T, value: T.Value) -> Self where T: ChannelOption
@@ -144,11 +156,21 @@ public protocol ServerBootstrapProtocol {
   func childChannelOption<T>(_ option: T, value: T.Value) -> Self where T: ChannelOption
 }
 
+extension ServerBootstrapProtocol {
+  public func withBoundSocket(_ connectedSocket: NIOBSDSocket.Handle) -> EventLoopFuture<Channel> {
+    preconditionFailure("withBoundSocket(_:) is not implemented")
+  }
+}
+
 extension ServerBootstrap: ServerBootstrapProtocol {}
 
 #if canImport(Network)
 @available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *)
-extension NIOTSListenerBootstrap: ServerBootstrapProtocol {}
+extension NIOTSListenerBootstrap: ServerBootstrapProtocol {
+  public func withBoundSocket(_ connectedSocket: NIOBSDSocket.Handle) -> EventLoopFuture<Channel> {
+    preconditionFailure("NIOTSListenerBootstrap does not support withConnectedSocket(_:)")
+  }
+}
 #endif
 
 // MARK: - Bootstrap / EventLoopGroup helpers

+ 3 - 0
Sources/GRPC/Server.swift

@@ -462,6 +462,9 @@ extension ServerBootstrapProtocol {
 
     case let .socketAddress(address):
       return self.bind(to: address)
+
+    case let .connectedSocket(socket):
+      return self.withBoundSocket(socket)
     }
   }
 }

+ 6 - 0
Sources/GRPC/ServerBuilder.swift

@@ -54,6 +54,12 @@ extension Server {
       self.configuration.tlsConfiguration = self.maybeTLS
       return Server.start(configuration: self.configuration)
     }
+
+    public func bind(unixDomainSocketPath path: String) -> EventLoopFuture<Server> {
+      self.configuration.target = .unixDomainSocket(path)
+      self.configuration.tlsConfiguration = self.maybeTLS
+      return Server.start(configuration: self.configuration)
+    }
   }
 }
 

+ 69 - 0
Tests/GRPCTests/WithConnectedSocketTests.swift

@@ -0,0 +1,69 @@
+/*
+ * Copyright 2022, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import EchoImplementation
+import EchoModel
+@testable import GRPC
+import NIOCore
+import NIOPosix
+import XCTest
+
+class WithConnectedSockettests: GRPCTestCase {
+  func testWithConnectedSocket() throws {
+    let group = NIOPosix.MultiThreadedEventLoopGroup(numberOfThreads: 1)
+    defer {
+      XCTAssertNoThrow(try group.syncShutdownGracefully())
+    }
+
+    let path = "/tmp/grpc-\(getpid()).sock"
+    // Setup a server.
+    let server = try Server.insecure(group: group)
+      .withServiceProviders([EchoProvider()])
+      .withLogger(self.serverLogger)
+      .bind(unixDomainSocketPath: path)
+      .wait()
+    defer {
+      XCTAssertNoThrow(try server.close().wait())
+    }
+
+    #if os(Linux)
+    let sockStream = CInt(SOCK_STREAM.rawValue)
+    #else
+    let sockStream = SOCK_STREAM
+    #endif
+    let clientSocket = socket(AF_UNIX, sockStream, 0)
+
+    XCTAssert(clientSocket != -1)
+    let addr = try SocketAddress(unixDomainSocketPath: path)
+    addr.withSockAddr { addr, size in
+      let ret = connect(clientSocket, addr, UInt32(size))
+      XCTAssert(ret != -1)
+    }
+    let flags = fcntl(clientSocket, F_GETFL, 0)
+    XCTAssert(flags != -1)
+    XCTAssert(fcntl(clientSocket, F_SETFL, flags | O_NONBLOCK) == 0)
+
+    let connection = ClientConnection.insecure(group: group)
+      .withBackgroundActivityLogger(self.clientLogger)
+      .withConnectedSocket(clientSocket)
+    defer {
+      XCTAssertNoThrow(try connection.close().wait())
+    }
+
+    let client = Echo_EchoClient(channel: connection)
+    let resp = try client.get(Echo_EchoRequest(text: "Hello")).response.wait()
+    XCTAssertEqual(resp.text, "Swift echo get: Hello")
+  }
+}