Browse Source

Add completion events to tests.

Jon Shier 2 years ago
parent
commit
75506e8679
2 changed files with 23 additions and 3 deletions
  1. 1 1
      Source/SessionDelegate.swift
  2. 22 2
      Tests/WebSocketTests.swift

+ 1 - 1
Source/SessionDelegate.swift

@@ -46,7 +46,7 @@ open class SessionDelegate: NSObject {
     ///   - type: The `Request` subclass type to cast any `Request` associate with `task`.
     ///   - type: The `Request` subclass type to cast any `Request` associate with `task`.
     func request<R: Request>(for task: URLSessionTask, as type: R.Type) -> R? {
     func request<R: Request>(for task: URLSessionTask, as type: R.Type) -> R? {
         guard let provider = stateProvider else {
         guard let provider = stateProvider else {
-            assertionFailure("StateProvider is nil.")
+            assertionFailure("StateProvider is nil for task \(task.taskIdentifier).")
             return nil
             return nil
         }
         }
 
 

+ 22 - 2
Tests/WebSocketTests.swift

@@ -62,6 +62,7 @@ final class WebSocketTests: BaseTestCase {
     func testThatWebSocketsCanReceiveAMessage() {
     func testThatWebSocketsCanReceiveAMessage() {
         // Given
         // Given
         let didReceiveMessage = expectation(description: "didReceiveMessage")
         let didReceiveMessage = expectation(description: "didReceiveMessage")
+        let didComplete = expectation(description: "didComplete")
         let session = stored(Session())
         let session = stored(Session())
 
 
         var receivedMessage: URLSessionWebSocketTask.Message?
         var receivedMessage: URLSessionWebSocketTask.Message?
@@ -71,8 +72,11 @@ final class WebSocketTests: BaseTestCase {
             receivedMessage = message
             receivedMessage = message
             didReceiveMessage.fulfill()
             didReceiveMessage.fulfill()
         }
         }
+        .onCompletion {
+            didComplete.fulfill()
+        }
 
 
-        waitForExpectations(timeout: timeout)
+        wait(for: [didReceiveMessage, didComplete], timeout: timeout, enforceOrder: true)
 
 
         // Then
         // Then
         XCTAssertNotNil(receivedMessage)
         XCTAssertNotNil(receivedMessage)
@@ -129,6 +133,7 @@ final class WebSocketTests: BaseTestCase {
     func testThatWebSocketsCanReceiveADecodableValue() {
     func testThatWebSocketsCanReceiveADecodableValue() {
         // Given
         // Given
         let didReceiveValue = expectation(description: "didReceiveMessage")
         let didReceiveValue = expectation(description: "didReceiveMessage")
+        let didComplete = expectation(description: "didComplete")
 
 
         let session = stored(Session())
         let session = stored(Session())
 
 
@@ -139,8 +144,11 @@ final class WebSocketTests: BaseTestCase {
             receivedValue = value
             receivedValue = value
             didReceiveValue.fulfill()
             didReceiveValue.fulfill()
         }
         }
+        .onCompletion {
+            didComplete.fulfill()
+        }
 
 
-        waitForExpectations(timeout: timeout)
+        wait(for: [didReceiveValue, didComplete], timeout: timeout, enforceOrder: true)
 
 
         // Then
         // Then
         XCTAssertNotNil(receivedValue)
         XCTAssertNotNil(receivedValue)
@@ -586,6 +594,18 @@ final class WebSocketTests: BaseTestCase {
     }
     }
 }
 }
 
 
+@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
+extension WebSocketRequest {
+    @discardableResult
+    func onCompletion(queue: DispatchQueue = .main, handler: @escaping () -> Void) -> Self {
+        streamMessageEvents(on: queue) { event in
+            guard case .completed = event.kind else { return }
+
+            handler()
+        }
+    }
+}
+
 @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
 @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
 extension URLSessionWebSocketTask.Message: Equatable {
 extension URLSessionWebSocketTask.Message: Equatable {
     public static func ==(lhs: URLSessionWebSocketTask.Message, rhs: URLSessionWebSocketTask.Message) -> Bool {
     public static func ==(lhs: URLSessionWebSocketTask.Message, rhs: URLSessionWebSocketTask.Message) -> Bool {