Browse Source

Change the `ServerSession`s such that errors thrown during operation are returned to the client.

If you throw an error of type `ServerStatus`, that is returned to the client. For other errors, `ServerStatus.processingError` is returned.

This change also extends the interface of the `ServerSession`s to accept a custom status and a completion handler when closing the stream.
It also makes closing the stream mandatory for `ServerSessionBidirectionalStreaming`, `ServerSessionClientStreaming`, and `ServerSessionServerStreaming` (in fact it was already mandatory for the first two).

Also adds tests to verify this behavior.
Daniel Alm 7 years ago
parent
commit
da28b9aec2

+ 3 - 2
Sources/Examples/Echo/EchoProvider.swift

@@ -38,6 +38,7 @@ class EchoProvider: Echo_EchoProvider {
       }
     }
     session.waitForSendOperationsToFinish()
+    try session.close(withStatus: .ok, completion: nil)
   }
 
   // collect collects a sequence of messages and returns them concatenated when the caller closes.
@@ -55,7 +56,7 @@ class EchoProvider: Echo_EchoProvider {
     }
     var response = Echo_EchoResponse()
     response.text = "Swift echo collect: " + parts.joined(separator: " ")
-    try session.sendAndClose(response)
+    try session.sendAndClose(response: response, status: .ok, completion: nil)
   }
 
   // update streams back messages as they are received in an input stream.
@@ -79,6 +80,6 @@ class EchoProvider: Echo_EchoProvider {
       }
     }
     session.waitForSendOperationsToFinish()
-    try session.close()
+    try session.close(withStatus: .ok, completion: nil)
   }
 }

+ 10 - 4
Sources/Examples/Echo/Generated/echo.grpc.swift

@@ -203,6 +203,10 @@ class Echo_EchoGetSessionTestStub: ServerSessionUnaryTestStub, Echo_EchoGetSessi
 internal protocol Echo_EchoExpandSession: ServerSessionServerStreaming {
   /// Call this to send each message in the request stream. Nonblocking.
   func send(_ message: Echo_EchoResponse, completion: @escaping (Error?) -> Void) throws
+
+  /// Close the connection and send the status. Non-blocking.
+  /// You MUST call this method once you are done processing the request.
+  func close(withStatus status: ServerStatus, completion: ((CallResult) -> Void)?) throws
 }
 
 fileprivate final class Echo_EchoExpandSessionBase: ServerSessionServerStreamingBase<Echo_EchoRequest, Echo_EchoResponse>, Echo_EchoExpandSession {}
@@ -215,8 +219,9 @@ internal protocol Echo_EchoCollectSession: ServerSessionClientStreaming {
   /// Call this to wait for a result. Nonblocking.
   func receive(completion: @escaping (ResultOrRPCError<Echo_EchoRequest?>) -> Void) throws
 
-  /// Send a response and close the connection.
-  func sendAndClose(_ response: Echo_EchoResponse) throws
+  /// Close the connection and send a single result. Non-blocking.
+  /// You MUST call this method once you are done processing the request.
+  func sendAndClose(response: Echo_EchoResponse, status: ServerStatus, completion: ((CallResult) -> Void)?) throws
 }
 
 fileprivate final class Echo_EchoCollectSessionBase: ServerSessionClientStreamingBase<Echo_EchoRequest, Echo_EchoResponse>, Echo_EchoCollectSession {}
@@ -232,8 +237,9 @@ internal protocol Echo_EchoUpdateSession: ServerSessionBidirectionalStreaming {
   /// Call this to send each message in the request stream. Nonblocking.
   func send(_ message: Echo_EchoResponse, completion: @escaping (Error?) -> Void) throws
 
-  /// Close a connection. Blocks until the connection is closed.
-  func close() throws
+  /// Close the connection and send the status. Non-blocking.
+  /// You MUST call this method once you are done processing the request.
+  func close(withStatus status: ServerStatus, completion: ((CallResult) -> Void)?) throws
 }
 
 fileprivate final class Echo_EchoUpdateSessionBase: ServerSessionBidirectionalStreamingBase<Echo_EchoRequest, Echo_EchoResponse>, Echo_EchoUpdateSession {}

+ 1 - 3
Sources/Examples/Simple/main.swift

@@ -105,9 +105,7 @@ func server() throws {
         "2": "two"
       ])
       try requestHandler.sendResponse(message: replyMessage.data(using: .utf8)!,
-                                      statusCode: .ok,
-                                      statusMessage: "OK",
-                                      trailingMetadata: trailingMetadataToSend)
+                                      status: ServerStatus(code: .ok, message: "OK", trailingMetadata: trailingMetadataToSend))
 
       print("------------------------------")
     } catch {

+ 13 - 0
Sources/SwiftGRPC/Core/CallResult.swift

@@ -41,6 +41,16 @@ public struct CallResult: CustomStringConvertible {
     trailingMetadata = op.receivedTrailingMetadata()
   }
   
+  fileprivate init(success: Bool, statusCode: StatusCode, statusMessage: String?, resultData: Data?,
+                   initialMetadata: Metadata?, trailingMetadata: Metadata?) {
+    self.success = success
+    self.statusCode = statusCode
+    self.statusMessage = statusMessage
+    self.resultData = resultData
+    self.initialMetadata = initialMetadata
+    self.trailingMetadata = trailingMetadata
+  }
+  
   public var description: String {
     var result = "\(success ? "successful" : "unsuccessful"), status \(statusCode)"
     if let statusMessage = self.statusMessage {
@@ -60,4 +70,7 @@ public struct CallResult: CustomStringConvertible {
     }
     return result
   }
+  
+  static let fakeOK = CallResult(success: true, statusCode: .ok, statusMessage: "OK", resultData: nil,
+                                 initialMetadata: nil, trailingMetadata: nil)
 }

+ 31 - 53
Sources/SwiftGRPC/Core/Handler.swift

@@ -95,80 +95,58 @@ public class Handler {
   /// - Parameter completion: a completion handler to call after the metadata has been sent
   public func sendMetadata(initialMetadata: Metadata,
                            completion: ((Bool) -> Void)? = nil) throws {
-    let operations = OperationGroup(call: call,
-                                    operations: [.sendInitialMetadata(initialMetadata)],
-                                    completion: completion != nil
-                                      ? { operationGroup in completion?(operationGroup.success) }
-                                      : nil)
-    try call.perform(operations)
+    try call.perform(OperationGroup(
+      call: call,
+      operations: [.sendInitialMetadata(initialMetadata)],
+      completion: completion != nil
+        ? { operationGroup in completion?(operationGroup.success) }
+        : nil))
   }
 
   /// Receive the message sent with a call
   ///
   public func receiveMessage(initialMetadata: Metadata,
                              completion: @escaping (Data?) -> Void) throws {
-    let operations = OperationGroup(call: call,
-                                    operations: [
-                                      .sendInitialMetadata(initialMetadata),
-                                      .receiveMessage
+    try call.perform(OperationGroup(
+      call: call,
+      operations: [
+        .sendInitialMetadata(initialMetadata),
+        .receiveMessage
     ]) { operationGroup in
       if operationGroup.success {
         completion(operationGroup.receivedMessage()?.data())
       } else {
         completion(nil)
       }
-    }
-    try call.perform(operations)
+    })
   }
 
   /// Sends the response to a request
-  ///
-  /// - Parameter message: the message to send
-  /// - Parameter statusCode: status code to send
-  /// - Parameter statusMessage: status message to send
-  /// - Parameter trailingMetadata: trailing metadata to send
-  public func sendResponse(message: Data,
-                           statusCode: StatusCode,
-                           statusMessage: String,
-                           trailingMetadata: Metadata) throws {
+  public func sendResponse(message: Data, status: ServerStatus,
+                           completion: ((CallResult) -> Void)? = nil) throws {
     let messageBuffer = ByteBuffer(data: message)
-    let operations = OperationGroup(call: call,
-                                    operations: [
-                                      .receiveCloseOnServer,
-                                      .sendStatusFromServer(statusCode, statusMessage, trailingMetadata),
-                                      .sendMessage(messageBuffer)
-    ]) { _ in
+    try call.perform(OperationGroup(
+      call: call,
+      operations: [
+        .sendMessage(messageBuffer),
+        .receiveCloseOnServer,
+        .sendStatusFromServer(status.code, status.message, status.trailingMetadata)
+    ]) { operationGroup in
+      completion?(CallResult(operationGroup))
       self.shutdown()
-    }
-    try call.perform(operations)
+    })
   }
 
   /// Send final status to the client
-  ///
-  /// - Parameter statusCode: status code to send
-  /// - Parameter statusMessage: status message to send
-  /// - Parameter trailingMetadata: trailing metadata to send
-  /// - Parameter completion: a completion handler to call after the status has been sent
-  public func sendStatus(statusCode: StatusCode,
-                         statusMessage: String,
-                         trailingMetadata: Metadata = Metadata(),
-                         completion: ((Bool) -> Void)? = nil) throws {
-    let operations = OperationGroup(call: call,
-                                    operations: [
-                                      .receiveCloseOnServer,
-                                      .sendStatusFromServer(statusCode, statusMessage, trailingMetadata)
+  public func sendStatus(_ status: ServerStatus, completion: ((CallResult) -> Void)? = nil) throws {
+    try call.perform(OperationGroup(
+      call: call,
+      operations: [
+        .receiveCloseOnServer,
+        .sendStatusFromServer(status.code, status.message, status.trailingMetadata)
     ]) { operationGroup in
-      completion?(operationGroup.success)
+      completion?(CallResult(operationGroup))
       self.shutdown()
-    }
-    try call.perform(operations)
-  }
-  
-  public func sendError(_ error: ServerErrorStatus,
-                        completion: ((Bool) -> Void)? = nil) throws {
-    try sendStatus(statusCode: error.statusCode,
-                   statusMessage: error.statusMessage,
-                   trailingMetadata: error.trailingMetadata,
-                   completion: completion)
+    })
   }
 }

+ 11 - 15
Sources/SwiftGRPC/Runtime/ServerSession.swift

@@ -18,35 +18,34 @@ import Dispatch
 import Foundation
 import SwiftProtobuf
 
-public struct ServerErrorStatus: Error {
-  public let statusCode: StatusCode
-  public let statusMessage: String
+public struct ServerStatus: Error {
+  public let code: StatusCode
+  public let message: String
   public let trailingMetadata: Metadata
   
-  public init(statusCode: StatusCode, statusMessage: String, trailingMetadata: Metadata = Metadata()) {
-    self.statusCode = statusCode
-    self.statusMessage = statusMessage
+  public init(code: StatusCode, message: String, trailingMetadata: Metadata = Metadata()) {
+    self.code = code
+    self.message = message
     self.trailingMetadata = trailingMetadata
   }
+  
+  public static let ok = ServerStatus(code: .ok, message: "OK")
+  public static let processingError = ServerStatus(code: .internalError, message: "unknown error processing request")
+  public static let noRequestData = ServerStatus(code: .invalidArgument, message: "no request data received")
+  public static let sendingInitialMetadataFailed = ServerStatus(code: .internalError, message: "sending initial metadata failed")
 }
 
 public protocol ServerSession: class {
   var requestMetadata: Metadata { get }
 
-  var statusCode: StatusCode { get set }
-  var statusMessage: String { get set }
   var initialMetadata: Metadata { get set }
-  var trailingMetadata: Metadata { get set }
 }
 
 open class ServerSessionBase: ServerSession {
   public var handler: Handler
   public var requestMetadata: Metadata { return handler.requestMetadata }
 
-  public var statusCode: StatusCode = .ok
-  public var statusMessage: String = "OK"
   public var initialMetadata: Metadata = Metadata()
-  public var trailingMetadata: Metadata = Metadata()
   
   public var call: Call { return handler.call }
 
@@ -58,10 +57,7 @@ open class ServerSessionBase: ServerSession {
 open class ServerSessionTestStub: ServerSession {
   open var requestMetadata = Metadata()
 
-  open var statusCode = StatusCode.ok
-  open var statusMessage = "OK"
   open var initialMetadata = Metadata()
-  open var trailingMetadata = Metadata()
 
   public init() {}
 }

+ 26 - 15
Sources/SwiftGRPC/Runtime/ServerSessionBidirectionalStreaming.swift

@@ -33,22 +33,29 @@ open class ServerSessionBidirectionalStreamingBase<InputType: Message, OutputTyp
     self.providerBlock = providerBlock
     super.init(handler: handler)
   }
-
-  public func close() throws {
-    let sem = DispatchSemaphore(value: 0)
-    try handler.sendStatus(statusCode: statusCode,
-                           statusMessage: statusMessage,
-                           trailingMetadata: trailingMetadata) { _ in sem.signal() }
-    _ = sem.wait()
-  }
-
+  
   public func run(queue: DispatchQueue) throws {
-    try handler.sendMetadata(initialMetadata: initialMetadata) { _ in
+    try handler.sendMetadata(initialMetadata: initialMetadata) { success in
       queue.async {
-        do {
-          try self.providerBlock(self)
-        } catch {
-          print("error \(error)")
+        var responseStatus: ServerStatus?
+        if success {
+          do {
+            try self.providerBlock(self)
+          } catch {
+            responseStatus = (error as? ServerStatus) ?? .processingError
+          }
+        } else {
+          print("ServerSessionBidirectionalStreamingBase.run sending initial metadata failed")
+          responseStatus = .sendingInitialMetadataFailed
+        }
+        
+        if let responseStatus = responseStatus {
+          // Error encountered, notify the client.
+          do {
+            try self.handler.sendStatus(responseStatus)
+          } catch {
+            print("ServerSessionBidirectionalStreamingBase.run error sending status: \(error)")
+          }
         }
       }
     }
@@ -60,6 +67,7 @@ open class ServerSessionBidirectionalStreamingBase<InputType: Message, OutputTyp
 open class ServerSessionBidirectionalStreamingTestStub<InputType: Message, OutputType: Message>: ServerSessionTestStub, ServerSessionBidirectionalStreaming {
   open var inputs: [InputType] = []
   open var outputs: [OutputType] = []
+  open var status: ServerStatus?
 
   open func receive() throws -> InputType? {
     defer { inputs.removeFirst() }
@@ -75,7 +83,10 @@ open class ServerSessionBidirectionalStreamingTestStub<InputType: Message, Outpu
     outputs.append(message)
   }
 
-  open func close() throws {}
+  open func close(withStatus status: ServerStatus, completion: ((CallResult) -> Void)?) throws {
+    self.status = status
+    completion?(.fakeOK)
+  }
 
   open func waitForSendOperationsToFinish() {}
 }

+ 29 - 13
Sources/SwiftGRPC/Runtime/ServerSessionClientStreaming.swift

@@ -30,21 +30,34 @@ open class ServerSessionClientStreamingBase<InputType: Message, OutputType: Mess
     self.providerBlock = providerBlock
     super.init(handler: handler)
   }
-
-  public func sendAndClose(_ response: OutputType) throws {
-    try handler.sendResponse(message: response.serializedData(),
-                             statusCode: statusCode,
-                             statusMessage: statusMessage,
-                             trailingMetadata: trailingMetadata)
+  
+  public func sendAndClose(response: OutputType, status: ServerStatus = .ok,
+                           completion: ((CallResult) -> Void)? = nil) throws {
+    try handler.sendResponse(message: response.serializedData(), status: status, completion: completion)
   }
 
   public func run(queue: DispatchQueue) throws {
-    try handler.sendMetadata(initialMetadata: initialMetadata) { _ in
+    try handler.sendMetadata(initialMetadata: initialMetadata) { success in
       queue.async {
-        do {
-          try self.providerBlock(self)
-        } catch {
-          print("error \(error)")
+        var responseStatus: ServerStatus?
+        if success {
+          do {
+            try self.providerBlock(self)
+          } catch {
+            responseStatus = (error as? ServerStatus) ?? .processingError
+          }
+        } else {
+          print("ServerSessionClientStreamingBase.run sending initial metadata failed")
+          responseStatus = .sendingInitialMetadataFailed
+        }
+        
+        if let responseStatus = responseStatus {
+          // Error encountered, notify the client.
+          do {
+            try self.handler.sendStatus(responseStatus)
+          } catch {
+            print("ServerSessionClientStreamingBase.run error sending status: \(error)")
+          }
         }
       }
     }
@@ -56,6 +69,7 @@ open class ServerSessionClientStreamingBase<InputType: Message, OutputType: Mess
 open class ServerSessionClientStreamingTestStub<InputType: Message, OutputType: Message>: ServerSessionTestStub, ServerSessionClientStreaming {
   open var inputs: [InputType] = []
   open var output: OutputType?
+  open var status: ServerStatus?
 
   open func receive() throws -> InputType? {
     defer { inputs.removeFirst() }
@@ -67,8 +81,10 @@ open class ServerSessionClientStreamingTestStub<InputType: Message, OutputType:
     inputs.removeFirst()
   }
 
-  open func sendAndClose(_ response: OutputType) throws {
-    output = response
+  open func sendAndClose(response: OutputType, status: ServerStatus, completion: ((CallResult) -> Void)?) throws {
+    self.output = response
+    self.status = status
+    completion?(.fakeOK)
   }
 
   open func close() throws {}

+ 26 - 20
Sources/SwiftGRPC/Runtime/ServerSessionServerStreaming.swift

@@ -32,28 +32,30 @@ open class ServerSessionServerStreamingBase<InputType: Message, OutputType: Mess
     self.providerBlock = providerBlock
     super.init(handler: handler)
   }
-
+  
   public func run(queue: DispatchQueue) throws {
     try handler.receiveMessage(initialMetadata: initialMetadata) { requestData in
-      // TODO(danielalm): Unify this behavior with `ServerSessionBidirectionalStreamingBase.run()`.
-      if let requestData = requestData {
-        do {
-          let requestMessage = try InputType(serializedData: requestData)
-          // to keep providers from blocking the server thread,
-          // we dispatch them to another queue.
-          queue.async {
-            do {
-              try self.providerBlock(requestMessage, self)
-              try self.handler.sendStatus(statusCode: self.statusCode,
-                                          statusMessage: self.statusMessage,
-                                          trailingMetadata: self.trailingMetadata,
-                                          completion: nil)
-            } catch {
-              print("error: \(error)")
-            }
+      queue.async {
+        var responseStatus: ServerStatus?
+        if let requestData = requestData {
+          do {
+            let requestMessage = try InputType(serializedData: requestData)
+            try self.providerBlock(requestMessage, self)
+          } catch {
+            responseStatus = (error as? ServerStatus) ?? .processingError
+          }
+        } else {
+          print("ServerSessionServerStreamingBase.run empty request data")
+          responseStatus = .noRequestData
+        }
+        
+        if let responseStatus = responseStatus {
+          // Error encountered, notify the client.
+          do {
+            try self.handler.sendStatus(responseStatus)
+          } catch {
+            print("ServerSessionServerStreamingBase.run error sending status: \(error)")
           }
-        } catch {
-          print("error: \(error)")
         }
       }
     }
@@ -64,12 +66,16 @@ open class ServerSessionServerStreamingBase<InputType: Message, OutputType: Mess
 /// and stores sent values for later verification.
 open class ServerSessionServerStreamingTestStub<OutputType: Message>: ServerSessionTestStub, ServerSessionServerStreaming {
   open var outputs: [OutputType] = []
+  open var status: ServerStatus?
 
   open func send(_ message: OutputType, completion _: @escaping (Error?) -> Void) throws {
     outputs.append(message)
   }
 
-  open func close() throws {}
+  open func close(withStatus status: ServerStatus, completion: ((CallResult) -> Void)?) throws {
+    self.status = status
+    completion?(.fakeOK)
+  }
 
   open func waitForSendOperationsToFinish() {}
 }

+ 22 - 21
Sources/SwiftGRPC/Runtime/ServerSessionUnary.swift

@@ -21,6 +21,8 @@ import SwiftProtobuf
 public protocol ServerSessionUnary: ServerSession {}
 
 open class ServerSessionUnaryBase<InputType: Message, OutputType: Message>: ServerSessionBase, ServerSessionUnary {
+  public typealias SentType = OutputType
+  
   public typealias ProviderBlock = (InputType, ServerSessionUnaryBase) throws -> OutputType
   private var providerBlock: ProviderBlock
 
@@ -29,31 +31,30 @@ open class ServerSessionUnaryBase<InputType: Message, OutputType: Message>: Serv
     super.init(handler: handler)
   }
   
-  public func run(queue _: DispatchQueue) throws {
+  public func run(queue: DispatchQueue) throws {
     try handler.receiveMessage(initialMetadata: initialMetadata) { requestData in
-      guard let requestData = requestData else {
-        print("ServerSessionUnaryBase.run empty request data")
-        do {
-          try self.handler.sendStatus(statusCode: .invalidArgument,
-                                      statusMessage: "no request data received")
-        } catch {
-          print("ServerSessionUnaryBase.run error sending status: \(error)")
+      queue.async {
+        let responseStatus: ServerStatus
+        if let requestData = requestData {
+          do {
+            let requestMessage = try InputType(serializedData: requestData)
+            let responseMessage = try self.providerBlock(requestMessage, self)
+            try self.handler.call.sendMessage(data: responseMessage.serializedData()) {
+              guard let error = $0
+                else { return }
+              print("ServerSessionUnaryBase.run error sending response: \(error)")
+            }
+            responseStatus = .ok
+          } catch {
+            responseStatus = (error as? ServerStatus) ?? .processingError
+          }
+        } else {
+          print("ServerSessionUnaryBase.run empty request data")
+          responseStatus = .noRequestData
         }
-        return
-      }
-      do {
-        let requestMessage = try InputType(serializedData: requestData)
-        let replyMessage = try self.providerBlock(requestMessage, self)
-        try self.handler.sendResponse(message: replyMessage.serializedData(),
-                                      statusCode: self.statusCode,
-                                      statusMessage: self.statusMessage,
-                                      trailingMetadata: self.trailingMetadata)
-      } catch {
-        print("ServerSessionUnaryBase.run error processing request: \(error)")
         
         do {
-          try self.handler.sendError((error as? ServerErrorStatus)
-            ?? ServerErrorStatus(statusCode: .unknown, statusMessage: "unknown error processing request"))
+          try self.handler.sendStatus(responseStatus)
         } catch {
           print("ServerSessionUnaryBase.run error sending status: \(error)")
         }

+ 6 - 0
Sources/SwiftGRPC/Runtime/StreamSending.swift

@@ -33,3 +33,9 @@ extension StreamSending {
     call.messageQueueEmpty.wait()
   }
 }
+
+extension StreamSending where Self: ServerSessionBase {
+  public func close(withStatus status: ServerStatus = .ok, completion: ((CallResult) -> Void)? = nil) throws {
+    try handler.sendStatus(status, completion: completion)
+  }
+}

+ 16 - 4
Sources/protoc-gen-swiftgrpc/Generator-Server.swift

@@ -141,14 +141,19 @@ extension Generator {
       println("class \(methodSessionName)TestStub: ServerSessionUnaryTestStub, \(methodSessionName) {}")
     }
   }
+  
+  private func printServerMethodSendAndClose(sentType: String) {
+    println("/// Close the connection and send a single result. Non-blocking.")
+    println("/// You MUST call this method once you are done processing the request.")
+    println("func sendAndClose(response: \(sentType), status: ServerStatus, completion: ((CallResult) -> Void)?) throws")
+  }
 
   private func printServerMethodClientStreaming() {
     println("\(access) protocol \(methodSessionName): ServerSessionClientStreaming {")
     indent()
     printStreamReceiveMethods(receivedType: methodInputName)
     println()
-    println("/// Send a response and close the connection.")
-    println("func sendAndClose(_ response: \(methodOutputName)) throws")
+    printServerMethodSendAndClose(sentType: methodOutputName)
     outdent()
     println("}")
     println()
@@ -159,10 +164,18 @@ extension Generator {
     }
   }
 
+  private func printServerMethodClose() {
+    println("/// Close the connection and send the status. Non-blocking.")
+    println("/// You MUST call this method once you are done processing the request.")
+    println("func close(withStatus status: ServerStatus, completion: ((CallResult) -> Void)?) throws")
+  }
+  
   private func printServerMethodServerStreaming() {
     println("\(access) protocol \(methodSessionName): ServerSessionServerStreaming {")
     indent()
     printStreamSendMethods(sentType: methodOutputName)
+    println()
+    printServerMethodClose()
     outdent()
     println("}")
     println()
@@ -180,8 +193,7 @@ extension Generator {
     println()
     printStreamSendMethods(sentType: methodOutputName)
     println()
-    println("/// Close a connection. Blocks until the connection is closed.")
-    println("func close() throws")
+    printServerMethodClose()
     outdent()
     println("}")
     println()

+ 1 - 0
Tests/LinuxMain.swift

@@ -21,5 +21,6 @@ XCTMain([
   testCase(ClientTimeoutTests.allTests),
   testCase(ConnectionFailureTests.allTests),
   testCase(EchoTests.allTests),
+  testCase(ServerThrowingTests.allTests),
   testCase(ServerTimeoutTests.allTests)
 ])

+ 12 - 12
Tests/SwiftGRPCTests/GRPCTests.swift

@@ -322,14 +322,14 @@ func handleUnary(requestHandler: Handler, requestCount: Int) throws {
     let replyMessage = serverText
     let trailingMetadataToSend = Metadata(trailingServerMetadata)
     try requestHandler.sendResponse(message: replyMessage.data(using: .utf8)!,
-                                    statusCode: evenStatusCode,
-                                    statusMessage: eventStatusMessage,
-                                    trailingMetadata: trailingMetadataToSend)
+                                    status: ServerStatus(code: evenStatusCode,
+                                                         message: eventStatusMessage,
+                                                         trailingMetadata: trailingMetadataToSend))
   } else {
     let trailingMetadataToSend = Metadata(trailingServerMetadata)
-    try requestHandler.sendStatus(statusCode: oddStatusCode,
-                                  statusMessage: oddStatusMessage,
-                                  trailingMetadata: trailingMetadataToSend)
+    try requestHandler.sendStatus(ServerStatus(code: oddStatusCode,
+                                               message: oddStatusMessage,
+                                               trailingMetadata: trailingMetadataToSend))
   }
 }
 
@@ -357,9 +357,9 @@ func handleServerStream(requestHandler: Handler) throws {
   }
 
   let trailingMetadataToSend = Metadata(trailingServerMetadata)
-  try requestHandler.sendStatus(statusCode: StatusCode.outOfRange,
-                                statusMessage: "Out of range",
-                                trailingMetadata: trailingMetadataToSend)
+  try requestHandler.sendStatus(ServerStatus(code: .outOfRange,
+                                             message: "Out of range",
+                                             trailingMetadata: trailingMetadataToSend))
 }
 
 func handleBiDiStream(requestHandler: Handler) throws {
@@ -400,8 +400,8 @@ func handleBiDiStream(requestHandler: Handler) throws {
 
   let trailingMetadataToSend = Metadata(trailingServerMetadata)
   let sem = DispatchSemaphore(value: 0)
-  try requestHandler.sendStatus(statusCode: StatusCode.resourceExhausted,
-                                statusMessage: "Resource Exhausted",
-                                trailingMetadata: trailingMetadataToSend) { _ in sem.signal() }
+  try requestHandler.sendStatus(ServerStatus(code: .resourceExhausted,
+                                             message: "Resource Exhausted",
+                                             trailingMetadata: trailingMetadataToSend)) { _ in sem.signal() }
   _ = sem.wait()
 }

+ 137 - 0
Tests/SwiftGRPCTests/ServerThrowingTests.swift

@@ -0,0 +1,137 @@
+/*
+ * Copyright 2018, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import Dispatch
+import Foundation
+@testable import SwiftGRPC
+import XCTest
+
+fileprivate let testStatus = ServerStatus(code: .permissionDenied, message: "custom status message")
+
+fileprivate class StatusThrowingProvider: Echo_EchoProvider {
+  func get(request: Echo_EchoRequest, session _: Echo_EchoGetSession) throws -> Echo_EchoResponse {
+    throw testStatus
+  }
+  
+  func expand(request: Echo_EchoRequest, session: Echo_EchoExpandSession) throws {
+    throw testStatus
+  }
+  
+  func collect(session: Echo_EchoCollectSession) throws {
+    throw testStatus
+  }
+  
+  func update(session: Echo_EchoUpdateSession) throws {
+    throw testStatus
+  }
+}
+
+class ServerThrowingTests: BasicEchoTestCase {
+  static var allTests: [(String, (ServerThrowingTests) -> () throws -> Void)] {
+    return [
+      ("testServerThrowsUnary", testServerThrowsUnary),
+      ("testServerThrowsClientStreaming", testServerThrowsClientStreaming),
+      ("testServerThrowsServerStreaming", testServerThrowsServerStreaming),
+      ("testServerThrowsBidirectionalStreaming", testServerThrowsBidirectionalStreaming)
+    ]
+  }
+  
+  override func makeProvider() -> Echo_EchoProvider { return StatusThrowingProvider() }
+}
+
+extension ServerThrowingTests {
+  func testServerThrowsUnary() {
+    do {
+      _ = try client.get(Echo_EchoRequest(text: "foo")).text
+      XCTFail("should have thrown")
+    } catch {
+      guard case let .callError(callResult) = error as! RPCError
+        else { XCTFail("unexpected error \(error)"); return }
+      XCTAssertEqual(.permissionDenied, callResult.statusCode)
+      XCTAssertEqual("custom status message", callResult.statusMessage)
+    }
+  }
+  
+  func testServerThrowsClientStreaming() {
+    let completionHandlerExpectation = expectation(description: "final completion handler called")
+    let call = try! client.collect { callResult in
+      XCTAssertEqual(.permissionDenied, callResult.statusCode)
+      XCTAssertEqual("custom status message", callResult.statusMessage)
+      completionHandlerExpectation.fulfill()
+    }
+    
+    let sendExpectation = expectation(description: "send completion handler 1 called")
+    try! call.send(Echo_EchoRequest(text: "foo")) { [sendExpectation] in
+      // The server only times out later in its lifecycle, so we shouldn't get an error when trying to send a message.
+      XCTAssertNil($0)
+      sendExpectation.fulfill()
+    }
+    call.waitForSendOperationsToFinish()
+    
+    do {
+      _ = try call.closeAndReceive()
+      XCTFail("should have thrown")
+    } catch let receiveError {
+      XCTAssertEqual(.unknown, (receiveError as! RPCError).callResult!.statusCode)
+    }
+    
+    waitForExpectations(timeout: defaultTimeout)
+  }
+  
+  func testServerThrowsServerStreaming() {
+    let completionHandlerExpectation = expectation(description: "completion handler called")
+    let call = try! client.expand(Echo_EchoRequest(text: "foo bar baz")) { callResult in
+      XCTAssertEqual(.permissionDenied, callResult.statusCode)
+      XCTAssertEqual("custom status message", callResult.statusMessage)
+      completionHandlerExpectation.fulfill()
+    }
+    
+    // TODO(danielalm): Why doesn't `call.receive()` throw once the call times out?
+    XCTAssertNil(try! call.receive())
+    
+    waitForExpectations(timeout: defaultTimeout)
+  }
+  
+  func testServerThrowsBidirectionalStreaming() {
+    let completionHandlerExpectation = expectation(description: "completion handler called")
+    let call = try! client.update { callResult in
+      XCTAssertEqual(.permissionDenied, callResult.statusCode)
+      XCTAssertEqual("custom status message", callResult.statusMessage)
+      completionHandlerExpectation.fulfill()
+    }
+    
+    let sendExpectation = expectation(description: "send completion handler 1 called")
+    try! call.send(Echo_EchoRequest(text: "foo")) { [sendExpectation] in
+      // The server only times out later in its lifecycle, so we shouldn't get an error when trying to send a message.
+      XCTAssertNil($0)
+      sendExpectation.fulfill()
+    }
+    call.waitForSendOperationsToFinish()
+    
+    // FIXME(danielalm): Why does `call.receive()` only throw on Linux (but not macOS) once the call times out?
+    #if os(Linux)
+    do {
+      _ = try call.receive()
+      XCTFail("should have thrown")
+    } catch let receiveError {
+      XCTAssertEqual(.unknown, (receiveError as! RPCError).callResult!.statusCode)
+    }
+    #else
+    XCTAssertNil(try! call.receive())
+    #endif
+    
+    waitForExpectations(timeout: defaultTimeout)
+  }
+}