2
0
Эх сурвалжийг харах

Delay creating event observers for client streaming calls (#523)

* Delay creating event observers for client streaming calls

Motivation:

Providers for client and bidirectional streaming calls require
the user provide a future stream-event handler to handle requests
from the client. However, these methods get called as the pipeline
handling an incoming call is being configured, as these methods
also expose promises for response (for client streaming) and
call status (for bidirectional streaming) it is possible for these
to be fulfilled before the pipeline has been configured. Since
no handler is in place to deal with the promised types the server
will fatal error as the first handler in place will fail to unwrap
the promised type.

Modifications:

Delay the creation of event observers for client and
bidirectional streaming calls until their handlers have been added
to the pipeline.

Result:

Client streaming and bidirectional streaming calls can fulfill
their response and status promises outside of their stream handlers.

* Track the event observer and factory with state
George Barnett 6 жил өмнө
parent
commit
d4bcd92215

+ 26 - 11
Sources/GRPC/CallHandlers/BidirectionalStreamingCallHandler.swift

@@ -25,29 +25,42 @@ import NIOHTTP1
 ///   they can fail the observer block future.
 /// - To close the call and send the status, complete `context.statusPromise`.
 public class BidirectionalStreamingCallHandler<RequestMessage: Message, ResponseMessage: Message>: BaseCallHandler<RequestMessage, ResponseMessage> {
+  public typealias Context = StreamingResponseCallContext<ResponseMessage>
   public typealias EventObserver = (StreamEvent<RequestMessage>) -> Void
-  private var eventObserver: EventLoopFuture<EventObserver>?
+  public typealias EventObserverFactory = (Context) -> EventLoopFuture<EventObserver>
 
-  private var callContext: StreamingResponseCallContext<ResponseMessage>?
+  private var observerState: ClientStreamingHandlerObserverState<EventObserverFactory, EventObserver>
+  private var callContext: Context?
 
   // We ask for a future of type `EventObserver` to allow the framework user to e.g. asynchronously authenticate a call.
   // If authentication fails, they can simply fail the observer future, which causes the call to be terminated.
-  public init(channel: Channel, request: HTTPRequestHead, errorDelegate: ServerErrorDelegate?, eventObserverFactory: (StreamingResponseCallContext<ResponseMessage>) -> EventLoopFuture<EventObserver>) {
-    super.init(errorDelegate: errorDelegate)
+  public init(channel: Channel, request: HTTPRequestHead, errorDelegate: ServerErrorDelegate?, eventObserverFactory: @escaping (StreamingResponseCallContext<ResponseMessage>) -> EventLoopFuture<EventObserver>) {
+    // Delay the creation of the event observer until `handlerAdded(context:)`, otherwise it is
+    // possible for the service to write into the pipeline (by fulfilling the status promise
+    // of the call context outside of the observer) before it has been configured.
+    self.observerState = .pendingCreation(eventObserverFactory)
+
     let context = StreamingResponseCallContextImpl<ResponseMessage>(channel: channel, request: request, errorDelegate: errorDelegate)
     self.callContext = context
-    let eventObserver = eventObserverFactory(context)
-    self.eventObserver = eventObserver
+
+    super.init(errorDelegate: errorDelegate)
+
     context.statusPromise.futureResult.whenComplete { _ in
       // When done, reset references to avoid retain cycles.
-      self.eventObserver = nil
       self.callContext = nil
+      self.observerState = .notRequired
     }
   }
 
   public override func handlerAdded(context: ChannelHandlerContext) {
-    guard let eventObserver = self.eventObserver,
-      let callContext = self.callContext else { return }
+    guard let callContext = self.callContext,
+      case let .pendingCreation(factory) = self.observerState else {
+      return
+    }
+
+    let eventObserver = factory(callContext)
+    self.observerState = .created(eventObserver)
+
     // Terminate the call if the future providing an observer fails.
     // This is being done _after_ we have been added as a handler to ensure that the `GRPCServerCodec` required to
     // translate our outgoing `GRPCServerResponsePart<ResponseMessage>` message is already present on the channel.
@@ -57,13 +70,15 @@ public class BidirectionalStreamingCallHandler<RequestMessage: Message, Response
 
 
   public override func processMessage(_ message: RequestMessage) {
-    self.eventObserver?.whenSuccess { observer in
+    guard case .created(let eventObserver) = self.observerState else { return }
+    eventObserver.whenSuccess { observer in
       observer(.message(message))
     }
   }
 
   public override func endOfStreamReceived() throws {
-    self.eventObserver?.whenSuccess { observer in
+    guard case .created(let eventObserver) = self.observerState else { return }
+    eventObserver.whenSuccess { observer in
       observer(.end)
     }
   }

+ 33 - 10
Sources/GRPC/CallHandlers/ClientStreamingCallHandler.swift

@@ -18,6 +18,14 @@ import SwiftProtobuf
 import NIO
 import NIOHTTP1
 
+/// For calls which support client streaming we need to delay the creation of the event observer
+/// until the handler has been added to the pipeline.
+enum ClientStreamingHandlerObserverState<Factory, Observer> {
+  case pendingCreation(Factory)
+  case created(EventLoopFuture<Observer>)
+  case notRequired
+}
+
 /// Handles client-streaming calls. Forwards incoming messages and end-of-stream events to the observer block.
 ///
 /// - The observer block is implemented by the framework user and fulfills `context.responsePromise` when done.
@@ -25,29 +33,42 @@ import NIOHTTP1
 ///   they can fail the observer block future.
 /// - To close the call and send the response, complete `context.responsePromise`.
 public class ClientStreamingCallHandler<RequestMessage: Message, ResponseMessage: Message>: BaseCallHandler<RequestMessage, ResponseMessage> {
+  public typealias Context = UnaryResponseCallContext<ResponseMessage>
   public typealias EventObserver = (StreamEvent<RequestMessage>) -> Void
-  private var eventObserver: EventLoopFuture<EventObserver>?
+  public typealias EventObserverFactory = (Context) -> EventLoopFuture<EventObserver>
 
+  private var observerState: ClientStreamingHandlerObserverState<EventObserverFactory, EventObserver>
   private var callContext: UnaryResponseCallContext<ResponseMessage>?
 
   // We ask for a future of type `EventObserver` to allow the framework user to e.g. asynchronously authenticate a call.
   // If authentication fails, they can simply fail the observer future, which causes the call to be terminated.
-  public init(channel: Channel, request: HTTPRequestHead, errorDelegate: ServerErrorDelegate?, eventObserverFactory: (UnaryResponseCallContext<ResponseMessage>) -> EventLoopFuture<EventObserver>) {
-    super.init(errorDelegate: errorDelegate)
+  public init(channel: Channel, request: HTTPRequestHead, errorDelegate: ServerErrorDelegate?, eventObserverFactory: @escaping EventObserverFactory) {
+    // Delay the creation of the event observer until `handlerAdded(context:)`, otherwise it is
+    // possible for the service to write into the pipeline (by fulfilling the response promise
+    // of the call context outside of the observer) before it has been configured.
+    self.observerState = .pendingCreation(eventObserverFactory)
+
     let callContext = UnaryResponseCallContextImpl<ResponseMessage>(channel: channel, request: request, errorDelegate: errorDelegate)
     self.callContext = callContext
-    let eventObserver = eventObserverFactory(callContext)
-    self.eventObserver = eventObserver
+
+    super.init(errorDelegate: errorDelegate)
+
     callContext.responsePromise.futureResult.whenComplete { _ in
       // When done, reset references to avoid retain cycles.
-      self.eventObserver = nil
       self.callContext = nil
+      self.observerState = .notRequired
     }
   }
 
   public override func handlerAdded(context: ChannelHandlerContext) {
-    guard let eventObserver = self.eventObserver,
-      let callContext = self.callContext else { return }
+    guard let callContext = self.callContext,
+      case let .pendingCreation(factory) = self.observerState else {
+      return
+    }
+
+    let eventObserver = factory(callContext)
+    self.observerState = .created(eventObserver)
+
     // Terminate the call if the future providing an observer fails.
     // This is being done _after_ we have been added as a handler to ensure that the `GRPCServerCodec` required to
     // translate our outgoing `GRPCServerResponsePart<ResponseMessage>` message is already present on the channel.
@@ -56,13 +77,15 @@ public class ClientStreamingCallHandler<RequestMessage: Message, ResponseMessage
   }
 
   public override func processMessage(_ message: RequestMessage) {
-    self.eventObserver?.whenSuccess { observer in
+    guard case .created(let eventObserver) = self.observerState else { return }
+    eventObserver.whenSuccess { observer in
       observer(.message(message))
     }
   }
 
   public override func endOfStreamReceived() throws {
-    self.eventObserver?.whenSuccess { observer in
+    guard case .created(let eventObserver) = self.observerState else { return }
+    eventObserver.whenSuccess { observer in
       observer(.end)
     }
   }

+ 97 - 0
Tests/GRPCTests/ImmediateServerFailureTests.swift

@@ -0,0 +1,97 @@
+/*
+ * Copyright 2019, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import Foundation
+import GRPC
+import NIO
+import XCTest
+
+class ImmediatelyFailingEchoProvider: Echo_EchoProvider {
+  static let status: GRPCStatus = .init(code: .unavailable, message: nil)
+
+  func get(
+    request: Echo_EchoRequest,
+    context: StatusOnlyCallContext
+  ) -> EventLoopFuture<Echo_EchoResponse> {
+    return context.eventLoop.makeFailedFuture(ImmediatelyFailingEchoProvider.status)
+  }
+
+  func expand(
+    request: Echo_EchoRequest,
+    context: StreamingResponseCallContext<Echo_EchoResponse>
+  ) -> EventLoopFuture<GRPCStatus> {
+    return context.eventLoop.makeFailedFuture(ImmediatelyFailingEchoProvider.status)
+  }
+
+  func collect(
+    context: UnaryResponseCallContext<Echo_EchoResponse>
+  ) -> EventLoopFuture<(StreamEvent<Echo_EchoRequest>) -> Void> {
+    context.responsePromise.fail(ImmediatelyFailingEchoProvider.status)
+    return context.eventLoop.makeSucceededFuture({ _ in
+      // no-op
+    })
+  }
+
+  func update(
+    context: StreamingResponseCallContext<Echo_EchoResponse>
+  ) -> EventLoopFuture<(StreamEvent<Echo_EchoRequest>) -> Void> {
+    context.statusPromise.fail(ImmediatelyFailingEchoProvider.status)
+    return context.eventLoop.makeSucceededFuture({ _ in
+      // no-op
+    })
+  }
+}
+
+class ImmediatelyFailingProviderTests: EchoTestCaseBase {
+  override func makeEchoProvider() -> Echo_EchoProvider {
+    return ImmediatelyFailingEchoProvider()
+  }
+
+  func testUnary() throws {
+    let expcectation = self.makeStatusExpectation()
+    let call = self.client.get(Echo_EchoRequest(text: "foo"))
+    call.status.map { $0.code }.assertEqual(.unavailable, fulfill: expcectation)
+
+    self.wait(for: [expcectation], timeout: self.defaultTestTimeout)
+  }
+
+  func testServerStreaming() throws {
+    let expcectation = self.makeStatusExpectation()
+    let call = self.client.expand(Echo_EchoRequest(text: "foo")) { response in
+      XCTFail("unexpected response: \(response)")
+    }
+
+    call.status.map { $0.code }.assertEqual(.unavailable, fulfill: expcectation)
+    self.wait(for: [expcectation], timeout: self.defaultTestTimeout)
+  }
+
+  func testClientStreaming() throws {
+    let expcectation = self.makeStatusExpectation()
+    let call = self.client.collect()
+
+    call.status.map { $0.code }.assertEqual(.unavailable, fulfill: expcectation)
+    self.wait(for: [expcectation], timeout: self.defaultTestTimeout)
+  }
+
+  func testBidirectionalStreaming() throws {
+    let expcectation = self.makeStatusExpectation()
+    let call = self.client.update { response in
+      XCTFail("unexpected response: \(response)")
+    }
+
+    call.status.map { $0.code }.assertEqual(.unavailable, fulfill: expcectation)
+    self.wait(for: [expcectation], timeout: self.defaultTestTimeout)
+  }
+}

+ 13 - 0
Tests/GRPCTests/XCTestManifests.swift

@@ -342,6 +342,18 @@ extension HTTP1ToRawGRPCServerCodecTests {
     ]
 }
 
+extension ImmediatelyFailingProviderTests {
+    // DO NOT MODIFY: This is autogenerated, use:
+    //   `swift test --generate-linuxmain`
+    // to regenerate.
+    static let __allTests__ImmediatelyFailingProviderTests = [
+        ("testBidirectionalStreaming", testBidirectionalStreaming),
+        ("testClientStreaming", testClientStreaming),
+        ("testServerStreaming", testServerStreaming),
+        ("testUnary", testUnary),
+    ]
+}
+
 extension LengthPrefixedMessageReaderTests {
     // DO NOT MODIFY: This is autogenerated, use:
     //   `swift test --generate-linuxmain`
@@ -435,6 +447,7 @@ public func __allTests() -> [XCTestCaseEntry] {
         testCase(GRPCStatusMessageMarshallerTests.__allTests__GRPCStatusMessageMarshallerTests),
         testCase(GRPCTypeSizeTests.__allTests__GRPCTypeSizeTests),
         testCase(HTTP1ToRawGRPCServerCodecTests.__allTests__HTTP1ToRawGRPCServerCodecTests),
+        testCase(ImmediatelyFailingProviderTests.__allTests__ImmediatelyFailingProviderTests),
         testCase(LengthPrefixedMessageReaderTests.__allTests__LengthPrefixedMessageReaderTests),
         testCase(ServerDelayedThrowingTests.__allTests__ServerDelayedThrowingTests),
         testCase(ServerErrorTransformingTests.__allTests__ServerErrorTransformingTests),