瀏覽代碼

Improve test synchronization.

Tim Burks 7 年之前
父節點
當前提交
3e00aa5569

+ 0 - 1
Examples/Echo/EchoProvider.swift

@@ -36,7 +36,6 @@ class EchoProvider: Echo_EchoProvider {
       try session.send(response) { _ in sem.signal() }
       _ = sem.wait(timeout: DispatchTime.distantFuture)
       i += 1
-      sleep(1)
     }
   }
 

+ 2 - 2
Examples/Echo/Generated/echo.grpc.swift

@@ -49,7 +49,7 @@ class Echo_EchoExpandCallTestStub: ClientCallServerStreamingTestStub<Echo_EchoRe
 
 internal protocol Echo_EchoCollectCall: ClientCallClientStreaming {
   /// Call this to send each message in the request stream. Nonblocking.
-  func send(_ message: Echo_EchoRequest, errorHandler: @escaping (Error) -> Void) throws
+  func send(_ message: Echo_EchoRequest, completion: @escaping (Error?) -> Void) throws
 
   /// Call this to close the connection and wait for a response. Blocking.
   func closeAndReceive() throws -> Echo_EchoResponse
@@ -74,7 +74,7 @@ internal protocol Echo_EchoUpdateCall: ClientCallBidirectionalStreaming {
   func receive(completion: @escaping (Echo_EchoResponse?, ClientError?) -> Void) throws
 
   /// Call this to send each message in the request stream.
-  func send(_ message: Echo_EchoRequest, errorHandler: @escaping (Error) -> Void) throws
+  func send(_ message: Echo_EchoRequest, completion: @escaping (Error?) -> Void) throws
 
   /// Call this to close the sending connection. Blocking.
   func closeSend() throws

+ 69 - 10
Examples/Echo/PackageManager/Sources/main.swift

@@ -25,7 +25,7 @@ func addressOption(_ address: String) -> Option<String> {
 }
 
 let portOption = Option("port",
-                        default: "8081",
+                        default: "8080",
                         description: "port of server")
 let messageOption = Option("message",
                            default: "Testing 1 2 3",
@@ -102,8 +102,9 @@ Group {
     requestMessage.text = message
     print("expand sending: " + requestMessage.text)
     let sem = DispatchSemaphore(value: 0)
+    var callResult : CallResult?
     let expandCall = try service.expand(requestMessage) { result in
-      print("expand completed with result \(result)")
+      callResult = result
       sem.signal()
     }
     var running = true
@@ -116,61 +117,119 @@ Group {
       }
     }
     _ = sem.wait(timeout: DispatchTime.distantFuture)
+    if let statusCode = callResult?.statusCode {
+      print("expand completed with code \(statusCode)")
+    }
   }
 
   $0.command("collect", sslFlag, addressOption("localhost"), portOption, messageOption,
              description: "Perform a client-streaming collect().") { ssl, address, port, message in
     let service = buildEchoService(ssl, address, port, message)
     let sem = DispatchSemaphore(value: 0)
+    var callResult : CallResult?
     let collectCall = try service.collect { result in
-      print("collect completed with result \(result)")
+      callResult = result
       sem.signal()
     }
+
+    let sendCountMutex = Mutex()
+    var sendCount = 0
+
     let parts = message.components(separatedBy: " ")
     for part in parts {
       var requestMessage = Echo_EchoRequest()
       requestMessage.text = part
       print("collect sending: " + part)
-      try collectCall.send(requestMessage) { error in print(error) }
-      sleep(1)
+      try collectCall.send(requestMessage) {
+        error in
+        sendCountMutex.synchronize {
+          sendCount = sendCount + 1
+        }
+	if let error = error {	
+          print("collect send error \(error)")
+	}
+      }
+    }
+    // don't close until all sends have completed
+    var waiting = true
+    while (waiting) {
+      sendCountMutex.synchronize {
+        if sendCount == parts.count {
+          waiting = false
+        }
+      }
     }
     let responseMessage = try collectCall.closeAndReceive()
     print("collect received: \(responseMessage.text)")
     _ = sem.wait(timeout: DispatchTime.distantFuture)
+    if let statusCode = callResult?.statusCode {
+      print("collect completed with status \(statusCode)")
+    }
   }
 
   $0.command("update", sslFlag, addressOption("localhost"), portOption, messageOption,
              description: "Perform a bidirectional-streaming update().") { ssl, address, port, message in
     let service = buildEchoService(ssl, address, port, message)
     let sem = DispatchSemaphore(value: 0)
+    var callResult : CallResult?
     let updateCall = try service.update { result in
-      print("update completed with result \(result)")
+      callResult = result
       sem.signal()
     }
 
+    let responsesMutex = Mutex()
+    var responses : [String] = []
+
     DispatchQueue.global().async {
       var running = true
       while running {
         do {
           let responseMessage = try updateCall.receive()
-          print("update received: \(responseMessage.text)")
+          responsesMutex.synchronize {
+            responses.append("update received: \(responseMessage.text)")
+          }
         } catch ClientError.endOfStream {
           running = false
         } catch (let error) {
-          print("error: \(error)")
+          responsesMutex.synchronize {
+            responses.append("update receive error: \(error)")
+          }
         }
       }
     }
+
     let parts = message.components(separatedBy: " ")
     for part in parts {
       var requestMessage = Echo_EchoRequest()
       requestMessage.text = part
       print("update sending: " + requestMessage.text)
-      try updateCall.send(requestMessage) { error in print(error) }
-      sleep(1)
+      try updateCall.send(requestMessage) {
+        error in
+        if let error = error {
+          print("update send error: \(error)")
+        }
+      }
+    }
+
+    // don't close until last update is received
+    var waiting = true
+    while (waiting) {
+      responsesMutex.synchronize {
+        if responses.count == parts.count {
+          waiting = false
+        }
+      }
     }
     try updateCall.closeSend()
+
     _ = sem.wait(timeout: DispatchTime.distantFuture)
+
+    for response in responses {
+      print(response)
+    }
+    if let statusCode = callResult?.statusCode {
+      print("update completed with status \(statusCode)")
+    }
   }
 
 }.run()

+ 6 - 12
Examples/Echo/PackageManager/test.gold

@@ -5,25 +5,19 @@ expand received: Swift echo expand (0): Testing
 expand received: Swift echo expand (1): 1
 expand received: Swift echo expand (2): 2
 expand received: Swift echo expand (3): 3
-expand completed with result status ok: OK
-
-
+expand completed with code ok
 collect sending: Testing
 collect sending: 1
 collect sending: 2
 collect sending: 3
 collect received: Swift echo collect: Testing 1 2 3
-collect completed with result status ok: OK
-
-
+collect completed with status ok
 update sending: Testing
-update received: Swift echo update (1): Testing
 update sending: 1
-update received: Swift echo update (2): 1
 update sending: 2
-update received: Swift echo update (3): 2
 update sending: 3
+update received: Swift echo update (1): Testing
+update received: Swift echo update (2): 1
+update received: Swift echo update (3): 2
 update received: Swift echo update (4): 3
-update completed with result status ok: OK
-
-
+update completed with status ok

+ 32 - 16
Examples/Echo2/Sources/main.swift

@@ -60,25 +60,29 @@ Group {
              description: "Run an echo server.") { ssl, address, port in
     let sem = DispatchSemaphore(value: 0)
     let echoProvider = EchoProvider()
+    var echoServer: Echo_EchoServer?
     if ssl {
       print("starting secure server")
       let certificateURL = URL(fileURLWithPath: "ssl.crt")
       let keyURL = URL(fileURLWithPath: "ssl.key")
-      if let echoServer = Echo_EchoServer(address: address + ":" + port,
-                                          certificateURL: certificateURL,
-                                          keyURL: keyURL,
-                                          provider: echoProvider) {
-        echoServer.start()
-      }
+      echoServer = Echo_EchoServer(address: address + ":" + port,
+                                   certificateURL: certificateURL,
+                                   keyURL: keyURL,
+                                   provider: echoProvider)
+      echoServer?.start()
     } else {
       print("starting insecure server")
-      let echoServer = Echo_EchoServer(address: address + ":" + port,
-                                       provider: echoProvider)
-      echoServer.start()
+      echoServer = Echo_EchoServer(address: address + ":" + port,
+                                   provider: echoProvider)
+      echoServer?.start()
     }
     // This blocks to keep the main thread from finishing while the server runs,
     // but the server never exits. Kill the process to stop it.
     _ = sem.wait(timeout: DispatchTime.distantFuture)
+    // This suppresses a "variable echoServer was written to, but never read" warning.
+    _ = echoServer
+    // And this ensures that echoServer doesn't get deallocated right after it is created.
+    echoServer = nil
   }
 
   $0.command("get", sslFlag, addressOption("localhost"), portOption, messageOption,
@@ -98,8 +102,9 @@ Group {
     requestMessage.text = message
     print("expand sending: " + requestMessage.text)
     let sem = DispatchSemaphore(value: 0)
+    var callResult : CallResult?
     let expandCall = try service.expand(requestMessage) { result in
-      print("expand completed with result \(result)")
+      callResult = result
       sem.signal()
     }
     var running = true
@@ -107,19 +112,23 @@ Group {
       do {
         let responseMessage = try expandCall.receive()
         print("expand received: \(responseMessage.text)")
-      } catch Echo_EchoClientError.endOfStream {
+      } catch ClientError.endOfStream {
         running = false
       }
     }
     _ = sem.wait(timeout: DispatchTime.distantFuture)
+    if let statusMessage = callResult?.statusMessage {
+      print("expand completed with result \(statusMessage)")
+    }
   }
 
   $0.command("collect", sslFlag, addressOption("localhost"), portOption, messageOption,
              description: "Perform a client-streaming collect().") { ssl, address, port, message in
     let service = buildEchoService(ssl, address, port, message)
     let sem = DispatchSemaphore(value: 0)
+    var callResult : CallResult?
     let collectCall = try service.collect { result in
-      print("collect completed with result \(result)")
+      callResult = result
       sem.signal()
     }
     let parts = message.components(separatedBy: " ")
@@ -128,19 +137,23 @@ Group {
       requestMessage.text = part
       print("collect sending: " + part)
       try collectCall.send(requestMessage) { error in print(error) }
-      sleep(1)
+      usleep(100000)
     }
     let responseMessage = try collectCall.closeAndReceive()
     print("collect received: \(responseMessage.text)")
     _ = sem.wait(timeout: DispatchTime.distantFuture)
+    if let statusMessage = callResult?.statusMessage {
+      print("collect completed with result \(statusMessage)")
+    }
   }
 
   $0.command("update", sslFlag, addressOption("localhost"), portOption, messageOption,
              description: "Perform a bidirectional-streaming update().") { ssl, address, port, message in
     let service = buildEchoService(ssl, address, port, message)
     let sem = DispatchSemaphore(value: 0)
+    var callResult : CallResult?
     let updateCall = try service.update { result in
-      print("update completed with result \(result)")
+      callResult = result
       sem.signal()
     }
 
@@ -150,7 +163,7 @@ Group {
         do {
           let responseMessage = try updateCall.receive()
           print("update received: \(responseMessage.text)")
-        } catch Echo_EchoClientError.endOfStream {
+        } catch ClientError.endOfStream {
           running = false
         } catch (let error) {
           print("error: \(error)")
@@ -163,10 +176,13 @@ Group {
       requestMessage.text = part
       print("update sending: " + requestMessage.text)
       try updateCall.send(requestMessage) { error in print(error) }
-      sleep(1)
+      usleep(100000)
     }
     try updateCall.closeSend()
     _ = sem.wait(timeout: DispatchTime.distantFuture)
+    if let statusMessage = callResult?.statusMessage {
+      print("update completed with result \(statusMessage)")
+    }
   }
 
 }.run()

+ 2 - 2
Examples/Simple/PackageManager/main.swift

@@ -125,13 +125,13 @@ func server() throws {
 Group {
   $0.command("server") {
     gRPC.initialize()
-    print("gRPC version", gRPC.version!)
+    print("gRPC version", gRPC.version)
     try server()
   }
 
   $0.command("client") {
     gRPC.initialize()
-    print("gRPC version", gRPC.version!)
+    print("gRPC version", gRPC.version)
     try client()
   }
 

+ 1 - 1
Plugin/Templates/client-call-bidistreaming.swift

@@ -5,7 +5,7 @@
   func receive(completion: @escaping ({{ method|output }}?, ClientError?) -> Void) throws
 
   /// Call this to send each message in the request stream.
-  func send(_ message: {{ method|input }}, errorHandler: @escaping (Error) -> Void) throws
+  func send(_ message: {{ method|input }}, completion: @escaping (Error?) -> Void) throws
 
   /// Call this to close the sending connection. Blocking.
   func closeSend() throws

+ 1 - 1
Plugin/Templates/client-call-clientstreaming.swift

@@ -1,6 +1,6 @@
 {{ access }} protocol {{ .|call:file,service,method }}: ClientCallClientStreaming {
   /// Call this to send each message in the request stream. Nonblocking.
-  func send(_ message: {{ method|input }}, errorHandler: @escaping (Error) -> Void) throws
+  func send(_ message: {{ method|input }}, completion: @escaping (Error?) -> Void) throws
 
   /// Call this to close the connection and wait for a response. Blocking.
   func closeAndReceive() throws -> {{ method|output }}

+ 24 - 9
Sources/gRPC/Call.swift

@@ -155,7 +155,7 @@ public class Call {
   private let owned: Bool
 
   /// A queue of pending messages to send over the call
-  private var messageQueue: [(dataToSend: Data, errorHandler: (Error) -> Void)] = []
+  private var messageQueue: [(dataToSend: Data, completion: (Error?) -> Void)] = []
 
   /// True if a message write operation is underway
   private var writing: Bool
@@ -253,23 +253,23 @@ public class Call {
   ///
   /// Parameter data: the message data to send
   /// - Throws: `CallError` if fails to call. `CallWarning` if blocked.
-  public func sendMessage(data: Data, errorHandler: @escaping (Error) -> Void) throws {
+  public func sendMessage(data: Data, completion: @escaping (Error?) -> Void) throws {
     try sendMutex.synchronize {
       if writing {
         if (Call.messageQueueMaxLength > 0) && // if max length is <= 0, consider it infinite
           (messageQueue.count == Call.messageQueueMaxLength) {
           throw CallWarning.blocked
         }
-        messageQueue.append((dataToSend: data, errorHandler: errorHandler))
+        messageQueue.append((dataToSend: data, completion: completion))
       } else {
         writing = true
-        try sendWithoutBlocking(data: data, errorHandler: errorHandler)
+        try sendWithoutBlocking(data: data, completion: completion)
       }  
     }
   }
 
   /// helper for sending queued messages
-  private func sendWithoutBlocking(data: Data, errorHandler: @escaping (Error) -> Void) throws {
+  private func sendWithoutBlocking(data: Data, completion: @escaping (Error?) -> Void) throws {
     try perform(OperationGroup(call: self,
                                operations: [.sendMessage(ByteBuffer(data: data))]) { operationGroup in
         if operationGroup.success {
@@ -277,11 +277,11 @@ public class Call {
             self.sendMutex.synchronize {
               // if there are messages pending, send the next one
               if self.messageQueue.count > 0 {
-                let (nextMessage, nextErrorHandler) = self.messageQueue.removeFirst()
+                let (nextMessage, nextCompletionHandler) = self.messageQueue.removeFirst()
                 do {
-                  try self.sendWithoutBlocking(data: nextMessage, errorHandler: nextErrorHandler)
+                  try self.sendWithoutBlocking(data: nextMessage, completion: nextCompletionHandler)
                 } catch (let callError) {
-                  errorHandler(callError)
+                  nextCompletionHandler(callError)
                 }
               } else {
                 // otherwise, we are finished writing
@@ -289,14 +289,29 @@ public class Call {
               }
             }
           }
+          completion(nil)
         } else {
           // if the event failed, shut down
           self.writing = false
-          errorHandler(CallError.unknown)
+          completion(CallError.unknown)
         }
     })
   }
 
+  // Receive a message over a streaming connection.
+  /// - Throws: `CallError` if fails to call.
+  public func closeAndReceiveMessage(callback: @escaping (Data?) throws -> Void) throws {
+    try perform(OperationGroup(call: self, operations: [.sendCloseFromClient, .receiveMessage]) { operationGroup in
+      if operationGroup.success {
+        if let messageBuffer = operationGroup.receivedMessage() {
+          try callback(messageBuffer.data())
+        } else {
+          try callback(nil) // an empty response signals the end of a connection
+        }
+      }
+    })
+  }
+
   // Receive a message over a streaming connection.
   /// - Throws: `CallError` if fails to call.
   public func receiveMessage(callback: @escaping (Data?) throws -> Void) throws {

+ 3 - 3
Sources/gRPC/GenCodeSupport/ClientCallBidirectionalStreaming.swift

@@ -65,9 +65,9 @@ open class ClientCallBidirectionalStreamingBase<InputType: Message, OutputType:
     return returnMessage
   }
 
-  public func send(_ message: InputType, errorHandler: @escaping (Error) -> Void) throws {
+  public func send(_ message: InputType, completion: @escaping (Error?) -> Void) throws {
     let messageData = try message.serializedData()
-    try call.sendMessage(data: messageData, errorHandler: errorHandler)
+    try call.sendMessage(data: messageData, completion: completion)
   }
 
   public func closeSend(completion: (() -> Void)?) throws {
@@ -115,7 +115,7 @@ open class ClientCallBidirectionalStreamingTestStub<InputType: Message, OutputTy
     }
   }
 
-  open func send(_ message: InputType, errorHandler _: @escaping (Error) -> Void) throws {
+  open func send(_ message: InputType, completion _: @escaping (Error?) -> Void) throws {
     inputs.append(message)
   }
 

+ 4 - 5
Sources/gRPC/GenCodeSupport/ClientCallClientStreaming.swift

@@ -33,14 +33,14 @@ open class ClientCallClientStreamingBase<InputType: Message, OutputType: Message
     return self
   }
 
-  public func send(_ message: InputType, errorHandler: @escaping (Error) -> Void) throws {
+  public func send(_ message: InputType, completion: @escaping (Error?) -> Void) throws {
     let messageData = try message.serializedData()
-    try call.sendMessage(data: messageData, errorHandler: errorHandler)
+    try call.sendMessage(data: messageData, completion: completion)
   }
 
   public func closeAndReceive(completion: @escaping (OutputType?, ClientError?) -> Void) throws {
     do {
-      try call.receiveMessage { responseData in
+      try call.closeAndReceiveMessage { responseData in
         if let responseData = responseData,
           let response = try? OutputType(serializedData: responseData) {
           completion(response, nil)
@@ -48,7 +48,6 @@ open class ClientCallClientStreamingBase<InputType: Message, OutputType: Message
           completion(nil, .invalidMessageReceived)
         }
       }
-      try call.close(completion: {})
     } catch (let error) {
       throw error
     }
@@ -89,7 +88,7 @@ open class ClientCallClientStreamingTestStub<InputType: Message, OutputType: Mes
   
   public init() {}
 
-  open func send(_ message: InputType, errorHandler _: @escaping (Error) -> Void) throws {
+  open func send(_ message: InputType, completion _: @escaping (Error?) -> Void) throws {
     inputs.append(message)
   }