Browse Source

Add VirtualSocket resolver (#1815)

Motivation:

Users should be able to connect to Virtual Sockets. To do these they
need a VSOCK target and resolver.

Modifications:

- Add a VSOCK target and resolver
- Update the registry to include it by default

Result:

Can resolve VSOCK targets
George Barnett 1 year ago
parent
commit
a6a4539aa0

+ 64 - 0
Sources/GRPCHTTP2Core/Client/Resolver/NameResolver+VSOCK.swift

@@ -0,0 +1,64 @@
+/*
+ * Copyright 2024, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import GRPCCore
+
+extension ResolvableTargets {
+  /// A resolvable target for Virtual Socket addresses.
+  ///
+  /// ``VirtualSocket`` addresses can be resolved by the ``NameResolvers/VirtualSocket``
+  /// resolver which creates a single ``Endpoint`` for target address.
+  public struct VirtualSocket: ResolvableTarget {
+    public var address: SocketAddress.VirtualSocket
+
+    public init(address: SocketAddress.VirtualSocket) {
+      self.address = address
+    }
+  }
+}
+
+extension ResolvableTarget where Self == ResolvableTargets.VirtualSocket {
+  /// Creates a new resolvable Virtual Socket target.
+  /// - Parameters:
+  ///   - contextID: The context ID ('cid') of the service.
+  ///   - port: The port to connect to.
+  public static func vsock(
+    contextID: SocketAddress.VirtualSocket.ContextID,
+    port: SocketAddress.VirtualSocket.Port
+  ) -> Self {
+    let address = SocketAddress.VirtualSocket(contextID: contextID, port: port)
+    return ResolvableTargets.VirtualSocket(address: address)
+  }
+}
+
+@available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
+extension NameResolvers {
+  /// A ``NameResolverFactory`` for ``ResolvableTargets/VirtualSocket`` targets.
+  ///
+  /// The name resolver for a given target always produces the same values, with a single endpoint.
+  /// This resolver doesn't support fetching service configuration.
+  public struct VirtualSocket: NameResolverFactory {
+    public typealias Target = ResolvableTargets.VirtualSocket
+
+    public init() {}
+
+    public func resolver(for target: Target) -> NameResolver {
+      let endpoint = Endpoint(addresses: [.vsock(target.address)])
+      let resolutionResult = NameResolutionResult(endpoints: [endpoint], serviceConfiguration: nil)
+      return NameResolver(names: .constant(resolutionResult), updateMode: .pull)
+    }
+  }
+}

+ 14 - 3
Sources/GRPCHTTP2Core/Client/Resolver/NameResolverRegistry.swift

@@ -28,9 +28,9 @@
 /// // type `CustomResolver.ResolvableTarget`.
 /// registry.registerFactory(CustomResolver())
 ///
-/// // Remove the Unix Domain Socket and VSOCK resolvers, if they exist.
+/// // Remove the Unix Domain Socket and Virtual Socket resolvers, if they exist.
 /// registry.removeFactory(ofType: NameResolvers.UnixDomainSocket.self)
-/// registry.removeFactory(ofType: NameResolvers.VSOCK.self)
+/// registry.removeFactory(ofType: NameResolvers.VirtualSocket.self)
 ///
 /// // Resolve an IPv4 target
 /// if let resolver = registry.makeResolver(for: .ipv4(host: "localhost", port: 80)) {
@@ -43,6 +43,7 @@ public struct NameResolverRegistry {
     case ipv4(NameResolvers.IPv4)
     case ipv6(NameResolvers.IPv6)
     case unix(NameResolvers.UnixDomainSocket)
+    case vsock(NameResolvers.VirtualSocket)
     case other(any NameResolverFactory)
 
     init(_ factory: some NameResolverFactory) {
@@ -52,6 +53,8 @@ public struct NameResolverRegistry {
         self = .ipv6(ipv6)
       } else if let unix = factory as? NameResolvers.UnixDomainSocket {
         self = .unix(unix)
+      } else if let vsock = factory as? NameResolvers.VirtualSocket {
+        self = .vsock(vsock)
       } else {
         self = .other(factory)
       }
@@ -65,6 +68,8 @@ public struct NameResolverRegistry {
         return factory.makeResolverIfCompatible(target)
       case .unix(let factory):
         return factory.makeResolverIfCompatible(target)
+      case .vsock(let factory):
+        return factory.makeResolverIfCompatible(target)
       case .other(let factory):
         return factory.makeResolverIfCompatible(target)
       }
@@ -78,6 +83,8 @@ public struct NameResolverRegistry {
         return factory.isCompatible(withTarget: target)
       case .unix(let factory):
         return factory.isCompatible(withTarget: target)
+      case .vsock(let factory):
+        return factory.isCompatible(withTarget: target)
       case .other(let factory):
         return factory.isCompatible(withTarget: target)
       }
@@ -91,6 +98,8 @@ public struct NameResolverRegistry {
         return NameResolvers.IPv6.self == factoryType
       case .unix:
         return NameResolvers.UnixDomainSocket.self == factoryType
+      case .vsock:
+        return NameResolvers.VirtualSocket.self == factoryType
       case .other(let factory):
         return type(of: factory) == factoryType
       }
@@ -109,12 +118,14 @@ public struct NameResolverRegistry {
   /// The default resolvers include:
   /// - ``NameResolvers/IPv4``,
   /// - ``NameResolvers/IPv6``,
-  /// - ``NameResolvers/UnixDomainSocket``.
+  /// - ``NameResolvers/UnixDomainSocket``,
+  /// - ``NameResolvers/VirtualSocket``.
   public static var defaults: Self {
     var resolvers = NameResolverRegistry()
     resolvers.registerFactory(NameResolvers.IPv4())
     resolvers.registerFactory(NameResolvers.IPv6())
     resolvers.registerFactory(NameResolvers.UnixDomainSocket())
+    resolvers.registerFactory(NameResolvers.VirtualSocket())
     return resolvers
   }
 

+ 17 - 1
Tests/GRPCHTTP2CoreTests/Client/Resolver/NameResolverRegistryTests.swift

@@ -135,7 +135,8 @@ final class NameResolverRegistryTests: XCTestCase {
     XCTAssert(resolvers.containsFactory(ofType: NameResolvers.IPv4.self))
     XCTAssert(resolvers.containsFactory(ofType: NameResolvers.IPv6.self))
     XCTAssert(resolvers.containsFactory(ofType: NameResolvers.UnixDomainSocket.self))
-    XCTAssertEqual(resolvers.count, 3)
+    XCTAssert(resolvers.containsFactory(ofType: NameResolvers.VirtualSocket.self))
+    XCTAssertEqual(resolvers.count, 4)
   }
 
   func testMakeResolver() {
@@ -252,4 +253,19 @@ final class NameResolverRegistryTests: XCTestCase {
       XCTAssertNil(result.serviceConfiguration)
     }
   }
+
+  func testVSOCKResolver() async throws {
+    let factory = NameResolvers.VirtualSocket()
+    let resolver = factory.resolver(for: .vsock(contextID: .any, port: .any))
+
+    XCTAssertEqual(resolver.updateMode, .pull)
+
+    // The VSOCK resolver always returns the same values.
+    var iterator = resolver.names.makeAsyncIterator()
+    for _ in 0 ..< 1000 {
+      let result = try await XCTUnwrapAsync { try await iterator.next() }
+      XCTAssertEqual(result.endpoints, [Endpoint(addresses: [.vsock(contextID: .any, port: .any)])])
+      XCTAssertNil(result.serviceConfiguration)
+    }
+  }
 }