Browse Source

Fix Protected Value Property Access Safety (#3505)

* Use the wrapper for property access, add tests.

* Note the change that triggers the sanitizer.

* Fix inappropriate wrapper access; local wrapper use.
Jon Shier 4 years ago
parent
commit
c039ac798b

+ 4 - 5
Source/AuthenticationInterceptor.swift

@@ -217,15 +217,15 @@ public class AuthenticationInterceptor<AuthenticatorType>: RequestInterceptor wh
 
     /// The `Credential` used to authenticate requests.
     public var credential: Credential? {
-        get { mutableState.credential }
-        set { mutableState.credential = newValue }
+        get { $mutableState.credential }
+        set { $mutableState.credential = newValue }
     }
 
     let authenticator: AuthenticatorType
     let queue = DispatchQueue(label: "org.alamofire.authentication.inspector")
 
     @Protected
-    private var mutableState = MutableState()
+    private var mutableState: MutableState
 
     // MARK: Initialization
 
@@ -242,8 +242,7 @@ public class AuthenticationInterceptor<AuthenticatorType>: RequestInterceptor wh
                 credential: Credential? = nil,
                 refreshWindow: RefreshWindow? = RefreshWindow()) {
         self.authenticator = authenticator
-        mutableState.credential = credential
-        mutableState.refreshWindow = refreshWindow
+        mutableState = MutableState(credential: credential, refreshWindow: refreshWindow)
     }
 
     // MARK: Adapt

+ 3 - 3
Source/MultipartUpload.swift

@@ -45,8 +45,8 @@ final class MultipartUpload {
 
     func build() throws -> UploadRequest.Uploadable {
         let uploadable: UploadRequest.Uploadable
-        if multipartFormData.contentLength < encodingMemoryThreshold {
-            let data = try multipartFormData.encode()
+        if $multipartFormData.contentLength < encodingMemoryThreshold {
+            let data = try $multipartFormData.read { try $0.encode() }
 
             uploadable = .data(data)
         } else {
@@ -58,7 +58,7 @@ final class MultipartUpload {
             try fileManager.createDirectory(at: directoryURL, withIntermediateDirectories: true, attributes: nil)
 
             do {
-                try multipartFormData.writeEncodedData(to: fileURL)
+                try $multipartFormData.read { try $0.writeEncodedData(to: fileURL) }
             } catch {
                 // Cleanup after attempted write if it fails.
                 try? fileManager.removeItem(at: fileURL)

+ 12 - 8
Source/Protected.swift

@@ -35,17 +35,17 @@ extension Lock {
     /// - Parameter closure: The closure to run.
     ///
     /// - Returns:           The value the closure generated.
-    func around<T>(_ closure: () -> T) -> T {
+    func around<T>(_ closure: () throws -> T) rethrows -> T {
         lock(); defer { unlock() }
-        return closure()
+        return try closure()
     }
 
     /// Execute a closure while acquiring the lock.
     ///
     /// - Parameter closure: The closure to run.
-    func around(_ closure: () -> Void) {
+    func around(_ closure: () throws -> Void) rethrows {
         lock(); defer { unlock() }
-        closure()
+        try closure()
     }
 }
 
@@ -112,8 +112,8 @@ final class Protected<T> {
     /// - Parameter closure: The closure to execute.
     ///
     /// - Returns:           The return value of the closure passed.
-    func read<U>(_ closure: (T) -> U) -> U {
-        lock.around { closure(self.value) }
+    func read<U>(_ closure: (T) throws -> U) rethrows -> U {
+        try lock.around { try closure(self.value) }
     }
 
     /// Synchronously modify the protected value.
@@ -122,14 +122,18 @@ final class Protected<T> {
     ///
     /// - Returns:           The modified value.
     @discardableResult
-    func write<U>(_ closure: (inout T) -> U) -> U {
-        lock.around { closure(&self.value) }
+    func write<U>(_ closure: (inout T) throws -> U) rethrows -> U {
+        try lock.around { try closure(&self.value) }
     }
 
     subscript<Property>(dynamicMember keyPath: WritableKeyPath<T, Property>) -> Property {
         get { lock.around { value[keyPath: keyPath] } }
         set { lock.around { value[keyPath: keyPath] = newValue } }
     }
+
+    subscript<Property>(dynamicMember keyPath: KeyPath<T, Property>) -> Property {
+        lock.around { value[keyPath: keyPath] }
+    }
 }
 
 extension Protected where T: RangeReplaceableCollection {

+ 28 - 28
Source/Request.swift

@@ -127,7 +127,7 @@ public class Request {
     fileprivate var mutableState = MutableState()
 
     /// `State` of the `Request`.
-    public var state: State { mutableState.state }
+    public var state: State { $mutableState.state }
     /// Returns whether `state` is `.initialized`.
     public var isInitialized: Bool { state == .initialized }
     /// Returns whether `state is `.resumed`.
@@ -150,38 +150,38 @@ public class Request {
     public let downloadProgress = Progress(totalUnitCount: 0)
     /// `ProgressHandler` called when `uploadProgress` is updated, on the provided `DispatchQueue`.
     private var uploadProgressHandler: (handler: ProgressHandler, queue: DispatchQueue)? {
-        get { mutableState.uploadProgressHandler }
-        set { mutableState.uploadProgressHandler = newValue }
+        get { $mutableState.uploadProgressHandler }
+        set { $mutableState.uploadProgressHandler = newValue }
     }
 
     /// `ProgressHandler` called when `downloadProgress` is updated, on the provided `DispatchQueue`.
     fileprivate var downloadProgressHandler: (handler: ProgressHandler, queue: DispatchQueue)? {
-        get { mutableState.downloadProgressHandler }
-        set { mutableState.downloadProgressHandler = newValue }
+        get { $mutableState.downloadProgressHandler }
+        set { $mutableState.downloadProgressHandler = newValue }
     }
 
     // MARK: Redirect Handling
 
     /// `RedirectHandler` set on the instance.
     public private(set) var redirectHandler: RedirectHandler? {
-        get { mutableState.redirectHandler }
-        set { mutableState.redirectHandler = newValue }
+        get { $mutableState.redirectHandler }
+        set { $mutableState.redirectHandler = newValue }
     }
 
     // MARK: Cached Response Handling
 
     /// `CachedResponseHandler` set on the instance.
     public private(set) var cachedResponseHandler: CachedResponseHandler? {
-        get { mutableState.cachedResponseHandler }
-        set { mutableState.cachedResponseHandler = newValue }
+        get { $mutableState.cachedResponseHandler }
+        set { $mutableState.cachedResponseHandler = newValue }
     }
 
     // MARK: URLCredential
 
     /// `URLCredential` used for authentication challenges. Created by calling one of the `authenticate` methods.
     public private(set) var credential: URLCredential? {
-        get { mutableState.credential }
-        set { mutableState.credential = newValue }
+        get { $mutableState.credential }
+        set { $mutableState.credential = newValue }
     }
 
     // MARK: Validators
@@ -193,7 +193,7 @@ public class Request {
     // MARK: URLRequests
 
     /// All `URLRequests` created on behalf of the `Request`, including original and adapted requests.
-    public var requests: [URLRequest] { mutableState.requests }
+    public var requests: [URLRequest] { $mutableState.requests }
     /// First `URLRequest` created on behalf of the `Request`. May not be the first one actually executed.
     public var firstRequest: URLRequest? { requests.first }
     /// Last `URLRequest` created on behalf of the `Request`.
@@ -214,7 +214,7 @@ public class Request {
     // MARK: Tasks
 
     /// All `URLSessionTask`s created on behalf of the `Request`.
-    public var tasks: [URLSessionTask] { mutableState.tasks }
+    public var tasks: [URLSessionTask] { $mutableState.tasks }
     /// First `URLSessionTask` created on behalf of the `Request`.
     public var firstTask: URLSessionTask? { tasks.first }
     /// Last `URLSessionTask` crated on behalf of the `Request`.
@@ -225,7 +225,7 @@ public class Request {
     // MARK: Metrics
 
     /// All `URLSessionTaskMetrics` gathered on behalf of the `Request`. Should correspond to the `tasks` created.
-    public var allMetrics: [URLSessionTaskMetrics] { mutableState.metrics }
+    public var allMetrics: [URLSessionTaskMetrics] { $mutableState.metrics }
     /// First `URLSessionTaskMetrics` gathered on behalf of the `Request`.
     public var firstMetrics: URLSessionTaskMetrics? { allMetrics.first }
     /// Last `URLSessionTaskMetrics` gathered on behalf of the `Request`.
@@ -236,14 +236,14 @@ public class Request {
     // MARK: Retry Count
 
     /// Number of times the `Request` has been retried.
-    public var retryCount: Int { mutableState.retryCount }
+    public var retryCount: Int { $mutableState.retryCount }
 
     // MARK: Error
 
     /// `Error` returned from Alamofire internally, from the network request directly, or any validators executed.
     public fileprivate(set) var error: AFError? {
-        get { mutableState.error }
-        set { mutableState.error = newValue }
+        get { $mutableState.error }
+        set { $mutableState.error = newValue }
     }
 
     /// Default initializer for the `Request` superclass.
@@ -511,9 +511,9 @@ public class Request {
     func finish(error: AFError? = nil) {
         dispatchPrecondition(condition: .onQueue(underlyingQueue))
 
-        guard !mutableState.isFinishing else { return }
+        guard !$mutableState.isFinishing else { return }
 
-        mutableState.isFinishing = true
+        $mutableState.isFinishing = true
 
         if let error = error { self.error = error }
 
@@ -752,7 +752,7 @@ public class Request {
     /// - Returns:              The instance.
     @discardableResult
     public func authenticate(with credential: URLCredential) -> Self {
-        mutableState.credential = credential
+        $mutableState.credential = credential
 
         return self
     }
@@ -768,7 +768,7 @@ public class Request {
     /// - Returns:   The instance.
     @discardableResult
     public func downloadProgress(queue: DispatchQueue = .main, closure: @escaping ProgressHandler) -> Self {
-        mutableState.downloadProgressHandler = (handler: closure, queue: queue)
+        $mutableState.downloadProgressHandler = (handler: closure, queue: queue)
 
         return self
     }
@@ -784,7 +784,7 @@ public class Request {
     /// - Returns:   The instance.
     @discardableResult
     public func uploadProgress(queue: DispatchQueue = .main, closure: @escaping ProgressHandler) -> Self {
-        mutableState.uploadProgressHandler = (handler: closure, queue: queue)
+        $mutableState.uploadProgressHandler = (handler: closure, queue: queue)
 
         return self
     }
@@ -1538,14 +1538,14 @@ public class DownloadRequest: Request {
     /// - Note: For more information about `resumeData`, see [Apple's documentation](https://developer.apple.com/documentation/foundation/urlsessiondownloadtask/1411634-cancel).
     public var resumeData: Data? {
         #if !(os(Linux) || os(Windows))
-        return mutableDownloadState.resumeData ?? error?.downloadResumeData
+        return $mutableDownloadState.resumeData ?? error?.downloadResumeData
         #else
-        return mutableDownloadState.resumeData
+        return $mutableDownloadState.resumeData
         #endif
     }
 
     /// If the download is successful, the `URL` where the file was downloaded.
-    public var fileURL: URL? { mutableDownloadState.fileURL }
+    public var fileURL: URL? { $mutableDownloadState.fileURL }
 
     // MARK: Initial State
 
@@ -1603,7 +1603,7 @@ public class DownloadRequest: Request {
         eventMonitor?.request(self, didFinishDownloadingUsing: task, with: result)
 
         switch result {
-        case let .success(url): mutableDownloadState.fileURL = url
+        case let .success(url): $mutableDownloadState.fileURL = url
         case let .failure(error): self.error = error
         }
     }
@@ -1697,7 +1697,7 @@ public class DownloadRequest: Request {
                 // Resume to ensure metrics are gathered.
                 task.resume()
                 task.cancel { resumeData in
-                    self.mutableDownloadState.resumeData = resumeData
+                    self.$mutableDownloadState.resumeData = resumeData
                     self.underlyingQueue.async { self.didCancelTask(task) }
                     completionHandler(resumeData)
                 }
@@ -1865,7 +1865,7 @@ public class UploadRequest: DataRequest {
         defer { super.cleanup() }
 
         guard
-            let uploadable = self.uploadable,
+            let uploadable = uploadable,
             case let .file(url, shouldRemove) = uploadable,
             shouldRemove
         else { return }

+ 200 - 2
Tests/ProtectedTests.swift

@@ -27,7 +27,7 @@ import Alamofire
 
 import XCTest
 
-final class ProtectedTests: XCTestCase {
+final class ProtectedTests: BaseTestCase {
     func testThatProtectedValuesAreAccessedSafely() {
         // Given
         let initialValue = "value"
@@ -59,7 +59,7 @@ final class ProtectedTests: XCTestCase {
     }
 }
 
-final class ProtectedWrapperTests: XCTestCase {
+final class ProtectedWrapperTests: BaseTestCase {
     @Protected var value = "value"
 
     override func setUp() {
@@ -109,6 +109,36 @@ final class ProtectedWrapperTests: XCTestCase {
         XCTAssertEqual(count.wrappedValue, 5)
     }
 
+    func testThatDynamicMemberPropertiesAreAccessedSafely() {
+        // Given
+        let string = Protected<String>("test")
+        let count = Protected<Int>(0)
+
+        // When
+        DispatchQueue.concurrentPerform(iterations: 10_000) { _ in
+            count.wrappedValue = string.wrappedValue.count
+        }
+
+        // Then
+        XCTAssertEqual(string.wrappedValue.count, count.wrappedValue)
+    }
+
+    #if swift(>=5.5)
+    func testThatLocalWrapperInstanceWorkCorrectly() {
+        // Given
+        @Protected var string = "test"
+        @Protected var count = 0
+
+        // When
+        DispatchQueue.concurrentPerform(iterations: 10_000) { _ in
+            count = string.count
+        }
+
+        // Then
+        XCTAssertEqual(string.count, count)
+    }
+    #endif
+
     func testThatDynamicMembersAreSetSafely() {
         // Given
         struct Mutable { var value = "value" }
@@ -123,3 +153,171 @@ final class ProtectedWrapperTests: XCTestCase {
         XCTAssertNotEqual(mutable.wrappedValue.value, "value")
     }
 }
+
+final class ProtectedHighContentionTests: BaseTestCase {
+    final class StringContainer {
+        var totalStrings: Int = 10
+        var stringArray = ["this", "is", "a", "simple", "set", "of", "test", "strings", "to", "use"]
+    }
+
+    struct StringContainerWriteState {
+        var results: [Int] = []
+        var completedWrites = 0
+
+        var queue1Complete = false
+        var queue2Complete = false
+    }
+
+    struct StringContainerReadState {
+        var results1: [Int] = []
+        var results2: [Int] = []
+
+        var queue1Complete = false
+        var queue2Complete = false
+    }
+
+    // MARK: - Properties
+
+    @Protected var stringContainer = StringContainer()
+    @Protected var stringContainerWrite = StringContainerWriteState()
+    @Protected var stringContainerRead = StringContainerReadState()
+
+    func testConcurrentReadWriteBlocks() {
+        // Given
+        let totalWrites = 4000
+        let totalReads = 10_000
+
+        let writeExpectation = expectation(description: "all parallel writes should complete before timeout")
+        let readExpectation = expectation(description: "all parallel reads should complete before timeout")
+
+        var writerQueueResults: [Int] = []
+        var completedWritesCount = 0
+
+        var readerQueueResults1: [Int] = []
+        var readerQueueResults2: [Int] = []
+
+        // When
+        executeWriteOperationsInParallel(totalOperationsToExecute: totalWrites) { results, completedOperationCount in
+            writerQueueResults = results
+            completedWritesCount = completedOperationCount
+            writeExpectation.fulfill()
+        }
+
+        executeReadOperationsInParallel(totalOperationsToExecute: totalReads) { results1, results2 in
+            readerQueueResults1 = results1
+            readerQueueResults2 = results2
+            readExpectation.fulfill()
+        }
+
+        waitForExpectations(timeout: timeout, handler: nil)
+
+        // Then
+        XCTAssertEqual(readerQueueResults1.count, totalReads)
+        XCTAssertEqual(readerQueueResults2.count, totalReads)
+        XCTAssertEqual(writerQueueResults.count, totalWrites)
+        XCTAssertEqual(completedWritesCount, totalWrites)
+
+        readerQueueResults1.forEach { XCTAssertEqual($0, 10) }
+        readerQueueResults2.forEach { XCTAssertEqual($0, 10) }
+        writerQueueResults.forEach { XCTAssertEqual($0, 10) }
+    }
+
+    private func executeWriteOperationsInParallel(totalOperationsToExecute totalOperations: Int,
+                                                  completion: @escaping ([Int], Int) -> Void) {
+        let queue1 = DispatchQueue(label: "com.alamofire.testWriterQueue1")
+        let queue2 = DispatchQueue(label: "com.alamofire.testWriterQueue2")
+
+        for _ in 1...totalOperations {
+            queue1.async {
+                // Moves the last string element to the beginning of the string array
+                let result: Int = self.$stringContainer.write { stringContainer in
+                    let lastElement = stringContainer.stringArray.removeLast()
+                    stringContainer.totalStrings = stringContainer.stringArray.count
+
+                    stringContainer.stringArray.insert(lastElement, at: 0)
+                    stringContainer.totalStrings = stringContainer.stringArray.count
+
+                    return stringContainer.totalStrings
+                }
+
+                self.$stringContainerWrite.write { mutableState in
+                    mutableState.results.append(result)
+
+                    if mutableState.results.count == totalOperations {
+                        mutableState.queue1Complete = true
+
+                        if mutableState.queue2Complete {
+                            completion(mutableState.results, mutableState.completedWrites)
+                        }
+                    }
+                }
+            }
+
+            queue2.async {
+                // Moves the first string element to the end of the string array
+                self.$stringContainer.write { stringContainer in
+                    let firstElement = stringContainer.stringArray.remove(at: 0)
+                    stringContainer.totalStrings = stringContainer.stringArray.count
+
+                    stringContainer.stringArray.append(firstElement)
+                    stringContainer.totalStrings = stringContainer.stringArray.count
+                }
+
+                self.$stringContainerWrite.write { mutableState in
+                    mutableState.completedWrites += 1
+
+                    if mutableState.completedWrites == totalOperations {
+                        mutableState.queue2Complete = true
+
+                        if mutableState.queue1Complete {
+                            completion(mutableState.results, mutableState.completedWrites)
+                        }
+                    }
+                }
+            }
+        }
+    }
+
+    private func executeReadOperationsInParallel(totalOperationsToExecute totalOperations: Int,
+                                                 completion: @escaping ([Int], [Int]) -> Void) {
+        let queue1 = DispatchQueue(label: "com.alamofire.testReaderQueue1")
+        let queue2 = DispatchQueue(label: "com.alamofire.testReaderQueue1")
+
+        for _ in 1...totalOperations {
+            queue1.async {
+                // Reads the total string count in the string array
+                // Using the wrapped value (no $) instead of the wrapper itself triggers the thread sanitizer.
+                let result = self.$stringContainer.totalStrings
+
+                self.$stringContainerRead.write {
+                    $0.results1.append(result)
+
+                    if $0.results1.count == totalOperations {
+                        $0.queue1Complete = true
+
+                        if $0.queue2Complete {
+                            completion($0.results1, $0.results2)
+                        }
+                    }
+                }
+            }
+
+            queue2.async {
+                // Reads the total string count in the string array
+                let result = self.$stringContainer.read { $0.totalStrings }
+
+                self.$stringContainerRead.write {
+                    $0.results2.append(result)
+
+                    if $0.results2.count == totalOperations {
+                        $0.queue2Complete = true
+
+                        if $0.queue1Complete {
+                            completion($0.results1, $0.results2)
+                        }
+                    }
+                }
+            }
+        }
+    }
+}