Browse Source

Fix crashes due to mismatching responses sent to the channel when event observer factories fail. (#395)

* Fix crashes due to mismatching responses sent to the channel when event observer factories fail.

* Tweak `newFailedFuture`.

* PR fixes.

* Minor comment improvements.

* PR fixes.
Daniel Alm 6 years ago
parent
commit
772b78ebce

+ 6 - 0
Sources/SwiftGRPCNIO/CallHandlers/BaseCallHandler.swift

@@ -16,6 +16,12 @@ public class BaseCallHandler<RequestMessage: Message, ResponseMessage: Message>:
     fatalError("needs to be overridden")
   }
 
+  /// Needs to be implemented by this class so that subclasses can override it.
+  ///
+  /// Otherwise, the subclass's implementation will simply never be called (probably because the protocol's default
+  /// implementation in an extension is being used instead).
+  public func handlerAdded(ctx: ChannelHandlerContext) { }
+  
   /// Called when the client has half-closed the stream, indicating that they won't send any further data.
   ///
   /// Overridden by subclasses if the "end-of-stream" event is relevant.

+ 14 - 3
Sources/SwiftGRPCNIO/CallHandlers/BidirectionalStreamingCallHandler.swift

@@ -6,7 +6,9 @@ import NIOHTTP1
 /// Handles bidirectional streaming calls. Forwards incoming messages and end-of-stream events to the observer block.
 ///
 /// - The observer block is implemented by the framework user and calls `context.sendResponse` as needed.
-/// - To close the call and send the status, fulfill `context.statusPromise`.
+///   If the framework user wants to return a call error (e.g. in case of authentication failure),
+///   they can fail the observer block future.
+/// - To close the call and send the status, complete `context.statusPromise`.
 public class BidirectionalStreamingCallHandler<RequestMessage: Message, ResponseMessage: Message>: BaseCallHandler<RequestMessage, ResponseMessage> {
   public typealias EventObserver = (StreamEvent<RequestMessage>) -> Void
   private var eventObserver: EventLoopFuture<EventObserver>?
@@ -21,14 +23,23 @@ public class BidirectionalStreamingCallHandler<RequestMessage: Message, Response
     self.context = context
     let eventObserver = eventObserverFactory(context)
     self.eventObserver = eventObserver
-    // Terminate the call if no observer is provided.
-    eventObserver.cascadeFailure(promise: context.statusPromise)
     context.statusPromise.futureResult.whenComplete {
       // When done, reset references to avoid retain cycles.
       self.eventObserver = nil
       self.context = nil
     }
   }
+  
+  public override func handlerAdded(ctx: ChannelHandlerContext) {
+    guard let eventObserver = eventObserver,
+      let context = context else { return }
+    // Terminate the call if the future providing an observer fails.
+    // This is being done _after_ we have been added as a handler to ensure that the `GRPCServerCodec` required to
+    // translate our outgoing `GRPCServerResponsePart<ResponseMessage>` message is already present on the channel.
+    // Otherwise, our `OutboundOut` type would not match the `OutboundIn` type of the next handler on the channel.
+    eventObserver.cascadeFailure(promise: context.statusPromise)
+  }
+  
 
   public override func processMessage(_ message: RequestMessage) {
     eventObserver?.whenSuccess { observer in

+ 13 - 2
Sources/SwiftGRPCNIO/CallHandlers/ClientStreamingCallHandler.swift

@@ -6,6 +6,9 @@ import NIOHTTP1
 /// Handles client-streaming calls. Forwards incoming messages and end-of-stream events to the observer block.
 ///
 /// - The observer block is implemented by the framework user and fulfills `context.responsePromise` when done.
+///   If the framework user wants to return a call error (e.g. in case of authentication failure),
+///   they can fail the observer block future.
+/// - To close the call and send the response, complete `context.responsePromise`.
 public class ClientStreamingCallHandler<RequestMessage: Message, ResponseMessage: Message>: BaseCallHandler<RequestMessage, ResponseMessage> {
   public typealias EventObserver = (StreamEvent<RequestMessage>) -> Void
   private var eventObserver: EventLoopFuture<EventObserver>?
@@ -20,8 +23,6 @@ public class ClientStreamingCallHandler<RequestMessage: Message, ResponseMessage
     self.context = context
     let eventObserver = eventObserverFactory(context)
     self.eventObserver = eventObserver
-    // Terminate the call if no observer is provided.
-    eventObserver.cascadeFailure(promise: context.responsePromise)
     context.responsePromise.futureResult.whenComplete {
       // When done, reset references to avoid retain cycles.
       self.eventObserver = nil
@@ -29,6 +30,16 @@ public class ClientStreamingCallHandler<RequestMessage: Message, ResponseMessage
     }
   }
   
+  public override func handlerAdded(ctx: ChannelHandlerContext) {
+    guard let eventObserver = eventObserver,
+      let context = context else { return }
+    // Terminate the call if the future providing an observer fails.
+    // This is being done _after_ we have been added as a handler to ensure that the `GRPCServerCodec` required to
+    // translate our outgoing `GRPCServerResponsePart<ResponseMessage>` message is already present on the channel.
+    // Otherwise, our `OutboundOut` type would not match the `OutboundIn` type of the next handler on the channel.
+    eventObserver.cascadeFailure(promise: context.responsePromise)
+  }
+  
   public override func processMessage(_ message: RequestMessage) {
     eventObserver?.whenSuccess { observer in
       observer(.message(message))

+ 1 - 1
Sources/SwiftGRPCNIO/CallHandlers/UnaryCallHandler.swift

@@ -7,7 +7,7 @@ import NIOHTTP1
 ///
 /// - The observer block is implemented by the framework user and returns a future containing the call result.
 /// - To return a response to the client, the framework user should complete that future
-/// (similar to e.g. serving regular HTTP requests in frameworks such as Vapor).
+///   (similar to e.g. serving regular HTTP requests in frameworks such as Vapor).
 public class UnaryCallHandler<RequestMessage: Message, ResponseMessage: Message>: BaseCallHandler<RequestMessage, ResponseMessage> {
   public typealias EventObserver = (RequestMessage) -> EventLoopFuture<ResponseMessage>
   private var eventObserver: EventObserver?

+ 4 - 1
Tests/LinuxMain.swift

@@ -34,11 +34,14 @@ XCTMain([
   testCase(MetadataTests.allTests),
   testCase(ServerCancellingTests.allTests),
   testCase(ServerTestExample.allTests),
-  testCase(ServerThrowingTests.allTests),
+  testCase(SwiftGRPCTests.ServerThrowingTests.allTests),
   testCase(ServerTimeoutTests.allTests),
 
   // SwiftGRPCNIO
   testCase(NIOServerTests.allTests),
+  testCase(SwiftGRPCNIOTests.ServerThrowingTests.allTests),
+  testCase(SwiftGRPCNIOTests.ServerDelayedThrowingTests.allTests),
+  testCase(SwiftGRPCNIOTests.ClientThrowingWhenServerReturningErrorTests.allTests),
   testCase(NIOClientCancellingTests.allTests),
   testCase(NIOClientTimeoutTests.allTests),
   testCase(NIOServerWebTests.allTests),

+ 3 - 1
Tests/SwiftGRPCNIOTests/NIOBasicEchoTestCase.swift

@@ -39,13 +39,15 @@ class NIOBasicEchoTestCase: XCTestCase {
 
   var clientEventLoopGroup: EventLoopGroup!
   var client: Echo_EchoService_NIOClient!
+  
+  func makeEchoProvider() -> Echo_EchoProvider_NIO { return EchoProviderNIO() }
 
   override func setUp() {
     super.setUp()
 
     self.serverEventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
     self.server = try! GRPCServer.start(
-      hostname: "localhost", port: 5050, eventLoopGroup: self.serverEventLoopGroup, serviceProviders: [EchoProviderNIO()])
+      hostname: "localhost", port: 5050, eventLoopGroup: self.serverEventLoopGroup, serviceProviders: [makeEchoProvider()])
       .wait()
 
     self.clientEventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)

+ 2 - 0
Tests/SwiftGRPCNIOTests/NIOServerTests.swift

@@ -26,6 +26,8 @@ class NIOServerTests: NIOBasicEchoTestCase {
     return [
       ("testUnary", testUnary),
       ("testUnaryLotsOfRequests", testUnaryLotsOfRequests),
+      ("testUnaryWithLargeData", testUnaryWithLargeData),
+      ("testUnaryEmptyRequest", testUnaryEmptyRequest),
       ("testClientStreaming", testClientStreaming),
       ("testClientStreamingLotsOfMessages", testClientStreamingLotsOfMessages),
       ("testServerStreaming", testServerStreaming),

+ 152 - 0
Tests/SwiftGRPCNIOTests/ServerThrowingTests.swift

@@ -0,0 +1,152 @@
+/*
+ * Copyright 2018, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import Dispatch
+import Foundation
+import NIO
+import NIOHTTP1
+import NIOHTTP2
+@testable import SwiftGRPCNIO
+import XCTest
+
+private let expectedError = GRPCStatus(code: .internalError, message: "expected error")
+
+// Motivation for two different providers: Throwing immediately causes the event observer future (in the
+// client-streaming and bidi-streaming cases) to throw immediately, _before_ the corresponding handler has even added
+// to the channel. We want to test that case as well as the one where we throw only _after_ the handler has been added
+// to the channel.
+private class ImmediateThrowingEchoProviderNIO: Echo_EchoProvider_NIO {
+  func get(request: Echo_EchoRequest, context: StatusOnlyCallContext) -> EventLoopFuture<Echo_EchoResponse> {
+    return context.eventLoop.newFailedFuture(error: expectedError)
+  }
+  
+  func expand(request: Echo_EchoRequest, context: StreamingResponseCallContext<Echo_EchoResponse>) -> EventLoopFuture<GRPCStatus> {
+    return context.eventLoop.newFailedFuture(error: expectedError)
+  }
+  
+  func collect(context: UnaryResponseCallContext<Echo_EchoResponse>) -> EventLoopFuture<(StreamEvent<Echo_EchoRequest>) -> Void> {
+    return context.eventLoop.newFailedFuture(error: expectedError)
+  }
+  
+  func update(context: StreamingResponseCallContext<Echo_EchoResponse>) -> EventLoopFuture<(StreamEvent<Echo_EchoRequest>) -> Void> {
+    return context.eventLoop.newFailedFuture(error: expectedError)
+  }
+}
+
+private extension EventLoop {
+  func newFailedFuture<T>(error: Error, delay: TimeInterval) -> EventLoopFuture<T> {
+    return self.scheduleTask(in: .nanoseconds(TimeAmount.Value(delay * 1000 * 1000 * 1000))) { () }.futureResult
+      .thenThrowing { _ -> T in throw error }
+  }
+}
+
+/// See `ImmediateThrowingEchoProviderNIO`.
+private class DelayedThrowingEchoProviderNIO: Echo_EchoProvider_NIO {
+  func get(request: Echo_EchoRequest, context: StatusOnlyCallContext) -> EventLoopFuture<Echo_EchoResponse> {
+    return context.eventLoop.newFailedFuture(error: expectedError, delay: 0.01)
+  }
+  
+  func expand(request: Echo_EchoRequest, context: StreamingResponseCallContext<Echo_EchoResponse>) -> EventLoopFuture<GRPCStatus> {
+    return context.eventLoop.newFailedFuture(error: expectedError, delay: 0.01)
+  }
+  
+  func collect(context: UnaryResponseCallContext<Echo_EchoResponse>) -> EventLoopFuture<(StreamEvent<Echo_EchoRequest>) -> Void> {
+    return context.eventLoop.newFailedFuture(error: expectedError, delay: 0.01)
+  }
+  
+  func update(context: StreamingResponseCallContext<Echo_EchoResponse>) -> EventLoopFuture<(StreamEvent<Echo_EchoRequest>) -> Void> {
+    return context.eventLoop.newFailedFuture(error: expectedError, delay: 0.01)
+  }
+}
+
+/// Ensures that fulfilling the status promise (where possible) with an error yields the same result as failing the future.
+private class ErrorReturningEchoProviderNIO: ImmediateThrowingEchoProviderNIO {
+  // There's no status promise to fulfill for unary calls (only the response promise), so that case is omitted.
+  
+  override func expand(request: Echo_EchoRequest, context: StreamingResponseCallContext<Echo_EchoResponse>) -> EventLoopFuture<GRPCStatus> {
+    return context.eventLoop.newSucceededFuture(result: expectedError)
+  }
+  
+  override func collect(context: UnaryResponseCallContext<Echo_EchoResponse>) -> EventLoopFuture<(StreamEvent<Echo_EchoRequest>) -> Void> {
+    return context.eventLoop.newSucceededFuture(result: { _ in
+      context.responseStatus = expectedError
+      context.responsePromise.succeed(result: Echo_EchoResponse())
+    })
+  }
+  
+  override func update(context: StreamingResponseCallContext<Echo_EchoResponse>) -> EventLoopFuture<(StreamEvent<Echo_EchoRequest>) -> Void> {
+    return context.eventLoop.newSucceededFuture(result: { _ in
+      context.statusPromise.succeed(result: expectedError)
+    })
+  }
+}
+
+class ServerThrowingTests: NIOBasicEchoTestCase {
+  override func makeEchoProvider() -> Echo_EchoProvider_NIO { return ImmediateThrowingEchoProviderNIO() }
+  
+  static var allTests: [(String, (ServerThrowingTests) -> () throws -> Void)] {
+    return [
+      ("testUnary", testUnary),
+      ("testClientStreaming", testClientStreaming),
+      ("testServerStreaming", testServerStreaming),
+      ("testBidirectionalStreaming", testBidirectionalStreaming),
+    ]
+  }
+}
+
+class ServerDelayedThrowingTests: ServerThrowingTests {
+  override func makeEchoProvider() -> Echo_EchoProvider_NIO { return DelayedThrowingEchoProviderNIO() }
+}
+
+class ClientThrowingWhenServerReturningErrorTests: ServerThrowingTests {
+  override func makeEchoProvider() -> Echo_EchoProvider_NIO { return ErrorReturningEchoProviderNIO() }
+}
+
+extension ServerThrowingTests {
+  func testUnary() throws {
+    let call = client.get(Echo_EchoRequest(text: "foo"))
+    XCTAssertEqual(expectedError, try call.status.wait())
+    XCTAssertThrowsError(try call.response.wait()) {
+      XCTAssertEqual(expectedError, $0 as? GRPCStatus)
+    }
+  }
+  
+  func testClientStreaming() {
+    let call = client.collect()
+    XCTAssertNoThrow(try call.sendEnd().wait())
+    XCTAssertEqual(expectedError, try call.status.wait())
+    
+    if type(of: makeEchoProvider()) != ErrorReturningEchoProviderNIO.self {
+      // With `ErrorReturningEchoProviderNIO` we actually _return_ a response, which means that the `response` future
+      // will _not_ fail, so in that case this test doesn't apply.
+      XCTAssertThrowsError(try call.response.wait()) {
+        XCTAssertEqual(expectedError, $0 as? GRPCStatus)
+      }
+    }
+  }
+  
+  func testServerStreaming() {
+    let call = client.expand(Echo_EchoRequest(text: "foo")) { XCTFail("no message expected, got \($0)") }
+    // Nothing to throw here, but the `status` should be the expected error.
+    XCTAssertEqual(expectedError, try call.status.wait())
+  }
+  
+  func testBidirectionalStreaming() {
+    let call = client.update() { XCTFail("no message expected, got \($0)") }
+    XCTAssertNoThrow(try call.sendEnd().wait())
+    // Nothing to throw here, but the `status` should be the expected error.
+    XCTAssertEqual(expectedError, try call.status.wait())
+  }
+}