Browse Source

Introduce new CountDownLatch class and simplify unit test.

Tim Burks 9 years ago
parent
commit
141dec29ae

+ 65 - 0
Sources/gRPC/CountDownLatch.swift

@@ -0,0 +1,65 @@
+/*
+ *
+ * Copyright 2016, Google Inc.
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are
+ * met:
+ *
+ *     * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ *     * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following disclaimer
+ * in the documentation and/or other materials provided with the
+ * distribution.
+ *     * Neither the name of Google Inc. nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ */
+import Foundation
+
+/// A synchronization primitive used to synchronize gRPC operations
+/// Initialize it with a count, a call to wait() will block until
+/// countDown() has been called the specified number of times.
+public class CountDownLatch {
+  private var condition : NSCondition
+  private var count : Int
+
+  public init(_ count : Int) {
+    self.condition = NSCondition()
+    self.count = count
+  }
+
+  public func countDown() {
+    condition.lock()
+    self.count = self.count - 1
+    self.condition.signal()
+    self.condition.unlock()
+  }
+
+  public func wait() {
+    var running = true
+    while (running) {
+      self.condition.lock()
+      self.condition.wait()
+      if (self.count == 0) {
+        running = false
+      }
+      self.condition.unlock()
+    }
+  }
+}

+ 6 - 3
SwiftGRPC.xcodeproj/project.pbxproj

@@ -7,6 +7,7 @@
 	objects = {
 
 /* Begin PBXBuildFile section */
+		D30D29181E3B0AC9004A414B /* CountDownLatch.swift in Sources */ = {isa = PBXBuildFile; fileRef = D30D29171E3B0AC9004A414B /* CountDownLatch.swift */; };
 		OBJ_1000 /* stack.c in Sources */ = {isa = PBXBuildFile; fileRef = OBJ_231 /* stack.c */; };
 		OBJ_1001 /* a_digest.c in Sources */ = {isa = PBXBuildFile; fileRef = OBJ_233 /* a_digest.c */; };
 		OBJ_1002 /* a_sign.c in Sources */ = {isa = PBXBuildFile; fileRef = OBJ_234 /* a_sign.c */; };
@@ -641,6 +642,7 @@
 /* End PBXContainerItemProxy section */
 
 /* Begin PBXFileReference section */
+		D30D29171E3B0AC9004A414B /* CountDownLatch.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = CountDownLatch.swift; sourceTree = "<group>"; };
 		OBJ_100 /* buf.c */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.c; path = buf.c; sourceTree = "<group>"; };
 		OBJ_102 /* asn1_compat.c */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.c; path = asn1_compat.c; sourceTree = "<group>"; };
 		OBJ_103 /* ber.c */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.c; path = ber.c; sourceTree = "<group>"; };
@@ -2238,7 +2240,7 @@
 			path = sockaddr;
 			sourceTree = "<group>";
 		};
-		OBJ_5 /*  */ = {
+		OBJ_5 = {
 			isa = PBXGroup;
 			children = (
 				OBJ_6 /* Package.swift */,
@@ -2250,7 +2252,6 @@
 				OBJ_774 /* third_party */,
 				OBJ_775 /* Products */,
 			);
-			name = "";
 			sourceTree = "<group>";
 		};
 		OBJ_500 /* dns */ = {
@@ -2821,6 +2822,7 @@
 				OBJ_758 /* Call.swift */,
 				OBJ_759 /* Channel.swift */,
 				OBJ_760 /* CompletionQueue.swift */,
+				D30D29171E3B0AC9004A414B /* CountDownLatch.swift */,
 				OBJ_761 /* gRPC.swift */,
 				OBJ_762 /* Handler.swift */,
 				OBJ_763 /* Metadata.swift */,
@@ -2996,7 +2998,7 @@
 			knownRegions = (
 				en,
 			);
-			mainGroup = OBJ_5 /*  */;
+			mainGroup = OBJ_5;
 			productRefGroup = OBJ_775 /* Products */;
 			projectDirPath = "";
 			projectRoot = "";
@@ -3313,6 +3315,7 @@
 			isa = PBXSourcesBuildPhase;
 			buildActionMask = 0;
 			files = (
+				D30D29181E3B0AC9004A414B /* CountDownLatch.swift in Sources */,
 				OBJ_1369 /* ByteBuffer.swift in Sources */,
 				OBJ_1370 /* Call.swift in Sources */,
 				OBJ_1371 /* Channel.swift in Sources */,

+ 45 - 89
Tests/gRPCTests/GRPCTests.swift

@@ -8,39 +8,27 @@ func Log(_ message : String) {
 }
 
 class gRPCTests: XCTestCase {
-  var count : Int = 0
-
+  
   func testBasicSanity() {
     gRPC.initialize()
-    let done = NSCondition()
-    self.count = 2
+    let latch = CountDownLatch(2)
     DispatchQueue.global().async() {
-      server()
-      Log("server finished")
-      done.lock()
-      self.count = self.count - 1
-      done.signal()
-      done.unlock()
+      do {
+        try server()
+      } catch (let error) {
+        XCTFail("server error \(error)")
+      }
+      latch.countDown()
     }
     DispatchQueue.global().async() {
-      client()
-      Log("client finished")
-      done.lock()
-      self.count = self.count - 1
-      done.signal()
-      done.unlock()
-    }
-    var running = true
-    while (running) {
-      Log("waiting")
-      done.lock()
-      done.wait()
-      if (self.count == 0) {
-        running = false
+      do {
+        try client()
+      } catch (let error) {
+        XCTFail("client error \(error)")
       }
-      Log("count \(self.count)")
-      done.unlock()
+      latch.countDown()
     }
+    latch.wait()
   }
 }
 
@@ -83,61 +71,43 @@ func verify_metadata(_ metadata: Metadata, expected: [String:String]) {
   }
 }
 
-func client() {
+func client() throws {
   let message = clientText.data(using: .utf8)
-  let c = gRPC.Channel(address:address)
-  c.host = host
-  let done = NSCondition()
-  var running = true
+  let channel = gRPC.Channel(address:address)
+  channel.host = host
   for i in 0..<steps {
-    var call : Call
-    do {
-      let method = (i < steps-1) ? hello : goodbye
-      call = c.makeCall(method)
-      let metadata = Metadata(initialClientMetadata)
-      try call.start(.unary, metadata:metadata, message:message) {
-        (response) in
-        // verify the basic response from the server
-        XCTAssertEqual(response.statusCode, statusCode)
-        XCTAssertEqual(response.statusMessage, statusMessage)
-        // verify the message from the server
-        let resultData = response.resultData
-        let messageString = String(data: resultData!, encoding: .utf8)
-        XCTAssertEqual(messageString, serverText)
-        // verify the initial metadata from the server
-        let initialMetadata = response.initialMetadata!
-        verify_metadata(initialMetadata, expected: initialServerMetadata)
-        // verify the trailing metadata from the server
-        let trailingMetadata = response.trailingMetadata!
-        verify_metadata(trailingMetadata, expected: trailingServerMetadata)
-        done.lock()
-        running = false
-        done.signal()
-        done.unlock()
-      }
-    } catch (let error) {
-        XCTFail("error \(error)")
+    let latch = CountDownLatch(1)
+    let method = (i < steps-1) ? hello : goodbye
+    let call = channel.makeCall(method)
+    let metadata = Metadata(initialClientMetadata)
+    try call.start(.unary, metadata:metadata, message:message) {
+      (response) in
+      // verify the basic response from the server
+      XCTAssertEqual(response.statusCode, statusCode)
+      XCTAssertEqual(response.statusMessage, statusMessage)
+      // verify the message from the server
+      let resultData = response.resultData
+      let messageString = String(data: resultData!, encoding: .utf8)
+      XCTAssertEqual(messageString, serverText)
+      // verify the initial metadata from the server
+      let initialMetadata = response.initialMetadata!
+      verify_metadata(initialMetadata, expected: initialServerMetadata)
+      // verify the trailing metadata from the server
+      let trailingMetadata = response.trailingMetadata!
+      verify_metadata(trailingMetadata, expected: trailingServerMetadata)
+      // report completion
+      latch.countDown()
     }
     // wait for the call to complete
-    var finished = false
-    while (!finished) {
-      done.lock()
-      done.wait()
-      if (!running) {
-        finished = true
-      }
-      done.unlock()
-    }
-    Log("finished client call \(i)")
+    latch.wait()
   }
-  Log("client done")
+  usleep(500) // temporarily delay calls to the channel destructor
 }
 
-func server() {
+func server() throws {
   let server = gRPC.Server(address:address)
   var requestCount = 0
-  let done = NSCondition()
-  var running = true
+  let latch = CountDownLatch(1)
   server.run() {(requestHandler) in
     do {
       requestCount += 1
@@ -169,23 +139,9 @@ func server() {
     }
   }
   server.onCompletion() {
-      // exit the server thread
-      Log("signaling completion")
-      done.lock()
-      running = false
-      done.signal()
-      done.unlock()
+    // exit the server thread
+    latch.countDown()
   }
   // wait for the server to exit
-  var finished = false
-  while !finished {
-    done.lock()
-    done.wait()
-    if (!running) {
-      finished = true
-    }
-    done.unlock()
-  }
-  Log("server done")
+  latch.wait()
 }
-