Browse Source

Lock publick methods to prevent concurrency issues

Evgenii Neumerzhitckii 6 years ago
parent
commit
dc77fd8591
2 changed files with 112 additions and 7 deletions
  1. 34 5
      Sources/KeychainSwift.swift
  2. 78 2
      Tests/KeychainSwiftTests/ConcurrencyTests.swift

+ 34 - 5
Sources/KeychainSwift.swift

@@ -33,7 +33,8 @@ open class KeychainSwift {
   */
   open var synchronizable: Bool = false
 
-  private let readLock = NSLock()
+  private let lock = NSLock()
+
   
   /// Instantiate a KeychainSwift object
   public init() { }
@@ -84,7 +85,12 @@ open class KeychainSwift {
   open func set(_ value: Data, forKey key: String,
     withAccess access: KeychainSwiftAccessOptions? = nil) -> Bool {
     
-    delete(key) // Delete any existing key before saving it
+    // The lock prevents the code to be run simlultaneously
+    // from multiple threads which may result in crashing
+    lock.lock()
+    defer { lock.unlock() }
+    
+    deleteNoLock(key) // Delete any existing key before saving it
 
     let accessible = access?.value ?? KeychainSwiftAccessOptions.defaultOption.value
       
@@ -160,8 +166,8 @@ open class KeychainSwift {
   open func getData(_ key: String, asReference: Bool = false) -> Data? {
     // The lock prevents the code to be run simlultaneously
     // from multiple threads which may result in crashing
-    readLock.lock()
-    defer { readLock.unlock() }
+    lock.lock()
+    defer { lock.unlock() }
     
     let prefixedKey = keyWithPrefix(key)
     
@@ -218,8 +224,26 @@ open class KeychainSwift {
   */
   @discardableResult
   open func delete(_ key: String) -> Bool {
+    // The lock prevents the code to be run simlultaneously
+    // from multiple threads which may result in crashing
+    lock.lock()
+    defer { lock.unlock() }
+    
+    return deleteNoLock(key)
+  }
+  
+  /**
+   
+  Same as delete but is only accessed internally, since it is not thread safe.
+   
+   - parameter key: The key that is used to delete the keychain item.
+   - returns: True if the item was successfully deleted.
+   
+   */
+  @discardableResult
+  func deleteNoLock(_ key: String) -> Bool {
     let prefixedKey = keyWithPrefix(key)
-
+    
     var query: [String: Any] = [
       KeychainSwiftConstants.klass       : kSecClassGenericPassword,
       KeychainSwiftConstants.attrAccount : prefixedKey
@@ -243,6 +267,11 @@ open class KeychainSwift {
   */
   @discardableResult
   open func clear() -> Bool {
+    // The lock prevents the code to be run simlultaneously
+    // from multiple threads which may result in crashing
+    lock.lock()
+    defer { lock.unlock() }
+    
     var query: [String: Any] = [ kSecClass as String : kSecClassGenericPassword ]
     query = addAccessGroupWhenPresent(query)
     query = addSynchronizableIfRequired(query, addingItems: false)

+ 78 - 2
Tests/KeychainSwiftTests/ConcurrencyTests.swift

@@ -26,6 +26,8 @@ class ConcurrencyTests: XCTestCase {
     func testConcurrencyDoesntCrash() {
 
         let expectation = self.expectation(description: "Wait for write loop")
+        let expectation2 = self.expectation(description: "Wait for write loop")
+
 
         let dataToWrite = "{ asdf ñlk BNALSKDJFÑLAKSJDFÑLKJ ZÑCLXKJ ÑALSKDFJÑLKASJDFÑLKJASDÑFLKJAÑSDLKFJÑLKJ}"
         obj.set(dataToWrite, forKey: "test-key")
@@ -71,6 +73,62 @@ class ConcurrencyTests: XCTestCase {
                 }, timeoutWith: nil)
             }
         }
+      
+        let deleteQueue = DispatchQueue(label: "deleteQueue", attributes: [])
+        deleteQueue.async {
+          for _ in 0..<400 {
+            let _: Bool = synchronize( { completion in
+              let result = self.obj.delete("test-key")
+              DispatchQueue.global(qos: .background).async {
+                DispatchQueue.main.asyncAfter(deadline: .now() + .milliseconds(5)) {
+                  completion(result)
+                }
+              }
+            }, timeoutWith: false)
+          }
+        }
+      
+        let deleteQueue2 = DispatchQueue(label: "deleteQueue2", attributes: [])
+        deleteQueue2.async {
+          for _ in 0..<400 {
+            let _: Bool = synchronize( { completion in
+              let result = self.obj.delete("test-key")
+              DispatchQueue.global(qos: .background).async {
+                DispatchQueue.main.asyncAfter(deadline: .now() + .milliseconds(5)) {
+                  completion(result)
+                }
+              }
+            }, timeoutWith: false)
+          }
+        }
+      
+        let clearQueue = DispatchQueue(label: "clearQueue", attributes: [])
+        clearQueue.async {
+          for _ in 0..<400 {
+            let _: Bool = synchronize( { completion in
+              let result = self.obj.clear()
+              DispatchQueue.global(qos: .background).async {
+                DispatchQueue.main.asyncAfter(deadline: .now() + .milliseconds(5)) {
+                  completion(result)
+                }
+              }
+            }, timeoutWith: false)
+          }
+        }
+      
+        let clearQueue2 = DispatchQueue(label: "clearQueue2", attributes: [])
+        clearQueue2.async {
+          for _ in 0..<400 {
+            let _: Bool = synchronize( { completion in
+              let result = self.obj.clear()
+              DispatchQueue.global(qos: .background).async {
+                DispatchQueue.main.asyncAfter(deadline: .now() + .milliseconds(5)) {
+                  completion(result)
+                }
+              }
+            }, timeoutWith: false)
+          }
+        }
 
         let writeQueue = DispatchQueue(label: "WriteQueue", attributes: [])
         writeQueue.async {
@@ -89,14 +147,32 @@ class ConcurrencyTests: XCTestCase {
             }
             expectation.fulfill()
         }
+      
+        let writeQueue2 = DispatchQueue(label: "WriteQueue2", attributes: [])
+        writeQueue2.async {
+          for _ in 0..<500 {
+            let written: Bool = synchronize({ completion in
+              DispatchQueue.global(qos: .background).async {
+                DispatchQueue.main.asyncAfter(deadline: .now() + .milliseconds(5)) {
+                  let result = self.obj.set(dataToWrite, forKey: "test-key")
+                  completion(result)
+                }
+              }
+            }, timeoutWith: false)
+            if written {
+              writes = writes + 1
+            }
+          }
+          expectation2.fulfill()
+        }
 
         for _ in 0..<1000 {
             self.obj.set(dataToWrite, forKey: "test-key")
             let _ = self.obj.get("test-key")
         }
-        self.waitForExpectations(timeout: 20, handler: nil)
+        self.waitForExpectations(timeout: 30, handler: nil)
 
-        XCTAssertEqual(500, writes)
+        XCTAssertEqual(1000, writes)
     }
 }