Browse Source

Expose listening address on NIOTS server transport (#1939)

Gustavo Cairo 1 year ago
parent
commit
4f2e2e0fd7

+ 1 - 1
Sources/GRPCHTTP2TransportNIOPosix/HTTP2ServerTransport+Posix.swift

@@ -34,7 +34,7 @@ extension HTTP2ServerTransport {
       case listening(EventLoopFuture<GRPCHTTP2Core.SocketAddress>)
       case closedOrInvalidAddress(RuntimeError)
 
-      public var listeningAddressFuture: EventLoopFuture<GRPCHTTP2Core.SocketAddress> {
+      var listeningAddressFuture: EventLoopFuture<GRPCHTTP2Core.SocketAddress> {
         get throws {
           switch self {
           case .idle(let eventLoopPromise):

+ 128 - 1
Sources/GRPCHTTP2TransportNIOTransportServices/HTTP2ServerTransport+TransportServices.swift

@@ -30,6 +30,111 @@ extension HTTP2ServerTransport {
     private let eventLoopGroup: NIOTSEventLoopGroup
     private let serverQuiescingHelper: ServerQuiescingHelper
 
+    private enum State {
+      case idle(EventLoopPromise<GRPCHTTP2Core.SocketAddress>)
+      case listening(EventLoopFuture<GRPCHTTP2Core.SocketAddress>)
+      case closedOrInvalidAddress(RuntimeError)
+
+      var listeningAddressFuture: EventLoopFuture<GRPCHTTP2Core.SocketAddress> {
+        get throws {
+          switch self {
+          case .idle(let eventLoopPromise):
+            return eventLoopPromise.futureResult
+          case .listening(let eventLoopFuture):
+            return eventLoopFuture
+          case .closedOrInvalidAddress(let runtimeError):
+            throw runtimeError
+          }
+        }
+      }
+
+      enum OnBound {
+        case succeedPromise(
+          _ promise: EventLoopPromise<GRPCHTTP2Core.SocketAddress>,
+          address: GRPCHTTP2Core.SocketAddress
+        )
+        case failPromise(
+          _ promise: EventLoopPromise<GRPCHTTP2Core.SocketAddress>,
+          error: RuntimeError
+        )
+      }
+
+      mutating func addressBound(_ address: NIOCore.SocketAddress?) -> OnBound {
+        switch self {
+        case .idle(let listeningAddressPromise):
+          if let address {
+            self = .listening(listeningAddressPromise.futureResult)
+            return .succeedPromise(
+              listeningAddressPromise,
+              address: GRPCHTTP2Core.SocketAddress(address)
+            )
+
+          } else {
+            assertionFailure("Unknown address type")
+            let invalidAddressError = RuntimeError(
+              code: .transportError,
+              message: "Unknown address type returned by transport."
+            )
+            self = .closedOrInvalidAddress(invalidAddressError)
+            return .failPromise(listeningAddressPromise, error: invalidAddressError)
+          }
+
+        case .listening, .closedOrInvalidAddress:
+          fatalError(
+            "Invalid state: addressBound should only be called once and when in idle state"
+          )
+        }
+      }
+
+      enum OnClose {
+        case failPromise(
+          EventLoopPromise<GRPCHTTP2Core.SocketAddress>,
+          error: RuntimeError
+        )
+        case doNothing
+      }
+
+      mutating func close() -> OnClose {
+        let serverStoppedError = RuntimeError(
+          code: .serverIsStopped,
+          message: """
+            There is no listening address bound for this server: there may have been \
+            an error which caused the transport to close, or it may have shut down.
+            """
+        )
+
+        switch self {
+        case .idle(let listeningAddressPromise):
+          self = .closedOrInvalidAddress(serverStoppedError)
+          return .failPromise(listeningAddressPromise, error: serverStoppedError)
+
+        case .listening:
+          self = .closedOrInvalidAddress(serverStoppedError)
+          return .doNothing
+
+        case .closedOrInvalidAddress:
+          return .doNothing
+        }
+      }
+    }
+
+    private let listeningAddressState: _LockedValueBox<State>
+
+    /// The listening address for this server transport.
+    ///
+    /// It is an `async` property because it will only return once the address has been successfully bound.
+    ///
+    /// - Throws: A runtime error will be thrown if the address could not be bound or is not bound any
+    /// longer, because the transport isn't listening anymore. It can also throw if the transport returned an
+    /// invalid address.
+    public var listeningAddress: GRPCHTTP2Core.SocketAddress {
+      get async throws {
+        try await self.listeningAddressState
+          .withLockedValue { try $0.listeningAddressFuture }
+          .get()
+      }
+    }
+
     /// Create a new `TransportServices` transport.
     ///
     /// - Parameters:
@@ -45,11 +150,23 @@ extension HTTP2ServerTransport {
       self.config = config
       self.eventLoopGroup = eventLoopGroup
       self.serverQuiescingHelper = ServerQuiescingHelper(group: self.eventLoopGroup)
+
+      let eventLoop = eventLoopGroup.any()
+      self.listeningAddressState = _LockedValueBox(.idle(eventLoop.makePromise()))
     }
 
     public func listen(
       _ streamHandler: @escaping (RPCStream<Inbound, Outbound>) async -> Void
     ) async throws {
+      defer {
+        switch self.listeningAddressState.withLockedValue({ $0.close() }) {
+        case .failPromise(let promise, let error):
+          promise.fail(error)
+        case .doNothing:
+          ()
+        }
+      }
+
       let serverChannel = try await NIOTSListenerBootstrap(group: self.eventLoopGroup)
         .serverChannelInitializer { channel in
           let quiescingHandler = self.serverQuiescingHelper.makeServerChannelHandler(
@@ -70,6 +187,16 @@ extension HTTP2ServerTransport {
           }
         }
 
+      let action = self.listeningAddressState.withLockedValue {
+        $0.addressBound(serverChannel.channel.localAddress)
+      }
+      switch action {
+      case .succeedPromise(let promise, let address):
+        promise.succeed(address)
+      case .failPromise(let promise, let error):
+        promise.fail(error)
+      }
+
       try await serverChannel.executeThenClose { inbound in
         try await withThrowingDiscardingTaskGroup { serverTaskGroup in
           for try await (connectionChannel, streamMultiplexer) in inbound {
@@ -203,7 +330,7 @@ extension NIOTSListenerBootstrap {
     to address: GRPCHTTP2Core.SocketAddress,
     childChannelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<Output>
   ) async throws -> NIOAsyncChannel<Output, Never> {
-    if let virtualSocket = address.virtualSocket {
+    if address.virtualSocket != nil {
       throw RuntimeError(
         code: .transportError,
         message: """

+ 5 - 3
Tests/GRPCHTTP2TransportTests/HTTP2TransportNIOPosixTests.swift

@@ -41,7 +41,9 @@ final class HTTP2TransportNIOPosixTests: XCTestCase {
   }
 
   func testGetListeningAddress_IPv6() async throws {
-    let transport = GRPCHTTP2Core.HTTP2ServerTransport.Posix(address: .ipv6(host: "::1", port: 0))
+    let transport = GRPCHTTP2Core.HTTP2ServerTransport.Posix(
+      address: .ipv6(host: "::1", port: 0)
+    )
 
     try await withThrowingDiscardingTaskGroup { group in
       group.addTask {
@@ -59,7 +61,7 @@ final class HTTP2TransportNIOPosixTests: XCTestCase {
 
   func testGetListeningAddress_UnixDomainSocket() async throws {
     let transport = GRPCHTTP2Core.HTTP2ServerTransport.Posix(
-      address: .unixDomainSocket(path: "/tmp/test")
+      address: .unixDomainSocket(path: "/tmp/posix-uds-test")
     )
 
     try await withThrowingDiscardingTaskGroup { group in
@@ -71,7 +73,7 @@ final class HTTP2TransportNIOPosixTests: XCTestCase {
         let address = try await transport.listeningAddress
         XCTAssertEqual(
           address.unixDomainSocket,
-          GRPCHTTP2Core.SocketAddress.UnixDomainSocket(path: "/tmp/test")
+          GRPCHTTP2Core.SocketAddress.UnixDomainSocket(path: "/tmp/posix-uds-test")
         )
         transport.stopListening()
       }

+ 144 - 0
Tests/GRPCHTTP2TransportTests/HTTP2TransportNIOTransportServicesTests.swift

@@ -0,0 +1,144 @@
+/*
+ * Copyright 2024, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#if canImport(Network)
+import GRPCCore
+import GRPCHTTP2Core
+import GRPCHTTP2TransportNIOTransportServices
+import XCTest
+
+@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *)
+final class HTTP2TransportNIOTransportServicesTests: XCTestCase {
+  func testGetListeningAddress_IPv4() async throws {
+    let transport = GRPCHTTP2Core.HTTP2ServerTransport.TransportServices(
+      address: .ipv4(host: "0.0.0.0", port: 0)
+    )
+
+    try await withThrowingDiscardingTaskGroup { group in
+      group.addTask {
+        try await transport.listen { _ in }
+      }
+
+      group.addTask {
+        let address = try await transport.listeningAddress
+        let ipv4Address = try XCTUnwrap(address.ipv4)
+        XCTAssertNotEqual(ipv4Address.port, 0)
+        transport.stopListening()
+      }
+    }
+  }
+
+  func testGetListeningAddress_IPv6() async throws {
+    let transport = GRPCHTTP2Core.HTTP2ServerTransport.TransportServices(
+      address: .ipv6(host: "::1", port: 0)
+    )
+
+    try await withThrowingDiscardingTaskGroup { group in
+      group.addTask {
+        try await transport.listen { _ in }
+      }
+
+      group.addTask {
+        let address = try await transport.listeningAddress
+        let ipv6Address = try XCTUnwrap(address.ipv6)
+        XCTAssertNotEqual(ipv6Address.port, 0)
+        transport.stopListening()
+      }
+    }
+  }
+
+  func testGetListeningAddress_UnixDomainSocket() async throws {
+    let transport = GRPCHTTP2Core.HTTP2ServerTransport.TransportServices(
+      address: .unixDomainSocket(path: "/tmp/niots-uds-test")
+    )
+
+    try await withThrowingDiscardingTaskGroup { group in
+      group.addTask {
+        try await transport.listen { _ in }
+      }
+
+      group.addTask {
+        let address = try await transport.listeningAddress
+        XCTAssertEqual(
+          address.unixDomainSocket,
+          GRPCHTTP2Core.SocketAddress.UnixDomainSocket(path: "/tmp/niots-uds-test")
+        )
+        transport.stopListening()
+      }
+    }
+  }
+
+  func testGetListeningAddress_InvalidAddress() async {
+    let transport = GRPCHTTP2Core.HTTP2ServerTransport.TransportServices(
+      address: .unixDomainSocket(path: "/this/should/be/an/invalid/path")
+    )
+
+    try? await withThrowingDiscardingTaskGroup { group in
+      group.addTask {
+        try await transport.listen { _ in }
+      }
+
+      group.addTask {
+        do {
+          _ = try await transport.listeningAddress
+          XCTFail("Should have thrown a RuntimeError")
+        } catch let error as RuntimeError {
+          XCTAssertEqual(error.code, .serverIsStopped)
+          XCTAssertEqual(
+            error.message,
+            """
+            There is no listening address bound for this server: there may have \
+            been an error which caused the transport to close, or it may have shut down.
+            """
+          )
+        }
+      }
+    }
+  }
+
+  func testGetListeningAddress_StoppedListening() async throws {
+    let transport = GRPCHTTP2Core.HTTP2ServerTransport.TransportServices(
+      address: .ipv4(host: "0.0.0.0", port: 0)
+    )
+
+    try? await withThrowingDiscardingTaskGroup { group in
+      group.addTask {
+        try await transport.listen { _ in }
+
+        do {
+          _ = try await transport.listeningAddress
+          XCTFail("Should have thrown a RuntimeError")
+        } catch let error as RuntimeError {
+          XCTAssertEqual(error.code, .serverIsStopped)
+          XCTAssertEqual(
+            error.message,
+            """
+            There is no listening address bound for this server: there may have \
+            been an error which caused the transport to close, or it may have shut down.
+            """
+          )
+        }
+      }
+
+      group.addTask {
+        let address = try await transport.listeningAddress
+        XCTAssertNotNil(address.ipv4)
+        transport.stopListening()
+      }
+    }
+  }
+}
+#endif