Browse Source

Alter the documented requirements of `MessageSerializer` (#888)

Motivation:

requirements that implementations must prefix their serialized messages
with a compression byte and `UInt32` length. The motivation behind this
was to avoid a reallocation later on when the message was being framed.
The result was that in some cases the framing would be done by the
serialier and in other cases by the framer (or
`LengthPrefixedMessageWriter`).

Modifications:

- Remove the requirement for the serialier to frame the message
- Optimize the `LengthPrefixedMessageWriter` so that it checks whether
  there is enough space _before_ the serialized bytes in which the
  compression byte and length could be written.
- gRPC-Web: Replace the single response buffer with a circular buffer of
  buffers; this avoids writing into a buffer and simplifies the
  `LengthPrefixedMessageWriter`.

Result:

- Better separation of concerns.
George Barnett 5 years ago
parent
commit
b29b16ba06

+ 29 - 20
Sources/GRPC/HTTP1ToGRPCServerCodec.swift

@@ -87,7 +87,7 @@ public final class HTTP1ToGRPCServerCodec {
   // TODO(kaipi): Extract all gRPC Web processing logic into an independent handler only added on
   // the HTTP1.1 pipeline, as it's starting to get in the way of readability.
   private var requestTextBuffer: NIO.ByteBuffer!
-  private var responseTextBuffer: NIO.ByteBuffer!
+  private var responseTextBuffers: CircularBuffer<ByteBuffer> = []
 
   var inboundState = InboundState.expectingHeaders {
     willSet {
@@ -311,10 +311,6 @@ extension HTTP1ToGRPCServerCodec: ChannelOutboundHandler {
         }
       }
 
-      if self.contentType == .webTextProtobuf {
-        responseTextBuffer = context.channel.allocator.buffer(capacity: 0)
-      }
-
       // Are we compressing responses?
       if let responseEncoding = self.responseEncodingHeader {
         headers.add(name: GRPCHeaderName.encoding, value: responseEncoding)
@@ -340,12 +336,12 @@ extension HTTP1ToGRPCServerCodec: ChannelOutboundHandler {
           // Store the response into an independent buffer. We can't return the message directly as
           // it needs to be aggregated with all the responses plus the trailers, in order to have
           // the base64 response properly encoded in a single byte stream.
-          precondition(self.responseTextBuffer != nil)
-          try self.messageWriter.write(
+          let buffer = try self.messageWriter.write(
             buffer: messageContext.message,
-            into: &self.responseTextBuffer,
+            allocator: context.channel.allocator,
             compressed: messageContext.compressed
           )
+          self.responseTextBuffers.append(buffer)
 
           // Since we stored the written data, mark the write promise as successful so that the
           // ServerStreaming provider continues sending the data.
@@ -383,25 +379,38 @@ extension HTTP1ToGRPCServerCodec: ChannelOutboundHandler {
       }
 
       if contentType == .webTextProtobuf {
-        precondition(responseTextBuffer != nil)
-
         // Encode the trailers into the response byte stream as a length delimited message, as per
         // https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-WEB.md
         let textTrailers = trailers.map { name, value in "\(name): \(value)" }.joined(separator: "\r\n")
-        responseTextBuffer.writeInteger(UInt8(0x80))
-        responseTextBuffer.writeInteger(UInt32(textTrailers.utf8.count))
-        responseTextBuffer.writeString(textTrailers)
+        var trailersBuffer = context.channel.allocator.buffer(capacity: 5 + textTrailers.utf8.count)
+        trailersBuffer.writeInteger(UInt8(0x80))
+        trailersBuffer.writeInteger(UInt32(textTrailers.utf8.count))
+        trailersBuffer.writeString(textTrailers)
+        self.responseTextBuffers.append(trailersBuffer)
+
+        // The '!' is fine, we know it's not empty since we just added a buffer.
+        var responseTextBuffer = self.responseTextBuffers.popFirst()!
+
+        // Read the data from the first buffer.
+        var accumulatedData = responseTextBuffer.readData(length: responseTextBuffer.readableBytes)!
+
+        // Reserve enough capacity and append the remaining buffers.
+        let requiredExtraCapacity = self.responseTextBuffers.lazy.map { $0.readableBytes }.reduce(0, +)
+        accumulatedData.reserveCapacity(accumulatedData.count + requiredExtraCapacity)
+        while let buffer = self.responseTextBuffers.popFirst() {
+          accumulatedData.append(contentsOf: buffer.readableBytesView)
+        }
 
         // TODO: Binary responses that are non multiples of 3 will end = or == when encoded in
         // base64. Investigate whether this might have any effect on the transport mechanism and
-        // client decoding. Initial results say that they are inocuous, but we might have to keep
+        // client decoding. Initial results say that they are innocuous, but we might have to keep
         // an eye on this in case something trips up.
-        if let binaryData = responseTextBuffer.readData(length: responseTextBuffer.readableBytes) {
-          let encodedData = binaryData.base64EncodedString()
-          responseTextBuffer.clear()
-          responseTextBuffer.reserveCapacity(encodedData.utf8.count)
-          responseTextBuffer.writeString(encodedData)
-        }
+        let encodedData = accumulatedData.base64EncodedString()
+
+        // Reuse our first buffer.
+        responseTextBuffer.clear(minimumCapacity: numericCast(encodedData.utf8.count))
+        responseTextBuffer.writeString(encodedData)
+
         // After collecting all response for gRPC Web connections, send one final aggregated
         // response.
         context.write(self.wrapOutboundOut(.body(.byteBuffer(responseTextBuffer))), promise: promise)

+ 49 - 36
Sources/GRPC/LengthPrefixedMessageWriter.swift

@@ -43,29 +43,25 @@ internal struct LengthPrefixedMessageWriter {
 
   private func compress(
     buffer: ByteBuffer,
-    into output: inout ByteBuffer,
-    using compressor: Zlib.Deflate
-  ) throws {
-    let save = output
+    using compressor: Zlib.Deflate,
+    allocator: ByteBufferAllocator
+  ) throws -> ByteBuffer {
+    // The compressor will allocate the correct size. For now the leading 5 bytes will do.
+    var output = allocator.buffer(capacity: 5)
 
     // Set the compression byte.
     output.writeInteger(UInt8(1))
 
-    // Leave a gap for the length, we'll set it in a moment.
+    // Set the length to zero; we'll write the actual value in a moment.
     let payloadSizeIndex = output.writerIndex
-    output.moveWriterIndex(forwardBy: MemoryLayout<UInt32>.size)
-
-    // Compress the message. We know that we need to drop the first 5 bytes, and we know that these
-    // bytes must exist.
-    var buffer = buffer
-    buffer.moveReaderIndex(forwardBy: 5)
+    output.writeInteger(UInt32(0))
 
     let bytesWritten: Int
     
     do {
+      var buffer = buffer
       bytesWritten = try compressor.deflate(&buffer, into: &output)
     } catch {
-      output = save
       throw error
     }
 
@@ -74,37 +70,54 @@ internal struct LengthPrefixedMessageWriter {
 
     // Finally, the compression context should be reset between messages.
     compressor.reset()
-  }
 
-  func write(buffer: ByteBuffer, into output: inout ByteBuffer, compressed: Bool = true) throws {
-    // We expect the message to be prefixed with the compression flag and length. Let's double check.
-    assert(buffer.readableBytes >= 5, "Buffer does not contain the 5-byte head (compression byte and length)")
-    assert(buffer.getInteger(at: buffer.readerIndex, as: UInt8.self) == 0, "Compression byte was unexpectedly non-zero")
-    assert(Int(buffer.getInteger(at: buffer.readerIndex + 1, as: UInt32.self)!) + 5 == buffer.readableBytes, "Incorrect message length")
+    return output
+  }
 
+  /// Writes the readable bytes of `buffer` as a gRPC length-prefixed message.
+  ///
+  /// - Parameters:
+  ///   - buffer: The bytes to compress and length-prefix.
+  ///   - allocator: A `ByteBufferAllocator`.
+  ///   - compressed: Whether the bytes should be compressed. This is ignored if not compression
+  ///     mechanism was configured on this writer.
+  /// - Returns: A buffer containing the length prefixed bytes.
+  func write(buffer: ByteBuffer, allocator: ByteBufferAllocator, compressed: Bool = true) throws -> ByteBuffer {
     if compressed, let compressor = self.compressor {
-      try self.compress(buffer: buffer, into: &output, using: compressor)
-    } else {
-      // A straight copy.
+      return try self.compress(buffer: buffer, using: compressor, allocator: allocator)
+    } else if buffer.readerIndex >= 5 {
+      // We're not compressing and we have enough bytes before the reader index that we can write
+      // over with the compression byte and length.
       var buffer = buffer
-      output.writeBuffer(&buffer)
-    }
-  }
 
-  func write(buffer: ByteBuffer, allocator: ByteBufferAllocator, compressed: Bool = true) throws -> ByteBuffer {
-    // We expect the message to be prefixed with the compression flag and length. Let's double check.
-    assert(buffer.readableBytes >= 5, "Buffer does not contain the 5-byte preamble (compression byte and length)")
-    assert(buffer.getInteger(at: buffer.readerIndex, as: UInt8.self) == 0, "Compression byte was unexpectedly non-zero")
-    assert(Int(buffer.getInteger(at: buffer.readerIndex + 1, as: UInt32.self)!) + 5 == buffer.readableBytes, "Incorrect message length")
+      // Get the size of the message.
+      let messageSize = buffer.readableBytes
 
-    if compressed, let compressor = self.compressor {
-      // Darn, we need another buffer. We'll assume it'll need to at least the size of the input buffer.
-      var compressed = allocator.buffer(capacity: buffer.readableBytes)
-      try self.compress(buffer: buffer, into: &compressed, using: compressor)
-      return compressed
-    } else {
-      // We're not using compression and our preamble is already in place; easy!
+      // Move the reader index back 5 bytes. This is okay: we validated the `readerIndex` above.
+      buffer.moveReaderIndex(to: buffer.readerIndex - 5)
+
+      // Fill in the compression byte and message length.
+      buffer.setInteger(UInt8(0), at: buffer.readerIndex)
+      buffer.setInteger(UInt32(messageSize), at: buffer.readerIndex + 1)
+
+      // The message bytes are already in place, we're done.
       return buffer
+    } else {
+      // We're not compressing and we don't have enough space before the message bytes passed in.
+      // We need a new buffer.
+      var lengthPrefixed = allocator.buffer(capacity: 5 + buffer.readableBytes)
+
+      // Write the compression byte.
+      lengthPrefixed.writeInteger(UInt8(0))
+
+      // Write the message length.
+      lengthPrefixed.writeInteger(UInt32(buffer.readableBytes))
+
+      // Write the message.
+      var buffer = buffer
+      lengthPrefixed.writeBuffer(&buffer)
+
+      return lengthPrefixed
     }
   }
 

+ 12 - 15
Sources/GRPC/Serialization.swift

@@ -22,10 +22,6 @@ internal protocol MessageSerializer {
 
   /// Serializes `input` into a `ByteBuffer` allocated using the provided `allocator`.
   ///
-  /// The serialized buffer should have 5 leading bytes: the first must be zero, the following
-  /// four bytes are the `UInt32` encoded length of the serialized message. The bytes of the
-  /// serialized message follow.
-  ///
   /// - Parameters:
   ///   - input: The element to serialize.
   ///   - allocator: A `ByteBufferAllocator`.
@@ -48,15 +44,14 @@ internal struct ProtobufSerializer<Message: SwiftProtobuf.Message>: MessageSeria
     // Serialize the message.
     let serialized = try message.serializedData()
 
+    // Allocate enough space and an extra 5 leading bytes. This a minor optimisation win: the length
+    // prefixed message writer can re-use the leading 5 bytes without needing to allocate a new
+    // buffer and copy over the serialized message.
     var buffer = allocator.buffer(capacity: serialized.count + 5)
+    buffer.writeBytes(Array(repeating: 0, count: 5))
+    buffer.moveReaderIndex(forwardBy: 5)
 
-    // The compression byte. This will be modified later, if necessary.
-    buffer.writeInteger(UInt8(0))
-
-    // The length of the serialized message.
-    buffer.writeInteger(UInt32(serialized.count))
-
-    // The serialized message.
+    // Now write the serialized message.
     buffer.writeBytes(serialized)
 
     return buffer
@@ -76,7 +71,9 @@ internal struct ProtobufDeserializer<Message: SwiftProtobuf.Message>: MessageDes
 
 internal struct GRPCPayloadSerializer<Message: GRPCPayload>: MessageSerializer {
   internal func serialize(_ message: Message, allocator: ByteBufferAllocator) throws -> ByteBuffer {
-    // Reserve 5 leading bytes.
+    // Reserve 5 leading bytes. This a minor optimisation win: the length prefixed message writer
+    // can re-use the leading 5 bytes without needing to allocate a new buffer and copy over the
+    // serialized message.
     var buffer = allocator.buffer(repeating: 0, count: 5)
 
     let readerIndex = buffer.readerIndex
@@ -91,9 +88,9 @@ internal struct GRPCPayloadSerializer<Message: GRPCPayload>: MessageSerializer {
     assert(buffer.getBytes(at: readerIndex, length: 5) == Array(repeating: 0, count: 5),
            "serialize(into:) must not write over existing written bytes")
 
-    // The first byte is already zero. Set the length.
-    let messageSize = buffer.writerIndex - writerIndex
-    buffer.setInteger(UInt32(messageSize), at: readerIndex + 1)
+    // 'read' the first 5 bytes so that the buffer's readable bytes are only the bytes of the
+    // serialized message.
+    buffer.moveReaderIndex(forwardBy: 5)
 
     return buffer
   }

+ 75 - 0
Tests/GRPCTests/LengthPrefixedMessageWriterTests.swift

@@ -0,0 +1,75 @@
+/*
+ * Copyright 2020, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+@testable import GRPC
+import NIO
+import XCTest
+
+class LengthPrefixedMessageWriterTests: GRPCTestCase {
+  func testWriteBytesWithNoLeadingSpaceOrCompression() throws {
+    let writer = LengthPrefixedMessageWriter()
+    let allocator = ByteBufferAllocator()
+    let buffer = allocator.buffer(bytes: [1, 2, 3])
+
+    var prefixed = try writer.write(buffer: buffer, allocator: allocator)
+    XCTAssertEqual(prefixed.readInteger(as: UInt8.self), 0)
+    XCTAssertEqual(prefixed.readInteger(as: UInt32.self), 3)
+    XCTAssertEqual(prefixed.readBytes(length: 3), [1, 2, 3])
+    XCTAssertEqual(prefixed.readableBytes, 0)
+  }
+
+  func testWriteBytesWithLeadingSpaceAndNoCompression() throws {
+    let writer = LengthPrefixedMessageWriter()
+    let allocator = ByteBufferAllocator()
+
+    var buffer = allocator.buffer(bytes: Array(repeating: 0, count: 5) +  [1, 2, 3])
+    buffer.moveReaderIndex(forwardBy: 5)
+
+    var prefixed = try writer.write(buffer: buffer, allocator: allocator)
+    XCTAssertEqual(prefixed.readInteger(as: UInt8.self), 0)
+    XCTAssertEqual(prefixed.readInteger(as: UInt32.self), 3)
+    XCTAssertEqual(prefixed.readBytes(length: 3), [1, 2, 3])
+    XCTAssertEqual(prefixed.readableBytes, 0)
+  }
+
+  func testWriteBytesWithNoLeadingSpaceAndCompression() throws {
+    let writer = LengthPrefixedMessageWriter(compression: .gzip)
+    let allocator = ByteBufferAllocator()
+
+    let buffer = allocator.buffer(bytes: [1, 2, 3])
+    var prefixed = try writer.write(buffer: buffer, allocator: allocator)
+
+    XCTAssertEqual(prefixed.readInteger(as: UInt8.self), 1)
+    let size = prefixed.readInteger(as: UInt32.self)!
+    XCTAssertGreaterThanOrEqual(size, 0)
+    XCTAssertNotNil(prefixed.readBytes(length: Int(size)))
+    XCTAssertEqual(prefixed.readableBytes, 0)
+  }
+
+  func testWriteBytesWithLeadingSpaceAndCompression() throws {
+    let writer = LengthPrefixedMessageWriter(compression: .gzip)
+    let allocator = ByteBufferAllocator()
+
+    var buffer = allocator.buffer(bytes: Array(repeating: 0, count: 5) +  [1, 2, 3])
+    buffer.moveReaderIndex(forwardBy: 5)
+    var prefixed = try writer.write(buffer: buffer, allocator: allocator)
+
+    XCTAssertEqual(prefixed.readInteger(as: UInt8.self), 1)
+    let size = prefixed.readInteger(as: UInt32.self)!
+    XCTAssertGreaterThanOrEqual(size, 0)
+    XCTAssertNotNil(prefixed.readBytes(length: Int(size)))
+    XCTAssertEqual(prefixed.readableBytes, 0)
+  }
+}

+ 13 - 0
Tests/GRPCTests/XCTestManifests.swift

@@ -683,6 +683,18 @@ extension LengthPrefixedMessageReaderTests {
     ]
 }
 
+extension LengthPrefixedMessageWriterTests {
+    // DO NOT MODIFY: This is autogenerated, use:
+    //   `swift test --generate-linuxmain`
+    // to regenerate.
+    static let __allTests__LengthPrefixedMessageWriterTests = [
+        ("testWriteBytesWithLeadingSpaceAndCompression", testWriteBytesWithLeadingSpaceAndCompression),
+        ("testWriteBytesWithLeadingSpaceAndNoCompression", testWriteBytesWithLeadingSpaceAndNoCompression),
+        ("testWriteBytesWithNoLeadingSpaceAndCompression", testWriteBytesWithNoLeadingSpaceAndCompression),
+        ("testWriteBytesWithNoLeadingSpaceOrCompression", testWriteBytesWithNoLeadingSpaceOrCompression),
+    ]
+}
+
 extension MessageCompressionTests {
     // DO NOT MODIFY: This is autogenerated, use:
     //   `swift test --generate-linuxmain`
@@ -905,6 +917,7 @@ public func __allTests() -> [XCTestCaseEntry] {
         testCase(ImmediatelyFailingProviderTests.__allTests__ImmediatelyFailingProviderTests),
         testCase(LazyEventLoopPromiseTests.__allTests__LazyEventLoopPromiseTests),
         testCase(LengthPrefixedMessageReaderTests.__allTests__LengthPrefixedMessageReaderTests),
+        testCase(LengthPrefixedMessageWriterTests.__allTests__LengthPrefixedMessageWriterTests),
         testCase(MessageCompressionTests.__allTests__MessageCompressionTests),
         testCase(MessageEncodingHeaderValidatorTests.__allTests__MessageEncodingHeaderValidatorTests),
         testCase(PlatformSupportTests.__allTests__PlatformSupportTests),