2
0
Эх сурвалжийг харах

Fix #2682: Cancel all requests on session invalidation or deinit. (#2728)

* Cancel all requests on session invalidation or deinit.

* Fix typos.

* Prevent implicit self capture by capturing early.

* Fix validation leaks for downloads too.

* Separate deinit finish from invalidation finish.

* Fix master build failure.

* Attempt to make notification tests more resilient.

* Refactor notifications and tests.
Jon Shier 6 жил өмнө
parent
commit
9ab2a22e86

+ 41 - 15
Source/AFError.swift

@@ -185,6 +185,8 @@ public enum AFError: Error {
         case publicKeyPinningFailed(host: String, trust: SecTrust, pinnedKeys: [SecKey], serverKeys: [SecKey])
     }
 
+    case sessionDeinitialized
+    case sessionInvalidated(error: Error?)
     case explicitlyCancelled
     case invalidURL(url: URLConvertible)
     case parameterEncodingFailed(reason: ParameterEncodingFailureReason)
@@ -207,66 +209,79 @@ extension Error {
 // MARK: - Error Booleans
 
 extension AFError {
-    /// Returns whether the `AFError` is an explicitly cancelled error.
+    // Returns whether the instance is `.sessionDeinitialized`.
+    public var isSessionDeinitializedError: Bool {
+        if case .sessionDeinitialized = self { return true }
+        return false
+    }
+
+    // Returns whether the instance is `.sessionInvalidated`.
+    public var isSessionInvalidatedError: Bool {
+        if case .sessionInvalidated = self { return true }
+        return false
+    }
+
+    /// Returns whether the instance is `.explicitlyCancelled`.
     public var isExplicitlyCancelledError: Bool {
         if case .explicitlyCancelled = self { return true }
         return false
     }
 
-    /// Returns whether the AFError is an invalid URL error.
+    /// Returns whether the instance is `.invalidURL`.
     public var isInvalidURLError: Bool {
         if case .invalidURL = self { return true }
         return false
     }
 
-    /// Returns whether the AFError is a parameter encoding error. When `true`, the `underlyingError` property will
+    /// Returns whether the instance is `.parameterEncodingFailed`. When `true`, the `underlyingError` property will
     /// contain the associated value.
     public var isParameterEncodingError: Bool {
         if case .parameterEncodingFailed = self { return true }
         return false
     }
 
-    /// Returns whether the instance is a parameter encoder error.
+    /// Returns whether the instance is `.parameterEncoderFailed`. When `true`, the `underlyingError` property will
+    // contain the associated value.
     public var isParameterEncoderError: Bool {
         if case .parameterEncoderFailed = self { return true }
         return false
     }
 
-    /// Returns whether the AFError is a multipart encoding error. When `true`, the `url` and `underlyingError` properties
-    /// will contain the associated values.
+    /// Returns whether the instance is `.multipartEncodingFailed`. When `true`, the `url` and `underlyingError`
+    /// properties will contain the associated values.
     public var isMultipartEncodingError: Bool {
         if case .multipartEncodingFailed = self { return true }
         return false
     }
 
-    /// Returns whether the AFError is a request adaptation error. When `true`, the `underlyingError` property will
+    /// Returns whether the instance is `.requestAdaptationFailed`. When `true`, the `underlyingError` property will
     /// contain the associated value.
     public var isRequestAdaptationError: Bool {
         if case .requestAdaptationFailed = self { return true }
         return false
     }
 
-    /// Returns whether the `AFError` is a response validation error. When `true`, the `acceptableContentTypes`,
+    /// Returns whether the instance is `.responseValidationFailed`. When `true`, the `acceptableContentTypes`,
     /// `responseContentType`, and `responseCode` properties will contain the associated values.
     public var isResponseValidationError: Bool {
         if case .responseValidationFailed = self { return true }
         return false
     }
 
-    /// Returns whether the `AFError` is a response serialization error. When `true`, the `failedStringEncoding` and
+    /// Returns whether the instance is `.responseSerializationFailed`. When `true`, the `failedStringEncoding` and
     /// `underlyingError` properties will contain the associated values.
     public var isResponseSerializationError: Bool {
         if case .responseSerializationFailed = self { return true }
         return false
     }
 
-    /// Returns whether the `AFError` is a server trust evaluation error.
+    /// Returns whether the instance is `.serverTrustEvaluationFailed`.
     public var isServerTrustEvaluationError: Bool {
         if case .serverTrustEvaluationFailed = self { return true }
         return false
     }
 
-    /// Returns whether the AFError is a request retry error. When `true`, the `underlyingError` property will
+    /// Returns whether the instance is `requestRetryFailed`. When `true`, the `underlyingError` property will
     /// contain the associated value.
     public var isRequestRetryError: Bool {
         if case .requestRetryFailed = self { return true }
@@ -297,11 +312,13 @@ extension AFError {
         }
     }
 
-    /// The underlying `Error` responsible for generating the failure associated with `.parameterEncodingFailed`,
-    /// `.parameterEncoderFailed`, `.multipartEncodingFailed`, `.requestAdaptationFailed`,
+    /// The underlying `Error` responsible for generating the failure associated with `.sessionInvalidated`,
+    /// `.parameterEncodingFailed`, `.parameterEncoderFailed`, `.multipartEncodingFailed`, `.requestAdaptationFailed`,
     /// `.responseSerializationFailed`, `.requestRetryFailed` errors.
     public var underlyingError: Error? {
         switch self {
+        case .sessionInvalidated(let error):
+            return error
         case .parameterEncodingFailed(let reason):
             return reason.underlyingError
         case .parameterEncoderFailed(let reason):
@@ -473,6 +490,13 @@ extension AFError.ServerTrustFailureReason {
 extension AFError: LocalizedError {
     public var errorDescription: String? {
         switch self {
+        case .sessionDeinitialized:
+            return """
+                   Session was invalidated without error, so it was likely deinitialized unexpectedly. \
+                   Be sure to retain a reference to your Session for the duration of your requests.
+                   """
+        case .sessionInvalidated(let error):
+            return "Session was invalidated with error: \(error?.localizedDescription ?? "No description.")"
         case .explicitlyCancelled:
             return "Request explicitly cancelled."
         case .invalidURL(let url):
@@ -492,8 +516,10 @@ extension AFError: LocalizedError {
         case .serverTrustEvaluationFailed:
             return "Server trust evaluation failed."
         case .requestRetryFailed(let retryError, let originalError):
-            return "Request retry failed with retry error: \(retryError.localizedDescription), " +
-                "original error: \(originalError.localizedDescription)"
+            return """
+                   Request retry failed with retry error: \(retryError.localizedDescription), \
+                   original error: \(originalError.localizedDescription)
+                   """
         }
     }
 }

+ 6 - 6
Source/Notifications.swift

@@ -25,14 +25,14 @@
 import Foundation
 
 public extension Request {
-    /// Posted when a `Request`'s task is resumed. The `Notification` contains the resumed `Request`.
+    /// Posted when a `Request` is resumed. The `Notification` contains the resumed `Request`.
     static let didResume = Notification.Name(rawValue: "org.alamofire.notification.name.request.didResume")
-    /// Posted when a `Request`'s task is suspended. The `Notification` contains the suspended `Request`.
+    /// Posted when a `Request` is suspended. The `Notification` contains the suspended `Request`.
     static let didSuspend = Notification.Name(rawValue: "org.alamofire.notification.name.request.didSuspend")
     /// Posted when a `Request` is cancelled. The `Notification` contains the cancelled `Request`.
     static let didCancel = Notification.Name(rawValue: "org.alamofire.notification.name.request.didCancel")
-    /// Posted when a `Request`'s task is completed. The `Notification` contains the completed `Request`.
-    static let didComplete = Notification.Name(rawValue: "org.alamofire.notification.name.request.didComplete")
+    /// Posted when a `Request` is finished. The `Notification` contains the completed `Request`.
+    static let didFinish = Notification.Name(rawValue: "org.alamofire.notification.name.request.didFinish")
 }
 
 // MARK: -
@@ -72,8 +72,8 @@ extension String {
 
 /// `EventMonitor` that provides Alamofire's notifications.
 public final class AlamofireNotifications: EventMonitor {
-    public func request(_ request: Request, didCompleteTask task: URLSessionTask, with error: Error?) {
-        NotificationCenter.default.postNotification(named: Request.didComplete, with: request)
+    public func requestDidFinish(_ request: Request) {
+        NotificationCenter.default.postNotification(named: Request.didFinish, with: request)
     }
 
     public func requestDidResume(_ request: Request) {

+ 27 - 23
Source/RequestTaskMap.swift

@@ -26,63 +26,67 @@ import Foundation
 
 /// A type that maintains a two way, one to one map of `URLSessionTask`s to `Request`s.
 struct RequestTaskMap {
-    private var requests: [URLSessionTask: Request]
-    private var tasks: [Request: URLSessionTask]
+    private var tasksToRequests: [URLSessionTask: Request]
+    private var requestsToTasks: [Request: URLSessionTask]
 
-    init(requests: [URLSessionTask: Request] = [:], tasks: [Request: URLSessionTask] = [:]) {
-        self.requests = requests
-        self.tasks = tasks
+    var requests: [Request] {
+        return Array(tasksToRequests.values)
+    }
+
+    init(tasksToRequests: [URLSessionTask: Request] = [:], requestsToTasks: [Request: URLSessionTask] = [:]) {
+        self.tasksToRequests = tasksToRequests
+        self.requestsToTasks = requestsToTasks
     }
 
     subscript(_ request: Request) -> URLSessionTask? {
-        get { return tasks[request] }
+        get { return requestsToTasks[request] }
         set {
             guard let newValue = newValue else {
-                guard let task = tasks[request] else {
+                guard let task = requestsToTasks[request] else {
                     fatalError("RequestTaskMap consistency error: no task corresponding to request found.")
                 }
 
-                tasks.removeValue(forKey: request)
-                requests.removeValue(forKey: task)
+                requestsToTasks.removeValue(forKey: request)
+                tasksToRequests.removeValue(forKey: task)
 
                 return
             }
 
-            tasks[request] = newValue
-            requests[newValue] = request
+            requestsToTasks[request] = newValue
+            tasksToRequests[newValue] = request
         }
     }
 
     subscript(_ task: URLSessionTask) -> Request? {
-        get { return requests[task] }
+        get { return tasksToRequests[task] }
         set {
             guard let newValue = newValue else {
-                guard let request = requests[task] else {
+                guard let request = tasksToRequests[task] else {
                     fatalError("RequestTaskMap consistency error: no request corresponding to task found.")
                 }
 
-                requests.removeValue(forKey: task)
-                tasks.removeValue(forKey: request)
+                tasksToRequests.removeValue(forKey: task)
+                requestsToTasks.removeValue(forKey: request)
 
                 return
             }
 
-            requests[task] = newValue
-            tasks[newValue] = task
+            tasksToRequests[task] = newValue
+            requestsToTasks[newValue] = task
         }
     }
 
     var count: Int {
-        precondition(requests.count == tasks.count,
-                     "RequestTaskMap.count invalid, requests.count: \(requests.count) != tasks.count: \(tasks.count)")
+        precondition(tasksToRequests.count == requestsToTasks.count,
+                     "RequestTaskMap.count invalid, requests.count: \(tasksToRequests.count) != tasks.count: \(requestsToTasks.count)")
 
-        return requests.count
+        return tasksToRequests.count
     }
 
     var isEmpty: Bool {
-        precondition(requests.isEmpty == tasks.isEmpty,
-                     "RequestTaskMap.isEmpty invalid, requests.isEmpty: \(requests.isEmpty) != tasks.isEmpty: \(tasks.isEmpty)")
+        precondition(tasksToRequests.isEmpty == requestsToTasks.isEmpty,
+                     "RequestTaskMap.isEmpty invalid, requests.isEmpty: \(tasksToRequests.isEmpty) != tasks.isEmpty: \(requestsToTasks.isEmpty)")
 
-        return requests.isEmpty
+        return tasksToRequests.isEmpty
     }
 }

+ 0 - 2
Source/Response.swift

@@ -286,7 +286,6 @@ extension DownloadResponse: CustomStringConvertible, CustomDebugStringConvertibl
                    \(sortedHeaders)
                    """
         } ?? "nil"
-        let responseBody = data.map { String(decoding: $0, as: UTF8.self) } ?? "None"
         let metricsDescription = metrics.map { "\($0.taskInterval.duration)s" } ?? "None"
         let resumeDataDescription = resumeData.map { "\($0)" } ?? "None"
 
@@ -294,7 +293,6 @@ extension DownloadResponse: CustomStringConvertible, CustomDebugStringConvertibl
         [Request]: \(requestDescription)
         [Request Body]: \n\(requestBody)
         [Response]: \n\(responseDescription)
-        [Response Body]: \n\(responseBody)
         [File URL]: \(fileURL?.path ?? "nil")
         [ResumeData]: \(resumeDataDescription)
         [Network Duration]: \(metricsDescription)

+ 19 - 8
Source/Session.swift

@@ -43,10 +43,10 @@ open class Session {
     var requestTaskMap = RequestTaskMap()
     public let startRequestsImmediately: Bool
 
-    public init(startRequestsImmediately: Bool = true,
-                session: URLSession,
+    public init(session: URLSession,
                 delegate: SessionDelegate,
                 rootQueue: DispatchQueue,
+                startRequestsImmediately: Bool = true,
                 requestQueue: DispatchQueue? = nil,
                 serializationQueue: DispatchQueue? = nil,
                 interceptor: RequestInterceptor? = nil,
@@ -59,10 +59,10 @@ open class Session {
         precondition(session.delegateQueue.underlyingQueue === rootQueue,
                      "SessionManager(session:) intializer must be passed the DispatchQueue used as the delegateQueue's underlyingQueue as rootQueue.")
 
-        self.startRequestsImmediately = startRequestsImmediately
         self.session = session
         self.delegate = delegate
         self.rootQueue = rootQueue
+        self.startRequestsImmediately = startRequestsImmediately
         self.requestQueue = requestQueue ?? DispatchQueue(label: "\(rootQueue.label).requestQueue", target: rootQueue)
         self.serializationQueue = serializationQueue ?? DispatchQueue(label: "\(rootQueue.label).serializationQueue", target: rootQueue)
         self.interceptor = interceptor
@@ -74,10 +74,10 @@ open class Session {
         delegate.stateProvider = self
     }
 
-    public convenience init(startRequestsImmediately: Bool = true,
-                            configuration: URLSessionConfiguration = .alamofireDefault,
+    public convenience init(configuration: URLSessionConfiguration = .alamofireDefault,
                             delegate: SessionDelegate = SessionDelegate(),
                             rootQueue: DispatchQueue = DispatchQueue(label: "org.alamofire.sessionManager.rootQueue"),
+                            startRequestsImmediately: Bool = true,
                             requestQueue: DispatchQueue? = nil,
                             serializationQueue: DispatchQueue? = nil,
                             interceptor: RequestInterceptor? = nil,
@@ -88,10 +88,10 @@ open class Session {
         let delegateQueue = OperationQueue(maxConcurrentOperationCount: 1, underlyingQueue: rootQueue, name: "org.alamofire.sessionManager.sessionDelegateQueue")
         let session = URLSession(configuration: configuration, delegate: delegate, delegateQueue: delegateQueue)
 
-        self.init(startRequestsImmediately: startRequestsImmediately,
-                  session: session,
+        self.init(session: session,
                   delegate: delegate,
                   rootQueue: rootQueue,
+                  startRequestsImmediately: startRequestsImmediately,
                   requestQueue: requestQueue,
                   serializationQueue: serializationQueue,
                   interceptor: interceptor,
@@ -102,6 +102,7 @@ open class Session {
     }
 
     deinit {
+        finishRequestsForDeinit()
         session.invalidateAndCancel()
     }
 
@@ -497,6 +498,12 @@ open class Session {
             return request.interceptor ?? interceptor
         }
     }
+
+    // MARK: - Invalidation
+
+    func finishRequestsForDeinit() {
+        requestTaskMap.requests.forEach { $0.finish(error: AFError.sessionDeinitialized) }
+    }
 }
 
 // MARK: - RequestDelegate
@@ -605,8 +612,12 @@ extension Session: SessionStateProvider {
         requestTaskMap[task] = nil
     }
 
-    public func credential(for task: URLSessionTask, protectionSpace: URLProtectionSpace) -> URLCredential? {
+    public func credential(for task: URLSessionTask, in protectionSpace: URLProtectionSpace) -> URLCredential? {
         return requestTaskMap[task]?.credential ??
                session.configuration.urlCredentialStorage?.defaultCredential(for: protectionSpace)
     }
+
+    public func cancelRequestsForSessionInvalidation(with error: Error?) {
+        requestTaskMap.requests.forEach { $0.finish(error: AFError.sessionInvalidated(error: error)) }
+    }
 }

+ 8 - 4
Source/SessionStateProvider.swift

@@ -25,12 +25,14 @@
 import Foundation
 
 public protocol SessionStateProvider: AnyObject {
-    func request(for task: URLSessionTask) -> Request?
-    func didCompleteTask(_ task: URLSessionTask)
     var serverTrustManager: ServerTrustManager? { get }
     var redirectHandler: RedirectHandler? { get }
     var cachedResponseHandler: CachedResponseHandler? { get }
-    func credential(for task: URLSessionTask, protectionSpace: URLProtectionSpace) -> URLCredential?
+
+    func request(for task: URLSessionTask) -> Request?
+    func didCompleteTask(_ task: URLSessionTask)
+    func credential(for task: URLSessionTask, in protectionSpace: URLProtectionSpace) -> URLCredential?
+    func cancelRequestsForSessionInvalidation(with error: Error?)
 }
 
 open class SessionDelegate: NSObject {
@@ -47,6 +49,8 @@ open class SessionDelegate: NSObject {
 extension SessionDelegate: URLSessionDelegate {
     open func urlSession(_ session: URLSession, didBecomeInvalidWithError error: Error?) {
         eventMonitor?.urlSession(session, didBecomeInvalidWithError: error)
+
+        stateProvider?.cancelRequestsForSessionInvalidation(with: error)
     }
 }
 
@@ -106,7 +110,7 @@ extension SessionDelegate: URLSessionTaskDelegate {
             return (.rejectProtectionSpace, nil, nil)
         }
 
-        guard let credential = stateProvider?.credential(for: task, protectionSpace: challenge.protectionSpace) else {
+        guard let credential = stateProvider?.credential(for: task, in: challenge.protectionSpace) else {
             return (.performDefaultHandling, nil, nil)
         }
 

+ 8 - 2
Source/Validation.swift

@@ -184,7 +184,10 @@ extension DataRequest {
     /// - returns: The request.
     @discardableResult
     public func validate() -> Self {
-        return validate(statusCode: self.acceptableStatusCodes).validate(contentType: self.acceptableContentTypes)
+        let contentTypes: () -> [String] = { [unowned self] in
+            return self.acceptableContentTypes
+        }
+        return validate(statusCode: acceptableStatusCodes).validate(contentType: contentTypes())
     }
 }
 
@@ -244,6 +247,9 @@ extension DownloadRequest {
     /// - returns: The request.
     @discardableResult
     public func validate() -> Self {
-        return validate(statusCode: self.acceptableStatusCodes).validate(contentType: self.acceptableContentTypes)
+        let contentTypes = { [unowned self] in
+            return self.acceptableContentTypes
+        }
+        return validate(statusCode: acceptableStatusCodes).validate(contentType: contentTypes())
     }
 }

+ 4 - 1
Tests/BaseTestCase.swift

@@ -44,7 +44,10 @@ class BaseTestCase: XCTestCase {
         return bundle.url(forResource: fileName, withExtension: ext)!
     }
 
-    func assertErrorIsAFError(_ error: Error?, file: StaticString = #file, line: UInt = #line, evaluation: (_ error: AFError) -> Void) {
+    func assertErrorIsAFError(_ error: Error?,
+                              file: StaticString = #file,
+                              line: UInt = #line,
+                              evaluation: (_ error: AFError) -> Void) {
         guard let error = error?.asAFError else {
             XCTFail("error is not an AFError", file: file, line: line)
             return

+ 1 - 1
Tests/MultipartFormDataTests.swift

@@ -909,7 +909,7 @@ class MultipartFormDataFailureTestCase: BaseTestCase {
         var writerError: Error?
 
         do {
-            try "dummy data".write(to: fileURL, atomically: true, encoding: String.Encoding.utf8)
+            try "dummy data".write(to: fileURL, atomically: true, encoding: .utf8)
         } catch {
             writerError = error
         }

+ 50 - 16
Tests/SessionDelegateTests.swift

@@ -96,37 +96,71 @@ class SessionDelegateTestCase: BaseTestCase {
 
     func testThatAppropriateNotificationsAreCalledWithRequestForDataRequest() {
         // Given
-        var request: Request?
-        _ = expectation(forNotification: Request.didResume, object: nil, handler: nil)
-        _ = expectation(forNotification: Request.didComplete, object: nil) { (notification) in
-            request = notification.request
-            return (request != nil)
-        }
+        let session = Session(startRequestsImmediately: false)
+        var resumedRequest: Request?
+        var completedRequest: Request?
+        var requestResponse: DataResponse<Data?>?
+        let expect = expectation(description: "request should complete")
 
         // When
-        manager.request("https://httpbin.org/get").response { _ in }
+        let request = session.request("https://httpbin.org/get").response { (response) in
+            requestResponse = response
+            expect.fulfill()
+        }
+        expectation(forNotification: Request.didResume, object: nil) { (notification) in
+            guard let receivedRequest = notification.request, receivedRequest == request else { return false }
+
+            resumedRequest = notification.request
+            return true
+        }
+        expectation(forNotification: Request.didFinish, object: nil) { (notification) in
+            guard let receivedRequest = notification.request, receivedRequest == request else { return false }
+
+            completedRequest = notification.request
+            return true
+        }
+
+        request.resume()
 
         waitForExpectations(timeout: timeout, handler: nil)
 
         // Then
-        XCTAssertEqual(request?.response?.statusCode, 200)
+        XCTAssertEqual(resumedRequest, completedRequest)
+        XCTAssertEqual(requestResponse?.response?.statusCode, 200)
     }
 
     func testThatDidCompleteNotificationIsCalledWithRequestForDownloadRequests() {
         // Given
-        var request: Request?
-        _ = expectation(forNotification: Request.didResume, object: nil, handler: nil)
-        _ = expectation(forNotification: Request.didComplete, object: nil) { (notification) in
-            request = notification.request
-            return (request != nil)
-        }
+        let session = Session(startRequestsImmediately: false)
+        var resumedRequest: Request?
+        var completedRequest: Request?
+        var requestResponse: DownloadResponse<URL?>?
+        let expect = expectation(description: "request should complete")
 
         // When
-        manager.download("https://httpbin.org/get").response { _ in }
+        let request = session.download("https://httpbin.org/get").response { (response) in
+            requestResponse = response
+            expect.fulfill()
+        }
+        expectation(forNotification: Request.didResume, object: nil) { (notification) in
+            guard let receivedRequest = notification.request, receivedRequest == request else { return false }
+
+            resumedRequest = notification.request
+            return true
+        }
+        expectation(forNotification: Request.didFinish, object: nil) { (notification) in
+            guard let receivedRequest = notification.request, receivedRequest == request else { return false }
+
+            completedRequest = notification.request
+            return true
+        }
+
+        request.resume()
 
         waitForExpectations(timeout: timeout, handler: nil)
 
         // Then
-        XCTAssertEqual(request?.response?.statusCode, 200)
+        XCTAssertEqual(resumedRequest, completedRequest)
+        XCTAssertEqual(requestResponse?.response?.statusCode, 200)
     }
 }

+ 26 - 0
Tests/SessionTests.swift

@@ -1136,6 +1136,32 @@ class SessionTestCase: BaseTestCase {
             XCTFail("error should not be nil")
         }
     }
+
+    // MARK: Tests - Session Invalidation
+
+    func testThatSessionIsInvalidatedAndAllRequestsCompleteWhenSessionIsDeinitialized() {
+        // Given
+        let invalidationExpectation = expectation(description: "sessionDidBecomeInvalidWithError should be called")
+        let events = ClosureEventMonitor()
+        events.sessionDidBecomeInvalidWithError = { (_, _) in
+            invalidationExpectation.fulfill()
+        }
+        var session: Session? = Session(startRequestsImmediately: false, eventMonitors: [events])
+        var error: Error?
+        let requestExpectation = expectation(description: "request should complete")
+
+        // When
+        session?.request(URLRequest.makeHTTPBinRequest()).response { (response) in
+            error = response.error
+            requestExpectation.fulfill()
+        }
+        session = nil
+
+        waitForExpectations(timeout: timeout, handler: nil)
+
+        // Then
+        assertErrorIsAFError(error) { XCTAssertTrue($0.isSessionDeinitializedError) }
+    }
 }
 
 // MARK: -