Просмотр исходного кода

Buffer in the server pipeline configuration (#1564)

Motivation:

When not using TLS, the server pipeline configurator inspects the first
bytes on a connection to determine whether HTTP1 or HTTP2 is being used
and closes the connection if it is determined that neither are. It does
this by only parsing the first packet, which may not have enough bytes
to make a correct determination.

Modifications:

- Buffer bytes in the configurator.
- Parse the buffered bytes and only close if enough bytes have been
  received.

Result:

Better version determination.
George Barnett 2 лет назад
Родитель
Сommit
c12f59f38e

+ 110 - 30
Sources/GRPC/GRPCServerPipelineConfigurator.swift

@@ -33,8 +33,8 @@ final class GRPCServerPipelineConfigurator: ChannelInboundHandler, RemovableChan
   /// The server configuration.
   private let configuration: Server.Configuration
 
-  /// Reads which we're holding on to before the pipeline is configured.
-  private var bufferedReads = CircularBuffer<NIOAny>()
+  /// A buffer containing the buffered bytes.
+  private var buffer: ByteBuffer?
 
   /// The current state.
   private var state: State
@@ -212,13 +212,17 @@ final class GRPCServerPipelineConfigurator: ChannelInboundHandler, RemovableChan
     buffer: ByteBuffer,
     context: ChannelHandlerContext
   ) {
-    if HTTPVersionParser.prefixedWithHTTP2ConnectionPreface(buffer) {
+    switch HTTPVersionParser.determineHTTPVersion(buffer) {
+    case .http2:
       self.configureHTTP2(context: context)
-    } else if HTTPVersionParser.prefixedWithHTTP1RequestLine(buffer) {
+    case .http1:
       self.configureHTTP1(context: context)
-    } else {
+    case .unknown:
+      // Neither H2 nor H1 or the length limit has been exceeded.
       self.configuration.logger.error("Unable to determine http version, closing")
       context.close(mode: .all, promise: nil)
+    case .notEnoughBytes:
+      () // Try again with more bytes.
     }
   }
 
@@ -268,13 +272,9 @@ final class GRPCServerPipelineConfigurator: ChannelInboundHandler, RemovableChan
 
   /// Try to parse the buffered data to determine whether or not HTTP/2 or HTTP/1 should be used.
   private func tryParsingBufferedData(context: ChannelHandlerContext) {
-    guard let first = self.bufferedReads.first else {
-      // No data buffered yet. We'll try when we read.
-      return
+    if let buffer = self.buffer {
+      self.determineHTTPVersionAndConfigurePipeline(buffer: buffer, context: context)
     }
-
-    let buffer = self.unwrapInboundIn(first)
-    self.determineHTTPVersionAndConfigurePipeline(buffer: buffer, context: context)
   }
 
   // MARK: - Channel Handler
@@ -312,7 +312,8 @@ final class GRPCServerPipelineConfigurator: ChannelInboundHandler, RemovableChan
   }
 
   internal func channelRead(context: ChannelHandlerContext, data: NIOAny) {
-    self.bufferedReads.append(data)
+    var buffer = self.unwrapInboundIn(data)
+    self.buffer.setOrWriteBuffer(&buffer)
 
     switch self.state {
     case .notConfigured(alpn: .notExpected),
@@ -335,8 +336,9 @@ final class GRPCServerPipelineConfigurator: ChannelInboundHandler, RemovableChan
     removalToken: ChannelHandlerContext.RemovalToken
   ) {
     // Forward any buffered reads.
-    while let read = self.bufferedReads.popFirst() {
-      context.fireChannelRead(read)
+    if let buffer = self.buffer {
+      self.buffer = nil
+      context.fireChannelRead(self.wrapInboundOut(buffer))
     }
     context.leavePipeline(removalToken: removalToken)
   }
@@ -375,16 +377,64 @@ struct HTTPVersionParser {
 
   /// Determines whether the bytes in the `ByteBuffer` are prefixed with the HTTP/2 client
   /// connection preface.
-  static func prefixedWithHTTP2ConnectionPreface(_ buffer: ByteBuffer) -> Bool {
+  static func prefixedWithHTTP2ConnectionPreface(_ buffer: ByteBuffer) -> SubParseResult {
     let view = buffer.readableBytesView
 
     guard view.count >= HTTPVersionParser.http2ClientMagic.count else {
       // Not enough bytes.
-      return false
+      return .notEnoughBytes
     }
 
     let slice = view[view.startIndex ..< view.startIndex.advanced(by: self.http2ClientMagic.count)]
-    return slice.elementsEqual(HTTPVersionParser.http2ClientMagic)
+    return slice.elementsEqual(HTTPVersionParser.http2ClientMagic) ? .accepted : .rejected
+  }
+
+  enum ParseResult: Hashable {
+    case http1
+    case http2
+    case unknown
+    case notEnoughBytes
+  }
+
+  enum SubParseResult: Hashable {
+    case accepted
+    case rejected
+    case notEnoughBytes
+  }
+
+  private static let maxLengthToCheck = 1024
+
+  static func determineHTTPVersion(_ buffer: ByteBuffer) -> ParseResult {
+    switch Self.prefixedWithHTTP2ConnectionPreface(buffer) {
+    case .accepted:
+      return .http2
+
+    case .notEnoughBytes:
+      switch Self.prefixedWithHTTP1RequestLine(buffer) {
+      case .accepted:
+        // Not enough bytes to check H2, but enough to confirm H1.
+        return .http1
+      case .notEnoughBytes:
+        // Not enough bytes to check H2 or H1.
+        return .notEnoughBytes
+      case .rejected:
+        // Not enough bytes to check H2 and definitely not H1.
+        return .notEnoughBytes
+      }
+
+    case .rejected:
+      switch Self.prefixedWithHTTP1RequestLine(buffer) {
+      case .accepted:
+        // Not H2, but H1 is confirmed.
+        return .http1
+      case .notEnoughBytes:
+        // Not H2, but not enough bytes to reject H1 yet.
+        return .notEnoughBytes
+      case .rejected:
+        // Not H2 or H1.
+        return .unknown
+      }
+    }
   }
 
   private static let http1_1 = [
@@ -399,29 +449,59 @@ struct HTTPVersionParser {
   ]
 
   /// Determines whether the bytes in the `ByteBuffer` are prefixed with an HTTP/1.1 request line.
-  static func prefixedWithHTTP1RequestLine(_ buffer: ByteBuffer) -> Bool {
+  static func prefixedWithHTTP1RequestLine(_ buffer: ByteBuffer) -> SubParseResult {
     var readableBytesView = buffer.readableBytesView
 
+    // We don't need to validate the request line, only determine whether we think it's an HTTP1
+    // request line. Another handler will parse it properly.
+
     // From RFC 2616 § 5.1:
     //   Request-Line = Method SP Request-URI SP HTTP-Version CRLF
 
-    // Read off the Method and Request-URI (and spaces).
-    guard readableBytesView.trimPrefix(to: UInt8(ascii: " ")) != nil,
-          readableBytesView.trimPrefix(to: UInt8(ascii: " ")) != nil else {
-      return false
+    // Get through the first space.
+    guard readableBytesView.dropPrefix(through: UInt8(ascii: " ")) != nil else {
+      let tooLong = buffer.readableBytes > Self.maxLengthToCheck
+      return tooLong ? .rejected : .notEnoughBytes
+    }
+
+    // Get through the second space.
+    guard readableBytesView.dropPrefix(through: UInt8(ascii: " ")) != nil else {
+      let tooLong = buffer.readableBytes > Self.maxLengthToCheck
+      return tooLong ? .rejected : .notEnoughBytes
+    }
+
+    // +2 for \r\n
+    guard readableBytesView.count >= (Self.http1_1.count + 2) else {
+      return .notEnoughBytes
     }
 
-    // Read off the HTTP-Version and CR.
-    guard let versionView = readableBytesView.trimPrefix(to: UInt8(ascii: "\r")) else {
-      return false
+    guard let version = readableBytesView.dropPrefix(through: UInt8(ascii: "\r")),
+          readableBytesView.first == UInt8(ascii: "\n") else {
+      // If we didn't drop the prefix OR we did and the next byte wasn't '\n', then we had enough
+      // bytes but the '\r\n' wasn't present: reject this as being HTTP1.
+      return .rejected
+    }
+
+    return version.elementsEqual(Self.http1_1) ? .accepted : .rejected
+  }
+}
+
+extension Collection where Self == Self.SubSequence, Self.Element: Equatable {
+  /// Drops the prefix off the collection up to and including the first `separator`
+  /// only if that separator appears in the collection.
+  ///
+  /// Returns the prefix up to but not including the separator if it was found, nil otherwise.
+  mutating func dropPrefix(through separator: Element) -> SubSequence? {
+    if self.isEmpty {
+      return nil
     }
 
-    // Check that the LF followed the CR.
-    guard readableBytesView.first == UInt8(ascii: "\n") else {
-      return false
+    guard let separatorIndex = self.firstIndex(of: separator) else {
+      return nil
     }
 
-    // Now check the HTTP version.
-    return versionView.elementsEqual(HTTPVersionParser.http1_1)
+    let prefix = self[..<separatorIndex]
+    self = self[self.index(after: separatorIndex)...]
+    return prefix
   }
 }

+ 51 - 0
Tests/GRPCTests/GRPCServerPipelineConfiguratorTests.swift

@@ -135,6 +135,24 @@ class GRPCServerPipelineConfiguratorTests: GRPCTestCase {
     self.assertHTTP2Handler(isPresent: true)
   }
 
+  func testHTTP2SetupViaBytesDripFed() {
+    self.setUp(tls: false)
+    var bytes = ByteBuffer(staticString: "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n")
+    var head = bytes.readSlice(length: bytes.readableBytes - 1)!
+    let tail = bytes.readSlice(length: 1)!
+
+    while let slice = head.readSlice(length: 1) {
+      assertThat(try self.channel.writeInbound(slice), .doesNotThrow())
+      self.assertConfigurator(isPresent: true)
+      self.assertHTTP2Handler(isPresent: false)
+    }
+
+    // Final byte.
+    assertThat(try self.channel.writeInbound(tail), .doesNotThrow())
+    self.assertConfigurator(isPresent: false)
+    self.assertHTTP2Handler(isPresent: true)
+  }
+
   func testHTTP1Dot1SetupViaBytes() {
     self.setUp(tls: false)
     let bytes = ByteBuffer(staticString: "GET http://www.foo.bar HTTP/1.1\r\n")
@@ -143,6 +161,39 @@ class GRPCServerPipelineConfiguratorTests: GRPCTestCase {
     self.assertGRPCWebToHTTP2Handler(isPresent: true)
   }
 
+  func testHTTP1Dot1SetupViaBytesDripFed() {
+    self.setUp(tls: false)
+    var bytes = ByteBuffer(staticString: "GET http://www.foo.bar HTTP/1.1\r\n")
+    var head = bytes.readSlice(length: bytes.readableBytes - 1)!
+    let tail = bytes.readSlice(length: 1)!
+
+    while let slice = head.readSlice(length: 1) {
+      assertThat(try self.channel.writeInbound(slice), .doesNotThrow())
+      self.assertConfigurator(isPresent: true)
+      self.assertGRPCWebToHTTP2Handler(isPresent: false)
+    }
+
+    // Final byte.
+    assertThat(try self.channel.writeInbound(tail), .doesNotThrow())
+    self.assertConfigurator(isPresent: false)
+    self.assertGRPCWebToHTTP2Handler(isPresent: true)
+  }
+
+  func testUnexpectedInputClosesEventuallyWhenDripFed() {
+    self.setUp(tls: false)
+    var bytes = ByteBuffer(repeating: UInt8(ascii: "a"), count: 2048)
+
+    while let slice = bytes.readSlice(length: 1) {
+      assertThat(try self.channel.writeInbound(slice), .doesNotThrow())
+      self.assertConfigurator(isPresent: true)
+      self.assertHTTP2Handler(isPresent: false)
+      self.assertGRPCWebToHTTP2Handler(isPresent: false)
+    }
+
+    self.channel.embeddedEventLoop.run()
+    assertThat(try self.channel.closeFuture.wait(), .doesNotThrow())
+  }
+
   func testReadsAreUnbufferedAfterConfiguration() throws {
     self.setUp(tls: false)
 

+ 17 - 11
Tests/GRPCTests/HTTPVersionParserTests.swift

@@ -22,58 +22,64 @@ class HTTPVersionParserTests: GRPCTestCase {
 
   func testHTTP2ExactlyTheRightBytes() {
     let buffer = ByteBuffer(string: self.preface)
-    XCTAssertTrue(HTTPVersionParser.prefixedWithHTTP2ConnectionPreface(buffer))
+    XCTAssertEqual(HTTPVersionParser.prefixedWithHTTP2ConnectionPreface(buffer), .accepted)
   }
 
   func testHTTP2TheRightBytesAndMore() {
     var buffer = ByteBuffer(string: self.preface)
     buffer.writeRepeatingByte(42, count: 1024)
-    XCTAssertTrue(HTTPVersionParser.prefixedWithHTTP2ConnectionPreface(buffer))
+    XCTAssertEqual(HTTPVersionParser.prefixedWithHTTP2ConnectionPreface(buffer), .accepted)
   }
 
   func testHTTP2NoBytes() {
     let empty = ByteBuffer()
-    XCTAssertFalse(HTTPVersionParser.prefixedWithHTTP2ConnectionPreface(empty))
+    XCTAssertEqual(HTTPVersionParser.prefixedWithHTTP2ConnectionPreface(empty), .notEnoughBytes)
   }
 
   func testHTTP2NotEnoughBytes() {
     var buffer = ByteBuffer(string: self.preface)
     buffer.moveWriterIndex(to: buffer.writerIndex - 1)
-    XCTAssertFalse(HTTPVersionParser.prefixedWithHTTP2ConnectionPreface(buffer))
+    XCTAssertEqual(HTTPVersionParser.prefixedWithHTTP2ConnectionPreface(buffer), .notEnoughBytes)
   }
 
   func testHTTP2EnoughOfTheWrongBytes() {
     let buffer = ByteBuffer(string: String(self.preface.reversed()))
-    XCTAssertFalse(HTTPVersionParser.prefixedWithHTTP2ConnectionPreface(buffer))
+    XCTAssertEqual(HTTPVersionParser.prefixedWithHTTP2ConnectionPreface(buffer), .rejected)
   }
 
   func testHTTP1RequestLine() {
     let buffer = ByteBuffer(staticString: "GET https://grpc.io/index.html HTTP/1.1\r\n")
-    XCTAssertTrue(HTTPVersionParser.prefixedWithHTTP1RequestLine(buffer))
+    XCTAssertEqual(HTTPVersionParser.prefixedWithHTTP1RequestLine(buffer), .accepted)
   }
 
   func testHTTP1RequestLineAndMore() {
     let buffer = ByteBuffer(staticString: "GET https://grpc.io/index.html HTTP/1.1\r\nMore")
-    XCTAssertTrue(HTTPVersionParser.prefixedWithHTTP1RequestLine(buffer))
+    XCTAssertEqual(HTTPVersionParser.prefixedWithHTTP1RequestLine(buffer), .accepted)
   }
 
   func testHTTP1RequestLineWithoutCRLF() {
     let buffer = ByteBuffer(staticString: "GET https://grpc.io/index.html HTTP/1.1")
-    XCTAssertFalse(HTTPVersionParser.prefixedWithHTTP1RequestLine(buffer))
+    XCTAssertEqual(HTTPVersionParser.prefixedWithHTTP1RequestLine(buffer), .notEnoughBytes)
   }
 
   func testHTTP1NoBytes() {
     let empty = ByteBuffer()
-    XCTAssertFalse(HTTPVersionParser.prefixedWithHTTP1RequestLine(empty))
+    XCTAssertEqual(HTTPVersionParser.prefixedWithHTTP1RequestLine(empty), .notEnoughBytes)
   }
 
   func testHTTP1IncompleteRequestLine() {
     let buffer = ByteBuffer(staticString: "GET https://grpc.io/index.html")
-    XCTAssertFalse(HTTPVersionParser.prefixedWithHTTP1RequestLine(buffer))
+    XCTAssertEqual(HTTPVersionParser.prefixedWithHTTP1RequestLine(buffer), .notEnoughBytes)
   }
 
   func testHTTP1MalformedVersion() {
     let buffer = ByteBuffer(staticString: "GET https://grpc.io/index.html ptth/1.1\r\n")
-    XCTAssertFalse(HTTPVersionParser.prefixedWithHTTP1RequestLine(buffer))
+    XCTAssertEqual(HTTPVersionParser.prefixedWithHTTP1RequestLine(buffer), .rejected)
+  }
+
+  func testTooManyIncorrectBytes() {
+    let buffer = ByteBuffer(repeating: UInt8(ascii: "\r"), count: 2048)
+    XCTAssertEqual(HTTPVersionParser.prefixedWithHTTP2ConnectionPreface(buffer), .rejected)
+    XCTAssertEqual(HTTPVersionParser.prefixedWithHTTP1RequestLine(buffer), .rejected)
   }
 }