Browse Source

Enforce request cardinality for unary-request calls also for the case of zero request messages being sent. (#392)

Otherwise, the server will never respond to a call that gets closed without the client sending a response.

In addition, we introduce a method `sendErrorStatus` (happy to discuss naming) on `BaseCallHandler` that sends an error status to the client while ensuring that all call context promises are fulfilled. This method is required (and needs to be overridden) because only the concrete call subclass knows which promises need to be fulfilled.
Daniel Alm 6 years ago
parent
commit
d4a6366bf7

+ 13 - 3
Sources/SwiftGRPCNIO/CallHandlers/BaseCallHandler.swift

@@ -19,7 +19,7 @@ public class BaseCallHandler<RequestMessage: Message, ResponseMessage: Message>:
   /// Called when the client has half-closed the stream, indicating that they won't send any further data.
   ///
   /// Overridden by subclasses if the "end-of-stream" event is relevant.
-  public func endOfStreamReceived() { }
+  public func endOfStreamReceived() throws { }
 
   /// Whether this handler can still write messages to the client.
   private var serverCanWrite = true
@@ -30,6 +30,12 @@ public class BaseCallHandler<RequestMessage: Message, ResponseMessage: Message>:
   public init(errorDelegate: ServerErrorDelegate?) {
     self.errorDelegate = errorDelegate
   }
+  
+  /// Sends an error status to the client while ensuring that all call context promises are fulfilled.
+  /// Because only the concrete call subclass knows which promises need to be fulfilled, this method needs to be overridden.
+  func sendErrorStatus(_ status: GRPCStatus) {
+    fatalError("needs to be overridden")
+  }
 }
 
 extension BaseCallHandler: ChannelInboundHandler {
@@ -43,7 +49,7 @@ extension BaseCallHandler: ChannelInboundHandler {
 
     let transformed = errorDelegate?.transform(error) ?? error
     let status = (transformed as? GRPCStatusTransformable)?.asGRPCStatus() ?? GRPCStatus.processingError
-    self.write(ctx: ctx, data: NIOAny(GRPCServerResponsePart<ResponseMessage>.status(status)), promise: nil)
+    sendErrorStatus(status)
   }
 
   public func channelRead(ctx: ChannelHandlerContext, data: NIOAny) {
@@ -60,7 +66,11 @@ extension BaseCallHandler: ChannelInboundHandler {
       }
 
     case .end:
-      endOfStreamReceived()
+      do {
+        try endOfStreamReceived()
+      } catch {
+        self.errorCaught(ctx: ctx, error: error)
+      }
     }
   }
 }

+ 5 - 1
Sources/SwiftGRPCNIO/CallHandlers/BidirectionalStreamingCallHandler.swift

@@ -36,9 +36,13 @@ public class BidirectionalStreamingCallHandler<RequestMessage: Message, Response
     }
   }
 
-  public override func endOfStreamReceived() {
+  public override func endOfStreamReceived() throws {
     eventObserver?.whenSuccess { observer in
       observer(.end)
     }
   }
+  
+  override func sendErrorStatus(_ status: GRPCStatus) {
+    context?.statusPromise.fail(error: status)
+  }
 }

+ 5 - 1
Sources/SwiftGRPCNIO/CallHandlers/ClientStreamingCallHandler.swift

@@ -35,9 +35,13 @@ public class ClientStreamingCallHandler<RequestMessage: Message, ResponseMessage
     }
   }
   
-  public override func endOfStreamReceived() {
+  public override func endOfStreamReceived() throws {
     eventObserver?.whenSuccess { observer in
       observer(.end)
     }
   }
+  
+  override func sendErrorStatus(_ status: GRPCStatus) {
+    context?.responsePromise.fail(error: status)
+  }
 }

+ 11 - 1
Sources/SwiftGRPCNIO/CallHandlers/ServerStreamingCallHandler.swift

@@ -28,7 +28,7 @@ public class ServerStreamingCallHandler<RequestMessage: Message, ResponseMessage
   public override func processMessage(_ message: RequestMessage) throws {
     guard let eventObserver = self.eventObserver,
       let context = self.context else {
-        throw GRPCError.server(.requestCardinalityViolation)
+        throw GRPCError.server(.tooManyRequests)
     }
 
     let resultFuture = eventObserver(message)
@@ -37,4 +37,14 @@ public class ServerStreamingCallHandler<RequestMessage: Message, ResponseMessage
       .cascade(promise: context.statusPromise)
     self.eventObserver = nil
   }
+  
+  public override func endOfStreamReceived() throws {
+    if self.eventObserver != nil {
+      throw GRPCError.server(.noRequestsButOneExpected)
+    }
+  }
+  
+  override func sendErrorStatus(_ status: GRPCStatus) {
+    context?.statusPromise.fail(error: status)
+  }
 }

+ 11 - 1
Sources/SwiftGRPCNIO/CallHandlers/UnaryCallHandler.swift

@@ -29,7 +29,7 @@ public class UnaryCallHandler<RequestMessage: Message, ResponseMessage: Message>
   public override func processMessage(_ message: RequestMessage) throws {
     guard let eventObserver = self.eventObserver,
       let context = self.context else {
-      throw GRPCError.server(.requestCardinalityViolation)
+      throw GRPCError.server(.tooManyRequests)
     }
     
     let resultFuture = eventObserver(message)
@@ -38,4 +38,14 @@ public class UnaryCallHandler<RequestMessage: Message, ResponseMessage: Message>
       .cascade(promise: context.responsePromise)
     self.eventObserver = nil
   }
+  
+  public override func endOfStreamReceived() throws {
+    if self.eventObserver != nil {
+      throw GRPCError.server(.noRequestsButOneExpected)
+    }
+  }
+  
+  override func sendErrorStatus(_ status: GRPCStatus) {
+    context?.responsePromise.fail(error: status)
+  }
 }

+ 8 - 2
Sources/SwiftGRPCNIO/GRPCError.swift

@@ -87,8 +87,11 @@ public enum GRPCServerError: Error, Equatable {
   /// It was not possible to serialize the response protobuf.
   case responseProtoSerializationFailure
 
+  /// Zero requests were sent for a unary-request call.
+  case noRequestsButOneExpected
+  
   /// More than one request was sent for a unary-request call.
-  case requestCardinalityViolation
+  case tooManyRequests
 
   /// The server received a message when it was not in a writable state.
   case serverNotWritable
@@ -143,7 +146,10 @@ extension GRPCServerError: GRPCStatusTransformable {
     case .responseProtoSerializationFailure:
       return GRPCStatus(code: .internalError, message: "could not serialize response proto")
 
-    case .requestCardinalityViolation:
+    case .noRequestsButOneExpected:
+      return GRPCStatus(code: .unimplemented, message: "request cardinality violation; method requires exactly one request but client sent none")
+      
+    case .tooManyRequests:
       return GRPCStatus(code: .unimplemented, message: "request cardinality violation; method requires exactly one request but client sent more")
 
     case .serverNotWritable:

+ 32 - 7
Tests/SwiftGRPCNIOTests/NIOServerWebTests.swift

@@ -25,6 +25,7 @@ class NIOServerWebTests: NIOBasicEchoTestCase {
   static var allTests: [(String, (NIOServerWebTests) -> () throws -> Void)] {
     return [
       ("testUnary", testUnary),
+      ("testUnaryWithoutRequestMessage", testUnaryWithoutRequestMessage),
       //! FIXME: Broken on Linux: https://github.com/grpc/grpc-swift/issues/382
       // ("testUnaryLotsOfRequests", testUnaryLotsOfRequests),
       ("testServerStreaming", testServerStreaming),
@@ -43,8 +44,8 @@ class NIOServerWebTests: NIOBasicEchoTestCase {
     return data
   }
 
-  private func gRPCWebOKTrailers() -> Data {
-    var data = "grpc-status: 0\r\ngrpc-message: OK".data(using: .utf8)!
+  private func gRPCWebTrailers(status: Int = 0, message: String = "OK") -> Data {
+    var data = "grpc-status: \(status)\r\ngrpc-message: \(message)".data(using: .utf8)!
     // Add the gRPC prefix with the compression byte and the 4 length bytes.
     for i in 0..<4 {
       data.insert(UInt8((data.count >> (i * 8)) & 0xFF), at: 0)
@@ -53,13 +54,15 @@ class NIOServerWebTests: NIOBasicEchoTestCase {
     return data
   }
 
-  private func sendOverHTTP1(rpcMethod: String, message: String, handler: @escaping (Data?, Error?) -> Void) {
+  private func sendOverHTTP1(rpcMethod: String, message: String?, handler: @escaping (Data?, Error?) -> Void) {
     let serverURL = URL(string: "http://localhost:5050/echo.Echo/\(rpcMethod)")!
     var request = URLRequest(url: serverURL)
     request.httpMethod = "POST"
     request.setValue("application/grpc-web-text", forHTTPHeaderField: "content-type")
 
-    request.httpBody = gRPCEncodedEchoRequest(message).base64EncodedData()
+    if let message = message {
+      request.httpBody = gRPCEncodedEchoRequest(message).base64EncodedData()
+    }
 
     let sem = DispatchSemaphore(value: 0)
     URLSession.shared.dataTask(with: request) { (data, response, error) in
@@ -73,7 +76,7 @@ class NIOServerWebTests: NIOBasicEchoTestCase {
 extension NIOServerWebTests {
   func testUnary() {
     let message = "hello, world!"
-    let expectedData = gRPCEncodedEchoRequest("Swift echo get: \(message)") + gRPCWebOKTrailers()
+    let expectedData = gRPCEncodedEchoRequest("Swift echo get: \(message)") + gRPCWebTrailers()
     let expectedResponse = expectedData.base64EncodedString()
 
     let completionHandlerExpectation = expectation(description: "completion handler called")
@@ -83,6 +86,28 @@ extension NIOServerWebTests {
       if let data = data {
         XCTAssertEqual(String(data: data, encoding: .utf8), expectedResponse)
         completionHandlerExpectation.fulfill()
+      } else {
+        XCTFail("no data returned")
+      }
+    }
+
+    waitForExpectations(timeout: defaultTestTimeout)
+  }
+  
+  func testUnaryWithoutRequestMessage() {
+    let expectedData = gRPCWebTrailers(
+      status: 12, message: "request cardinality violation; method requires exactly one request but client sent none")
+    let expectedResponse = expectedData.base64EncodedString()
+
+    let completionHandlerExpectation = expectation(description: "completion handler called")
+
+    sendOverHTTP1(rpcMethod: "Get", message: nil) { data, error in
+      XCTAssertNil(error)
+      if let data = data {
+        XCTAssertEqual(String(data: data, encoding: .utf8), expectedResponse)
+        completionHandlerExpectation.fulfill()
+      } else {
+        XCTFail("no data returned")
       }
     }
 
@@ -104,7 +129,7 @@ extension NIOServerWebTests {
 
     for i in 0..<numberOfRequests {
       let message = "foo \(i)"
-      let expectedData = gRPCEncodedEchoRequest("Swift echo get: \(message)") + gRPCWebOKTrailers()
+      let expectedData = gRPCEncodedEchoRequest("Swift echo get: \(message)") + gRPCWebTrailers()
       let expectedResponse = expectedData.base64EncodedString()
       sendOverHTTP1(rpcMethod: "Get", message: message) { data, error in
         XCTAssertNil(error)
@@ -132,7 +157,7 @@ extension NIOServerWebTests {
       expectedData.append(gRPCEncodedEchoRequest("Swift echo expand (\(index)): \(component)"))
       index += 1
     }
-    expectedData.append(gRPCWebOKTrailers())
+    expectedData.append(gRPCWebTrailers())
     let expectedResponse = expectedData.base64EncodedString()
     let completionHandlerExpectation = expectation(description: "completion handler called")