Browse Source

Retain the client call in async calls until we have received a response. (#631)

* Retain the client call in async calls until we have received a response.

This means that simply releasing the client will no longer cancel any in-flight requests; instead, the user needs to explicitly call client.channel.shutdown() for that to happen.

* Fix the test.
Daniel Alm 6 years ago
parent
commit
99fc4a8d6b

+ 2 - 0
Sources/SwiftGRPC/Runtime/ClientCall.swift

@@ -26,10 +26,12 @@ open class ClientCallBase {
   open class var method: String { fatalError("needs to be overridden") }
 
   public let call: Call
+  internal let channel: Channel  // To ensure that the channel is not deallocated before this call has finished.
 
   /// Create a call.
   public init(_ channel: Channel) throws {
     self.call = try channel.makeCall(type(of: self).method)
+    self.channel = channel
   }
 }
 

+ 6 - 3
Sources/SwiftGRPC/Runtime/ClientCallBidirectionalStreaming.swift

@@ -30,9 +30,12 @@ open class ClientCallBidirectionalStreamingBase<InputType: Message, OutputType:
   public typealias SentType = InputType
   
   /// Call this to start a call. Nonblocking.
-  public func start(metadata: Metadata, completion: ((CallResult) -> Void)?)
-    throws -> Self {
-    try call.start(.bidiStreaming, metadata: metadata, completion: completion)
+  public func start(metadata: Metadata, completion: ((CallResult) -> Void)?) throws -> Self {
+    try call.start(.bidiStreaming, metadata: metadata) { result in
+      withExtendedLifetime(self) {  // retain `self` (and, transitively, the channel) until the call has finished.
+        completion?(result)
+      }
+    }
     return self
   }
 

+ 5 - 1
Sources/SwiftGRPC/Runtime/ClientCallClientStreaming.swift

@@ -30,7 +30,11 @@ open class ClientCallClientStreamingBase<InputType: Message, OutputType: Message
   
   /// Call this to start a call. Nonblocking.
   public func start(metadata: Metadata, completion: ((CallResult) -> Void)?) throws -> Self {
-    try call.start(.clientStreaming, metadata: metadata, completion: completion)
+    try call.start(.clientStreaming, metadata: metadata) { result in
+      withExtendedLifetime(self) {  // retain `self` (and, transitively, the channel) until the call has finished.
+        completion?(result)
+      }
+    }
     return self
   }
 

+ 5 - 4
Sources/SwiftGRPC/Runtime/ClientCallServerStreaming.swift

@@ -29,10 +29,11 @@ open class ClientCallServerStreamingBase<InputType: Message, OutputType: Message
   /// Call this once with the message to send. Nonblocking.
   public func start(request: InputType, metadata: Metadata, completion: ((CallResult) -> Void)?) throws -> Self {
     let requestData = try request.serializedData()
-    try call.start(.serverStreaming,
-                   metadata: metadata,
-                   message: requestData,
-                   completion: completion)
+    try call.start(.serverStreaming, metadata: metadata, message: requestData) { result in
+      withExtendedLifetime(self) {  // retain `self` (and, transitively, the channel) until the call has finished.
+        completion?(result)
+      }
+    }
     return self
   }
 }

+ 6 - 4
Sources/SwiftGRPC/Runtime/ClientCallUnary.swift

@@ -47,10 +47,12 @@ open class ClientCallUnaryBase<InputType: Message, OutputType: Message>: ClientC
                     completion: @escaping ((OutputType?, CallResult) -> Void)) throws -> Self {
     let requestData = try request.serializedData()
     try call.start(.unary, metadata: metadata, message: requestData) { callResult in
-      if let responseData = callResult.resultData {
-        completion(try? OutputType(serializedData: responseData), callResult)
-      } else {
-        completion(nil, callResult)
+      withExtendedLifetime(self) {  // retain `self` (and, transitively, the channel) until the call has finished.
+        if let responseData = callResult.resultData {
+          completion(try? OutputType(serializedData: responseData), callResult)
+        } else {
+          completion(nil, callResult)
+        }
       }
     }
     return self

+ 1 - 0
Tests/LinuxMain.swift

@@ -20,6 +20,7 @@ import XCTest
 XCTMain([
   // SwiftGRPC
   testCase(gRPCTests.allTests),
+  testCase(AsyncClientTests.allTests),
   testCase(ChannelArgumentTests.allTests),
   testCase(ChannelConnectivityTests.allTests),
   testCase(ChannelShutdownTests.allTests),

+ 95 - 0
Tests/SwiftGRPCTests/AsyncClientTests.swift

@@ -0,0 +1,95 @@
+/*
+ * Copyright 2018, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import Dispatch
+import Foundation
+@testable import SwiftGRPC
+import XCTest
+
+class AsyncClientTests: BasicEchoTestCase {
+  // Using `TimingOutEchoProvider` gives us enough time to release the client before we can expect a result.
+  override func makeProvider() -> Echo_EchoProvider { return TimingOutEchoProvider() }
+
+  static var allTests: [(String, (AsyncClientTests) -> () throws -> Void)] {
+    return [
+      ("testAsyncUnaryRetainsClientUntilCallFinished", testAsyncUnaryRetainsClientUntilCallFinished),
+      ("testClientStreamingRetainsClientUntilCallFinished", testClientStreamingRetainsClientUntilCallFinished),
+      ("testServerStreamingRetainsClientUntilCallFinished", testServerStreamingRetainsClientUntilCallFinished),
+      ("testBidiStreamingRetainsClientUntilCallFinished", testBidiStreamingRetainsClientUntilCallFinished),
+    ]
+  }
+}
+
+extension AsyncClientTests {
+  func testAsyncUnaryRetainsClientUntilCallFinished() {
+    let completionHandlerExpectation = expectation(description: "completion handler called")
+    _ = try! client.get(Echo_EchoRequest(text: "foo")) { response, result in
+      XCTAssertEqual("", response?.text)
+      XCTAssertEqual(.ok, result.statusCode)
+      completionHandlerExpectation.fulfill()
+    }
+    // The call should complete even when the client and call are not retained.
+    client = nil
+
+    waitForExpectations(timeout: 1.0)
+  }
+
+  func testClientStreamingRetainsClientUntilCallFinished() {
+    let callCompletionHandlerExpectation = expectation(description: "call completion handler called")
+    var call: Echo_EchoCollectCall? = try! client.collect { result in
+      XCTAssertEqual(.ok, result.statusCode)
+      callCompletionHandlerExpectation.fulfill()
+    }
+    let responseCompletionHandlerExpectation = expectation(description: "response completion handler called")
+    try! call!.closeAndReceive { response in
+      XCTAssertEqual("", response.result?.text)
+      responseCompletionHandlerExpectation.fulfill()
+    }
+    call = nil
+    // The call should complete even when the client and call are not retained.
+    client = nil
+
+    waitForExpectations(timeout: 1.0)
+  }
+
+  func testServerStreamingRetainsClientUntilCallFinished() {
+    let callCompletionHandlerExpectation = expectation(description: "call completion handler called")
+    _ = try! client.expand(.init()) { result in
+      XCTAssertEqual(.ok, result.statusCode)
+      callCompletionHandlerExpectation.fulfill()
+    }
+    // The call should complete even when the client and call are not retained.
+    client = nil
+
+    waitForExpectations(timeout: 1.0)
+  }
+
+  func testBidiStreamingRetainsClientUntilCallFinished() {
+    let callCompletionHandlerExpectation = expectation(description: "call completion handler called")
+    var call: Echo_EchoUpdateCall? = try! client.update { result in
+      XCTAssertEqual(.ok, result.statusCode)
+      callCompletionHandlerExpectation.fulfill()
+    }
+    let closeSendCompletionHandlerExpectation = expectation(description: "closeSend completion handler called")
+    try! call!.closeSend {
+      closeSendCompletionHandlerExpectation.fulfill()
+    }
+    call = nil
+    // The call should complete even when the client and call are not retained.
+    client = nil
+
+    waitForExpectations(timeout: 1.0)
+  }
+}

+ 1 - 1
Tests/SwiftGRPCTests/ServerTimeoutTests.swift

@@ -18,7 +18,7 @@ import Foundation
 @testable import SwiftGRPC
 import XCTest
 
-fileprivate class TimingOutEchoProvider: Echo_EchoProvider {
+class TimingOutEchoProvider: Echo_EchoProvider {
   func get(request: Echo_EchoRequest, session _: Echo_EchoGetSession) throws -> Echo_EchoResponse {
     Thread.sleep(forTimeInterval: 0.1)
     return Echo_EchoResponse()