Browse Source

Merge pull request #201 from MrMage/wait-timeout

Add a `timeout` option to blocking `send()` and `receive()` calls
Tim Burks 7 years ago
parent
commit
1a9f7e22a7

+ 2 - 1
Makefile

@@ -24,7 +24,8 @@ test-echo:	all
 	kill -9 `cat echo.pid`
 	diff -u test.out Sources/Examples/Echo/test.gold
 
-test-plugin: all
+test-plugin:
+	swift build -v $(CFLAGS) --product protoc-gen-swiftgrpc
 	protoc Sources/Examples/Echo/echo.proto --proto_path=Sources/Examples/Echo --plugin=.build/debug/protoc-gen-swift --plugin=.build/debug/protoc-gen-swiftgrpc --swiftgrpc_out=/tmp --swiftgrpc_opt=TestStubs=true
 	diff -u /tmp/echo.grpc.swift Sources/Examples/Echo/Generated/echo.grpc.swift
 

+ 56 - 16
Sources/Examples/Echo/Generated/echo.grpc.swift

@@ -32,12 +32,17 @@ fileprivate final class Echo_EchoGetCallBase: ClientCallUnaryBase<Echo_EchoReque
 }
 
 internal protocol Echo_EchoExpandCall: ClientCallServerStreaming {
-  /// Call this to wait for a result. Blocking.
-  func receive() throws -> Echo_EchoResponse?
+  /// Do not call this directly, call `receive()` in the protocol extension below instead.
+  func _receive(timeout: DispatchTime) throws -> Echo_EchoResponse?
   /// Call this to wait for a result. Nonblocking.
   func receive(completion: @escaping (ResultOrRPCError<Echo_EchoResponse?>) -> Void) throws
 }
 
+internal extension Echo_EchoExpandCall {
+  /// Call this to wait for a result. Blocking.
+  func receive(timeout: DispatchTime = .distantFuture) throws -> Echo_EchoResponse? { return try self._receive(timeout: timeout) }
+}
+
 fileprivate final class Echo_EchoExpandCallBase: ClientCallServerStreamingBase<Echo_EchoRequest, Echo_EchoResponse>, Echo_EchoExpandCall {
   override class var method: String { return "/echo.Echo/Expand" }
 }
@@ -49,8 +54,8 @@ class Echo_EchoExpandCallTestStub: ClientCallServerStreamingTestStub<Echo_EchoRe
 internal protocol Echo_EchoCollectCall: ClientCallClientStreaming {
   /// Send a message to the stream. Nonblocking.
   func send(_ message: Echo_EchoRequest, completion: @escaping (Error?) -> Void) throws
-  /// Send a message to the stream and wait for the send operation to finish. Blocking.
-  func send(_ message: Echo_EchoRequest) throws
+  /// Do not call this directly, call `send()` in the protocol extension below instead.
+  func _send(_ message: Echo_EchoRequest, timeout: DispatchTime) throws
 
   /// Call this to close the connection and wait for a response. Blocking.
   func closeAndReceive() throws -> Echo_EchoResponse
@@ -58,6 +63,11 @@ internal protocol Echo_EchoCollectCall: ClientCallClientStreaming {
   func closeAndReceive(completion: @escaping (ResultOrRPCError<Echo_EchoResponse>) -> Void) throws
 }
 
+internal extension Echo_EchoCollectCall {
+  /// Send a message to the stream and wait for the send operation to finish. Blocking.
+  func send(_ message: Echo_EchoRequest, timeout: DispatchTime = .distantFuture) throws { try self._send(message, timeout: timeout) }
+}
+
 fileprivate final class Echo_EchoCollectCallBase: ClientCallClientStreamingBase<Echo_EchoRequest, Echo_EchoResponse>, Echo_EchoCollectCall {
   override class var method: String { return "/echo.Echo/Collect" }
 }
@@ -69,15 +79,15 @@ class Echo_EchoCollectCallTestStub: ClientCallClientStreamingTestStub<Echo_EchoR
 }
 
 internal protocol Echo_EchoUpdateCall: ClientCallBidirectionalStreaming {
-  /// Call this to wait for a result. Blocking.
-  func receive() throws -> Echo_EchoResponse?
+  /// Do not call this directly, call `receive()` in the protocol extension below instead.
+  func _receive(timeout: DispatchTime) throws -> Echo_EchoResponse?
   /// Call this to wait for a result. Nonblocking.
   func receive(completion: @escaping (ResultOrRPCError<Echo_EchoResponse?>) -> Void) throws
 
   /// Send a message to the stream. Nonblocking.
   func send(_ message: Echo_EchoRequest, completion: @escaping (Error?) -> Void) throws
-  /// Send a message to the stream and wait for the send operation to finish. Blocking.
-  func send(_ message: Echo_EchoRequest) throws
+  /// Do not call this directly, call `send()` in the protocol extension below instead.
+  func _send(_ message: Echo_EchoRequest, timeout: DispatchTime) throws
 
   /// Call this to close the sending connection. Blocking.
   func closeSend() throws
@@ -85,6 +95,16 @@ internal protocol Echo_EchoUpdateCall: ClientCallBidirectionalStreaming {
   func closeSend(completion: (() -> Void)?) throws
 }
 
+internal extension Echo_EchoUpdateCall {
+  /// Call this to wait for a result. Blocking.
+  func receive(timeout: DispatchTime = .distantFuture) throws -> Echo_EchoResponse? { return try self._receive(timeout: timeout) }
+}
+
+internal extension Echo_EchoUpdateCall {
+  /// Send a message to the stream and wait for the send operation to finish. Blocking.
+  func send(_ message: Echo_EchoRequest, timeout: DispatchTime = .distantFuture) throws { try self._send(message, timeout: timeout) }
+}
+
 fileprivate final class Echo_EchoUpdateCallBase: ClientCallBidirectionalStreamingBase<Echo_EchoRequest, Echo_EchoResponse>, Echo_EchoUpdateCall {
   override class var method: String { return "/echo.Echo/Update" }
 }
@@ -207,21 +227,26 @@ class Echo_EchoGetSessionTestStub: ServerSessionUnaryTestStub, Echo_EchoGetSessi
 internal protocol Echo_EchoExpandSession: ServerSessionServerStreaming {
   /// Send a message to the stream. Nonblocking.
   func send(_ message: Echo_EchoResponse, completion: @escaping (Error?) -> Void) throws
-  /// Send a message to the stream and wait for the send operation to finish. Blocking.
-  func send(_ message: Echo_EchoResponse) throws
+  /// Do not call this directly, call `send()` in the protocol extension below instead.
+  func _send(_ message: Echo_EchoResponse, timeout: DispatchTime) throws
 
   /// Close the connection and send the status. Non-blocking.
   /// You MUST call this method once you are done processing the request.
   func close(withStatus status: ServerStatus, completion: (() -> Void)?) throws
 }
 
+internal extension Echo_EchoExpandSession {
+  /// Send a message to the stream and wait for the send operation to finish. Blocking.
+  func send(_ message: Echo_EchoResponse, timeout: DispatchTime = .distantFuture) throws { try self._send(message, timeout: timeout) }
+}
+
 fileprivate final class Echo_EchoExpandSessionBase: ServerSessionServerStreamingBase<Echo_EchoRequest, Echo_EchoResponse>, Echo_EchoExpandSession {}
 
 class Echo_EchoExpandSessionTestStub: ServerSessionServerStreamingTestStub<Echo_EchoResponse>, Echo_EchoExpandSession {}
 
 internal protocol Echo_EchoCollectSession: ServerSessionClientStreaming {
-  /// Call this to wait for a result. Blocking.
-  func receive() throws -> Echo_EchoRequest?
+  /// Do not call this directly, call `receive()` in the protocol extension below instead.
+  func _receive(timeout: DispatchTime) throws -> Echo_EchoRequest?
   /// Call this to wait for a result. Nonblocking.
   func receive(completion: @escaping (ResultOrRPCError<Echo_EchoRequest?>) -> Void) throws
 
@@ -234,26 +259,41 @@ internal protocol Echo_EchoCollectSession: ServerSessionClientStreaming {
   func sendErrorAndClose(status: ServerStatus, completion: (() -> Void)?) throws
 }
 
+internal extension Echo_EchoCollectSession {
+  /// Call this to wait for a result. Blocking.
+  func receive(timeout: DispatchTime = .distantFuture) throws -> Echo_EchoRequest? { return try self._receive(timeout: timeout) }
+}
+
 fileprivate final class Echo_EchoCollectSessionBase: ServerSessionClientStreamingBase<Echo_EchoRequest, Echo_EchoResponse>, Echo_EchoCollectSession {}
 
 class Echo_EchoCollectSessionTestStub: ServerSessionClientStreamingTestStub<Echo_EchoRequest, Echo_EchoResponse>, Echo_EchoCollectSession {}
 
 internal protocol Echo_EchoUpdateSession: ServerSessionBidirectionalStreaming {
-  /// Call this to wait for a result. Blocking.
-  func receive() throws -> Echo_EchoRequest?
+  /// Do not call this directly, call `receive()` in the protocol extension below instead.
+  func _receive(timeout: DispatchTime) throws -> Echo_EchoRequest?
   /// Call this to wait for a result. Nonblocking.
   func receive(completion: @escaping (ResultOrRPCError<Echo_EchoRequest?>) -> Void) throws
 
   /// Send a message to the stream. Nonblocking.
   func send(_ message: Echo_EchoResponse, completion: @escaping (Error?) -> Void) throws
-  /// Send a message to the stream and wait for the send operation to finish. Blocking.
-  func send(_ message: Echo_EchoResponse) throws
+  /// Do not call this directly, call `send()` in the protocol extension below instead.
+  func _send(_ message: Echo_EchoResponse, timeout: DispatchTime) throws
 
   /// Close the connection and send the status. Non-blocking.
   /// You MUST call this method once you are done processing the request.
   func close(withStatus status: ServerStatus, completion: (() -> Void)?) throws
 }
 
+internal extension Echo_EchoUpdateSession {
+  /// Call this to wait for a result. Blocking.
+  func receive(timeout: DispatchTime = .distantFuture) throws -> Echo_EchoRequest? { return try self._receive(timeout: timeout) }
+}
+
+internal extension Echo_EchoUpdateSession {
+  /// Send a message to the stream and wait for the send operation to finish. Blocking.
+  func send(_ message: Echo_EchoResponse, timeout: DispatchTime = .distantFuture) throws { try self._send(message, timeout: timeout) }
+}
+
 fileprivate final class Echo_EchoUpdateSessionBase: ServerSessionBidirectionalStreamingBase<Echo_EchoRequest, Echo_EchoResponse>, Echo_EchoUpdateSession {}
 
 class Echo_EchoUpdateSessionTestStub: ServerSessionBidirectionalStreamingTestStub<Echo_EchoRequest, Echo_EchoResponse>, Echo_EchoUpdateSession {}

+ 3 - 3
Sources/SwiftGRPC/Runtime/ClientCallBidirectionalStreaming.swift

@@ -59,20 +59,20 @@ open class ClientCallBidirectionalStreamingTestStub<InputType: Message, OutputTy
   
   public init() {}
 
-  open func receive() throws -> OutputType? {
+  open func _receive(timeout: DispatchTime) throws -> OutputType? {
     defer { if !outputs.isEmpty { outputs.removeFirst() } }
     return outputs.first
   }
   
   open func receive(completion: @escaping (ResultOrRPCError<OutputType?>) -> Void) throws {
-    completion(.result(try self.receive()))
+    completion(.result(try self._receive(timeout: .distantFuture)))
   }
 
   open func send(_ message: InputType, completion _: @escaping (Error?) -> Void) throws {
     inputs.append(message)
   }
   
-  open func send(_ message: InputType) throws {
+  open func _send(_ message: InputType, timeout: DispatchTime) throws {
     inputs.append(message)
   }
 

+ 1 - 1
Sources/SwiftGRPC/Runtime/ClientCallClientStreaming.swift

@@ -76,7 +76,7 @@ open class ClientCallClientStreamingTestStub<InputType: Message, OutputType: Mes
     inputs.append(message)
   }
   
-  open func send(_ message: InputType) throws {
+  open func _send(_ message: InputType, timeout: DispatchTime) throws {
     inputs.append(message)
   }
 

+ 2 - 2
Sources/SwiftGRPC/Runtime/ClientCallServerStreaming.swift

@@ -45,13 +45,13 @@ open class ClientCallServerStreamingTestStub<OutputType: Message>: ClientCallSer
   
   public init() {}
   
-  open func receive() throws -> OutputType? {
+  open func _receive(timeout: DispatchTime) throws -> OutputType? {
     defer { if !outputs.isEmpty { outputs.removeFirst() } }
     return outputs.first
   }
   
   open func receive(completion: @escaping (ResultOrRPCError<OutputType?>) -> Void) throws {
-    completion(.result(try self.receive()))
+    completion(.result(try self._receive(timeout: .distantFuture)))
   }
 
   open func cancel() {}

+ 2 - 1
Sources/SwiftGRPC/Runtime/RPCError.swift

@@ -19,13 +19,14 @@ import Foundation
 /// Type for errors thrown from generated client code.
 public enum RPCError: Error {
   case invalidMessageReceived
+  case timedOut
   case callError(CallResult)
 }
 
 public extension RPCError {
   var callResult: CallResult? {
     switch self {
-    case .invalidMessageReceived: return nil
+    case .invalidMessageReceived, .timedOut: return nil
     case .callError(let callResult): return callResult
     }
   }

+ 3 - 3
Sources/SwiftGRPC/Runtime/ServerSessionBidirectionalStreaming.swift

@@ -69,20 +69,20 @@ open class ServerSessionBidirectionalStreamingTestStub<InputType: Message, Outpu
   open var outputs: [OutputType] = []
   open var status: ServerStatus?
 
-  open func receive() throws -> InputType? {
+  open func _receive(timeout: DispatchTime) throws -> InputType? {
     defer { if !inputs.isEmpty { inputs.removeFirst() } }
     return inputs.first
   }
   
   open func receive(completion: @escaping (ResultOrRPCError<InputType?>) -> Void) throws {
-    completion(.result(try self.receive()))
+    completion(.result(try self._receive(timeout: .distantFuture)))
   }
 
   open func send(_ message: OutputType, completion _: @escaping (Error?) -> Void) throws {
     outputs.append(message)
   }
 
-  open func send(_ message: OutputType) throws {
+  open func _send(_ message: OutputType, timeout: DispatchTime) throws {
     outputs.append(message)
   }
 

+ 2 - 2
Sources/SwiftGRPC/Runtime/ServerSessionClientStreaming.swift

@@ -75,13 +75,13 @@ open class ServerSessionClientStreamingTestStub<InputType: Message, OutputType:
   open var output: OutputType?
   open var status: ServerStatus?
 
-  open func receive() throws -> InputType? {
+  open func _receive(timeout: DispatchTime) throws -> InputType? {
     defer { if !inputs.isEmpty { inputs.removeFirst() } }
     return inputs.first
   }
   
   open func receive(completion: @escaping (ResultOrRPCError<InputType?>) -> Void) throws {
-    completion(.result(try self.receive()))
+    completion(.result(try self._receive(timeout: .distantFuture)))
   }
 
   open func sendAndClose(response: OutputType, status: ServerStatus, completion: (() -> Void)?) throws {

+ 1 - 1
Sources/SwiftGRPC/Runtime/ServerSessionServerStreaming.swift

@@ -72,7 +72,7 @@ open class ServerSessionServerStreamingTestStub<OutputType: Message>: ServerSess
     outputs.append(message)
   }
 
-  open func send(_ message: OutputType) throws {
+  open func _send(_ message: OutputType, timeout: DispatchTime) throws {
     outputs.append(message)
   }
 

+ 4 - 2
Sources/SwiftGRPC/Runtime/StreamReceiving.swift

@@ -43,14 +43,16 @@ extension StreamReceiving {
     }
   }
   
-  public func receive() throws -> ReceivedType? {
+  public func _receive(timeout: DispatchTime) throws -> ReceivedType? {
     var result: ResultOrRPCError<ReceivedType?>?
     let sem = DispatchSemaphore(value: 0)
     try receive {
       result = $0
       sem.signal()
     }
-    _ = sem.wait()
+    if sem.wait(timeout: timeout) == .timedOut {
+      throw RPCError.timedOut
+    }
     switch result! {
     case .result(let response): return response
     case .error(let error): throw error

+ 4 - 2
Sources/SwiftGRPC/Runtime/StreamSending.swift

@@ -29,14 +29,16 @@ extension StreamSending {
     try call.sendMessage(data: message.serializedData(), completion: completion)
   }
   
-  public func send(_ message: SentType) throws {
+  public func _send(_ message: SentType, timeout: DispatchTime) throws {
     var resultError: Error?
     let sem = DispatchSemaphore(value: 0)
     try send(message) {
       resultError = $0
       sem.signal()
     }
-    _ = sem.wait()
+    if sem.wait(timeout: timeout) == .timedOut {
+      throw RPCError.timedOut
+    }
     if let resultError = resultError {
       throw resultError
     }

+ 8 - 0
Sources/protoc-gen-swiftgrpc/Generator-Client.swift

@@ -60,6 +60,8 @@ extension Generator {
     outdent()
     println("}")
     println()
+    printStreamReceiveExtension(extendedType: callName, receivedType: methodOutputName)
+    println()
     println("fileprivate final class \(callName)Base: ClientCallServerStreamingBase<\(methodInputName), \(methodOutputName)>, \(callName) {")
     indent()
     println("override class var method: String { return \(methodPath) }")
@@ -88,6 +90,8 @@ extension Generator {
     outdent()
     println("}")
     println()
+    printStreamSendExtension(extendedType: callName, sentType: methodInputName)
+    println()
     println("fileprivate final class \(callName)Base: ClientCallClientStreamingBase<\(methodInputName), \(methodOutputName)>, \(callName) {")
     indent()
     println("override class var method: String { return \(methodPath) }")
@@ -120,6 +124,10 @@ extension Generator {
     outdent()
     println("}")
     println()
+    printStreamReceiveExtension(extendedType: callName, receivedType: methodOutputName)
+    println()
+    printStreamSendExtension(extendedType: callName, sentType: methodInputName)
+    println()
     println("fileprivate final class \(callName)Base: ClientCallBidirectionalStreamingBase<\(methodInputName), \(methodOutputName)>, \(callName) {")
     indent()
     println("override class var method: String { return \(methodPath) }")

+ 21 - 3
Sources/protoc-gen-swiftgrpc/Generator-Methods.swift

@@ -19,16 +19,34 @@ import SwiftProtobufPluginLibrary
 
 extension Generator {
   func printStreamReceiveMethods(receivedType: String) {
-    println("/// Call this to wait for a result. Blocking.")
-    println("func receive() throws -> \(receivedType)?")
+    println("/// Do not call this directly, call `receive()` in the protocol extension below instead.")
+    println("func _receive(timeout: DispatchTime) throws -> \(receivedType)?")
     println("/// Call this to wait for a result. Nonblocking.")
     println("func receive(completion: @escaping (ResultOrRPCError<\(receivedType)?>) -> Void) throws")
   }
   
+  func printStreamReceiveExtension(extendedType: String, receivedType: String) {
+    println("\(access) extension \(extendedType) {")
+    indent()
+    println("/// Call this to wait for a result. Blocking.")
+    println("func receive(timeout: DispatchTime = .distantFuture) throws -> \(receivedType)? { return try self._receive(timeout: timeout) }")
+    outdent()
+    println("}")
+  }
+  
   func printStreamSendMethods(sentType: String) {
     println("/// Send a message to the stream. Nonblocking.")
     println("func send(_ message: \(sentType), completion: @escaping (Error?) -> Void) throws")
+    println("/// Do not call this directly, call `send()` in the protocol extension below instead.")
+    println("func _send(_ message: \(sentType), timeout: DispatchTime) throws")
+  }
+  
+  func printStreamSendExtension(extendedType: String,sentType: String) {
+    println("\(access) extension \(extendedType) {")
+    indent()
     println("/// Send a message to the stream and wait for the send operation to finish. Blocking.")
-    println("func send(_ message: \(sentType)) throws")
+    println("func send(_ message: \(sentType), timeout: DispatchTime = .distantFuture) throws { try self._send(message, timeout: timeout) }")
+    outdent()
+    println("}")
   }
 }

+ 8 - 0
Sources/protoc-gen-swiftgrpc/Generator-Server.swift

@@ -161,6 +161,8 @@ extension Generator {
     outdent()
     println("}")
     println()
+    printStreamReceiveExtension(extendedType: methodSessionName, receivedType: methodInputName)
+    println()
     println("fileprivate final class \(methodSessionName)Base: ServerSessionClientStreamingBase<\(methodInputName), \(methodOutputName)>, \(methodSessionName) {}")
     if options.generateTestStubs {
       println()
@@ -183,6 +185,8 @@ extension Generator {
     outdent()
     println("}")
     println()
+    printStreamSendExtension(extendedType: methodSessionName, sentType: methodOutputName)
+    println()
     println("fileprivate final class \(methodSessionName)Base: ServerSessionServerStreamingBase<\(methodInputName), \(methodOutputName)>, \(methodSessionName) {}")
     if options.generateTestStubs {
       println()
@@ -201,6 +205,10 @@ extension Generator {
     outdent()
     println("}")
     println()
+    printStreamReceiveExtension(extendedType: methodSessionName, receivedType: methodInputName)
+    println()
+    printStreamSendExtension(extendedType: methodSessionName, sentType: methodOutputName)
+    println()
     println("fileprivate final class \(methodSessionName)Base: ServerSessionBidirectionalStreamingBase<\(methodInputName), \(methodOutputName)>, \(methodSessionName) {}")
     if options.generateTestStubs {
       println()

+ 23 - 0
Tests/SwiftGRPCTests/ClientTimeoutTests.swift

@@ -79,6 +79,29 @@ extension ClientTimeoutTests {
     waitForExpectations(timeout: defaultTimeout)
   }
   
+  func testBidirectionalStreamingTimeoutPassedToReceiveMethod() {
+    let completionHandlerExpectation = expectation(description: "final completion handler called")
+    let call = try! client.update { callResult in
+      XCTAssertEqual(.ok, callResult.statusCode)
+      completionHandlerExpectation.fulfill()
+    }
+    
+    do {
+      let result = try call.receive(timeout: .now() + .milliseconds(10))
+      XCTFail("should have thrown, received \(String(describing: result)) instead")
+    } catch let receiveError {
+      if case .timedOut = receiveError as! RPCError {
+        // This is the expected case - we need to formulate this as an if statement to use case-based pattern matching.
+      } else {
+        XCTFail("received error \(receiveError) instead of .timedOut")
+      }
+    }
+	
+	try! call.closeSend()
+    
+    waitForExpectations(timeout: defaultTimeout)
+  }
+  
   // FIXME(danielalm): Add support for setting a maximum timeout on the server, to prevent DoS attacks where clients
   // start a ton of calls, but never finish them (i.e. essentially leaking a connection on the server side).
 }