Browse Source

Refactor to Configuration value.

Jon Shier 2 years ago
parent
commit
121fadc768
5 changed files with 239 additions and 53 deletions
  1. 40 23
      Source/Request.swift
  2. 51 7
      Source/Session.swift
  3. 15 2
      Tests/ConcurrencyTests.swift
  4. 26 8
      Tests/TestHelpers.swift
  5. 107 13
      Tests/WebSocketTests.swift

+ 40 - 23
Source/Request.swift

@@ -1686,6 +1686,32 @@ public final class WebSocketRequest: Request {
         public let error: AFError?
     }
 
+    public struct Configuration {
+        public static var `default`: Self { Self() }
+
+        public static func `protocol`(_ protocol: String) -> Self {
+            Self(protocol: `protocol`)
+        }
+
+        public static func maximumMessageSize(_ maximumMessageSize: Int) -> Self {
+            Self(maximumMessageSize: maximumMessageSize)
+        }
+
+        public static func pingInterval(_ pingInterval: TimeInterval) -> Self {
+            Self(pingInterval: pingInterval)
+        }
+
+        public let `protocol`: String?
+        public let maximumMessageSize: Int
+        public let pingInterval: TimeInterval?
+
+        init(protocol: String? = nil, maximumMessageSize: Int = 1_048_576, pingInterval: TimeInterval? = nil) {
+            self.protocol = `protocol`
+            self.maximumMessageSize = maximumMessageSize
+            self.pingInterval = pingInterval
+        }
+    }
+
     /// Response to a sent ping.
     public enum PingResponse {
         public struct Pong {
@@ -1717,24 +1743,18 @@ public final class WebSocketRequest: Request {
     }
 
     public let convertible: URLRequestConvertible
-    public let `protocol`: String?
-    public let maximumMessageSize: Int
-    public let pingInterval: TimeInterval?
+    public let configuration: Configuration
 
     init(id: UUID = UUID(),
          convertible: URLRequestConvertible,
-         protocol: String? = nil,
-         maximumMessageSize: Int,
-         pingInterval: TimeInterval?,
+         configuration: Configuration,
          underlyingQueue: DispatchQueue,
          serializationQueue: DispatchQueue,
          eventMonitor: EventMonitor?,
          interceptor: RequestInterceptor?,
          delegate: RequestDelegate) {
         self.convertible = convertible
-        self.protocol = `protocol`
-        self.maximumMessageSize = maximumMessageSize
-        self.pingInterval = pingInterval
+        self.configuration = configuration
 
         super.init(id: id,
                    underlyingQueue: underlyingQueue,
@@ -1747,13 +1767,13 @@ public final class WebSocketRequest: Request {
     override func task(for request: URLRequest, using session: URLSession) -> URLSessionTask {
         var copiedRequest = request
         let task: URLSessionWebSocketTask
-        if let `protocol` = `protocol` {
+        if let `protocol` = configuration.protocol {
             copiedRequest.headers.update(.websocketProtocol(`protocol`))
             task = session.webSocketTask(with: copiedRequest)
         } else {
             task = session.webSocketTask(with: copiedRequest)
         }
-        task.maximumMessageSize = maximumMessageSize
+        task.maximumMessageSize = configuration.maximumMessageSize
 
         return task
     }
@@ -1846,7 +1866,7 @@ public final class WebSocketRequest: Request {
             }
         }
 
-        if let pingInterval = pingInterval {
+        if let pingInterval = configuration.pingInterval {
             startAutomaticPing(onInterval: pingInterval)
         }
     }
@@ -1881,6 +1901,7 @@ public final class WebSocketRequest: Request {
     func startAutomaticPing(onInterval pingInterval: TimeInterval) {
         socketMutableState.write { mutableState in
             guard isResumed else {
+                // Defer out of lock.
                 defer { cancelAutomaticPing() }
                 return
             }
@@ -2048,31 +2069,27 @@ public final class WebSocketRequest: Request {
     public func send(_ message: URLSessionWebSocketTask.Message,
                      queue: DispatchQueue = .main,
                      completionHandler: @escaping (Result<Void, Error>) -> Void) {
-        guard !(isCancelled || isFinished) else {
-            // Error for attempting send while cancelled or finished?
-            // Probably just silently ignore.
-            return
-        }
+        guard !(isCancelled || isFinished) else { return }
 
         guard let socket = socket else {
+            // URLSessionWebSocketTask note created yet, enqueue the send.
             socketMutableState.write { mutableState in
                 mutableState.enqueuedSends.append((message, queue, completionHandler))
             }
+
             return
         }
 
-        underlyingQueue.async {
-            socket.send(message) { error in
-                queue.async {
-                    completionHandler(Result(value: (), error: error))
-                }
+        socket.send(message) { error in
+            queue.async {
+                completionHandler(Result(value: (), error: error))
             }
         }
     }
 }
 
 @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
-public protocol WebSocketMessageSerializer {
+public protocol WebSocketMessageSerializer<Output, Failure> {
     associatedtype Output
     associatedtype Failure: Error = Error
 

+ 51 - 7
Source/Session.swift

@@ -463,15 +463,59 @@ open class Session {
 
     #if canImport(Darwin) && !canImport(FoundationNetworking)
     @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
-    open func websocketRequest(to convertible: URLRequestConvertible,
-                               protocol: String? = nil,
-                               maximumMessageSize: Int = 1_048_576,
-                               pingInterval: TimeInterval? = nil,
+    open func webSocketRequest(
+        to url: URLConvertible,
+        configuration: WebSocketRequest.Configuration = .default,
+        headers: HTTPHeaders? = nil,
+        interceptor: RequestInterceptor? = nil,
+        requestModifier: RequestModifier? = nil
+    ) -> WebSocketRequest {
+        webSocketRequest(
+            to: url,
+            configuration: configuration,
+            parameters: Empty?.none,
+            encoder: URLEncodedFormParameterEncoder.default,
+            headers: headers,
+            interceptor: interceptor,
+            requestModifier: requestModifier
+        )
+    }
+
+    @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
+    open func webSocketRequest<Parameters>(
+        to url: URLConvertible,
+        configuration: WebSocketRequest.Configuration = .default,
+        parameters: Parameters? = nil,
+        encoder: ParameterEncoder = URLEncodedFormParameterEncoder.default,
+        headers: HTTPHeaders? = nil,
+        interceptor: RequestInterceptor? = nil,
+        requestModifier: RequestModifier? = nil
+    ) -> WebSocketRequest where Parameters: Encodable {
+        let convertible = RequestEncodableConvertible(url: url,
+                                                      method: .get,
+                                                      parameters: parameters,
+                                                      encoder: encoder,
+                                                      headers: headers,
+                                                      requestModifier: requestModifier)
+        let request = WebSocketRequest(convertible: convertible,
+                                       configuration: configuration,
+                                       underlyingQueue: rootQueue,
+                                       serializationQueue: serializationQueue,
+                                       eventMonitor: eventMonitor,
+                                       interceptor: interceptor,
+                                       delegate: self)
+
+        perform(request)
+
+        return request
+    }
+
+    @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
+    open func webSocketRequest(performing convertible: URLRequestConvertible,
+                               configuration: WebSocketRequest.Configuration = .default,
                                interceptor: RequestInterceptor? = nil) -> WebSocketRequest {
         let request = WebSocketRequest(convertible: convertible,
-                                       protocol: `protocol`,
-                                       maximumMessageSize: maximumMessageSize,
-                                       pingInterval: pingInterval,
+                                       configuration: configuration,
                                        underlyingQueue: rootQueue,
                                        serializationQueue: serializationQueue,
                                        eventMonitor: eventMonitor,

+ 15 - 2
Tests/ConcurrencyTests.swift

@@ -42,6 +42,19 @@ final class DataRequestConcurrencyTests: BaseTestCase {
         XCTAssertNotNil(value)
     }
 
+    func testThat500ResponseCanBeRetried() async throws {
+        // Given
+        let session = stored(Session())
+
+        // When
+        let value = try await session.request(.endpoints(.status(500), .method(.get)), interceptor: .retryPolicy)
+            .serializingResponse(using: .data)
+            .value
+
+        // Then
+        XCTAssertNotNil(value)
+    }
+
     func testThatDataTaskSerializesDecodable() async throws {
         // Given
         let session = stored(Session())
@@ -756,7 +769,7 @@ final class WebSocketConcurrencyTests: BaseTestCase {
         receivedEvent.expectedFulfillmentCount = 4
 
         // When
-        for await _ in session.websocketRequest(.websocket()).webSocketTask().streamingMessageEvents() {
+        for await _ in session.webSocketRequest(.websocket()).webSocketTask().streamingMessageEvents() {
             receivedEvent.fulfill()
         }
 
@@ -770,7 +783,7 @@ final class WebSocketConcurrencyTests: BaseTestCase {
         let session = stored(Session())
 
         // When
-        let messages = await session.websocketRequest(.websocket()).webSocketTask().streamingMessages().collect()
+        let messages = await session.webSocketRequest(.websocket()).webSocketTask().streamingMessages().collect()
 
         // Then
         XCTAssertTrue(messages.count == 1)

+ 26 - 8
Tests/TestHelpers.swift

@@ -323,6 +323,28 @@ extension Endpoint: URLConvertible {
     }
 }
 
+final class EndpointSequence: URLRequestConvertible {
+    enum Error: Swift.Error { case noRemainingEndpoints }
+
+    private var remainingEndpoints: [Endpoint]
+
+    init(endpoints: [Endpoint]) {
+        remainingEndpoints = endpoints
+    }
+
+    func asURLRequest() throws -> URLRequest {
+        guard !remainingEndpoints.isEmpty else { throw Error.noRemainingEndpoints }
+
+        return try remainingEndpoints.removeFirst().asURLRequest()
+    }
+}
+
+extension URLRequestConvertible where Self == EndpointSequence {
+    static func endpoints(_ endpoints: Endpoint...) -> Self {
+        EndpointSequence(endpoints: endpoints)
+    }
+}
+
 extension Session {
     func request(_ endpoint: Endpoint,
                  parameters: Parameters? = nil,
@@ -381,15 +403,11 @@ extension Session {
 
     #if canImport(Darwin) && !canImport(FoundationNetworking)
     @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
-    func websocketRequest(_ endpoint: Endpoint,
-                          protocol: String? = nil,
-                          maximumMessageSize: Int = 1_048_576,
-                          pingInterval: TimeInterval? = nil,
+    func webSocketRequest(_ endpoint: Endpoint,
+                          configuration: WebSocketRequest.Configuration = .default,
                           interceptor: RequestInterceptor? = nil) -> WebSocketRequest {
-        websocketRequest(to: endpoint as URLRequestConvertible,
-                         protocol: `protocol`,
-                         maximumMessageSize: maximumMessageSize,
-                         pingInterval: pingInterval,
+        webSocketRequest(performing: endpoint as URLRequestConvertible,
+                         configuration: configuration,
                          interceptor: interceptor)
     }
     #endif

+ 107 - 13
Tests/WebSocketTests.swift

@@ -29,7 +29,52 @@ final class WebSocketTests: BaseTestCase {
         var receivedCompletion: WebSocketRequest.Completion?
 
         // When
-        session.websocketRequest(.websocket()).streamMessageEvents { event in
+        session.webSocketRequest(.websocket()).streamMessageEvents { event in
+            switch event.kind {
+            case let .connected(`protocol`):
+                connectedProtocol = `protocol`
+                didConnect.fulfill()
+            case let .receivedMessage(receivedMessage):
+                message = receivedMessage
+                didReceiveMessage.fulfill()
+            case let .disconnected(code, reason):
+                closeCode = code
+                closeReason = reason
+                didDisconnect.fulfill()
+            case let .completed(completion):
+                receivedCompletion = completion
+                didComplete.fulfill()
+            }
+        }
+
+        wait(for: [didConnect, didReceiveMessage, didDisconnect, didComplete],
+             timeout: timeout,
+             enforceOrder: true)
+
+        // Then
+        XCTAssertNil(connectedProtocol)
+        XCTAssertNotNil(message)
+        XCTAssertEqual(closeCode, .normalClosure)
+        XCTAssertNil(closeReason)
+        XCTAssertNil(receivedCompletion?.error)
+    }
+
+    func testThatWebSocketsCanReceiveMessageEventsWithParameters() {
+        // Given
+        let didConnect = expectation(description: "didConnect")
+        let didReceiveMessage = expectation(description: "didReceiveMessage")
+        let didDisconnect = expectation(description: "didDisconnect")
+        let didComplete = expectation(description: "didComplete")
+        let session = stored(Session())
+
+        var connectedProtocol: String?
+        var message: URLSessionWebSocketTask.Message?
+        var closeCode: URLSessionWebSocketTask.CloseCode?
+        var closeReason: Data?
+        var receivedCompletion: WebSocketRequest.Completion?
+
+        // When
+        session.webSocketRequest(.websocket()).streamMessageEvents { event in
             switch event.kind {
             case let .connected(`protocol`):
                 connectedProtocol = `protocol`
@@ -68,7 +113,7 @@ final class WebSocketTests: BaseTestCase {
         var receivedMessage: URLSessionWebSocketTask.Message?
 
         // When
-        session.websocketRequest(.websocket()).streamMessages { message in
+        session.webSocketRequest(.websocket()).streamMessages { message in
             receivedMessage = message
             didReceiveMessage.fulfill()
         }
@@ -98,7 +143,7 @@ final class WebSocketTests: BaseTestCase {
         var receivedCompletion: WebSocketRequest.Completion?
 
         // When
-        session.websocketRequest(.websocketCount(1)).streamDecodableEvents(TestResponse.self) { event in
+        session.webSocketRequest(.websocketCount(1)).streamDecodableEvents(TestResponse.self) { event in
             switch event.kind {
             case let .connected(`protocol`):
                 connectedProtocol = `protocol`
@@ -140,7 +185,7 @@ final class WebSocketTests: BaseTestCase {
         var receivedValue: TestResponse?
 
         // When
-        session.websocketRequest(.websocket()).streamDecodable(TestResponse.self) { value in
+        session.webSocketRequest(.websocket()).streamDecodable(TestResponse.self) { value in
             receivedValue = value
             didReceiveValue.fulfill()
         }
@@ -170,7 +215,7 @@ final class WebSocketTests: BaseTestCase {
         var receivedCompletion: WebSocketRequest.Completion?
 
         // When
-        session.websocketRequest(.websocket(), protocol: `protocol`).streamMessageEvents { event in
+        session.webSocketRequest(.websocket(), configuration: .protocol(`protocol`)).streamMessageEvents { event in
             switch event.kind {
             case let .connected(`protocol`):
                 connectedProtocol = `protocol`
@@ -218,7 +263,7 @@ final class WebSocketTests: BaseTestCase {
         var receivedCompletion: WebSocketRequest.Completion?
 
         // When
-        session.websocketRequest(.websocketCount(count)).streamMessageEvents { event in
+        session.webSocketRequest(.websocketCount(count)).streamMessageEvents { event in
             switch event.kind {
             case let .connected(`protocol`):
                 connectedProtocol = `protocol`
@@ -263,7 +308,7 @@ final class WebSocketTests: BaseTestCase {
         var receivedCompletion: WebSocketRequest.Completion?
 
         // When
-        let request = session.websocketRequest(.websocketEcho)
+        let request = session.webSocketRequest(.websocketEcho)
         request.streamMessageEvents { event in
             switch event.kind {
             case let .connected(`protocol`):
@@ -307,7 +352,7 @@ final class WebSocketTests: BaseTestCase {
         var receivedCompletion: WebSocketRequest.Completion?
 
         // When
-        let request = session.websocketRequest(.websocketEcho)
+        let request = session.webSocketRequest(.websocketEcho)
         request.streamMessageEvents { event in
             switch event.kind {
             case let .connected(`protocol`):
@@ -352,7 +397,7 @@ final class WebSocketTests: BaseTestCase {
         var receivedCompletion: WebSocketRequest.Completion?
 
         // When
-        let request = session.websocketRequest(.websocketEcho)
+        let request = session.webSocketRequest(.websocketEcho)
         request.streamMessageEvents { event in
             switch event.kind {
             case let .connected(`protocol`):
@@ -414,7 +459,7 @@ final class WebSocketTests: BaseTestCase {
         var receivedCompletion: WebSocketRequest.Completion?
 
         // When
-        let request = session.websocketRequest(.websocketPings(), pingInterval: 0.01)
+        let request = session.webSocketRequest(.websocketPings(), configuration: .pingInterval(0.01))
         request.streamMessageEvents { event in
             switch event.kind {
             case let .connected(`protocol`):
@@ -452,7 +497,7 @@ final class WebSocketTests: BaseTestCase {
         var receivedCompletion: WebSocketRequest.Completion?
 
         // When
-        session.websocketRequest(.websocket(), maximumMessageSize: 1).streamMessageEvents { event in
+        session.webSocketRequest(.websocket(), configuration: .maximumMessageSize(1)).streamMessageEvents { event in
             switch event.kind {
             case let .connected(`protocol`):
                 connectedProtocol = `protocol`
@@ -487,7 +532,7 @@ final class WebSocketTests: BaseTestCase {
         var receivedCompletion: WebSocketRequest.Completion?
 
         // When
-        session.websocketRequest(.websocket(closeCode: .goingAway)).streamMessageEvents { event in
+        session.webSocketRequest(.websocket(closeCode: .goingAway)).streamMessageEvents { event in
             switch event.kind {
             case let .connected(`protocol`):
                 connectedProtocol = `protocol`
@@ -541,7 +586,7 @@ final class WebSocketTests: BaseTestCase {
         var secondReceivedCompletion: WebSocketRequest.Completion?
 
         // When
-        session.websocketRequest(.websocket(closeCode: .goingAway)).streamMessageEvents { event in
+        session.webSocketRequest(.websocket(closeCode: .goingAway)).streamMessageEvents { event in
             switch event.kind {
             case let .connected(`protocol`):
                 firstConnectedProtocol = `protocol`
@@ -594,6 +639,55 @@ final class WebSocketTests: BaseTestCase {
     }
 }
 
+@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
+final class WebSocketIntegrationTests: BaseTestCase {
+    func testThatWebSocketsCanReceiveMessageEventsAfterRetry() {
+        // Given
+        let didConnect = expectation(description: "didConnect")
+        let didReceiveMessage = expectation(description: "didReceiveMessage")
+        let didDisconnect = expectation(description: "didDisconnect")
+        let didComplete = expectation(description: "didComplete")
+        let session = stored(Session())
+
+        var connectedProtocol: String?
+        var message: URLSessionWebSocketTask.Message?
+        var closeCode: URLSessionWebSocketTask.CloseCode?
+        var closeReason: Data?
+        var receivedCompletion: WebSocketRequest.Completion?
+
+        // When
+        session.webSocketRequest(performing: .endpoints(.status(501), .websocket()), interceptor: .retryPolicy)
+            .streamMessageEvents { event in
+                switch event.kind {
+                case let .connected(`protocol`):
+                    connectedProtocol = `protocol`
+                    didConnect.fulfill()
+                case let .receivedMessage(receivedMessage):
+                    message = receivedMessage
+                    didReceiveMessage.fulfill()
+                case let .disconnected(code, reason):
+                    closeCode = code
+                    closeReason = reason
+                    didDisconnect.fulfill()
+                case let .completed(completion):
+                    receivedCompletion = completion
+                    didComplete.fulfill()
+                }
+            }
+
+        wait(for: [didConnect, didReceiveMessage, didDisconnect, didComplete],
+             timeout: 100,
+             enforceOrder: true)
+
+        // Then
+        XCTAssertNil(connectedProtocol)
+        XCTAssertNotNil(message)
+        XCTAssertEqual(closeCode, .normalClosure)
+        XCTAssertNil(closeReason)
+        XCTAssertNil(receivedCompletion?.error)
+    }
+}
+
 @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
 extension WebSocketRequest {
     @discardableResult