Browse Source

Try cast to UploadRequest before DataRequest. (#3127)

Jon Shier 5 years ago
parent
commit
1bc958baee
4 changed files with 61 additions and 1 deletions
  1. 7 0
      Source/Request.swift
  2. 2 1
      Source/Session.swift
  3. 21 0
      Tests/RequestInterceptorTests.swift
  4. 31 0
      Tests/UploadTests.swift

+ 7 - 0
Source/Request.swift

@@ -1457,6 +1457,13 @@ public class UploadRequest: DataRequest {
         }
     }
 
+    override func reset() {
+        // Uploadable must be recreated on every retry.
+        uploadable = nil
+
+        super.reset()
+    }
+
     /// Produces the `InputStream` from `uploadable`, if it can.
     ///
     /// - Note: Calling this method with a non-`.stream` `Uploadable` is a logic error and will crash.

+ 2 - 1
Source/Session.swift

@@ -813,9 +813,10 @@ open class Session {
     ///
     /// - Parameter request: The `Request` to perform.
     func perform(_ request: Request) {
+        // Leaf types must come first, otherwise they will cast as their superclass.
         switch request {
+        case let r as UploadRequest: perform(r) // UploadRequest must come before DataRequest due to subtype relationship.
         case let r as DataRequest: perform(r)
-        case let r as UploadRequest: perform(r)
         case let r as DownloadRequest: perform(r)
         default: fatalError("Attempted to perform unsupported Request subclass: \(type(of: request))")
         }

+ 21 - 0
Tests/RequestInterceptorTests.swift

@@ -559,6 +559,27 @@ final class InspectorInterceptor<Interceptor: RequestInterceptor>: RequestInterc
     }
 }
 
+/// Retry a request once, allowing the second to succeed using the method path.
+final class SingleRetrier: RequestInterceptor {
+    private var hasRetried = false
+
+    func adapt(_ urlRequest: URLRequest, for session: Session, completion: @escaping (Result<URLRequest, Error>) -> Void) {
+        if hasRetried {
+            var request = URLRequest.makeHTTPBinRequest(path: "\(urlRequest.httpMethod?.lowercased() ?? "get")")
+            request.method = urlRequest.method
+            request.headers = urlRequest.headers
+            completion(.success(request))
+        } else {
+            completion(.success(urlRequest))
+        }
+    }
+
+    func retry(_ request: Request, for session: Session, dueTo error: Error, completion: @escaping (RetryResult) -> Void) {
+        completion(hasRetried ? .doNotRetry : .retry)
+        hasRetried = true
+    }
+}
+
 extension RetryResult: Equatable {
     public static func ==(lhs: RetryResult, rhs: RetryResult) -> Bool {
         switch (lhs, rhs) {

+ 31 - 0
Tests/UploadTests.swift

@@ -641,6 +641,37 @@ class UploadMultipartFormDataTestCase: BaseTestCase {
     }
 }
 
+final class UploadRetryTests: BaseTestCase {
+    func testThatDataUploadRetriesCorrectly() {
+        // Given
+        let request = URLRequest.makeHTTPBinRequest(path: "delay/1",
+                                                    method: .post,
+                                                    headers: [.contentType("text/plain")],
+                                                    timeout: 0.1)
+        let retrier = InspectorInterceptor(SingleRetrier())
+        let didRetry = expectation(description: "request did retry")
+        retrier.onRetry = { _ in didRetry.fulfill() }
+        let session = Session(interceptor: retrier)
+        let body = "body"
+        let data = Data(body.utf8)
+        var response: AFDataResponse<HTTPBinResponse>?
+        let completion = expectation(description: "upload should complete")
+
+        // When
+        session.upload(data, with: request).responseDecodable(of: HTTPBinResponse.self) {
+            response = $0
+            completion.fulfill()
+        }
+
+        waitForExpectations(timeout: timeout)
+
+        // Then
+        XCTAssertEqual(retrier.retryCalledCount, 1)
+        XCTAssertTrue(response?.result.isSuccess == true)
+        XCTAssertEqual(response?.value?.data, body)
+    }
+}
+
 final class UploadRequestEventsTestCase: BaseTestCase {
     func testThatUploadRequestTriggersAllAppropriateLifetimeEvents() {
         // Given