Kaynağa Gözat

Fix Double Lock in AuthenticationInterceptor (#3322)

* Fix double lock in AuthenticationInterceptor.

* Refactor test name.

* Get rid of additional Xcode checks.
Jon Shier 5 yıl önce
ebeveyn
işleme
19df306645

+ 1 - 1
Alamofire.xcodeproj/project.pbxproj

@@ -1092,7 +1092,7 @@
 			isa = PBXProject;
 			attributes = {
 				LastSwiftUpdateCheck = 0700;
-				LastUpgradeCheck = 1200;
+				LastUpgradeCheck = 1220;
 				ORGANIZATIONNAME = Alamofire;
 				TargetAttributes = {
 					4CF626EE1BA7CB3E0011A099 = {

+ 1 - 1
Alamofire.xcodeproj/xcshareddata/xcschemes/Alamofire iOS.xcscheme

@@ -1,6 +1,6 @@
 <?xml version="1.0" encoding="UTF-8"?>
 <Scheme
-   LastUpgradeVersion = "1200"
+   LastUpgradeVersion = "1220"
    version = "1.3">
    <BuildAction
       parallelizeBuildables = "YES"

+ 1 - 1
Alamofire.xcodeproj/xcshareddata/xcschemes/Alamofire macOS.xcscheme

@@ -1,6 +1,6 @@
 <?xml version="1.0" encoding="UTF-8"?>
 <Scheme
-   LastUpgradeVersion = "1200"
+   LastUpgradeVersion = "1220"
    version = "1.3">
    <BuildAction
       parallelizeBuildables = "YES"

+ 1 - 1
Alamofire.xcodeproj/xcshareddata/xcschemes/Alamofire tvOS.xcscheme

@@ -1,6 +1,6 @@
 <?xml version="1.0" encoding="UTF-8"?>
 <Scheme
-   LastUpgradeVersion = "1200"
+   LastUpgradeVersion = "1220"
    version = "1.3">
    <BuildAction
       parallelizeBuildables = "YES"

+ 1 - 1
Alamofire.xcodeproj/xcshareddata/xcschemes/Alamofire watchOS.xcscheme

@@ -1,6 +1,6 @@
 <?xml version="1.0" encoding="UTF-8"?>
 <Scheme
-   LastUpgradeVersion = "1200"
+   LastUpgradeVersion = "1220"
    version = "1.3">
    <BuildAction
       parallelizeBuildables = "YES"

+ 1 - 1
Example/iOS Example.xcodeproj/project.pbxproj

@@ -208,7 +208,7 @@
 			isa = PBXProject;
 			attributes = {
 				LastSwiftUpdateCheck = 0720;
-				LastUpgradeCheck = 1200;
+				LastUpgradeCheck = 1220;
 				ORGANIZATIONNAME = Alamofire;
 				TargetAttributes = {
 					F8111E0419A951050040E7D1 = {

+ 1 - 1
Example/iOS Example.xcodeproj/xcshareddata/xcschemes/iOS Example.xcscheme

@@ -1,6 +1,6 @@
 <?xml version="1.0" encoding="UTF-8"?>
 <Scheme
-   LastUpgradeVersion = "1200"
+   LastUpgradeVersion = "1220"
    version = "1.3">
    <BuildAction
       parallelizeBuildables = "YES"

+ 11 - 9
Source/AuthenticationInterceptor.swift

@@ -210,7 +210,7 @@ public class AuthenticationInterceptor<AuthenticatorType>: RequestInterceptor wh
         var refreshWindow: RefreshWindow?
 
         var adaptOperations: [AdaptOperation] = []
-        var requestsToRetry: [(Alamofire.RetryResult) -> Void] = []
+        var requestsToRetry: [(RetryResult) -> Void] = []
     }
 
     // MARK: Properties
@@ -338,14 +338,16 @@ public class AuthenticationInterceptor<AuthenticatorType>: RequestInterceptor wh
         mutableState.refreshTimestamps.append(ProcessInfo.processInfo.systemUptime)
         mutableState.isRefreshing = true
 
-        authenticator.refresh(credential, for: session) { result in
-            self.$mutableState.write { mutableState in
-                switch result {
-                case let .success(credential):
-                    self.handleRefreshSuccess(credential, insideLock: &mutableState)
-
-                case let .failure(error):
-                    self.handleRefreshFailure(error, insideLock: &mutableState)
+        // Dispatch to queue to hop out of the lock in case authenticator.refresh is implemented synchronously.
+        queue.async {
+            self.authenticator.refresh(credential, for: session) { result in
+                self.$mutableState.write { mutableState in
+                    switch result {
+                    case let .success(credential):
+                        self.handleRefreshSuccess(credential, insideLock: &mutableState)
+                    case let .failure(error):
+                        self.handleRefreshFailure(error, insideLock: &mutableState)
+                    }
                 }
             }
         }

+ 106 - 51
Tests/AuthenticationInterceptorTests.swift

@@ -26,10 +26,10 @@
 import Foundation
 import XCTest
 
-class AuthenticationInterceptorTestCase: BaseTestCase {
+final class AuthenticationInterceptorTestCase: BaseTestCase {
     // MARK: - Helper Types
 
-    struct OAuthCredential: AuthenticationCredential {
+    struct TestCredential: AuthenticationCredential {
         let accessToken: String
         let refreshToken: String
         let userID: String
@@ -50,24 +50,26 @@ class AuthenticationInterceptorTestCase: BaseTestCase {
         }
     }
 
-    enum OAuthError: Error {
+    enum TestAuthError: Error {
         case refreshNetworkFailure
     }
 
-    class OAuthAuthenticator: Authenticator {
+    final class TestAuthenticator: Authenticator {
         private(set) var applyCount = 0
         private(set) var refreshCount = 0
         private(set) var didRequestFailDueToAuthErrorCount = 0
         private(set) var isRequestAuthenticatedWithCredentialCount = 0
 
-        let refreshResult: Result<OAuthCredential, Error>?
+        let shouldRefreshAsynchronously: Bool
+        let refreshResult: Result<TestCredential, Error>?
         let lock = NSLock()
 
-        init(refreshResult: Result<OAuthCredential, Error>? = nil) {
+        init(shouldRefreshAsynchronously: Bool = true, refreshResult: Result<TestCredential, Error>? = nil) {
+            self.shouldRefreshAsynchronously = shouldRefreshAsynchronously
             self.refreshResult = refreshResult
         }
 
-        func apply(_ credential: OAuthCredential, to urlRequest: inout URLRequest) {
+        func apply(_ credential: TestCredential, to urlRequest: inout URLRequest) {
             lock.lock(); defer { lock.unlock() }
 
             applyCount += 1
@@ -75,22 +77,28 @@ class AuthenticationInterceptorTestCase: BaseTestCase {
             urlRequest.headers.add(.authorization(bearerToken: credential.accessToken))
         }
 
-        func refresh(_ credential: OAuthCredential,
+        func refresh(_ credential: TestCredential,
                      for session: Session,
-                     completion: @escaping (Result<OAuthCredential, Error>) -> Void) {
-            lock.lock(); defer { lock.unlock() }
+                     completion: @escaping (Result<TestCredential, Error>) -> Void) {
+            lock.lock()
 
             refreshCount += 1
 
-            let refreshResult = self.refreshResult ?? .success(
-                OAuthCredential(accessToken: "a\(refreshCount)",
-                                refreshToken: "a\(refreshCount)",
-                                userID: "u1",
-                                expiration: Date())
+            let result = refreshResult ?? .success(
+                TestCredential(accessToken: "a\(refreshCount)",
+                               refreshToken: "a\(refreshCount)",
+                               userID: "u1",
+                               expiration: Date())
             )
 
-            // The 100 ms delay here is important to allow multiple requests to queue up while refreshing
-            DispatchQueue.global(qos: .utility).asyncAfter(deadline: .now() + 0.1) { completion(refreshResult) }
+            if shouldRefreshAsynchronously {
+                // The 10 ms delay here is important to allow multiple requests to queue up while refreshing.
+                DispatchQueue.global(qos: .utility).asyncAfter(deadline: .now() + 0.01) { completion(result) }
+                lock.unlock()
+            } else {
+                lock.unlock()
+                completion(result)
+            }
         }
 
         func didRequest(_ urlRequest: URLRequest,
@@ -104,7 +112,7 @@ class AuthenticationInterceptorTestCase: BaseTestCase {
             return response.statusCode == 401
         }
 
-        func isRequest(_ urlRequest: URLRequest, authenticatedWith credential: OAuthCredential) -> Bool {
+        func isRequest(_ urlRequest: URLRequest, authenticatedWith credential: TestCredential) -> Bool {
             lock.lock(); defer { lock.unlock() }
 
             isRequestAuthenticatedWithCredentialCount += 1
@@ -115,7 +123,7 @@ class AuthenticationInterceptorTestCase: BaseTestCase {
         }
     }
 
-    class PathAdapter: RequestAdapter {
+    final class PathAdapter: RequestAdapter {
         var paths: [String]
 
         init(paths: [String]) {
@@ -138,8 +146,8 @@ class AuthenticationInterceptorTestCase: BaseTestCase {
 
     func testThatInterceptorCanAdaptURLRequest() {
         // Given
-        let credential = OAuthCredential()
-        let authenticator = OAuthAuthenticator()
+        let credential = TestCredential()
+        let authenticator = TestAuthenticator()
         let interceptor = AuthenticationInterceptor(authenticator: authenticator, credential: credential)
 
         let urlRequest = URLRequest.makeHTTPBinRequest()
@@ -170,8 +178,8 @@ class AuthenticationInterceptorTestCase: BaseTestCase {
 
     func testThatInterceptorQueuesAdaptOperationWhenRefreshing() {
         // Given
-        let credential = OAuthCredential(requiresRefresh: true)
-        let authenticator = OAuthAuthenticator()
+        let credential = TestCredential(requiresRefresh: true)
+        let authenticator = TestAuthenticator()
         let interceptor = AuthenticationInterceptor(authenticator: authenticator, credential: credential)
 
         let urlRequest1 = URLRequest.makeHTTPBinRequest(path: "/status/200")
@@ -214,7 +222,7 @@ class AuthenticationInterceptorTestCase: BaseTestCase {
 
     func testThatInterceptorThrowsMissingCredentialErrorWhenCredentialIsNil() {
         // Given
-        let authenticator = OAuthAuthenticator()
+        let authenticator = TestAuthenticator()
         let interceptor = AuthenticationInterceptor(authenticator: authenticator)
 
         let urlRequest = URLRequest.makeHTTPBinRequest()
@@ -248,8 +256,8 @@ class AuthenticationInterceptorTestCase: BaseTestCase {
 
     func testThatInterceptorRethrowsRefreshErrorFromAdapt() {
         // Given
-        let credential = OAuthCredential(requiresRefresh: true)
-        let authenticator = OAuthAuthenticator(refreshResult: .failure(OAuthError.refreshNetworkFailure))
+        let credential = TestCredential(requiresRefresh: true)
+        let authenticator = TestAuthenticator(refreshResult: .failure(TestAuthError.refreshNetworkFailure))
         let interceptor = AuthenticationInterceptor(authenticator: authenticator, credential: credential)
 
         let session = Session()
@@ -271,7 +279,7 @@ class AuthenticationInterceptorTestCase: BaseTestCase {
 
         XCTAssertEqual(response?.result.isFailure, true)
         XCTAssertEqual(response?.result.failure?.asAFError?.isRequestAdaptationError, true)
-        XCTAssertEqual(response?.result.failure?.asAFError?.underlyingError as? OAuthError, .refreshNetworkFailure)
+        XCTAssertEqual(response?.result.failure?.asAFError?.underlyingError as? TestAuthError, .refreshNetworkFailure)
 
         if case let .requestRetryFailed(_, originalError) = response?.result.failure {
             XCTAssertEqual(originalError.asAFError?.isResponseValidationError, true)
@@ -290,8 +298,8 @@ class AuthenticationInterceptorTestCase: BaseTestCase {
 
     func testThatInterceptorDoesNotRetryWithoutResponse() {
         // Given
-        let credential = OAuthCredential()
-        let authenticator = OAuthAuthenticator()
+        let credential = TestCredential()
+        let authenticator = TestAuthenticator()
         let interceptor = AuthenticationInterceptor(authenticator: authenticator, credential: credential)
 
         let urlRequest = URLRequest(url: URL(string: "/invalid/path")!)
@@ -324,8 +332,8 @@ class AuthenticationInterceptorTestCase: BaseTestCase {
 
     func testThatInterceptorDoesNotRetryWhenRequestDoesNotFailDueToAuthError() {
         // Given
-        let credential = OAuthCredential()
-        let authenticator = OAuthAuthenticator()
+        let credential = TestCredential()
+        let authenticator = TestAuthenticator()
         let interceptor = AuthenticationInterceptor(authenticator: authenticator, credential: credential)
 
         let urlRequest = URLRequest.makeHTTPBinRequest(path: "status/500")
@@ -359,8 +367,8 @@ class AuthenticationInterceptorTestCase: BaseTestCase {
 
     func testThatInterceptorThrowsMissingCredentialErrorWhenCredentialIsNilAndRequestShouldBeRetried() {
         // Given
-        let credential = OAuthCredential()
-        let authenticator = OAuthAuthenticator()
+        let credential = TestCredential()
+        let authenticator = TestAuthenticator()
         let interceptor = AuthenticationInterceptor(authenticator: authenticator, credential: credential)
 
         let eventMonitor = ClosureEventMonitor()
@@ -403,18 +411,18 @@ class AuthenticationInterceptorTestCase: BaseTestCase {
 
     func testThatInterceptorRetriesRequestThatFailedWithOutdatedCredential() {
         // Given
-        let credential = OAuthCredential()
-        let authenticator = OAuthAuthenticator()
+        let credential = TestCredential()
+        let authenticator = TestAuthenticator()
         let interceptor = AuthenticationInterceptor(authenticator: authenticator, credential: credential)
 
         let eventMonitor = ClosureEventMonitor()
 
         eventMonitor.requestDidCreateTask = { _, _ in
-            interceptor.credential = OAuthCredential(accessToken: "a1",
-                                                     refreshToken: "r1",
-                                                     userID: "u0",
-                                                     expiration: Date(),
-                                                     requiresRefresh: false)
+            interceptor.credential = TestCredential(accessToken: "a1",
+                                                    refreshToken: "r1",
+                                                    userID: "u0",
+                                                    expiration: Date(),
+                                                    requiresRefresh: false)
         }
 
         let session = Session(eventMonitors: [eventMonitor])
@@ -447,10 +455,57 @@ class AuthenticationInterceptorTestCase: BaseTestCase {
         XCTAssertEqual(request.retryCount, 1)
     }
 
+    // Produces double lock reported in https://github.com/Alamofire/Alamofire/issues/3294#issuecomment-703241558
+    func testThatInterceptorDoesNotDeadlockWhenAuthenticatorCallsRefreshCompletionSynchronouslyOnCallingQueue() {
+        // Given
+        let credential = TestCredential(requiresRefresh: true)
+        let authenticator = TestAuthenticator(shouldRefreshAsynchronously: false)
+        let interceptor = AuthenticationInterceptor(authenticator: authenticator, credential: credential)
+
+        let eventMonitor = ClosureEventMonitor()
+
+        eventMonitor.requestDidCreateTask = { _, _ in
+            interceptor.credential = TestCredential(accessToken: "a1",
+                                                    refreshToken: "r1",
+                                                    userID: "u0",
+                                                    expiration: Date(),
+                                                    requiresRefresh: false)
+        }
+
+        let session = Session(eventMonitors: [eventMonitor])
+
+        let pathAdapter = PathAdapter(paths: ["/status/200"])
+        let compositeInterceptor = Interceptor(adapters: [pathAdapter, interceptor], retriers: [interceptor])
+
+        let urlRequest = URLRequest.makeHTTPBinRequest()
+
+        let expect = expectation(description: "request should complete")
+        var response: AFDataResponse<Data?>?
+
+        // When
+        let request = session.request(urlRequest, interceptor: compositeInterceptor).validate().response {
+            response = $0
+            expect.fulfill()
+        }
+
+        waitForExpectations(timeout: timeout)
+
+        // Then
+        XCTAssertEqual(response?.request?.headers["Authorization"], "Bearer a1")
+        XCTAssertEqual(response?.result.isSuccess, true)
+
+        XCTAssertEqual(authenticator.applyCount, 1)
+        XCTAssertEqual(authenticator.refreshCount, 1)
+        XCTAssertEqual(authenticator.didRequestFailDueToAuthErrorCount, 0)
+        XCTAssertEqual(authenticator.isRequestAuthenticatedWithCredentialCount, 0)
+
+        XCTAssertEqual(request.retryCount, 0)
+    }
+
     func testThatInterceptorRetriesRequestAfterRefresh() {
         // Given
-        let credential = OAuthCredential()
-        let authenticator = OAuthAuthenticator()
+        let credential = TestCredential()
+        let authenticator = TestAuthenticator()
         let interceptor = AuthenticationInterceptor(authenticator: authenticator, credential: credential)
 
         let pathAdapter = PathAdapter(paths: ["/status/401", "/status/200"])
@@ -485,8 +540,8 @@ class AuthenticationInterceptorTestCase: BaseTestCase {
 
     func testThatInterceptorRethrowsRefreshErrorFromRetry() {
         // Given
-        let credential = OAuthCredential()
-        let authenticator = OAuthAuthenticator(refreshResult: .failure(OAuthError.refreshNetworkFailure))
+        let credential = TestCredential()
+        let authenticator = TestAuthenticator(refreshResult: .failure(TestAuthError.refreshNetworkFailure))
         let interceptor = AuthenticationInterceptor(authenticator: authenticator, credential: credential)
 
         let session = Session()
@@ -508,7 +563,7 @@ class AuthenticationInterceptorTestCase: BaseTestCase {
 
         XCTAssertEqual(response?.result.isFailure, true)
         XCTAssertEqual(response?.result.failure?.asAFError?.isRequestRetryError, true)
-        XCTAssertEqual(response?.result.failure?.asAFError?.underlyingError as? OAuthError, .refreshNetworkFailure)
+        XCTAssertEqual(response?.result.failure?.asAFError?.underlyingError as? TestAuthError, .refreshNetworkFailure)
 
         if case let .requestRetryFailed(_, originalError) = response?.result.failure {
             XCTAssertEqual(originalError.asAFError?.isResponseValidationError, true)
@@ -525,8 +580,8 @@ class AuthenticationInterceptorTestCase: BaseTestCase {
 
     func testThatInterceptorTriggersRefreshWithMultipleParallelRequestsReturning401Responses() {
         // Given
-        let credential = OAuthCredential()
-        let authenticator = OAuthAuthenticator()
+        let credential = TestCredential()
+        let authenticator = TestAuthenticator()
         let interceptor = AuthenticationInterceptor(authenticator: authenticator, credential: credential)
 
         let requestCount = 6
@@ -574,8 +629,8 @@ class AuthenticationInterceptorTestCase: BaseTestCase {
 
     func testThatInterceptorIgnoresExcessiveRefreshWhenRefreshWindowIsNil() {
         // Given
-        let credential = OAuthCredential()
-        let authenticator = OAuthAuthenticator()
+        let credential = TestCredential()
+        let authenticator = TestAuthenticator()
         let interceptor = AuthenticationInterceptor(authenticator: authenticator, credential: credential)
 
         let pathAdapter = PathAdapter(paths: ["/status/401",
@@ -615,8 +670,8 @@ class AuthenticationInterceptorTestCase: BaseTestCase {
 
     func testThatInterceptorThrowsExcessiveRefreshErrorWhenExcessiveRefreshOccurs() {
         // Given
-        let credential = OAuthCredential()
-        let authenticator = OAuthAuthenticator()
+        let credential = TestCredential()
+        let authenticator = TestAuthenticator()
         let interceptor = AuthenticationInterceptor(authenticator: authenticator,
                                                     credential: credential,
                                                     refreshWindow: .init(interval: 30, maximumAttempts: 2))

+ 3 - 3
Tests/RequestTests.swift

@@ -1204,7 +1204,7 @@ class RequestInvalidURLTestCase: BaseTestCase {
         // Then
         XCTAssertEqual(response?.error?.isInvalidURLError, true)
     }
-    
+
     func testThatDownloadRequestWithFileURLThrowsError() {
         // Given
         let fileURL = url(forResource: "valid_data", withExtension: "json")
@@ -1223,7 +1223,7 @@ class RequestInvalidURLTestCase: BaseTestCase {
         // Then
         XCTAssertEqual(response?.error?.isInvalidURLError, true)
     }
-    
+
     func testThatDataStreamRequestWithFileURLThrowsError() {
         // Given
         let fileURL = url(forResource: "valid_data", withExtension: "json")
@@ -1234,7 +1234,7 @@ class RequestInvalidURLTestCase: BaseTestCase {
         AF.streamRequest(fileURL)
             .responseStream { stream in
                 guard case let .complete(completion) = stream.event else { return }
-                
+
                 response = completion
                 expectation.fulfill()
             }

+ 1 - 1
watchOS Example/watchOS Example.xcodeproj/xcshareddata/xcschemes/watchOS Example WatchKit App.xcscheme

@@ -1,6 +1,6 @@
 <?xml version="1.0" encoding="UTF-8"?>
 <Scheme
-   LastUpgradeVersion = "1200"
+   LastUpgradeVersion = "1220"
    version = "1.3">
    <BuildAction
       parallelizeBuildables = "YES"