Browse Source

Merge pull request from GHSA-r6ww-5963-7r95

Better handle client sending GOAWAY
George Barnett 3 years ago
parent
commit
858f977f2a

BIN
FuzzTesting/FailCases/clusterfuzz-testcase-minimized-ServerFuzzer-release-4739158818553856


+ 16 - 1
Sources/GRPC/GRPCIdleHandler.swift

@@ -153,7 +153,19 @@ internal final class GRPCIdleHandler: ChannelInboundHandler {
         streamID: .rootStream,
         payload: .goAway(lastStreamID: streamID, errorCode: .noError, opaqueData: nil)
       )
-      self.context?.writeAndFlush(self.wrapOutboundOut(goAwayFrame), promise: nil)
+
+      self.context?.write(self.wrapOutboundOut(goAwayFrame), promise: nil)
+
+      // We emit a ping after some GOAWAY frames.
+      if operations.shouldPingAfterGoAway {
+        let pingFrame = HTTP2Frame(
+          streamID: .rootStream,
+          payload: .ping(self.pingHandler.pingDataGoAway, ack: false)
+        )
+        self.context?.write(self.wrapOutboundOut(pingFrame), promise: nil)
+      }
+
+      self.context?.flush()
     }
 
     // Close the channel, if necessary.
@@ -181,6 +193,9 @@ internal final class GRPCIdleHandler: ChannelInboundHandler {
     case let .reply(framePayload):
       let frame = HTTP2Frame(streamID: .rootStream, payload: framePayload)
       self.context?.writeAndFlush(self.wrapOutboundOut(frame), promise: nil)
+
+    case .ratchetDownLastSeenStreamID:
+      self.perform(operations: self.stateMachine.ratchetDownGoAwayStreamID())
     }
   }
 

+ 39 - 7
Sources/GRPC/GRPCIdleHandlerStateMachine.swift

@@ -189,10 +189,17 @@ struct GRPCIdleHandlerStateMachine {
     /// Whether the channel should be closed.
     private(set) var shouldCloseChannel: Bool
 
+    /// Whether a ping should be sent after a GOAWAY frame.
+    private(set) var shouldPingAfterGoAway: Bool
+
     fileprivate static let none = Operations()
 
-    fileprivate mutating func sendGoAwayFrame(lastPeerInitiatedStreamID streamID: HTTP2StreamID) {
+    fileprivate mutating func sendGoAwayFrame(
+      lastPeerInitiatedStreamID streamID: HTTP2StreamID,
+      followWithPing: Bool = false
+    ) {
       self.sendGoAwayWithLastPeerInitiatedStreamID = streamID
+      self.shouldPingAfterGoAway = followWithPing
     }
 
     fileprivate mutating func cancelIdleTask(_ task: Scheduled<Void>) {
@@ -220,6 +227,7 @@ struct GRPCIdleHandlerStateMachine {
       self.idleTask = nil
       self.sendGoAwayWithLastPeerInitiatedStreamID = nil
       self.shouldCloseChannel = false
+      self.shouldPingAfterGoAway = false
     }
   }
 
@@ -267,12 +275,7 @@ struct GRPCIdleHandlerStateMachine {
       operations.cancelIdleTask(state.idleTask)
 
     case var .quiescing(state):
-      precondition(state.initiatedByUs)
-      precondition(state.role == .client)
-      // If we're a client and we initiated shutdown then it's possible for streams to be created in
-      // the quiescing state as there's a delay between stream channels (i.e. `HTTP2StreamChannel`)
-      // being created and us being notified about their creation (via a user event fired by
-      // the `HTTP2Handler`).
+      state.lastPeerInitiatedStreamID = streamID
       state.openStreams += 1
       self.state = .quiescing(state)
 
@@ -466,6 +469,18 @@ struct GRPCIdleHandlerStateMachine {
 
       if state.hasOpenStreams {
         operations.notifyConnectionManager(about: .quiescing)
+        switch state.role {
+        case .client:
+          // The server sent us a GOAWAY we'll just stop opening new streams and will send a GOAWAY
+          // frame before we close later.
+          ()
+        case .server:
+          // Client sent us a GOAWAY frame; we'll let the streams drain and then close. We'll tell
+          // the client that we're going away and send them a ping. When we receive the pong we will
+          // send another GOAWAY frame with a lower stream ID. In this case, the pong acts as an ack
+          // for the GOAWAY.
+          operations.sendGoAwayFrame(lastPeerInitiatedStreamID: .maxID, followWithPing: true)
+        }
         self.state = .quiescing(.init(fromOperating: state, initiatedByUs: false))
       } else {
         // No open streams, we can close as well.
@@ -494,6 +509,23 @@ struct GRPCIdleHandlerStateMachine {
     return operations
   }
 
+  mutating func ratchetDownGoAwayStreamID() -> Operations {
+    var operations: Operations = .none
+
+    switch self.state {
+    case let .quiescing(state):
+      let streamID = state.lastPeerInitiatedStreamID
+      operations.sendGoAwayFrame(lastPeerInitiatedStreamID: streamID)
+    case .operating, .waitingToIdle:
+      // We can only ratchet down the stream ID if we're already quiescing.
+      preconditionFailure()
+    case .closing, .closed:
+      ()
+    }
+
+    return operations
+  }
+
   mutating func receiveSettings(_ settings: HTTP2Settings) -> Operations {
     // Log the change in settings.
     self.logger.debug(

+ 21 - 9
Sources/GRPC/GRPCKeepaliveHandlers.swift

@@ -17,8 +17,11 @@ import NIOCore
 import NIOHTTP2
 
 struct PingHandler {
-  /// Code for ping
-  private let pingCode: UInt64
+  /// Opaque ping data used for keep-alive pings.
+  private let pingData: HTTP2PingData
+
+  /// Opaque ping data used for a ping sent after a GOAWAY frame.
+  internal let pingDataGoAway: HTTP2PingData
 
   /// The amount of time to wait before sending a keepalive ping.
   private let interval: TimeAmount
@@ -90,6 +93,7 @@ struct PingHandler {
     case schedulePing(delay: TimeAmount, timeout: TimeAmount)
     case cancelScheduledTimeout
     case reply(HTTP2Frame.FramePayload)
+    case ratchetDownLastSeenStreamID
   }
 
   init(
@@ -102,7 +106,8 @@ struct PingHandler {
     minimumReceivedPingIntervalWithoutData: TimeAmount? = nil,
     maximumPingStrikes: UInt? = nil
   ) {
-    self.pingCode = pingCode
+    self.pingData = HTTP2PingData(withInteger: pingCode)
+    self.pingDataGoAway = HTTP2PingData(withInteger: ~pingCode)
     self.interval = interval
     self.timeout = timeout
     self.permitWithoutCalls = permitWithoutCalls
@@ -137,8 +142,12 @@ struct PingHandler {
   }
 
   private func handlePong(_ pingData: HTTP2PingData) -> Action {
-    if pingData.integer == self.pingCode {
+    if pingData == self.pingData {
       return .cancelScheduledTimeout
+    } else if pingData == self.pingDataGoAway {
+      // We received a pong for a ping we sent to trail a GOAWAY frame: this means we can now
+      // send another GOAWAY frame with a (possibly) lower stream ID.
+      return .ratchetDownLastSeenStreamID
     } else {
       return .none
     }
@@ -161,14 +170,14 @@ struct PingHandler {
         // This is a valid ping, reset our strike count and reply with a pong.
         self.pingStrikes = 0
         self.lastReceivedPingDate = self.now()
-        return .reply(self.generatePingFrame(code: pingData.integer, ack: true))
+        return .reply(self.generatePingFrame(data: pingData, ack: true))
       }
     } else {
       // We don't support ping strikes. We'll just reply with a pong.
       //
       // Note: we don't need to update `pingStrikes` or `lastReceivedPingDate` as we don't
       // support ping strikes.
-      return .reply(self.generatePingFrame(code: pingData.integer, ack: true))
+      return .reply(self.generatePingFrame(data: pingData, ack: true))
     }
   }
 
@@ -176,17 +185,20 @@ struct PingHandler {
     if self.shouldBlockPing {
       return .none
     } else {
-      return .reply(self.generatePingFrame(code: self.pingCode, ack: false))
+      return .reply(self.generatePingFrame(data: self.pingData, ack: false))
     }
   }
 
-  private mutating func generatePingFrame(code: UInt64, ack: Bool) -> HTTP2Frame.FramePayload {
+  private mutating func generatePingFrame(
+    data: HTTP2PingData,
+    ack: Bool
+  ) -> HTTP2Frame.FramePayload {
     if self.activeStreams == 0 {
       self.sentPingsWithoutData += 1
     }
 
     self.lastSentPingDate = self.now()
-    return HTTP2Frame.FramePayload.ping(HTTP2PingData(withInteger: code), ack: ack)
+    return HTTP2Frame.FramePayload.ping(data, ack: ack)
   }
 
   /// Returns true if, on receipt of a ping, the ping should be regarded as a ping strike.

+ 50 - 0
Tests/GRPCTests/GRPCIdleHandlerStateMachineTests.swift

@@ -24,6 +24,10 @@ class GRPCIdleHandlerStateMachineTests: GRPCTestCase {
     return GRPCIdleHandlerStateMachine(role: .client, logger: self.clientLogger)
   }
 
+  private func makeServerStateMachine() -> GRPCIdleHandlerStateMachine {
+    return GRPCIdleHandlerStateMachine(role: .server, logger: self.serverLogger)
+  }
+
   private func makeNoOpScheduled() -> Scheduled<Void> {
     let loop = EmbeddedEventLoop()
     return loop.scheduleTask(deadline: .distantFuture) { return () }
@@ -469,6 +473,43 @@ class GRPCIdleHandlerStateMachineTests: GRPCTestCase {
     // The peer initiated shutdown by sending GOAWAY, we'll idle.
     op6.assertConnectionManager(.idle)
   }
+
+  func testClientSendsGoAwayAndOpensStream() {
+    var stateMachine = self.makeServerStateMachine()
+
+    let op1 = stateMachine.receiveSettings([])
+    op1.assertConnectionManager(.ready)
+    op1.assertScheduleIdleTimeout()
+
+    // Schedule the idle timeout.
+    let op2 = stateMachine.scheduledIdleTimeoutTask(self.makeNoOpScheduled())
+    op2.assertDoNothing()
+
+    // Create a stream to cancel the task.
+    let op3 = stateMachine.streamCreated(withID: 1)
+    op3.assertCancelIdleTimeout()
+
+    // Receive a GOAWAY frame from the client.
+    let op4 = stateMachine.receiveGoAway()
+    op4.assertGoAway(streamID: .maxID)
+    op4.assertShouldPingAfterGoAway()
+
+    // Create another stream. This is fine, the client hasn't ack'd the ping yet.
+    let op5 = stateMachine.streamCreated(withID: 7)
+    op5.assertDoNothing()
+
+    // Receiving the ping is handled by a different state machine which will tell us to ratchet
+    // down the go away stream ID.
+    let op6 = stateMachine.ratchetDownGoAwayStreamID()
+    op6.assertGoAway(streamID: 7)
+    op6.assertShouldNotPingAfterGoAway()
+
+    let op7 = stateMachine.streamClosed(withID: 7)
+    op7.assertDoNothing()
+
+    let op8 = stateMachine.streamClosed(withID: 1)
+    op8.assertShouldClose()
+  }
 }
 
 extension GRPCIdleHandlerStateMachine.Operations {
@@ -477,6 +518,7 @@ extension GRPCIdleHandlerStateMachine.Operations {
     XCTAssertNil(self.idleTask)
     XCTAssertNil(self.sendGoAwayWithLastPeerInitiatedStreamID)
     XCTAssertFalse(self.shouldCloseChannel)
+    XCTAssertFalse(self.shouldPingAfterGoAway)
   }
 
   func assertGoAway(streamID: HTTP2StreamID) {
@@ -524,4 +566,12 @@ extension GRPCIdleHandlerStateMachine.Operations {
   func assertShouldNotClose() {
     XCTAssertFalse(self.shouldCloseChannel)
   }
+
+  func assertShouldPingAfterGoAway() {
+    XCTAssert(self.shouldPingAfterGoAway)
+  }
+
+  func assertShouldNotPingAfterGoAway() {
+    XCTAssertFalse(self.shouldPingAfterGoAway)
+  }
 }

+ 8 - 0
Tests/GRPCTests/GRPCPingHandlerTests.swift

@@ -347,6 +347,12 @@ class GRPCPingHandlerTests: GRPCTestCase {
     )
   }
 
+  func testPongWithGoAwayPingData() {
+    self.setupPingHandler()
+    let response = self.pingHandler.read(pingData: self.pingHandler.pingDataGoAway, ack: true)
+    XCTAssertEqual(response, .ratchetDownLastSeenStreamID)
+  }
+
   private func setupPingHandler(
     pingCode: UInt64 = 1,
     interval: TimeAmount = .seconds(15),
@@ -379,6 +385,8 @@ extension PingHandler.Action: Equatable {
       return lhsDelay == rhsDelay && lhsTimeout == rhsTimeout
     case (.cancelScheduledTimeout, .cancelScheduledTimeout):
       return true
+    case (.ratchetDownLastSeenStreamID, .ratchetDownLastSeenStreamID):
+      return true
     case let (.reply(lhsPayload), .reply(rhsPayload)):
       switch (lhsPayload, rhsPayload) {
       case (let .ping(lhsData, ack: lhsAck), let .ping(rhsData, ack: rhsAck)):

+ 5 - 0
Tests/GRPCTests/ServerFuzzingRegressionTests.swift

@@ -83,4 +83,9 @@ final class ServerFuzzingRegressionTests: GRPCTestCase {
     let name = "clusterfuzz-testcase-minimized-ServerFuzzer-release-5285159577452544"
     XCTAssertNoThrow(try self.runTest(withInputNamed: name))
   }
+
+  func testFuzzCase_release_4739158818553856() {
+    let name = "clusterfuzz-testcase-minimized-ServerFuzzer-release-4739158818553856"
+    XCTAssertNoThrow(try self.runTest(withInputNamed: name))
+  }
 }