浏览代码

Avoid unnecessary arrays. (#935)

Motivation:

HTTP1ToGRPCServerCodec currently creates a temporary array for parsing
all messages into before it forwards them on. This is both a minor perf
drain (due to the extra allocations) and a correctness problem, as it
makes this channel handler non-reentrant-safe. We should fix both
issues.

Modifications:

- Replace the temporary array with a simple loop.
- Add tests that validates correct behaviour on reentrancy.

Result:

Better re-entrancy behaviour! Verrrry slightly better perf.
Cory Benfield 5 年之前
父节点
当前提交
d3f039d2f3

+ 8 - 8
Sources/GRPC/HTTP1ToGRPCServerCodec.swift

@@ -270,10 +270,12 @@ extension HTTP1ToGRPCServerCodec: ChannelInboundHandler {
     }
 
     self.messageReader.append(buffer: &body)
-    var requests: [ByteBuffer] = []
     do {
-      while let buffer = try self.messageReader.nextMessage() {
-        requests.append(buffer)
+      // We may be re-entrantly called, and that re-entrant call may error. If the state changed for any reason,
+      // stop looping.
+      while self.inboundState == .expectingBody,
+        let buffer = try self.messageReader.nextMessage() {
+        context.fireChannelRead(self.wrapInboundOut(.message(buffer)))
       }
     } catch let grpcError as GRPCError.WithContext {
       context.fireErrorCaught(grpcError)
@@ -283,11 +285,9 @@ extension HTTP1ToGRPCServerCodec: ChannelInboundHandler {
       return .ignore
     }
 
-    requests.forEach {
-      context.fireChannelRead(self.wrapInboundOut(.message($0)))
-    }
-
-    return .expectingBody
+    // We may have been called re-entrantly and transitioned out of the state we were in (e.g. because of an
+    // error). In all cases, if we get here we want to persist the current state.
+    return self.inboundState
   }
 
   private func processEnd(context: ChannelHandlerContext,

+ 142 - 0
Tests/GRPCTests/HTTP1ToGRPCServerCodecTests.swift

@@ -22,6 +22,39 @@ import NIO
 import NIOHTTP1
 import XCTest
 
+/// A trivial channel handler that invokes a callback once, the first time it sees
+/// channelRead.
+final class OnFirstReadHandler: ChannelInboundHandler {
+  typealias InboundIn = Any
+  typealias InboundOut = Any
+
+  private var callback: (() -> Void)?
+
+  init(callback: @escaping () -> Void) {
+    self.callback = callback
+  }
+
+  func channelRead(context: ChannelHandlerContext, data: NIOAny) {
+    context.fireChannelRead(data)
+
+    if let callback = self.callback {
+      self.callback = nil
+      callback()
+    }
+  }
+}
+
+final class ErrorRecordingHandler: ChannelInboundHandler {
+  typealias InboundIn = Any
+
+  var errors: [Error] = []
+
+  func errorCaught(context: ChannelHandlerContext, error: Error) {
+    self.errors.append(error)
+    context.fireErrorCaught(error)
+  }
+}
+
 class HTTP1ToGRPCServerCodecTests: GRPCTestCase {
   var channel: EmbeddedChannel!
 
@@ -127,4 +160,113 @@ class HTTP1ToGRPCServerCodecTests: GRPCTestCase {
       }
     }
   }
+
+  func testReentrantMessageDelivery() throws {
+    XCTAssertNoThrow(
+      try self.channel
+        .writeInbound(HTTPServerRequestPart.head(self.makeRequestHead()))
+    )
+    let requestPart = try self.channel.readInbound(as: _RawGRPCServerRequestPart.self)
+
+    switch requestPart {
+    case .some(.head):
+      ()
+    default:
+      XCTFail("Unexpected request part: \(String(describing: requestPart))")
+    }
+
+    // Write three messages into a single body.
+    var buffer = self.channel.allocator.buffer(capacity: 0)
+    let serializedMessages: [Data] = try ["foo", "bar", "baz"].map { text in
+      Echo_EchoRequest.with { $0.text = text }
+    }.map { request in
+      try request.serializedData()
+    }
+
+    for data in serializedMessages {
+      buffer.writeInteger(UInt8(0))
+      buffer.writeInteger(UInt32(data.count))
+      buffer.writeBytes(data)
+    }
+
+    // Create an OnFirstReadHandler that will _also_ send the data when it sees the first read.
+    // This is try! because it cannot throw.
+    let onFirstRead = OnFirstReadHandler {
+      try! self.channel.writeInbound(HTTPServerRequestPart.body(buffer))
+    }
+    XCTAssertNoThrow(try self.channel.pipeline.addHandler(onFirstRead).wait())
+
+    // Now write.
+    XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.body(buffer)))
+
+    // This must not re-order messages.
+    for message in [serializedMessages, serializedMessages].flatMap({ $0 }) {
+      let requestPart = try self.channel.readInbound(as: _RawGRPCServerRequestPart.self)
+      switch requestPart {
+      case var .some(.message(buffer)):
+        XCTAssertEqual(message, buffer.readData(length: buffer.readableBytes)!)
+      default:
+        XCTFail("Unexpected request part: \(String(describing: requestPart))")
+      }
+    }
+  }
+
+  func testErrorsOnlyHappenOnce() throws {
+    XCTAssertNoThrow(
+      try self.channel
+        .writeInbound(HTTPServerRequestPart.head(self.makeRequestHead()))
+    )
+    let requestPart = try self.channel.readInbound(as: _RawGRPCServerRequestPart.self)
+
+    switch requestPart {
+    case .some(.head):
+      ()
+    default:
+      XCTFail("Unexpected request part: \(String(describing: requestPart))")
+    }
+
+    // Write three messages into a single body.
+    var buffer = self.channel.allocator.buffer(capacity: 0)
+    let serializedMessages: [Data] = try ["foo", "bar", "baz"].map { text in
+      Echo_EchoRequest.with { $0.text = text }
+    }.map { request in
+      try request.serializedData()
+    }
+
+    for data in serializedMessages {
+      buffer.writeInteger(UInt8(0))
+      buffer.writeInteger(UInt32(data.count))
+      buffer.writeBytes(data)
+    }
+
+    // Create an OnFirstReadHandler that will _also_ send the data when it sees the first read.
+    // This is try! because it cannot throw.
+    let onFirstRead = OnFirstReadHandler {
+      // Let's create a bad message: we'll turn on compression. We use two bytes here to deal with the fact that
+      // in hitting the error we'll actually consume the first byte (whoops).
+      var badBuffer = self.channel.allocator.buffer(capacity: 0)
+      badBuffer.writeInteger(UInt8(1))
+      badBuffer.writeInteger(UInt8(1))
+      _ = try? self.channel.writeInbound(HTTPServerRequestPart.body(badBuffer))
+    }
+    let errorHandler = ErrorRecordingHandler()
+    XCTAssertNoThrow(try self.channel.pipeline.addHandlers([onFirstRead, errorHandler]).wait())
+
+    // Now write.
+    XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.body(buffer)))
+
+    // We should have seen the original three messages
+    for message in serializedMessages {
+      let requestPart = try self.channel.readInbound(as: _RawGRPCServerRequestPart.self)
+      switch requestPart {
+      case var .some(.message(buffer)):
+        XCTAssertEqual(message, buffer.readData(length: buffer.readableBytes)!)
+      default:
+        XCTFail("Unexpected request part: \(String(describing: requestPart))")
+      }
+    }
+
+    // We should have recorded only one error.
+    XCTAssertEqual(errorHandler.errors.count, 1)
+  }
 }

+ 2 - 0
Tests/GRPCTests/XCTestManifests.swift

@@ -646,7 +646,9 @@ extension HTTP1ToGRPCServerCodecTests {
     //   `swift test --generate-linuxmain`
     // to regenerate.
     static let __allTests__HTTP1ToGRPCServerCodecTests = [
+        ("testErrorsOnlyHappenOnce", testErrorsOnlyHappenOnce),
         ("testMultipleMessagesFromSingleBodyPart", testMultipleMessagesFromSingleBodyPart),
+        ("testReentrantMessageDelivery", testReentrantMessageDelivery),
         ("testSingleMessageFromMultipleBodyParts", testSingleMessageFromMultipleBodyParts),
     ]
 }