Sfoglia il codice sorgente

Expose listening address on HTTP2 server transport (#1933)

Motivation:

We should expose the listening address on the H2 server transports so that our client can connect to it.

Modifications:

Exposed a `listeningAddress` async property. It will block until we bind or error, return the listening address while it's bound and listening, and throw an error when it's not listening anymore.

Result:

We can now know what address our server transport's listening on for requests.
Gustavo Cairo 1 anno fa
parent
commit
9c8a34d36d

+ 22 - 0
Sources/GRPCHTTP2Core/Internal/NIOSocketAddress+GRPCSocketAddress.swift

@@ -16,6 +16,28 @@
 
 import NIOCore
 
+@_spi(Package)
+extension GRPCHTTP2Core.SocketAddress {
+  public init(_ nioSocketAddress: NIOCore.SocketAddress) {
+    switch nioSocketAddress {
+    case .v4(let address):
+      self = .ipv4(
+        host: address.host,
+        port: nioSocketAddress.port ?? 0
+      )
+
+    case .v6(let address):
+      self = .ipv6(
+        host: address.host,
+        port: nioSocketAddress.port ?? 0
+      )
+
+    case .unixDomainSocket:
+      self = .unixDomainSocket(path: nioSocketAddress.pathname ?? "")
+    }
+  }
+}
+
 @_spi(Package)
 extension NIOCore.SocketAddress {
   public init(_ address: GRPCHTTP2Core.SocketAddress.IPv4) throws {

+ 144 - 0
Sources/GRPCHTTP2TransportNIOPosix/HTTP2ServerTransport+Posix.swift

@@ -21,6 +21,7 @@ import NIOExtras
 import NIOPosix
 
 extension HTTP2ServerTransport {
+  /// A NIOPosix-backed implementation of a server transport.
   @available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *)
   public struct Posix: ServerTransport {
     private let address: GRPCHTTP2Core.SocketAddress
@@ -28,6 +29,124 @@ extension HTTP2ServerTransport {
     private let eventLoopGroup: MultiThreadedEventLoopGroup
     private let serverQuiescingHelper: ServerQuiescingHelper
 
+    private enum State {
+      case idle(EventLoopPromise<GRPCHTTP2Core.SocketAddress>)
+      case listening(EventLoopFuture<GRPCHTTP2Core.SocketAddress>)
+      case closedOrInvalidAddress(RuntimeError)
+
+      public var listeningAddressFuture: EventLoopFuture<GRPCHTTP2Core.SocketAddress> {
+        get throws {
+          switch self {
+          case .idle(let eventLoopPromise):
+            return eventLoopPromise.futureResult
+          case .listening(let eventLoopFuture):
+            return eventLoopFuture
+          case .closedOrInvalidAddress(let runtimeError):
+            throw runtimeError
+          }
+        }
+      }
+
+      enum OnBound {
+        case succeedPromise(
+          _ promise: EventLoopPromise<GRPCHTTP2Core.SocketAddress>,
+          address: GRPCHTTP2Core.SocketAddress
+        )
+        case failPromise(
+          _ promise: EventLoopPromise<GRPCHTTP2Core.SocketAddress>,
+          error: RuntimeError
+        )
+      }
+
+      mutating func addressBound(
+        _ address: NIOCore.SocketAddress?,
+        userProvidedAddress: GRPCHTTP2Core.SocketAddress
+      ) -> OnBound {
+        switch self {
+        case .idle(let listeningAddressPromise):
+          if let address {
+            self = .listening(listeningAddressPromise.futureResult)
+            return .succeedPromise(
+              listeningAddressPromise,
+              address: GRPCHTTP2Core.SocketAddress(address)
+            )
+
+          } else if userProvidedAddress.virtualSocket != nil {
+            self = .listening(listeningAddressPromise.futureResult)
+            return .succeedPromise(listeningAddressPromise, address: userProvidedAddress)
+
+          } else {
+            assertionFailure("Unknown address type")
+            let invalidAddressError = RuntimeError(
+              code: .transportError,
+              message: "Unknown address type returned by transport."
+            )
+            self = .closedOrInvalidAddress(invalidAddressError)
+            return .failPromise(listeningAddressPromise, error: invalidAddressError)
+          }
+
+        case .listening, .closedOrInvalidAddress:
+          fatalError(
+            "Invalid state: addressBound should only be called once and when in idle state"
+          )
+        }
+      }
+
+      enum OnClose {
+        case failPromise(
+          EventLoopPromise<GRPCHTTP2Core.SocketAddress>,
+          error: RuntimeError
+        )
+        case doNothing
+      }
+
+      mutating func close() -> OnClose {
+        let serverStoppedError = RuntimeError(
+          code: .serverIsStopped,
+          message: """
+            There is no listening address bound for this server: there may have been \
+            an error which caused the transport to close, or it may have shut down.
+            """
+        )
+
+        switch self {
+        case .idle(let listeningAddressPromise):
+          self = .closedOrInvalidAddress(serverStoppedError)
+          return .failPromise(listeningAddressPromise, error: serverStoppedError)
+
+        case .listening:
+          self = .closedOrInvalidAddress(serverStoppedError)
+          return .doNothing
+
+        case .closedOrInvalidAddress:
+          return .doNothing
+        }
+      }
+    }
+
+    private let listeningAddressState: _LockedValueBox<State>
+
+    /// The listening address for this server transport.
+    ///
+    /// It is an `async` property because it will only return once the address has been successfully bound.
+    ///
+    /// - Throws: A runtime error will be thrown if the address could not be bound or is not bound any
+    /// longer, because the transport isn't listening anymore. It can also throw if the transport returned an
+    /// invalid address.
+    public var listeningAddress: GRPCHTTP2Core.SocketAddress {
+      get async throws {
+        try await self.listeningAddressState
+          .withLockedValue { try $0.listeningAddressFuture }
+          .get()
+      }
+    }
+
+    /// Create a new `Posix` transport.
+    ///
+    /// - Parameters:
+    ///   - address: The address to which the server should be bound.
+    ///   - config: The transport configuration.
+    ///   - eventLoopGroup: The ELG from which to get ELs to run this transport.
     public init(
       address: GRPCHTTP2Core.SocketAddress,
       config: Config = .defaults,
@@ -37,11 +156,23 @@ extension HTTP2ServerTransport {
       self.config = config
       self.eventLoopGroup = eventLoopGroup
       self.serverQuiescingHelper = ServerQuiescingHelper(group: self.eventLoopGroup)
+
+      let eventLoop = eventLoopGroup.any()
+      self.listeningAddressState = _LockedValueBox(.idle(eventLoop.makePromise()))
     }
 
     public func listen(
       _ streamHandler: @escaping (RPCStream<Inbound, Outbound>) async -> Void
     ) async throws {
+      defer {
+        switch self.listeningAddressState.withLockedValue({ $0.close() }) {
+        case .failPromise(let promise, let error):
+          promise.fail(error)
+        case .doNothing:
+          ()
+        }
+      }
+
       let serverChannel = try await ServerBootstrap(group: self.eventLoopGroup)
         .serverChannelInitializer { channel in
           let quiescingHandler = self.serverQuiescingHelper.makeServerChannelHandler(
@@ -62,6 +193,19 @@ extension HTTP2ServerTransport {
           }
         }
 
+      let action = self.listeningAddressState.withLockedValue {
+        $0.addressBound(
+          serverChannel.channel.localAddress,
+          userProvidedAddress: self.address
+        )
+      }
+      switch action {
+      case .succeedPromise(let promise, let address):
+        promise.succeed(address)
+      case .failPromise(let promise, let error):
+        promise.fail(error)
+      }
+
       try await serverChannel.executeThenClose { inbound in
         try await withThrowingDiscardingTaskGroup { serverTaskGroup in
           for try await (connectionChannel, streamMultiplexer) in inbound {

+ 160 - 0
Tests/GRPCHTTP2TransportTests/HTTP2TransportNIOPosixTests.swift

@@ -0,0 +1,160 @@
+/*
+ * Copyright 2024, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import GRPCCore
+import GRPCHTTP2Core
+import GRPCHTTP2TransportNIOPosix
+import XCTest
+
+@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *)
+final class HTTP2TransportNIOPosixTests: XCTestCase {
+  func testGetListeningAddress_IPv4() async throws {
+    let transport = GRPCHTTP2Core.HTTP2ServerTransport.Posix(
+      address: .ipv4(host: "0.0.0.0", port: 0)
+    )
+
+    try await withThrowingDiscardingTaskGroup { group in
+      group.addTask {
+        try await transport.listen { _ in }
+      }
+
+      group.addTask {
+        let address = try await transport.listeningAddress
+        let ipv4Address = try XCTUnwrap(address.ipv4)
+        XCTAssertNotEqual(ipv4Address.port, 0)
+        transport.stopListening()
+      }
+    }
+  }
+
+  func testGetListeningAddress_IPv6() async throws {
+    let transport = GRPCHTTP2Core.HTTP2ServerTransport.Posix(address: .ipv6(host: "::1", port: 0))
+
+    try await withThrowingDiscardingTaskGroup { group in
+      group.addTask {
+        try await transport.listen { _ in }
+      }
+
+      group.addTask {
+        let address = try await transport.listeningAddress
+        let ipv6Address = try XCTUnwrap(address.ipv6)
+        XCTAssertNotEqual(ipv6Address.port, 0)
+        transport.stopListening()
+      }
+    }
+  }
+
+  func testGetListeningAddress_UnixDomainSocket() async throws {
+    let transport = GRPCHTTP2Core.HTTP2ServerTransport.Posix(
+      address: .unixDomainSocket(path: "/tmp/test")
+    )
+
+    try await withThrowingDiscardingTaskGroup { group in
+      group.addTask {
+        try await transport.listen { _ in }
+      }
+
+      group.addTask {
+        let address = try await transport.listeningAddress
+        XCTAssertEqual(
+          address.unixDomainSocket,
+          GRPCHTTP2Core.SocketAddress.UnixDomainSocket(path: "/tmp/test")
+        )
+        transport.stopListening()
+      }
+    }
+  }
+
+  func testGetListeningAddress_Vsock() async throws {
+    try XCTSkipUnless(self.vsockAvailable(), "Vsock unavailable")
+
+    let transport = GRPCHTTP2Core.HTTP2ServerTransport.Posix(
+      address: .vsock(contextID: .any, port: .any)
+    )
+
+    try await withThrowingDiscardingTaskGroup { group in
+      group.addTask {
+        try await transport.listen { _ in }
+      }
+
+      group.addTask {
+        let address = try await transport.listeningAddress
+        XCTAssertNotNil(address.virtualSocket)
+        transport.stopListening()
+      }
+    }
+  }
+
+  func testGetListeningAddress_InvalidAddress() async {
+    let transport = GRPCHTTP2Core.HTTP2ServerTransport.Posix(
+      address: .unixDomainSocket(path: "/this/should/be/an/invalid/path")
+    )
+
+    try? await withThrowingDiscardingTaskGroup { group in
+      group.addTask {
+        try await transport.listen { _ in }
+      }
+
+      group.addTask {
+        do {
+          _ = try await transport.listeningAddress
+          XCTFail("Should have thrown a RuntimeError")
+        } catch let error as RuntimeError {
+          XCTAssertEqual(error.code, .serverIsStopped)
+          XCTAssertEqual(
+            error.message,
+            """
+            There is no listening address bound for this server: there may have \
+            been an error which caused the transport to close, or it may have shut down.
+            """
+          )
+        }
+      }
+    }
+  }
+
+  func testGetListeningAddress_StoppedListening() async throws {
+    let transport = GRPCHTTP2Core.HTTP2ServerTransport.Posix(
+      address: .ipv4(host: "0.0.0.0", port: 0)
+    )
+
+    try? await withThrowingDiscardingTaskGroup { group in
+      group.addTask {
+        try await transport.listen { _ in }
+
+        do {
+          _ = try await transport.listeningAddress
+          XCTFail("Should have thrown a RuntimeError")
+        } catch let error as RuntimeError {
+          XCTAssertEqual(error.code, .serverIsStopped)
+          XCTAssertEqual(
+            error.message,
+            """
+            There is no listening address bound for this server: there may have \
+            been an error which caused the transport to close, or it may have shut down.
+            """
+          )
+        }
+      }
+
+      group.addTask {
+        let address = try await transport.listeningAddress
+        XCTAssertNotNil(address.ipv4)
+        transport.stopListening()
+      }
+    }
+  }
+}

+ 34 - 0
Tests/GRPCHTTP2TransportTests/XCTestCase+Vsock.swift

@@ -0,0 +1,34 @@
+/*
+ * Copyright 2024, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import NIOPosix
+import XCTest
+
+extension XCTestCase {
+  func vsockAvailable() -> Bool {
+    let fd: CInt
+    #if os(Linux)
+    fd = socket(AF_VSOCK, CInt(SOCK_STREAM.rawValue), 0)
+    #elseif canImport(Darwin)
+    fd = socket(AF_VSOCK, SOCK_STREAM, 0)
+    #else
+    fd = -1
+    #endif
+    if fd == -1 { return false }
+    precondition(close(fd) == 0)
+    return true
+  }
+}