Browse Source

Support serializers, update naming.

Jon Shier 2 years ago
parent
commit
be0db535d8
3 changed files with 296 additions and 57 deletions
  1. 3 3
      Source/Concurrency.swift
  2. 158 42
      Source/Request.swift
  3. 135 12
      Tests/WebSocketTests.swift

+ 3 - 3
Source/Concurrency.swift

@@ -780,7 +780,7 @@ public struct WebSocketTask {
     ) -> StreamOf<WebSocketRequest.Event<URLSessionWebSocketTask.Message, Never>> {
         createStream(automaticallyCancelling: shouldAutomaticallyCancel,
                      bufferingPolicy: bufferingPolicy) { onEvent in
-            request.responseMessage(on: .streamCompletionQueue(forRequestID: request.id), handler: onEvent)
+            request.streamMessageEvents(on: .streamCompletionQueue(forRequestID: request.id), handler: onEvent)
         }
     }
 
@@ -813,8 +813,8 @@ public struct WebSocketTask {
     }
 
     /// Cancel the underlying `WebSocketRequest`.
-    public func cancel(with closeCode: URLSessionWebSocketTask.CloseCode, reason: Data? = nil) {
-        request.cancel(with: closeCode, reason: reason)
+    public func close(sending closeCode: URLSessionWebSocketTask.CloseCode, reason: Data? = nil) {
+        request.close(sending: closeCode, reason: reason)
     }
 
     public func cancel() {

+ 158 - 42
Source/Request.swift

@@ -1630,6 +1630,7 @@ public final class WebSocketRequest: Request {
         public enum Kind {
             case connected(protocol: String?)
             case receivedMessage(Success)
+            case serializerFailed(Failure)
             // Only received if the server disconnects or we cancel with code, not if we do a simple cancel or error.
             case disconnected(closeCode: URLSessionWebSocketTask.CloseCode, reason: Data?)
             case completed(Completion)
@@ -1649,8 +1650,8 @@ public final class WebSocketRequest: Request {
             self.kind = kind
         }
 
-        public func cancel(with closeCode: URLSessionWebSocketTask.CloseCode, reason: Data?) {
-            socket?.cancel(with: closeCode, reason: reason)
+        public func close(sending closeCode: URLSessionWebSocketTask.CloseCode, reason: Data? = nil) {
+            socket?.close(sending: closeCode, reason: reason)
         }
 
         public func cancel() {
@@ -1795,10 +1796,27 @@ public final class WebSocketRequest: Request {
 //        eventMonitor?.requestDidCancel(self)
 //    }
 
+    func didClose() {
+        dispatchPrecondition(condition: .onQueue(underlyingQueue))
+
+        mutableState.write { mutableState in
+            // Check whether error is cancellation or other websocket closing error.
+            // If so, remove it.
+            // Otherwise keep it.
+            if case let .sessionTaskFailed(error) = mutableState.error, (error as? URLError)?.code == .cancelled {
+                mutableState.error = nil
+            }
+//            mutableState.error = mutableState.error ?? AFError.explicitlyCancelled
+        }
+
+        // TODO: Still issue this event?
+        eventMonitor?.requestDidCancel(self)
+    }
+
     // TODO: Distinguish between cancellation and close behavior?
     // TODO: Reexamine cancellation behavior.
     @discardableResult
-    public func cancel(with closeCode: URLSessionWebSocketTask.CloseCode, reason: Data? = nil) -> Self {
+    public func close(sending closeCode: URLSessionWebSocketTask.CloseCode, reason: Data? = nil) -> Self {
         cancelTimedPing()
 
         mutableState.write { mutableState in
@@ -1806,7 +1824,7 @@ public final class WebSocketRequest: Request {
 
             mutableState.state = .cancelled
 
-            underlyingQueue.async { self.didCancel() }
+            underlyingQueue.async { self.didClose() }
 
             guard let task = mutableState.tasks.last, task.state != .completed else {
                 underlyingQueue.async { self.finish() }
@@ -1815,7 +1833,7 @@ public final class WebSocketRequest: Request {
 
             // Resume to ensure metrics are gathered.
             task.resume()
-            // Cast from state directly, otherwise the lock is recursive.
+            // Cast from state directly, not the property, otherwise the lock is recursive.
             (mutableState.tasks.last as? URLSessionWebSocketTask)?.cancel(with: closeCode, reason: reason)
             underlyingQueue.async { self.didCancelTask(task) }
         }
@@ -1823,6 +1841,13 @@ public final class WebSocketRequest: Request {
         return self
     }
 
+    @discardableResult
+    override public func cancel() -> Self {
+        cancelTimedPing()
+
+        return super.cancel()
+    }
+
     func didConnect(protocol: String?) {
         dispatchPrecondition(condition: .onQueue(underlyingQueue))
 
@@ -1918,6 +1943,7 @@ public final class WebSocketRequest: Request {
                         handler.queue.async { handler.handler(.receivedMessage(message)) }
                     }
                 }
+
                 self.listen(to: task)
             case let .failure(error):
                 NSLog("Receive for task: \(task), didFailWithError: \(error)")
@@ -1926,37 +1952,105 @@ public final class WebSocketRequest: Request {
     }
 
     @discardableResult
-    public func responseMessage(on queue: DispatchQueue = .main,
-                                handler: @escaping (Event<URLSessionWebSocketTask.Message, Never>) -> Void) -> Self {
+    public func streamSerializer<Serializer>(
+        _ serializer: Serializer,
+        on queue: DispatchQueue = .main,
+        handler: @escaping (_ event: Event<Serializer.Output, Serializer.Failure>) -> Void
+    ) -> Self where Serializer: WebSocketMessageSerializer, Serializer.Failure == Error {
+        forIncomingEvent(on: queue) { incomingEvent in
+            let event: Event<Serializer.Output, Serializer.Failure>
+            switch incomingEvent {
+            case let .connected(`protocol`):
+                event = .init(socket: self, kind: .connected(protocol: `protocol`))
+            case let .receivedMessage(message):
+                do {
+                    let serializedMessage = try serializer.decode(message)
+                    event = .init(socket: self, kind: .receivedMessage(serializedMessage))
+                } catch {
+                    event = .init(socket: self, kind: .serializerFailed(error))
+                }
+            case let .disconnected(closeCode, reason):
+                event = .init(socket: self, kind: .disconnected(closeCode: closeCode, reason: reason))
+            case let .completed(completion):
+                event = .init(socket: self, kind: .completed(completion))
+            }
+
+            queue.async { handler(event) }
+        }
+    }
+
+    @discardableResult
+    public func streamDecodableEvents<Value>(
+        _ type: Value.Type = Value.self,
+        on queue: DispatchQueue = .main,
+        using decoder: DataDecoder = JSONDecoder(),
+        handler: @escaping (_ event: Event<Value, Error>) -> Void
+    ) -> Self where Value: Decodable {
+        streamSerializer(DecodableWebSocketMessageDecoder<Value>(decoder: decoder), on: queue, handler: handler)
+    }
+
+    @discardableResult
+    public func streamDecodable<Value>(
+        _ type: Value.Type = Value.self,
+        on queue: DispatchQueue = .main,
+        using decoder: DataDecoder = JSONDecoder(),
+        handler: @escaping (_ value: Value) -> Void
+    ) -> Self where Value: Decodable {
+        streamDecodableEvents(Value.self, on: queue) { event in
+            event.message.map(handler)
+        }
+    }
+
+    @discardableResult
+    public func streamMessageEvents(
+        on queue: DispatchQueue = .main,
+        handler: @escaping (_ event: Event<URLSessionWebSocketTask.Message, Never>) -> Void
+    ) -> Self {
+        forIncomingEvent(on: queue) { incomingEvent in
+            let event: Event<URLSessionWebSocketTask.Message, Never>
+            switch incomingEvent {
+            case let .connected(`protocol`):
+                event = .init(socket: self, kind: .connected(protocol: `protocol`))
+            case let .receivedMessage(message):
+                event = .init(socket: self, kind: .receivedMessage(message))
+            case let .disconnected(closeCode, reason):
+                event = .init(socket: self, kind: .disconnected(closeCode: closeCode, reason: reason))
+            case let .completed(completion):
+                event = .init(socket: self, kind: .completed(completion))
+            }
+
+            queue.async { handler(event) }
+        }
+    }
+
+    @discardableResult
+    public func streamMessages(
+        on queue: DispatchQueue = .main,
+        handler: @escaping (_ message: URLSessionWebSocketTask.Message) -> Void
+    ) -> Self {
+        streamMessageEvents(on: queue) { event in
+            event.message.map(handler)
+        }
+    }
+
+    func forIncomingEvent(on queue: DispatchQueue, handler: @escaping (IncomingEvent) -> Void) -> Self {
         socketMutableState.write { state in
             state.handlers.append((queue: queue, handler: { incomingEvent in
-                queue.async {
-                    switch incomingEvent {
-                    case let .connected(`protocol`):
-                        handler(.init(socket: self, kind: .connected(protocol: `protocol`)))
-                    case let .receivedMessage(message):
-                        // TODO: Call serializers.
-                        handler(.init(socket: self, kind: .receivedMessage(message)))
-                    case let .disconnected(closeCode, reason):
-                        handler(.init(socket: self, kind: .disconnected(closeCode: closeCode, reason: reason)))
-                    case let .completed(completion):
-                        handler(.init(socket: self, kind: .completed(completion)))
-                    }
+                self.serializationQueue.async {
+                    handler(incomingEvent)
                 }
             }))
         }
 
         appendResponseSerializer {
-            self.underlyingQueue.async {
-                self.responseSerializerDidComplete {
-                    self.socketMutableState.read { state in
-                        state.handlers.forEach { handler in
-                            handler.queue.async {
-                                handler.handler(.completed(.init(request: self.request,
-                                                                 response: self.response,
-                                                                 metrics: self.metrics,
-                                                                 error: self.error)))
-                            }
+            self.responseSerializerDidComplete {
+                self.socketMutableState.read { state in
+                    state.handlers.forEach { storedHandler in
+                        storedHandler.queue.async {
+                            storedHandler.handler(.completed(.init(request: self.request,
+                                                                   response: self.response,
+                                                                   metrics: self.metrics,
+                                                                   error: self.error)))
                         }
                     }
                 }
@@ -1966,15 +2060,6 @@ public final class WebSocketRequest: Request {
         return self
     }
 
-//    @discardableResult
-//    public func responseDecodable<Value: Decodable>(on queue: DispatchQueue = .main,
-//                                                    using decoder: WebSocketMessageDecoder,
-//                                                    handler: @escaping (Event<Value, Error>) -> Void) -> Self {
-//        responseMessage(on: queue) { event in
-//            event
-//        }
-//    }
-
     public func send(_ message: URLSessionWebSocketTask.Message,
                      queue: DispatchQueue = .main,
                      completionHandler: @escaping (Result<Void, Error>) -> Void) {
@@ -2002,19 +2087,50 @@ public final class WebSocketRequest: Request {
 }
 
 @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
-protocol WebSocketMessageDecoder {
+public protocol WebSocketMessageSerializer {
     associatedtype Output
+    associatedtype Failure = Error
 
     func decode(_ message: URLSessionWebSocketTask.Message) throws -> Output
 }
 
 @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
-public struct DecodableWebSocketMessageDecoder<Value: Decodable>: WebSocketMessageDecoder {
-    struct UnknownMessage: Error {}
+extension WebSocketMessageSerializer {
+    public static func json<Value>(
+        decoder: JSONDecoder = JSONDecoder()
+    ) -> DecodableWebSocketMessageDecoder<Value> where Self == DecodableWebSocketMessageDecoder<Value> {
+        .json(decoder: decoder)
+    }
+
+    static var passthrough: PassthroughWebSocketMessageDecoder {
+        PassthroughWebSocketMessageDecoder()
+    }
+}
+
+@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
+struct PassthroughWebSocketMessageDecoder: WebSocketMessageSerializer {
+    public typealias Failure = Never
+
+    public func decode(_ message: URLSessionWebSocketTask.Message) -> URLSessionWebSocketTask.Message {
+        message
+    }
+}
+
+@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
+public struct DecodableWebSocketMessageDecoder<Value: Decodable>: WebSocketMessageSerializer {
+    public static func json(decoder: JSONDecoder = JSONDecoder()) -> DecodableWebSocketMessageDecoder<Value> {
+        DecodableWebSocketMessageDecoder(decoder: decoder)
+    }
+
+    public struct UnknownMessage: Error {}
 
     public let decoder: DataDecoder
 
-    func decode(_ message: URLSessionWebSocketTask.Message) throws -> Value {
+    public init(decoder: DataDecoder) {
+        self.decoder = decoder
+    }
+
+    public func decode(_ message: URLSessionWebSocketTask.Message) throws -> Value {
         switch message {
         case let .data(data):
             return try decoder.decode(Value.self, from: data)

+ 135 - 12
Tests/WebSocketTests.swift

@@ -16,7 +16,7 @@ import XCTest
 final class WebSocketTests: BaseTestCase {
 //    override var skipVersion: SkipVersion { .twenty }
 
-    func testThatWebSocketsCanReceiveAMessage() {
+    func testThatWebSocketsCanReceiveMessageEvents() {
         // Given
         let didConnect = expectation(description: "didConnect")
         let didReceiveMessage = expectation(description: "didReceiveMessage")
@@ -31,7 +31,7 @@ final class WebSocketTests: BaseTestCase {
         var receivedCompletion: WebSocketRequest.Completion?
 
         // When
-        session.websocketRequest(.websocket()).responseMessage { event in
+        session.websocketRequest(.websocket()).streamMessageEvents { event in
             switch event.kind {
             case let .connected(`protocol`):
                 connectedProtocol = `protocol`
@@ -61,6 +61,94 @@ final class WebSocketTests: BaseTestCase {
         XCTAssertNil(receivedCompletion?.error)
     }
 
+    func testThatWebSocketsCanReceiveAMessage() {
+        // Given
+        let didReceiveMessage = expectation(description: "didReceiveMessage")
+
+        let session = stored(Session())
+
+        var receivedMessage: URLSessionWebSocketTask.Message?
+
+        // When
+        session.websocketRequest(.websocket()).streamMessages { message in
+            receivedMessage = message
+            didReceiveMessage.fulfill()
+        }
+
+        waitForExpectations(timeout: timeout)
+
+        // Then
+        XCTAssertNotNil(receivedMessage)
+        XCTAssertNotNil(receivedMessage?.data)
+    }
+
+    func testThatWebSocketsCanReceiveADecodableMessage() {
+        // Given
+        let didConnect = expectation(description: "didConnect")
+        let didReceiveMessage = expectation(description: "didReceiveMessage")
+        let didDisconnect = expectation(description: "didDisconnect")
+        let didComplete = expectation(description: "didComplete")
+        let session = stored(Session())
+
+        var connectedProtocol: String?
+        var message: TestResponse?
+        var closeCode: URLSessionWebSocketTask.CloseCode?
+        var closeReason: Data?
+        var receivedCompletion: WebSocketRequest.Completion?
+
+        // When
+        session.websocketRequest(.websocketCount(1)).streamDecodableEvents(TestResponse.self) { event in
+            switch event.kind {
+            case let .connected(`protocol`):
+                connectedProtocol = `protocol`
+                didConnect.fulfill()
+            case let .receivedMessage(receivedMessage):
+                message = receivedMessage
+                didReceiveMessage.fulfill()
+            case let .serializerFailed(error):
+                XCTFail("websocket message serialization failed with error: \(error)")
+            case let .disconnected(code, reason):
+                closeCode = code
+                closeReason = reason
+                didDisconnect.fulfill()
+            case let .completed(completion):
+                receivedCompletion = completion
+                didComplete.fulfill()
+            }
+        }
+
+        wait(for: [didConnect, didReceiveMessage, didDisconnect, didComplete],
+             timeout: timeout,
+             enforceOrder: false)
+
+        // Then
+        XCTAssertNil(connectedProtocol)
+        XCTAssertNotNil(message)
+        XCTAssertEqual(closeCode, .normalClosure)
+        XCTAssertNil(closeReason)
+        XCTAssertNil(receivedCompletion?.error)
+    }
+
+    func testThatWebSocketsCanReceiveADecodableValue() {
+        // Given
+        let didReceiveValue = expectation(description: "didReceiveMessage")
+
+        let session = stored(Session())
+
+        var receivedValue: TestResponse?
+
+        // When
+        session.websocketRequest(.websocket()).streamDecodable(TestResponse.self) { value in
+            receivedValue = value
+            didReceiveValue.fulfill()
+        }
+
+        waitForExpectations(timeout: timeout)
+
+        // Then
+        XCTAssertNotNil(receivedValue)
+    }
+
     func testThatWebSocketsCanReceiveAMessageWithAProtocol() {
         // Given
         let didConnect = expectation(description: "didConnect")
@@ -77,7 +165,7 @@ final class WebSocketTests: BaseTestCase {
         var receivedCompletion: WebSocketRequest.Completion?
 
         // When
-        session.websocketRequest(.websocket(), protocol: `protocol`).responseMessage { event in
+        session.websocketRequest(.websocket(), protocol: `protocol`).streamMessageEvents { event in
             switch event.kind {
             case let .connected(`protocol`):
                 connectedProtocol = `protocol`
@@ -125,7 +213,7 @@ final class WebSocketTests: BaseTestCase {
         var receivedCompletion: WebSocketRequest.Completion?
 
         // When
-        session.websocketRequest(.websocketCount(count)).responseMessage { event in
+        session.websocketRequest(.websocketCount(count)).streamMessageEvents { event in
             switch event.kind {
             case let .connected(`protocol`):
                 connectedProtocol = `protocol`
@@ -171,7 +259,7 @@ final class WebSocketTests: BaseTestCase {
 
         // When
         let request = session.websocketRequest(.websocketEcho)
-        request.responseMessage { event in
+        request.streamMessageEvents { event in
             switch event.kind {
             case let .connected(`protocol`):
                 connectedProtocol = `protocol`
@@ -179,7 +267,7 @@ final class WebSocketTests: BaseTestCase {
                 request.send(sentMessage) { _ in didSend.fulfill() }
             case let .receivedMessage(receivedMessage):
                 message = receivedMessage
-                event.cancel(with: .normalClosure, reason: nil)
+                event.close(sending: .normalClosure)
                 didReceiveMessage.fulfill()
             case let .disconnected(code, reason):
                 closeCode = code
@@ -201,7 +289,42 @@ final class WebSocketTests: BaseTestCase {
         XCTAssertEqual(sentMessage, message)
         XCTAssertEqual(closeCode, .normalClosure)
         XCTAssertNil(closeReason)
-//        XCTAssertNil(receivedCompletion?.error)
+        XCTAssertNil(receivedCompletion?.error)
+    }
+
+    func testThatWebSocketsCanBeCancelled() {
+        // Given
+        let didConnect = expectation(description: "didConnect")
+        let didComplete = expectation(description: "didComplete")
+        let session = stored(Session())
+
+        var connectedProtocol: String?
+        var receivedCompletion: WebSocketRequest.Completion?
+
+        // When
+        let request = session.websocketRequest(.websocketEcho)
+        request.streamMessageEvents { event in
+            switch event.kind {
+            case let .connected(`protocol`):
+                connectedProtocol = `protocol`
+                didConnect.fulfill()
+                request.cancel()
+            case let .receivedMessage(receivedMessage):
+                XCTFail("cancelled socket received message: \(receivedMessage)")
+            case .disconnected:
+                XCTFail("cancelled socket shouldn't receive disconnected event")
+            case let .completed(completion):
+                receivedCompletion = completion
+                didComplete.fulfill()
+            }
+        }
+
+        wait(for: [didConnect, didComplete], timeout: timeout, enforceOrder: true)
+
+        // Then
+        XCTAssertNil(connectedProtocol)
+        XCTAssertTrue(receivedCompletion?.error?.isExplicitlyCancelledError == true)
+        XCTAssertTrue(request.error?.isExplicitlyCancelledError == true)
     }
 
     func testOnePingOnly() {
@@ -225,7 +348,7 @@ final class WebSocketTests: BaseTestCase {
 
         // When
         let request = session.websocketRequest(.websocketEcho)
-        request.responseMessage { event in
+        request.streamMessageEvents { event in
             switch event.kind {
             case let .connected(`protocol`):
                 connectedProtocol = `protocol`
@@ -244,7 +367,7 @@ final class WebSocketTests: BaseTestCase {
                         }
                         didReceivePong.fulfill()
                         if count == 99 {
-                            request.cancel(with: .normalClosure, reason: nil)
+                            request.close(sending: .normalClosure)
                         }
                     }
                 }
@@ -287,7 +410,7 @@ final class WebSocketTests: BaseTestCase {
 
         // When
         let request = session.websocketRequest(.websocketPings(), pingInterval: 0.01)
-        request.responseMessage { event in
+        request.streamMessageEvents { event in
             switch event.kind {
             case let .connected(`protocol`):
                 connectedProtocol = `protocol`
@@ -324,7 +447,7 @@ final class WebSocketTests: BaseTestCase {
         var receivedCompletion: WebSocketRequest.Completion?
 
         // When
-        session.websocketRequest(.websocket(), maximumMessageSize: 1).responseMessage { event in
+        session.websocketRequest(.websocket(), maximumMessageSize: 1).streamMessageEvents { event in
             switch event.kind {
             case let .connected(`protocol`):
                 connectedProtocol = `protocol`
@@ -359,7 +482,7 @@ final class WebSocketTests: BaseTestCase {
         var receivedCompletion: WebSocketRequest.Completion?
 
         // When
-        session.websocketRequest(.websocket(closeCode: .goingAway)).responseMessage { event in
+        session.websocketRequest(.websocket(closeCode: .goingAway)).streamMessageEvents { event in
             switch event.kind {
             case let .connected(`protocol`):
                 connectedProtocol = `protocol`