Răsfoiți Sursa

Add a client streaming server handler (#1095)

Motivation:

Following from #1093; we need a client streaming server handler.

Modifications:

- Add a client streaming server handler and tests.

Result:

We'll be able to do client streaming RPCs with new codegen.
George Barnett 4 ani în urmă
părinte
comite
2e511b0f4a

+ 345 - 0
Sources/GRPC/CallHandlers/ClientStreamingServerHandler.swift

@@ -0,0 +1,345 @@
+/*
+ * Copyright 2021, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import NIO
+import NIOHPACK
+
+public final class ClientStreamingServerHandler<
+  Serializer: MessageSerializer,
+  Deserializer: MessageDeserializer
+>: GRPCServerHandlerProtocol {
+  public typealias Request = Deserializer.Output
+  public typealias Response = Serializer.Input
+
+  /// A response serializer.
+  @usableFromInline
+  internal let serializer: Serializer
+
+  /// A request deserializer.
+  @usableFromInline
+  internal let deserializer: Deserializer
+
+  /// A pipeline of user provided interceptors.
+  @usableFromInline
+  internal var interceptors: ServerInterceptorPipeline<Request, Response>!
+
+  /// Stream events which have arrived before the stream observer future has been resolved.
+  @usableFromInline
+  internal var requestBuffer: CircularBuffer<StreamEvent<Request>> = CircularBuffer()
+
+  /// The context required in order create the function.
+  @usableFromInline
+  internal let context: CallHandlerContext
+
+  /// A reference to a `UserInfo`.
+  @usableFromInline
+  internal let userInfoRef: Ref<UserInfo>
+
+  /// The user provided function to execute.
+  @usableFromInline
+  internal let handlerFactory: (UnaryResponseCallContext<Response>)
+    -> EventLoopFuture<(StreamEvent<Request>) -> Void>
+
+  /// The state of the handler.
+  @usableFromInline
+  internal var state: State = .idle
+
+  @usableFromInline
+  internal enum State {
+    // Nothing has happened yet.
+    case idle
+    // Headers have been received, a context has been created and the user code has been called to
+    // make an observer with. The observer future hasn't completed yet and, as such, the observer
+    // is yet to see any events.
+    case creatingObserver(UnaryResponseCallContext<Response>)
+    // The observer future has succeeded, messages may have been delivered to it.
+    case observing((StreamEvent<Request>) -> Void, UnaryResponseCallContext<Response>)
+    // The observer has completed by completing the status promise.
+    case completed
+  }
+
+  @inlinable
+  public init(
+    context: CallHandlerContext,
+    requestDeserializer: Deserializer,
+    responseSerializer: Serializer,
+    interceptors: [ServerInterceptor<Request, Response>],
+    observerFactory: @escaping (UnaryResponseCallContext<Response>)
+      -> EventLoopFuture<(StreamEvent<Request>) -> Void>
+  ) {
+    self.serializer = responseSerializer
+    self.deserializer = requestDeserializer
+    self.context = context
+    self.handlerFactory = observerFactory
+
+    let userInfoRef = Ref(UserInfo())
+    self.userInfoRef = userInfoRef
+    self.interceptors = ServerInterceptorPipeline(
+      logger: context.logger,
+      eventLoop: context.eventLoop,
+      path: context.path,
+      callType: .clientStreaming,
+      remoteAddress: context.remoteAddress,
+      userInfoRef: userInfoRef,
+      interceptors: interceptors,
+      onRequestPart: self.receiveInterceptedPart(_:),
+      onResponsePart: self.sendInterceptedPart(_:promise:)
+    )
+  }
+
+  // MARK: Public API; gRPC to Handler
+
+  @inlinable
+  public func receiveMetadata(_ headers: HPACKHeaders) {
+    self.interceptors.receive(.metadata(headers))
+  }
+
+  @inlinable
+  public func receiveMessage(_ bytes: ByteBuffer) {
+    do {
+      let message = try self.deserializer.deserialize(byteBuffer: bytes)
+      self.interceptors.receive(.message(message))
+    } catch {
+      self.handleError(error)
+    }
+  }
+
+  @inlinable
+  public func receiveEnd() {
+    self.interceptors.receive(.end)
+  }
+
+  @inlinable
+  public func receiveError(_ error: Error) {
+    self._finish(error: error)
+  }
+
+  @inlinable
+  public func finish() {
+    self._finish(error: nil)
+  }
+
+  @inlinable
+  internal func _finish(error: Error?) {
+    switch self.state {
+    case .idle:
+      self.interceptors = nil
+      self.state = .completed
+
+    case let .creatingObserver(context),
+         let .observing(_, context):
+      let error = error ?? GRPCStatus(code: .unavailable, message: nil)
+      context.responsePromise.fail(error)
+
+    case .completed:
+      self.interceptors = nil
+    }
+  }
+
+  // MARK: Interceptors to User Function
+
+  @inlinable
+  internal func receiveInterceptedPart(_ part: GRPCServerRequestPart<Request>) {
+    switch part {
+    case let .metadata(headers):
+      self.receiveInterceptedMetadata(headers)
+    case let .message(message):
+      self.receiveInterceptedMessage(message)
+    case .end:
+      self.receiveInterceptedEnd()
+    }
+  }
+
+  @inlinable
+  internal func receiveInterceptedMetadata(_ headers: HPACKHeaders) {
+    switch self.state {
+    case .idle:
+      // Make a context to invoke the observer block factory with.
+      let context = UnaryResponseCallContext<Response>(
+        eventLoop: self.context.eventLoop,
+        headers: headers,
+        logger: self.context.logger,
+        userInfoRef: self.userInfoRef
+      )
+
+      // Move to the next state.
+      self.state = .creatingObserver(context)
+
+      // Register a callback on the response future.
+      context.responsePromise.futureResult.whenComplete(self.userFunctionCompleted(with:))
+
+      // Make an observer block and register a completion block.
+      self.handlerFactory(context).whenComplete(self.userFunctionResolved(_:))
+
+      // Send response headers back via the interceptors.
+      self.interceptors.send(.metadata([:]), promise: nil)
+
+    case .creatingObserver, .observing:
+      self.handleError(GRPCError.InvalidState("Protocol violation: already received headers"))
+
+    case .completed:
+      // We may receive headers from the interceptor pipeline if we have already finished (i.e. due
+      // to an error or otherwise) and an interceptor doing some async work later emitting headers.
+      // Dropping them is fine.
+      ()
+    }
+  }
+
+  @inlinable
+  internal func receiveInterceptedMessage(_ request: Request) {
+    switch self.state {
+    case .idle:
+      self
+        .handleError(GRPCError.InvalidState("Protocol violation: message received before headers"))
+    case .creatingObserver:
+      self.requestBuffer.append(.message(request))
+    case let .observing(observer, _):
+      observer(.message(request))
+    case .completed:
+      // We received a message but we're already done: this may happen if we terminate the RPC
+      // due to a channel error, for example.
+      ()
+    }
+  }
+
+  @inlinable
+  internal func receiveInterceptedEnd() {
+    switch self.state {
+    case .idle:
+      self.handleError(GRPCError.InvalidState("Protocol violation: 'end received before headers'"))
+    case .creatingObserver:
+      self.requestBuffer.append(.end)
+    case let .observing(observer, _):
+      observer(.end)
+    case .completed:
+      // We received a message but we're already done: this may happen if we terminate the RPC
+      // due to a channel error, for example.
+      ()
+    }
+  }
+
+  // MARK: User Function to Interceptors
+
+  @inlinable
+  internal func userFunctionResolved(_ result: Result<(StreamEvent<Request>) -> Void, Error>) {
+    switch self.state {
+    case .idle, .observing:
+      // The observer block can't resolve if it hasn't been created ('idle') and it can't be
+      // resolved more than once ('created').
+      preconditionFailure()
+
+    case let .creatingObserver(context):
+      switch result {
+      case let .success(observer):
+        // We have an observer block now; unbuffer any requests.
+        self.state = .observing(observer, context)
+        while let request = self.requestBuffer.popFirst() {
+          observer(request)
+        }
+
+      case let .failure(error):
+        self.handleError(error)
+      }
+
+    case .completed:
+      // We've already completed. That's fine.
+      ()
+    }
+  }
+
+  @inlinable
+  internal func userFunctionCompleted(with result: Result<Response, Error>) {
+    switch self.state {
+    case .idle:
+      // Invalid state: the user function can only complete if it exists..
+      preconditionFailure()
+
+    case let .creatingObserver(context),
+         let .observing(_, context):
+      self.state = .completed
+
+      switch result {
+      case let .success(response):
+        let metadata = MessageMetadata(compress: false, flush: false)
+        self.interceptors.send(.message(response, metadata), promise: nil)
+        self.interceptors.send(.end(context.responseStatus, context.trailers), promise: nil)
+
+      case let .failure(error):
+        let (status, trailers) = ServerErrorProcessor.processObserverError(
+          error,
+          headers: context.headers,
+          trailers: context.trailers,
+          delegate: self.context.errorDelegate
+        )
+        self.interceptors.send(.end(status, trailers), promise: nil)
+      }
+
+    case .completed:
+      // We've already completed. Ignore this.
+      ()
+    }
+  }
+
+  @inlinable
+  internal func handleError(_ error: Error) {
+    switch self.state {
+    case .idle:
+      // We don't have a promise to fail. Just send back end.
+      let (status, trailers) = ServerErrorProcessor.processLibraryError(
+        error,
+        delegate: self.context.errorDelegate
+      )
+      self.interceptors.send(.end(status, trailers), promise: nil)
+
+    case let .creatingObserver(context),
+         let .observing(_, context):
+      context.responsePromise.fail(error)
+
+    case .completed:
+      ()
+    }
+  }
+
+  // MARK: Interceptor Glue
+
+  @inlinable
+  internal func sendInterceptedPart(
+    _ part: GRPCServerResponsePart<Response>,
+    promise: EventLoopPromise<Void>?
+  ) {
+    switch part {
+    case let .metadata(headers):
+      self.context.responseWriter.sendMetadata(headers, promise: promise)
+
+    case let .message(message, metadata):
+      do {
+        let bytes = try self.serializer.serialize(message, allocator: ByteBufferAllocator())
+        self.context.responseWriter.sendMessage(bytes, metadata: metadata, promise: promise)
+      } catch {
+        // Serialization failed: fail the promise and send end.
+        promise?.fail(error)
+        let (status, trailers) = ServerErrorProcessor.processLibraryError(
+          error,
+          delegate: self.context.errorDelegate
+        )
+        // Loop back via the interceptors.
+        self.interceptors.send(.end(status, trailers), promise: nil)
+      }
+
+    case let .end(status, trailers):
+      self.context.responseWriter.sendEnd(status: status, trailers: trailers, promise: promise)
+    }
+  }
+}

+ 286 - 61
Tests/GRPCTests/UnaryServerHandlerTests.swift

@@ -18,6 +18,8 @@ import NIO
 import NIOHPACK
 import XCTest
 
+// MARK: - Utils
+
 final class ResponseRecorder: GRPCServerResponseWriter {
   var metadata: HPACKHeaders?
   var messages: [ByteBuffer] = []
@@ -48,11 +50,14 @@ final class ResponseRecorder: GRPCServerResponseWriter {
   }
 }
 
-class UnaryServerHandlerTests: GRPCTestCase {
-  let eventLoop = EmbeddedEventLoop()
-  let allocator = ByteBufferAllocator()
+protocol ServerHandlerTestCase: GRPCTestCase {
+  var eventLoop: EmbeddedEventLoop { get }
+  var allocator: ByteBufferAllocator { get }
+  var recorder: ResponseRecorder { get }
+}
 
-  private func makeCallHandlerContext(writer: GRPCServerResponseWriter) -> CallHandlerContext {
+extension ServerHandlerTestCase {
+  func makeCallHandlerContext() -> CallHandlerContext {
     return CallHandlerContext(
       errorDelegate: nil,
       logger: self.logger,
@@ -60,17 +65,24 @@ class UnaryServerHandlerTests: GRPCTestCase {
       eventLoop: self.eventLoop,
       path: "/ignored",
       remoteAddress: nil,
-      responseWriter: writer,
+      responseWriter: self.recorder,
       allocator: self.allocator
     )
   }
+}
+
+// MARK: - Unary
+
+class UnaryServerHandlerTests: GRPCTestCase, ServerHandlerTestCase {
+  let eventLoop = EmbeddedEventLoop()
+  let allocator = ByteBufferAllocator()
+  let recorder = ResponseRecorder()
 
   private func makeHandler(
-    writer: GRPCServerResponseWriter,
     function: @escaping (String, StatusOnlyCallContext) -> EventLoopFuture<String>
   ) -> UnaryServerHandler<StringSerializer, StringDeserializer> {
     return UnaryServerHandler(
-      context: self.makeCallHandlerContext(writer: writer),
+      context: self.makeCallHandlerContext(),
       requestDeserializer: StringDeserializer(),
       responseSerializer: StringSerializer(),
       interceptors: [],
@@ -101,26 +113,24 @@ class UnaryServerHandlerTests: GRPCTestCase {
   }
 
   func testHappyPath() {
-    let recorder = ResponseRecorder()
-    let handler = self.makeHandler(writer: recorder, function: self.echo(_:context:))
+    let handler = self.makeHandler(function: self.echo(_:context:))
 
     handler.receiveMetadata([:])
-    assertThat(recorder.metadata, .is([:]))
+    assertThat(self.recorder.metadata, .is([:]))
 
     let buffer = ByteBuffer(string: "hello")
     handler.receiveMessage(buffer)
     handler.receiveEnd()
     handler.finish()
 
-    assertThat(recorder.messages.first, .is(buffer))
-    assertThat(recorder.status, .notNil(.hasCode(.ok)))
-    assertThat(recorder.trailers, .is([:]))
+    assertThat(self.recorder.messages.first, .is(buffer))
+    assertThat(self.recorder.status, .notNil(.hasCode(.ok)))
+    assertThat(self.recorder.trailers, .is([:]))
   }
 
   func testThrowingDeserializer() {
-    let recorder = ResponseRecorder()
     let handler = UnaryServerHandler(
-      context: self.makeCallHandlerContext(writer: recorder),
+      context: self.makeCallHandlerContext(),
       requestDeserializer: ThrowingStringDeserializer(),
       responseSerializer: StringSerializer(),
       interceptors: [],
@@ -128,19 +138,18 @@ class UnaryServerHandlerTests: GRPCTestCase {
     )
 
     handler.receiveMetadata([:])
-    assertThat(recorder.metadata, .is([:]))
+    assertThat(self.recorder.metadata, .is([:]))
 
     let buffer = ByteBuffer(string: "hello")
     handler.receiveMessage(buffer)
 
-    assertThat(recorder.messages, .isEmpty())
-    assertThat(recorder.status, .notNil(.hasCode(.internalError)))
+    assertThat(self.recorder.messages, .isEmpty())
+    assertThat(self.recorder.status, .notNil(.hasCode(.internalError)))
   }
 
   func testThrowingSerializer() {
-    let recorder = ResponseRecorder()
     let handler = UnaryServerHandler(
-      context: self.makeCallHandlerContext(writer: recorder),
+      context: self.makeCallHandlerContext(),
       requestDeserializer: StringDeserializer(),
       responseSerializer: ThrowingStringSerializer(),
       interceptors: [],
@@ -148,61 +157,57 @@ class UnaryServerHandlerTests: GRPCTestCase {
     )
 
     handler.receiveMetadata([:])
-    assertThat(recorder.metadata, .is([:]))
+    assertThat(self.recorder.metadata, .is([:]))
 
     let buffer = ByteBuffer(string: "hello")
     handler.receiveMessage(buffer)
     handler.receiveEnd()
 
-    assertThat(recorder.messages, .isEmpty())
-    assertThat(recorder.status, .notNil(.hasCode(.internalError)))
+    assertThat(self.recorder.messages, .isEmpty())
+    assertThat(self.recorder.status, .notNil(.hasCode(.internalError)))
   }
 
   func testUserFunctionReturnsFailedFuture() {
-    let recorder = ResponseRecorder()
-    let handler = self.makeHandler(writer: recorder) { _, context in
+    let handler = self.makeHandler { _, context in
       return context.eventLoop.makeFailedFuture(GRPCStatus(code: .unavailable, message: ":("))
     }
 
     handler.receiveMetadata([:])
-    assertThat(recorder.metadata, .is([:]))
+    assertThat(self.recorder.metadata, .is([:]))
 
     let buffer = ByteBuffer(string: "hello")
     handler.receiveMessage(buffer)
 
-    assertThat(recorder.messages, .isEmpty())
-    assertThat(recorder.status, .notNil(.hasCode(.unavailable)))
-    assertThat(recorder.status?.message, .is(":("))
+    assertThat(self.recorder.messages, .isEmpty())
+    assertThat(self.recorder.status, .notNil(.hasCode(.unavailable)))
+    assertThat(self.recorder.status?.message, .is(":("))
   }
 
   func testReceiveMessageBeforeHeaders() {
-    let recorder = ResponseRecorder()
-    let handler = self.makeHandler(writer: recorder, function: self.neverCalled(_:context:))
+    let handler = self.makeHandler(function: self.neverCalled(_:context:))
 
     handler.receiveMessage(ByteBuffer(string: "foo"))
-    assertThat(recorder.metadata, .is(.nil()))
-    assertThat(recorder.messages, .isEmpty())
-    assertThat(recorder.status, .notNil(.hasCode(.internalError)))
+    assertThat(self.recorder.metadata, .is(.nil()))
+    assertThat(self.recorder.messages, .isEmpty())
+    assertThat(self.recorder.status, .notNil(.hasCode(.internalError)))
   }
 
   func testReceiveMultipleHeaders() {
-    let recorder = ResponseRecorder()
-    let handler = self.makeHandler(writer: recorder, function: self.neverCalled(_:context:))
+    let handler = self.makeHandler(function: self.neverCalled(_:context:))
 
     handler.receiveMetadata([:])
-    assertThat(recorder.metadata, .is([:]))
+    assertThat(self.recorder.metadata, .is([:]))
 
     handler.receiveMetadata([:])
-    assertThat(recorder.messages, .isEmpty())
-    assertThat(recorder.status, .notNil(.hasCode(.internalError)))
+    assertThat(self.recorder.messages, .isEmpty())
+    assertThat(self.recorder.status, .notNil(.hasCode(.internalError)))
   }
 
   func testReceiveMultipleMessages() {
-    let recorder = ResponseRecorder()
-    let handler = self.makeHandler(writer: recorder, function: self.neverComplete(_:context:))
+    let handler = self.makeHandler(function: self.neverComplete(_:context:))
 
     handler.receiveMetadata([:])
-    assertThat(recorder.metadata, .is([:]))
+    assertThat(self.recorder.metadata, .is([:]))
 
     let buffer = ByteBuffer(string: "hello")
     handler.receiveMessage(buffer)
@@ -210,44 +215,264 @@ class UnaryServerHandlerTests: GRPCTestCase {
     // Send another message before the function completes.
     handler.receiveMessage(buffer)
 
-    assertThat(recorder.messages, .isEmpty())
-    assertThat(recorder.status, .notNil(.hasCode(.internalError)))
+    assertThat(self.recorder.messages, .isEmpty())
+    assertThat(self.recorder.status, .notNil(.hasCode(.internalError)))
+  }
+
+  func testFinishBeforeStarting() {
+    let handler = self.makeHandler(function: self.neverCalled(_:context:))
+
+    handler.finish()
+    assertThat(self.recorder.metadata, .is(.nil()))
+    assertThat(self.recorder.messages, .isEmpty())
+    assertThat(self.recorder.status, .is(.nil()))
+    assertThat(self.recorder.trailers, .is(.nil()))
+  }
+
+  func testFinishAfterHeaders() {
+    let handler = self.makeHandler(function: self.neverCalled(_:context:))
+    handler.receiveMetadata([:])
+    assertThat(self.recorder.metadata, .is([:]))
+
+    handler.finish()
+
+    assertThat(self.recorder.messages, .isEmpty())
+    assertThat(self.recorder.status, .notNil(.hasCode(.unavailable)))
+    assertThat(self.recorder.trailers, .is([:]))
+  }
+
+  func testFinishAfterMessage() {
+    let handler = self.makeHandler(function: self.neverComplete(_:context:))
+
+    handler.receiveMetadata([:])
+    handler.receiveMessage(ByteBuffer(string: "hello"))
+    handler.finish()
+
+    assertThat(self.recorder.messages, .isEmpty())
+    assertThat(self.recorder.status, .notNil(.hasCode(.unavailable)))
+    assertThat(self.recorder.trailers, .is([:]))
+  }
+}
+
+// MARK: - Client Streaming
+
+class ClientStreamingServerHandlerTests: GRPCTestCase, ServerHandlerTestCase {
+  let eventLoop = EmbeddedEventLoop()
+  let allocator = ByteBufferAllocator()
+  let recorder = ResponseRecorder()
+
+  private func makeHandler(
+    observerFactory: @escaping (UnaryResponseCallContext<String>)
+      -> EventLoopFuture<(StreamEvent<String>) -> Void>
+  ) -> ClientStreamingServerHandler<StringSerializer, StringDeserializer> {
+    return ClientStreamingServerHandler(
+      context: self.makeCallHandlerContext(),
+      requestDeserializer: StringDeserializer(),
+      responseSerializer: StringSerializer(),
+      interceptors: [],
+      observerFactory: observerFactory
+    )
+  }
+
+  private func joinWithSpaces(
+    context: UnaryResponseCallContext<String>
+  ) -> EventLoopFuture<(StreamEvent<String>) -> Void> {
+    var messages: [String] = []
+    func onEvent(_ event: StreamEvent<String>) {
+      switch event {
+      case let .message(message):
+        messages.append(message)
+      case .end:
+        context.responsePromise.succeed(messages.joined(separator: " "))
+      }
+    }
+    return context.eventLoop.makeSucceededFuture(onEvent(_:))
+  }
+
+  private func neverReceivesMessage(
+    context: UnaryResponseCallContext<String>
+  ) -> EventLoopFuture<(StreamEvent<String>) -> Void> {
+    func onEvent(_ event: StreamEvent<String>) {
+      switch event {
+      case let .message(message):
+        XCTFail("Unexpected message: '\(message)'")
+      case .end:
+        context.responsePromise.succeed("")
+      }
+    }
+    return context.eventLoop.makeSucceededFuture(onEvent(_:))
+  }
+
+  private func neverCalled(
+    context: UnaryResponseCallContext<String>
+  ) -> EventLoopFuture<(StreamEvent<String>) -> Void> {
+    XCTFail("This observer factory should never be called")
+    return context.eventLoop.makeFailedFuture(GRPCStatus(code: .aborted, message: nil))
+  }
+
+  func testHappyPath() {
+    let handler = self.makeHandler(observerFactory: self.joinWithSpaces(context:))
+
+    handler.receiveMetadata([:])
+    assertThat(self.recorder.metadata, .is([:]))
+
+    handler.receiveMessage(ByteBuffer(string: "1"))
+    handler.receiveMessage(ByteBuffer(string: "2"))
+    handler.receiveMessage(ByteBuffer(string: "3"))
+    handler.receiveEnd()
+    handler.finish()
+
+    assertThat(self.recorder.messages.first, .is(ByteBuffer(string: "1 2 3")))
+    assertThat(self.recorder.status, .notNil(.hasCode(.ok)))
+    assertThat(self.recorder.trailers, .is([:]))
+  }
+
+  func testThrowingDeserializer() {
+    let handler = ClientStreamingServerHandler(
+      context: self.makeCallHandlerContext(),
+      requestDeserializer: ThrowingStringDeserializer(),
+      responseSerializer: StringSerializer(),
+      interceptors: [],
+      observerFactory: self.neverReceivesMessage(context:)
+    )
+
+    handler.receiveMetadata([:])
+    assertThat(self.recorder.metadata, .is([:]))
+
+    let buffer = ByteBuffer(string: "hello")
+    handler.receiveMessage(buffer)
+
+    assertThat(self.recorder.messages, .isEmpty())
+    assertThat(self.recorder.status, .notNil(.hasCode(.internalError)))
+  }
+
+  func testThrowingSerializer() {
+    let handler = ClientStreamingServerHandler(
+      context: self.makeCallHandlerContext(),
+      requestDeserializer: StringDeserializer(),
+      responseSerializer: ThrowingStringSerializer(),
+      interceptors: [],
+      observerFactory: self.joinWithSpaces(context:)
+    )
+
+    handler.receiveMetadata([:])
+    assertThat(self.recorder.metadata, .is([:]))
+
+    let buffer = ByteBuffer(string: "hello")
+    handler.receiveMessage(buffer)
+    handler.receiveEnd()
+
+    assertThat(self.recorder.messages, .isEmpty())
+    assertThat(self.recorder.status, .notNil(.hasCode(.internalError)))
+  }
+
+  func testObserverFactoryReturnsFailedFuture() {
+    let handler = self.makeHandler { context in
+      context.eventLoop.makeFailedFuture(GRPCStatus(code: .unavailable, message: ":("))
+    }
+
+    handler.receiveMetadata([:])
+    assertThat(self.recorder.messages, .isEmpty())
+    assertThat(self.recorder.status, .notNil(.hasCode(.unavailable)))
+    assertThat(self.recorder.status?.message, .is(":("))
+  }
+
+  func testDelayedObserverFactory() {
+    let promise = self.eventLoop.makePromise(of: Void.self)
+    let handler = self.makeHandler { context in
+      return promise.futureResult.flatMap {
+        self.joinWithSpaces(context: context)
+      }
+    }
+
+    handler.receiveMetadata([:])
+    // Queue up some messages.
+    handler.receiveMessage(ByteBuffer(string: "1"))
+    handler.receiveMessage(ByteBuffer(string: "2"))
+    handler.receiveMessage(ByteBuffer(string: "3"))
+    // Succeed the observer block.
+    promise.succeed(())
+    // A few more messages.
+    handler.receiveMessage(ByteBuffer(string: "4"))
+    handler.receiveMessage(ByteBuffer(string: "5"))
+    handler.receiveEnd()
+
+    assertThat(self.recorder.messages.first, .is(ByteBuffer(string: "1 2 3 4 5")))
+    assertThat(self.recorder.status, .notNil(.hasCode(.ok)))
+  }
+
+  func testDelayedObserverFactoryAllMessagesBeforeSucceeding() {
+    let promise = self.eventLoop.makePromise(of: Void.self)
+    let handler = self.makeHandler { context in
+      return promise.futureResult.flatMap {
+        self.joinWithSpaces(context: context)
+      }
+    }
+
+    handler.receiveMetadata([:])
+    // Queue up some messages.
+    handler.receiveMessage(ByteBuffer(string: "1"))
+    handler.receiveMessage(ByteBuffer(string: "2"))
+    handler.receiveMessage(ByteBuffer(string: "3"))
+    handler.receiveEnd()
+    // Succeed the observer block.
+    promise.succeed(())
+
+    assertThat(self.recorder.messages.first, .is(ByteBuffer(string: "1 2 3")))
+    assertThat(self.recorder.status, .notNil(.hasCode(.ok)))
+  }
+
+  func testReceiveMessageBeforeHeaders() {
+    let handler = self.makeHandler(observerFactory: self.neverCalled(context:))
+
+    handler.receiveMessage(ByteBuffer(string: "foo"))
+    assertThat(self.recorder.metadata, .is(.nil()))
+    assertThat(self.recorder.messages, .isEmpty())
+    assertThat(self.recorder.status, .notNil(.hasCode(.internalError)))
+  }
+
+  func testReceiveMultipleHeaders() {
+    let handler = self.makeHandler(observerFactory: self.neverReceivesMessage(context:))
+
+    handler.receiveMetadata([:])
+    assertThat(self.recorder.metadata, .is([:]))
+
+    handler.receiveMetadata([:])
+    assertThat(self.recorder.messages, .isEmpty())
+    assertThat(self.recorder.status, .notNil(.hasCode(.internalError)))
   }
 
   func testFinishBeforeStarting() {
-    let recorder = ResponseRecorder()
-    let handler = self.makeHandler(writer: recorder, function: self.neverCalled(_:context:))
+    let handler = self.makeHandler(observerFactory: self.neverCalled(context:))
 
     handler.finish()
-    assertThat(recorder.metadata, .is(.nil()))
-    assertThat(recorder.messages, .isEmpty())
-    assertThat(recorder.status, .is(.nil()))
-    assertThat(recorder.trailers, .is(.nil()))
+    assertThat(self.recorder.metadata, .is(.nil()))
+    assertThat(self.recorder.messages, .isEmpty())
+    assertThat(self.recorder.status, .is(.nil()))
+    assertThat(self.recorder.trailers, .is(.nil()))
   }
 
   func testFinishAfterHeaders() {
-    let recorder = ResponseRecorder()
-    let handler = self.makeHandler(writer: recorder, function: self.neverCalled(_:context:))
+    let handler = self.makeHandler(observerFactory: self.joinWithSpaces(context:))
     handler.receiveMetadata([:])
-    assertThat(recorder.metadata, .is([:]))
+    assertThat(self.recorder.metadata, .is([:]))
 
     handler.finish()
 
-    assertThat(recorder.messages, .isEmpty())
-    assertThat(recorder.status, .notNil(.hasCode(.unavailable)))
-    assertThat(recorder.trailers, .is([:]))
+    assertThat(self.recorder.messages, .isEmpty())
+    assertThat(self.recorder.status, .notNil(.hasCode(.unavailable)))
+    assertThat(self.recorder.trailers, .is([:]))
   }
 
   func testFinishAfterMessage() {
-    let recorder = ResponseRecorder()
-    let handler = self.makeHandler(writer: recorder, function: self.neverComplete(_:context:))
+    let handler = self.makeHandler(observerFactory: self.joinWithSpaces(context:))
 
     handler.receiveMetadata([:])
     handler.receiveMessage(ByteBuffer(string: "hello"))
     handler.finish()
 
-    assertThat(recorder.messages, .isEmpty())
-    assertThat(recorder.status, .notNil(.hasCode(.unavailable)))
-    assertThat(recorder.trailers, .is([:]))
+    assertThat(self.recorder.messages, .isEmpty())
+    assertThat(self.recorder.status, .notNil(.hasCode(.unavailable)))
+    assertThat(self.recorder.trailers, .is([:]))
   }
 }