Pārlūkot izejas kodu

Better handle client sending GOAWAY

Motivation:

The code paths for handling quiescing made an incorrect assumption about
who may open streams while in the quiescing state. This was mostly fine
but fell apart when the client sends a GOAWAY frame to the server and
then opens a stream. This results in the server crashing.

Modifications:

- Remove the unnecessary precondition
- Send a GOAWAY frame to the client if the client sends a GOAWAY frame
  to the server and the server had not previously indicated that it was
  going away. This is followed by a ping whose pong is used as an ack
  that the client has received the GOAWAY and to then ratchet down the
  stream ID in the GOAWAY frame.
- Tests

Result:

Server better handles the client sending it a GOAWAY.
George Barnett 3 gadi atpakaļ
vecāks
revīzija
e09cf661a5

BIN
FuzzTesting/FailCases/clusterfuzz-testcase-minimized-ServerFuzzer-release-4739158818553856


+ 16 - 1
Sources/GRPC/GRPCIdleHandler.swift

@@ -153,7 +153,19 @@ internal final class GRPCIdleHandler: ChannelInboundHandler {
         streamID: .rootStream,
         streamID: .rootStream,
         payload: .goAway(lastStreamID: streamID, errorCode: .noError, opaqueData: nil)
         payload: .goAway(lastStreamID: streamID, errorCode: .noError, opaqueData: nil)
       )
       )
-      self.context?.writeAndFlush(self.wrapOutboundOut(goAwayFrame), promise: nil)
+
+      self.context?.write(self.wrapOutboundOut(goAwayFrame), promise: nil)
+
+      // We emit a ping after some GOAWAY frames.
+      if operations.shouldPingAfterGoAway {
+        let pingFrame = HTTP2Frame(
+          streamID: .rootStream,
+          payload: .ping(self.pingHandler.pingDataGoAway, ack: false)
+        )
+        self.context?.write(self.wrapOutboundOut(pingFrame), promise: nil)
+      }
+
+      self.context?.flush()
     }
     }
 
 
     // Close the channel, if necessary.
     // Close the channel, if necessary.
@@ -181,6 +193,9 @@ internal final class GRPCIdleHandler: ChannelInboundHandler {
     case let .reply(framePayload):
     case let .reply(framePayload):
       let frame = HTTP2Frame(streamID: .rootStream, payload: framePayload)
       let frame = HTTP2Frame(streamID: .rootStream, payload: framePayload)
       self.context?.writeAndFlush(self.wrapOutboundOut(frame), promise: nil)
       self.context?.writeAndFlush(self.wrapOutboundOut(frame), promise: nil)
+
+    case .ratchetDownLastSeenStreamID:
+      self.perform(operations: self.stateMachine.ratchetDownGoAwayStreamID())
     }
     }
   }
   }
 
 

+ 39 - 7
Sources/GRPC/GRPCIdleHandlerStateMachine.swift

@@ -189,10 +189,17 @@ struct GRPCIdleHandlerStateMachine {
     /// Whether the channel should be closed.
     /// Whether the channel should be closed.
     private(set) var shouldCloseChannel: Bool
     private(set) var shouldCloseChannel: Bool
 
 
+    /// Whether a ping should be sent after a GOAWAY frame.
+    private(set) var shouldPingAfterGoAway: Bool
+
     fileprivate static let none = Operations()
     fileprivate static let none = Operations()
 
 
-    fileprivate mutating func sendGoAwayFrame(lastPeerInitiatedStreamID streamID: HTTP2StreamID) {
+    fileprivate mutating func sendGoAwayFrame(
+      lastPeerInitiatedStreamID streamID: HTTP2StreamID,
+      followWithPing: Bool = false
+    ) {
       self.sendGoAwayWithLastPeerInitiatedStreamID = streamID
       self.sendGoAwayWithLastPeerInitiatedStreamID = streamID
+      self.shouldPingAfterGoAway = followWithPing
     }
     }
 
 
     fileprivate mutating func cancelIdleTask(_ task: Scheduled<Void>) {
     fileprivate mutating func cancelIdleTask(_ task: Scheduled<Void>) {
@@ -220,6 +227,7 @@ struct GRPCIdleHandlerStateMachine {
       self.idleTask = nil
       self.idleTask = nil
       self.sendGoAwayWithLastPeerInitiatedStreamID = nil
       self.sendGoAwayWithLastPeerInitiatedStreamID = nil
       self.shouldCloseChannel = false
       self.shouldCloseChannel = false
+      self.shouldPingAfterGoAway = false
     }
     }
   }
   }
 
 
@@ -267,12 +275,7 @@ struct GRPCIdleHandlerStateMachine {
       operations.cancelIdleTask(state.idleTask)
       operations.cancelIdleTask(state.idleTask)
 
 
     case var .quiescing(state):
     case var .quiescing(state):
-      precondition(state.initiatedByUs)
-      precondition(state.role == .client)
-      // If we're a client and we initiated shutdown then it's possible for streams to be created in
-      // the quiescing state as there's a delay between stream channels (i.e. `HTTP2StreamChannel`)
-      // being created and us being notified about their creation (via a user event fired by
-      // the `HTTP2Handler`).
+      state.lastPeerInitiatedStreamID = streamID
       state.openStreams += 1
       state.openStreams += 1
       self.state = .quiescing(state)
       self.state = .quiescing(state)
 
 
@@ -466,6 +469,18 @@ struct GRPCIdleHandlerStateMachine {
 
 
       if state.hasOpenStreams {
       if state.hasOpenStreams {
         operations.notifyConnectionManager(about: .quiescing)
         operations.notifyConnectionManager(about: .quiescing)
+        switch state.role {
+        case .client:
+          // The server sent us a GOAWAY we'll just stop opening new streams and will send a GOAWAY
+          // frame before we close later.
+          ()
+        case .server:
+          // Client sent us a GOAWAY frame; we'll let the streams drain and then close. We'll tell
+          // the client that we're going away and send them a ping. When we receive the pong we will
+          // send another GOAWAY frame with a lower stream ID. In this case, the pong acts as an ack
+          // for the GOAWAY.
+          operations.sendGoAwayFrame(lastPeerInitiatedStreamID: .maxID, followWithPing: true)
+        }
         self.state = .quiescing(.init(fromOperating: state, initiatedByUs: false))
         self.state = .quiescing(.init(fromOperating: state, initiatedByUs: false))
       } else {
       } else {
         // No open streams, we can close as well.
         // No open streams, we can close as well.
@@ -494,6 +509,23 @@ struct GRPCIdleHandlerStateMachine {
     return operations
     return operations
   }
   }
 
 
+  mutating func ratchetDownGoAwayStreamID() -> Operations {
+    var operations: Operations = .none
+
+    switch self.state {
+    case let .quiescing(state):
+      let streamID = state.lastPeerInitiatedStreamID
+      operations.sendGoAwayFrame(lastPeerInitiatedStreamID: streamID)
+    case .operating, .waitingToIdle:
+      // We can only ratchet down the stream ID if we're already quiescing.
+      preconditionFailure()
+    case .closing, .closed:
+      ()
+    }
+
+    return operations
+  }
+
   mutating func receiveSettings(_ settings: HTTP2Settings) -> Operations {
   mutating func receiveSettings(_ settings: HTTP2Settings) -> Operations {
     // Log the change in settings.
     // Log the change in settings.
     self.logger.debug(
     self.logger.debug(

+ 21 - 9
Sources/GRPC/GRPCKeepaliveHandlers.swift

@@ -17,8 +17,11 @@ import NIOCore
 import NIOHTTP2
 import NIOHTTP2
 
 
 struct PingHandler {
 struct PingHandler {
-  /// Code for ping
-  private let pingCode: UInt64
+  /// Opaque ping data used for keep-alive pings.
+  private let pingData: HTTP2PingData
+
+  /// Opaque ping data used for a ping sent after a GOAWAY frame.
+  internal let pingDataGoAway: HTTP2PingData
 
 
   /// The amount of time to wait before sending a keepalive ping.
   /// The amount of time to wait before sending a keepalive ping.
   private let interval: TimeAmount
   private let interval: TimeAmount
@@ -90,6 +93,7 @@ struct PingHandler {
     case schedulePing(delay: TimeAmount, timeout: TimeAmount)
     case schedulePing(delay: TimeAmount, timeout: TimeAmount)
     case cancelScheduledTimeout
     case cancelScheduledTimeout
     case reply(HTTP2Frame.FramePayload)
     case reply(HTTP2Frame.FramePayload)
+    case ratchetDownLastSeenStreamID
   }
   }
 
 
   init(
   init(
@@ -102,7 +106,8 @@ struct PingHandler {
     minimumReceivedPingIntervalWithoutData: TimeAmount? = nil,
     minimumReceivedPingIntervalWithoutData: TimeAmount? = nil,
     maximumPingStrikes: UInt? = nil
     maximumPingStrikes: UInt? = nil
   ) {
   ) {
-    self.pingCode = pingCode
+    self.pingData = HTTP2PingData(withInteger: pingCode)
+    self.pingDataGoAway = HTTP2PingData(withInteger: ~pingCode)
     self.interval = interval
     self.interval = interval
     self.timeout = timeout
     self.timeout = timeout
     self.permitWithoutCalls = permitWithoutCalls
     self.permitWithoutCalls = permitWithoutCalls
@@ -137,8 +142,12 @@ struct PingHandler {
   }
   }
 
 
   private func handlePong(_ pingData: HTTP2PingData) -> Action {
   private func handlePong(_ pingData: HTTP2PingData) -> Action {
-    if pingData.integer == self.pingCode {
+    if pingData == self.pingData {
       return .cancelScheduledTimeout
       return .cancelScheduledTimeout
+    } else if pingData == self.pingDataGoAway {
+      // We received a pong for a ping we sent to trail a GOAWAY frame: this means we can now
+      // send another GOAWAY frame with a (possibly) lower stream ID.
+      return .ratchetDownLastSeenStreamID
     } else {
     } else {
       return .none
       return .none
     }
     }
@@ -161,14 +170,14 @@ struct PingHandler {
         // This is a valid ping, reset our strike count and reply with a pong.
         // This is a valid ping, reset our strike count and reply with a pong.
         self.pingStrikes = 0
         self.pingStrikes = 0
         self.lastReceivedPingDate = self.now()
         self.lastReceivedPingDate = self.now()
-        return .reply(self.generatePingFrame(code: pingData.integer, ack: true))
+        return .reply(self.generatePingFrame(data: pingData, ack: true))
       }
       }
     } else {
     } else {
       // We don't support ping strikes. We'll just reply with a pong.
       // We don't support ping strikes. We'll just reply with a pong.
       //
       //
       // Note: we don't need to update `pingStrikes` or `lastReceivedPingDate` as we don't
       // Note: we don't need to update `pingStrikes` or `lastReceivedPingDate` as we don't
       // support ping strikes.
       // support ping strikes.
-      return .reply(self.generatePingFrame(code: pingData.integer, ack: true))
+      return .reply(self.generatePingFrame(data: pingData, ack: true))
     }
     }
   }
   }
 
 
@@ -176,17 +185,20 @@ struct PingHandler {
     if self.shouldBlockPing {
     if self.shouldBlockPing {
       return .none
       return .none
     } else {
     } else {
-      return .reply(self.generatePingFrame(code: self.pingCode, ack: false))
+      return .reply(self.generatePingFrame(data: self.pingData, ack: false))
     }
     }
   }
   }
 
 
-  private mutating func generatePingFrame(code: UInt64, ack: Bool) -> HTTP2Frame.FramePayload {
+  private mutating func generatePingFrame(
+    data: HTTP2PingData,
+    ack: Bool
+  ) -> HTTP2Frame.FramePayload {
     if self.activeStreams == 0 {
     if self.activeStreams == 0 {
       self.sentPingsWithoutData += 1
       self.sentPingsWithoutData += 1
     }
     }
 
 
     self.lastSentPingDate = self.now()
     self.lastSentPingDate = self.now()
-    return HTTP2Frame.FramePayload.ping(HTTP2PingData(withInteger: code), ack: ack)
+    return HTTP2Frame.FramePayload.ping(data, ack: ack)
   }
   }
 
 
   /// Returns true if, on receipt of a ping, the ping should be regarded as a ping strike.
   /// Returns true if, on receipt of a ping, the ping should be regarded as a ping strike.

+ 50 - 0
Tests/GRPCTests/GRPCIdleHandlerStateMachineTests.swift

@@ -24,6 +24,10 @@ class GRPCIdleHandlerStateMachineTests: GRPCTestCase {
     return GRPCIdleHandlerStateMachine(role: .client, logger: self.clientLogger)
     return GRPCIdleHandlerStateMachine(role: .client, logger: self.clientLogger)
   }
   }
 
 
+  private func makeServerStateMachine() -> GRPCIdleHandlerStateMachine {
+    return GRPCIdleHandlerStateMachine(role: .server, logger: self.serverLogger)
+  }
+
   private func makeNoOpScheduled() -> Scheduled<Void> {
   private func makeNoOpScheduled() -> Scheduled<Void> {
     let loop = EmbeddedEventLoop()
     let loop = EmbeddedEventLoop()
     return loop.scheduleTask(deadline: .distantFuture) { return () }
     return loop.scheduleTask(deadline: .distantFuture) { return () }
@@ -469,6 +473,43 @@ class GRPCIdleHandlerStateMachineTests: GRPCTestCase {
     // The peer initiated shutdown by sending GOAWAY, we'll idle.
     // The peer initiated shutdown by sending GOAWAY, we'll idle.
     op6.assertConnectionManager(.idle)
     op6.assertConnectionManager(.idle)
   }
   }
+
+  func testClientSendsGoAwayAndOpensStream() {
+    var stateMachine = self.makeServerStateMachine()
+
+    let op1 = stateMachine.receiveSettings([])
+    op1.assertConnectionManager(.ready)
+    op1.assertScheduleIdleTimeout()
+
+    // Schedule the idle timeout.
+    let op2 = stateMachine.scheduledIdleTimeoutTask(self.makeNoOpScheduled())
+    op2.assertDoNothing()
+
+    // Create a stream to cancel the task.
+    let op3 = stateMachine.streamCreated(withID: 1)
+    op3.assertCancelIdleTimeout()
+
+    // Receive a GOAWAY frame from the client.
+    let op4 = stateMachine.receiveGoAway()
+    op4.assertGoAway(streamID: .maxID)
+    op4.assertShouldPingAfterGoAway()
+
+    // Create another stream. This is fine, the client hasn't ack'd the ping yet.
+    let op5 = stateMachine.streamCreated(withID: 7)
+    op5.assertDoNothing()
+
+    // Receiving the ping is handled by a different state machine which will tell us to ratchet
+    // down the go away stream ID.
+    let op6 = stateMachine.ratchetDownGoAwayStreamID()
+    op6.assertGoAway(streamID: 7)
+    op6.assertShouldNotPingAfterGoAway()
+
+    let op7 = stateMachine.streamClosed(withID: 7)
+    op7.assertDoNothing()
+
+    let op8 = stateMachine.streamClosed(withID: 1)
+    op8.assertShouldClose()
+  }
 }
 }
 
 
 extension GRPCIdleHandlerStateMachine.Operations {
 extension GRPCIdleHandlerStateMachine.Operations {
@@ -477,6 +518,7 @@ extension GRPCIdleHandlerStateMachine.Operations {
     XCTAssertNil(self.idleTask)
     XCTAssertNil(self.idleTask)
     XCTAssertNil(self.sendGoAwayWithLastPeerInitiatedStreamID)
     XCTAssertNil(self.sendGoAwayWithLastPeerInitiatedStreamID)
     XCTAssertFalse(self.shouldCloseChannel)
     XCTAssertFalse(self.shouldCloseChannel)
+    XCTAssertFalse(self.shouldPingAfterGoAway)
   }
   }
 
 
   func assertGoAway(streamID: HTTP2StreamID) {
   func assertGoAway(streamID: HTTP2StreamID) {
@@ -524,4 +566,12 @@ extension GRPCIdleHandlerStateMachine.Operations {
   func assertShouldNotClose() {
   func assertShouldNotClose() {
     XCTAssertFalse(self.shouldCloseChannel)
     XCTAssertFalse(self.shouldCloseChannel)
   }
   }
+
+  func assertShouldPingAfterGoAway() {
+    XCTAssert(self.shouldPingAfterGoAway)
+  }
+
+  func assertShouldNotPingAfterGoAway() {
+    XCTAssertFalse(self.shouldPingAfterGoAway)
+  }
 }
 }

+ 8 - 0
Tests/GRPCTests/GRPCPingHandlerTests.swift

@@ -347,6 +347,12 @@ class GRPCPingHandlerTests: GRPCTestCase {
     )
     )
   }
   }
 
 
+  func testPongWithGoAwayPingData() {
+    self.setupPingHandler()
+    let response = self.pingHandler.read(pingData: self.pingHandler.pingDataGoAway, ack: true)
+    XCTAssertEqual(response, .ratchetDownLastSeenStreamID)
+  }
+
   private func setupPingHandler(
   private func setupPingHandler(
     pingCode: UInt64 = 1,
     pingCode: UInt64 = 1,
     interval: TimeAmount = .seconds(15),
     interval: TimeAmount = .seconds(15),
@@ -379,6 +385,8 @@ extension PingHandler.Action: Equatable {
       return lhsDelay == rhsDelay && lhsTimeout == rhsTimeout
       return lhsDelay == rhsDelay && lhsTimeout == rhsTimeout
     case (.cancelScheduledTimeout, .cancelScheduledTimeout):
     case (.cancelScheduledTimeout, .cancelScheduledTimeout):
       return true
       return true
+    case (.ratchetDownLastSeenStreamID, .ratchetDownLastSeenStreamID):
+      return true
     case let (.reply(lhsPayload), .reply(rhsPayload)):
     case let (.reply(lhsPayload), .reply(rhsPayload)):
       switch (lhsPayload, rhsPayload) {
       switch (lhsPayload, rhsPayload) {
       case (let .ping(lhsData, ack: lhsAck), let .ping(rhsData, ack: rhsAck)):
       case (let .ping(lhsData, ack: lhsAck), let .ping(rhsData, ack: rhsAck)):

+ 5 - 0
Tests/GRPCTests/ServerFuzzingRegressionTests.swift

@@ -83,4 +83,9 @@ final class ServerFuzzingRegressionTests: GRPCTestCase {
     let name = "clusterfuzz-testcase-minimized-ServerFuzzer-release-5285159577452544"
     let name = "clusterfuzz-testcase-minimized-ServerFuzzer-release-5285159577452544"
     XCTAssertNoThrow(try self.runTest(withInputNamed: name))
     XCTAssertNoThrow(try self.runTest(withInputNamed: name))
   }
   }
+
+  func testFuzzCase_release_4739158818553856() {
+    let name = "clusterfuzz-testcase-minimized-ServerFuzzer-release-4739158818553856"
+    XCTAssertNoThrow(try self.runTest(withInputNamed: name))
+  }
 }
 }