浏览代码

Add a DNS NameResolver (#5)

Motivation:

Many users will rely on DNS to resolve the IP addresses of servers to
connect to, we should therefore provide a DNS name resolver.

Modifications:

- Add a DNS name resolver factory capable of resolving IP addresses
- Add the resolver to the registry defaults

Result:

Can resolve DNS targets
George Barnett 1 年之前
父节点
当前提交
ba557f238b

+ 127 - 0
Sources/GRPCNIOTransportCore/Client/Resolver/NameResolver+DNS.swift

@@ -0,0 +1,127 @@
+/*
+ * Copyright 2024, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+private import GRPCCore
+
+extension ResolvableTargets {
+  /// A resolvable target for addresses which can be resolved via DNS.
+  ///
+  /// If you already have an IPv4 or IPv6 address use ``ResolvableTargets/IPv4`` and
+  /// ``ResolvableTargets/IPv6`` respectively.
+  public struct DNS: ResolvableTarget, Sendable {
+    /// The host to resolve via DNS.
+    public var host: String
+
+    /// The port to use with resolved addresses.
+    public var port: Int
+
+    /// Create a new DNS target.
+    /// - Parameters:
+    ///   - host: The host to resolve via DNS.
+    ///   - port: The port to use with resolved addresses.
+    public init(host: String, port: Int) {
+      self.host = host
+      self.port = port
+    }
+  }
+}
+
+extension ResolvableTarget where Self == ResolvableTargets.DNS {
+  /// Creates a new resolvable DNS target.
+  /// - Parameters:
+  ///   - host: The host address to resolve.
+  ///   - port: The port to use for each resolved address.
+  /// - Returns: A ``ResolvableTarget``.
+  public static func dns(host: String, port: Int = 443) -> Self {
+    return Self(host: host, port: port)
+  }
+}
+
+@available(macOS 15.0, iOS 18.0, watchOS 11.0, tvOS 18.0, visionOS 2.0, *)
+extension NameResolvers {
+  /// A ``NameResolverFactory`` for ``ResolvableTargets/DNS`` targets.
+  public struct DNS: NameResolverFactory {
+    public typealias Target = ResolvableTargets.DNS
+
+    /// Create a new DNS name resolver factory.
+    public init() {}
+
+    public func resolver(for target: Target) -> NameResolver {
+      let resolver = Self.Resolver(target: target)
+      return NameResolver(names: RPCAsyncSequence(wrapping: resolver), updateMode: .pull)
+    }
+  }
+}
+
+@available(macOS 15.0, iOS 18.0, watchOS 11.0, tvOS 18.0, visionOS 2.0, *)
+extension NameResolvers.DNS {
+  struct Resolver: Sendable {
+    var target: ResolvableTargets.DNS
+
+    init(target: ResolvableTargets.DNS) {
+      self.target = target
+    }
+
+    func resolve(
+      isolation actor: isolated (any Actor)? = nil
+    ) async throws -> NameResolutionResult {
+      let addresses: [SocketAddress]
+
+      do {
+        addresses = try await DNSResolver.resolve(host: self.target.host, port: self.target.port)
+      } catch let error as CancellationError {
+        throw error
+      } catch {
+        throw RPCError(
+          code: .internalError,
+          message: "Couldn't resolve address for \(self.target.host):\(self.target.port)",
+          cause: error
+        )
+      }
+
+      return NameResolutionResult(endpoints: [Endpoint(addresses: addresses)], serviceConfig: nil)
+    }
+  }
+}
+
+@available(macOS 15.0, iOS 18.0, watchOS 11.0, tvOS 18.0, visionOS 2.0, *)
+extension NameResolvers.DNS.Resolver: AsyncSequence {
+  typealias Element = NameResolutionResult
+
+  func makeAsyncIterator() -> AsyncIterator {
+    return AsyncIterator(resolver: self)
+  }
+
+  struct AsyncIterator: AsyncIteratorProtocol {
+    typealias Element = NameResolutionResult
+
+    private let resolver: NameResolvers.DNS.Resolver
+
+    init(resolver: NameResolvers.DNS.Resolver) {
+      self.resolver = resolver
+    }
+
+    func next() async throws -> NameResolutionResult? {
+      return try await self.next(isolation: nil)
+    }
+
+    func next(
+      isolation actor: isolated (any Actor)?
+    ) async throws(any Error) -> NameResolutionResult? {
+      return try await self.resolver.resolve(isolation: actor)
+    }
+  }
+}

+ 12 - 1
Sources/GRPCNIOTransportCore/Client/Resolver/NameResolverRegistry.swift

@@ -40,6 +40,7 @@
 @available(macOS 15.0, iOS 18.0, watchOS 11.0, tvOS 18.0, visionOS 2.0, *)
 public struct NameResolverRegistry {
   private enum Factory {
+    case dns(NameResolvers.DNS)
     case ipv4(NameResolvers.IPv4)
     case ipv6(NameResolvers.IPv6)
     case unix(NameResolvers.UnixDomainSocket)
@@ -47,7 +48,9 @@ public struct NameResolverRegistry {
     case other(any NameResolverFactory)
 
     init(_ factory: some NameResolverFactory) {
-      if let ipv4 = factory as? NameResolvers.IPv4 {
+      if let dns = factory as? NameResolvers.DNS {
+        self = .dns(dns)
+      } else if let ipv4 = factory as? NameResolvers.IPv4 {
         self = .ipv4(ipv4)
       } else if let ipv6 = factory as? NameResolvers.IPv6 {
         self = .ipv6(ipv6)
@@ -62,6 +65,8 @@ public struct NameResolverRegistry {
 
     func makeResolverIfCompatible<Target: ResolvableTarget>(_ target: Target) -> NameResolver? {
       switch self {
+      case .dns(let factory):
+        return factory.makeResolverIfCompatible(target)
       case .ipv4(let factory):
         return factory.makeResolverIfCompatible(target)
       case .ipv6(let factory):
@@ -77,6 +82,8 @@ public struct NameResolverRegistry {
 
     func hasTarget<Target: ResolvableTarget>(_ target: Target) -> Bool {
       switch self {
+      case .dns(let factory):
+        return factory.isCompatible(withTarget: target)
       case .ipv4(let factory):
         return factory.isCompatible(withTarget: target)
       case .ipv6(let factory):
@@ -92,6 +99,8 @@ public struct NameResolverRegistry {
 
     func `is`<Factory: NameResolverFactory>(ofType factoryType: Factory.Type) -> Bool {
       switch self {
+      case .dns:
+        return NameResolvers.DNS.self == factoryType
       case .ipv4:
         return NameResolvers.IPv4.self == factoryType
       case .ipv6:
@@ -116,12 +125,14 @@ public struct NameResolverRegistry {
   /// Returns a new name resolver registry with the default factories registered.
   ///
   /// The default resolvers include:
+  /// - ``NameResolvers/DNS``,
   /// - ``NameResolvers/IPv4``,
   /// - ``NameResolvers/IPv6``,
   /// - ``NameResolvers/UnixDomainSocket``,
   /// - ``NameResolvers/VirtualSocket``.
   public static var defaults: Self {
     var resolvers = NameResolverRegistry()
+    resolvers.registerFactory(NameResolvers.DNS())
     resolvers.registerFactory(NameResolvers.IPv4())
     resolvers.registerFactory(NameResolvers.IPv6())
     resolvers.registerFactory(NameResolvers.UnixDomainSocket())

+ 24 - 1
Tests/GRPCNIOTransportCoreTests/Client/Resolver/NameResolverRegistryTests.swift

@@ -132,11 +132,12 @@ final class NameResolverRegistryTests: XCTestCase {
 
   func testDefaultResolvers() {
     let resolvers = NameResolverRegistry.defaults
+    XCTAssert(resolvers.containsFactory(ofType: NameResolvers.DNS.self))
     XCTAssert(resolvers.containsFactory(ofType: NameResolvers.IPv4.self))
     XCTAssert(resolvers.containsFactory(ofType: NameResolvers.IPv6.self))
     XCTAssert(resolvers.containsFactory(ofType: NameResolvers.UnixDomainSocket.self))
     XCTAssert(resolvers.containsFactory(ofType: NameResolvers.VirtualSocket.self))
-    XCTAssertEqual(resolvers.count, 4)
+    XCTAssertEqual(resolvers.count, 5)
   }
 
   func testMakeResolver() {
@@ -167,6 +168,28 @@ final class NameResolverRegistryTests: XCTestCase {
     }
   }
 
+  func testDNSResolverForIPv4() async throws {
+    let factory = NameResolvers.DNS()
+    let resolver = factory.resolver(for: .dns(host: "127.0.0.1", port: 1234))
+    XCTAssertEqual(resolver.updateMode, .pull)
+
+    var iterator = resolver.names.makeAsyncIterator()
+    let result = try await XCTUnwrapAsync { try await iterator.next() }
+    XCTAssertEqual(result.endpoints, [Endpoint(.ipv4(host: "127.0.0.1", port: 1234))])
+    XCTAssertNil(result.serviceConfig)
+  }
+
+  func testDNSResolverForIPv6() async throws {
+    let factory = NameResolvers.DNS()
+    let resolver = factory.resolver(for: .dns(host: "::1", port: 1234))
+    XCTAssertEqual(resolver.updateMode, .pull)
+
+    var iterator = resolver.names.makeAsyncIterator()
+    let result = try await XCTUnwrapAsync { try await iterator.next() }
+    XCTAssertEqual(result.endpoints, [Endpoint(.ipv6(host: "::1", port: 1234))])
+    XCTAssertNil(result.serviceConfig)
+  }
+
   func testIPv4ResolverForSingleHost() async throws {
     let factory = NameResolvers.IPv4()
     let resolver = factory.resolver(for: .ipv4(host: "foo", port: 1234))