Browse Source

Add a bidirectional streaming server handler. (#1098)

Motivation:

Following from #1093, #1095 and #1097; we need a bidirectional streaming
server handler.

Modifications:

- Add a bidirectional streaming server handler and tests.

Result:

We'll be able to do bidirectional streaming RPCs with new codegen.
George Barnett 5 years ago
parent
commit
7606b999c7

+ 366 - 0
Sources/GRPC/CallHandlers/BidirectionalStreamingServerHandler.swift

@@ -0,0 +1,366 @@
+/*
+ * Copyright 2021, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import NIO
+import NIOHPACK
+
+public final class BidirectionalStreamingServerHandler<
+  Serializer: MessageSerializer,
+  Deserializer: MessageDeserializer
+>: GRPCServerHandlerProtocol {
+  public typealias Request = Deserializer.Output
+  public typealias Response = Serializer.Input
+
+  /// A response serializer.
+  @usableFromInline
+  internal let serializer: Serializer
+
+  /// A request deserializer.
+  @usableFromInline
+  internal let deserializer: Deserializer
+
+  /// A pipeline of user provided interceptors.
+  @usableFromInline
+  internal var interceptors: ServerInterceptorPipeline<Request, Response>!
+
+  /// Stream events which have arrived before the stream observer future has been resolved.
+  @usableFromInline
+  internal var requestBuffer: CircularBuffer<StreamEvent<Request>> = CircularBuffer()
+
+  /// The context required in order create the function.
+  @usableFromInline
+  internal let context: CallHandlerContext
+
+  /// A reference to a `UserInfo`.
+  @usableFromInline
+  internal let userInfoRef: Ref<UserInfo>
+
+  /// The user provided function to execute.
+  @usableFromInline
+  internal let observerFactory: (_StreamingResponseCallContext<Request, Response>)
+    -> EventLoopFuture<(StreamEvent<Request>) -> Void>
+
+  /// The state of the handler.
+  @usableFromInline
+  internal var state: State = .idle
+
+  @usableFromInline
+  internal enum State {
+    // No headers have been received.
+    case idle
+    // Headers have been received, a context has been created and the user code has been called to
+    // make a stream observer with. The observer is yet to see any messages.
+    case creatingObserver(_StreamingResponseCallContext<Request, Response>)
+    // The observer future has resolved and the observer may have seen messages.
+    case observing((StreamEvent<Request>) -> Void, _StreamingResponseCallContext<Request, Response>)
+    // The observer has completed by completing the status promise.
+    case completed
+  }
+
+  @inlinable
+  public init(
+    context: CallHandlerContext,
+    requestDeserializer: Deserializer,
+    responseSerializer: Serializer,
+    interceptors: [ServerInterceptor<Request, Response>],
+    observerFactory: @escaping (StreamingResponseCallContext<Response>)
+      -> EventLoopFuture<(StreamEvent<Request>) -> Void>
+  ) {
+    self.serializer = responseSerializer
+    self.deserializer = requestDeserializer
+    self.context = context
+    self.observerFactory = observerFactory
+
+    let userInfoRef = Ref(UserInfo())
+    self.userInfoRef = userInfoRef
+    self.interceptors = ServerInterceptorPipeline(
+      logger: context.logger,
+      eventLoop: context.eventLoop,
+      path: context.path,
+      callType: .bidirectionalStreaming,
+      remoteAddress: context.remoteAddress,
+      userInfoRef: userInfoRef,
+      interceptors: interceptors,
+      onRequestPart: self.receiveInterceptedPart(_:),
+      onResponsePart: self.sendInterceptedPart(_:promise:)
+    )
+  }
+
+  // MARK: - Public API: gRPC to Handler
+
+  @inlinable
+  public func receiveMetadata(_ headers: HPACKHeaders) {
+    self.interceptors.receive(.metadata(headers))
+  }
+
+  @inlinable
+  public func receiveMessage(_ bytes: ByteBuffer) {
+    do {
+      let message = try self.deserializer.deserialize(byteBuffer: bytes)
+      self.interceptors.receive(.message(message))
+    } catch {
+      self.handleError(error)
+    }
+  }
+
+  @inlinable
+  public func receiveEnd() {
+    self.interceptors.receive(.end)
+  }
+
+  @inlinable
+  public func receiveError(_ error: Error) {
+    self._finish(error: error)
+  }
+
+  @inlinable
+  public func finish() {
+    self._finish(error: nil)
+  }
+
+  @inlinable
+  internal func _finish(error: Error?) {
+    switch self.state {
+    case .idle:
+      self.interceptors = nil
+      self.state = .completed
+
+    case let .creatingObserver(context),
+         let .observing(_, context):
+      let error = error ?? GRPCStatus(code: .unavailable, message: nil)
+      context.statusPromise.fail(error)
+
+    case .completed:
+      self.interceptors = nil
+    }
+  }
+
+  // MARK: - Interceptors to User Function
+
+  @inlinable
+  internal func receiveInterceptedPart(_ part: GRPCServerRequestPart<Request>) {
+    switch part {
+    case let .metadata(headers):
+      self.receiveInterceptedMetadata(headers)
+    case let .message(message):
+      self.receiveInterceptedMessage(message)
+    case .end:
+      self.receiveInterceptedEnd()
+    }
+  }
+
+  @inlinable
+  internal func receiveInterceptedMetadata(_ headers: HPACKHeaders) {
+    switch self.state {
+    case .idle:
+      // Make a context to invoke the observer block factory with.
+      let context = _StreamingResponseCallContext<Request, Response>(
+        eventLoop: self.context.eventLoop,
+        headers: headers,
+        logger: self.context.logger,
+        userInfoRef: self.userInfoRef,
+        sendResponse: self.interceptResponse(_:metadata:promise:)
+      )
+
+      // Move to the next state.
+      self.state = .creatingObserver(context)
+
+      // Send response headers back via the interceptors.
+      self.interceptors.send(.metadata([:]), promise: nil)
+
+      // Register callbacks on the status future.
+      context.statusPromise.futureResult.whenComplete(self.userFunctionStatusResolved(_:))
+
+      // Make an observer block and register a completion block.
+      self.observerFactory(context).whenComplete(self.userFunctionResolvedWithResult(_:))
+
+    case .creatingObserver, .observing:
+      self.handleError(GRPCError.InvalidState("Protocol violation: already received headers"))
+
+    case .completed:
+      // We may receive headers from the interceptor pipeline if we have already finished (i.e. due
+      // to an error or otherwise) and an interceptor doing some async work later emitting headers.
+      // Dropping them is fine.
+      ()
+    }
+  }
+
+  @inlinable
+  internal func receiveInterceptedMessage(_ request: Request) {
+    switch self.state {
+    case .idle:
+      self.handleError(
+        GRPCError.InvalidState("Protocol violation: message received before headers")
+      )
+    case .creatingObserver:
+      self.requestBuffer.append(.message(request))
+    case let .observing(observer, _):
+      observer(.message(request))
+    case .completed:
+      // We received a message but we're already done: this may happen if we terminate the RPC
+      // due to a channel error, for example.
+      ()
+    }
+  }
+
+  @inlinable
+  internal func receiveInterceptedEnd() {
+    switch self.state {
+    case .idle:
+      self.handleError(
+        GRPCError.InvalidState("Protocol violation: end of stream received before headers")
+      )
+    case .creatingObserver:
+      self.requestBuffer.append(.end)
+    case let .observing(observer, _):
+      observer(.end)
+    case .completed:
+      // We received a message but we're already done: this may happen if we terminate the RPC
+      // due to a channel error, for example.
+      ()
+    }
+  }
+
+  // MARK: - User Function To Interceptors
+
+  @inlinable
+  internal func userFunctionResolvedWithResult(
+    _ result: Result<(StreamEvent<Request>) -> Void, Error>
+  ) {
+    switch self.state {
+    case .idle, .observing:
+      // The observer block can't resolve if it hasn't been created ('idle') and it can't be
+      // resolved more than once ('observing').
+      preconditionFailure()
+
+    case let .creatingObserver(context):
+      switch result {
+      case let .success(observer):
+        // We have an observer block now; unbuffer any requests.
+        self.state = .observing(observer, context)
+        while let request = self.requestBuffer.popFirst() {
+          observer(request)
+        }
+
+      case let .failure(error):
+        self.handleError(error)
+      }
+
+    case .completed:
+      // We've already completed. That's fine.
+      ()
+    }
+  }
+
+  @inlinable
+  internal func interceptResponse(
+    _ response: Response,
+    metadata: MessageMetadata,
+    promise: EventLoopPromise<Void>?
+  ) {
+    switch self.state {
+    case .idle:
+      // The observer block can't end responses if it doesn't exist!
+      preconditionFailure()
+
+    case .creatingObserver, .observing:
+      // The user has access to the response context before returning a future observer,
+      // so 'creatingObserver' is valid here (if a little strange).
+      self.interceptors.send(.message(response, metadata), promise: promise)
+
+    case .completed:
+      promise?.fail(GRPCError.AlreadyComplete())
+    }
+  }
+
+  @inlinable
+  internal func userFunctionStatusResolved(_ result: Result<GRPCStatus, Error>) {
+    switch self.state {
+    case .idle:
+      // The promise can't fail before we create it.
+      preconditionFailure()
+
+    // Making is possible, the user can complete the status before returning a stream handler.
+    case let .creatingObserver(context), let .observing(_, context):
+      switch result {
+      case let .success(status):
+        self.interceptors.send(.end(status, context.trailers), promise: nil)
+
+      case let .failure(error):
+        let (status, trailers) = ServerErrorProcessor.processObserverError(
+          error,
+          headers: context.headers,
+          trailers: context.trailers,
+          delegate: self.context.errorDelegate
+        )
+        self.interceptors.send(.end(status, trailers), promise: nil)
+      }
+
+    case .completed:
+      ()
+    }
+  }
+
+  @inlinable
+  internal func handleError(_ error: Error) {
+    switch self.state {
+    case .idle:
+      // We don't have a promise to fail. Just send back end.
+      let (status, trailers) = ServerErrorProcessor.processLibraryError(
+        error,
+        delegate: self.context.errorDelegate
+      )
+      self.interceptors.send(.end(status, trailers), promise: nil)
+
+    case let .creatingObserver(context):
+      context.statusPromise.fail(error)
+
+    case let .observing(_, context):
+      context.statusPromise.fail(error)
+
+    case .completed:
+      ()
+    }
+  }
+
+  @inlinable
+  internal func sendInterceptedPart(
+    _ part: GRPCServerResponsePart<Response>,
+    promise: EventLoopPromise<Void>?
+  ) {
+    switch part {
+    case let .metadata(headers):
+      self.context.responseWriter.sendMetadata(headers, promise: promise)
+
+    case let .message(message, metadata):
+      do {
+        let bytes = try self.serializer.serialize(message, allocator: ByteBufferAllocator())
+        self.context.responseWriter.sendMessage(bytes, metadata: metadata, promise: promise)
+      } catch {
+        // Serialization failed: fail the promise and send end.
+        promise?.fail(error)
+        let (status, trailers) = ServerErrorProcessor.processLibraryError(
+          error,
+          delegate: self.context.errorDelegate
+        )
+        // Loop back via the interceptors.
+        self.interceptors.send(.end(status, trailers), promise: nil)
+      }
+
+    case let .end(status, trailers):
+      self.context.responseWriter.sendEnd(status: status, trailers: trailers, promise: promise)
+    }
+  }
+}

+ 227 - 0
Tests/GRPCTests/UnaryServerHandlerTests.swift

@@ -664,3 +664,230 @@ class ServerStreamingServerHandlerTests: GRPCTestCase, ServerHandlerTestCase {
     assertThat(self.recorder.trailers, .is([:]))
   }
 }
+
+// MARK: - Bidirectional Streaming
+
+class BidirectionalStreamingServerHandlerTests: GRPCTestCase, ServerHandlerTestCase {
+  let eventLoop = EmbeddedEventLoop()
+  let allocator = ByteBufferAllocator()
+  let recorder = ResponseRecorder()
+
+  private func makeHandler(
+    observerFactory: @escaping (StreamingResponseCallContext<String>)
+      -> EventLoopFuture<(StreamEvent<String>) -> Void>
+  ) -> BidirectionalStreamingServerHandler<StringSerializer, StringDeserializer> {
+    return BidirectionalStreamingServerHandler(
+      context: self.makeCallHandlerContext(),
+      requestDeserializer: StringDeserializer(),
+      responseSerializer: StringSerializer(),
+      interceptors: [],
+      observerFactory: observerFactory
+    )
+  }
+
+  private func echo(
+    context: StreamingResponseCallContext<String>
+  ) -> EventLoopFuture<(StreamEvent<String>) -> Void> {
+    func onEvent(_ event: StreamEvent<String>) {
+      switch event {
+      case let .message(message):
+        context.sendResponse(message, promise: nil)
+      case .end:
+        context.statusPromise.succeed(.ok)
+      }
+    }
+    return context.eventLoop.makeSucceededFuture(onEvent(_:))
+  }
+
+  private func neverReceivesMessage(
+    context: StreamingResponseCallContext<String>
+  ) -> EventLoopFuture<(StreamEvent<String>) -> Void> {
+    func onEvent(_ event: StreamEvent<String>) {
+      switch event {
+      case let .message(message):
+        XCTFail("Unexpected message: '\(message)'")
+      case .end:
+        context.statusPromise.succeed(.ok)
+      }
+    }
+    return context.eventLoop.makeSucceededFuture(onEvent(_:))
+  }
+
+  private func neverCalled(
+    context: StreamingResponseCallContext<String>
+  ) -> EventLoopFuture<(StreamEvent<String>) -> Void> {
+    XCTFail("This observer factory should never be called")
+    return context.eventLoop.makeFailedFuture(GRPCStatus(code: .aborted, message: nil))
+  }
+
+  func testHappyPath() {
+    let handler = self.makeHandler(observerFactory: self.echo(context:))
+
+    handler.receiveMetadata([:])
+    assertThat(self.recorder.metadata, .is([:]))
+
+    handler.receiveMessage(ByteBuffer(string: "1"))
+    handler.receiveMessage(ByteBuffer(string: "2"))
+    handler.receiveMessage(ByteBuffer(string: "3"))
+    handler.receiveEnd()
+    handler.finish()
+
+    assertThat(
+      self.recorder.messages,
+      .is([ByteBuffer(string: "1"), ByteBuffer(string: "2"), ByteBuffer(string: "3")])
+    )
+    assertThat(self.recorder.status, .notNil(.hasCode(.ok)))
+    assertThat(self.recorder.trailers, .is([:]))
+  }
+
+  func testThrowingDeserializer() {
+    let handler = BidirectionalStreamingServerHandler(
+      context: self.makeCallHandlerContext(),
+      requestDeserializer: ThrowingStringDeserializer(),
+      responseSerializer: StringSerializer(),
+      interceptors: [],
+      observerFactory: self.neverReceivesMessage(context:)
+    )
+
+    handler.receiveMetadata([:])
+    assertThat(self.recorder.metadata, .is([:]))
+
+    let buffer = ByteBuffer(string: "hello")
+    handler.receiveMessage(buffer)
+
+    assertThat(self.recorder.messages, .isEmpty())
+    assertThat(self.recorder.status, .notNil(.hasCode(.internalError)))
+  }
+
+  func testThrowingSerializer() {
+    let handler = BidirectionalStreamingServerHandler(
+      context: self.makeCallHandlerContext(),
+      requestDeserializer: StringDeserializer(),
+      responseSerializer: ThrowingStringSerializer(),
+      interceptors: [],
+      observerFactory: self.echo(context:)
+    )
+
+    handler.receiveMetadata([:])
+    assertThat(self.recorder.metadata, .is([:]))
+
+    let buffer = ByteBuffer(string: "hello")
+    handler.receiveMessage(buffer)
+    handler.receiveEnd()
+
+    assertThat(self.recorder.messages, .isEmpty())
+    assertThat(self.recorder.status, .notNil(.hasCode(.internalError)))
+  }
+
+  func testObserverFactoryReturnsFailedFuture() {
+    let handler = self.makeHandler { context in
+      context.eventLoop.makeFailedFuture(GRPCStatus(code: .unavailable, message: ":("))
+    }
+
+    handler.receiveMetadata([:])
+    assertThat(self.recorder.messages, .isEmpty())
+    assertThat(self.recorder.status, .notNil(.hasCode(.unavailable)))
+    assertThat(self.recorder.status?.message, .is(":("))
+  }
+
+  func testDelayedObserverFactory() {
+    let promise = self.eventLoop.makePromise(of: Void.self)
+    let handler = self.makeHandler { context in
+      return promise.futureResult.flatMap {
+        self.echo(context: context)
+      }
+    }
+
+    handler.receiveMetadata([:])
+    // Queue up some messages.
+    handler.receiveMessage(ByteBuffer(string: "1"))
+    // Succeed the observer block.
+    promise.succeed(())
+    // A few more messages.
+    handler.receiveMessage(ByteBuffer(string: "2"))
+    handler.receiveEnd()
+
+    assertThat(
+      self.recorder.messages,
+      .is([ByteBuffer(string: "1"), ByteBuffer(string: "2")])
+    )
+    assertThat(self.recorder.status, .notNil(.hasCode(.ok)))
+  }
+
+  func testDelayedObserverFactoryAllMessagesBeforeSucceeding() {
+    let promise = self.eventLoop.makePromise(of: Void.self)
+    let handler = self.makeHandler { context in
+      return promise.futureResult.flatMap {
+        self.echo(context: context)
+      }
+    }
+
+    handler.receiveMetadata([:])
+    // Queue up some messages.
+    handler.receiveMessage(ByteBuffer(string: "1"))
+    handler.receiveMessage(ByteBuffer(string: "2"))
+    handler.receiveEnd()
+    // Succeed the observer block.
+    promise.succeed(())
+
+    assertThat(
+      self.recorder.messages,
+      .is([ByteBuffer(string: "1"), ByteBuffer(string: "2")])
+    )
+    assertThat(self.recorder.status, .notNil(.hasCode(.ok)))
+  }
+
+  func testReceiveMessageBeforeHeaders() {
+    let handler = self.makeHandler(observerFactory: self.neverCalled(context:))
+
+    handler.receiveMessage(ByteBuffer(string: "foo"))
+    assertThat(self.recorder.metadata, .is(.nil()))
+    assertThat(self.recorder.messages, .isEmpty())
+    assertThat(self.recorder.status, .notNil(.hasCode(.internalError)))
+  }
+
+  func testReceiveMultipleHeaders() {
+    let handler = self.makeHandler(observerFactory: self.neverReceivesMessage(context:))
+
+    handler.receiveMetadata([:])
+    assertThat(self.recorder.metadata, .is([:]))
+
+    handler.receiveMetadata([:])
+    assertThat(self.recorder.messages, .isEmpty())
+    assertThat(self.recorder.status, .notNil(.hasCode(.internalError)))
+  }
+
+  func testFinishBeforeStarting() {
+    let handler = self.makeHandler(observerFactory: self.neverCalled(context:))
+
+    handler.finish()
+    assertThat(self.recorder.metadata, .is(.nil()))
+    assertThat(self.recorder.messages, .isEmpty())
+    assertThat(self.recorder.status, .is(.nil()))
+    assertThat(self.recorder.trailers, .is(.nil()))
+  }
+
+  func testFinishAfterHeaders() {
+    let handler = self.makeHandler(observerFactory: self.echo(context:))
+    handler.receiveMetadata([:])
+    assertThat(self.recorder.metadata, .is([:]))
+
+    handler.finish()
+
+    assertThat(self.recorder.messages, .isEmpty())
+    assertThat(self.recorder.status, .notNil(.hasCode(.unavailable)))
+    assertThat(self.recorder.trailers, .is([:]))
+  }
+
+  func testFinishAfterMessage() {
+    let handler = self.makeHandler(observerFactory: self.echo(context:))
+
+    handler.receiveMetadata([:])
+    handler.receiveMessage(ByteBuffer(string: "hello"))
+    handler.finish()
+
+    assertThat(self.recorder.messages.first, .is(ByteBuffer(string: "hello")))
+    assertThat(self.recorder.status, .notNil(.hasCode(.unavailable)))
+    assertThat(self.recorder.trailers, .is([:]))
+  }
+}