Browse Source

Bugfix: DownloadRequest event duplication and missing events. (#2807)

* Refactor DownloadRequest.cancel to match other requests.

* Add missing DownloadTask eventMonitor calls.

* Fix metrics vs. completion race, duplicate task events.

* Add general tests for all events.

* Update naming of RequestTaskMap methods.

* Switch on the tuple elements separately.
Jon Shier 6 years ago
parent
commit
63cbd535f9

+ 9 - 13
Source/Request.swift

@@ -102,7 +102,7 @@ public class Request {
     }
 
     /// Protected `MutableState` value that provides threadsafe access to state values.
-    private let protectedMutableState: Protector<MutableState> = Protector(MutableState())
+    fileprivate let protectedMutableState: Protector<MutableState> = Protector(MutableState())
 
     /// `State` of the `Request`.
     public fileprivate(set) var state: State {
@@ -891,10 +891,10 @@ public class DownloadRequest: Request {
         var fileURL: URL?
     }
 
-    private let protectedMutableState: Protector<DownloadRequestMutableState> = Protector(DownloadRequestMutableState())
+    private let protectedDownloadMutableState: Protector<DownloadRequestMutableState> = Protector(DownloadRequestMutableState())
 
-    public var resumeData: Data? { return protectedMutableState.directValue.resumeData }
-    public var fileURL: URL? { return protectedMutableState.directValue.fileURL }
+    public var resumeData: Data? { return protectedDownloadMutableState.directValue.resumeData }
+    public var fileURL: URL? { return protectedDownloadMutableState.directValue.fileURL }
 
     // MARK: Init
 
@@ -920,15 +920,15 @@ public class DownloadRequest: Request {
     override func reset() {
         super.reset()
 
-        protectedMutableState.write { $0.resumeData = nil }
-        protectedMutableState.write { $0.fileURL = nil }
+        protectedDownloadMutableState.write { $0.resumeData = nil }
+        protectedDownloadMutableState.write { $0.fileURL = nil }
     }
 
     func didFinishDownloading(using task: URLSessionTask, with result: AFResult<URL>) {
         eventMonitor?.request(self, didFinishDownloadingUsing: task, with: result)
 
         switch result {
-        case .success(let url):   protectedMutableState.write { $0.fileURL = url }
+        case .success(let url):   protectedDownloadMutableState.write { $0.fileURL = url }
         case .failure(let error): self.error = error
         }
     }
@@ -950,16 +950,12 @@ public class DownloadRequest: Request {
 
     @discardableResult
     public override func cancel() -> Self {
-        guard state.canTransitionTo(.cancelled) else { return self }
-
-        state = .cancelled
+        guard protectedMutableState.attemptToTransitionTo(.cancelled) else { return self }
 
         delegate?.cancelDownloadRequest(self) { (resumeData) in
-            self.protectedMutableState.write { $0.resumeData = resumeData }
+            self.protectedDownloadMutableState.write { $0.resumeData = resumeData }
         }
 
-        eventMonitor?.requestDidCancel(self)
-
         return self
     }
 

+ 45 - 1
Source/RequestTaskMap.swift

@@ -28,14 +28,18 @@ import Foundation
 struct RequestTaskMap {
     private var tasksToRequests: [URLSessionTask: Request]
     private var requestsToTasks: [Request: URLSessionTask]
+    private var taskEvents: [URLSessionTask: (completed: Bool, metricsGathered: Bool)]
 
     var requests: [Request] {
         return Array(tasksToRequests.values)
     }
 
-    init(tasksToRequests: [URLSessionTask: Request] = [:], requestsToTasks: [Request: URLSessionTask] = [:]) {
+    init(tasksToRequests: [URLSessionTask: Request] = [:],
+         requestsToTasks: [Request: URLSessionTask] = [:],
+         taskEvents: [URLSessionTask: (completed: Bool, metricsGathered: Bool)] = [:]) {
         self.tasksToRequests = tasksToRequests
         self.requestsToTasks = requestsToTasks
+        self.taskEvents = taskEvents
     }
 
     subscript(_ request: Request) -> URLSessionTask? {
@@ -48,12 +52,14 @@ struct RequestTaskMap {
 
                 requestsToTasks.removeValue(forKey: request)
                 tasksToRequests.removeValue(forKey: task)
+                taskEvents.removeValue(forKey: task)
 
                 return
             }
 
             requestsToTasks[request] = newValue
             tasksToRequests[newValue] = request
+            taskEvents[newValue] = (completed: false, metricsGathered: false)
         }
     }
 
@@ -67,12 +73,14 @@ struct RequestTaskMap {
 
                 tasksToRequests.removeValue(forKey: task)
                 requestsToTasks.removeValue(forKey: request)
+                taskEvents.removeValue(forKey: task)
 
                 return
             }
 
             tasksToRequests[task] = newValue
             requestsToTasks[newValue] = task
+            taskEvents[task] = (completed: false, metricsGathered: false)
         }
     }
 
@@ -83,10 +91,46 @@ struct RequestTaskMap {
         return tasksToRequests.count
     }
 
+    var eventCount: Int {
+        precondition(taskEvents.count == count, "RequestTaskMap.eventCount invalid, count: \(count) != taskEvents.count: \(taskEvents.count)")
+
+        return taskEvents.count
+    }
+
     var isEmpty: Bool {
         precondition(tasksToRequests.isEmpty == requestsToTasks.isEmpty,
                      "RequestTaskMap.isEmpty invalid, requests.isEmpty: \(tasksToRequests.isEmpty) != tasks.isEmpty: \(requestsToTasks.isEmpty)")
 
         return tasksToRequests.isEmpty
     }
+
+    var isEventsEmpty: Bool {
+        precondition(taskEvents.isEmpty == isEmpty, "RequestTaskMap.isEventsEmpty invalid, isEmpty: \(isEmpty) != taskEvents.isEmpty: \(taskEvents.isEmpty)")
+
+        return taskEvents.isEmpty
+    }
+
+    mutating func disassociateIfNecessaryAfterGatheringMetricsForTask(_ task: URLSessionTask) {
+        guard let events = taskEvents[task] else {
+            fatalError("RequestTaskMap consistency error: no events corresponding to task found.")
+        }
+
+        switch (events.completed, events.metricsGathered) {
+        case (_, true): fatalError("RequestTaskMap consistency error: duplicate metricsGatheredForTask call.")
+        case (false, false): taskEvents[task] = (completed: false, metricsGathered: true)
+        case (true, false): self[task] = nil
+        }
+    }
+
+    mutating func disassociateIfNecessaryAfterCompletingTask(_ task: URLSessionTask) {
+        guard let events = taskEvents[task] else {
+            fatalError("RequestTaskMap consistency error: no events corresponding to task found.")
+        }
+
+        switch (events.completed, events.metricsGathered) {
+        case (true, _): fatalError("RequestTaskMap consistency error: duplicate completionReceivedForTask call.")
+        case (false, false): taskEvents[task] = (completed: true, metricsGathered: false)
+        case (false, true): self[task] = nil
+        }
+    }
 }

+ 4 - 0
Source/ResponseSerialization.swift

@@ -220,6 +220,8 @@ extension DownloadRequest {
                                             serializationDuration: 0,
                                             result: result)
 
+            self.eventMonitor?.request(self, didParseResponse: response)
+
             self.responseSerializerDidComplete { queue.async { completionHandler(response) } }
         }
 
@@ -257,6 +259,8 @@ extension DownloadRequest {
                                             serializationDuration: (end - start),
                                             result: result)
 
+            self.eventMonitor?.request(self, didParseResponse: response)
+
             guard let serializerError = result.error, let delegate = self.delegate else {
                 self.responseSerializerDidComplete { queue.async { completionHandler(response) } }
                 return

+ 15 - 5
Source/Session.swift

@@ -556,7 +556,8 @@ extension Session: RequestDelegate {
         rootQueue.async {
             request.didCancel()
 
-            guard let task = self.requestTaskMap[request] else {
+            // Cancellation only has an effect on running or suspended tasks.
+            guard let task = self.requestTaskMap[request], [.running, .suspended].contains(task.state) else {
                 request.finish()
                 return
             }
@@ -570,7 +571,10 @@ extension Session: RequestDelegate {
         rootQueue.async {
             request.didCancel()
 
-            guard let downloadTask = self.requestTaskMap[request] as? URLSessionDownloadTask else {
+            // Cancellation only has an effect on running or suspended tasks.
+            guard
+                let downloadTask = self.requestTaskMap[request] as? URLSessionDownloadTask,
+                [.running, .suspended].contains(downloadTask.state) else {
                 request.finish()
                 return
             }
@@ -590,7 +594,8 @@ extension Session: RequestDelegate {
 
             request.didSuspend()
 
-            guard let task = self.requestTaskMap[request] else { return }
+            // Tasks can only be suspended if they're running.
+            guard let task = self.requestTaskMap[request], task.state == .running else { return }
 
             task.suspend()
             request.didSuspendTask(task)
@@ -603,7 +608,8 @@ extension Session: RequestDelegate {
 
             request.didResume()
 
-            guard let task = self.requestTaskMap[request] else { return }
+            // Tasks can only be resumed if they're suspended.
+            guard let task = self.requestTaskMap[request], task.state == .suspended else { return }
 
             task.resume()
             request.didResumeTask(task)
@@ -618,8 +624,12 @@ extension Session: SessionStateProvider {
         return requestTaskMap[task]
     }
 
+    public func didGatherMetricsForTask(_ task: URLSessionTask) {
+        requestTaskMap.disassociateIfNecessaryAfterGatheringMetricsForTask(task)
+    }
+
     public func didCompleteTask(_ task: URLSessionTask) {
-        requestTaskMap[task] = nil
+        requestTaskMap.disassociateIfNecessaryAfterCompletingTask(task)
     }
 
     public func credential(for task: URLSessionTask, in protectionSpace: URLProtectionSpace) -> URLCredential? {

+ 3 - 0
Source/SessionDelegate.swift

@@ -30,6 +30,7 @@ protocol SessionStateProvider: AnyObject {
     var cachedResponseHandler: CachedResponseHandler? { get }
 
     func request(for task: URLSessionTask) -> Request?
+    func didGatherMetricsForTask(_ task: URLSessionTask)
     func didCompleteTask(_ task: URLSessionTask)
     func credential(for task: URLSessionTask, in protectionSpace: URLProtectionSpace) -> URLCredential?
     func cancelRequestsForSessionInvalidation(with error: Error?)
@@ -162,6 +163,8 @@ extension SessionDelegate: URLSessionTaskDelegate {
         eventMonitor?.urlSession(session, task: task, didFinishCollecting: metrics)
 
         stateProvider?.request(for: task)?.didGatherMetrics(metrics)
+
+        stateProvider?.didGatherMetricsForTask(task)
     }
 
     open func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) {

+ 79 - 0
Tests/DownloadTests.swift

@@ -381,6 +381,85 @@ class DownloadResponseTestCase: BaseTestCase {
     }
 }
 
+final class DownloadRequestEventsTestCase: BaseTestCase {
+    func testThatDownloadRequestTriggersAllAppropriateLifetimeEvents() {
+        // Given
+        let eventMonitor = ClosureEventMonitor()
+        let session = Session(eventMonitors: [eventMonitor])
+
+        let expect = expectation(description: "request should receive appropriate lifetime events")
+        expect.expectedFulfillmentCount = 14
+
+        var wroteData = false
+
+        eventMonitor.taskDidFinishCollectingMetrics = { (_, _, _) in expect.fulfill() }
+        eventMonitor.requestDidCreateURLRequest = { (_, _) in expect.fulfill() }
+        eventMonitor.requestDidCreateTask = { (_, _) in expect.fulfill() }
+        eventMonitor.requestDidGatherMetrics = { (_, _) in expect.fulfill() }
+        eventMonitor.requestDidCompleteTaskWithError = { (_, _, _) in expect.fulfill() }
+        eventMonitor.downloadTaskDidWriteData = { (_, _, _, _, _) in
+            guard !wroteData else { return }
+
+            wroteData = true
+            expect.fulfill()
+        }
+        eventMonitor.downloadTaskDidFinishDownloadingToURL = { (_, _, _) in expect.fulfill() }
+        eventMonitor.requestDidFinishDownloadingUsingTaskWithResult = { (_, _, _) in expect.fulfill() }
+        eventMonitor.requestDidCreateDestinationURL = { (_, _) in expect.fulfill() }
+        eventMonitor.requestDidFinish = { (_) in expect.fulfill() }
+        eventMonitor.requestDidResume = { (_) in expect.fulfill() }
+        eventMonitor.requestDidResumeTask = { (_, _) in expect.fulfill() }
+        eventMonitor.requestDidParseDownloadResponse = { (_, _) in expect.fulfill() }
+
+        // When
+        let request = session.download(URLRequest.makeHTTPBinRequest()).response { response in
+            expect.fulfill()
+        }
+
+        waitForExpectations(timeout: timeout, handler: nil)
+
+        // Then
+        XCTAssertEqual(request.state, .resumed)
+    }
+
+    func testThatCancelledDownloadRequestTriggersAllAppropriateLifetimeEvents() {
+        // Given
+        let eventMonitor = ClosureEventMonitor()
+        let session = Session(startRequestsImmediately: false, eventMonitors: [eventMonitor])
+
+        let expect = expectation(description: "request should receive appropriate lifetime events")
+        expect.expectedFulfillmentCount = 12
+
+        eventMonitor.taskDidFinishCollectingMetrics = { (_, _, _) in expect.fulfill() }
+        eventMonitor.requestDidCreateURLRequest = { (_, _) in expect.fulfill() }
+        eventMonitor.requestDidCreateTask = { (_, _) in expect.fulfill() }
+        eventMonitor.requestDidGatherMetrics = { (_, _) in expect.fulfill() }
+        eventMonitor.requestDidCompleteTaskWithError = { (_, _, _) in expect.fulfill() }
+        eventMonitor.requestDidFinish = { (_) in expect.fulfill() }
+        eventMonitor.requestDidResume = { (_) in expect.fulfill() }
+        eventMonitor.requestDidCancel = { _ in expect.fulfill() }
+        eventMonitor.requestDidCancelTask = { _, _ in expect.fulfill() }
+        eventMonitor.requestDidParseDownloadResponse = { (_, _) in expect.fulfill() }
+
+        // When
+        let request = session.download(URLRequest.makeHTTPBinRequest()).response { response in
+            expect.fulfill()
+        }
+
+        eventMonitor.requestDidResumeTask = { (_, _) in
+            request.cancel()
+            expect.fulfill()
+        }
+
+        request.resume()
+
+        waitForExpectations(timeout: timeout, handler: nil)
+
+        // Then
+        XCTAssertEqual(request.state, .cancelled)
+    }
+}
+
 // MARK: -
 
 class DownloadResumeDataTestCase: BaseTestCase {

+ 83 - 7
Tests/RequestTests.swift

@@ -491,15 +491,15 @@ class RequestResponseTestCase: BaseTestCase {
         // Then
         XCTAssertEqual(request.state, .cancelled)
     }
-    
+
     func testThatRequestManuallyCancelledManyTimesOnManyQueuesOnlyReceivesAppropriateLifetimeEvents() {
         // Given
         let eventMonitor = ClosureEventMonitor()
         let session = Session(eventMonitors: [eventMonitor])
-        
+
         let expect = expectation(description: "request should receive appropriate lifetime events")
         expect.expectedFulfillmentCount = 5
-        
+
         eventMonitor.requestDidCancelTask = { (_, _) in expect.fulfill() }
         eventMonitor.requestDidCancel = { _ in expect.fulfill() }
         eventMonitor.requestDidResume = { _ in expect.fulfill() }
@@ -507,20 +507,96 @@ class RequestResponseTestCase: BaseTestCase {
         // Fulfill other events that would exceed the expected count. Inverted expectations require the full timeout.
         eventMonitor.requestDidSuspend = { _ in expect.fulfill() }
         eventMonitor.requestDidSuspendTask = { (_, _) in expect.fulfill() }
-        
+
         // When
         let request = session.request(URLRequest.makeHTTPBinRequest())
         // Cancellation stops task creation, so don't cancel the request until the task has been created.
         eventMonitor.requestDidCreateTask = { (_, _) in
             DispatchQueue.concurrentPerform(iterations: 100) { i in
                 request.cancel()
-                
+
                 if i == 99 { expect.fulfill() }
             }
         }
-        
+
         waitForExpectations(timeout: timeout, handler: nil)
-        
+
+        // Then
+        XCTAssertEqual(request.state, .cancelled)
+    }
+
+    func testThatRequestTriggersAllAppropriateLifetimeEvents() {
+        // Given
+        let eventMonitor = ClosureEventMonitor()
+        let session = Session(eventMonitors: [eventMonitor])
+
+        let expect = expectation(description: "request should receive appropriate lifetime events")
+        expect.expectedFulfillmentCount = 13
+
+        var dataReceived = false
+
+        eventMonitor.taskDidReceiveChallenge = { (_, _, _) in expect.fulfill() }
+        eventMonitor.taskDidFinishCollectingMetrics = { (_, _, _) in expect.fulfill() }
+        eventMonitor.dataTaskDidReceiveData = { (_, _, _) in
+            guard !dataReceived else { return }
+            // Data may be received many times, fulfill only once.
+            dataReceived = true
+            expect.fulfill()
+        }
+        eventMonitor.dataTaskWillCacheResponse = { (_, _, _) in expect.fulfill() }
+        eventMonitor.requestDidCreateURLRequest = { (_, _) in expect.fulfill() }
+        eventMonitor.requestDidCreateTask = { (_, _) in expect.fulfill() }
+        eventMonitor.requestDidGatherMetrics = { (_, _) in expect.fulfill() }
+        eventMonitor.requestDidCompleteTaskWithError = { (_, _, _) in expect.fulfill() }
+        eventMonitor.requestDidFinish = { (_) in expect.fulfill() }
+        eventMonitor.requestDidResume = { (_) in expect.fulfill() }
+        eventMonitor.requestDidResumeTask = { (_, _) in expect.fulfill() }
+        eventMonitor.requestDidParseResponse = { (_, _) in expect.fulfill() }
+
+        // When
+        let request = session.request(URLRequest.makeHTTPBinRequest()).response { _ in
+            expect.fulfill()
+        }
+
+        waitForExpectations(timeout: timeout, handler: nil)
+
+        // Then
+        XCTAssertEqual(request.state, .resumed)
+    }
+
+    func testThatCancelledRequestTriggersAllAppropriateLifetimeEvents() {
+        // Given
+        let eventMonitor = ClosureEventMonitor()
+        let session = Session(startRequestsImmediately: false, eventMonitors: [eventMonitor])
+
+        let expect = expectation(description: "request should receive appropriate lifetime events")
+        expect.expectedFulfillmentCount = 12
+
+        eventMonitor.taskDidFinishCollectingMetrics = { (_, _, _) in expect.fulfill() }
+        eventMonitor.requestDidCreateURLRequest = { (_, _) in expect.fulfill() }
+        eventMonitor.requestDidCreateTask = { (_, _) in expect.fulfill() }
+        eventMonitor.requestDidGatherMetrics = { (_, _) in expect.fulfill() }
+        eventMonitor.requestDidCompleteTaskWithError = { (_, _, _) in expect.fulfill() }
+        eventMonitor.requestDidFinish = { (_) in expect.fulfill() }
+        eventMonitor.requestDidResume = { (_) in expect.fulfill() }
+        eventMonitor.requestDidCancel = { _ in expect.fulfill() }
+        eventMonitor.requestDidCancelTask = { _, _ in expect.fulfill() }
+        eventMonitor.requestDidParseResponse = { (_, _) in expect.fulfill() }
+
+        // When
+        let request = session.request(URLRequest.makeHTTPBinRequest()).response { _ in
+            expect.fulfill()
+        }
+
+        eventMonitor.requestDidResumeTask = { (_, _) in
+            request.cancel()
+            expect.fulfill()
+        }
+
+        request.resume()
+
+        waitForExpectations(timeout: timeout, handler: nil)
+
         // Then
         XCTAssertEqual(request.state, .cancelled)
     }

+ 72 - 0
Tests/UploadTests.swift

@@ -648,3 +648,75 @@ class UploadMultipartFormDataTestCase: BaseTestCase {
         }
     }
 }
+
+final class UploadRequestEventsTestCase: BaseTestCase {
+    func testThatUploadRequestTriggersAllAppropriateLifetimeEvents() {
+        // Given
+        let eventMonitor = ClosureEventMonitor()
+        let session = Session(eventMonitors: [eventMonitor])
+
+        let expect = expectation(description: "request should receive appropriate lifetime events")
+        expect.expectedFulfillmentCount = 11
+
+        eventMonitor.taskDidFinishCollectingMetrics = { (_, _, _) in expect.fulfill() }
+        eventMonitor.requestDidCreateURLRequest = { (_, _) in expect.fulfill() }
+        eventMonitor.requestDidCreateTask = { (_, _) in expect.fulfill() }
+        eventMonitor.requestDidGatherMetrics = { (_, _) in expect.fulfill() }
+        eventMonitor.requestDidCompleteTaskWithError = { (_, _, _) in expect.fulfill() }
+        eventMonitor.requestDidFinish = { (_) in expect.fulfill() }
+        eventMonitor.requestDidResume = { (_) in expect.fulfill() }
+        eventMonitor.requestDidResumeTask = { (_, _) in expect.fulfill() }
+        eventMonitor.requestDidCreateUploadable = { (_, _) in expect.fulfill() }
+        eventMonitor.requestDidParseResponse = { (_, _) in expect.fulfill() }
+
+        // When
+        let request = session.upload(Data("PAYLOAD".utf8),
+                                     with: URLRequest.makeHTTPBinRequest(path: "post", method: .post)).response { _ in
+            expect.fulfill()
+        }
+
+        waitForExpectations(timeout: timeout, handler: nil)
+
+        // Then
+        XCTAssertEqual(request.state, .resumed)
+    }
+
+    func testThatCancelledUploadRequestTriggersAllAppropriateLifetimeEvents() {
+        // Given
+        let eventMonitor = ClosureEventMonitor()
+        let session = Session(startRequestsImmediately: false, eventMonitors: [eventMonitor])
+
+        let expect = expectation(description: "request should receive appropriate lifetime events")
+        expect.expectedFulfillmentCount = 13
+
+        eventMonitor.taskDidFinishCollectingMetrics = { (_, _, _) in expect.fulfill() }
+        eventMonitor.requestDidCreateURLRequest = { (_, _) in expect.fulfill() }
+        eventMonitor.requestDidCreateTask = { (_, _) in expect.fulfill() }
+        eventMonitor.requestDidGatherMetrics = { (_, _) in expect.fulfill() }
+        eventMonitor.requestDidCompleteTaskWithError = { (_, _, _) in expect.fulfill() }
+        eventMonitor.requestDidFinish = { (_) in expect.fulfill() }
+        eventMonitor.requestDidResume = { (_) in expect.fulfill() }
+        eventMonitor.requestDidCreateUploadable = { (_, _) in expect.fulfill() }
+        eventMonitor.requestDidParseResponse = { (_, _) in expect.fulfill() }
+        eventMonitor.requestDidCancel = { (_) in expect.fulfill() }
+        eventMonitor.requestDidCancelTask = { (_, _) in expect.fulfill() }
+
+        // When
+        let request = session.upload(Data("PAYLOAD".utf8),
+                                     with: URLRequest.makeHTTPBinRequest(path: "post", method: .post)).response { _ in
+                                        expect.fulfill()
+        }
+
+        eventMonitor.requestDidResumeTask = { (_, _) in
+            request.cancel()
+            expect.fulfill()
+        }
+
+        request.resume()
+
+        waitForExpectations(timeout: timeout, handler: nil)
+
+        // Then
+        XCTAssertEqual(request.state, .cancelled)
+    }
+}