Просмотр исходного кода

Better handle client sending GOAWAY

Motivation:

The code paths for handling quiescing made an incorrect assumption about
who may open streams while in the quiescing state. This was mostly fine
but fell apart when the client sends a GOAWAY frame to the server and
then opens a stream. This results in the server crashing.

Modifications:

- Remove the unnecessary precondition
- Send a GOAWAY frame to the client if the client sends a GOAWAY frame
  to the server and the server had not previously indicated that it was
  going away. This is followed by a ping whose pong is used as an ack
  that the client has received the GOAWAY and to then ratchet down the
  stream ID in the GOAWAY frame.
- Tests

Result:

Server better handles the client sending it a GOAWAY.
George Barnett 3 лет назад
Родитель
Сommit
e09cf661a5

BIN
FuzzTesting/FailCases/clusterfuzz-testcase-minimized-ServerFuzzer-release-4739158818553856


+ 16 - 1
Sources/GRPC/GRPCIdleHandler.swift

@@ -153,7 +153,19 @@ internal final class GRPCIdleHandler: ChannelInboundHandler {
         streamID: .rootStream,
         payload: .goAway(lastStreamID: streamID, errorCode: .noError, opaqueData: nil)
       )
-      self.context?.writeAndFlush(self.wrapOutboundOut(goAwayFrame), promise: nil)
+
+      self.context?.write(self.wrapOutboundOut(goAwayFrame), promise: nil)
+
+      // We emit a ping after some GOAWAY frames.
+      if operations.shouldPingAfterGoAway {
+        let pingFrame = HTTP2Frame(
+          streamID: .rootStream,
+          payload: .ping(self.pingHandler.pingDataGoAway, ack: false)
+        )
+        self.context?.write(self.wrapOutboundOut(pingFrame), promise: nil)
+      }
+
+      self.context?.flush()
     }
 
     // Close the channel, if necessary.
@@ -181,6 +193,9 @@ internal final class GRPCIdleHandler: ChannelInboundHandler {
     case let .reply(framePayload):
       let frame = HTTP2Frame(streamID: .rootStream, payload: framePayload)
       self.context?.writeAndFlush(self.wrapOutboundOut(frame), promise: nil)
+
+    case .ratchetDownLastSeenStreamID:
+      self.perform(operations: self.stateMachine.ratchetDownGoAwayStreamID())
     }
   }
 

+ 39 - 7
Sources/GRPC/GRPCIdleHandlerStateMachine.swift

@@ -189,10 +189,17 @@ struct GRPCIdleHandlerStateMachine {
     /// Whether the channel should be closed.
     private(set) var shouldCloseChannel: Bool
 
+    /// Whether a ping should be sent after a GOAWAY frame.
+    private(set) var shouldPingAfterGoAway: Bool
+
     fileprivate static let none = Operations()
 
-    fileprivate mutating func sendGoAwayFrame(lastPeerInitiatedStreamID streamID: HTTP2StreamID) {
+    fileprivate mutating func sendGoAwayFrame(
+      lastPeerInitiatedStreamID streamID: HTTP2StreamID,
+      followWithPing: Bool = false
+    ) {
       self.sendGoAwayWithLastPeerInitiatedStreamID = streamID
+      self.shouldPingAfterGoAway = followWithPing
     }
 
     fileprivate mutating func cancelIdleTask(_ task: Scheduled<Void>) {
@@ -220,6 +227,7 @@ struct GRPCIdleHandlerStateMachine {
       self.idleTask = nil
       self.sendGoAwayWithLastPeerInitiatedStreamID = nil
       self.shouldCloseChannel = false
+      self.shouldPingAfterGoAway = false
     }
   }
 
@@ -267,12 +275,7 @@ struct GRPCIdleHandlerStateMachine {
       operations.cancelIdleTask(state.idleTask)
 
     case var .quiescing(state):
-      precondition(state.initiatedByUs)
-      precondition(state.role == .client)
-      // If we're a client and we initiated shutdown then it's possible for streams to be created in
-      // the quiescing state as there's a delay between stream channels (i.e. `HTTP2StreamChannel`)
-      // being created and us being notified about their creation (via a user event fired by
-      // the `HTTP2Handler`).
+      state.lastPeerInitiatedStreamID = streamID
       state.openStreams += 1
       self.state = .quiescing(state)
 
@@ -466,6 +469,18 @@ struct GRPCIdleHandlerStateMachine {
 
       if state.hasOpenStreams {
         operations.notifyConnectionManager(about: .quiescing)
+        switch state.role {
+        case .client:
+          // The server sent us a GOAWAY we'll just stop opening new streams and will send a GOAWAY
+          // frame before we close later.
+          ()
+        case .server:
+          // Client sent us a GOAWAY frame; we'll let the streams drain and then close. We'll tell
+          // the client that we're going away and send them a ping. When we receive the pong we will
+          // send another GOAWAY frame with a lower stream ID. In this case, the pong acts as an ack
+          // for the GOAWAY.
+          operations.sendGoAwayFrame(lastPeerInitiatedStreamID: .maxID, followWithPing: true)
+        }
         self.state = .quiescing(.init(fromOperating: state, initiatedByUs: false))
       } else {
         // No open streams, we can close as well.
@@ -494,6 +509,23 @@ struct GRPCIdleHandlerStateMachine {
     return operations
   }
 
+  mutating func ratchetDownGoAwayStreamID() -> Operations {
+    var operations: Operations = .none
+
+    switch self.state {
+    case let .quiescing(state):
+      let streamID = state.lastPeerInitiatedStreamID
+      operations.sendGoAwayFrame(lastPeerInitiatedStreamID: streamID)
+    case .operating, .waitingToIdle:
+      // We can only ratchet down the stream ID if we're already quiescing.
+      preconditionFailure()
+    case .closing, .closed:
+      ()
+    }
+
+    return operations
+  }
+
   mutating func receiveSettings(_ settings: HTTP2Settings) -> Operations {
     // Log the change in settings.
     self.logger.debug(

+ 21 - 9
Sources/GRPC/GRPCKeepaliveHandlers.swift

@@ -17,8 +17,11 @@ import NIOCore
 import NIOHTTP2
 
 struct PingHandler {
-  /// Code for ping
-  private let pingCode: UInt64
+  /// Opaque ping data used for keep-alive pings.
+  private let pingData: HTTP2PingData
+
+  /// Opaque ping data used for a ping sent after a GOAWAY frame.
+  internal let pingDataGoAway: HTTP2PingData
 
   /// The amount of time to wait before sending a keepalive ping.
   private let interval: TimeAmount
@@ -90,6 +93,7 @@ struct PingHandler {
     case schedulePing(delay: TimeAmount, timeout: TimeAmount)
     case cancelScheduledTimeout
     case reply(HTTP2Frame.FramePayload)
+    case ratchetDownLastSeenStreamID
   }
 
   init(
@@ -102,7 +106,8 @@ struct PingHandler {
     minimumReceivedPingIntervalWithoutData: TimeAmount? = nil,
     maximumPingStrikes: UInt? = nil
   ) {
-    self.pingCode = pingCode
+    self.pingData = HTTP2PingData(withInteger: pingCode)
+    self.pingDataGoAway = HTTP2PingData(withInteger: ~pingCode)
     self.interval = interval
     self.timeout = timeout
     self.permitWithoutCalls = permitWithoutCalls
@@ -137,8 +142,12 @@ struct PingHandler {
   }
 
   private func handlePong(_ pingData: HTTP2PingData) -> Action {
-    if pingData.integer == self.pingCode {
+    if pingData == self.pingData {
       return .cancelScheduledTimeout
+    } else if pingData == self.pingDataGoAway {
+      // We received a pong for a ping we sent to trail a GOAWAY frame: this means we can now
+      // send another GOAWAY frame with a (possibly) lower stream ID.
+      return .ratchetDownLastSeenStreamID
     } else {
       return .none
     }
@@ -161,14 +170,14 @@ struct PingHandler {
         // This is a valid ping, reset our strike count and reply with a pong.
         self.pingStrikes = 0
         self.lastReceivedPingDate = self.now()
-        return .reply(self.generatePingFrame(code: pingData.integer, ack: true))
+        return .reply(self.generatePingFrame(data: pingData, ack: true))
       }
     } else {
       // We don't support ping strikes. We'll just reply with a pong.
       //
       // Note: we don't need to update `pingStrikes` or `lastReceivedPingDate` as we don't
       // support ping strikes.
-      return .reply(self.generatePingFrame(code: pingData.integer, ack: true))
+      return .reply(self.generatePingFrame(data: pingData, ack: true))
     }
   }
 
@@ -176,17 +185,20 @@ struct PingHandler {
     if self.shouldBlockPing {
       return .none
     } else {
-      return .reply(self.generatePingFrame(code: self.pingCode, ack: false))
+      return .reply(self.generatePingFrame(data: self.pingData, ack: false))
     }
   }
 
-  private mutating func generatePingFrame(code: UInt64, ack: Bool) -> HTTP2Frame.FramePayload {
+  private mutating func generatePingFrame(
+    data: HTTP2PingData,
+    ack: Bool
+  ) -> HTTP2Frame.FramePayload {
     if self.activeStreams == 0 {
       self.sentPingsWithoutData += 1
     }
 
     self.lastSentPingDate = self.now()
-    return HTTP2Frame.FramePayload.ping(HTTP2PingData(withInteger: code), ack: ack)
+    return HTTP2Frame.FramePayload.ping(data, ack: ack)
   }
 
   /// Returns true if, on receipt of a ping, the ping should be regarded as a ping strike.

+ 50 - 0
Tests/GRPCTests/GRPCIdleHandlerStateMachineTests.swift

@@ -24,6 +24,10 @@ class GRPCIdleHandlerStateMachineTests: GRPCTestCase {
     return GRPCIdleHandlerStateMachine(role: .client, logger: self.clientLogger)
   }
 
+  private func makeServerStateMachine() -> GRPCIdleHandlerStateMachine {
+    return GRPCIdleHandlerStateMachine(role: .server, logger: self.serverLogger)
+  }
+
   private func makeNoOpScheduled() -> Scheduled<Void> {
     let loop = EmbeddedEventLoop()
     return loop.scheduleTask(deadline: .distantFuture) { return () }
@@ -469,6 +473,43 @@ class GRPCIdleHandlerStateMachineTests: GRPCTestCase {
     // The peer initiated shutdown by sending GOAWAY, we'll idle.
     op6.assertConnectionManager(.idle)
   }
+
+  func testClientSendsGoAwayAndOpensStream() {
+    var stateMachine = self.makeServerStateMachine()
+
+    let op1 = stateMachine.receiveSettings([])
+    op1.assertConnectionManager(.ready)
+    op1.assertScheduleIdleTimeout()
+
+    // Schedule the idle timeout.
+    let op2 = stateMachine.scheduledIdleTimeoutTask(self.makeNoOpScheduled())
+    op2.assertDoNothing()
+
+    // Create a stream to cancel the task.
+    let op3 = stateMachine.streamCreated(withID: 1)
+    op3.assertCancelIdleTimeout()
+
+    // Receive a GOAWAY frame from the client.
+    let op4 = stateMachine.receiveGoAway()
+    op4.assertGoAway(streamID: .maxID)
+    op4.assertShouldPingAfterGoAway()
+
+    // Create another stream. This is fine, the client hasn't ack'd the ping yet.
+    let op5 = stateMachine.streamCreated(withID: 7)
+    op5.assertDoNothing()
+
+    // Receiving the ping is handled by a different state machine which will tell us to ratchet
+    // down the go away stream ID.
+    let op6 = stateMachine.ratchetDownGoAwayStreamID()
+    op6.assertGoAway(streamID: 7)
+    op6.assertShouldNotPingAfterGoAway()
+
+    let op7 = stateMachine.streamClosed(withID: 7)
+    op7.assertDoNothing()
+
+    let op8 = stateMachine.streamClosed(withID: 1)
+    op8.assertShouldClose()
+  }
 }
 
 extension GRPCIdleHandlerStateMachine.Operations {
@@ -477,6 +518,7 @@ extension GRPCIdleHandlerStateMachine.Operations {
     XCTAssertNil(self.idleTask)
     XCTAssertNil(self.sendGoAwayWithLastPeerInitiatedStreamID)
     XCTAssertFalse(self.shouldCloseChannel)
+    XCTAssertFalse(self.shouldPingAfterGoAway)
   }
 
   func assertGoAway(streamID: HTTP2StreamID) {
@@ -524,4 +566,12 @@ extension GRPCIdleHandlerStateMachine.Operations {
   func assertShouldNotClose() {
     XCTAssertFalse(self.shouldCloseChannel)
   }
+
+  func assertShouldPingAfterGoAway() {
+    XCTAssert(self.shouldPingAfterGoAway)
+  }
+
+  func assertShouldNotPingAfterGoAway() {
+    XCTAssertFalse(self.shouldPingAfterGoAway)
+  }
 }

+ 8 - 0
Tests/GRPCTests/GRPCPingHandlerTests.swift

@@ -347,6 +347,12 @@ class GRPCPingHandlerTests: GRPCTestCase {
     )
   }
 
+  func testPongWithGoAwayPingData() {
+    self.setupPingHandler()
+    let response = self.pingHandler.read(pingData: self.pingHandler.pingDataGoAway, ack: true)
+    XCTAssertEqual(response, .ratchetDownLastSeenStreamID)
+  }
+
   private func setupPingHandler(
     pingCode: UInt64 = 1,
     interval: TimeAmount = .seconds(15),
@@ -379,6 +385,8 @@ extension PingHandler.Action: Equatable {
       return lhsDelay == rhsDelay && lhsTimeout == rhsTimeout
     case (.cancelScheduledTimeout, .cancelScheduledTimeout):
       return true
+    case (.ratchetDownLastSeenStreamID, .ratchetDownLastSeenStreamID):
+      return true
     case let (.reply(lhsPayload), .reply(rhsPayload)):
       switch (lhsPayload, rhsPayload) {
       case (let .ping(lhsData, ack: lhsAck), let .ping(rhsData, ack: rhsAck)):

+ 5 - 0
Tests/GRPCTests/ServerFuzzingRegressionTests.swift

@@ -83,4 +83,9 @@ final class ServerFuzzingRegressionTests: GRPCTestCase {
     let name = "clusterfuzz-testcase-minimized-ServerFuzzer-release-5285159577452544"
     XCTAssertNoThrow(try self.runTest(withInputNamed: name))
   }
+
+  func testFuzzCase_release_4739158818553856() {
+    let name = "clusterfuzz-testcase-minimized-ServerFuzzer-release-4739158818553856"
+    XCTAssertNoThrow(try self.runTest(withInputNamed: name))
+  }
 }