Browse Source

Rewrite LengthPrefixedMessageReader and tests (#397)

* Rewrite LengthPrefixedMessageReader, add tests

* Switch order of expected and actual in LengthPrefixedMessageReaderTests
George Barnett 6 years ago
parent
commit
143255e986

+ 2 - 2
Sources/SwiftGRPCNIO/CompressionMechanism.swift

@@ -40,7 +40,7 @@ public enum CompressionMechanism: String {
   /// Whether the compression flag in gRPC length-prefixed messages should be set or not.
   ///
   /// See `LengthPrefixedMessageReader` for the message format.
-  var requiresFlag: Bool {
+  public var requiresFlag: Bool {
     switch self {
     case .none:
       return false
@@ -51,7 +51,7 @@ public enum CompressionMechanism: String {
   }
 
   /// Whether the given compression is supported.
-  var supported: Bool {
+  public var supported: Bool {
     switch self {
     case .identity, .none:
       return true

+ 8 - 5
Sources/SwiftGRPCNIO/HTTP1ToRawGRPCClientCodec.swift

@@ -52,7 +52,7 @@ public final class HTTP1ToRawGRPCClientCodec {
   }
 
   private var state: State = .expectingHeaders
-  private let messageReader = LengthPrefixedMessageReader(mode: .client)
+  private let messageReader = LengthPrefixedMessageReader(mode: .client, compressionMechanism: .none)
   private let messageWriter = LengthPrefixedMessageWriter()
   private var inboundCompression: CompressionMechanism = .none
 }
@@ -93,14 +93,16 @@ extension HTTP1ToRawGRPCClientCodec: ChannelInboundHandler {
       throw GRPCError.client(.HTTPStatusNotOk(head.status))
     }
 
-    if let encodingType = head.headers["grpc-encoding"].first {
-      self.inboundCompression = CompressionMechanism(rawValue: encodingType) ?? .unknown
-    }
+    let inboundCompression: CompressionMechanism = head.headers["grpc-encoding"]
+      .first
+      .map { CompressionMechanism(rawValue: $0) ?? .unknown } ?? .none
 
     guard inboundCompression.supported else {
       throw GRPCError.client(.unsupportedCompressionMechanism(inboundCompression.rawValue))
     }
 
+    self.messageReader.compressionMechanism = inboundCompression
+
     ctx.fireChannelRead(self.wrapInboundOut(.headers(head.headers)))
     return .expectingBodyOrTrailers
   }
@@ -114,7 +116,8 @@ extension HTTP1ToRawGRPCClientCodec: ChannelInboundHandler {
       throw GRPCError.client(.invalidState("received body while in state \(state)"))
     }
 
-    for message in try self.messageReader.consume(messageBuffer: &messageBuffer, compression: inboundCompression) {
+    self.messageReader.append(buffer: &messageBuffer)
+    while let message = try self.messageReader.nextMessage() {
       ctx.fireChannelRead(self.wrapInboundOut(.message(message)))
     }
 

+ 3 - 2
Sources/SwiftGRPCNIO/HTTP1ToRawGRPCServerCodec.swift

@@ -53,7 +53,7 @@ public final class HTTP1ToRawGRPCServerCodec {
   var outboundState = OutboundState.expectingHeaders
 
   var messageWriter = LengthPrefixedMessageWriter()
-  var messageReader = LengthPrefixedMessageReader(mode: .server)
+  var messageReader = LengthPrefixedMessageReader(mode: .server, compressionMechanism: .none)
 }
 
 extension HTTP1ToRawGRPCServerCodec {
@@ -148,7 +148,8 @@ extension HTTP1ToRawGRPCServerCodec: ChannelInboundHandler {
       body.write(bytes: decodedData)
     }
 
-    for message in try messageReader.consume(messageBuffer: &body, compression: .none) {
+    self.messageReader.append(buffer: &body)
+    while let message = try self.messageReader.nextMessage() {
       ctx.fireChannelRead(self.wrapInboundOut(.message(message)))
     }
 

+ 89 - 92
Sources/SwiftGRPCNIO/LengthPrefixedMessageReader.swift

@@ -15,7 +15,6 @@
  */
 import Foundation
 import NIO
-import NIOHTTP1
 
 /// This class reads and decodes length-prefixed gRPC messages.
 ///
@@ -32,117 +31,115 @@ import NIOHTTP1
 public class LengthPrefixedMessageReader {
   public typealias Mode = GRPCError.Origin
 
-  private let mode: Mode
-  private var buffer: ByteBuffer!
-  private var state: State = .expectingCompressedFlag
+  /// The mechanism that messages will be compressed with.
+  public var compressionMechanism: CompressionMechanism
 
-  private enum State {
-    case expectingCompressedFlag
-    case expectingMessageLength
-    case receivedMessageLength(Int)
-    case willBuffer(requiredBytes: Int)
-    case isBuffering(requiredBytes: Int)
-  }
-
-  public init(mode: Mode) {
+  public init(mode: Mode, compressionMechanism: CompressionMechanism) {
     self.mode = mode
+    self.compressionMechanism = compressionMechanism
   }
 
-  /// Consumes all readable bytes from given buffer and returns all messages which could be read.
+  /// The result of trying to parse a message with the bytes we currently have.
   ///
-  /// - SeeAlso: `read(messageBuffer:compression:)`
-  public func consume(messageBuffer: inout ByteBuffer, compression: CompressionMechanism) throws -> [ByteBuffer] {
-    var messages: [ByteBuffer] = []
+  /// - needMoreData: More data is required to continue reading a message.
+  /// - continue: Continue reading a message.
+  /// - message: A message was read.
+  internal enum ParseResult {
+    case needMoreData
+    case `continue`
+    case message(ByteBuffer)
+  }
 
-    while messageBuffer.readableBytes > 0 {
-      if let message = try self.read(messageBuffer: &messageBuffer, compression: compression) {
-        messages.append(message)
-      }
-    }
+  /// The parsing state; what we expect to be reading next.
+  internal enum ParseState {
+    case expectingCompressedFlag
+    case expectingMessageLength
+    case expectingMessage(UInt32)
+  }
 
-    return messages
+  private let mode: Mode
+  private var buffer: ByteBuffer!
+  private var state: ParseState = .expectingCompressedFlag
+
+  /// Appends data to the buffer from which messages will be read.
+  public func append(buffer: inout ByteBuffer) {
+    if self.buffer == nil {
+      self.buffer = buffer.slice()
+      // mark the bytes as "read"
+      buffer.moveReaderIndex(forwardBy: buffer.readableBytes)
+    } else {
+      self.buffer.write(buffer: &buffer)
+    }
   }
 
-  /// Reads bytes from the given buffer until it is exhausted or a message has been read.
-  ///
-  /// Length prefixed messages may be split across multiple input buffers in any of the
-  /// following places:
-  /// 1. after the compression flag,
-  /// 2. after the message length field,
-  /// 3. at any point within the message.
-  ///
-  /// It is possible for the message length field to be split across multiple `ByteBuffer`s,
-  /// this is unlikely to happen in practice.
+  /// Reads bytes from the buffer until it is exhausted or a message has been read.
   ///
-  /// - Note:
-  /// This method relies on state; if a message is _not_ returned then the next time this
-  /// method is called it expects to read the bytes which follow the most recently read bytes.
-  ///
-  /// - Parameters:
-  ///   - messageBuffer: buffer to read from.
-  ///   - compression: compression mechanism to decode message with.
   /// - Returns: A buffer containing a message if one has been read, or `nil` if not enough
   ///   bytes have been consumed to return a message.
   /// - Throws: Throws an error if the compression algorithm is not supported.
-  public func read(messageBuffer: inout ByteBuffer, compression: CompressionMechanism) throws -> ByteBuffer? {
-    while true {
-      switch state {
-      case .expectingCompressedFlag:
-        guard let compressionFlag: Int8 = messageBuffer.readInteger() else { return nil }
-        try handleCompressionFlag(enabled: compressionFlag != 0, mechanism: compression)
-        self.state = .expectingMessageLength
-
-      case .expectingMessageLength:
-        //! FIXME: Support the message length being split across multiple byte buffers.
-        guard let messageLength: UInt32 = messageBuffer.readInteger() else { return nil }
-        self.state = .receivedMessageLength(numericCast(messageLength))
-
-      case .receivedMessageLength(let messageLength):
-        // If this holds true, we can skip buffering and return a slice.
-        guard messageLength <= messageBuffer.readableBytes else {
-          self.state = .willBuffer(requiredBytes: messageLength)
-          continue
-        }
-
-        self.state = .expectingCompressedFlag
-        // We know messageBuffer.readableBytes >= messageLength, so it's okay to force unwrap here.
-        return messageBuffer.readSlice(length: messageLength)!
-
-      case .willBuffer(let requiredBytes):
-        messageBuffer.reserveCapacity(requiredBytes)
-        self.buffer = messageBuffer
-
-        let readableBytes = messageBuffer.readableBytes
-        // Move the reader index to avoid reading the bytes again.
-        messageBuffer.moveReaderIndex(forwardBy: readableBytes)
-
-        self.state = .isBuffering(requiredBytes: requiredBytes - readableBytes)
-        return nil
-
-      case .isBuffering(let requiredBytes):
-        guard requiredBytes <= messageBuffer.readableBytes else {
-          self.state = .isBuffering(requiredBytes: requiredBytes - self.buffer.write(buffer: &messageBuffer))
-          return nil
-        }
-
-        // We know messageBuffer.readableBytes >= requiredBytes, so it's okay to force unwrap here.
-        var slice = messageBuffer.readSlice(length: requiredBytes)!
-        self.buffer.write(buffer: &slice)
-        self.state = .expectingCompressedFlag
-
-        defer { self.buffer = nil }
-        return buffer
+  public func nextMessage() throws -> ByteBuffer? {
+    switch try self.processNextState() {
+    case .needMoreData:
+      self.nilBufferIfPossible()
+      return nil
+
+    case .continue:
+      return try nextMessage()
+
+    case .message(let message):
+      self.nilBufferIfPossible()
+      return message
+    }
+  }
+
+  /// `nil`s out `buffer` if it exists and has no readable bytes.
+  ///
+  /// This allows the next call to `append` to avoid writing the contents of the appended buffer.
+  private func nilBufferIfPossible() {
+    if self.buffer?.readableBytes == 0 {
+      self.buffer = nil
+    }
+  }
+
+  private func processNextState() throws -> ParseResult {
+    guard self.buffer != nil else { return .needMoreData }
+
+    switch self.state {
+    case .expectingCompressedFlag:
+      guard let compressionFlag: Int8 = self.buffer.readInteger() else {
+        return .needMoreData
+      }
+      try self.handleCompressionFlag(enabled: compressionFlag != 0)
+      self.state = .expectingMessageLength
+
+    case .expectingMessageLength:
+      guard let messageLength: UInt32 = self.buffer.readInteger() else {
+        return .needMoreData
+      }
+      self.state = .expectingMessage(messageLength)
+
+    case .expectingMessage(let length):
+      guard let message = self.buffer.readSlice(length: numericCast(length)) else {
+        return .needMoreData
       }
+      self.state = .expectingCompressedFlag
+      return .message(message)
     }
+
+    return .continue
   }
 
-  private func handleCompressionFlag(enabled flagEnabled: Bool, mechanism: CompressionMechanism) throws {
-    guard flagEnabled == mechanism.requiresFlag else {
+  private func handleCompressionFlag(enabled flagEnabled: Bool) throws {
+    guard flagEnabled else {
+      return
+    }
+
+    guard self.compressionMechanism.requiresFlag else {
       throw GRPCError.common(.unexpectedCompression, origin: mode)
     }
 
-    guard mechanism.supported else {
-      throw GRPCError.common(.unsupportedCompressionMechanism(mechanism.rawValue), origin: mode)
+    guard self.compressionMechanism.supported else {
+      throw GRPCError.common(.unsupportedCompressionMechanism(compressionMechanism.rawValue), origin: mode)
     }
   }
 }

+ 2 - 1
Tests/LinuxMain.swift

@@ -46,5 +46,6 @@ XCTMain([
   testCase(NIOClientTimeoutTests.allTests),
   testCase(NIOServerWebTests.allTests),
   testCase(GRPCChannelHandlerTests.allTests),
-  testCase(HTTP1ToRawGRPCServerCodecTests.allTests)
+  testCase(HTTP1ToRawGRPCServerCodecTests.allTests),
+  testCase(LengthPrefixedMessageReaderTests.allTests),
 ])

+ 251 - 0
Tests/SwiftGRPCNIOTests/LengthPrefixedMessageReaderTests.swift

@@ -0,0 +1,251 @@
+import Foundation
+import XCTest
+import SwiftGRPCNIO
+import NIO
+
+class LengthPrefixedMessageReaderTests: XCTestCase {
+  static var allTests: [(String, (LengthPrefixedMessageReaderTests) -> () throws -> Void)] {
+    return [
+      ("testNextMessageReturnsNilWhenNoBytesAppended", testNextMessageReturnsNilWhenNoBytesAppended),
+      ("testNextMessageReturnsMessageIsAppendedInOneBuffer", testNextMessageReturnsMessageIsAppendedInOneBuffer),
+      ("testNextMessageReturnsMessageForZeroLengthMessage", testNextMessageReturnsMessageForZeroLengthMessage),
+      ("testNextMessageDeliveredAcrossMultipleByteBuffers", testNextMessageDeliveredAcrossMultipleByteBuffers),
+      ("testNextMessageWhenMultipleMessagesAreBuffered", testNextMessageWhenMultipleMessagesAreBuffered),
+      ("testNextMessageReturnsNilWhenNoMessageLengthIsAvailable", testNextMessageReturnsNilWhenNoMessageLengthIsAvailable),
+      ("testNextMessageReturnsNilWhenNotAllMessageLengthIsAvailable", testNextMessageReturnsNilWhenNotAllMessageLengthIsAvailable),
+      ("testNextMessageReturnsNilWhenNoMessageBytesAreAvailable", testNextMessageReturnsNilWhenNoMessageBytesAreAvailable),
+      ("testNextMessageReturnsNilWhenNotAllMessageBytesAreAvailable", testNextMessageReturnsNilWhenNotAllMessageBytesAreAvailable),
+      ("testNextMessageThrowsWhenCompressionMechanismIsNotSupported", testNextMessageThrowsWhenCompressionMechanismIsNotSupported),
+      ("testNextMessageThrowsWhenCompressionFlagIsSetButNotExpected", testNextMessageThrowsWhenCompressionFlagIsSetButNotExpected),
+      ("testNextMessageDoesNotThrowWhenCompressionFlagIsExpectedButNotSet", testNextMessageDoesNotThrowWhenCompressionFlagIsExpectedButNotSet),
+      ("testAppendReadsAllBytes", testAppendReadsAllBytes),
+    ]
+  }
+
+  var reader = LengthPrefixedMessageReader(mode: .client, compressionMechanism: .none)
+
+  var allocator = ByteBufferAllocator()
+
+  func byteBuffer(withBytes bytes: [UInt8]) -> ByteBuffer {
+    var buffer = allocator.buffer(capacity: bytes.count)
+    buffer.write(bytes: bytes)
+    return buffer
+  }
+
+  final let twoByteMessage: [UInt8] = [0x01, 0x02]
+  func lengthPrefixedTwoByteMessage(withCompression compression: Bool = false) -> [UInt8] {
+    return [
+      compression ? 0x01 : 0x00,  // 1-byte compression flag
+      0x00, 0x00, 0x00, 0x02,     // 4-byte message length (2)
+    ] + twoByteMessage
+  }
+
+  func assertMessagesEqual(expected expectedBytes: [UInt8], actual buffer: ByteBuffer?, file: StaticString = #file, line: UInt = #line) {
+    guard let buffer = buffer else {
+      XCTFail("buffer is nil", file: file, line: line)
+      return
+    }
+
+    guard let bytes = buffer.getBytes(at: buffer.readerIndex, length: expectedBytes.count) else {
+      XCTFail("Expected \(expectedBytes.count) bytes, but only \(buffer.readableBytes) bytes are readable", file: file, line: line)
+      return
+    }
+
+    XCTAssertEqual(expectedBytes, bytes, file: file, line: line)
+  }
+
+  func testNextMessageReturnsNilWhenNoBytesAppended() throws {
+    XCTAssertNil(try reader.nextMessage())
+  }
+
+  func testNextMessageReturnsMessageIsAppendedInOneBuffer() throws {
+    var buffer = byteBuffer(withBytes: lengthPrefixedTwoByteMessage())
+    reader.append(buffer: &buffer)
+
+    self.assertMessagesEqual(expected: twoByteMessage, actual: try reader.nextMessage())
+  }
+
+  func testNextMessageReturnsMessageForZeroLengthMessage() throws {
+    let bytes: [UInt8] = [
+      0x00,                    // 1-byte compression flag
+      0x00, 0x00, 0x00, 0x00,  // 4-byte message length (0)
+                               // 0-byte message
+    ]
+
+    var buffer = byteBuffer(withBytes: bytes)
+    reader.append(buffer: &buffer)
+
+    self.assertMessagesEqual(expected: [], actual: try reader.nextMessage())
+  }
+
+  func testNextMessageDeliveredAcrossMultipleByteBuffers() throws {
+    let firstBytes: [UInt8] = [
+      0x00,              // 1-byte compression flag
+      0x00, 0x00, 0x00,  // first 3 bytes of 4-byte message length
+    ]
+
+    let secondBytes: [UInt8] = [
+      0x02,              // fourth byte of 4-byte message length (2)
+      0xf0, 0xba,        // 2-byte message
+    ]
+
+    var firstBuffer = byteBuffer(withBytes: firstBytes)
+    reader.append(buffer: &firstBuffer)
+    var secondBuffer = byteBuffer(withBytes: secondBytes)
+    reader.append(buffer: &secondBuffer)
+
+    self.assertMessagesEqual(expected: [0xf0, 0xba], actual: try reader.nextMessage())
+  }
+
+  func testNextMessageWhenMultipleMessagesAreBuffered() throws {
+    let bytes: [UInt8] = [
+      // 1st message
+      0x00,                    // 1-byte compression flag
+      0x00, 0x00, 0x00, 0x02,  // 4-byte message length (2)
+      0x0f, 0x00,              // 2-byte message
+      // 2nd message
+      0x00,                    // 1-byte compression flag
+      0x00, 0x00, 0x00, 0x04,  // 4-byte message length (4)
+      0xde, 0xad, 0xbe, 0xef,  // 4-byte message
+      // 3rd message
+      0x00,                    // 1-byte compression flag
+      0x00, 0x00, 0x00, 0x01,  // 4-byte message length (1)
+      0x01,                    // 1-byte message
+    ]
+
+    var buffer = byteBuffer(withBytes: bytes)
+    reader.append(buffer: &buffer)
+
+    self.assertMessagesEqual(expected: [0x0f, 0x00], actual: try reader.nextMessage())
+    self.assertMessagesEqual(expected: [0xde, 0xad, 0xbe, 0xef], actual: try reader.nextMessage())
+    self.assertMessagesEqual(expected: [0x01], actual: try reader.nextMessage())
+  }
+
+  func testNextMessageReturnsNilWhenNoMessageLengthIsAvailable() throws {
+    let bytes: [UInt8] = [
+      0x00,  // 1-byte compression flag
+    ]
+
+    var buffer = byteBuffer(withBytes: bytes)
+    reader.append(buffer: &buffer)
+
+    XCTAssertNil(try reader.nextMessage())
+
+    // Ensure we can read a message when the rest of the bytes are delivered
+    let restOfBytes: [UInt8] = [
+      0x00, 0x00, 0x00, 0x01,  // 4-byte message length (1)
+      0x00,                    // 1-byte message
+    ]
+
+    var secondBuffer = byteBuffer(withBytes: restOfBytes)
+    reader.append(buffer: &secondBuffer)
+    self.assertMessagesEqual(expected: [0x00], actual: try reader.nextMessage())
+  }
+
+  func testNextMessageReturnsNilWhenNotAllMessageLengthIsAvailable() throws {
+    let bytes: [UInt8] = [
+      0x00,        // 1-byte compression flag
+      0x00, 0x00,  // 2-bytes of message length (should be 4)
+    ]
+
+    var buffer = byteBuffer(withBytes: bytes)
+    reader.append(buffer: &buffer)
+
+    XCTAssertNil(try reader.nextMessage())
+
+    // Ensure we can read a message when the rest of the bytes are delivered
+    let restOfBytes: [UInt8] = [
+      0x00, 0x01,  // 4-byte message length (1)
+      0x00,        // 1-byte message
+    ]
+
+    var secondBuffer = byteBuffer(withBytes: restOfBytes)
+    reader.append(buffer: &secondBuffer)
+    self.assertMessagesEqual(expected: [0x00], actual: try reader.nextMessage())
+  }
+
+
+  func testNextMessageReturnsNilWhenNoMessageBytesAreAvailable() throws {
+    let bytes: [UInt8] = [
+      0x00,                    // 1-byte compression flag
+      0x00, 0x00, 0x00, 0x02,  // 4-byte message length (2)
+    ]
+
+    var buffer = byteBuffer(withBytes: bytes)
+    reader.append(buffer: &buffer)
+
+    XCTAssertNil(try reader.nextMessage())
+
+    // Ensure we can read a message when the rest of the bytes are delivered
+    var secondBuffer = byteBuffer(withBytes: twoByteMessage)
+    reader.append(buffer: &secondBuffer)
+    self.assertMessagesEqual(expected: twoByteMessage, actual: try reader.nextMessage())
+  }
+
+  func testNextMessageReturnsNilWhenNotAllMessageBytesAreAvailable() throws {
+    let bytes: [UInt8] = [
+      0x00,                    // 1-byte compression flag
+      0x00, 0x00, 0x00, 0x02,  // 4-byte message length (2)
+      0x00,                    // 1-byte of message
+    ]
+
+    var buffer = byteBuffer(withBytes: bytes)
+    reader.append(buffer: &buffer)
+
+    XCTAssertNil(try reader.nextMessage())
+
+    // Ensure we can read a message when the rest of the bytes are delivered
+    let restOfBytes: [UInt8] = [
+      0x01  // final byte of message
+    ]
+
+    var secondBuffer = byteBuffer(withBytes: restOfBytes)
+    reader.append(buffer: &secondBuffer)
+    self.assertMessagesEqual(expected: [0x00, 0x01], actual: try reader.nextMessage())
+  }
+
+  func testNextMessageThrowsWhenCompressionMechanismIsNotSupported() throws {
+    // Unknown should never be supported.
+    reader.compressionMechanism = .unknown
+    XCTAssertFalse(reader.compressionMechanism.supported)
+
+    var buffer = byteBuffer(withBytes: lengthPrefixedTwoByteMessage(withCompression: true))
+    reader.append(buffer: &buffer)
+
+    XCTAssertThrowsError(try reader.nextMessage()) { error in
+      XCTAssertEqual(.unsupportedCompressionMechanism("unknown"), (error as? GRPCError)?.error as? GRPCCommonError)
+    }
+  }
+
+  func testNextMessageThrowsWhenCompressionFlagIsSetButNotExpected() throws {
+    // Default compression mechanism is `.none` which requires that no
+    // compression flag is set as it indicates a lack of message encoding header.
+    XCTAssertFalse(reader.compressionMechanism.requiresFlag)
+
+    var buffer = byteBuffer(withBytes: lengthPrefixedTwoByteMessage(withCompression: true))
+    reader.append(buffer: &buffer)
+
+    XCTAssertThrowsError(try reader.nextMessage()) { error in
+      XCTAssertEqual(.unexpectedCompression, (error as? GRPCError)?.error as? GRPCCommonError)
+    }
+  }
+
+  func testNextMessageDoesNotThrowWhenCompressionFlagIsExpectedButNotSet() throws {
+    // `.identity` should always be supported and requires a flag.
+    reader.compressionMechanism = .identity
+    XCTAssertTrue(reader.compressionMechanism.supported)
+    XCTAssertTrue(reader.compressionMechanism.requiresFlag)
+
+    var buffer = byteBuffer(withBytes: lengthPrefixedTwoByteMessage())
+    reader.append(buffer: &buffer)
+
+    self.assertMessagesEqual(expected: twoByteMessage, actual: try reader.nextMessage())
+  }
+
+  func testAppendReadsAllBytes() throws {
+    var buffer = byteBuffer(withBytes: lengthPrefixedTwoByteMessage())
+    reader.append(buffer: &buffer)
+
+    XCTAssertEqual(0, buffer.readableBytes)
+  }
+}