Bläddra i källkod

Always store compression algorithm in stream state machine (#1898)

Gustavo Cairo 1 år sedan
förälder
incheckning
3c66c5f29a

+ 31 - 11
Sources/GRPCHTTP2Core/GRPCStreamStateMachine.swift

@@ -75,7 +75,7 @@ private enum GRPCStreamStateMachineState {
     let maximumPayloadSize: Int
     var framer: GRPCMessageFramer
     var compressor: Zlib.Compressor?
-    var outboundCompression: CompressionAlgorithm?
+    var outboundCompression: CompressionAlgorithm
 
     // The deframer must be optional because the client will not have one configured
     // until the server opens and sends a grpc-encoding header.
@@ -89,12 +89,14 @@ private enum GRPCStreamStateMachineState {
     init(
       previousState: ClientIdleServerIdleState,
       compressor: Zlib.Compressor?,
+      outboundCompression: CompressionAlgorithm,
       framer: GRPCMessageFramer,
       decompressor: Zlib.Decompressor?,
       deframer: NIOSingleStepByteToMessageProcessor<GRPCMessageDeframer>?
     ) {
       self.maximumPayloadSize = previousState.maximumPayloadSize
       self.compressor = compressor
+      self.outboundCompression = outboundCompression
       self.framer = framer
       self.decompressor = decompressor
       self.deframer = deframer
@@ -105,6 +107,7 @@ private enum GRPCStreamStateMachineState {
   struct ClientOpenServerOpenState {
     var framer: GRPCMessageFramer
     var compressor: Zlib.Compressor?
+    var outboundCompression: CompressionAlgorithm
 
     let deframer: NIOSingleStepByteToMessageProcessor<GRPCMessageDeframer>
     var decompressor: Zlib.Decompressor?
@@ -118,6 +121,7 @@ private enum GRPCStreamStateMachineState {
     ) {
       self.framer = previousState.framer
       self.compressor = previousState.compressor
+      self.outboundCompression = previousState.outboundCompression
 
       self.deframer = deframer
       self.decompressor = decompressor
@@ -129,6 +133,7 @@ private enum GRPCStreamStateMachineState {
   struct ClientOpenServerClosedState {
     var framer: GRPCMessageFramer?
     var compressor: Zlib.Compressor?
+    var outboundCompression: CompressionAlgorithm
 
     let deframer: NIOSingleStepByteToMessageProcessor<GRPCMessageDeframer>?
     var decompressor: Zlib.Decompressor?
@@ -145,6 +150,7 @@ private enum GRPCStreamStateMachineState {
     init(previousState: ClientIdleServerIdleState) {
       self.framer = nil
       self.compressor = nil
+      self.outboundCompression = .none
       self.deframer = nil
       self.decompressor = nil
       self.inboundMessageBuffer = .init()
@@ -153,6 +159,7 @@ private enum GRPCStreamStateMachineState {
     init(previousState: ClientOpenServerOpenState) {
       self.framer = previousState.framer
       self.compressor = previousState.compressor
+      self.outboundCompression = previousState.outboundCompression
       self.deframer = previousState.deframer
       self.decompressor = previousState.decompressor
       self.inboundMessageBuffer = previousState.inboundMessageBuffer
@@ -161,6 +168,7 @@ private enum GRPCStreamStateMachineState {
     init(previousState: ClientOpenServerIdleState) {
       self.framer = previousState.framer
       self.compressor = previousState.compressor
+      self.outboundCompression = previousState.outboundCompression
       self.inboundMessageBuffer = previousState.inboundMessageBuffer
       // The server went directly from idle to closed - this means it sent a
       // trailers-only response:
@@ -177,7 +185,7 @@ private enum GRPCStreamStateMachineState {
     let maximumPayloadSize: Int
     var framer: GRPCMessageFramer
     var compressor: Zlib.Compressor?
-    var outboundCompression: CompressionAlgorithm?
+    var outboundCompression: CompressionAlgorithm
 
     let deframer: NIOSingleStepByteToMessageProcessor<GRPCMessageDeframer>?
     var decompressor: Zlib.Decompressor?
@@ -195,9 +203,12 @@ private enum GRPCStreamStateMachineState {
 
       if let zlibMethod = Zlib.Method(encoding: compressionAlgorithm) {
         self.compressor = Zlib.Compressor(method: zlibMethod)
+        self.outboundCompression = compressionAlgorithm
+      } else {
+        self.compressor = nil
+        self.outboundCompression = .none
       }
       self.framer = GRPCMessageFramer()
-      self.outboundCompression = compressionAlgorithm
       // We don't need a deframer since we won't receive any messages from the
       // client: it's closed.
       self.deframer = nil
@@ -218,6 +229,7 @@ private enum GRPCStreamStateMachineState {
   struct ClientClosedServerOpenState {
     var framer: GRPCMessageFramer
     var compressor: Zlib.Compressor?
+    var outboundCompression: CompressionAlgorithm
 
     let deframer: NIOSingleStepByteToMessageProcessor<GRPCMessageDeframer>?
     var decompressor: Zlib.Decompressor?
@@ -227,6 +239,7 @@ private enum GRPCStreamStateMachineState {
     init(previousState: ClientOpenServerOpenState) {
       self.framer = previousState.framer
       self.compressor = previousState.compressor
+      self.outboundCompression = previousState.outboundCompression
       self.deframer = previousState.deframer
       self.decompressor = previousState.decompressor
       self.inboundMessageBuffer = previousState.inboundMessageBuffer
@@ -236,6 +249,7 @@ private enum GRPCStreamStateMachineState {
     init(previousState: ClientClosedServerIdleState) {
       self.framer = previousState.framer
       self.compressor = previousState.compressor
+      self.outboundCompression = previousState.outboundCompression
 
       // In the case of the server, we don't need to deframe/decompress any more
       // messages, since the client's closed.
@@ -252,6 +266,7 @@ private enum GRPCStreamStateMachineState {
     ) {
       self.framer = previousState.framer
       self.compressor = previousState.compressor
+      self.outboundCompression = previousState.outboundCompression
 
       // In the case of the client, it will only be able to set up the deframer
       // after it receives the chosen encoding from the server.
@@ -274,6 +289,7 @@ private enum GRPCStreamStateMachineState {
     // the client.
     var framer: GRPCMessageFramer?
     var compressor: Zlib.Compressor?
+    var outboundCompression: CompressionAlgorithm
 
     // These are already deframed, so we don't need the deframer anymore.
     var inboundMessageBuffer: OneOrManyQueue<[UInt8]>
@@ -288,36 +304,42 @@ private enum GRPCStreamStateMachineState {
     init(previousState: ClientIdleServerIdleState) {
       self.framer = nil
       self.compressor = nil
+      self.outboundCompression = .none
       self.inboundMessageBuffer = .init()
     }
 
     init(previousState: ClientClosedServerOpenState) {
       self.framer = previousState.framer
       self.compressor = previousState.compressor
+      self.outboundCompression = previousState.outboundCompression
       self.inboundMessageBuffer = previousState.inboundMessageBuffer
     }
 
     init(previousState: ClientClosedServerIdleState) {
       self.framer = previousState.framer
       self.compressor = previousState.compressor
+      self.outboundCompression = previousState.outboundCompression
       self.inboundMessageBuffer = previousState.inboundMessageBuffer
     }
 
     init(previousState: ClientOpenServerIdleState) {
       self.framer = previousState.framer
       self.compressor = previousState.compressor
+      self.outboundCompression = previousState.outboundCompression
       self.inboundMessageBuffer = previousState.inboundMessageBuffer
     }
 
     init(previousState: ClientOpenServerOpenState) {
       self.framer = previousState.framer
       self.compressor = previousState.compressor
+      self.outboundCompression = previousState.outboundCompression
       self.inboundMessageBuffer = previousState.inboundMessageBuffer
     }
 
     init(previousState: ClientOpenServerClosedState) {
       self.framer = previousState.framer
       self.compressor = previousState.compressor
+      self.outboundCompression = previousState.outboundCompression
       self.inboundMessageBuffer = previousState.inboundMessageBuffer
     }
   }
@@ -555,6 +577,7 @@ extension GRPCStreamStateMachine {
         .init(
           previousState: state,
           compressor: compressor,
+          outboundCompression: outboundEncoding,
           framer: GRPCMessageFramer(),
           decompressor: nil,
           deframer: nil
@@ -1019,10 +1042,6 @@ extension GRPCStreamStateMachine {
       headers.add(outboundEncoding.name, forKey: .encoding)
     }
 
-    for acceptedEncoding in configuration.acceptedEncodings.elements.filter({ $0 != .none }) {
-      headers.add(acceptedEncoding.name, forKey: .acceptEncoding)
-    }
-
     for metadataPair in customMetadata {
       headers.add(name: metadataPair.key, value: metadataPair.value.encoded())
     }
@@ -1037,6 +1056,7 @@ extension GRPCStreamStateMachine {
     // Server sends initial metadata
     switch self.state {
     case .clientOpenServerIdle(let state):
+      let outboundEncoding = state.outboundCompression
       self.state = .clientOpenServerOpen(
         .init(
           previousState: state,
@@ -1048,14 +1068,15 @@ extension GRPCStreamStateMachine {
         )
       )
       return self.makeResponseHeaders(
-        outboundEncoding: state.outboundCompression,
+        outboundEncoding: outboundEncoding,
         configuration: configuration,
         customMetadata: metadata
       )
     case .clientClosedServerIdle(let state):
+      let outboundEncoding = state.outboundCompression
       self.state = .clientClosedServerOpen(.init(previousState: state))
       return self.makeResponseHeaders(
-        outboundEncoding: state.outboundCompression,
+        outboundEncoding: outboundEncoding,
         configuration: configuration,
         customMetadata: metadata
       )
@@ -1326,8 +1347,6 @@ extension GRPCStreamStateMachine {
         canonicalForm: true
       )
       // Find the preferred encoding and use it to compress responses.
-      // If it's identity, just skip it altogether, since we won't be
-      // compressing.
       for clientAdvertisedEncoding in clientAdvertisedEncodings {
         if let algorithm = CompressionAlgorithm(name: clientAdvertisedEncoding),
           configuration.acceptedEncodings.contains(algorithm)
@@ -1358,6 +1377,7 @@ extension GRPCStreamStateMachine {
           .init(
             previousState: state,
             compressor: compressor,
+            outboundCompression: outboundEncoding,
             framer: GRPCMessageFramer(),
             decompressor: decompressor,
             deframer: NIOSingleStepByteToMessageProcessor(deframer)

+ 56 - 31
Tests/GRPCHTTP2CoreTests/GRPCStreamStateMachineTests.swift

@@ -117,19 +117,16 @@ extension HPACKHeaders {
   fileprivate static let serverInitialMetadata: Self = [
     GRPCHTTP2Keys.status.rawValue: "200",
     GRPCHTTP2Keys.contentType.rawValue: ContentType.grpc.canonicalValue,
-    GRPCHTTP2Keys.acceptEncoding.rawValue: "deflate",
   ]
   fileprivate static let serverInitialMetadataWithDeflateCompression: Self = [
     GRPCHTTP2Keys.status.rawValue: "200",
     GRPCHTTP2Keys.contentType.rawValue: ContentType.grpc.canonicalValue,
     GRPCHTTP2Keys.encoding.rawValue: "deflate",
-    GRPCHTTP2Keys.acceptEncoding.rawValue: "deflate",
   ]
   fileprivate static let serverInitialMetadataWithGZIPCompression: Self = [
     GRPCHTTP2Keys.status.rawValue: "200",
     GRPCHTTP2Keys.contentType.rawValue: ContentType.grpc.canonicalValue,
     GRPCHTTP2Keys.encoding.rawValue: "gzip",
-    GRPCHTTP2Keys.acceptEncoding.rawValue: "deflate",
   ]
   fileprivate static let serverTrailers: Self = [
     GRPCHTTP2Keys.status.rawValue: "200",
@@ -366,7 +363,6 @@ final class GRPCStreamClientStateMachineTests: XCTestCase {
           ":status": "200",
           "content-type": "application/grpc",
           "grpc-encoding": "gzip",
-          "grpc-accept-encoding": "deflate",
         ]
       )
     )
@@ -1010,7 +1006,6 @@ final class GRPCStreamClientStateMachineTests: XCTestCase {
         [
           ":status": "200",
           "content-type": "application/grpc",
-          "grpc-accept-encoding": "deflate",
         ],
         nil
       )
@@ -1113,7 +1108,6 @@ final class GRPCStreamClientStateMachineTests: XCTestCase {
         [
           ":status": "200",
           "content-type": "application/grpc",
-          "grpc-accept-encoding": "deflate",
         ],
         nil
       )
@@ -1200,7 +1194,6 @@ final class GRPCStreamClientStateMachineTests: XCTestCase {
         [
           ":status": "200",
           "content-type": "application/grpc",
-          "grpc-accept-encoding": "deflate",
         ],
         nil
       )
@@ -1255,14 +1248,14 @@ final class GRPCStreamClientStateMachineTests: XCTestCase {
 final class GRPCStreamServerStateMachineTests: XCTestCase {
   private func makeServerStateMachine(
     targetState: TargetStateMachineState,
-    compressionEnabled: Bool = false
+    deflateCompressionEnabled: Bool = false
   ) -> GRPCStreamStateMachine {
 
     var stateMachine = GRPCStreamStateMachine(
       configuration: .server(
         .init(
           scheme: .http,
-          acceptedEncodings: [.deflate]
+          acceptedEncodings: deflateCompressionEnabled ? [.deflate] : []
         )
       ),
       maximumPayloadSize: 100,
@@ -1270,7 +1263,8 @@ final class GRPCStreamServerStateMachineTests: XCTestCase {
     )
 
     let clientMetadata: HPACKHeaders =
-      compressionEnabled ? .clientInitialMetadataWithDeflateCompression : .clientInitialMetadata
+      deflateCompressionEnabled
+      ? .clientInitialMetadataWithDeflateCompression : .clientInitialMetadata
     switch targetState {
     case .clientIdleServerIdle:
       break
@@ -1343,8 +1337,34 @@ final class GRPCStreamServerStateMachineTests: XCTestCase {
   }
 
   func testSendMetadataWhenClientOpenAndServerIdle() throws {
-    var stateMachine = self.makeServerStateMachine(targetState: .clientOpenServerIdle)
-    XCTAssertNoThrow(try stateMachine.send(metadata: .init()))
+    var stateMachine = self.makeServerStateMachine(
+      targetState: .clientOpenServerIdle,
+      deflateCompressionEnabled: false
+    )
+    XCTAssertEqual(
+      try stateMachine.send(metadata: .init()),
+      [
+        ":status": "200",
+        "content-type": "application/grpc",
+      ]
+    )
+  }
+
+  func testSendMetadataWhenClientOpenAndServerIdle_AndCompressionEnabled() {
+    // Enable deflate compression on server
+    var stateMachine = self.makeServerStateMachine(
+      targetState: .clientOpenServerIdle,
+      deflateCompressionEnabled: true
+    )
+
+    XCTAssertEqual(
+      try stateMachine.send(metadata: .init()),
+      [
+        ":status": "200",
+        "content-type": "application/grpc",
+        "grpc-encoding": "deflate",
+      ]
+    )
   }
 
   func testSendMetadataWhenClientOpenAndServerOpen() throws {
@@ -1866,7 +1886,10 @@ final class GRPCStreamServerStateMachineTests: XCTestCase {
   }
 
   func testReceiveMetadataWhenClientIdleAndServerIdle_ServerUnsupportedEncoding() throws {
-    var stateMachine = self.makeServerStateMachine(targetState: .clientIdleServerIdle)
+    var stateMachine = self.makeServerStateMachine(
+      targetState: .clientIdleServerIdle,
+      deflateCompressionEnabled: true
+    )
 
     // Try opening client with a compression algorithm that is not accepted
     // by the server.
@@ -1876,18 +1899,23 @@ final class GRPCStreamServerStateMachineTests: XCTestCase {
     )
 
     self.assertRejectedRPC(action) { trailers in
-      XCTAssertEqual(
-        trailers,
-        [
-          ":status": "200",
-          "content-type": "application/grpc",
-          "grpc-accept-encoding": "deflate",
-          "grpc-status": "12",
-          "grpc-message":
-            "gzip compression is not supported; supported algorithms are listed in grpc-accept-encoding",
-          "grpc-accept-encoding": "identity",
-        ]
-      )
+      let expected: HPACKHeaders = [
+        ":status": "200",
+        "content-type": "application/grpc",
+        "grpc-status": "12",
+        "grpc-message":
+          "gzip compression is not supported; supported algorithms are listed in grpc-accept-encoding",
+        "grpc-accept-encoding": "deflate",
+        "grpc-accept-encoding": "identity",
+      ]
+      XCTAssertEqual(expected.count, trailers.count, "Expected \(expected) but got \(trailers)")
+      for header in trailers {
+        XCTAssertTrue(
+          expected.contains { name, value, _ in
+            header.name == name && header.value == header.value
+          }
+        )
+      }
     }
   }
 
@@ -2016,7 +2044,7 @@ final class GRPCStreamServerStateMachineTests: XCTestCase {
     // Enable deflate compression on server
     var stateMachine = self.makeServerStateMachine(
       targetState: .clientOpenServerOpen,
-      compressionEnabled: true
+      deflateCompressionEnabled: true
     )
 
     let originalMessage = [UInt8]([42, 42, 43, 43])
@@ -2171,7 +2199,7 @@ final class GRPCStreamServerStateMachineTests: XCTestCase {
   func testNextOutboundMessageWhenClientOpenAndServerOpen_WithCompression() throws {
     var stateMachine = self.makeServerStateMachine(
       targetState: .clientOpenServerOpen,
-      compressionEnabled: true
+      deflateCompressionEnabled: true
     )
 
     XCTAssertEqual(try stateMachine.nextOutboundFrame(), .awaitMoreMessages)
@@ -2308,7 +2336,7 @@ final class GRPCStreamServerStateMachineTests: XCTestCase {
   func testNextInboundMessageWhenClientOpenAndServerOpen_WithCompression() throws {
     var stateMachine = self.makeServerStateMachine(
       targetState: .clientOpenServerOpen,
-      compressionEnabled: true
+      deflateCompressionEnabled: true
     )
 
     let originalMessage = [UInt8]([42, 42, 43, 43])
@@ -2430,7 +2458,6 @@ final class GRPCStreamServerStateMachineTests: XCTestCase {
       [
         ":status": "200",
         "content-type": "application/grpc",
-        "grpc-accept-encoding": "deflate",
         "custom": "value",
       ]
     )
@@ -2552,7 +2579,6 @@ final class GRPCStreamServerStateMachineTests: XCTestCase {
         "custom": "value",
         ":status": "200",
         "content-type": "application/grpc",
-        "grpc-accept-encoding": "deflate",
       ]
     )
 
@@ -2626,7 +2652,6 @@ final class GRPCStreamServerStateMachineTests: XCTestCase {
         "custom": "value",
         ":status": "200",
         "content-type": "application/grpc",
-        "grpc-accept-encoding": "deflate",
       ]
     )