Browse Source

Adopt `ClientContext` changes (#56)

This PR adopts the changes introduced in
https://github.com/grpc/grpc-swift/pull/2158, requiring client
transports to provide a `ClientContext` alongside the stream.
Gus Cairo 1 year ago
parent
commit
b8a3c3b319

+ 1 - 1
Package.swift

@@ -35,7 +35,7 @@ let products: [Product] = [
 let dependencies: [Package.Dependency] = [
   .package(
     url: "https://github.com/grpc/grpc-swift.git",
-    exact: "2.0.0-beta.2"
+    branch: "main"
   ),
   .package(
     url: "https://github.com/apple/swift-nio.git",

+ 18 - 6
Sources/GRPCNIOTransportCore/Client/Connection/Connection.swift

@@ -202,10 +202,10 @@ package final class Connection: Sendable {
     descriptor: MethodDescriptor,
     options: CallOptions
   ) async throws -> Stream {
-    let (multiplexer, scheme) = try self.state.withLock { state in
+    let (multiplexer, scheme, remotePeer, localPeer) = try self.state.withLock { state in
       switch state {
       case .connected(let connected):
-        return (connected.multiplexer, connected.scheme)
+        return (connected.multiplexer, connected.scheme, connected.remotePeer, connected.localPeer)
       case .notConnected, .closing, .closed:
         throw RPCError(code: .unavailable, message: "subchannel isn't ready")
       }
@@ -246,7 +246,13 @@ package final class Connection: Sendable {
         }
       }
 
-      return Stream(wrapping: stream, descriptor: descriptor)
+      let context = ClientContext(
+        descriptor: descriptor,
+        remotePeer: remotePeer,
+        localPeer: localPeer
+      )
+
+      return Stream(wrapping: stream, context: context)
     } catch {
       throw RPCError(code: .unavailable, message: "subchannel is unavailable", cause: error)
     }
@@ -417,16 +423,16 @@ extension Connection {
       }
     }
 
-    let descriptor: MethodDescriptor
+    let context: ClientContext
 
     private let http2Stream: NIOAsyncChannel<RPCResponsePart, RPCRequestPart>
 
     init(
       wrapping stream: NIOAsyncChannel<RPCResponsePart, RPCRequestPart>,
-      descriptor: MethodDescriptor
+      context: ClientContext
     ) {
       self.http2Stream = stream
-      self.descriptor = descriptor
+      self.context = context
     }
 
     package func execute<T>(
@@ -457,6 +463,10 @@ extension Connection {
     struct Connected: Sendable {
       /// The connection channel.
       var channel: NIOAsyncChannel<ClientConnectionEvent, Void>
+      /// The connection's remote peer information.
+      var remotePeer: String
+      /// The connection's local peer information.
+      var localPeer: String
       /// Multiplexer for creating HTTP/2 streams.
       var multiplexer: NIOHTTP2Handler.AsyncStreamMultiplexer<Void>
       /// Whether the connection is plaintext, `false` implies TLS is being used.
@@ -464,6 +474,8 @@ extension Connection {
 
       init(_ connection: HTTP2Connection) {
         self.channel = connection.channel
+        self.remotePeer = connection.channel.remoteAddressInfo
+        self.localPeer = connection.channel.localAddressInfo
         self.multiplexer = connection.multiplexer
         self.scheme = connection.isPlaintext ? .http : .https
       }

+ 4 - 4
Sources/GRPCNIOTransportCore/Client/Connection/GRPCChannel.swift

@@ -198,11 +198,11 @@ package final class GRPCChannel: ClientTransport {
     self.input.continuation.yield(.close)
   }
 
-  /// Opens a stream using the transport, and uses it as input into a user-provided closure.
+  /// Opens a stream using the transport, and uses it as input into a user-provided closure, alongside the client's context.
   package func withStream<T: Sendable>(
     descriptor: MethodDescriptor,
     options: CallOptions,
-    _ closure: (_ stream: RPCStream<Inbound, Outbound>) async throws -> T
+    _ closure: (_ stream: RPCStream<Inbound, Outbound>, _ context: ClientContext) async throws -> T
   ) async throws -> T {
     // Merge options from the call with those from the service config.
     let methodConfig = self.config(forMethod: descriptor)
@@ -214,11 +214,11 @@ package final class GRPCChannel: ClientTransport {
       case .created(let stream):
         return try await stream.execute { inbound, outbound in
           let rpcStream = RPCStream(
-            descriptor: stream.descriptor,
+            descriptor: stream.context.descriptor,
             inbound: RPCAsyncSequence<RPCResponsePart, any Error>(wrapping: inbound),
             outbound: RPCWriter.Closable(wrapping: outbound)
           )
-          return try await closure(rpcStream)
+          return try await closure(rpcStream, stream.context)
         }
 
       case .tryAgain(let error):

+ 93 - 0
Sources/GRPCNIOTransportCore/Internal/Channel+AddressInfo.swift

@@ -0,0 +1,93 @@
+/*
+ * Copyright 2025, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+internal import NIOCore
+
+extension NIOAsyncChannel {
+  var remoteAddressInfo: String {
+    guard let remote = self.channel.remoteAddress else {
+      return "<unknown>"
+    }
+
+    switch remote {
+    case .v4(let address):
+      // '!' is safe, v4 always has a port.
+      return "ipv4:\(address.host):\(remote.port!)"
+
+    case .v6(let address):
+      // '!' is safe, v6 always has a port.
+      return "ipv6:[\(address.host)]:\(remote.port!)"
+
+    case .unixDomainSocket:
+      // '!' is safe, UDS always has a path.
+      if remote.pathname!.isEmpty {
+        guard let local = self.channel.localAddress else {
+          return "unix:<unknown>"
+        }
+
+        switch local {
+        case .unixDomainSocket:
+          // '!' is safe, UDS always has a path.
+          return "unix:\(local.pathname!)"
+
+        case .v4, .v6:
+          // Remote address is UDS but local isn't. This shouldn't ever happen.
+          return "unix:<unknown>"
+        }
+      } else {
+        // '!' is safe, UDS always has a path.
+        return "unix:\(remote.pathname!)"
+      }
+    }
+  }
+
+  var localAddressInfo: String {
+    guard let local = self.channel.localAddress else {
+      return "<unknown>"
+    }
+
+    switch local {
+    case .v4(let address):
+      // '!' is safe, v4 always has a port.
+      return "ipv4:\(address.host):\(local.port!)"
+
+    case .v6(let address):
+      // '!' is safe, v6 always has a port.
+      return "ipv6:[\(address.host)]:\(local.port!)"
+
+    case .unixDomainSocket:
+      // '!' is safe, UDS always has a path.
+      if local.pathname!.isEmpty {
+        guard let remote = self.channel.remoteAddress else {
+          return "unix:<unknown>"
+        }
+
+        switch remote {
+        case .unixDomainSocket:
+          // '!' is safe, UDS always has a path.
+          return "unix:\(remote.pathname!)"
+
+        case .v4, .v6:
+          // Remote address is UDS but local isn't. This shouldn't ever happen.
+          return "unix:<unknown>"
+        }
+      } else {
+        // '!' is safe, UDS always has a path.
+        return "unix:\(local.pathname!)"
+      }
+    }
+  }
+}

+ 1 - 35
Sources/GRPCNIOTransportCore/Server/CommonHTTP2ServerTransport.swift

@@ -191,40 +191,6 @@ package final class CommonHTTP2ServerTransport<
     }
   }
 
-  private func peerInfo(channel: any Channel) -> String {
-    guard let remote = channel.remoteAddress else {
-      return "<unknown>"
-    }
-
-    switch remote {
-    case .v4(let address):
-      // '!' is safe, v4 always has a port.
-      return "ipv4:\(address.host):\(remote.port!)"
-
-    case .v6(let address):
-      // '!' is safe, v6 always has a port.
-      return "ipv6:[\(address.host)]:\(remote.port!)"
-
-    case .unixDomainSocket:
-      // The pathname will be on the local address.
-      guard let local = channel.localAddress else {
-        // UDS but no local address; this shouldn't ever happen but at least note the transport
-        // as being UDS.
-        return "unix:<unknown>"
-      }
-
-      switch local {
-      case .unixDomainSocket:
-        // '!' is safe, UDS always has a path.
-        return "unix:\(local.pathname!)"
-
-      case .v4, .v6:
-        // Remote address is UDS but local isn't. This shouldn't ever happen.
-        return "unix:<unknown>"
-      }
-    }
-  }
-
   private func handleConnection(
     _ connection: NIOAsyncChannel<HTTP2Frame, HTTP2Frame>,
     multiplexer: ChannelPipeline.SynchronousOperations.HTTP2StreamMultiplexer,
@@ -233,7 +199,7 @@ package final class CommonHTTP2ServerTransport<
       _ context: ServerContext
     ) async -> Void
   ) async throws {
-    let peer = self.peerInfo(channel: connection.channel)
+    let peer = connection.remoteAddressInfo
     try await connection.executeThenClose { inbound, _ in
       await withDiscardingTaskGroup { group in
         group.addTask {

+ 1 - 1
Sources/GRPCNIOTransportHTTP2Posix/HTTP2ClientTransport+Posix.swift

@@ -120,7 +120,7 @@ extension HTTP2ClientTransport {
     public func withStream<T: Sendable>(
       descriptor: MethodDescriptor,
       options: CallOptions,
-      _ closure: (RPCStream<Inbound, Outbound>) async throws -> T
+      _ closure: (RPCStream<Inbound, Outbound>, ClientContext) async throws -> T
     ) async throws -> T {
       try await self.channel.withStream(descriptor: descriptor, options: options, closure)
     }

+ 1 - 1
Sources/GRPCNIOTransportHTTP2TransportServices/HTTP2ClientTransport+TransportServices.swift

@@ -118,7 +118,7 @@ extension HTTP2ClientTransport {
     public func withStream<T: Sendable>(
       descriptor: MethodDescriptor,
       options: CallOptions,
-      _ closure: (RPCStream<Inbound, Outbound>) async throws -> T
+      _ closure: (RPCStream<Inbound, Outbound>, ClientContext) async throws -> T
     ) async throws -> T {
       try await self.channel.withStream(descriptor: descriptor, options: options, closure)
     }

+ 6 - 6
Tests/GRPCNIOTransportCoreTests/Client/Connection/GRPCChannelTests.swift

@@ -353,7 +353,7 @@ final class GRPCChannelTests: XCTestCase {
         await channel.connect()
       }
 
-      try await channel.withStream(descriptor: .echoGet, options: .defaults) { stream in
+      try await channel.withStream(descriptor: .echoGet, options: .defaults) { stream, _ in
         try await stream.outbound.write(.metadata([:]))
 
         var iterator = stream.inbound.makeAsyncIterator()
@@ -441,7 +441,7 @@ final class GRPCChannelTests: XCTestCase {
       // be queued though.
       for _ in 1 ... 100 {
         group.addTask {
-          try await channel.withStream(descriptor: .echoGet, options: .defaults) { stream in
+          try await channel.withStream(descriptor: .echoGet, options: .defaults) { stream, _ in
             try await stream.outbound.write(.metadata([:]))
             await stream.outbound.finish()
 
@@ -510,7 +510,7 @@ final class GRPCChannelTests: XCTestCase {
           options.waitForReady = false
 
           await XCTAssertThrowsErrorAsync(ofType: RPCError.self) {
-            try await channel.withStream(descriptor: .echoGet, options: options) { _ in
+            try await channel.withStream(descriptor: .echoGet, options: options) { _, _ in
               XCTFail("Unexpected stream")
             }
           } errorHandler: { error in
@@ -780,7 +780,7 @@ final class GRPCChannelTests: XCTestCase {
 
           // Try to open a new stream.
           await XCTAssertThrowsErrorAsync(ofType: RPCError.self) {
-            try await channel.withStream(descriptor: .echoGet, options: .defaults) { stream in
+            try await channel.withStream(descriptor: .echoGet, options: .defaults) { stream, _ in
               XCTFail("Unexpected new stream")
             }
           } errorHandler: { error in
@@ -823,7 +823,7 @@ final class GRPCChannelTests: XCTestCase {
       }
 
       func doAnRPC() async throws {
-        try await channel.withStream(descriptor: .echoGet, options: .defaults) { stream in
+        try await channel.withStream(descriptor: .echoGet, options: .defaults) { stream, _ in
           try await stream.outbound.write(.metadata([:]))
           await stream.outbound.finish()
 
@@ -873,7 +873,7 @@ extension GRPCChannel {
     let values: Metadata.StringValues? = try await self.withStream(
       descriptor: .echoGet,
       options: .defaults
-    ) { stream in
+    ) { stream, _ in
       try await stream.outbound.write(.metadata([:]))
       await stream.outbound.finish()
 

+ 1 - 1
Tests/GRPCNIOTransportHTTP2Tests/ControlClient.swift

@@ -109,7 +109,7 @@ internal struct ControlClient {
   internal func peerInfo<R>(
     options: GRPCCore.CallOptions = .defaults,
     _ body: @Sendable @escaping (
-      _ response: GRPCCore.ClientResponse<String>
+      _ response: GRPCCore.ClientResponse<ControlService.PeerInfoResponse>
     ) async throws -> R = { try $0.message }
   ) async throws -> R where R: Sendable {
     try await self.client.unary(

+ 56 - 5
Tests/GRPCNIOTransportHTTP2Tests/ControlService.swift

@@ -65,11 +65,22 @@ struct ControlService: RegistrableRPCService {
     router.registerHandler(
       forMethod: MethodDescriptor(fullyQualifiedService: "Control", method: "PeerInfo"),
       deserializer: JSONDeserializer<String>(),
-      serializer: JSONSerializer<String>()
+      serializer: JSONSerializer<PeerInfoResponse>()
     ) { request, context in
       return StreamingServerResponse { response in
-        let info = try await self.peerInfo(context: context)
-        try await response.write(info)
+        let peerInfo = PeerInfoResponse(
+          client: PeerInfoResponse.PeerInfo(
+            local: clientLocalPeerInfo(request: request),
+            remote: clientRemotePeerInfo(request: request)
+          ),
+          server: PeerInfoResponse.PeerInfo(
+            local: serverLocalPeerInfo(context: context),
+            remote: serverRemotePeerInfo(context: context)
+          )
+        )
+
+        try await response.write(peerInfo)
+
         return [:]
       }
     }
@@ -101,8 +112,20 @@ extension ControlService {
     }
   }
 
-  private func peerInfo(context: ServerContext) async throws -> String {
-    return context.peer
+  private func serverRemotePeerInfo(context: ServerContext) -> String {
+    context.peer
+  }
+
+  private func serverLocalPeerInfo(context: ServerContext) -> String {
+    "<not yet implemented>"
+  }
+
+  private func clientRemotePeerInfo<T>(request: StreamingServerRequest<T>) -> String {
+    request.metadata[stringValues: "remotePeer"].first(where: { _ in true })!
+  }
+
+  private func clientLocalPeerInfo<T>(request: StreamingServerRequest<T>) -> String {
+    request.metadata[stringValues: "localPeer"].first(where: { _ in true })!
   }
 
   private func handle(
@@ -235,6 +258,18 @@ extension ControlService {
   }
 }
 
+extension ControlService {
+  struct PeerInfoResponse: Codable {
+    struct PeerInfo: Codable {
+      var local: String
+      var remote: String
+    }
+
+    var client: PeerInfo
+    var server: PeerInfo
+  }
+}
+
 extension Metadata {
   fileprivate func echo() -> Self {
     var copy = Metadata()
@@ -264,3 +299,19 @@ private struct UnsafeTransfer<Wrapped> {
 }
 
 extension UnsafeTransfer: @unchecked Sendable {}
+
+struct PeerInfoClientInterceptor: ClientInterceptor {
+  func intercept<Input, Output>(
+    request: StreamingClientRequest<Input>,
+    context: ClientContext,
+    next: (
+      StreamingClientRequest<Input>,
+      ClientContext
+    ) async throws -> StreamingClientResponse<Output>
+  ) async throws -> StreamingClientResponse<Output> where Input: Sendable, Output: Sendable {
+    var request = request
+    request.metadata.addString(context.localPeer, forKey: "localPeer")
+    request.metadata.addString(context.remotePeer, forKey: "remotePeer")
+    return try await next(request, context)
+  }
+}

+ 35 - 6
Tests/GRPCNIOTransportHTTP2Tests/HTTP2TransportTests.swift

@@ -228,7 +228,7 @@ final class HTTP2TransportTests: XCTestCase {
       #endif
     }
 
-    return GRPCClient(transport: transport)
+    return GRPCClient(transport: transport, interceptors: [PeerInfoClientInterceptor()])
   }
 
   func testUnaryOK() async throws {
@@ -1632,8 +1632,20 @@ final class HTTP2TransportTests: XCTestCase {
       serverAddress: .ipv4(host: "127.0.0.1", port: 0)
     ) { control, _, _ in
       let peerInfo = try await control.peerInfo()
-      let matches = peerInfo.matches(of: /ipv4:127.0.0.1:\d+/)
-      XCTAssertNotNil(matches)
+
+      let serverRemotePeerMatches = peerInfo.server.remote.wholeMatch(of: /ipv4:127\.0\.0\.1:(\d+)/)
+      let clientPort = try XCTUnwrap(serverRemotePeerMatches).1
+
+      // TODO: Uncomment when server local peer info is implemented
+
+      //      let serverLocalPeerMatches = peerInfo.server.local.wholeMatch(of: /<not yet implemented>/)
+      //      let serverPort = XCTUnwrap(serverLocalPeerMatches).1
+
+      //      let clientRemotePeerMatches = peerInfo.client.remote.wholeMatch(of: /ipv4:127.0.0.1:(\d+)/)
+      //      XCTAssertEqual(try XCTUnwrap(clientRemotePeerMatches).1, serverPort)
+
+      let clientLocalPeerMatches = peerInfo.client.local.wholeMatch(of: /ipv4:127\.0\.0\.1:(\d+)/)
+      XCTAssertEqual(try XCTUnwrap(clientLocalPeerMatches).1, clientPort)
     }
   }
 
@@ -1642,8 +1654,20 @@ final class HTTP2TransportTests: XCTestCase {
       serverAddress: .ipv6(host: "::1", port: 0)
     ) { control, _, _ in
       let peerInfo = try await control.peerInfo()
-      let matches = peerInfo.matches(of: /ipv6:[::1]:\d+/)
-      XCTAssertNotNil(matches)
+
+      let serverRemotePeerMatches = peerInfo.server.remote.wholeMatch(of: /ipv6:\[::1\]:(\d+)/)
+      let clientPort = try XCTUnwrap(serverRemotePeerMatches).1
+
+      // TODO: Uncomment when server local peer info is implemented
+
+      //      let serverLocalPeerMatches = peerInfo.server.local.wholeMatch(of: /<not yet implemented>/)
+      //      let serverPort = XCTUnwrap(serverLocalPeerMatches).1
+
+      //      let clientRemotePeerMatches = peerInfo.client.remote.wholeMatch(of: /ipv6:\[::1\]:(\d+)/)
+      //      XCTAssertEqual(try XCTUnwrap(clientRemotePeerMatches).1, serverPort)
+
+      let clientLocalPeerMatches = peerInfo.client.local.wholeMatch(of: /ipv6:\[::1\]:(\d+)/)
+      XCTAssertEqual(try XCTUnwrap(clientLocalPeerMatches).1, clientPort)
     }
   }
 
@@ -1653,7 +1677,12 @@ final class HTTP2TransportTests: XCTestCase {
       serverAddress: .unixDomainSocket(path: path)
     ) { control, _, _ in
       let peerInfo = try await control.peerInfo()
-      XCTAssertEqual(peerInfo, "unix:peer-info-uds")
+
+      XCTAssertNotNil(peerInfo.server.remote.wholeMatch(of: /unix:peer-info-uds/))
+      XCTAssertNotNil(peerInfo.server.local.wholeMatch(of: /<not yet implemented>/))
+
+      XCTAssertNotNil(peerInfo.client.remote.wholeMatch(of: /unix:peer-info-uds/))
+      XCTAssertNotNil(peerInfo.client.local.wholeMatch(of: /unix:peer-info-uds/))
     }
   }
 }

+ 1 - 1
dev/license-check.sh

@@ -88,7 +88,7 @@ check_copyright_headers() {
 
     actual_sha=$(head -n "$((drop_first + expected_lines))" "$filename" \
       | tail -n "$expected_lines" \
-      | sed -e 's/201[56789]-20[12][0-9]/YEARS/' -e 's/20[12][0-9]/YEARS/' \
+      | sed -e 's/20[12][0-9]-20[12][0-9]/YEARS/' -e 's/20[12][0-9]/YEARS/' \
       | shasum \
       | awk '{print $1}')