Browse Source

Add missing transitions to GRPCStreamStateMachine (#1831)

Gustavo Cairo 1 year ago
parent
commit
e3b3e16ea4

+ 93 - 11
Sources/GRPCHTTP2Core/GRPCStreamStateMachine.swift

@@ -110,7 +110,7 @@ private enum GRPCStreamStateMachineState {
   }
 
   struct ClientOpenServerClosedState {
-    var framer: GRPCMessageFramer
+    var framer: GRPCMessageFramer?
     var compressor: Zlib.Compressor?
 
     let deframer: NIOSingleStepByteToMessageProcessor<GRPCMessageDeframer>?
@@ -118,6 +118,21 @@ private enum GRPCStreamStateMachineState {
 
     var inboundMessageBuffer: OneOrManyQueue<[UInt8]>
 
+    // This transition should only happen on the server-side when, upon receiving
+    // initial client metadata, some of the headers are invalid and we must reject
+    // the RPC.
+    // We will mark the client as open (because it sent initial metadata albeit
+    // invalid) but we'll close the server, meaning all future messages sent from
+    // the client will be ignored. Because of this, we won't need to frame or
+    // deframe any messages, as we won't be reading or writing any messages.
+    init(previousState: ClientIdleServerIdleState) {
+      self.framer = nil
+      self.compressor = nil
+      self.deframer = nil
+      self.decompressor = nil
+      self.inboundMessageBuffer = .init()
+    }
+
     init(previousState: ClientOpenServerOpenState) {
       self.framer = previousState.framer
       self.compressor = previousState.compressor
@@ -240,12 +255,25 @@ private enum GRPCStreamStateMachineState {
     // We still need the framer and compressor in case the server has closed
     // but its buffer is not yet empty and still needs to send messages out to
     // the client.
-    var framer: GRPCMessageFramer
+    var framer: GRPCMessageFramer?
     var compressor: Zlib.Compressor?
 
     // These are already deframed, so we don't need the deframer anymore.
     var inboundMessageBuffer: OneOrManyQueue<[UInt8]>
 
+    // This transition should only happen on the server-side when, upon receiving
+    // initial client metadata, some of the headers are invalid and we must reject
+    // the RPC.
+    // We will mark the client as closed (because it set the EOS flag, even if
+    // the initial metadata was invalid) and we'll close the server too.
+    // Because of this, we won't need to frame any messages, as we
+    // won't be writing any messages.
+    init(previousState: ClientIdleServerIdleState) {
+      self.framer = nil
+      self.compressor = nil
+      self.inboundMessageBuffer = .init()
+    }
+
     init(previousState: ClientClosedServerOpenState) {
       self.framer = previousState.framer
       self.compressor = previousState.compressor
@@ -1062,6 +1090,21 @@ extension GRPCStreamStateMachine {
     }
   }
 
+  mutating private func closeServerAndBuildRejectRPCAction(
+    currentState: GRPCStreamStateMachineState.ClientIdleServerIdleState,
+    endStream: Bool,
+    rejectWithStatus status: Status
+  ) -> OnMetadataReceived {
+    if endStream {
+      self.state = .clientClosedServerClosed(.init(previousState: currentState))
+    } else {
+      self.state = .clientOpenServerClosed(.init(previousState: currentState))
+    }
+
+    let trailers = self.makeTrailers(status: status, customMetadata: nil, trailersOnly: true)
+    return .rejectRPC(trailers: trailers)
+  }
+
   private mutating func serverReceive(
     metadata: HPACKHeaders,
     endStream: Bool,
@@ -1071,7 +1114,9 @@ extension GRPCStreamStateMachine {
     case .clientIdleServerIdle(let state):
       let contentType = metadata.firstString(forKey: .contentType)
         .flatMap { ContentType(value: $0) }
-      guard contentType != nil else {
+      if contentType == nil {
+        self.state = .clientOpenServerClosed(.init(previousState: state))
+
         // Respond with HTTP-level Unsupported Media Type status code.
         var trailers = HPACKHeaders()
         trailers.add("415", forKey: .status)
@@ -1080,13 +1125,50 @@ extension GRPCStreamStateMachine {
 
       let path = metadata.firstString(forKey: .path)
         .flatMap { MethodDescriptor(fullyQualifiedMethod: $0) }
-      guard path != nil else {
-        let status = Status(
-          code: .unimplemented,
-          message: "No \(GRPCHTTP2Keys.path.rawValue) header has been set."
+      if path == nil {
+        return self.closeServerAndBuildRejectRPCAction(
+          currentState: state,
+          endStream: endStream,
+          rejectWithStatus: Status(
+            code: .unimplemented,
+            message: "No \(GRPCHTTP2Keys.path.rawValue) header has been set."
+          )
+        )
+      }
+
+      let scheme = metadata.firstString(forKey: .scheme)
+        .flatMap { Scheme(rawValue: $0) }
+      if scheme == nil {
+        return self.closeServerAndBuildRejectRPCAction(
+          currentState: state,
+          endStream: endStream,
+          rejectWithStatus: Status(
+            code: .invalidArgument,
+            message: ":scheme header must be present and one of \"http\" or \"https\"."
+          )
+        )
+      }
+
+      guard let method = metadata.firstString(forKey: .method), method == "POST" else {
+        return self.closeServerAndBuildRejectRPCAction(
+          currentState: state,
+          endStream: endStream,
+          rejectWithStatus: Status(
+            code: .invalidArgument,
+            message: ":method header is expected to be present and have a value of \"POST\"."
+          )
+        )
+      }
+
+      guard let te = metadata.firstString(forKey: .te), te == "trailers" else {
+        return self.closeServerAndBuildRejectRPCAction(
+          currentState: state,
+          endStream: endStream,
+          rejectWithStatus: Status(
+            code: .invalidArgument,
+            message: "\"te\" header is expected to be present and have a value of \"trailers\"."
+          )
         )
-        let trailers = self.makeTrailers(status: status, customMetadata: nil, trailersOnly: true)
-        return .rejectRPC(trailers: trailers)
       }
 
       func isIdentityOrCompatibleEncoding(_ clientEncoding: CompressionAlgorithm) -> Bool {
@@ -1265,7 +1347,7 @@ extension GRPCStreamStateMachine {
       self.state = .clientClosedServerOpen(state)
       return response.map { .sendMessage($0) } ?? .awaitMoreMessages
     case .clientOpenServerClosed(var state):
-      let response = try state.framer.next(compressor: state.compressor)
+      let response = try state.framer?.next(compressor: state.compressor)
       self.state = .clientOpenServerClosed(state)
       if let response {
         return .sendMessage(response)
@@ -1273,7 +1355,7 @@ extension GRPCStreamStateMachine {
         return .noMoreMessages
       }
     case .clientClosedServerClosed(var state):
-      let response = try state.framer.next(compressor: state.compressor)
+      let response = try state.framer?.next(compressor: state.compressor)
       self.state = .clientClosedServerClosed(state)
       if let response {
         return .sendMessage(response)

+ 171 - 2
Tests/GRPCHTTP2CoreTests/GRPCStreamStateMachineTests.swift

@@ -45,7 +45,7 @@ extension HPACKHeaders {
     GRPCHTTP2Keys.contentType.rawValue: "application/grpc",
     GRPCHTTP2Keys.method.rawValue: "POST",
     GRPCHTTP2Keys.scheme.rawValue: "https",
-    GRPCHTTP2Keys.te.rawValue: "te",
+    GRPCHTTP2Keys.te.rawValue: "trailers",
     GRPCHTTP2Keys.acceptEncoding.rawValue: "deflate",
     GRPCHTTP2Keys.encoding.rawValue: "deflate",
   ]
@@ -54,7 +54,7 @@ extension HPACKHeaders {
     GRPCHTTP2Keys.contentType.rawValue: "application/grpc",
     GRPCHTTP2Keys.method.rawValue: "POST",
     GRPCHTTP2Keys.scheme.rawValue: "https",
-    GRPCHTTP2Keys.te.rawValue: "te",
+    GRPCHTTP2Keys.te.rawValue: "trailers",
     GRPCHTTP2Keys.acceptEncoding.rawValue: "gzip",
     GRPCHTTP2Keys.encoding.rawValue: "gzip",
   ]
@@ -68,6 +68,45 @@ extension HPACKHeaders {
   fileprivate static let receivedWithoutEndpoint: Self = [
     GRPCHTTP2Keys.contentType.rawValue: "application/grpc"
   ]
+  fileprivate static let receivedWithoutTE: Self = [
+    GRPCHTTP2Keys.path.rawValue: "test/test",
+    GRPCHTTP2Keys.scheme.rawValue: "http",
+    GRPCHTTP2Keys.method.rawValue: "POST",
+    GRPCHTTP2Keys.contentType.rawValue: "application/grpc",
+  ]
+  fileprivate static let receivedWithInvalidTE: Self = [
+    GRPCHTTP2Keys.path.rawValue: "test/test",
+    GRPCHTTP2Keys.scheme.rawValue: "http",
+    GRPCHTTP2Keys.method.rawValue: "POST",
+    GRPCHTTP2Keys.contentType.rawValue: "application/grpc",
+    GRPCHTTP2Keys.te.rawValue: "invalidte",
+  ]
+  fileprivate static let receivedWithoutMethod: Self = [
+    GRPCHTTP2Keys.path.rawValue: "test/test",
+    GRPCHTTP2Keys.scheme.rawValue: "http",
+    GRPCHTTP2Keys.contentType.rawValue: "application/grpc",
+    GRPCHTTP2Keys.te.rawValue: "trailers",
+  ]
+  fileprivate static let receivedWithInvalidMethod: Self = [
+    GRPCHTTP2Keys.path.rawValue: "test/test",
+    GRPCHTTP2Keys.scheme.rawValue: "http",
+    GRPCHTTP2Keys.method.rawValue: "GET",
+    GRPCHTTP2Keys.contentType.rawValue: "application/grpc",
+    GRPCHTTP2Keys.te.rawValue: "trailers",
+  ]
+  fileprivate static let receivedWithoutScheme: Self = [
+    GRPCHTTP2Keys.path.rawValue: "test/test",
+    GRPCHTTP2Keys.method.rawValue: "POST",
+    GRPCHTTP2Keys.contentType.rawValue: "application/grpc",
+    GRPCHTTP2Keys.te.rawValue: "trailers",
+  ]
+  fileprivate static let receivedWithInvalidScheme: Self = [
+    GRPCHTTP2Keys.path.rawValue: "test/test",
+    GRPCHTTP2Keys.scheme.rawValue: "invalidscheme",
+    GRPCHTTP2Keys.method.rawValue: "POST",
+    GRPCHTTP2Keys.contentType.rawValue: "application/grpc",
+    GRPCHTTP2Keys.te.rawValue: "trailers",
+  ]
 
   // Server
   fileprivate static let serverInitialMetadata: Self = [
@@ -1502,6 +1541,136 @@ final class GRPCStreamServerStateMachineTests: XCTestCase {
     }
   }
 
+  func testReceiveMetadataWhenClientIdleAndServerIdle_MissingTE() throws {
+    var stateMachine = self.makeServerStateMachine(targetState: .clientIdleServerIdle)
+
+    let action = try stateMachine.receive(
+      metadata: .receivedWithoutTE,
+      endStream: false
+    )
+
+    self.assertRejectedRPC(action) { trailers in
+      XCTAssertEqual(
+        trailers,
+        [
+          ":status": "200",
+          "content-type": "application/grpc",
+          "grpc-status": "3",
+          "grpc-status-message":
+            "\"te\" header is expected to be present and have a value of \"trailers\".",
+        ]
+      )
+    }
+  }
+
+  func testReceiveMetadataWhenClientIdleAndServerIdle_InvalidTE() throws {
+    var stateMachine = self.makeServerStateMachine(targetState: .clientIdleServerIdle)
+
+    let action = try stateMachine.receive(
+      metadata: .receivedWithInvalidTE,
+      endStream: false
+    )
+
+    self.assertRejectedRPC(action) { trailers in
+      XCTAssertEqual(
+        trailers,
+        [
+          ":status": "200",
+          "content-type": "application/grpc",
+          "grpc-status": "3",
+          "grpc-status-message":
+            "\"te\" header is expected to be present and have a value of \"trailers\".",
+        ]
+      )
+    }
+  }
+
+  func testReceiveMetadataWhenClientIdleAndServerIdle_MissingMethod() throws {
+    var stateMachine = self.makeServerStateMachine(targetState: .clientIdleServerIdle)
+
+    let action = try stateMachine.receive(
+      metadata: .receivedWithoutMethod,
+      endStream: false
+    )
+
+    self.assertRejectedRPC(action) { trailers in
+      XCTAssertEqual(
+        trailers,
+        [
+          ":status": "200",
+          "content-type": "application/grpc",
+          "grpc-status": "3",
+          "grpc-status-message":
+            ":method header is expected to be present and have a value of \"POST\".",
+        ]
+      )
+    }
+  }
+
+  func testReceiveMetadataWhenClientIdleAndServerIdle_InvalidMethod() throws {
+    var stateMachine = self.makeServerStateMachine(targetState: .clientIdleServerIdle)
+
+    let action = try stateMachine.receive(
+      metadata: .receivedWithInvalidMethod,
+      endStream: false
+    )
+
+    self.assertRejectedRPC(action) { trailers in
+      XCTAssertEqual(
+        trailers,
+        [
+          ":status": "200",
+          "content-type": "application/grpc",
+          "grpc-status": "3",
+          "grpc-status-message":
+            ":method header is expected to be present and have a value of \"POST\".",
+        ]
+      )
+    }
+  }
+
+  func testReceiveMetadataWhenClientIdleAndServerIdle_MissingScheme() throws {
+    var stateMachine = self.makeServerStateMachine(targetState: .clientIdleServerIdle)
+
+    let action = try stateMachine.receive(
+      metadata: .receivedWithoutScheme,
+      endStream: false
+    )
+
+    self.assertRejectedRPC(action) { trailers in
+      XCTAssertEqual(
+        trailers,
+        [
+          ":status": "200",
+          "content-type": "application/grpc",
+          "grpc-status": "3",
+          "grpc-status-message": ":scheme header must be present and one of \"http\" or \"https\".",
+        ]
+      )
+    }
+  }
+
+  func testReceiveMetadataWhenClientIdleAndServerIdle_InvalidScheme() throws {
+    var stateMachine = self.makeServerStateMachine(targetState: .clientIdleServerIdle)
+
+    let action = try stateMachine.receive(
+      metadata: .receivedWithInvalidScheme,
+      endStream: false
+    )
+
+    self.assertRejectedRPC(action) { trailers in
+      XCTAssertEqual(
+        trailers,
+        [
+          ":status": "200",
+          "content-type": "application/grpc",
+          "grpc-status": "3",
+          "grpc-status-message": ":scheme header must be present and one of \"http\" or \"https\".",
+        ]
+      )
+    }
+  }
+
   func testReceiveMetadataWhenClientIdleAndServerIdle_ServerUnsupportedEncoding() throws {
     var stateMachine = self.makeServerStateMachine(targetState: .clientIdleServerIdle)