فهرست منبع

Merge pull request #314 from p4checo/improve-errors-and-allow-configuring-notification-queue

Improve error handling and allow configuring notification queue
Ashley Mills 6 سال پیش
والد
کامیت
19fc460126
2فایلهای تغییر یافته به همراه56 افزوده شده و 36 حذف شده
  1. 54 34
      Sources/Reachability.swift
  2. 2 2
      Tests/ReachabilityTests.swift

+ 54 - 34
Sources/Reachability.swift

@@ -29,11 +29,11 @@ import SystemConfiguration
 import Foundation
 
 public enum ReachabilityError: Error {
-    case FailedToCreateWithAddress(sockaddr_in)
-    case FailedToCreateWithHostname(String)
-    case UnableToSetCallback
-    case UnableToSetDispatchQueue
-    case UnableToGetInitialFlags
+    case failedToCreateWithAddress(sockaddr, Int32)
+    case failedToCreateWithHostname(String, Int32)
+    case unableToSetCallback(Int32)
+    case unableToSetDispatchQueue(Int32)
+    case unableToGetFlags(Int32)
 }
 
 @available(*, unavailable, renamed: "Notification.Name.reachabilityChanged")
@@ -113,35 +113,49 @@ public class Reachability {
         #endif
     }()
 
-    fileprivate var notifierRunning = false
+    fileprivate(set) var notifierRunning = false
     fileprivate let reachabilityRef: SCNetworkReachability
     fileprivate let reachabilitySerialQueue: DispatchQueue
+    fileprivate let notificationQueue: DispatchQueue?
     fileprivate(set) var flags: SCNetworkReachabilityFlags? {
         didSet {
             guard flags != oldValue else { return }
-            reachabilityChanged()
+            notifyReachabilityChanged()
         }
     }
 
-    required public init(reachabilityRef: SCNetworkReachability, queueQoS: DispatchQoS = .default, targetQueue: DispatchQueue? = nil) {
+    required public init(reachabilityRef: SCNetworkReachability,
+                         queueQoS: DispatchQoS = .default,
+                         targetQueue: DispatchQueue? = nil,
+                         notificationQueue: DispatchQueue? = .main) {
         self.allowsCellularConnection = true
         self.reachabilityRef = reachabilityRef
         self.reachabilitySerialQueue = DispatchQueue(label: "uk.co.ashleymills.reachability", qos: queueQoS, target: targetQueue)
+        self.notificationQueue = notificationQueue
     }
 
-    public convenience init?(hostname: String, queueQoS: DispatchQoS = .default, targetQueue: DispatchQueue? = nil) {
-        guard let ref = SCNetworkReachabilityCreateWithName(nil, hostname) else { return nil }
-        self.init(reachabilityRef: ref, queueQoS: queueQoS, targetQueue: targetQueue)
+    public convenience init(hostname: String,
+                            queueQoS: DispatchQoS = .default,
+                            targetQueue: DispatchQueue? = nil,
+                            notificationQueue: DispatchQueue? = .main) throws {
+        guard let ref = SCNetworkReachabilityCreateWithName(nil, hostname) else {
+            throw ReachabilityError.failedToCreateWithHostname(hostname, SCError())
+        }
+        self.init(reachabilityRef: ref, queueQoS: queueQoS, targetQueue: targetQueue, notificationQueue: notificationQueue)
     }
 
-    public convenience init?(queueQoS: DispatchQoS = .default, targetQueue: DispatchQueue? = nil) {
+    public convenience init(queueQoS: DispatchQoS = .default,
+                            targetQueue: DispatchQueue? = nil,
+                            notificationQueue: DispatchQueue? = .main) throws {
         var zeroAddress = sockaddr()
         zeroAddress.sa_len = UInt8(MemoryLayout<sockaddr>.size)
         zeroAddress.sa_family = sa_family_t(AF_INET)
 
-        guard let ref = SCNetworkReachabilityCreateWithAddress(nil, &zeroAddress) else { return nil }
+        guard let ref = SCNetworkReachabilityCreateWithAddress(nil, &zeroAddress) else {
+            throw ReachabilityError.failedToCreateWithAddress(zeroAddress, SCError())
+        }
 
-        self.init(reachabilityRef: ref, queueQoS: queueQoS, targetQueue: targetQueue)
+        self.init(reachabilityRef: ref, queueQoS: queueQoS, targetQueue: targetQueue, notificationQueue: notificationQueue)
     }
 
     deinit {
@@ -163,15 +177,16 @@ public extension Reachability {
         }
 
         var context = SCNetworkReachabilityContext(version: 0, info: nil, retain: nil, release: nil, copyDescription: nil)
-        context.info = UnsafeMutableRawPointer(Unmanaged<Reachability>.passUnretained(self).toOpaque())
+        context.info = Unmanaged.passUnretained(self).toOpaque()
+
         if !SCNetworkReachabilitySetCallback(reachabilityRef, callback, &context) {
             stopNotifier()
-            throw ReachabilityError.UnableToSetCallback
+            throw ReachabilityError.unableToSetCallback(SCError())
         }
 
         if !SCNetworkReachabilitySetDispatchQueue(reachabilityRef, reachabilitySerialQueue) {
             stopNotifier()
-            throw ReachabilityError.UnableToSetDispatchQueue
+            throw ReachabilityError.unableToSetDispatchQueue(SCError())
         }
 
         // Perform an initial check
@@ -205,18 +220,7 @@ public extension Reachability {
     }
 
     var description: String {
-        guard let flags = flags else { return "unavailable flags" }
-        let W = isRunningOnDevice ? (flags.isOnWWANFlagSet ? "W" : "-") : "X"
-        let R = flags.isReachableFlagSet ? "R" : "-"
-        let c = flags.isConnectionRequiredFlagSet ? "c" : "-"
-        let t = flags.isTransientConnectionFlagSet ? "t" : "-"
-        let i = flags.isInterventionRequiredFlagSet ? "i" : "-"
-        let C = flags.isConnectionOnTrafficFlagSet ? "C" : "-"
-        let D = flags.isConnectionOnDemandFlagSet ? "D" : "-"
-        let l = flags.isLocalAddressFlagSet ? "l" : "-"
-        let d = flags.isDirectFlagSet ? "d" : "-"
-
-        return "\(W)\(R) \(c)\(t)\(i)\(C)\(D)\(l)\(d)"
+        return flags?.description ?? "unavailable flags"
     }
 }
 
@@ -227,21 +231,23 @@ fileprivate extension Reachability {
             var flags = SCNetworkReachabilityFlags()
             if !SCNetworkReachabilityGetFlags(self.reachabilityRef, &flags) {
                 self.stopNotifier()
-                throw ReachabilityError.UnableToGetInitialFlags
+                throw ReachabilityError.unableToGetFlags(SCError())
             }
             
             self.flags = flags
         }
     }
     
-    func reachabilityChanged() {
-        let block = connection != .none ? whenReachable : whenUnreachable
 
-        DispatchQueue.main.async { [weak self] in
+    func notifyReachabilityChanged() {
+        let notify = { [weak self] in
             guard let self = self else { return }
-            block?(self)
+            self.connection != .none ? self.whenReachable?(self) : self.whenUnreachable?(self)
             self.notificationCenter.post(name: .reachabilityChanged, object: self)
         }
+
+        // notify on the configured `notificationQueue`, or the caller's (i.e. `reachabilitySerialQueue`)
+        notificationQueue?.async(execute: notify) ?? notify()
     }
 }
 
@@ -313,4 +319,18 @@ extension SCNetworkReachabilityFlags {
     var isConnectionRequiredAndTransientFlagSet: Bool {
         return intersection([.connectionRequired, .transientConnection]) == [.connectionRequired, .transientConnection]
     }
+
+    var description: String {
+        let W = isOnWWANFlagSet ? "W" : "-"
+        let R = isReachableFlagSet ? "R" : "-"
+        let c = isConnectionRequiredFlagSet ? "c" : "-"
+        let t = isTransientConnectionFlagSet ? "t" : "-"
+        let i = isInterventionRequiredFlagSet ? "i" : "-"
+        let C = isConnectionOnTrafficFlagSet ? "C" : "-"
+        let D = isConnectionOnDemandFlagSet ? "D" : "-"
+        let l = isLocalAddressFlagSet ? "l" : "-"
+        let d = isDirectFlagSet ? "d" : "-"
+
+        return "\(W)\(R) \(c)\(t)\(i)\(C)\(D)\(l)\(d)"
+    }
 }

+ 2 - 2
Tests/ReachabilityTests.swift

@@ -14,7 +14,7 @@ class ReachabilityTests: XCTestCase {
     func testValidHost() {
         let validHostName = "google.com"
         
-        guard let reachability = Reachability(hostname: validHostName) else {
+        guard let reachability = try? Reachability(hostname: validHostName) else {
             return XCTFail("Unable to create reachability")
         }
         
@@ -47,7 +47,7 @@ class ReachabilityTests: XCTestCase {
 
         let invalidHostName = "invalidhost"
 
-        guard let reachability = Reachability(hostname: invalidHostName) else {
+        guard let reachability = try? Reachability(hostname: invalidHostName) else {
             return XCTFail("Unable to create reachability")
         }