Browse Source

Make Zlib.Compressor/Decompressor classes (#1769)

Motivation:

z_stream stores a pointer to itself in its internal state which it
checks against in inflate/deflate. As we hold these within structs, and
call through to C functions which take a pointer to a z_stream, this
address can change as the struct is copied about. This results in errors
when calling deflate/inflate.

Modifications:

- Hold a pointer to the z_stream

Result:

Harder to misue compressor/decompressor
George Barnett 1 year ago
parent
commit
36851ff7fe

+ 73 - 88
Sources/GRPCHTTP2Core/Compression/Zlib.swift

@@ -37,29 +37,18 @@ enum Zlib {
 extension Zlib {
   /// Creates a new compressor for the given compression format.
   ///
-  /// This compressor is only suitable for compressing whole messages at a time. Callers
-  /// must ``initialize()`` the compressor before using it.
+  /// This compressor is only suitable for compressing whole messages at a time.
   struct Compressor {
-    private var stream: z_stream
+    // TODO: Make this ~Copyable when 5.9 is the lowest supported Swift version.
+
+    private var stream: UnsafeMutablePointer<z_stream>
     private let method: Method
-    private var isInitialized = false
 
     init(method: Method) {
       self.method = method
-      self.stream = z_stream()
-    }
-
-    /// Initialize the compressor.
-    mutating func initialize() {
-      precondition(!self.isInitialized)
+      self.stream = .allocate(capacity: 1)
+      self.stream.initialize(to: z_stream())
       self.stream.deflateInit(windowBits: self.method.windowBits)
-      self.isInitialized = true
-    }
-
-    static func initialized(_ method: Method) -> Self {
-      var compressor = Compressor(method: method)
-      compressor.initialize()
-      return compressor
     }
 
     /// Compresses the data in `input` into the `output` buffer.
@@ -68,27 +57,27 @@ extension Zlib {
     /// - Parameter output: The `ByteBuffer` into which the compressed message should be written.
     /// - Returns: The number of bytes written into the `output` buffer.
     @discardableResult
-    mutating func compress(_ input: [UInt8], into output: inout ByteBuffer) throws -> Int {
-      precondition(self.isInitialized)
+    func compress(_ input: [UInt8], into output: inout ByteBuffer) throws -> Int {
       defer { self.reset() }
       let upperBound = self.stream.deflateBound(inputBytes: input.count)
       return try self.stream.deflate(input, into: &output, upperBound: upperBound)
     }
 
     /// Resets compression state.
-    private mutating func reset() {
+    private func reset() {
       do {
         try self.stream.deflateReset()
       } catch {
         self.end()
-        self.stream = z_stream()
+        self.stream.initialize(to: z_stream())
         self.stream.deflateInit(windowBits: self.method.windowBits)
       }
     }
 
     /// Deallocates any resources allocated by Zlib.
-    mutating func end() {
+    func end() {
       self.stream.deflateEnd()
+      self.stream.deallocate()
     }
   }
 }
@@ -96,22 +85,18 @@ extension Zlib {
 extension Zlib {
   /// Creates a new decompressor for the given compression format.
   ///
-  /// This decompressor is only suitable for compressing whole messages at a time. Callers
-  /// must ``initialize()`` the decompressor before using it.
+  /// This decompressor is only suitable for compressing whole messages at a time.
   struct Decompressor {
-    private var stream: z_stream
+    // TODO: Make this ~Copyable when 5.9 is the lowest supported Swift version.
+
+    private var stream: UnsafeMutablePointer<z_stream>
     private let method: Method
-    private var isInitialized = false
 
     init(method: Method) {
       self.method = method
-      self.stream = z_stream()
-    }
-
-    mutating func initialize() {
-      precondition(!self.isInitialized)
+      self.stream = UnsafeMutablePointer.allocate(capacity: 1)
+      self.stream.initialize(to: z_stream())
       self.stream.inflateInit(windowBits: self.method.windowBits)
-      self.isInitialized = true
     }
 
     /// Returns the decompressed bytes from ``input``.
@@ -119,26 +104,26 @@ extension Zlib {
     /// - Parameters:
     ///   - input: The buffer read compressed bytes from.
     ///   - limit: The largest size a decompressed payload may be.
-    mutating func decompress(_ input: inout ByteBuffer, limit: Int) throws -> [UInt8] {
-      precondition(self.isInitialized)
+    func decompress(_ input: inout ByteBuffer, limit: Int) throws -> [UInt8] {
       defer { self.reset() }
       return try self.stream.inflate(input: &input, limit: limit)
     }
 
     /// Resets decompression state.
-    private mutating func reset() {
+    private func reset() {
       do {
         try self.stream.inflateReset()
       } catch {
         self.end()
-        self.stream = z_stream()
+        self.stream.initialize(to: z_stream())
         self.stream.inflateInit(windowBits: self.method.windowBits)
       }
     }
 
     /// Deallocates any resources allocated by Zlib.
-    mutating func end() {
+    func end() {
       self.stream.inflateEnd()
+      self.stream.deallocate()
     }
   }
 }
@@ -155,13 +140,13 @@ struct ZlibError: Error, Hashable {
   }
 }
 
-extension z_stream {
-  mutating func inflateInit(windowBits: Int32) {
-    self.zfree = nil
-    self.zalloc = nil
-    self.opaque = nil
+extension UnsafeMutablePointer<z_stream> {
+  func inflateInit(windowBits: Int32) {
+    self.pointee.zfree = nil
+    self.pointee.zalloc = nil
+    self.pointee.opaque = nil
 
-    let rc = CGRPCZlib_inflateInit2(&self, windowBits)
+    let rc = CGRPCZlib_inflateInit2(self, windowBits)
     // Possible return codes:
     // - Z_OK
     // - Z_MEM_ERROR: not enough memory
@@ -171,8 +156,8 @@ extension z_stream {
     precondition(rc == Z_OK, "inflateInit2 failed with error (\(rc)) \(self.lastError ?? "")")
   }
 
-  mutating func inflateReset() throws {
-    let rc = CGRPCZlib_inflateReset(&self)
+  func inflateReset() throws {
+    let rc = CGRPCZlib_inflateReset(self)
 
     // Possible return codes:
     // - Z_OK
@@ -187,17 +172,17 @@ extension z_stream {
     }
   }
 
-  mutating func inflateEnd() {
-    _ = CGRPCZlib_inflateEnd(&self)
+  func inflateEnd() {
+    _ = CGRPCZlib_inflateEnd(self)
   }
 
-  mutating func deflateInit(windowBits: Int32) {
-    self.zfree = nil
-    self.zalloc = nil
-    self.opaque = nil
+  func deflateInit(windowBits: Int32) {
+    self.pointee.zfree = nil
+    self.pointee.zalloc = nil
+    self.pointee.opaque = nil
 
     let rc = CGRPCZlib_deflateInit2(
-      &self,
+      self,
       Z_DEFAULT_COMPRESSION,  // compression level
       Z_DEFLATED,  // compression method (this must be Z_DEFLATED)
       windowBits,  // window size, i.e. deflate/gzip
@@ -215,8 +200,8 @@ extension z_stream {
     precondition(rc == Z_OK, "deflateInit2 failed with error (\(rc)) \(self.lastError ?? "")")
   }
 
-  mutating func deflateReset() throws {
-    let rc = CGRPCZlib_deflateReset(&self)
+  func deflateReset() throws {
+    let rc = CGRPCZlib_deflateReset(self)
 
     // Possible return codes:
     // - Z_OK
@@ -231,87 +216,87 @@ extension z_stream {
     }
   }
 
-  mutating func deflateEnd() {
-    _ = CGRPCZlib_deflateEnd(&self)
+  func deflateEnd() {
+    _ = CGRPCZlib_deflateEnd(self)
   }
 
-  mutating func deflateBound(inputBytes: Int) -> Int {
-    let bound = CGRPCZlib_deflateBound(&self, UInt(inputBytes))
+  func deflateBound(inputBytes: Int) -> Int {
+    let bound = CGRPCZlib_deflateBound(self, UInt(inputBytes))
     return Int(bound)
   }
 
-  mutating func setNextInputBuffer(_ buffer: UnsafeMutableBufferPointer<UInt8>) {
+  func setNextInputBuffer(_ buffer: UnsafeMutableBufferPointer<UInt8>) {
     if let baseAddress = buffer.baseAddress {
-      self.next_in = baseAddress
-      self.avail_in = UInt32(buffer.count)
+      self.pointee.next_in = baseAddress
+      self.pointee.avail_in = UInt32(buffer.count)
     } else {
-      self.next_in = nil
-      self.avail_in = 0
+      self.pointee.next_in = nil
+      self.pointee.avail_in = 0
     }
   }
 
-  mutating func setNextInputBuffer(_ buffer: UnsafeMutableRawBufferPointer?) {
+  func setNextInputBuffer(_ buffer: UnsafeMutableRawBufferPointer?) {
     if let buffer = buffer, let baseAddress = buffer.baseAddress {
-      self.next_in = CGRPCZlib_castVoidToBytefPointer(baseAddress)
-      self.avail_in = UInt32(buffer.count)
+      self.pointee.next_in = CGRPCZlib_castVoidToBytefPointer(baseAddress)
+      self.pointee.avail_in = UInt32(buffer.count)
     } else {
-      self.next_in = nil
-      self.avail_in = 0
+      self.pointee.next_in = nil
+      self.pointee.avail_in = 0
     }
   }
 
-  mutating func setNextOutputBuffer(_ buffer: UnsafeMutableBufferPointer<UInt8>) {
+  func setNextOutputBuffer(_ buffer: UnsafeMutableBufferPointer<UInt8>) {
     if let baseAddress = buffer.baseAddress {
-      self.next_out = baseAddress
-      self.avail_out = UInt32(buffer.count)
+      self.pointee.next_out = baseAddress
+      self.pointee.avail_out = UInt32(buffer.count)
     } else {
-      self.next_out = nil
-      self.avail_out = 0
+      self.pointee.next_out = nil
+      self.pointee.avail_out = 0
     }
   }
 
-  mutating func setNextOutputBuffer(_ buffer: UnsafeMutableRawBufferPointer?) {
+  func setNextOutputBuffer(_ buffer: UnsafeMutableRawBufferPointer?) {
     if let buffer = buffer, let baseAddress = buffer.baseAddress {
-      self.next_out = CGRPCZlib_castVoidToBytefPointer(baseAddress)
-      self.avail_out = UInt32(buffer.count)
+      self.pointee.next_out = CGRPCZlib_castVoidToBytefPointer(baseAddress)
+      self.pointee.avail_out = UInt32(buffer.count)
     } else {
-      self.next_out = nil
-      self.avail_out = 0
+      self.pointee.next_out = nil
+      self.pointee.avail_out = 0
     }
   }
 
   /// Number of bytes available to read `self.nextInputBuffer`. See also: `z_stream.avail_in`.
   var availableInputBytes: Int {
     get {
-      Int(self.avail_in)
+      Int(self.pointee.avail_in)
     }
     set {
-      self.avail_in = UInt32(newValue)
+      self.pointee.avail_in = UInt32(newValue)
     }
   }
 
   /// The remaining writable space in `nextOutputBuffer`. See also: `z_stream.avail_out`.
   var availableOutputBytes: Int {
     get {
-      Int(self.avail_out)
+      Int(self.pointee.avail_out)
     }
     set {
-      self.avail_out = UInt32(newValue)
+      self.pointee.avail_out = UInt32(newValue)
     }
   }
 
   /// The total number of bytes written to the output buffer. See also: `z_stream.total_out`.
   var totalOutputBytes: Int {
-    Int(self.total_out)
+    Int(self.pointee.total_out)
   }
 
   /// The last error message that zlib wrote. No message is guaranteed on error, however, `nil` is
   /// guaranteed if there is no error. See also `z_stream.msg`.
   var lastError: String? {
-    self.msg.map { String(cString: $0) }
+    self.pointee.msg.map { String(cString: $0) }
   }
 
-  mutating func inflate(input: inout ByteBuffer, limit: Int) throws -> [UInt8] {
+  func inflate(input: inout ByteBuffer, limit: Int) throws -> [UInt8] {
     return try input.readWithUnsafeMutableReadableBytes { inputPointer in
       self.setNextInputBuffer(inputPointer)
       defer {
@@ -342,7 +327,7 @@ extension z_stream {
           //
           // Note that Z_OK is not okay here since we always flush with Z_FINISH and therefore
           // use Z_STREAM_END as our success criteria.
-          let rc = CGRPCZlib_inflate(&self, Z_FINISH)
+          let rc = CGRPCZlib_inflate(self, Z_FINISH)
           switch rc {
           case Z_STREAM_END:
             finished = true
@@ -377,7 +362,7 @@ extension z_stream {
     }
   }
 
-  mutating func deflate(
+  func deflate(
     _ input: [UInt8],
     into output: inout ByteBuffer,
     upperBound: Int
@@ -394,7 +379,7 @@ extension z_stream {
       return try output.writeWithUnsafeMutableBytes(minimumWritableBytes: upperBound) { output in
         self.setNextOutputBuffer(output)
 
-        let rc = CGRPCZlib_deflate(&self, Z_FINISH)
+        let rc = CGRPCZlib_deflate(self, Z_FINISH)
 
         // Possible return codes:
         // - Z_OK: some progress has been made

+ 5 - 10
Tests/GRPCHTTP2CoreTests/Server/Compression/ZlibTests.swift

@@ -31,8 +31,7 @@ final class ZlibTests: XCTestCase {
     """
 
   private func compress(_ input: [UInt8], method: Zlib.Method) throws -> ByteBuffer {
-    var compressor = Zlib.Compressor(method: method)
-    compressor.initialize()
+    let compressor = Zlib.Compressor(method: method)
     defer { compressor.end() }
 
     var buffer = ByteBuffer()
@@ -45,8 +44,7 @@ final class ZlibTests: XCTestCase {
     method: Zlib.Method,
     limit: Int = .max
   ) throws -> [UInt8] {
-    var decompressor = Zlib.Decompressor(method: method)
-    decompressor.initialize()
+    let decompressor = Zlib.Decompressor(method: method)
     defer { decompressor.end() }
 
     var input = input
@@ -69,8 +67,7 @@ final class ZlibTests: XCTestCase {
 
   func testRepeatedCompresses() throws {
     let original = Array(self.text.utf8)
-    var compressor = Zlib.Compressor(method: .deflate)
-    compressor.initialize()
+    let compressor = Zlib.Compressor(method: .deflate)
     defer { compressor.end() }
 
     var compressed = ByteBuffer()
@@ -86,8 +83,7 @@ final class ZlibTests: XCTestCase {
 
   func testRepeatedDecompresses() throws {
     let original = Array(self.text.utf8)
-    var decompressor = Zlib.Decompressor(method: .deflate)
-    decompressor.initialize()
+    let decompressor = Zlib.Decompressor(method: .deflate)
     defer { decompressor.end() }
 
     let compressed = try self.compress(original, method: .deflate)
@@ -123,8 +119,7 @@ final class ZlibTests: XCTestCase {
   }
 
   func testCompressAppendsToBuffer() throws {
-    var compressor = Zlib.Compressor(method: .deflate)
-    compressor.initialize()
+    let compressor = Zlib.Compressor(method: .deflate)
     defer { compressor.end() }
 
     var buffer = ByteBuffer()