Browse Source

Fix gRPC Web trailers encoding (#1582)

Motivation:

In gRPC Web HTTP trailers are encoded as a 'regular' gRPC message; that
is a length prefixed message. For grpc-web we sent trailers back as
regular trailers.

Modifications:

- Send trailers back as a length prefixed body part.
- Update tests.

Result:

Resolves #1580
George Barnett 2 years ago
parent
commit
b625430f5a

+ 27 - 10
Sources/GRPC/GRPCWebToHTTP2ServerCodec.swift

@@ -453,9 +453,17 @@ extension GRPCWebToHTTP2ServerCodec.StateMachine.State {
         )
       )
     } else {
-      // No response buffer; plain gRPC Web.
-      let trailers = HTTPHeaders(hpackHeaders: trailers)
-      return .write(.init(part: .end(trailers), promise: promise, closeChannel: closeChannel))
+      // No response buffer; plain gRPC Web. Trailers are encoded into the body as a regular
+      // length-prefixed message.
+      let buffer = GRPCWebToHTTP2ServerCodec.formatTrailers(trailers, allocator: allocator)
+      return .write(
+        .init(
+          part: .body(.byteBuffer(buffer)),
+          additionalPart: .end(nil),
+          promise: promise,
+          closeChannel: closeChannel
+        )
+      )
     }
   }
 
@@ -671,18 +679,27 @@ extension GRPCWebToHTTP2ServerCodec {
     allocator: ByteBufferAllocator
   ) -> ByteBuffer {
     // See: https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-WEB.md
-    let encodedTrailers = trailers.map { name, value, _ in
-      "\(name): \(value)"
-    }.joined(separator: "\r\n")
+    let length = trailers.reduce(0) { partial, trailer in
+      // +4 for: ":", " ", "\r", "\n"
+      return partial + trailer.name.utf8.count + trailer.value.utf8.count + 4
+    }
+    var buffer = allocator.buffer(capacity: 5 + length)
 
-    var buffer = allocator.buffer(capacity: 5 + encodedTrailers.utf8.count)
     // Uncompressed trailer byte.
     buffer.writeInteger(UInt8(0x80))
     // Length.
-    buffer.writeInteger(UInt32(encodedTrailers.utf8.count))
-    // Uncompressed trailers.
-    buffer.writeString(encodedTrailers)
+    let lengthIndex = buffer.writerIndex
+    buffer.writeInteger(UInt32(0))
+
+    var bytesWritten = 0
+    for (name, value, _) in trailers {
+      bytesWritten += buffer.writeString(name)
+      bytesWritten += buffer.writeString(": ")
+      bytesWritten += buffer.writeString(value)
+      bytesWritten += buffer.writeString("\r\n")
+    }
 
+    buffer.setInteger(UInt32(bytesWritten), at: lengthIndex)
     return buffer
   }
 

+ 13 - 6
Tests/GRPCTests/GRPCWebToHTTP2ServerCodecTests.swift

@@ -24,10 +24,14 @@ import XCTest
 
 class GRPCWebToHTTP2ServerCodecTests: GRPCTestCase {
   private func writeTrailers(_ trailers: HPACKHeaders, into buffer: inout ByteBuffer) {
-    let encoded = trailers.map { "\($0.name): \($0.value)" }.joined(separator: "\r\n")
     buffer.writeInteger(UInt8(0x80))
-    buffer.writeInteger(UInt32(encoded.utf8.count))
-    buffer.writeString(encoded)
+    try! buffer.writeLengthPrefixed(as: UInt32.self) {
+      var length = 0
+      for (name, value, _) in trailers {
+        length += $0.writeString("\(name): \(value)\r\n")
+      }
+      return length
+    }
   }
 
   private func receiveHead(
@@ -106,7 +110,7 @@ class GRPCWebToHTTP2ServerCodecTests: GRPCTestCase {
     on channel: EmbeddedChannel,
     expectedBytes: ByteBuffer? = nil
   ) throws {
-    let headers: HPACKHeaders = ["grpc-status": "\(status)"]
+    let headers: HPACKHeaders = ["grpc-status": "\(status.rawValue)"]
     let headersPayload: HTTP2Frame.FramePayload = .headers(.init(headers: headers, endStream: true))
     assertThat(try channel.writeOutbound(headersPayload), .doesNotThrow())
 
@@ -128,7 +132,10 @@ class GRPCWebToHTTP2ServerCodecTests: GRPCTestCase {
     // Outbound
     try self.sendResponseHeaders(on: channel)
     try self.sendBytes([1, 2, 3], on: channel, expectedBytes: [1, 2, 3])
-    try self.sendEnd(status: .ok, on: channel)
+
+    var buffer = ByteBuffer()
+    self.writeTrailers(["grpc-status": "0"], into: &buffer)
+    try self.sendEnd(status: .ok, on: channel, expectedBytes: buffer)
   }
 
   func testWebTextHappyPath() throws {
@@ -150,7 +157,7 @@ class GRPCWebToHTTP2ServerCodecTests: GRPCTestCase {
     // Build up the expected response, i.e. the response bytes and the trailers, base64 encoded.
     var expectedBodyBuffer = ByteBuffer(bytes: [1, 2, 3])
     let status = GRPCStatus.Code.ok
-    self.writeTrailers(["grpc-status": "\(status)"], into: &expectedBodyBuffer)
+    self.writeTrailers(["grpc-status": "\(status.rawValue)"], into: &expectedBodyBuffer)
     try self.sendEnd(status: status, on: channel, expectedBytes: expectedBodyBuffer.base64Encoded())
   }
 

+ 18 - 8
Tests/GRPCTests/GRPCWebToHTTP2StateMachineTests.swift

@@ -223,12 +223,11 @@ final class GRPCWebToHTTP2StateMachineTests: GRPCTestCase {
       promise: nil,
       allocator: self.allocator
     ).assertWrite { write in
-      write.part.assertEnd {
-        $0.assertSome { trailers in
-          XCTAssertEqual(trailers[canonicalForm: "grpc-status"], ["0"])
-        }
+      write.part.assertBody { buffer in
+        var buffer = buffer
+        let trailers = buffer.readLengthPrefixedMessage().map { String(buffer: $0) }
+        XCTAssertEqual(trailers, "grpc-status: 0\r\n")
       }
-
       XCTAssertEqual(write.closeChannel, expectChannelClose)
     }
   }
@@ -330,15 +329,15 @@ final class GRPCWebToHTTP2StateMachineTests: GRPCTestCase {
       write.part.assertBody { buffer in
         var buffer = buffer
         let base64Encoded = buffer.readString(length: buffer.readableBytes)!
-        XCTAssertEqual(base64Encoded, "aGVsbG8sIHdvcmxkIYAAAAAOZ3JwYy1zdGF0dXM6IDA=")
+        XCTAssertEqual(base64Encoded, "aGVsbG8sIHdvcmxkIYAAAAAQZ3JwYy1zdGF0dXM6IDANCg==")
 
         let data = Data(base64Encoded: base64Encoded)!
         buffer.writeData(data)
 
         XCTAssertEqual(buffer.readString(length: 13), "hello, world!")
         XCTAssertEqual(buffer.readInteger(), UInt8(0x80))
-        XCTAssertEqual(buffer.readInteger(), UInt32(14))
-        XCTAssertEqual(buffer.readString(length: 14), "grpc-status: 0")
+        XCTAssertEqual(buffer.readInteger(), UInt32(16))
+        XCTAssertEqual(buffer.readString(length: 16), "grpc-status: 0\r\n")
         XCTAssertEqual(buffer.readableBytes, 0)
       }
 
@@ -672,3 +671,14 @@ extension Optional {
     }
   }
 }
+
+extension ByteBuffer {
+  mutating func readLengthPrefixedMessage() -> ByteBuffer? {
+    // Read off and ignore the compression byte.
+    if self.readInteger(as: UInt8.self) == nil {
+      return nil
+    }
+
+    return self.readLengthPrefixedSlice(as: UInt32.self)
+  }
+}

+ 2 - 2
Tests/GRPCTests/ServerWebTests.swift

@@ -41,9 +41,9 @@ class ServerWebTests: EchoTestCaseBase {
   private func gRPCWebTrailers(status: Int = 0, message: String? = nil) -> Data {
     var data: Data
     if let message = message {
-      data = "grpc-status: \(status)\r\ngrpc-message: \(message)".data(using: .utf8)!
+      data = "grpc-status: \(status)\r\ngrpc-message: \(message)\r\n".data(using: .utf8)!
     } else {
-      data = "grpc-status: \(status)".data(using: .utf8)!
+      data = "grpc-status: \(status)\r\n".data(using: .utf8)!
     }
 
     // Add the gRPC prefix with the compression byte and the 4 length bytes.