Browse Source

Merge pull request #173 from MrMage/reuse-send-on-call

Make `Handler.sendResponse` and `.receiveMessage` call through to the corresponding methods on `Call`
Tim Burks 7 years ago
parent
commit
a27d840db6

+ 14 - 9
Examples/Echo/EchoProvider.swift

@@ -28,15 +28,16 @@ class EchoProvider: Echo_EchoProvider {
   // expand splits a request into words and returns each word in a separate message.
   // expand splits a request into words and returns each word in a separate message.
   func expand(request: Echo_EchoRequest, session: Echo_EchoExpandSession) throws {
   func expand(request: Echo_EchoRequest, session: Echo_EchoExpandSession) throws {
     let parts = request.text.components(separatedBy: " ")
     let parts = request.text.components(separatedBy: " ")
-    var i = 0
-    for part in parts {
+    for (i, part) in parts.enumerated() {
       var response = Echo_EchoResponse()
       var response = Echo_EchoResponse()
       response.text = "Swift echo expand (\(i)): \(part)"
       response.text = "Swift echo expand (\(i)): \(part)"
-      let sem = DispatchSemaphore(value: 0)
-      try session.send(response) { _ in sem.signal() }
-      _ = sem.wait()
-      i += 1
+      try session.send(response) {
+        if let error = $0 {
+          print("expand error: \(error)")
+        }
+      }
     }
     }
+    session.waitForSendOperationsToFinish()
   }
   }
 
 
   // collect collects a sequence of messages and returns them concatenated when the caller closes.
   // collect collects a sequence of messages and returns them concatenated when the caller closes.
@@ -66,15 +67,19 @@ class EchoProvider: Echo_EchoProvider {
         var response = Echo_EchoResponse()
         var response = Echo_EchoResponse()
         response.text = "Swift echo update (\(count)): \(request.text)"
         response.text = "Swift echo update (\(count)): \(request.text)"
         count += 1
         count += 1
-        let sem = DispatchSemaphore(value: 0)
-        try session.send(response) { _ in sem.signal() }
-        _ = sem.wait()
+        try session.send(response) {
+          if let error = $0 {
+            print("update error: \(error)")
+          }
+        }
       } catch ServerError.endOfStream {
       } catch ServerError.endOfStream {
         break
         break
       } catch (let error) {
       } catch (let error) {
         print("\(error)")
         print("\(error)")
+        break
       }
       }
     }
     }
+    session.waitForSendOperationsToFinish()
     try session.close()
     try session.close()
   }
   }
 }
 }

+ 2 - 2
Examples/Echo/Generated/echo.grpc.swift

@@ -202,7 +202,7 @@ class Echo_EchoGetSessionTestStub: ServerSessionUnaryTestStub, Echo_EchoGetSessi
 
 
 internal protocol Echo_EchoExpandSession: ServerSessionServerStreaming {
 internal protocol Echo_EchoExpandSession: ServerSessionServerStreaming {
   /// Send a message. Nonblocking.
   /// Send a message. Nonblocking.
-  func send(_ response: Echo_EchoResponse, completion: ((Bool) -> Void)?) throws
+  func send(_ response: Echo_EchoResponse, completion: ((Error?) -> Void)?) throws
 }
 }
 
 
 fileprivate final class Echo_EchoExpandSessionBase: ServerSessionServerStreamingBase<Echo_EchoRequest, Echo_EchoResponse>, Echo_EchoExpandSession {}
 fileprivate final class Echo_EchoExpandSessionBase: ServerSessionServerStreamingBase<Echo_EchoRequest, Echo_EchoResponse>, Echo_EchoExpandSession {}
@@ -226,7 +226,7 @@ internal protocol Echo_EchoUpdateSession: ServerSessionBidirectionalStreaming {
   func receive() throws -> Echo_EchoRequest
   func receive() throws -> Echo_EchoRequest
 
 
   /// Send a message. Nonblocking.
   /// Send a message. Nonblocking.
-  func send(_ response: Echo_EchoResponse, completion: ((Bool) -> Void)?) throws
+  func send(_ response: Echo_EchoResponse, completion: ((Error?) -> Void)?) throws
 
 
   /// Close a connection. Blocks until the connection is closed.
   /// Close a connection. Blocks until the connection is closed.
   func close() throws
   func close() throws

+ 18 - 18
Sources/gRPC/Call.swift

@@ -143,7 +143,7 @@ public class Call {
   private static let callMutex = Mutex()
   private static let callMutex = Mutex()
 
 
   /// Maximum number of messages that can be queued
   /// Maximum number of messages that can be queued
-  public static var messageQueueMaxLength = 0
+  public static var messageQueueMaxLength: Int? = nil
 
 
   /// Pointer to underlying C representation
   /// Pointer to underlying C representation
   private let underlyingCall: UnsafeMutableRawPointer
   private let underlyingCall: UnsafeMutableRawPointer
@@ -155,8 +155,10 @@ public class Call {
   private let owned: Bool
   private let owned: Bool
 
 
   /// A queue of pending messages to send over the call
   /// A queue of pending messages to send over the call
-  private var messageQueue: [(dataToSend: Data, completion: (Error?) -> Void)] = []
+  private var messageQueue: [(dataToSend: Data, completion: ((Error?) -> Void)?)] = []
 
 
+  /// A dispatch group that contains all pending send operations.
+  /// You can wait on it to ensure that all currently enqueued messages have been sent.
   public let messageQueueEmpty = DispatchGroup()
   public let messageQueueEmpty = DispatchGroup()
   
   
   /// True if a message write operation is underway
   /// True if a message write operation is underway
@@ -206,7 +208,7 @@ public class Call {
   /// - Parameter style: the style of call to start
   /// - Parameter style: the style of call to start
   /// - Parameter metadata: metadata to send with the call
   /// - Parameter metadata: metadata to send with the call
   /// - Parameter message: data containing the message to send (.unary and .serverStreaming only)
   /// - Parameter message: data containing the message to send (.unary and .serverStreaming only)
-  /// - Parameter callback: a block to call with call results
+  /// - Parameter completion: a block to call with call results
   /// - Throws: `CallError` if fails to call.
   /// - Throws: `CallError` if fails to call.
   public func start(_ style: CallStyle,
   public func start(_ style: CallStyle,
                     metadata: Metadata,
                     metadata: Metadata,
@@ -255,12 +257,12 @@ public class Call {
   ///
   ///
   /// Parameter data: the message data to send
   /// Parameter data: the message data to send
   /// - Throws: `CallError` if fails to call. `CallWarning` if blocked.
   /// - Throws: `CallError` if fails to call. `CallWarning` if blocked.
-  public func sendMessage(data: Data, completion: @escaping (Error?) -> Void) throws {
+  public func sendMessage(data: Data, completion: ((Error?) -> Void)? = nil) throws {
     messageQueueEmpty.enter()
     messageQueueEmpty.enter()
     try sendMutex.synchronize {
     try sendMutex.synchronize {
       if writing {
       if writing {
-        if (Call.messageQueueMaxLength > 0) && // if max length is <= 0, consider it infinite
-          (messageQueue.count == Call.messageQueueMaxLength) {
+        if let messageQueueMaxLength = Call.messageQueueMaxLength,
+          messageQueue.count >= messageQueueMaxLength {
           throw CallWarning.blocked
           throw CallWarning.blocked
         }
         }
         messageQueue.append((dataToSend: data, completion: completion))
         messageQueue.append((dataToSend: data, completion: completion))
@@ -272,7 +274,7 @@ public class Call {
   }
   }
 
 
   /// helper for sending queued messages
   /// helper for sending queued messages
-  private func sendWithoutBlocking(data: Data, completion: @escaping (Error?) -> Void) throws {
+  private func sendWithoutBlocking(data: Data, completion: ((Error?) -> Void)?) throws {
     try perform(OperationGroup(call: self,
     try perform(OperationGroup(call: self,
                                operations: [.sendMessage(ByteBuffer(data: data))]) { operationGroup in
                                operations: [.sendMessage(ByteBuffer(data: data))]) { operationGroup in
         // TODO(timburks, danielalm): Is the `async` dispatch here needed, and/or should we call the completion handler
         // TODO(timburks, danielalm): Is the `async` dispatch here needed, and/or should we call the completion handler
@@ -287,7 +289,7 @@ public class Call {
               do {
               do {
                 try self.sendWithoutBlocking(data: nextMessage, completion: nextCompletionHandler)
                 try self.sendWithoutBlocking(data: nextMessage, completion: nextCompletionHandler)
               } catch (let callError) {
               } catch (let callError) {
-                nextCompletionHandler(callError)
+                nextCompletionHandler?(callError)
               }
               }
             } else {
             } else {
               // otherwise, we are finished writing
               // otherwise, we are finished writing
@@ -295,20 +297,20 @@ public class Call {
             }
             }
           }
           }
         }
         }
-        completion(operationGroup.success ? nil : CallError.unknown)
+        completion?(operationGroup.success ? nil : CallError.unknown)
         self.messageQueueEmpty.leave()
         self.messageQueueEmpty.leave()
     })
     })
   }
   }
 
 
   // Receive a message over a streaming connection.
   // Receive a message over a streaming connection.
   /// - Throws: `CallError` if fails to call.
   /// - Throws: `CallError` if fails to call.
-  public func closeAndReceiveMessage(callback: @escaping (Data?) throws -> Void) throws {
+  public func closeAndReceiveMessage(completion: @escaping (Data?) throws -> Void) throws {
     try perform(OperationGroup(call: self, operations: [.sendCloseFromClient, .receiveMessage]) { operationGroup in
     try perform(OperationGroup(call: self, operations: [.sendCloseFromClient, .receiveMessage]) { operationGroup in
       if operationGroup.success {
       if operationGroup.success {
         if let messageBuffer = operationGroup.receivedMessage() {
         if let messageBuffer = operationGroup.receivedMessage() {
-          try callback(messageBuffer.data())
+          try completion(messageBuffer.data())
         } else {
         } else {
-          try callback(nil) // an empty response signals the end of a connection
+          try completion(nil) // an empty response signals the end of a connection
         }
         }
       }
       }
     })
     })
@@ -316,14 +318,12 @@ public class Call {
 
 
   // Receive a message over a streaming connection.
   // Receive a message over a streaming connection.
   /// - Throws: `CallError` if fails to call.
   /// - Throws: `CallError` if fails to call.
-  public func receiveMessage(callback: @escaping (Data?) throws -> Void) throws {
+  public func receiveMessage(completion: @escaping (Data?) throws -> Void) throws {
     try perform(OperationGroup(call: self, operations: [.receiveMessage]) { operationGroup in
     try perform(OperationGroup(call: self, operations: [.receiveMessage]) { operationGroup in
       if operationGroup.success {
       if operationGroup.success {
-        if let messageBuffer = operationGroup.receivedMessage() {
-          try callback(messageBuffer.data())
-        } else {
-          try callback(nil) // an empty response signals the end of a connection
-        }
+        try completion(operationGroup.receivedMessage()?.data())
+      } else {
+        try completion(nil)
       }
       }
     })
     })
   }
   }

+ 11 - 3
Sources/gRPC/GenCodeSupport/ServerSessionBidirectionalStreaming.swift

@@ -18,7 +18,9 @@ import Dispatch
 import Foundation
 import Foundation
 import SwiftProtobuf
 import SwiftProtobuf
 
 
-public protocol ServerSessionBidirectionalStreaming: ServerSession {}
+public protocol ServerSessionBidirectionalStreaming: ServerSession {
+  func waitForSendOperationsToFinish()
+}
 
 
 open class ServerSessionBidirectionalStreamingBase<InputType: Message, OutputType: Message>: ServerSessionBase, ServerSessionBidirectionalStreaming {
 open class ServerSessionBidirectionalStreamingBase<InputType: Message, OutputType: Message>: ServerSessionBase, ServerSessionBidirectionalStreaming {
   public typealias ProviderBlock = (ServerSessionBidirectionalStreamingBase) throws -> Void
   public typealias ProviderBlock = (ServerSessionBidirectionalStreamingBase) throws -> Void
@@ -50,7 +52,7 @@ open class ServerSessionBidirectionalStreamingBase<InputType: Message, OutputTyp
     }
     }
   }
   }
 
 
-  public func send(_ response: OutputType, completion: ((Bool) -> Void)?) throws {
+  public func send(_ response: OutputType, completion: ((Error?) -> Void)?) throws {
     try handler.sendResponse(message: response.serializedData(), completion: completion)
     try handler.sendResponse(message: response.serializedData(), completion: completion)
   }
   }
 
 
@@ -73,6 +75,10 @@ open class ServerSessionBidirectionalStreamingBase<InputType: Message, OutputTyp
       }
       }
     }
     }
   }
   }
+
+  public func waitForSendOperationsToFinish() {
+    handler.call.messageQueueEmpty.wait()
+  }
 }
 }
 
 
 /// Simple fake implementation of ServerSessionBidirectionalStreaming that returns a previously-defined set of results
 /// Simple fake implementation of ServerSessionBidirectionalStreaming that returns a previously-defined set of results
@@ -90,9 +96,11 @@ open class ServerSessionBidirectionalStreamingTestStub<InputType: Message, Outpu
     }
     }
   }
   }
 
 
-  open func send(_ response: OutputType, completion _: ((Bool) -> Void)?) throws {
+  open func send(_ response: OutputType, completion _: ((Error?) -> Void)?) throws {
     outputs.append(response)
     outputs.append(response)
   }
   }
 
 
   open func close() throws {}
   open func close() throws {}
+
+  open func waitForSendOperationsToFinish() {}
 }
 }

+ 12 - 3
Sources/gRPC/GenCodeSupport/ServerSessionServerStreaming.swift

@@ -18,7 +18,9 @@ import Dispatch
 import Foundation
 import Foundation
 import SwiftProtobuf
 import SwiftProtobuf
 
 
-public protocol ServerSessionServerStreaming: ServerSession {}
+public protocol ServerSessionServerStreaming: ServerSession {
+  func waitForSendOperationsToFinish()
+}
 
 
 open class ServerSessionServerStreamingBase<InputType: Message, OutputType: Message>: ServerSessionBase, ServerSessionServerStreaming {
 open class ServerSessionServerStreamingBase<InputType: Message, OutputType: Message>: ServerSessionBase, ServerSessionServerStreaming {
   public typealias ProviderBlock = (InputType, ServerSessionServerStreamingBase) throws -> Void
   public typealias ProviderBlock = (InputType, ServerSessionServerStreamingBase) throws -> Void
@@ -29,12 +31,13 @@ open class ServerSessionServerStreamingBase<InputType: Message, OutputType: Mess
     super.init(handler: handler)
     super.init(handler: handler)
   }
   }
 
 
-  public func send(_ response: OutputType, completion: ((Bool) -> Void)?) throws {
+  public func send(_ response: OutputType, completion: ((Error?) -> Void)?) throws {
     try handler.sendResponse(message: response.serializedData(), completion: completion)
     try handler.sendResponse(message: response.serializedData(), completion: completion)
   }
   }
 
 
   public func run(queue: DispatchQueue) throws {
   public func run(queue: DispatchQueue) throws {
     try handler.receiveMessage(initialMetadata: initialMetadata) { requestData in
     try handler.receiveMessage(initialMetadata: initialMetadata) { requestData in
+      // TODO(danielalm): Unify this behavior with `ServerSessionBidirectionalStreamingBase.run()`.
       if let requestData = requestData {
       if let requestData = requestData {
         do {
         do {
           let requestMessage = try InputType(serializedData: requestData)
           let requestMessage = try InputType(serializedData: requestData)
@@ -57,6 +60,10 @@ open class ServerSessionServerStreamingBase<InputType: Message, OutputType: Mess
       }
       }
     }
     }
   }
   }
+
+  public func waitForSendOperationsToFinish() {
+    handler.call.messageQueueEmpty.wait()
+  }
 }
 }
 
 
 /// Simple fake implementation of ServerSessionServerStreaming that returns a previously-defined set of results
 /// Simple fake implementation of ServerSessionServerStreaming that returns a previously-defined set of results
@@ -64,9 +71,11 @@ open class ServerSessionServerStreamingBase<InputType: Message, OutputType: Mess
 open class ServerSessionServerStreamingTestStub<OutputType: Message>: ServerSessionTestStub, ServerSessionServerStreaming {
 open class ServerSessionServerStreamingTestStub<OutputType: Message>: ServerSessionTestStub, ServerSessionServerStreaming {
   open var outputs: [OutputType] = []
   open var outputs: [OutputType] = []
 
 
-  open func send(_ response: OutputType, completion _: ((Bool) -> Void)?) throws {
+  open func send(_ response: OutputType, completion _: ((Error?) -> Void)?) throws {
     outputs.append(response)
     outputs.append(response)
   }
   }
 
 
   open func close() throws {}
   open func close() throws {}
+
+  open func waitForSendOperationsToFinish() {}
 }
 }

+ 4 - 16
Sources/gRPC/Handler.swift

@@ -30,7 +30,7 @@ public class Handler {
   public let requestMetadata: Metadata
   public let requestMetadata: Metadata
 
 
   /// A Call object that can be used to respond to the request
   /// A Call object that can be used to respond to the request
-  lazy var call: Call = {
+  private(set) lazy var call: Call = {
     Call(underlyingCall: cgrpc_handler_get_call(self.underlyingHandler),
     Call(underlyingCall: cgrpc_handler_get_call(self.underlyingHandler),
          owned: false,
          owned: false,
          completionQueue: self.completionQueue)
          completionQueue: self.completionQueue)
@@ -166,14 +166,7 @@ public class Handler {
   /// - Parameter completion: a completion handler to call after the message has been received
   /// - Parameter completion: a completion handler to call after the message has been received
   /// - Returns: a tuple containing status codes and a message (if available)
   /// - Returns: a tuple containing status codes and a message (if available)
   public func receiveMessage(completion: @escaping (Data?) throws -> Void) throws {
   public func receiveMessage(completion: @escaping (Data?) throws -> Void) throws {
-    let operations = OperationGroup(call: call, operations: [.receiveMessage]) { operationGroup in
-      if operationGroup.success {
-        try completion(operationGroup.receivedMessage()?.data())
-      } else {
-        try completion(nil)
-      }
-    }
-    try call.perform(operations)
+    try call.receiveMessage(completion: completion)
   }
   }
   
   
   /// Sends the response to a request
   /// Sends the response to a request
@@ -181,13 +174,8 @@ public class Handler {
   /// - Parameter message: the message to send
   /// - Parameter message: the message to send
   /// - Parameter completion: a completion handler to call after the response has been sent
   /// - Parameter completion: a completion handler to call after the response has been sent
   public func sendResponse(message: Data,
   public func sendResponse(message: Data,
-                           completion: ((Bool) throws -> Void)? = nil) throws {
-    let operations = OperationGroup(call: call,
-                                    operations: [.sendMessage(ByteBuffer(data: message))],
-                                    completion: completion != nil
-                                      ? { operationGroup in try completion?(operationGroup.success) }
-                                      : nil)
-    try call.perform(operations)
+                           completion: ((Error?) -> Void)? = nil) throws {
+    try call.sendMessage(data: message, completion: completion)
   }
   }
   
   
   /// Recognize when the client has closed a request
   /// Recognize when the client has closed a request

+ 2 - 2
Sources/protoc-gen-swiftgrpc/Generator-Server.swift

@@ -164,7 +164,7 @@ extension Generator {
     println("\(access) protocol \(methodSessionName): ServerSessionServerStreaming {")
     println("\(access) protocol \(methodSessionName): ServerSessionServerStreaming {")
     indent()
     indent()
     println("/// Send a message. Nonblocking.")
     println("/// Send a message. Nonblocking.")
-    println("func send(_ response: \(methodOutputName), completion: ((Bool) -> Void)?) throws")
+    println("func send(_ response: \(methodOutputName), completion: ((Error?) -> Void)?) throws")
     outdent()
     outdent()
     println("}")
     println("}")
     println()
     println()
@@ -182,7 +182,7 @@ extension Generator {
     println("func receive() throws -> \(methodInputName)")
     println("func receive() throws -> \(methodInputName)")
     println()
     println()
     println("/// Send a message. Nonblocking.")
     println("/// Send a message. Nonblocking.")
-    println("func send(_ response: \(methodOutputName), completion: ((Bool) -> Void)?) throws")
+    println("func send(_ response: \(methodOutputName), completion: ((Error?) -> Void)?) throws")
     println()
     println()
     println("/// Close a connection. Blocks until the connection is closed.")
     println("/// Close a connection. Blocks until the connection is closed.")
     println("func close() throws")
     println("func close() throws")