Quellcode durchsuchen

Add remote peer info to the server context (#2136)

Motivation:

It's often useful to know the identity of the remote peer when handling
RPCs.

Modifications:

- Add a 'peer' to the server context
- Implement this for the in-process transport
- Make some in-process inits `package`, these should never have been
`public`

Result:

Server RPCs have some idea what the address of remote peer is
George Barnett vor 1 Jahr
Ursprung
Commit
0fc49565a3

+ 20 - 1
Sources/GRPCCore/Call/Server/ServerContext.swift

@@ -19,6 +19,19 @@ public struct ServerContext: Sendable {
   /// A description of the method being called.
   public var descriptor: MethodDescriptor
 
+  /// A description of the remote peer.
+  ///
+  /// The format of the description should follow the pattern "<transport>:<address>" where
+  /// "<transport>" indicates the underlying network transport (such as "ipv4", "unix", or
+  /// "in-process"). This is a guideline for how descriptions should be formatted; different
+  /// implementations may not follow this format so you shouldn't make assumptions based on it.
+  ///
+  /// Some examples include:
+  /// - "ipv4:127.0.0.1:31415",
+  /// - "ipv6:[::1]:443",
+  /// - "in-process:27182".
+  public var peer: String
+
   /// A handle for checking the cancellation status of an RPC.
   public var cancellation: RPCCancellationHandle
 
@@ -26,10 +39,16 @@ public struct ServerContext: Sendable {
   ///
   /// - Parameters:
   ///   - descriptor: A description of the method being called.
+  ///   - peer: A description of the remote peer.
   ///   - cancellation: A cancellation handle. You can create a cancellation handle
   ///     using ``withServerContextRPCCancellationHandle(_:)``.
-  public init(descriptor: MethodDescriptor, cancellation: RPCCancellationHandle) {
+  public init(
+    descriptor: MethodDescriptor,
+    peer: String,
+    cancellation: RPCCancellationHandle
+  ) {
     self.descriptor = descriptor
+    self.peer = peer
     self.cancellation = cancellation
   }
 }

+ 1 - 1
Sources/GRPCInProcessTransport/InProcessTransport+Client.swift

@@ -109,7 +109,7 @@ extension InProcessTransport {
     /// - Parameters:
     ///   - server: The in-process server transport to connect to.
     ///   - serviceConfig: Service configuration.
-    public init(
+    package init(
       server: InProcessTransport.Server,
       serviceConfig: ServiceConfig = ServiceConfig()
     ) {

+ 8 - 2
Sources/GRPCInProcessTransport/InProcessTransport+Server.swift

@@ -34,6 +34,7 @@ extension InProcessTransport {
 
     private let newStreams: AsyncStream<RPCStream<Inbound, Outbound>>
     private let newStreamsContinuation: AsyncStream<RPCStream<Inbound, Outbound>>.Continuation
+    private let peer: String
 
     private struct State: Sendable {
       private var _nextID: UInt64
@@ -73,9 +74,10 @@ extension InProcessTransport {
     private let handles: Mutex<State>
 
     /// Creates a new instance of ``Server``.
-    public init() {
+    package init(peer: String) {
       (self.newStreams, self.newStreamsContinuation) = AsyncStream.makeStream()
       self.handles = Mutex(State())
+      self.peer = peer
     }
 
     /// Publish a new ``RPCStream``, which will be returned by the transport's ``events``
@@ -115,7 +117,11 @@ extension InProcessTransport {
                 handle.cancel()
               }
 
-              let context = ServerContext(descriptor: stream.descriptor, cancellation: handle)
+              let context = ServerContext(
+                descriptor: stream.descriptor,
+                peer: self.peer,
+                cancellation: handle
+              )
               await streamHandler(stream, context)
             }
           }

+ 2 - 1
Sources/GRPCInProcessTransport/InProcessTransport.swift

@@ -25,7 +25,8 @@ public struct InProcessTransport: Sendable {
   /// - Parameters:
   ///   - serviceConfig: Configuration describing how methods should be executed.
   public init(serviceConfig: ServiceConfig = ServiceConfig()) {
-    self.server = Self.Server()
+    let peer = "in-process:\(System.pid())"
+    self.server = Self.Server(peer: peer)
     self.client = Self.Client(server: self.server, serviceConfig: serviceConfig)
   }
 }

+ 38 - 0
Sources/GRPCInProcessTransport/Syscalls.swift

@@ -0,0 +1,38 @@
+/*
+ * Copyright 2024, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#if canImport(Darwin)
+import Darwin
+#elseif canImport(Glibc)
+import Glibc
+#elseif canImport(Musl)
+import Musl
+#endif
+
+enum System {
+  static func pid() -> Int {
+    #if canImport(Darwin)
+    let pid = Darwin.getpid()
+    return Int(pid)
+    #elseif canImport(Glibc)
+    let pid = Glibc.getpid()
+    return Int(pid)
+    #elseif canImport(Musl)
+    let pid = Musl.getpid()
+    return Int(pid)
+    #endif
+  }
+}

+ 2 - 2
Tests/GRPCCoreTests/Call/Client/Internal/ClientRPCExecutorTestSupport/ClientRPCExecutorTestHarness.swift

@@ -47,13 +47,13 @@ struct ClientRPCExecutorTestHarness {
 
     switch transport {
     case .inProcess:
-      let server = InProcessTransport.Server()
+      let server = InProcessTransport.Server(peer: "in-process:1234")
       let client = server.spawnClientTransport()
       self.serverTransport = StreamCountingServerTransport(wrapping: server)
       self.clientTransport = StreamCountingClientTransport(wrapping: client)
 
     case .throwsOnStreamCreation(let code):
-      let server = InProcessTransport.Server()  // Will never be called.
+      let server = InProcessTransport.Server(peer: "in-process:1234")  // Will never be called.
       let client = ThrowOnStreamCreationTransport(code: code)
       self.serverTransport = StreamCountingServerTransport(wrapping: server)
       self.clientTransport = StreamCountingClientTransport(wrapping: client)

+ 1 - 0
Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTestSupport/ServerRPCExecutorTestHarness.swift

@@ -102,6 +102,7 @@ struct ServerRPCExecutorTestHarness {
         await withServerContextRPCCancellationHandle { cancellation in
           let context = ServerContext(
             descriptor: MethodDescriptor(fullyQualifiedService: "foo", method: "bar"),
+            peer: "tests",
             cancellation: cancellation
           )
 

+ 4 - 1
Tests/GRPCCoreTests/GRPCServerTests.swift

@@ -334,7 +334,10 @@ final class GRPCServerTests: XCTestCase {
   }
 
   func testTestRunStoppedServer() async throws {
-    let server = GRPCServer(transport: InProcessTransport.Server(), services: [])
+    let server = GRPCServer(
+      transport: InProcessTransport.Server(peer: "in-process:1234"),
+      services: []
+    )
     // Run the server.
     let task = Task { try await server.serve() }
     task.cancel()

+ 5 - 5
Tests/GRPCInProcessTransportTests/InProcessClientTransportTests.swift

@@ -142,7 +142,7 @@ final class InProcessClientTransportTests: XCTestCase {
   }
 
   func testOpenStreamSuccessfullyAndThenClose() async throws {
-    let server = InProcessTransport.Server()
+    let server = InProcessTransport.Server(peer: "in-process:1234")
     let client = makeClient(server: server)
 
     try await withThrowingTaskGroup(of: Void.self) { group in
@@ -199,7 +199,7 @@ final class InProcessClientTransportTests: XCTestCase {
     )
 
     var client = InProcessTransport.Client(
-      server: InProcessTransport.Server(),
+      server: InProcessTransport.Server(peer: "in-process:1234"),
       serviceConfig: serviceConfig
     )
 
@@ -223,7 +223,7 @@ final class InProcessClientTransportTests: XCTestCase {
     )
     serviceConfig.methodConfig.append(overrideConfiguration)
     client = InProcessTransport.Client(
-      server: InProcessTransport.Server(),
+      server: InProcessTransport.Server(peer: "in-process:1234"),
       serviceConfig: serviceConfig
     )
 
@@ -239,7 +239,7 @@ final class InProcessClientTransportTests: XCTestCase {
   }
 
   func testOpenMultipleStreamsThenClose() async throws {
-    let server = InProcessTransport.Server()
+    let server = InProcessTransport.Server(peer: "in-process:1234")
     let client = makeClient(server: server)
 
     try await withThrowingTaskGroup(of: Void.self) { group in
@@ -269,7 +269,7 @@ final class InProcessClientTransportTests: XCTestCase {
   }
 
   func makeClient(
-    server: InProcessTransport.Server = InProcessTransport.Server()
+    server: InProcessTransport.Server = InProcessTransport.Server(peer: "in-process:1234")
   ) -> InProcessTransport.Client {
     let defaultPolicy = RetryPolicy(
       maxAttempts: 10,

+ 2 - 2
Tests/GRPCInProcessTransportTests/InProcessServerTransportTests.swift

@@ -21,7 +21,7 @@ import XCTest
 
 final class InProcessServerTransportTests: XCTestCase {
   func testStartListening() async throws {
-    let transport = InProcessTransport.Server()
+    let transport = InProcessTransport.Server(peer: "in-process:1234")
 
     let outbound = GRPCAsyncThrowingStream.makeStream(of: RPCResponsePart.self)
     let stream = RPCStream<
@@ -53,7 +53,7 @@ final class InProcessServerTransportTests: XCTestCase {
   }
 
   func testStopListening() async throws {
-    let transport = InProcessTransport.Server()
+    let transport = InProcessTransport.Server(peer: "in-process:1234")
 
     let firstStreamOutbound = GRPCAsyncThrowingStream.makeStream(of: RPCResponsePart.self)
     let firstStream = RPCStream<

+ 59 - 0
Tests/GRPCInProcessTransportTests/InProcessTransportTests.swift

@@ -64,6 +64,29 @@ struct InProcessTransportTests {
       client.beginGracefulShutdown()
     }
   }
+
+  @Test("Peer info")
+  func peerInfo() async throws {
+    try await self.withTestServerAndClient { server, client in
+      defer {
+        client.beginGracefulShutdown()
+        server.beginGracefulShutdown()
+      }
+
+      let peerInfo = try await client.unary(
+        request: ClientRequest(message: ()),
+        descriptor: .peerInfo,
+        serializer: VoidSerializer(),
+        deserializer: UTF8Deserializer(),
+        options: .defaults
+      ) {
+        try $0.message
+      }
+
+      let match = peerInfo.wholeMatch(of: /in-process:\d+/)
+      #expect(match != nil)
+    }
+  }
 }
 
 private struct TestService: RegistrableRPCService {
@@ -96,6 +119,13 @@ private struct TestService: RegistrableRPCService {
     }
   }
 
+  func peerInfo(
+    request: ServerRequest<Void>,
+    context: ServerContext
+  ) async throws -> ServerResponse<String> {
+    return ServerResponse(message: context.peer)
+  }
+
   func registerMethods(with router: inout RPCRouter) {
     router.registerHandler(
       forMethod: .testCancellation,
@@ -105,6 +135,19 @@ private struct TestService: RegistrableRPCService {
         try await self.cancellation(request: ServerRequest(stream: $0), context: $1)
       }
     )
+
+    router.registerHandler(
+      forMethod: .peerInfo,
+      deserializer: VoidDeserializer(),
+      serializer: UTF8Serializer(),
+      handler: {
+        let response = try await self.peerInfo(
+          request: ServerRequest<Void>(stream: $0),
+          context: $1
+        )
+        return StreamingServerResponse(single: response)
+      }
+    )
   }
 }
 
@@ -113,6 +156,11 @@ extension MethodDescriptor {
     fullyQualifiedService: "test",
     method: "cancellation"
   )
+
+  fileprivate static let peerInfo = Self(
+    fullyQualifiedService: "test",
+    method: "peerInfo"
+  )
 }
 
 private struct UTF8Serializer: MessageSerializer {
@@ -126,3 +174,14 @@ private struct UTF8Deserializer: MessageDeserializer {
     String(decoding: serializedMessageBytes, as: UTF8.self)
   }
 }
+
+private struct VoidSerializer: MessageSerializer {
+  func serialize(_ message: Void) throws -> [UInt8] {
+    []
+  }
+}
+
+private struct VoidDeserializer: MessageDeserializer {
+  func deserialize(_ serializedMessageBytes: [UInt8]) throws {
+  }
+}