Browse Source

Merge pull request #292 from p4checo/multiple-improvements-and-fixes

Add multiple improvements and fixes
Ashley Mills 7 years ago
parent
commit
342a47cacb
3 changed files with 108 additions and 142 deletions
  1. 1 1
      ReachabilitySwift.podspec
  2. 94 108
      Sources/Reachability.swift
  3. 13 33
      Tests/ReachabilityTests.swift

+ 1 - 1
ReachabilitySwift.podspec

@@ -17,7 +17,7 @@ Pod::Spec.new do |s|
     :git => 'https://github.com/ashleymills/Reachability.swift.git',
     :tag => 'v'+s.version.to_s
   }
-  s.source_files = 'Reachability/Reachability.swift'
+  s.source_files = 'Sources/Reachability.swift'
   s.framework    = 'SystemConfiguration'
 
   s.requires_arc = true

+ 94 - 108
Sources/Reachability.swift

@@ -33,6 +33,7 @@ public enum ReachabilityError: Error {
     case FailedToCreateWithHostname(String)
     case UnableToSetCallback
     case UnableToSetDispatchQueue
+    case UnableToGetInitialFlags
 }
 
 @available(*, unavailable, renamed: "Notification.Name.reachabilityChanged")
@@ -42,13 +43,6 @@ extension Notification.Name {
     public static let reachabilityChanged = Notification.Name("reachabilityChanged")
 }
 
-func callback(reachability: SCNetworkReachability, flags: SCNetworkReachabilityFlags, info: UnsafeMutableRawPointer?) {
-    guard let info = info else { return }
-
-    let reachability = Unmanaged<Reachability>.fromOpaque(info).takeUnretainedValue()
-    reachability.reachabilityChanged()
-}
-
 public class Reachability {
 
     public typealias NetworkReachable = (Reachability) -> ()
@@ -100,36 +94,13 @@ public class Reachability {
     }
 
     public var connection: Connection {
-        guard isReachableFlagSet else { return .none }
-
-        // If we're reachable, but not on an iOS device (i.e. simulator), we must be on WiFi
-        guard isRunningOnDevice else { return .wifi }
-
-        var connection = Connection.none
-
-        if !isConnectionRequiredFlagSet {
-            connection = .wifi
+        switch flags?.connection {
+        case .none?, nil: return .none
+        case .cellular?: return allowsCellularConnection ? .cellular : .none
+        case .wifi?: return .wifi
         }
-
-        if isConnectionOnTrafficOrDemandFlagSet {
-            if !isInterventionRequiredFlagSet {
-                connection = .wifi
-            }
-        }
-
-        if isOnWWANFlagSet {
-            if !allowsCellularConnection {
-                connection = .none
-            } else {
-                connection = .cellular
-            }
-        }
-
-        return connection
     }
 
-    fileprivate var previousFlags: SCNetworkReachabilityFlags?
-
     fileprivate var isRunningOnDevice: Bool = {
         #if targetEnvironment(simulator)
             return false
@@ -140,30 +111,33 @@ public class Reachability {
 
     fileprivate var notifierRunning = false
     fileprivate let reachabilityRef: SCNetworkReachability
+    fileprivate let reachabilitySerialQueue: DispatchQueue
+    fileprivate(set) var flags: SCNetworkReachabilityFlags? {
+        didSet {
+            guard flags != oldValue else { return }
+            reachabilityChanged()
+        }
+    }
 
-    fileprivate let reachabilitySerialQueue = DispatchQueue(label: "uk.co.ashleymills.reachability")
-
-    fileprivate var usingHostname = false
-
-    required public init(reachabilityRef: SCNetworkReachability, usingHostname: Bool = false) {
-        allowsCellularConnection = true
+    required public init(reachabilityRef: SCNetworkReachability, queueQoS: DispatchQoS = .default, targetQueue: DispatchQueue? = nil) {
+        self.allowsCellularConnection = true
         self.reachabilityRef = reachabilityRef
-        self.usingHostname = usingHostname
+        self.reachabilitySerialQueue = DispatchQueue(label: "uk.co.ashleymills.reachability", qos: queueQoS, target: targetQueue)
     }
 
-    public convenience init?(hostname: String) {
+    public convenience init?(hostname: String, queueQoS: DispatchQoS = .default, targetQueue: DispatchQueue? = nil) {
         guard let ref = SCNetworkReachabilityCreateWithName(nil, hostname) else { return nil }
-        self.init(reachabilityRef: ref, usingHostname: true)
+        self.init(reachabilityRef: ref, queueQoS: queueQoS, targetQueue: targetQueue)
     }
 
-    public convenience init?() {
+    public convenience init?(queueQoS: DispatchQoS = .default, targetQueue: DispatchQueue? = nil) {
         var zeroAddress = sockaddr()
         zeroAddress.sa_len = UInt8(MemoryLayout<sockaddr>.size)
         zeroAddress.sa_family = sa_family_t(AF_INET)
 
         guard let ref = SCNetworkReachabilityCreateWithAddress(nil, &zeroAddress) else { return nil }
 
-        self.init(reachabilityRef: ref)
+        self.init(reachabilityRef: ref, queueQoS: queueQoS, targetQueue: targetQueue)
     }
 
     deinit {
@@ -177,6 +151,13 @@ public extension Reachability {
     func startNotifier() throws {
         guard !notifierRunning else { return }
 
+        let callback: SCNetworkReachabilityCallBack = { (reachability, flags, info) in
+            guard let info = info else { return }
+
+            let reachability = Unmanaged<Reachability>.fromOpaque(info).takeUnretainedValue()
+            reachability.flags = flags
+        }
+
         var context = SCNetworkReachabilityContext(version: 0, info: nil, retain: nil, release: nil, copyDescription: nil)
         context.info = UnsafeMutableRawPointer(Unmanaged<Reachability>.passUnretained(self).toOpaque())
         if !SCNetworkReachabilitySetCallback(reachabilityRef, callback, &context) {
@@ -190,8 +171,14 @@ public extension Reachability {
         }
 
         // Perform an initial check
-        reachabilitySerialQueue.async {
-            self.reachabilityChanged()
+        try reachabilitySerialQueue.sync { [unowned self] in
+            var flags = SCNetworkReachabilityFlags()
+            if !SCNetworkReachabilityGetFlags(self.reachabilityRef, &flags) {
+                self.stopNotifier()
+                throw ReachabilityError.UnableToGetInitialFlags
+            }
+
+            self.flags = flags
         }
 
         notifierRunning = true
@@ -207,116 +194,115 @@ public extension Reachability {
     // MARK: - *** Connection test methods ***
     @available(*, deprecated: 4.0, message: "Please use `connection != .none`")
     var isReachable: Bool {
-        guard isReachableFlagSet else { return false }
-
-        if isConnectionRequiredAndTransientFlagSet {
-            return false
-        }
-
-        if isRunningOnDevice {
-            if isOnWWANFlagSet && !reachableOnWWAN {
-                // We don't want to connect when on cellular connection
-                return false
-            }
-        }
-
-        return true
+        return connection != .none
     }
 
     @available(*, deprecated: 4.0, message: "Please use `connection == .cellular`")
     var isReachableViaWWAN: Bool {
         // Check we're not on the simulator, we're REACHABLE and check we're on WWAN
-        return isRunningOnDevice && isReachableFlagSet && isOnWWANFlagSet
+        return connection == .cellular
     }
 
     @available(*, deprecated: 4.0, message: "Please use `connection == .wifi`")
     var isReachableViaWiFi: Bool {
-        // Check we're reachable
-        guard isReachableFlagSet else { return false }
-
-        // If reachable we're reachable, but not on an iOS device (i.e. simulator), we must be on WiFi
-        guard isRunningOnDevice else { return true }
-
-        // Check we're NOT on WWAN
-        return !isOnWWANFlagSet
+        return connection == .wifi
     }
 
     var description: String {
-        let W = isRunningOnDevice ? (isOnWWANFlagSet ? "W" : "-") : "X"
-        let R = isReachableFlagSet ? "R" : "-"
-        let c = isConnectionRequiredFlagSet ? "c" : "-"
-        let t = isTransientConnectionFlagSet ? "t" : "-"
-        let i = isInterventionRequiredFlagSet ? "i" : "-"
-        let C = isConnectionOnTrafficFlagSet ? "C" : "-"
-        let D = isConnectionOnDemandFlagSet ? "D" : "-"
-        let l = isLocalAddressFlagSet ? "l" : "-"
-        let d = isDirectFlagSet ? "d" : "-"
+        guard let flags = flags else { return "unavailable flags" }
+        let W = isRunningOnDevice ? (flags.isOnWWANFlagSet ? "W" : "-") : "X"
+        let R = flags.isReachableFlagSet ? "R" : "-"
+        let c = flags.isConnectionRequiredFlagSet ? "c" : "-"
+        let t = flags.isTransientConnectionFlagSet ? "t" : "-"
+        let i = flags.isInterventionRequiredFlagSet ? "i" : "-"
+        let C = flags.isConnectionOnTrafficFlagSet ? "C" : "-"
+        let D = flags.isConnectionOnDemandFlagSet ? "D" : "-"
+        let l = flags.isLocalAddressFlagSet ? "l" : "-"
+        let d = flags.isDirectFlagSet ? "d" : "-"
 
         return "\(W)\(R) \(c)\(t)\(i)\(C)\(D)\(l)\(d)"
     }
 }
 
 fileprivate extension Reachability {
-    func reachabilityChanged() {
-        guard previousFlags != flags else { return }
 
+    func reachabilityChanged() {
         let block = connection != .none ? whenReachable : whenUnreachable
 
-        DispatchQueue.main.async {
-            if self.usingHostname {
-                print("USING HOSTNAME ABOUT TO CALL BLOCK")
+        DispatchQueue.main.async { [weak self] in
+            guard let strongSelf = self else { return }
+            block?(strongSelf)
+            strongSelf.notificationCenter.post(name: .reachabilityChanged, object: strongSelf)
+        }
+    }
+}
+
+extension SCNetworkReachabilityFlags {
+
+    typealias Connection = Reachability.Connection
+
+    var connection: Connection {
+        guard isReachableFlagSet else { return .none }
+
+        // If we're reachable, but not on an iOS device (i.e. simulator), we must be on WiFi
+        #if targetEnvironment(simulator)
+        return .wifi
+        #else
+        var connection = Connection.none
+
+        if !isConnectionRequiredFlagSet {
+            connection = .wifi
+        }
+
+        if isConnectionOnTrafficOrDemandFlagSet {
+            if !isInterventionRequiredFlagSet {
+                connection = .wifi
             }
-            block?(self)
-            self.notificationCenter.post(name: .reachabilityChanged, object:self)
         }
 
-        previousFlags = flags
+        if isOnWWANFlagSet {
+            connection = .cellular
+        }
+
+        return connection
+        #endif
     }
 
     var isOnWWANFlagSet: Bool {
         #if os(iOS)
-            return flags.contains(.isWWAN)
+        return contains(.isWWAN)
         #else
-            return false
+        return false
         #endif
     }
     var isReachableFlagSet: Bool {
-        return flags.contains(.reachable)
+        return contains(.reachable)
     }
     var isConnectionRequiredFlagSet: Bool {
-        return flags.contains(.connectionRequired)
+        return contains(.connectionRequired)
     }
     var isInterventionRequiredFlagSet: Bool {
-        return flags.contains(.interventionRequired)
+        return contains(.interventionRequired)
     }
     var isConnectionOnTrafficFlagSet: Bool {
-        return flags.contains(.connectionOnTraffic)
+        return contains(.connectionOnTraffic)
     }
     var isConnectionOnDemandFlagSet: Bool {
-        return flags.contains(.connectionOnDemand)
+        return contains(.connectionOnDemand)
     }
     var isConnectionOnTrafficOrDemandFlagSet: Bool {
-        return !flags.intersection([.connectionOnTraffic, .connectionOnDemand]).isEmpty
+        return !intersection([.connectionOnTraffic, .connectionOnDemand]).isEmpty
     }
     var isTransientConnectionFlagSet: Bool {
-        return flags.contains(.transientConnection)
+        return contains(.transientConnection)
     }
     var isLocalAddressFlagSet: Bool {
-        return flags.contains(.isLocalAddress)
+        return contains(.isLocalAddress)
     }
     var isDirectFlagSet: Bool {
-        return flags.contains(.isDirect)
+        return contains(.isDirect)
     }
     var isConnectionRequiredAndTransientFlagSet: Bool {
-        return flags.intersection([.connectionRequired, .transientConnection]) == [.connectionRequired, .transientConnection]
-    }
-
-    var flags: SCNetworkReachabilityFlags {
-        var flags = SCNetworkReachabilityFlags()
-        if SCNetworkReachabilityGetFlags(reachabilityRef, &flags) {
-            return flags
-        } else {
-            return SCNetworkReachabilityFlags()
-        }
+        return intersection([.connectionRequired, .transientConnection]) == [.connectionRequired, .transientConnection]
     }
 }

+ 13 - 33
Tests/ReachabilityTests.swift

@@ -11,43 +11,29 @@ import XCTest
 
 class ReachabilityTests: XCTestCase {
     
-    override func setUp() {
-        super.setUp()
-    }
-    
-    override func tearDown() {
-        super.tearDown()
-    }
-    
     func testValidHost() {
         let validHostName = "google.com"
         
         guard let reachability = Reachability(hostname: validHostName) else {
-            XCTAssert(false, "Unable to create reachability")
-            return
+            return XCTFail("Unable to create reachability")
         }
         
         let expected = expectation(description: "Check valid host")
         reachability.whenReachable = { reachability in
-            DispatchQueue.main.async {
-                print("Pass: \(validHostName) is reachable - \(reachability)")
-                
-                // Only fulfill the expectation on host reachable
-                expected.fulfill()
-            }
+            print("Pass: \(validHostName) is reachable - \(reachability)")
+
+            // Only fulfill the expectation on host reachable
+            expected.fulfill()
         }
         reachability.whenUnreachable = { reachability in
-            DispatchQueue.main.async {
-                print("\(validHostName) is initially unreachable - \(reachability)")
-                // Expectation isn't fulfilled here, so wait will time out if this is the only closure called
-            }
+            print("\(validHostName) is initially unreachable - \(reachability)")
+            // Expectation isn't fulfilled here, so wait will time out if this is the only closure called
         }
         
         do {
             try reachability.startNotifier()
         } catch {
-            XCTAssert(false, "Unable to start notifier")
-            return
+            return XCTFail("Unable to start notifier")
         }
         
         waitForExpectations(timeout: 5, handler: nil)
@@ -62,29 +48,23 @@ class ReachabilityTests: XCTestCase {
         let invalidHostName = "invalidhost"
 
         guard let reachability = Reachability(hostname: invalidHostName) else {
-            XCTAssert(false, "Unable to create reachability")
-            return
+            return XCTFail("Unable to create reachability")
         }
         
         let expected = expectation(description: "Check invalid host")
         reachability.whenReachable = { reachability in
-            DispatchQueue.main.async {
-                XCTAssert(false, "\(invalidHostName) should never be reachable - \(reachability))")
-            }
+            print("\(invalidHostName) is initially reachable - \(reachability)")
         }
         
         reachability.whenUnreachable = { reachability in
-            DispatchQueue.main.async {
-                print("Pass: \(invalidHostName) is unreachable - \(reachability))")
-                expected.fulfill()
-            }
+            print("Pass: \(invalidHostName) is unreachable - \(reachability))")
+            expected.fulfill()
         }
         
         do {
             try reachability.startNotifier()
         } catch {
-            XCTAssert(false, "Unable to start notifier")
-            return
+            return XCTFail("Unable to start notifier")
         }
         
         waitForExpectations(timeout: 5, handler: nil)