Explorar o código

Add compression logic to GRPCMessageFramer (#1771)

Gustavo Cairo hai 1 ano
pai
achega
995b13aec6

+ 6 - 0
Sources/GRPCHTTP2Core/Compression/Zlib.swift

@@ -38,6 +38,9 @@ extension Zlib {
   /// Creates a new compressor for the given compression format.
   ///
   /// This compressor is only suitable for compressing whole messages at a time.
+  ///
+  /// - Important: ``Compressor/end()`` must be called when the compressor is not needed
+  /// anymore, to deallocate any resources allocated by `Zlib`.
   struct Compressor {
     // TODO: Make this ~Copyable when 5.9 is the lowest supported Swift version.
 
@@ -86,6 +89,9 @@ extension Zlib {
   /// Creates a new decompressor for the given compression format.
   ///
   /// This decompressor is only suitable for compressing whole messages at a time.
+  ///
+  /// - Important: ``Decompressor/end()`` must be called when the compressor is not needed
+  /// anymore, to deallocate any resources allocated by `Zlib`.
   struct Decompressor {
     // TODO: Make this ~Copyable when 5.9 is the lowest supported Swift version.
 

+ 21 - 17
Sources/GRPCHTTP2Core/GRPCMessageFramer.swift

@@ -32,15 +32,11 @@ struct GRPCMessageFramer {
   /// reserves capacity in powers of 2. This way, we can take advantage of the whole buffer.
   static let maximumWriteBufferLength = 65_536
 
-  private var pendingMessages: OneOrManyQueue<PendingMessage>
-
-  private struct PendingMessage {
-    let bytes: [UInt8]
-    let compress: Bool
-  }
+  private var pendingMessages: OneOrManyQueue<[UInt8]>
 
   private var writeBuffer: ByteBuffer
 
+  /// Create a new ``GRPCMessageFramer``.
   init() {
     self.pendingMessages = OneOrManyQueue()
     self.writeBuffer = ByteBuffer()
@@ -48,15 +44,16 @@ struct GRPCMessageFramer {
 
   /// Queue the given bytes to be framed and potentially coalesced alongside other messages in a `ByteBuffer`.
   /// The resulting data will be returned when calling ``GRPCMessageFramer/next()``.
-  /// If `compress` is true, then the given bytes will be compressed using the configured compression algorithm.
-  mutating func append(_ bytes: [UInt8], compress: Bool) {
-    self.pendingMessages.append(PendingMessage(bytes: bytes, compress: compress))
+  mutating func append(_ bytes: [UInt8]) {
+    self.pendingMessages.append(bytes)
   }
 
   /// If there are pending messages to be framed, a `ByteBuffer` will be returned with the framed data.
   /// Data may also be compressed (if configured) and multiple frames may be coalesced into the same `ByteBuffer`.
+  /// - Parameter compressor: An optional compressor: if present, payloads will be compressed; otherwise
+  /// they'll be framed as-is.
   /// - Throws: If an error is encountered, such as a compression failure, an error will be thrown.
-  mutating func next() throws -> ByteBuffer? {
+  mutating func next(compressor: Zlib.Compressor? = nil) throws -> ByteBuffer? {
     if self.pendingMessages.isEmpty {
       // Nothing pending: exit early.
       return nil
@@ -72,27 +69,34 @@ struct GRPCMessageFramer {
 
     var requiredCapacity = 0
     for message in self.pendingMessages {
-      requiredCapacity += message.bytes.count + Self.metadataLength
+      requiredCapacity += message.count + Self.metadataLength
     }
     self.writeBuffer.clear(minimumCapacity: requiredCapacity)
 
     while let message = self.pendingMessages.pop() {
-      try self.encode(message)
+      try self.encode(message, compressor: compressor)
     }
 
     return self.writeBuffer
   }
 
-  private mutating func encode(_ message: PendingMessage) throws {
-    if message.compress {
+  private mutating func encode(_ message: [UInt8], compressor: Zlib.Compressor?) throws {
+    if let compressor {
       self.writeBuffer.writeInteger(UInt8(1))  // Set compression flag
-      // TODO: compress message and write the compressed message length + bytes
+
+      // Write zeroes as length - we'll write the actual compressed size after compression.
+      let lengthIndex = self.writeBuffer.writerIndex
+      self.writeBuffer.writeInteger(UInt32(0))
+
+      // Compress and overwrite the payload length field with the right length.
+      let writtenBytes = try compressor.compress(message, into: &self.writeBuffer)
+      self.writeBuffer.setInteger(UInt32(writtenBytes), at: lengthIndex)
     } else {
       self.writeBuffer.writeMultipleIntegers(
         UInt8(0),  // Clear compression flag
-        UInt32(message.bytes.count)  // Set message length
+        UInt32(message.count)  // Set message length
       )
-      self.writeBuffer.writeBytes(message.bytes)
+      self.writeBuffer.writeBytes(message)
     }
   }
 }

+ 40 - 2
Tests/GRPCHTTP2CoreTests/GRPCMessageFramerTests.swift

@@ -22,7 +22,7 @@ import XCTest
 final class GRPCMessageFramerTests: XCTestCase {
   func testSingleWrite() throws {
     var framer = GRPCMessageFramer()
-    framer.append(Array(repeating: 42, count: 128), compress: false)
+    framer.append(Array(repeating: 42, count: 128))
 
     var buffer = try XCTUnwrap(framer.next())
     let (compressed, length) = try XCTUnwrap(buffer.readMessageHeader())
@@ -35,12 +35,49 @@ final class GRPCMessageFramerTests: XCTestCase {
     XCTAssertNil(try framer.next())
   }
 
+  private func testSingleWrite(compressionMethod: Zlib.Method) throws {
+    let compressor = Zlib.Compressor(method: compressionMethod)
+    defer {
+      compressor.end()
+    }
+    var framer = GRPCMessageFramer()
+
+    let message = [UInt8](repeating: 42, count: 128)
+    framer.append(message)
+
+    var buffer = ByteBuffer()
+    let testCompressor = Zlib.Compressor(method: compressionMethod)
+    let compressedSize = try testCompressor.compress(message, into: &buffer)
+    let compressedMessage = buffer.readSlice(length: compressedSize)
+    defer {
+      testCompressor.end()
+    }
+
+    buffer = try XCTUnwrap(framer.next(compressor: compressor))
+    let (compressed, length) = try XCTUnwrap(buffer.readMessageHeader())
+    XCTAssertTrue(compressed)
+    XCTAssertEqual(length, UInt32(compressedSize))
+    XCTAssertEqual(buffer.readSlice(length: Int(length)), compressedMessage)
+    XCTAssertEqual(buffer.readableBytes, 0)
+
+    // No more bufers.
+    XCTAssertNil(try framer.next())
+  }
+
+  func testSingleWriteDeflateCompressed() throws {
+    try self.testSingleWrite(compressionMethod: .deflate)
+  }
+
+  func testSingleWriteGZIPCompressed() throws {
+    try self.testSingleWrite(compressionMethod: .gzip)
+  }
+
   func testMultipleWrites() throws {
     var framer = GRPCMessageFramer()
 
     let messages = 100
     for _ in 0 ..< messages {
-      framer.append(Array(repeating: 42, count: 128), compress: false)
+      framer.append(Array(repeating: 42, count: 128))
     }
 
     var buffer = try XCTUnwrap(framer.next())
@@ -56,6 +93,7 @@ final class GRPCMessageFramerTests: XCTestCase {
     // No more bufers.
     XCTAssertNil(try framer.next())
   }
+
 }
 
 extension ByteBuffer {