فهرست منبع

Buffer in the server pipeline configuration (#1564)

Motivation:

When not using TLS, the server pipeline configurator inspects the first
bytes on a connection to determine whether HTTP1 or HTTP2 is being used
and closes the connection if it is determined that neither are. It does
this by only parsing the first packet, which may not have enough bytes
to make a correct determination.

Modifications:

- Buffer bytes in the configurator.
- Parse the buffered bytes and only close if enough bytes have been
  received.

Result:

Better version determination.
George Barnett 3 سال پیش
والد
کامیت
c12f59f38e

+ 110 - 30
Sources/GRPC/GRPCServerPipelineConfigurator.swift

@@ -33,8 +33,8 @@ final class GRPCServerPipelineConfigurator: ChannelInboundHandler, RemovableChan
   /// The server configuration.
   /// The server configuration.
   private let configuration: Server.Configuration
   private let configuration: Server.Configuration
 
 
-  /// Reads which we're holding on to before the pipeline is configured.
-  private var bufferedReads = CircularBuffer<NIOAny>()
+  /// A buffer containing the buffered bytes.
+  private var buffer: ByteBuffer?
 
 
   /// The current state.
   /// The current state.
   private var state: State
   private var state: State
@@ -212,13 +212,17 @@ final class GRPCServerPipelineConfigurator: ChannelInboundHandler, RemovableChan
     buffer: ByteBuffer,
     buffer: ByteBuffer,
     context: ChannelHandlerContext
     context: ChannelHandlerContext
   ) {
   ) {
-    if HTTPVersionParser.prefixedWithHTTP2ConnectionPreface(buffer) {
+    switch HTTPVersionParser.determineHTTPVersion(buffer) {
+    case .http2:
       self.configureHTTP2(context: context)
       self.configureHTTP2(context: context)
-    } else if HTTPVersionParser.prefixedWithHTTP1RequestLine(buffer) {
+    case .http1:
       self.configureHTTP1(context: context)
       self.configureHTTP1(context: context)
-    } else {
+    case .unknown:
+      // Neither H2 nor H1 or the length limit has been exceeded.
       self.configuration.logger.error("Unable to determine http version, closing")
       self.configuration.logger.error("Unable to determine http version, closing")
       context.close(mode: .all, promise: nil)
       context.close(mode: .all, promise: nil)
+    case .notEnoughBytes:
+      () // Try again with more bytes.
     }
     }
   }
   }
 
 
@@ -268,13 +272,9 @@ final class GRPCServerPipelineConfigurator: ChannelInboundHandler, RemovableChan
 
 
   /// Try to parse the buffered data to determine whether or not HTTP/2 or HTTP/1 should be used.
   /// Try to parse the buffered data to determine whether or not HTTP/2 or HTTP/1 should be used.
   private func tryParsingBufferedData(context: ChannelHandlerContext) {
   private func tryParsingBufferedData(context: ChannelHandlerContext) {
-    guard let first = self.bufferedReads.first else {
-      // No data buffered yet. We'll try when we read.
-      return
+    if let buffer = self.buffer {
+      self.determineHTTPVersionAndConfigurePipeline(buffer: buffer, context: context)
     }
     }
-
-    let buffer = self.unwrapInboundIn(first)
-    self.determineHTTPVersionAndConfigurePipeline(buffer: buffer, context: context)
   }
   }
 
 
   // MARK: - Channel Handler
   // MARK: - Channel Handler
@@ -312,7 +312,8 @@ final class GRPCServerPipelineConfigurator: ChannelInboundHandler, RemovableChan
   }
   }
 
 
   internal func channelRead(context: ChannelHandlerContext, data: NIOAny) {
   internal func channelRead(context: ChannelHandlerContext, data: NIOAny) {
-    self.bufferedReads.append(data)
+    var buffer = self.unwrapInboundIn(data)
+    self.buffer.setOrWriteBuffer(&buffer)
 
 
     switch self.state {
     switch self.state {
     case .notConfigured(alpn: .notExpected),
     case .notConfigured(alpn: .notExpected),
@@ -335,8 +336,9 @@ final class GRPCServerPipelineConfigurator: ChannelInboundHandler, RemovableChan
     removalToken: ChannelHandlerContext.RemovalToken
     removalToken: ChannelHandlerContext.RemovalToken
   ) {
   ) {
     // Forward any buffered reads.
     // Forward any buffered reads.
-    while let read = self.bufferedReads.popFirst() {
-      context.fireChannelRead(read)
+    if let buffer = self.buffer {
+      self.buffer = nil
+      context.fireChannelRead(self.wrapInboundOut(buffer))
     }
     }
     context.leavePipeline(removalToken: removalToken)
     context.leavePipeline(removalToken: removalToken)
   }
   }
@@ -375,16 +377,64 @@ struct HTTPVersionParser {
 
 
   /// Determines whether the bytes in the `ByteBuffer` are prefixed with the HTTP/2 client
   /// Determines whether the bytes in the `ByteBuffer` are prefixed with the HTTP/2 client
   /// connection preface.
   /// connection preface.
-  static func prefixedWithHTTP2ConnectionPreface(_ buffer: ByteBuffer) -> Bool {
+  static func prefixedWithHTTP2ConnectionPreface(_ buffer: ByteBuffer) -> SubParseResult {
     let view = buffer.readableBytesView
     let view = buffer.readableBytesView
 
 
     guard view.count >= HTTPVersionParser.http2ClientMagic.count else {
     guard view.count >= HTTPVersionParser.http2ClientMagic.count else {
       // Not enough bytes.
       // Not enough bytes.
-      return false
+      return .notEnoughBytes
     }
     }
 
 
     let slice = view[view.startIndex ..< view.startIndex.advanced(by: self.http2ClientMagic.count)]
     let slice = view[view.startIndex ..< view.startIndex.advanced(by: self.http2ClientMagic.count)]
-    return slice.elementsEqual(HTTPVersionParser.http2ClientMagic)
+    return slice.elementsEqual(HTTPVersionParser.http2ClientMagic) ? .accepted : .rejected
+  }
+
+  enum ParseResult: Hashable {
+    case http1
+    case http2
+    case unknown
+    case notEnoughBytes
+  }
+
+  enum SubParseResult: Hashable {
+    case accepted
+    case rejected
+    case notEnoughBytes
+  }
+
+  private static let maxLengthToCheck = 1024
+
+  static func determineHTTPVersion(_ buffer: ByteBuffer) -> ParseResult {
+    switch Self.prefixedWithHTTP2ConnectionPreface(buffer) {
+    case .accepted:
+      return .http2
+
+    case .notEnoughBytes:
+      switch Self.prefixedWithHTTP1RequestLine(buffer) {
+      case .accepted:
+        // Not enough bytes to check H2, but enough to confirm H1.
+        return .http1
+      case .notEnoughBytes:
+        // Not enough bytes to check H2 or H1.
+        return .notEnoughBytes
+      case .rejected:
+        // Not enough bytes to check H2 and definitely not H1.
+        return .notEnoughBytes
+      }
+
+    case .rejected:
+      switch Self.prefixedWithHTTP1RequestLine(buffer) {
+      case .accepted:
+        // Not H2, but H1 is confirmed.
+        return .http1
+      case .notEnoughBytes:
+        // Not H2, but not enough bytes to reject H1 yet.
+        return .notEnoughBytes
+      case .rejected:
+        // Not H2 or H1.
+        return .unknown
+      }
+    }
   }
   }
 
 
   private static let http1_1 = [
   private static let http1_1 = [
@@ -399,29 +449,59 @@ struct HTTPVersionParser {
   ]
   ]
 
 
   /// Determines whether the bytes in the `ByteBuffer` are prefixed with an HTTP/1.1 request line.
   /// Determines whether the bytes in the `ByteBuffer` are prefixed with an HTTP/1.1 request line.
-  static func prefixedWithHTTP1RequestLine(_ buffer: ByteBuffer) -> Bool {
+  static func prefixedWithHTTP1RequestLine(_ buffer: ByteBuffer) -> SubParseResult {
     var readableBytesView = buffer.readableBytesView
     var readableBytesView = buffer.readableBytesView
 
 
+    // We don't need to validate the request line, only determine whether we think it's an HTTP1
+    // request line. Another handler will parse it properly.
+
     // From RFC 2616 § 5.1:
     // From RFC 2616 § 5.1:
     //   Request-Line = Method SP Request-URI SP HTTP-Version CRLF
     //   Request-Line = Method SP Request-URI SP HTTP-Version CRLF
 
 
-    // Read off the Method and Request-URI (and spaces).
-    guard readableBytesView.trimPrefix(to: UInt8(ascii: " ")) != nil,
-          readableBytesView.trimPrefix(to: UInt8(ascii: " ")) != nil else {
-      return false
+    // Get through the first space.
+    guard readableBytesView.dropPrefix(through: UInt8(ascii: " ")) != nil else {
+      let tooLong = buffer.readableBytes > Self.maxLengthToCheck
+      return tooLong ? .rejected : .notEnoughBytes
+    }
+
+    // Get through the second space.
+    guard readableBytesView.dropPrefix(through: UInt8(ascii: " ")) != nil else {
+      let tooLong = buffer.readableBytes > Self.maxLengthToCheck
+      return tooLong ? .rejected : .notEnoughBytes
+    }
+
+    // +2 for \r\n
+    guard readableBytesView.count >= (Self.http1_1.count + 2) else {
+      return .notEnoughBytes
     }
     }
 
 
-    // Read off the HTTP-Version and CR.
-    guard let versionView = readableBytesView.trimPrefix(to: UInt8(ascii: "\r")) else {
-      return false
+    guard let version = readableBytesView.dropPrefix(through: UInt8(ascii: "\r")),
+          readableBytesView.first == UInt8(ascii: "\n") else {
+      // If we didn't drop the prefix OR we did and the next byte wasn't '\n', then we had enough
+      // bytes but the '\r\n' wasn't present: reject this as being HTTP1.
+      return .rejected
+    }
+
+    return version.elementsEqual(Self.http1_1) ? .accepted : .rejected
+  }
+}
+
+extension Collection where Self == Self.SubSequence, Self.Element: Equatable {
+  /// Drops the prefix off the collection up to and including the first `separator`
+  /// only if that separator appears in the collection.
+  ///
+  /// Returns the prefix up to but not including the separator if it was found, nil otherwise.
+  mutating func dropPrefix(through separator: Element) -> SubSequence? {
+    if self.isEmpty {
+      return nil
     }
     }
 
 
-    // Check that the LF followed the CR.
-    guard readableBytesView.first == UInt8(ascii: "\n") else {
-      return false
+    guard let separatorIndex = self.firstIndex(of: separator) else {
+      return nil
     }
     }
 
 
-    // Now check the HTTP version.
-    return versionView.elementsEqual(HTTPVersionParser.http1_1)
+    let prefix = self[..<separatorIndex]
+    self = self[self.index(after: separatorIndex)...]
+    return prefix
   }
   }
 }
 }

+ 51 - 0
Tests/GRPCTests/GRPCServerPipelineConfiguratorTests.swift

@@ -135,6 +135,24 @@ class GRPCServerPipelineConfiguratorTests: GRPCTestCase {
     self.assertHTTP2Handler(isPresent: true)
     self.assertHTTP2Handler(isPresent: true)
   }
   }
 
 
+  func testHTTP2SetupViaBytesDripFed() {
+    self.setUp(tls: false)
+    var bytes = ByteBuffer(staticString: "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n")
+    var head = bytes.readSlice(length: bytes.readableBytes - 1)!
+    let tail = bytes.readSlice(length: 1)!
+
+    while let slice = head.readSlice(length: 1) {
+      assertThat(try self.channel.writeInbound(slice), .doesNotThrow())
+      self.assertConfigurator(isPresent: true)
+      self.assertHTTP2Handler(isPresent: false)
+    }
+
+    // Final byte.
+    assertThat(try self.channel.writeInbound(tail), .doesNotThrow())
+    self.assertConfigurator(isPresent: false)
+    self.assertHTTP2Handler(isPresent: true)
+  }
+
   func testHTTP1Dot1SetupViaBytes() {
   func testHTTP1Dot1SetupViaBytes() {
     self.setUp(tls: false)
     self.setUp(tls: false)
     let bytes = ByteBuffer(staticString: "GET http://www.foo.bar HTTP/1.1\r\n")
     let bytes = ByteBuffer(staticString: "GET http://www.foo.bar HTTP/1.1\r\n")
@@ -143,6 +161,39 @@ class GRPCServerPipelineConfiguratorTests: GRPCTestCase {
     self.assertGRPCWebToHTTP2Handler(isPresent: true)
     self.assertGRPCWebToHTTP2Handler(isPresent: true)
   }
   }
 
 
+  func testHTTP1Dot1SetupViaBytesDripFed() {
+    self.setUp(tls: false)
+    var bytes = ByteBuffer(staticString: "GET http://www.foo.bar HTTP/1.1\r\n")
+    var head = bytes.readSlice(length: bytes.readableBytes - 1)!
+    let tail = bytes.readSlice(length: 1)!
+
+    while let slice = head.readSlice(length: 1) {
+      assertThat(try self.channel.writeInbound(slice), .doesNotThrow())
+      self.assertConfigurator(isPresent: true)
+      self.assertGRPCWebToHTTP2Handler(isPresent: false)
+    }
+
+    // Final byte.
+    assertThat(try self.channel.writeInbound(tail), .doesNotThrow())
+    self.assertConfigurator(isPresent: false)
+    self.assertGRPCWebToHTTP2Handler(isPresent: true)
+  }
+
+  func testUnexpectedInputClosesEventuallyWhenDripFed() {
+    self.setUp(tls: false)
+    var bytes = ByteBuffer(repeating: UInt8(ascii: "a"), count: 2048)
+
+    while let slice = bytes.readSlice(length: 1) {
+      assertThat(try self.channel.writeInbound(slice), .doesNotThrow())
+      self.assertConfigurator(isPresent: true)
+      self.assertHTTP2Handler(isPresent: false)
+      self.assertGRPCWebToHTTP2Handler(isPresent: false)
+    }
+
+    self.channel.embeddedEventLoop.run()
+    assertThat(try self.channel.closeFuture.wait(), .doesNotThrow())
+  }
+
   func testReadsAreUnbufferedAfterConfiguration() throws {
   func testReadsAreUnbufferedAfterConfiguration() throws {
     self.setUp(tls: false)
     self.setUp(tls: false)
 
 

+ 17 - 11
Tests/GRPCTests/HTTPVersionParserTests.swift

@@ -22,58 +22,64 @@ class HTTPVersionParserTests: GRPCTestCase {
 
 
   func testHTTP2ExactlyTheRightBytes() {
   func testHTTP2ExactlyTheRightBytes() {
     let buffer = ByteBuffer(string: self.preface)
     let buffer = ByteBuffer(string: self.preface)
-    XCTAssertTrue(HTTPVersionParser.prefixedWithHTTP2ConnectionPreface(buffer))
+    XCTAssertEqual(HTTPVersionParser.prefixedWithHTTP2ConnectionPreface(buffer), .accepted)
   }
   }
 
 
   func testHTTP2TheRightBytesAndMore() {
   func testHTTP2TheRightBytesAndMore() {
     var buffer = ByteBuffer(string: self.preface)
     var buffer = ByteBuffer(string: self.preface)
     buffer.writeRepeatingByte(42, count: 1024)
     buffer.writeRepeatingByte(42, count: 1024)
-    XCTAssertTrue(HTTPVersionParser.prefixedWithHTTP2ConnectionPreface(buffer))
+    XCTAssertEqual(HTTPVersionParser.prefixedWithHTTP2ConnectionPreface(buffer), .accepted)
   }
   }
 
 
   func testHTTP2NoBytes() {
   func testHTTP2NoBytes() {
     let empty = ByteBuffer()
     let empty = ByteBuffer()
-    XCTAssertFalse(HTTPVersionParser.prefixedWithHTTP2ConnectionPreface(empty))
+    XCTAssertEqual(HTTPVersionParser.prefixedWithHTTP2ConnectionPreface(empty), .notEnoughBytes)
   }
   }
 
 
   func testHTTP2NotEnoughBytes() {
   func testHTTP2NotEnoughBytes() {
     var buffer = ByteBuffer(string: self.preface)
     var buffer = ByteBuffer(string: self.preface)
     buffer.moveWriterIndex(to: buffer.writerIndex - 1)
     buffer.moveWriterIndex(to: buffer.writerIndex - 1)
-    XCTAssertFalse(HTTPVersionParser.prefixedWithHTTP2ConnectionPreface(buffer))
+    XCTAssertEqual(HTTPVersionParser.prefixedWithHTTP2ConnectionPreface(buffer), .notEnoughBytes)
   }
   }
 
 
   func testHTTP2EnoughOfTheWrongBytes() {
   func testHTTP2EnoughOfTheWrongBytes() {
     let buffer = ByteBuffer(string: String(self.preface.reversed()))
     let buffer = ByteBuffer(string: String(self.preface.reversed()))
-    XCTAssertFalse(HTTPVersionParser.prefixedWithHTTP2ConnectionPreface(buffer))
+    XCTAssertEqual(HTTPVersionParser.prefixedWithHTTP2ConnectionPreface(buffer), .rejected)
   }
   }
 
 
   func testHTTP1RequestLine() {
   func testHTTP1RequestLine() {
     let buffer = ByteBuffer(staticString: "GET https://grpc.io/index.html HTTP/1.1\r\n")
     let buffer = ByteBuffer(staticString: "GET https://grpc.io/index.html HTTP/1.1\r\n")
-    XCTAssertTrue(HTTPVersionParser.prefixedWithHTTP1RequestLine(buffer))
+    XCTAssertEqual(HTTPVersionParser.prefixedWithHTTP1RequestLine(buffer), .accepted)
   }
   }
 
 
   func testHTTP1RequestLineAndMore() {
   func testHTTP1RequestLineAndMore() {
     let buffer = ByteBuffer(staticString: "GET https://grpc.io/index.html HTTP/1.1\r\nMore")
     let buffer = ByteBuffer(staticString: "GET https://grpc.io/index.html HTTP/1.1\r\nMore")
-    XCTAssertTrue(HTTPVersionParser.prefixedWithHTTP1RequestLine(buffer))
+    XCTAssertEqual(HTTPVersionParser.prefixedWithHTTP1RequestLine(buffer), .accepted)
   }
   }
 
 
   func testHTTP1RequestLineWithoutCRLF() {
   func testHTTP1RequestLineWithoutCRLF() {
     let buffer = ByteBuffer(staticString: "GET https://grpc.io/index.html HTTP/1.1")
     let buffer = ByteBuffer(staticString: "GET https://grpc.io/index.html HTTP/1.1")
-    XCTAssertFalse(HTTPVersionParser.prefixedWithHTTP1RequestLine(buffer))
+    XCTAssertEqual(HTTPVersionParser.prefixedWithHTTP1RequestLine(buffer), .notEnoughBytes)
   }
   }
 
 
   func testHTTP1NoBytes() {
   func testHTTP1NoBytes() {
     let empty = ByteBuffer()
     let empty = ByteBuffer()
-    XCTAssertFalse(HTTPVersionParser.prefixedWithHTTP1RequestLine(empty))
+    XCTAssertEqual(HTTPVersionParser.prefixedWithHTTP1RequestLine(empty), .notEnoughBytes)
   }
   }
 
 
   func testHTTP1IncompleteRequestLine() {
   func testHTTP1IncompleteRequestLine() {
     let buffer = ByteBuffer(staticString: "GET https://grpc.io/index.html")
     let buffer = ByteBuffer(staticString: "GET https://grpc.io/index.html")
-    XCTAssertFalse(HTTPVersionParser.prefixedWithHTTP1RequestLine(buffer))
+    XCTAssertEqual(HTTPVersionParser.prefixedWithHTTP1RequestLine(buffer), .notEnoughBytes)
   }
   }
 
 
   func testHTTP1MalformedVersion() {
   func testHTTP1MalformedVersion() {
     let buffer = ByteBuffer(staticString: "GET https://grpc.io/index.html ptth/1.1\r\n")
     let buffer = ByteBuffer(staticString: "GET https://grpc.io/index.html ptth/1.1\r\n")
-    XCTAssertFalse(HTTPVersionParser.prefixedWithHTTP1RequestLine(buffer))
+    XCTAssertEqual(HTTPVersionParser.prefixedWithHTTP1RequestLine(buffer), .rejected)
+  }
+
+  func testTooManyIncorrectBytes() {
+    let buffer = ByteBuffer(repeating: UInt8(ascii: "\r"), count: 2048)
+    XCTAssertEqual(HTTPVersionParser.prefixedWithHTTP2ConnectionPreface(buffer), .rejected)
+    XCTAssertEqual(HTTPVersionParser.prefixedWithHTTP1RequestLine(buffer), .rejected)
   }
   }
 }
 }