Jelajahi Sumber

Make trying to perform an operation on a shut-down completion queue throw an error, and add tests for that.

Daniel Alm 7 tahun lalu
induk
melakukan
abb9cf2dc2

+ 9 - 6
Sources/SwiftGRPC/Core/Call.swift

@@ -84,12 +84,15 @@ public class Call {
   /// - Returns: the result of initiating the call
   /// - Throws: `CallError` if fails to call.
   func perform(_ operations: OperationGroup) throws {
-    completionQueue.register(operations)
-    Call.callMutex.lock()
-    let error = cgrpc_call_perform(underlyingCall, operations.underlyingOperations, operations.tag)
-    Call.callMutex.unlock()
-    if error != GRPC_CALL_OK {
-      throw CallError.callError(grpcCallError: error)
+    try completionQueue.register(operations) {
+      Call.callMutex.lock()
+      // We need to do the perform *inside* the `completionQueue.register` call, to ensure that the queue can't get
+      // shutdown in between registering the operation group and calling `cgrpc_call_perform`.
+      let error = cgrpc_call_perform(underlyingCall, operations.underlyingOperations, operations.tag)
+      Call.callMutex.unlock()
+      if error != GRPC_CALL_OK {
+        throw CallError.callError(grpcCallError: error)
+      }
     }
   }
 

+ 2 - 0
Sources/SwiftGRPC/Core/CallError.swift

@@ -36,6 +36,8 @@ public enum CallError: Error {
   case batchTooBig
   case payloadTypeMismatch
   
+  case completionQueueShutdown
+  
   static func callError(grpcCallError error: grpc_call_error) -> CallError {
     switch error {
     case GRPC_CALL_OK:

+ 15 - 14
Sources/SwiftGRPC/Core/CompletionQueue.swift

@@ -79,6 +79,10 @@ class CompletionQueue {
   }
   
   deinit {
+    operationGroupsMutex.synchronize {
+      hasBeenShutdown = true
+    }
+    cgrpc_completion_queue_shutdown(underlyingCompletionQueue)
     cgrpc_completion_queue_drain(underlyingCompletionQueue)
     grpc_completion_queue_destroy(underlyingCompletionQueue)
   }
@@ -92,21 +96,16 @@ class CompletionQueue {
     return CompletionQueueEvent(event)
   }
 
-  /// Register an operation group for handling upon completion
+  /// Register an operation group for handling upon completion. Returns true if the operation group was registered
+  /// successfully.
   ///
   /// - Parameter operationGroup: the operation group to handle
-  func register(_ operationGroup: OperationGroup) {
-    operationGroupsMutex.synchronize {
-      if !hasBeenShutdown {
-        operationGroups[operationGroup.tag] = operationGroup
-      } else {
-        // The queue has been shut down already, so there's no spinloop to call the operation group's completion handler
-        // on. To guarantee that the completion handler gets called, we'll enqueue it right now.
-        DispatchQueue.global().async {
-          operationGroup.success = false
-          operationGroup.completion?(operationGroup)
-        }
-      }
+  func register(_ operationGroup: OperationGroup, onSuccess: () throws -> Void) rethrows {
+    try operationGroupsMutex.synchronize {
+      guard !hasBeenShutdown
+        else { throw CallError.completionQueueShutdown }
+      operationGroups[operationGroup.tag] = operationGroup
+      try onSuccess()
     }
   }
 
@@ -138,7 +137,6 @@ class CompletionQueue {
           self.operationGroupsMutex.lock()
           let currentOperationGroups = self.operationGroups
           self.operationGroups = [:]
-          self.hasBeenShutdown = true
           self.operationGroupsMutex.unlock()
           
           for operationGroup in currentOperationGroups.values {
@@ -165,6 +163,9 @@ class CompletionQueue {
 
   /// Shuts down a completion queue
   func shutdown() {
+    operationGroupsMutex.synchronize {
+      hasBeenShutdown = true
+    }
     cgrpc_completion_queue_shutdown(underlyingCompletionQueue)
   }
 }

+ 3 - 4
Sources/SwiftGRPC/Core/Mutex.swift

@@ -53,9 +53,8 @@ public class Mutex {
   /// Runs a block within a locked mutex
   ///
   /// Parameter block: the code to run while the mutex is locked
-  public func synchronize(block: () throws -> Void) rethrows {
-    lock()
-    try block()
-    unlock()
+  public func synchronize<T>(block: () throws -> T) rethrows -> T {
+    lock(); defer { unlock() }
+    return try block()
   }
 }

+ 1 - 0
Tests/LinuxMain.swift

@@ -21,6 +21,7 @@ XCTMain([
   testCase(ClientCancellingTests.allTests),
   testCase(ClientTestExample.allTests),
   testCase(ClientTimeoutTests.allTests),
+  testCase(CompletionQueueTests.allTests),
   testCase(ConnectionFailureTests.allTests),
   testCase(EchoTests.allTests),
   testCase(ServerCancellingTests.allTests),

+ 65 - 0
Tests/SwiftGRPCTests/CompletionQueueTests.swift

@@ -0,0 +1,65 @@
+/*
+ * Copyright 2018, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import Dispatch
+import Foundation
+@testable import SwiftGRPC
+import XCTest
+
+fileprivate class ClosingProvider: Echo_EchoProvider {
+  var doneExpectation: XCTestExpectation!
+  
+  func get(request: Echo_EchoRequest, session: Echo_EchoGetSession) throws -> Echo_EchoResponse {
+    return Echo_EchoResponse()
+  }
+  
+  func expand(request: Echo_EchoRequest, session: Echo_EchoExpandSession) throws {
+    let closeSem = DispatchSemaphore(value: 0)
+    try! session.close(withStatus: .ok) {
+      closeSem.signal()
+    }
+    XCTAssertThrowsError(try session.send(Echo_EchoResponse()))
+    doneExpectation.fulfill()
+  }
+  
+  func collect(session: Echo_EchoCollectSession) throws { }
+  
+  func update(session: Echo_EchoUpdateSession) throws { }
+}
+
+class CompletionQueueTests: BasicEchoTestCase {
+  static var allTests: [(String, (CompletionQueueTests) -> () throws -> Void)] {
+    return [
+      ("testCompletionQueueThrowsAfterShutdown", testCompletionQueueThrowsAfterShutdown)
+    ]
+  }
+  
+  override func makeProvider() -> Echo_EchoProvider { return ClosingProvider() }
+}
+
+extension CompletionQueueTests {
+  func testCompletionQueueThrowsAfterShutdown() {
+    (self.provider as! ClosingProvider).doneExpectation = expectation(description: "end of server-side request handler reached")
+    
+    let completionHandlerExpectation = expectation(description: "completion handler called")
+    _ = try! client.expand(Echo_EchoRequest(text: "foo bar baz")) { callResult in
+      XCTAssertEqual(.ok, callResult.statusCode)
+      XCTAssertEqual("OK", callResult.statusMessage)
+      completionHandlerExpectation.fulfill()
+    }
+    
+    waitForExpectations(timeout: defaultTimeout)
+  }
+}