Переглянути джерело

Lock Request state access using internal API. (#2814)

Jon Shier 6 роки тому
батько
коміт
9bade216af
4 змінених файлів з 89 додано та 105 видалено
  1. 7 0
      Source/Protector.swift
  2. 63 22
      Source/Request.swift
  3. 17 79
      Source/Session.swift
  4. 2 4
      Tests/DownloadTests.swift

+ 7 - 0
Source/Protector.swift

@@ -143,4 +143,11 @@ extension Protector where T == Request.MutableState {
             return true
         }
     }
+    
+    /// Perform a closure while locked with the provided `Request.State`.
+    ///
+    /// - Parameter perform: The closure to perform while locked.
+    func withState(perform: (Request.State) -> Void) {
+        lock.around { perform(value.state) }
+    }
 }

+ 63 - 22
Source/Request.swift

@@ -105,10 +105,7 @@ public class Request {
     fileprivate let protectedMutableState: Protector<MutableState> = Protector(MutableState())
 
     /// `State` of the `Request`.
-    public fileprivate(set) var state: State {
-        get { return protectedMutableState.directValue.state }
-        set { protectedMutableState.write { $0.state = newValue } }
-    }
+    public var state: State { return protectedMutableState.directValue.state }
     /// Returns whether `state` is `.cancelled`.
     public var isCancelled: Bool { return state == .cancelled }
     /// Returns whether `state is `.resumed`.
@@ -461,6 +458,13 @@ public class Request {
 
         uploadProgressHandler?.queue.async { self.uploadProgressHandler?.handler(self.uploadProgress) }
     }
+    
+    /// Perform a closure on the current `state` while locked.
+    ///
+    /// - Parameter perform: The closure to perform.
+    func withState(perform: (State) -> Void) {
+        protectedMutableState.withState(perform: perform)
+    }
 
     // MARK: Task Creation
 
@@ -480,9 +484,21 @@ public class Request {
     /// - Returns: The `Request`.
     @discardableResult
     public func cancel() -> Self {
-        guard protectedMutableState.attemptToTransitionTo(.cancelled) else { return self }
-
-        delegate?.cancelRequest(self)
+        protectedMutableState.write { (mutableState) in
+            guard mutableState.state.canTransitionTo(.cancelled) else { return }
+            
+            mutableState.state = .cancelled
+            
+            underlyingQueue.async { self.didCancel() }
+            
+            guard let task = mutableState.tasks.last, task.state != .completed else {
+                underlyingQueue.async { self.finish() }
+                return
+            }
+            
+            task.cancel()
+            underlyingQueue.async { self.didCancelTask(task) }
+        }
 
         return self
     }
@@ -492,9 +508,18 @@ public class Request {
     /// - Returns: The `Request`.
     @discardableResult
     public func suspend() -> Self {
-        guard protectedMutableState.attemptToTransitionTo(.suspended) else { return self }
-
-        delegate?.suspendRequest(self)
+        protectedMutableState.write { (mutableState) in
+            guard mutableState.state.canTransitionTo(.suspended) else { return }
+            
+            mutableState.state = .suspended
+            
+            underlyingQueue.async { self.didSuspend() }
+            
+            guard let task = mutableState.tasks.last, task.state != .completed else { return }
+            
+            task.suspend()
+            underlyingQueue.async { self.didSuspendTask(task) }
+        }
 
         return self
     }
@@ -505,9 +530,18 @@ public class Request {
     /// - Returns: The `Request`.
     @discardableResult
     public func resume() -> Self {
-        guard protectedMutableState.attemptToTransitionTo(.resumed) else { return self }
-
-        delegate?.resumeRequest(self)
+        protectedMutableState.write { (mutableState) in
+            guard mutableState.state.canTransitionTo(.resumed) else { return }
+            
+            mutableState.state = .resumed
+            
+            underlyingQueue.async { self.didResume() }
+            
+            guard let task = mutableState.tasks.last, task.state != .completed else { return }
+            
+            task.resume()
+            underlyingQueue.async { self.didResumeTask(task) }
+        }
 
         return self
     }
@@ -724,11 +758,6 @@ public protocol RequestDelegate: AnyObject {
 
     func retryResult(for request: Request, dueTo error: Error, completion: @escaping (RetryResult) -> Void)
     func retryRequest(_ request: Request, withDelay timeDelay: TimeInterval?)
-
-    func cancelRequest(_ request: Request)
-    func cancelDownloadRequest(_ request: DownloadRequest, byProducingResumeData: @escaping (Data?) -> Void)
-    func suspendRequest(_ request: Request)
-    func resumeRequest(_ request: Request)
 }
 
 // MARK: - Subclasses
@@ -950,10 +979,22 @@ public class DownloadRequest: Request {
 
     @discardableResult
     public override func cancel() -> Self {
-        guard protectedMutableState.attemptToTransitionTo(.cancelled) else { return self }
-
-        delegate?.cancelDownloadRequest(self) { (resumeData) in
-            self.protectedDownloadMutableState.write { $0.resumeData = resumeData }
+        protectedMutableState.write { (mutableState) in
+            guard mutableState.state.canTransitionTo(.cancelled) else { return }
+            
+            mutableState.state = .cancelled
+            
+            underlyingQueue.async { self.didCancel() }
+            
+            guard let task = mutableState.tasks.last as? URLSessionDownloadTask, task.state != .completed else {
+                underlyingQueue.async { self.finish() }
+                return
+            }
+            
+            task.cancel { (resumeData) in
+                self.protectedDownloadMutableState.write { $0.resumeData = resumeData }
+                self.underlyingQueue.async { self.didCancelTask(task) }
+            }
         }
 
         return self

+ 17 - 79
Source/Session.swift

@@ -469,21 +469,23 @@ open class Session {
     }
 
     func updateStatesForTask(_ task: URLSessionTask, request: Request) {
-        switch (startRequestsImmediately, request.state) {
-        case (true, .initialized):
-            request.resume()
-        case (false, .initialized):
-            // Do nothing.
-            break
-        case (_, .resumed):
-            task.resume()
-            request.didResumeTask(task)
-        case (_, .suspended):
-            task.suspend()
-            request.didSuspendTask(task)
-        case (_, .cancelled):
-            task.cancel()
-            request.didCancelTask(task)
+        request.withState { (state) in
+            switch (startRequestsImmediately, state) {
+            case (true, .initialized):
+                rootQueue.async { request.resume() }
+            case (false, .initialized):
+                // Do nothing.
+                break
+            case (_, .resumed):
+                task.resume()
+                rootQueue.async { request.didResumeTask(task) }
+            case (_, .suspended):
+                task.suspend()
+                rootQueue.async { request.didSuspendTask(task) }
+            case (_, .cancelled):
+                task.cancel()
+                rootQueue.async { request.didCancelTask(task) }
+            }
         }
     }
 
@@ -551,70 +553,6 @@ extension Session: RequestDelegate {
             }
         }
     }
-
-    public func cancelRequest(_ request: Request) {
-        rootQueue.async {
-            request.didCancel()
-
-            // Cancellation only has an effect on running or suspended tasks.
-            guard let task = self.requestTaskMap[request], [.running, .suspended].contains(task.state) else {
-                request.finish()
-                return
-            }
-
-            task.cancel()
-            request.didCancelTask(task)
-        }
-    }
-
-    public func cancelDownloadRequest(_ request: DownloadRequest, byProducingResumeData: @escaping (Data?) -> Void) {
-        rootQueue.async {
-            request.didCancel()
-
-            // Cancellation only has an effect on running or suspended tasks.
-            guard
-                let downloadTask = self.requestTaskMap[request] as? URLSessionDownloadTask,
-                [.running, .suspended].contains(downloadTask.state) else {
-                request.finish()
-                return
-            }
-
-            downloadTask.cancel { (data) in
-                self.rootQueue.async {
-                    byProducingResumeData(data)
-                    request.didCancelTask(downloadTask)
-                }
-            }
-        }
-    }
-
-    public func suspendRequest(_ request: Request) {
-        rootQueue.async {
-            guard !request.isCancelled else { return }
-
-            request.didSuspend()
-
-            // Tasks can only be suspended if they're running.
-            guard let task = self.requestTaskMap[request], task.state == .running else { return }
-
-            task.suspend()
-            request.didSuspendTask(task)
-        }
-    }
-
-    public func resumeRequest(_ request: Request) {
-        rootQueue.async {
-            guard !request.isCancelled else { return }
-
-            request.didResume()
-
-            // Tasks can only be resumed if they're suspended.
-            guard let task = self.requestTaskMap[request], task.state == .suspended else { return }
-
-            task.resume()
-            request.didResumeTask(task)
-        }
-    }
 }
 
 // MARK: - SessionStateProvider

+ 2 - 4
Tests/DownloadTests.swift

@@ -571,10 +571,8 @@ class DownloadResumeDataTestCase: BaseTestCase {
 
         var progressValues: [Double] = []
         var response2: DownloadResponse<Data>?
-        let destination = DownloadRequest.suggestedDownloadDestination(options: [.removePreviousFile, .createIntermediateDirectories])
-        // TODO: Added destination because temp file was being deleted very quickly.
-        AF.download(resumingWith: resumeData,
-                           to: destination)
+
+        AF.download(resumingWith: resumeData)
             .downloadProgress { progress in
                 progressValues.append(progress.fractionCompleted)
             }