Browse Source

Add a server streaming server handler. (#1097)

Motivation:

Following from #1093 and #1095; we need a server streaming server handler.

Modifications:

- Add a server streaming server handler and tests.

Result:

We'll be able to do server streaming RPCs with new codegen.
George Barnett 4 years ago
parent
commit
d7a58b6d34

+ 317 - 0
Sources/GRPC/CallHandlers/ServerStreamingServerHandler.swift

@@ -0,0 +1,317 @@
+/*
+ * Copyright 2021, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import NIO
+import NIOHPACK
+
+public final class ServerStreamingServerHandler<
+  Serializer: MessageSerializer,
+  Deserializer: MessageDeserializer
+>: GRPCServerHandlerProtocol {
+  public typealias Request = Deserializer.Output
+  public typealias Response = Serializer.Input
+
+  /// A response serializer.
+  @usableFromInline
+  internal let serializer: Serializer
+
+  /// A request deserializer.
+  @usableFromInline
+  internal let deserializer: Deserializer
+
+  /// A pipeline of user provided interceptors.
+  @usableFromInline
+  internal var interceptors: ServerInterceptorPipeline<Request, Response>!
+
+  /// The context required in order create the function.
+  @usableFromInline
+  internal let context: CallHandlerContext
+
+  /// A reference to a `UserInfo`.
+  @usableFromInline
+  internal let userInfoRef: Ref<UserInfo>
+
+  /// The user provided function to execute.
+  @usableFromInline
+  internal let userFunction: (Request, StreamingResponseCallContext<Response>)
+    -> EventLoopFuture<GRPCStatus>
+
+  /// The state of the handler.
+  @usableFromInline
+  internal var state: State = .idle
+
+  @usableFromInline
+  internal enum State {
+    // Initial state. Nothing has happened yet.
+    case idle
+    // Headers have been received and now we're holding a context with which to invoke the user
+    // function when we receive a message.
+    case createdContext(_StreamingResponseCallContext<Request, Response>)
+    // The user function has been invoked, we're waiting for the status promise to be completed.
+    case invokedFunction(_StreamingResponseCallContext<Request, Response>)
+    // The function has completed or we are no longer proceeding with execution (because of an error
+    // or unexpected closure).
+    case completed
+  }
+
+  @inlinable
+  public init(
+    context: CallHandlerContext,
+    requestDeserializer: Deserializer,
+    responseSerializer: Serializer,
+    interceptors: [ServerInterceptor<Request, Response>],
+    userFunction: @escaping (Request, StreamingResponseCallContext<Response>)
+      -> EventLoopFuture<GRPCStatus>
+  ) {
+    self.serializer = responseSerializer
+    self.deserializer = requestDeserializer
+    self.context = context
+    self.userFunction = userFunction
+
+    let userInfoRef = Ref(UserInfo())
+    self.userInfoRef = userInfoRef
+    self.interceptors = ServerInterceptorPipeline(
+      logger: context.logger,
+      eventLoop: context.eventLoop,
+      path: context.path,
+      callType: .serverStreaming,
+      remoteAddress: context.remoteAddress,
+      userInfoRef: userInfoRef,
+      interceptors: interceptors,
+      onRequestPart: self.receiveInterceptedPart(_:),
+      onResponsePart: self.sendInterceptedPart(_:promise:)
+    )
+  }
+
+  // MARK: Public API; gRPC to Handler
+
+  @inlinable
+  public func receiveMetadata(_ headers: HPACKHeaders) {
+    self.interceptors.receive(.metadata(headers))
+  }
+
+  @inlinable
+  public func receiveMessage(_ bytes: ByteBuffer) {
+    do {
+      let message = try self.deserializer.deserialize(byteBuffer: bytes)
+      self.interceptors.receive(.message(message))
+    } catch {
+      self.handleError(error)
+    }
+  }
+
+  @inlinable
+  public func receiveEnd() {
+    self.interceptors.receive(.end)
+  }
+
+  @inlinable
+  public func receiveError(_ error: Error) {
+    self._finish(error: error)
+  }
+
+  @inlinable
+  public func finish() {
+    self._finish(error: nil)
+  }
+
+  @inlinable
+  internal func _finish(error: Error?) {
+    switch self.state {
+    case .idle:
+      self.interceptors = nil
+      self.state = .completed
+
+    case let .createdContext(context),
+         let .invokedFunction(context):
+      let error = error ?? GRPCStatus(code: .unavailable, message: nil)
+      context.statusPromise.fail(error)
+
+    case .completed:
+      self.interceptors = nil
+    }
+  }
+
+  // MARK: - Interceptors to User Function
+
+  @inlinable
+  internal func receiveInterceptedPart(_ part: GRPCServerRequestPart<Request>) {
+    switch part {
+    case let .metadata(headers):
+      self.receiveInterceptedMetadata(headers)
+    case let .message(message):
+      self.receiveInterceptedMessage(message)
+    case .end:
+      // Ignored.
+      ()
+    }
+  }
+
+  @inlinable
+  internal func receiveInterceptedMetadata(_ headers: HPACKHeaders) {
+    switch self.state {
+    case .idle:
+      // Make a context to invoke the observer block factory with.
+      let context = _StreamingResponseCallContext<Request, Response>(
+        eventLoop: self.context.eventLoop,
+        headers: headers,
+        logger: self.context.logger,
+        userInfoRef: self.userInfoRef,
+        sendResponse: self.interceptResponse(_:metadata:promise:)
+      )
+
+      // Move to the next state.
+      self.state = .createdContext(context)
+
+      // Register a callback on the status future.
+      context.statusPromise.futureResult.whenComplete(self.userFunctionCompletedWithResult(_:))
+
+      // Send response headers back via the interceptors.
+      self.interceptors.send(.metadata([:]), promise: nil)
+
+    case .createdContext, .invokedFunction:
+      self.handleError(GRPCError.InvalidState("Protocol violation: already received headers"))
+
+    case .completed:
+      // We may receive headers from the interceptor pipeline if we have already finished (i.e. due
+      // to an error or otherwise) and an interceptor doing some async work later emitting headers.
+      // Dropping them is fine.
+      ()
+    }
+  }
+
+  @inlinable
+  internal func receiveInterceptedMessage(_ request: Request) {
+    switch self.state {
+    case .idle:
+      self.handleError(
+        GRPCError.InvalidState("Protocol violation: message received before headers")
+      )
+    case let .createdContext(context):
+      self.state = .invokedFunction(context)
+      // Complete the status promise with the function outcome.
+      context.statusPromise.completeWith(self.userFunction(request, context))
+    case .invokedFunction:
+      self.handleError(GRPCError.InvalidState("Protocol violation: already received message"))
+    case .completed:
+      // We received a message but we're already done: this may happen if we terminate the RPC
+      // due to a channel error, for example.
+      ()
+    }
+  }
+
+  // MARK: - User Function To Interceptors
+
+  @inlinable
+  internal func interceptResponse(
+    _ response: Response,
+    metadata: MessageMetadata,
+    promise: EventLoopPromise<Void>?
+  ) {
+    switch self.state {
+    case .idle:
+      // The observer block can't send responses if it doesn't exist.
+      preconditionFailure()
+
+    case .createdContext, .invokedFunction:
+      // The user has access to the response context before returning a future observer,
+      // so 'createdContext' is valid here (if a little strange).
+      self.interceptors.send(.message(response, metadata), promise: promise)
+
+    case .completed:
+      promise?.fail(GRPCError.AlreadyComplete())
+    }
+  }
+
+  @inlinable
+  internal func userFunctionCompletedWithResult(_ result: Result<GRPCStatus, Error>) {
+    switch self.state {
+    case .idle:
+      // Invalid state: the user function can only completed if it was created.
+      preconditionFailure()
+
+    case let .createdContext(context),
+         let .invokedFunction(context):
+      self.state = .completed
+
+      switch result {
+      case let .success(status):
+        self.interceptors.send(.end(status, context.trailers), promise: nil)
+
+      case let .failure(error):
+        let (status, trailers) = ServerErrorProcessor.processObserverError(
+          error,
+          headers: context.headers,
+          trailers: context.trailers,
+          delegate: self.context.errorDelegate
+        )
+        self.interceptors.send(.end(status, trailers), promise: nil)
+      }
+
+    case .completed:
+      // We've already completed. Ignore this.
+      ()
+    }
+  }
+
+  @inlinable
+  internal func sendInterceptedPart(
+    _ part: GRPCServerResponsePart<Response>,
+    promise: EventLoopPromise<Void>?
+  ) {
+    switch part {
+    case let .metadata(headers):
+      self.context.responseWriter.sendMetadata(headers, promise: promise)
+
+    case let .message(message, metadata):
+      do {
+        let bytes = try self.serializer.serialize(message, allocator: ByteBufferAllocator())
+        self.context.responseWriter.sendMessage(bytes, metadata: metadata, promise: promise)
+      } catch {
+        // Serialization failed: fail the promise and send end.
+        promise?.fail(error)
+        let (status, trailers) = ServerErrorProcessor.processLibraryError(
+          error,
+          delegate: self.context.errorDelegate
+        )
+        // Loop back via the interceptors.
+        self.interceptors.send(.end(status, trailers), promise: nil)
+      }
+
+    case let .end(status, trailers):
+      self.context.responseWriter.sendEnd(status: status, trailers: trailers, promise: promise)
+    }
+  }
+
+  @inlinable
+  internal func handleError(_ error: Error) {
+    switch self.state {
+    case .idle:
+      // We don't have a promise to fail. Just send back end.
+      let (status, trailers) = ServerErrorProcessor.processLibraryError(
+        error,
+        delegate: self.context.errorDelegate
+      )
+      self.interceptors.send(.end(status, trailers), promise: nil)
+
+    case let .createdContext(context),
+         let .invokedFunction(context):
+      context.statusPromise.fail(error)
+
+    case .completed:
+      ()
+    }
+  }
+}

+ 9 - 5
Sources/GRPC/Compression/MessageEncoding.swift

@@ -16,15 +16,18 @@
 
 
 /// Whether compression should be enabled for the message.
 /// Whether compression should be enabled for the message.
 public struct Compression: Hashable {
 public struct Compression: Hashable {
-  private enum Wrapped: Hashable {
+  @usableFromInline
+  internal enum _Wrapped: Hashable {
     case enabled
     case enabled
     case disabled
     case disabled
     case deferToCallDefault
     case deferToCallDefault
   }
   }
 
 
-  private var wrapped: Wrapped
-  private init(_ wrapped: Wrapped) {
-    self.wrapped = wrapped
+  @usableFromInline
+  internal var _wrapped: _Wrapped
+
+  private init(_ wrapped: _Wrapped) {
+    self._wrapped = wrapped
   }
   }
 
 
   /// Enable compression. Note that this will be ignored if compression has not been enabled or is
   /// Enable compression. Note that this will be ignored if compression has not been enabled or is
@@ -40,8 +43,9 @@ public struct Compression: Hashable {
 }
 }
 
 
 extension Compression {
 extension Compression {
+  @inlinable
   internal func isEnabled(callDefault: Bool) -> Bool {
   internal func isEnabled(callDefault: Bool) -> Bool {
-    switch self.wrapped {
+    switch self._wrapped {
     case .enabled:
     case .enabled:
       return callDefault
       return callDefault
     case .disabled:
     case .disabled:

+ 9 - 2
Sources/GRPC/ServerCallContexts/StreamingResponseCallContext.swift

@@ -40,6 +40,7 @@ open class StreamingResponseCallContext<ResponsePayload>: ServerCallContextBase
     self.init(eventLoop: eventLoop, headers: headers, logger: logger, userInfoRef: .init(userInfo))
     self.init(eventLoop: eventLoop, headers: headers, logger: logger, userInfoRef: .init(userInfo))
   }
   }
 
 
+  @inlinable
   override internal init(
   override internal init(
     eventLoop: EventLoop,
     eventLoop: EventLoop,
     headers: HPACKHeaders,
     headers: HPACKHeaders,
@@ -119,10 +120,13 @@ open class StreamingResponseCallContext<ResponsePayload>: ServerCallContextBase
   }
   }
 }
 }
 
 
+@usableFromInline
 internal final class _StreamingResponseCallContext<Request, Response>:
 internal final class _StreamingResponseCallContext<Request, Response>:
   StreamingResponseCallContext<Response> {
   StreamingResponseCallContext<Response> {
-  private let _sendResponse: (Response, MessageMetadata, EventLoopPromise<Void>?) -> Void
+  @usableFromInline
+  internal let _sendResponse: (Response, MessageMetadata, EventLoopPromise<Void>?) -> Void
 
 
+  @inlinable
   internal init(
   internal init(
     eventLoop: EventLoop,
     eventLoop: EventLoop,
     headers: HPACKHeaders,
     headers: HPACKHeaders,
@@ -134,6 +138,7 @@ internal final class _StreamingResponseCallContext<Request, Response>:
     super.init(eventLoop: eventLoop, headers: headers, logger: logger, userInfoRef: userInfoRef)
     super.init(eventLoop: eventLoop, headers: headers, logger: logger, userInfoRef: userInfoRef)
   }
   }
 
 
+  @inlinable
   override func sendResponse(
   override func sendResponse(
     _ message: Response,
     _ message: Response,
     compression: Compression = .deferToCallDefault,
     compression: Compression = .deferToCallDefault,
@@ -149,6 +154,7 @@ internal final class _StreamingResponseCallContext<Request, Response>:
     }
     }
   }
   }
 
 
+  @inlinable
   override func sendResponses<Messages: Sequence>(
   override func sendResponses<Messages: Sequence>(
     _ messages: Messages,
     _ messages: Messages,
     compression: Compression = .deferToCallDefault,
     compression: Compression = .deferToCallDefault,
@@ -163,7 +169,8 @@ internal final class _StreamingResponseCallContext<Request, Response>:
     }
     }
   }
   }
 
 
-  private func _sendResponses<Messages: Sequence>(
+  @inlinable
+  internal func _sendResponses<Messages: Sequence>(
     _ messages: Messages,
     _ messages: Messages,
     compression: Compression,
     compression: Compression,
     promise: EventLoopPromise<Void>?
     promise: EventLoopPromise<Void>?

+ 188 - 0
Tests/GRPCTests/UnaryServerHandlerTests.swift

@@ -476,3 +476,191 @@ class ClientStreamingServerHandlerTests: GRPCTestCase, ServerHandlerTestCase {
     assertThat(self.recorder.trailers, .is([:]))
     assertThat(self.recorder.trailers, .is([:]))
   }
   }
 }
 }
+
+class ServerStreamingServerHandlerTests: GRPCTestCase, ServerHandlerTestCase {
+  let eventLoop = EmbeddedEventLoop()
+  let allocator = ByteBufferAllocator()
+  let recorder = ResponseRecorder()
+
+  private func makeHandler(
+    userFunction: @escaping (String, StreamingResponseCallContext<String>)
+      -> EventLoopFuture<GRPCStatus>
+  ) -> ServerStreamingServerHandler<StringSerializer, StringDeserializer> {
+    return ServerStreamingServerHandler(
+      context: self.makeCallHandlerContext(),
+      requestDeserializer: StringDeserializer(),
+      responseSerializer: StringSerializer(),
+      interceptors: [],
+      userFunction: userFunction
+    )
+  }
+
+  private func breakOnSpaces(
+    _ request: String,
+    context: StreamingResponseCallContext<String>
+  ) -> EventLoopFuture<GRPCStatus> {
+    let parts = request.components(separatedBy: " ")
+    context.sendResponses(parts, promise: nil)
+    return context.eventLoop.makeSucceededFuture(.ok)
+  }
+
+  private func neverCalled(
+    _ request: String,
+    context: StreamingResponseCallContext<String>
+  ) -> EventLoopFuture<GRPCStatus> {
+    XCTFail("Unexpected invocation")
+    return context.eventLoop.makeSucceededFuture(.processingError)
+  }
+
+  private func neverComplete(
+    _ request: String,
+    context: StreamingResponseCallContext<String>
+  ) -> EventLoopFuture<GRPCStatus> {
+    return context.eventLoop.scheduleTask(deadline: .distantFuture) {
+      return .processingError
+    }.futureResult
+  }
+
+  func testHappyPath() {
+    let handler = self.makeHandler(userFunction: self.breakOnSpaces(_:context:))
+
+    handler.receiveMetadata([:])
+    assertThat(self.recorder.metadata, .is([:]))
+
+    handler.receiveMessage(ByteBuffer(string: "a b"))
+    handler.receiveEnd()
+    handler.finish()
+
+    assertThat(
+      self.recorder.messages,
+      .is([ByteBuffer(string: "a"), ByteBuffer(string: "b")])
+    )
+    assertThat(self.recorder.status, .notNil(.hasCode(.ok)))
+    assertThat(self.recorder.trailers, .is([:]))
+  }
+
+  func testThrowingDeserializer() {
+    let handler = ServerStreamingServerHandler(
+      context: self.makeCallHandlerContext(),
+      requestDeserializer: ThrowingStringDeserializer(),
+      responseSerializer: StringSerializer(),
+      interceptors: [],
+      userFunction: self.neverCalled(_:context:)
+    )
+
+    handler.receiveMetadata([:])
+    assertThat(self.recorder.metadata, .is([:]))
+
+    let buffer = ByteBuffer(string: "hello")
+    handler.receiveMessage(buffer)
+
+    assertThat(self.recorder.messages, .isEmpty())
+    assertThat(self.recorder.status, .notNil(.hasCode(.internalError)))
+  }
+
+  func testThrowingSerializer() {
+    let handler = ServerStreamingServerHandler(
+      context: self.makeCallHandlerContext(),
+      requestDeserializer: StringDeserializer(),
+      responseSerializer: ThrowingStringSerializer(),
+      interceptors: [],
+      userFunction: self.breakOnSpaces(_:context:)
+    )
+
+    handler.receiveMetadata([:])
+    assertThat(self.recorder.metadata, .is([:]))
+
+    let buffer = ByteBuffer(string: "1 2 3")
+    handler.receiveMessage(buffer)
+    handler.receiveEnd()
+
+    assertThat(self.recorder.messages, .isEmpty())
+    assertThat(self.recorder.status, .notNil(.hasCode(.internalError)))
+  }
+
+  func testUserFunctionReturnsFailedFuture() {
+    let handler = self.makeHandler { _, context in
+      return context.eventLoop.makeFailedFuture(GRPCStatus(code: .unavailable, message: ":("))
+    }
+
+    handler.receiveMetadata([:])
+    assertThat(self.recorder.metadata, .is([:]))
+
+    let buffer = ByteBuffer(string: "hello")
+    handler.receiveMessage(buffer)
+
+    assertThat(self.recorder.messages, .isEmpty())
+    assertThat(self.recorder.status, .notNil(.hasCode(.unavailable)))
+    assertThat(self.recorder.status?.message, .is(":("))
+  }
+
+  func testReceiveMessageBeforeHeaders() {
+    let handler = self.makeHandler(userFunction: self.neverCalled(_:context:))
+
+    handler.receiveMessage(ByteBuffer(string: "foo"))
+    assertThat(self.recorder.metadata, .is(.nil()))
+    assertThat(self.recorder.messages, .isEmpty())
+    assertThat(self.recorder.status, .notNil(.hasCode(.internalError)))
+  }
+
+  func testReceiveMultipleHeaders() {
+    let handler = self.makeHandler(userFunction: self.neverCalled(_:context:))
+
+    handler.receiveMetadata([:])
+    assertThat(self.recorder.metadata, .is([:]))
+
+    handler.receiveMetadata([:])
+    assertThat(self.recorder.messages, .isEmpty())
+    assertThat(self.recorder.status, .notNil(.hasCode(.internalError)))
+  }
+
+  func testReceiveMultipleMessages() {
+    let handler = self.makeHandler(userFunction: self.neverComplete(_:context:))
+
+    handler.receiveMetadata([:])
+    assertThat(self.recorder.metadata, .is([:]))
+
+    let buffer = ByteBuffer(string: "hello")
+    handler.receiveMessage(buffer)
+    handler.receiveEnd()
+    // Send another message before the function completes.
+    handler.receiveMessage(buffer)
+
+    assertThat(self.recorder.messages, .isEmpty())
+    assertThat(self.recorder.status, .notNil(.hasCode(.internalError)))
+  }
+
+  func testFinishBeforeStarting() {
+    let handler = self.makeHandler(userFunction: self.neverCalled(_:context:))
+
+    handler.finish()
+    assertThat(self.recorder.metadata, .is(.nil()))
+    assertThat(self.recorder.messages, .isEmpty())
+    assertThat(self.recorder.status, .is(.nil()))
+    assertThat(self.recorder.trailers, .is(.nil()))
+  }
+
+  func testFinishAfterHeaders() {
+    let handler = self.makeHandler(userFunction: self.neverCalled(_:context:))
+    handler.receiveMetadata([:])
+    assertThat(self.recorder.metadata, .is([:]))
+
+    handler.finish()
+
+    assertThat(self.recorder.messages, .isEmpty())
+    assertThat(self.recorder.status, .notNil(.hasCode(.unavailable)))
+    assertThat(self.recorder.trailers, .is([:]))
+  }
+
+  func testFinishAfterMessage() {
+    let handler = self.makeHandler(userFunction: self.neverComplete(_:context:))
+
+    handler.receiveMetadata([:])
+    handler.receiveMessage(ByteBuffer(string: "hello"))
+    handler.finish()
+
+    assertThat(self.recorder.messages, .isEmpty())
+    assertThat(self.recorder.status, .notNil(.hasCode(.unavailable)))
+    assertThat(self.recorder.trailers, .is([:]))
+  }
+}