Browse Source

Add GRPCClientStreamHandler (#1838)

Gustavo Cairo 1 year ago
parent
commit
847a9348e4

+ 238 - 0
Sources/GRPCHTTP2Core/Client/GRPCClientStreamHandler.swift

@@ -0,0 +1,238 @@
+/*
+ * Copyright 2024, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import GRPCCore
+import NIOCore
+import NIOHTTP2
+
+@available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
+final class GRPCClientStreamHandler: ChannelDuplexHandler {
+  typealias InboundIn = HTTP2Frame.FramePayload
+  typealias InboundOut = RPCResponsePart
+
+  typealias OutboundIn = RPCRequestPart
+  typealias OutboundOut = HTTP2Frame.FramePayload
+
+  private var stateMachine: GRPCStreamStateMachine
+
+  private var isReading = false
+  private var flushPending = false
+
+  init(
+    methodDescriptor: MethodDescriptor,
+    scheme: Scheme,
+    outboundEncoding: CompressionAlgorithm,
+    acceptedEncodings: [CompressionAlgorithm],
+    maximumPayloadSize: Int,
+    skipStateMachineAssertions: Bool = false
+  ) {
+    self.stateMachine = .init(
+      configuration: .client(
+        .init(
+          methodDescriptor: methodDescriptor,
+          scheme: scheme,
+          outboundEncoding: outboundEncoding,
+          acceptedEncodings: acceptedEncodings
+        )
+      ),
+      maximumPayloadSize: maximumPayloadSize,
+      skipAssertions: skipStateMachineAssertions
+    )
+  }
+}
+
+// - MARK: ChannelInboundHandler
+
+@available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
+extension GRPCClientStreamHandler {
+  func channelRead(context: ChannelHandlerContext, data: NIOAny) {
+    self.isReading = true
+    let frame = self.unwrapInboundIn(data)
+    switch frame {
+    case .data(let frameData):
+      let endStream = frameData.endStream
+      switch frameData.data {
+      case .byteBuffer(let buffer):
+        do {
+          try self.stateMachine.receive(buffer: buffer, endStream: endStream)
+          loop: while true {
+            switch self.stateMachine.nextInboundMessage() {
+            case .receiveMessage(let message):
+              context.fireChannelRead(self.wrapInboundOut(.message(message)))
+            case .awaitMoreMessages:
+              break loop
+            case .noMoreMessages:
+              context.fireUserInboundEventTriggered(ChannelEvent.inputClosed)
+              break loop
+            }
+          }
+        } catch {
+          context.fireErrorCaught(error)
+        }
+
+      case .fileRegion:
+        preconditionFailure("Unexpected IOData.fileRegion")
+      }
+
+    case .headers(let headers):
+      do {
+        let action = try self.stateMachine.receive(
+          headers: headers.headers,
+          endStream: headers.endStream
+        )
+        switch action {
+        case .receivedMetadata(let metadata):
+          context.fireChannelRead(self.wrapInboundOut(.metadata(metadata)))
+
+        case .rejectRPC:
+          throw RPCError(
+            code: .internalError,
+            message: "Client cannot get rejectRPC."
+          )
+
+        case .receivedStatusAndMetadata(let status, let metadata):
+          context.fireChannelRead(self.wrapInboundOut(.status(status, metadata)))
+
+        case .doNothing:
+          ()
+        }
+      } catch {
+        context.fireErrorCaught(error)
+      }
+
+    case .ping, .goAway, .priority, .rstStream, .settings, .pushPromise, .windowUpdate,
+      .alternativeService, .origin:
+      ()
+    }
+  }
+
+  func channelReadComplete(context: ChannelHandlerContext) {
+    self.isReading = false
+    if self.flushPending {
+      self.flushPending = false
+      self.flush(context: context)
+    }
+    context.fireChannelReadComplete()
+  }
+
+  func handlerRemoved(context: ChannelHandlerContext) {
+    self.stateMachine.tearDown()
+  }
+}
+
+// - MARK: ChannelOutboundHandler
+
+@available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
+extension GRPCClientStreamHandler {
+  func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
+    switch self.unwrapOutboundIn(data) {
+    case .metadata(let metadata):
+      do {
+        self.flushPending = true
+        let headers = try self.stateMachine.send(metadata: metadata)
+        context.write(self.wrapOutboundOut(.headers(.init(headers: headers))), promise: nil)
+        // TODO: move the promise handling into the state machine
+        promise?.succeed()
+      } catch {
+        context.fireErrorCaught(error)
+        // TODO: move the promise handling into the state machine
+        promise?.fail(error)
+      }
+
+    case .message(let message):
+      do {
+        try self.stateMachine.send(message: message)
+        // TODO: move the promise handling into the state machine
+        promise?.succeed()
+      } catch {
+        context.fireErrorCaught(error)
+        // TODO: move the promise handling into the state machine
+        promise?.fail(error)
+      }
+    }
+  }
+
+  func close(context: ChannelHandlerContext, mode: CloseMode, promise: EventLoopPromise<Void>?) {
+    switch mode {
+    case .output, .all:
+      do {
+        try self.stateMachine.closeOutbound()
+        // Force a flush by calling _flush
+        // (otherwise, we'd skip flushing if we're in a read loop)
+        self._flush(context: context)
+        context.close(mode: mode, promise: promise)
+      } catch {
+        promise?.fail(error)
+        context.fireErrorCaught(error)
+      }
+
+    case .input:
+      context.close(mode: .input, promise: promise)
+    }
+  }
+
+  func flush(context: ChannelHandlerContext) {
+    if self.isReading {
+      // We don't want to flush yet if we're still in a read loop.
+      self.flushPending = true
+      return
+    }
+
+    self._flush(context: context)
+  }
+
+  private func _flush(context: ChannelHandlerContext) {
+    do {
+      loop: while true {
+        switch try self.stateMachine.nextOutboundMessage() {
+        case .sendMessage(let byteBuffer):
+          self.flushPending = true
+          context.write(
+            self.wrapOutboundOut(.data(.init(data: .byteBuffer(byteBuffer)))),
+            promise: nil
+          )
+
+        case .noMoreMessages:
+          // Write an empty data frame with the EOS flag set, to signal the RPC
+          // request is now finished.
+          context.write(
+            self.wrapOutboundOut(
+              HTTP2Frame.FramePayload.data(
+                .init(
+                  data: .byteBuffer(.init()),
+                  endStream: true
+                )
+              )
+            ),
+            promise: nil
+          )
+
+          context.flush()
+          break loop
+
+        case .awaitMoreMessages:
+          if self.flushPending {
+            self.flushPending = false
+            context.flush()
+          }
+          break loop
+        }
+      }
+    } catch {
+      context.fireErrorCaught(error)
+    }
+  }
+}

+ 21 - 13
Sources/GRPCHTTP2Core/GRPCStreamStateMachine.swift

@@ -373,8 +373,12 @@ struct GRPCStreamStateMachine {
 
   mutating func receive(headers: HPACKHeaders, endStream: Bool) throws -> OnMetadataReceived {
     switch self.configuration {
-    case .client:
-      return try self.clientReceive(headers: headers, endStream: endStream)
+    case .client(let clientConfiguration):
+      return try self.clientReceive(
+        headers: headers,
+        endStream: endStream,
+        configuration: clientConfiguration
+      )
     case .server(let serverConfiguration):
       return try self.serverReceive(
         headers: headers,
@@ -567,9 +571,7 @@ extension GRPCStreamStateMachine {
     case .clientOpenServerClosed(let state):
       self.state = .clientClosedServerClosed(.init(previousState: state))
     case .clientClosedServerIdle, .clientClosedServerOpen, .clientClosedServerClosed:
-      try self.invalidState(
-        "Client is closed, cannot send a message."
-      )
+      try self.invalidState("Client is already closed.")
     }
   }
 
@@ -665,7 +667,7 @@ extension GRPCStreamStateMachine {
         .receivedStatusAndMetadata(
           status: .init(
             code: .internalError,
-            message: "Missing \(GRPCHTTP2Keys.contentType) header"
+            message: "Missing \(GRPCHTTP2Keys.contentType.rawValue) header"
           ),
           metadata: Metadata(headers: metadata)
         )
@@ -680,10 +682,15 @@ extension GRPCStreamStateMachine {
     case success(CompressionAlgorithm)
   }
 
-  private func processInboundEncoding(_ metadata: HPACKHeaders) -> ProcessInboundEncodingResult {
+  private func processInboundEncoding(
+    headers: HPACKHeaders,
+    configuration: GRPCStreamStateMachineConfiguration.ClientConfiguration
+  ) -> ProcessInboundEncodingResult {
     let inboundEncoding: CompressionAlgorithm
-    if let serverEncoding = metadata.first(name: GRPCHTTP2Keys.encoding.rawValue) {
-      guard let parsedEncoding = CompressionAlgorithm(rawValue: serverEncoding) else {
+    if let serverEncoding = headers.first(name: GRPCHTTP2Keys.encoding.rawValue) {
+      guard let parsedEncoding = CompressionAlgorithm(rawValue: serverEncoding),
+        configuration.acceptedEncodings.contains(parsedEncoding)
+      else {
         return .error(
           .receivedStatusAndMetadata(
             status: .init(
@@ -691,7 +698,7 @@ extension GRPCStreamStateMachine {
               message:
                 "The server picked a compression algorithm ('\(serverEncoding)') the client does not know about."
             ),
-            metadata: Metadata(headers: metadata)
+            metadata: Metadata(headers: headers)
           )
         )
       }
@@ -732,7 +739,8 @@ extension GRPCStreamStateMachine {
 
   private mutating func clientReceive(
     headers: HPACKHeaders,
-    endStream: Bool
+    endStream: Bool,
+    configuration: GRPCStreamStateMachineConfiguration.ClientConfiguration
   ) throws -> OnMetadataReceived {
     switch self.state {
     case .clientOpenServerIdle(let state):
@@ -750,7 +758,7 @@ extension GRPCStreamStateMachine {
         self.state = .clientOpenServerClosed(.init(previousState: state))
         return try self.validateAndReturnStatusAndMetadata(headers)
       case (.valid, false):
-        switch self.processInboundEncoding(headers) {
+        switch self.processInboundEncoding(headers: headers, configuration: configuration) {
         case .error(let failure):
           return failure
         case .success(let inboundEncoding):
@@ -798,7 +806,7 @@ extension GRPCStreamStateMachine {
         self.state = .clientClosedServerClosed(.init(previousState: state))
         return try self.validateAndReturnStatusAndMetadata(headers)
       case (.valid, false):
-        switch self.processInboundEncoding(headers) {
+        switch self.processInboundEncoding(headers: headers, configuration: configuration) {
         case .error(let failure):
           return failure
         case .success(let inboundEncoding):

+ 724 - 0
Tests/GRPCHTTP2CoreTests/Client/GRPCClientStreamHandlerTests.swift

@@ -0,0 +1,724 @@
+/*
+ * Copyright 2024, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import GRPCCore
+import NIOCore
+import NIOEmbedded
+import NIOHPACK
+import NIOHTTP1
+import NIOHTTP2
+import XCTest
+
+@testable import GRPCHTTP2Core
+
+@available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
+final class GRPCClientStreamHandlerTests: XCTestCase {
+  func testH2FramesAreIgnored() throws {
+    let handler = GRPCClientStreamHandler(
+      methodDescriptor: .init(service: "test", method: "test"),
+      scheme: .http,
+      outboundEncoding: .identity,
+      acceptedEncodings: [],
+      maximumPayloadSize: 1
+    )
+
+    let channel = EmbeddedChannel(handler: handler)
+
+    let framesToBeIgnored: [HTTP2Frame.FramePayload] = [
+      .ping(.init(), ack: false),
+      .goAway(lastStreamID: .rootStream, errorCode: .cancel, opaqueData: nil),
+      // TODO: add .priority(StreamPriorityData) - right now, StreamPriorityData's
+      // initialiser is internal, so I can't create one of these frames.
+      .rstStream(.cancel),
+      .settings(.ack),
+      .pushPromise(.init(pushedStreamID: .maxID, headers: [:])),
+      .windowUpdate(windowSizeIncrement: 4),
+      .alternativeService(origin: nil, field: nil),
+      .origin([]),
+    ]
+
+    for toBeIgnored in framesToBeIgnored {
+      XCTAssertNoThrow(try channel.writeInbound(toBeIgnored))
+      XCTAssertNil(try channel.readInbound(as: HTTP2Frame.FramePayload.self))
+    }
+  }
+
+  func testServerInitialMetadataMissingHTTPStatusCodeResultsInFinishedRPC() throws {
+    let handler = GRPCClientStreamHandler(
+      methodDescriptor: .init(service: "test", method: "test"),
+      scheme: .http,
+      outboundEncoding: .identity,
+      acceptedEncodings: [],
+      maximumPayloadSize: 1,
+      skipStateMachineAssertions: true
+    )
+
+    let channel = EmbeddedChannel(handler: handler)
+
+    // Send client's initial metadata
+    let request = RPCRequestPart.metadata([:])
+    XCTAssertNoThrow(try channel.writeOutbound(request))
+
+    // Receive server's initial metadata without :status
+    let serverInitialMetadata: HPACKHeaders = [
+      GRPCHTTP2Keys.contentType.rawValue: ContentType.grpc.canonicalValue
+    ]
+
+    XCTAssertNoThrow(
+      try channel.writeInbound(
+        HTTP2Frame.FramePayload.headers(.init(headers: serverInitialMetadata))
+      )
+    )
+
+    XCTAssertEqual(
+      try channel.readInbound(as: RPCResponsePart.self),
+      .status(
+        .init(code: .unknown, message: "HTTP Status Code is missing."),
+        Metadata(headers: serverInitialMetadata)
+      )
+    )
+  }
+
+  func testServerInitialMetadata1xxHTTPStatusCodeResultsInNothingRead() throws {
+    let handler = GRPCClientStreamHandler(
+      methodDescriptor: .init(service: "test", method: "test"),
+      scheme: .http,
+      outboundEncoding: .identity,
+      acceptedEncodings: [],
+      maximumPayloadSize: 1,
+      skipStateMachineAssertions: true
+    )
+
+    let channel = EmbeddedChannel(handler: handler)
+
+    // Send client's initial metadata
+    let request = RPCRequestPart.metadata([:])
+    XCTAssertNoThrow(try channel.writeOutbound(request))
+
+    // Receive server's initial metadata with 1xx status
+    let serverInitialMetadata: HPACKHeaders = [
+      GRPCHTTP2Keys.status.rawValue: "104",
+      GRPCHTTP2Keys.contentType.rawValue: ContentType.grpc.canonicalValue,
+    ]
+
+    XCTAssertNoThrow(
+      try channel.writeInbound(
+        HTTP2Frame.FramePayload.headers(.init(headers: serverInitialMetadata))
+      )
+    )
+
+    XCTAssertNil(try channel.readInbound(as: RPCResponsePart.self))
+  }
+
+  func testServerInitialMetadataOtherNon200HTTPStatusCodeResultsInFinishedRPC() throws {
+    let handler = GRPCClientStreamHandler(
+      methodDescriptor: .init(service: "test", method: "test"),
+      scheme: .http,
+      outboundEncoding: .identity,
+      acceptedEncodings: [],
+      maximumPayloadSize: 1,
+      skipStateMachineAssertions: true
+    )
+
+    let channel = EmbeddedChannel(handler: handler)
+
+    // Send client's initial metadata
+    let request = RPCRequestPart.metadata([:])
+    XCTAssertNoThrow(try channel.writeOutbound(request))
+
+    // Receive server's initial metadata with non-200 and non-1xx :status
+    let serverInitialMetadata: HPACKHeaders = [
+      GRPCHTTP2Keys.status.rawValue: String(HTTPResponseStatus.tooManyRequests.code),
+      GRPCHTTP2Keys.contentType.rawValue: ContentType.grpc.canonicalValue,
+    ]
+
+    XCTAssertNoThrow(
+      try channel.writeInbound(
+        HTTP2Frame.FramePayload.headers(.init(headers: serverInitialMetadata))
+      )
+    )
+
+    XCTAssertEqual(
+      try channel.readInbound(as: RPCResponsePart.self),
+      .status(
+        .init(code: .unavailable, message: "Unexpected non-200 HTTP Status Code."),
+        Metadata(headers: serverInitialMetadata)
+      )
+    )
+  }
+
+  func testServerInitialMetadataMissingContentTypeResultsInFinishedRPC() throws {
+    let handler = GRPCClientStreamHandler(
+      methodDescriptor: .init(service: "test", method: "test"),
+      scheme: .http,
+      outboundEncoding: .identity,
+      acceptedEncodings: [],
+      maximumPayloadSize: 1,
+      skipStateMachineAssertions: true
+    )
+
+    let channel = EmbeddedChannel(handler: handler)
+
+    // Send client's initial metadata
+    let request = RPCRequestPart.metadata([:])
+    XCTAssertNoThrow(try channel.writeOutbound(request))
+
+    // Receive server's initial metadata without content-type
+    let serverInitialMetadata: HPACKHeaders = [
+      GRPCHTTP2Keys.status.rawValue: "200"
+    ]
+
+    XCTAssertNoThrow(
+      try channel.writeInbound(
+        HTTP2Frame.FramePayload.headers(.init(headers: serverInitialMetadata))
+      )
+    )
+
+    XCTAssertEqual(
+      try channel.readInbound(as: RPCResponsePart.self),
+      .status(
+        .init(code: .internalError, message: "Missing content-type header"),
+        Metadata(headers: serverInitialMetadata)
+      )
+    )
+  }
+
+  func testNotAcceptedEncodingResultsInFinishedRPC() throws {
+    let handler = GRPCClientStreamHandler(
+      methodDescriptor: .init(service: "test", method: "test"),
+      scheme: .http,
+      outboundEncoding: .deflate,
+      acceptedEncodings: [.deflate],
+      maximumPayloadSize: 1
+    )
+
+    let channel = EmbeddedChannel(handler: handler)
+
+    // Send client's initial metadata
+    XCTAssertNoThrow(
+      try channel.writeOutbound(RPCRequestPart.metadata(Metadata()))
+    )
+
+    // Make sure we have sent right metadata.
+    let writtenMetadata = try channel.assertReadHeadersOutbound()
+
+    XCTAssertEqual(
+      writtenMetadata.headers,
+      [
+        GRPCHTTP2Keys.method.rawValue: "POST",
+        GRPCHTTP2Keys.scheme.rawValue: "http",
+        GRPCHTTP2Keys.path.rawValue: "test/test",
+        GRPCHTTP2Keys.contentType.rawValue: "application/grpc",
+        GRPCHTTP2Keys.te.rawValue: "trailers",
+        GRPCHTTP2Keys.encoding.rawValue: "deflate",
+        GRPCHTTP2Keys.acceptEncoding.rawValue: "deflate",
+      ]
+    )
+
+    // Server sends initial metadata with unsupported encoding
+    let serverInitialMetadata: HPACKHeaders = [
+      GRPCHTTP2Keys.status.rawValue: "200",
+      GRPCHTTP2Keys.contentType.rawValue: ContentType.grpc.canonicalValue,
+      GRPCHTTP2Keys.encoding.rawValue: "gzip",
+    ]
+
+    XCTAssertNoThrow(
+      try channel.writeInbound(
+        HTTP2Frame.FramePayload.headers(.init(headers: serverInitialMetadata))
+      )
+    )
+
+    XCTAssertEqual(
+      try channel.readInbound(as: RPCResponsePart.self),
+      .status(
+        .init(
+          code: .internalError,
+          message:
+            "The server picked a compression algorithm ('gzip') the client does not know about."
+        ),
+        Metadata(headers: serverInitialMetadata)
+      )
+    )
+  }
+
+  func testOverMaximumPayloadSize() throws {
+    let handler = GRPCClientStreamHandler(
+      methodDescriptor: .init(service: "test", method: "test"),
+      scheme: .http,
+      outboundEncoding: .identity,
+      acceptedEncodings: [],
+      maximumPayloadSize: 1,
+      skipStateMachineAssertions: true
+    )
+
+    let channel = EmbeddedChannel(handler: handler)
+
+    // Send client's initial metadata
+    XCTAssertNoThrow(
+      try channel.writeOutbound(RPCRequestPart.metadata(Metadata()))
+    )
+
+    // Make sure we have sent right metadata.
+    let writtenMetadata = try channel.assertReadHeadersOutbound()
+
+    XCTAssertEqual(
+      writtenMetadata.headers,
+      [
+        GRPCHTTP2Keys.method.rawValue: "POST",
+        GRPCHTTP2Keys.scheme.rawValue: "http",
+        GRPCHTTP2Keys.path.rawValue: "test/test",
+        GRPCHTTP2Keys.contentType.rawValue: "application/grpc",
+        GRPCHTTP2Keys.te.rawValue: "trailers",
+      ]
+    )
+
+    // Server sends initial metadata
+    let serverInitialMetadata: HPACKHeaders = [
+      GRPCHTTP2Keys.status.rawValue: "200",
+      GRPCHTTP2Keys.contentType.rawValue: ContentType.grpc.canonicalValue,
+    ]
+    XCTAssertNoThrow(
+      try channel.writeInbound(
+        HTTP2Frame.FramePayload.headers(.init(headers: serverInitialMetadata))
+      )
+    )
+    XCTAssertEqual(
+      try channel.readInbound(as: RPCResponsePart.self),
+      .metadata(Metadata(headers: serverInitialMetadata))
+    )
+
+    // Server sends message over payload limit
+    var buffer = ByteBuffer()
+    buffer.writeInteger(UInt8(0))  // not compressed
+    buffer.writeInteger(UInt32(42))  // message length
+    buffer.writeRepeatingByte(0, count: 42)  // message
+    let clientDataPayload = HTTP2Frame.FramePayload.Data(data: .byteBuffer(buffer), endStream: true)
+    XCTAssertThrowsError(
+      ofType: RPCError.self,
+      try channel.writeInbound(HTTP2Frame.FramePayload.data(clientDataPayload))
+    ) { error in
+      XCTAssertEqual(error.code, .resourceExhausted)
+      XCTAssertEqual(
+        error.message,
+        "Message has exceeded the configured maximum payload size (max: 1, actual: 42)"
+      )
+    }
+
+    // Make sure we didn't read the received message
+    XCTAssertNil(try channel.readInbound(as: RPCRequestPart.self))
+  }
+
+  func testServerEndsStream() throws {
+    let handler = GRPCClientStreamHandler(
+      methodDescriptor: .init(service: "test", method: "test"),
+      scheme: .http,
+      outboundEncoding: .identity,
+      acceptedEncodings: [],
+      maximumPayloadSize: 1,
+      skipStateMachineAssertions: true
+    )
+
+    let channel = EmbeddedChannel(handler: handler)
+
+    // Write client's initial metadata
+    XCTAssertNoThrow(try channel.writeOutbound(RPCRequestPart.metadata(Metadata())))
+    let clientInitialMetadata: HPACKHeaders = [
+      GRPCHTTP2Keys.path.rawValue: "test/test",
+      GRPCHTTP2Keys.scheme.rawValue: "http",
+      GRPCHTTP2Keys.method.rawValue: "POST",
+      GRPCHTTP2Keys.contentType.rawValue: "application/grpc",
+      GRPCHTTP2Keys.te.rawValue: "trailers",
+    ]
+    let writtenInitialMetadata = try channel.assertReadHeadersOutbound()
+    XCTAssertEqual(writtenInitialMetadata.headers, clientInitialMetadata)
+
+    // Receive server's initial metadata with end stream set
+    let serverInitialMetadata: HPACKHeaders = [
+      GRPCHTTP2Keys.status.rawValue: "200",
+      GRPCHTTP2Keys.grpcStatus.rawValue: "0",
+      GRPCHTTP2Keys.contentType.rawValue: "application/grpc",
+    ]
+    XCTAssertNoThrow(
+      try channel.writeInbound(
+        HTTP2Frame.FramePayload.headers(
+          .init(
+            headers: serverInitialMetadata,
+            endStream: true
+          )
+        )
+      )
+    )
+    XCTAssertEqual(
+      try channel.readInbound(as: RPCResponsePart.self),
+      .status(
+        .init(code: .ok, message: ""),
+        [
+          GRPCHTTP2Keys.status.rawValue: "200",
+          GRPCHTTP2Keys.contentType.rawValue: "application/grpc",
+        ]
+      )
+    )
+
+    // We should throw if the server sends another message, since it's closed the stream already.
+    var buffer = ByteBuffer()
+    buffer.writeInteger(UInt8(0))  // not compressed
+    buffer.writeInteger(UInt32(42))  // message length
+    buffer.writeRepeatingByte(0, count: 42)  // message
+    let serverDataPayload = HTTP2Frame.FramePayload.Data(data: .byteBuffer(buffer), endStream: true)
+    XCTAssertThrowsError(
+      ofType: RPCError.self,
+      try channel.writeInbound(HTTP2Frame.FramePayload.data(serverDataPayload))
+    ) { error in
+      XCTAssertEqual(error.code, .internalError)
+      XCTAssertEqual(error.message, "Cannot have received anything from a closed server.")
+    }
+  }
+
+  func testNormalFlow() throws {
+    let handler = GRPCClientStreamHandler(
+      methodDescriptor: .init(service: "test", method: "test"),
+      scheme: .http,
+      outboundEncoding: .identity,
+      acceptedEncodings: [],
+      maximumPayloadSize: 100,
+      skipStateMachineAssertions: true
+    )
+
+    let channel = EmbeddedChannel(handler: handler)
+
+    // Send client's initial metadata
+    let request = RPCRequestPart.metadata([:])
+    XCTAssertNoThrow(try channel.writeOutbound(request))
+
+    // Make sure we have sent the corresponding frame, and that nothing has been written back.
+    let writtenHeaders = try channel.assertReadHeadersOutbound()
+    XCTAssertEqual(
+      writtenHeaders.headers,
+      [
+        GRPCHTTP2Keys.method.rawValue: "POST",
+        GRPCHTTP2Keys.scheme.rawValue: "http",
+        GRPCHTTP2Keys.path.rawValue: "test/test",
+        GRPCHTTP2Keys.contentType.rawValue: "application/grpc",
+        GRPCHTTP2Keys.te.rawValue: "trailers",
+
+      ]
+    )
+    XCTAssertNil(try channel.readInbound(as: RPCResponsePart.self))
+
+    // Receive server's initial metadata
+    let serverInitialMetadata: HPACKHeaders = [
+      GRPCHTTP2Keys.status.rawValue: "200",
+      GRPCHTTP2Keys.contentType.rawValue: "application/grpc",
+      "some-custom-header": "some-custom-value",
+    ]
+    XCTAssertNoThrow(
+      try channel.writeInbound(
+        HTTP2Frame.FramePayload.headers(.init(headers: serverInitialMetadata))
+      )
+    )
+
+    XCTAssertEqual(
+      try channel.readInbound(as: RPCResponsePart.self),
+      RPCResponsePart.metadata(Metadata(headers: serverInitialMetadata))
+    )
+
+    // Send a message
+    XCTAssertNoThrow(
+      try channel.writeOutbound(RPCRequestPart.message(.init(repeating: 1, count: 42)))
+    )
+
+    // Assert we wrote it successfully into the channel
+    let writtenMessage = try channel.assertReadDataOutbound()
+    var expectedBuffer = ByteBuffer()
+    expectedBuffer.writeInteger(UInt8(0))  // not compressed
+    expectedBuffer.writeInteger(UInt32(42))  // message length
+    expectedBuffer.writeRepeatingByte(1, count: 42)  // message
+    XCTAssertEqual(writtenMessage.data, .byteBuffer(expectedBuffer))
+
+    // Half-close the outbound end: this would be triggered by finishing the client's writer.
+    XCTAssertNoThrow(channel.close(mode: .output, promise: nil))
+
+    // Flush to make sure the EOS is written.
+    channel.flush()
+
+    // Make sure the EOS frame was sent
+    let emptyEOSFrame = try channel.assertReadDataOutbound()
+    XCTAssertEqual(emptyEOSFrame.data, .byteBuffer(.init()))
+    XCTAssertTrue(emptyEOSFrame.endStream)
+
+    // Make sure we cannot write anymore because client's closed.
+    XCTAssertThrowsError(
+      ofType: RPCError.self,
+      try channel.writeOutbound(RPCRequestPart.message(.init(repeating: 1, count: 42)))
+    ) { error in
+      XCTAssertEqual(error.code, .internalError)
+      XCTAssertEqual(error.message, "Client is closed, cannot send a message.")
+    }
+
+    // This is needed to clear the EmbeddedChannel's stored error, otherwise
+    // it will be thrown when writing inbound.
+    try? channel.throwIfErrorCaught()
+
+    // Server sends back response message
+    var buffer = ByteBuffer()
+    buffer.writeInteger(UInt8(0))  // not compressed
+    buffer.writeInteger(UInt32(42))  // message length
+    buffer.writeRepeatingByte(0, count: 42)  // message
+    let serverDataPayload = HTTP2Frame.FramePayload.Data(data: .byteBuffer(buffer))
+    XCTAssertNoThrow(try channel.writeInbound(HTTP2Frame.FramePayload.data(serverDataPayload)))
+
+    // Make sure we read the message properly
+    XCTAssertEqual(
+      try channel.readInbound(as: RPCResponsePart.self),
+      RPCResponsePart.message([UInt8](repeating: 0, count: 42))
+    )
+
+    // Server sends status to end RPC
+    XCTAssertNoThrow(
+      try channel.writeInbound(
+        HTTP2Frame.FramePayload.headers(
+          .init(headers: [
+            GRPCHTTP2Keys.grpcStatus.rawValue: String(Status.Code.dataLoss.rawValue),
+            GRPCHTTP2Keys.grpcStatusMessage.rawValue: "Test data loss",
+            "custom-header": "custom-value",
+          ])
+        )
+      )
+    )
+
+    XCTAssertEqual(
+      try channel.readInbound(as: RPCResponsePart.self),
+      .status(.init(code: .dataLoss, message: "Test data loss"), ["custom-header": "custom-value"])
+    )
+  }
+
+  func testReceiveMessageSplitAcrossMultipleBuffers() throws {
+    let handler = GRPCClientStreamHandler(
+      methodDescriptor: .init(service: "test", method: "test"),
+      scheme: .http,
+      outboundEncoding: .identity,
+      acceptedEncodings: [],
+      maximumPayloadSize: 100,
+      skipStateMachineAssertions: true
+    )
+
+    let channel = EmbeddedChannel(handler: handler)
+
+    // Send client's initial metadata
+    let request = RPCRequestPart.metadata([:])
+    XCTAssertNoThrow(try channel.writeOutbound(request))
+
+    // Make sure we have sent the corresponding frame, and that nothing has been written back.
+    let writtenHeaders = try channel.assertReadHeadersOutbound()
+    XCTAssertEqual(
+      writtenHeaders.headers,
+      [
+        GRPCHTTP2Keys.method.rawValue: "POST",
+        GRPCHTTP2Keys.scheme.rawValue: "http",
+        GRPCHTTP2Keys.path.rawValue: "test/test",
+        GRPCHTTP2Keys.contentType.rawValue: "application/grpc",
+        GRPCHTTP2Keys.te.rawValue: "trailers",
+
+      ]
+    )
+    XCTAssertNil(try channel.readInbound(as: RPCResponsePart.self))
+
+    // Receive server's initial metadata
+    let serverInitialMetadata: HPACKHeaders = [
+      GRPCHTTP2Keys.status.rawValue: "200",
+      GRPCHTTP2Keys.contentType.rawValue: "application/grpc",
+      "some-custom-header": "some-custom-value",
+    ]
+    XCTAssertNoThrow(
+      try channel.writeInbound(
+        HTTP2Frame.FramePayload.headers(.init(headers: serverInitialMetadata))
+      )
+    )
+    XCTAssertEqual(
+      try channel.readInbound(as: RPCResponsePart.self),
+      RPCResponsePart.metadata(Metadata(headers: serverInitialMetadata))
+    )
+
+    // Send a message
+    XCTAssertNoThrow(
+      try channel.writeOutbound(RPCRequestPart.message(.init(repeating: 1, count: 42)))
+    )
+
+    // Assert we wrote it successfully into the channel
+    let writtenMessage = try channel.assertReadDataOutbound()
+    var expectedBuffer = ByteBuffer()
+    expectedBuffer.writeInteger(UInt8(0))  // not compressed
+    expectedBuffer.writeInteger(UInt32(42))  // message length
+    expectedBuffer.writeRepeatingByte(1, count: 42)  // message
+    XCTAssertEqual(writtenMessage.data, .byteBuffer(expectedBuffer))
+
+    // Receive server's first message
+    var buffer = ByteBuffer()
+    buffer.writeInteger(UInt8(0))  // not compressed
+    XCTAssertNoThrow(
+      try channel.writeInbound(HTTP2Frame.FramePayload.data(.init(data: .byteBuffer(buffer))))
+    )
+    XCTAssertNil(try channel.readInbound(as: RPCResponsePart.self))
+
+    buffer.clear()
+    buffer.writeInteger(UInt32(30))  // message length
+    XCTAssertNoThrow(
+      try channel.writeInbound(HTTP2Frame.FramePayload.data(.init(data: .byteBuffer(buffer))))
+    )
+    XCTAssertNil(try channel.readInbound(as: RPCResponsePart.self))
+
+    buffer.clear()
+    buffer.writeRepeatingByte(0, count: 10)  // first part of the message
+    XCTAssertNoThrow(
+      try channel.writeInbound(HTTP2Frame.FramePayload.data(.init(data: .byteBuffer(buffer))))
+    )
+    XCTAssertNil(try channel.readInbound(as: RPCResponsePart.self))
+
+    buffer.clear()
+    buffer.writeRepeatingByte(1, count: 10)  // second part of the message
+    XCTAssertNoThrow(
+      try channel.writeInbound(HTTP2Frame.FramePayload.data(.init(data: .byteBuffer(buffer))))
+    )
+    XCTAssertNil(try channel.readInbound(as: RPCResponsePart.self))
+
+    buffer.clear()
+    buffer.writeRepeatingByte(2, count: 10)  // third part of the message
+    XCTAssertNoThrow(
+      try channel.writeInbound(HTTP2Frame.FramePayload.data(.init(data: .byteBuffer(buffer))))
+    )
+
+    // Make sure we read the message properly
+    XCTAssertEqual(
+      try channel.readInbound(as: RPCResponsePart.self),
+      RPCResponsePart.message(
+        [UInt8](repeating: 0, count: 10) + [UInt8](repeating: 1, count: 10)
+          + [UInt8](repeating: 2, count: 10)
+      )
+    )
+  }
+
+  func testSendMultipleMessagesInSingleBuffer() throws {
+    let handler = GRPCClientStreamHandler(
+      methodDescriptor: .init(service: "test", method: "test"),
+      scheme: .http,
+      outboundEncoding: .identity,
+      acceptedEncodings: [],
+      maximumPayloadSize: 100,
+      skipStateMachineAssertions: true
+    )
+
+    let channel = EmbeddedChannel(handler: handler)
+
+    // Send client's initial metadata
+    let request = RPCRequestPart.metadata([:])
+    XCTAssertNoThrow(try channel.writeOutbound(request))
+
+    // Make sure we have sent the corresponding frame, and that nothing has been written back.
+    let writtenHeaders = try channel.assertReadHeadersOutbound()
+    XCTAssertEqual(
+      writtenHeaders.headers,
+      [
+        GRPCHTTP2Keys.method.rawValue: "POST",
+        GRPCHTTP2Keys.scheme.rawValue: "http",
+        GRPCHTTP2Keys.path.rawValue: "test/test",
+        GRPCHTTP2Keys.contentType.rawValue: "application/grpc",
+        GRPCHTTP2Keys.te.rawValue: "trailers",
+
+      ]
+    )
+    XCTAssertNil(try channel.readInbound(as: RPCResponsePart.self))
+
+    // Receive server's initial metadata
+    let serverInitialMetadata: HPACKHeaders = [
+      GRPCHTTP2Keys.status.rawValue: "200",
+      GRPCHTTP2Keys.contentType.rawValue: "application/grpc",
+      "some-custom-header": "some-custom-value",
+    ]
+    XCTAssertNoThrow(
+      try channel.writeInbound(
+        HTTP2Frame.FramePayload.headers(.init(headers: serverInitialMetadata))
+      )
+    )
+    XCTAssertEqual(
+      try channel.readInbound(as: RPCResponsePart.self),
+      RPCResponsePart.metadata(Metadata(headers: serverInitialMetadata))
+    )
+
+    // This is where this test actually begins. We want to write two messages
+    // without flushing, and make sure that no messages are sent down the pipeline
+    // until we flush. Once we flush, both messages should be sent in the same ByteBuffer.
+
+    // Write back first message and make sure nothing's written in the channel.
+    XCTAssertNoThrow(channel.write(RPCRequestPart.message([UInt8](repeating: 1, count: 4))))
+    XCTAssertNil(try channel.readOutbound(as: HTTP2Frame.FramePayload.self))
+
+    // Write back second message and make sure nothing's written in the channel.
+    XCTAssertNoThrow(channel.write(RPCRequestPart.message([UInt8](repeating: 2, count: 4))))
+    XCTAssertNil(try channel.readOutbound(as: HTTP2Frame.FramePayload.self))
+
+    // Now flush and check we *do* write the data.
+    channel.flush()
+
+    let writtenMessage = try channel.assertReadDataOutbound()
+
+    // Make sure both messages have been framed together in the ByteBuffer.
+    XCTAssertEqual(
+      writtenMessage.data,
+      .byteBuffer(
+        .init(bytes: [
+          // First message
+          0,  // Compression disabled
+          0, 0, 0, 4,  // Message length
+          1, 1, 1, 1,  // First message data
+
+          // Second message
+          0,  // Compression disabled
+          0, 0, 0, 4,  // Message length
+          2, 2, 2, 2,  // Second message data
+        ])
+      )
+    )
+    XCTAssertNil(try channel.readOutbound(as: HTTP2Frame.FramePayload.self))
+  }
+}
+
+extension EmbeddedChannel {
+  fileprivate func assertReadHeadersOutbound() throws -> HTTP2Frame.FramePayload.Headers {
+    guard
+      case .headers(let writtenHeaders) = try XCTUnwrap(
+        try self.readOutbound(as: HTTP2Frame.FramePayload.self)
+      )
+    else {
+      throw TestError.assertionFailure("Expected to write headers")
+    }
+    return writtenHeaders
+  }
+
+  fileprivate func assertReadDataOutbound() throws -> HTTP2Frame.FramePayload.Data {
+    guard
+      case .data(let writtenMessage) = try XCTUnwrap(
+        try self.readOutbound(as: HTTP2Frame.FramePayload.self)
+      )
+    else {
+      throw TestError.assertionFailure("Expected to write data")
+    }
+    return writtenMessage
+  }
+}
+
+private enum TestError: Error {
+  case assertionFailure(String)
+}

+ 12 - 2
Tests/GRPCHTTP2CoreTests/Server/GRPCServerStreamHandlerTests.swift

@@ -257,7 +257,6 @@ final class GRPCServerStreamHandlerTests: XCTestCase {
       )
     )
 
-    // Make sure we haven't sent back an error response, and that we read the initial metadata
     // Make sure we have sent a trailers-only response
     let writtenTrailersOnlyResponse = try channel.assertReadHeadersOutbound()
 
@@ -413,7 +412,8 @@ final class GRPCServerStreamHandlerTests: XCTestCase {
     let handler = GRPCServerStreamHandler(
       scheme: .http,
       acceptedEncodings: [],
-      maximumPayloadSize: 100
+      maximumPayloadSize: 100,
+      skipStateMachineAssertions: true
     )
 
     let channel = EmbeddedChannel(handler: handler)
@@ -505,6 +505,16 @@ final class GRPCServerStreamHandlerTests: XCTestCase {
         "custom-header": "custom-value",
       ]
     )
+
+    // Try writing and assert it throws to make sure we don't allow writes
+    // after closing.
+    XCTAssertThrowsError(
+      ofType: RPCError.self,
+      try channel.writeOutbound(trailers)
+    ) { error in
+      XCTAssertEqual(error.code, .internalError)
+      XCTAssertEqual(error.message, "Server can't send anything if closed.")
+    }
   }
 
   func testReceiveMessageSplitAcrossMultipleBuffers() throws {