Browse Source

Add a handler for managing connections on the server (#1762)

Motivation:

Servers must manage connections created by clients. Part of this is
gracefully closing connections (by sending GOAWAY frames and ratcheting
down the last stream ID) in response to various conditions: the client
sending too many pings, the connection being idle too long, the
connection existing for longer than some configured limit, etc.

A previous change added a state machine which handles much of this
behaviour. This change adds a channel handler which builds on top of
that state machine.

Modifications:

- Add a channel handler for managing connections on the server.

Result:

We have a handler in place which can manage connections on the server.
George Barnett 2 years ago
parent
commit
941f500c3f

+ 4 - 1
Package.swift

@@ -306,7 +306,10 @@ extension Target {
   static let grpcHTTP2CoreTests: Target = .testTarget(
     name: "GRPCHTTP2CoreTests",
     dependencies: [
-      .grpcHTTP2Core
+      .grpcHTTP2Core,
+      .nioCore,
+      .nioHTTP2,
+      .nioEmbedded,
     ]
   )
   

+ 0 - 19
Sources/GRPCHTTP2Core/Server/Connection/ServerConnectionHandler.swift

@@ -1,19 +0,0 @@
-/*
- * Copyright 2024, gRPC Authors All rights reserved.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// Temporary namespace. Will be replaced with a channel handler.
-enum ServerConnectionHandler {
-}

+ 4 - 5
Sources/GRPCHTTP2Core/Server/Connection/ServerConnectionHandler+StateMachine.swift → Sources/GRPCHTTP2Core/Server/Connection/ServerConnectionManagementHandler+StateMachine.swift

@@ -17,7 +17,7 @@
 import NIOCore
 import NIOHTTP2
 
-extension ServerConnectionHandler {
+extension ServerConnectionManagementHandler {
   /// Tracks the state of TCP connections at the server.
   ///
   /// The state machine manages the state for the graceful shutdown procedure as well as policing
@@ -248,7 +248,7 @@ extension ServerConnectionHandler {
   }
 }
 
-extension ServerConnectionHandler.StateMachine {
+extension ServerConnectionManagementHandler.StateMachine {
   fileprivate struct KeepAlive {
     /// Allow the client to send keep alive pings when there are no active calls.
     private let allowWithoutCalls: Bool
@@ -267,8 +267,7 @@ extension ServerConnectionHandler.StateMachine {
     /// alive (a low number of strikes is therefore expected and okay).
     private var pingStrikes: Int
 
-    /// The last time a valid ping happened. This may be in the distant past if there is no such
-    /// time (for example the connection is new and there are no active calls).
+    /// The last time a valid ping happened.
     ///
     /// Note: `distantPast` isn't used to indicate no previous valid ping as `NIODeadline` uses
     /// the monotonic clock on Linux which uses an undefined starting point and in some cases isn't
@@ -320,7 +319,7 @@ extension ServerConnectionHandler.StateMachine {
   }
 }
 
-extension ServerConnectionHandler.StateMachine {
+extension ServerConnectionManagementHandler.StateMachine {
   fileprivate enum State {
     /// The connection is active.
     struct Active {

+ 473 - 0
Sources/GRPCHTTP2Core/Server/Connection/ServerConnectionManagementHandler.swift

@@ -0,0 +1,473 @@
+/*
+ * Copyright 2024, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import NIOCore
+import NIOHTTP2
+
+/// A `ChannelHandler` which manages the lifecycle of a gRPC connection over HTTP/2.
+///
+/// This handler is responsible for managing several aspects of the connection. These include:
+/// 1. Handling the graceful close of connections. When gracefully closing a connection the server
+///    sends a GOAWAY frame with the last stream ID set to the maximum stream ID allowed followed by
+///    a PING frame. On receipt of the PING frame the server sends another GOAWAY frame with the
+///    highest ID of all streams which have been opened. After this, the handler closes the
+///    connection once all streams are closed.
+/// 2. Enforcing that graceful shutdown doesn't exceed a configured limit (if configured).
+/// 3. Gracefully closing the connection once it reaches the maximum configured age (if configured).
+/// 4. Gracefully closing the connection once it has been idle for a given period of time (if
+///    configured).
+/// 5. Periodically sending keep alive pings to the client (if configured) and closing the
+///    connection if necessary.
+/// 6. Policing pings sent by the client to ensure that the client isn't misconfigured to send
+///    too many pings.
+///
+/// Some of the behaviours are described in:
+/// - [gRFC A8](https://github.com/grpc/proposal/blob/master/A8-client-side-keepalive.md), and
+/// - [gRFC A9](https://github.com/grpc/proposal/blob/master/A9-server-side-conn-mgt.md).
+final class ServerConnectionManagementHandler: ChannelDuplexHandler {
+  typealias InboundIn = HTTP2Frame
+  typealias InboundOut = HTTP2Frame
+  typealias OutboundIn = HTTP2Frame
+  typealias OutboundOut = HTTP2Frame
+
+  /// The `EventLoop` of the `Channel` this handler exists in.
+  private let eventLoop: EventLoop
+
+  /// The maximum amount of time a connection may be idle for. If the connection remains idle
+  /// (i.e. has no open streams) for this period of time then the connection will be gracefully
+  /// closed.
+  private var maxIdleTimer: Timer?
+
+  /// The maximum age of a connection. If the connection remains open after this amount of time
+  /// then it will be gracefully closed.
+  private var maxAgeTimer: Timer?
+
+  /// The maximum amount of time a connection may spend closing gracefully, after which it is
+  /// closed abruptly. The timer starts after the second GOAWAY frame has been sent.
+  private var maxGraceTimer: Timer?
+
+  /// The amount of time to wait before sending a keep alive ping.
+  private var keepAliveTimer: Timer?
+
+  /// The amount of time the client has to reply after sending a keep alive ping. Only used if
+  /// `keepAliveTimer` is set.
+  private var keepAliveTimeoutTimer: Timer
+
+  private struct Timer {
+    /// The delay to wait before running the task.
+    private let delay: TimeAmount
+    /// The task to run, if scheduled.
+    private var task: Scheduled<Void>?
+
+    init(delay: TimeAmount) {
+      self.delay = delay
+      self.task = nil
+    }
+
+    /// Schedule a task on the given `EventLoop`.
+    mutating func schedule(on eventLoop: EventLoop, task: @escaping () throws -> Void) {
+      self.task?.cancel()
+      self.task = eventLoop.scheduleTask(in: self.delay, task)
+    }
+
+    /// Cancels the task, if one was scheduled.
+    mutating func cancel() {
+      self.task?.cancel()
+      self.task = nil
+    }
+  }
+
+  /// Opaque data sent in keep alive pings.
+  private let keepAlivePingData: HTTP2PingData
+
+  /// Whether a flush is pending.
+  private var flushPending: Bool
+  /// Whether `channelRead` has been called and `channelReadComplete` hasn't yet been called.
+  /// Resets once `channelReadComplete` returns.
+  private var inReadLoop: Bool
+
+  /// The current state of the connection.
+  private var state: StateMachine
+
+  /// The clock.
+  private let clock: Clock
+
+  /// A clock providing the current time.
+  ///
+  /// This is necessary for testing where a manual clock can be used and advanced from the test.
+  /// While NIO's `EmbeddedEventLoop` provides control over its view of time (and therefore any
+  /// events scheduled on it) it doesn't offer a way to get the current time. This is usually done
+  /// via `NIODeadline`.
+  enum Clock {
+    case nio
+    case manual(Manual)
+
+    func now() -> NIODeadline {
+      switch self {
+      case .nio:
+        return .now()
+      case .manual(let clock):
+        return clock.time
+      }
+    }
+
+    final class Manual {
+      private(set) var time: NIODeadline
+
+      init() {
+        self.time = .uptimeNanoseconds(0)
+      }
+
+      func advance(by amount: TimeAmount) {
+        self.time = self.time + amount
+      }
+    }
+  }
+
+  /// Stats about recently written frames. Used to determine whether to reset keep-alive state.
+  private var frameStats: FrameStats
+
+  struct FrameStats {
+    private(set) var didWriteHeadersOrData = false
+
+    /// Mark that a HEADERS frame has been written.
+    mutating func wroteHeaders() {
+      self.didWriteHeadersOrData = true
+    }
+
+    /// Mark that DATA frame has been written.
+    mutating func wroteData() {
+      self.didWriteHeadersOrData = true
+    }
+
+    /// Resets the state such that no HEADERS or DATA frames have been written.
+    mutating func reset() {
+      self.didWriteHeadersOrData = false
+    }
+  }
+
+  /// A synchronous view over this handler.
+  var syncView: SyncView {
+    return SyncView(self)
+  }
+
+  /// A synchronous view over this handler.
+  ///
+  /// Methods on this view *must* be called from the same `EventLoop` as the `Channel` in which
+  /// this handler exists.
+  struct SyncView {
+    private let handler: ServerConnectionManagementHandler
+
+    fileprivate init(_ handler: ServerConnectionManagementHandler) {
+      self.handler = handler
+    }
+
+    /// Notify the handler that the connection has received a flush event.
+    func connectionWillFlush() {
+      // The handler can't rely on `flush(context:)` due to its expected position in the pipeline.
+      // It's expected to be placed after the HTTP/2 handler (i.e. closer to the application) as
+      // it needs to receive HTTP/2 frames. However, flushes from stream channels aren't sent down
+      // the entire connection channel, instead they are sent from the point in the channel they
+      // are multiplexed from (either the HTTP/2 handler or the HTTP/2 multiplexing handler,
+      // depending on how multiplexing is configured).
+      self.handler.eventLoop.assertInEventLoop()
+      if self.handler.frameStats.didWriteHeadersOrData {
+        self.handler.frameStats.reset()
+        self.handler.state.resetKeepAliveState()
+      }
+    }
+
+    /// Notify the handler that a HEADERS frame was written in the last write loop.
+    func wroteHeadersFrame() {
+      self.handler.eventLoop.assertInEventLoop()
+      self.handler.frameStats.wroteHeaders()
+    }
+
+    /// Notify the handler that a DATA frame was written in the last write loop.
+    func wroteDataFrame() {
+      self.handler.eventLoop.assertInEventLoop()
+      self.handler.frameStats.wroteData()
+    }
+  }
+
+  /// Creates a new handler which manages the lifecycle of a connection.
+  ///
+  /// - Parameters:
+  ///   - eventLoop: The `EventLoop` of the `Channel` this handler is placed in.
+  ///   - maxIdleTime: The maximum amount time a connection may be idle for before being closed.
+  ///   - maxAge: The maximum amount of time a connection may exist before being gracefully closed.
+  ///   - maxGraceTime: The maximum amount of time that the connection has to close gracefully.
+  ///   - keepAliveTime: The amount of time to wait after reading data before sending a keep-alive
+  ///       ping.
+  ///   - keepAliveTimeout: The amount of time the client has to reply after the server sends a
+  ///       keep-alive ping to keep the connection open. The connection is closed if no reply
+  ///       is received.
+  ///   - allowKeepAliveWithoutCalls: Whether the server allows the client to send keep-alive pings
+  ///       when there are no calls in progress.
+  ///   - minPingIntervalWithoutCalls: The minimum allowed interval the client is allowed to send
+  ///       keep-alive pings. Pings more frequent than this interval count as 'strikes' and the
+  ///       connection is closed if there are too many strikes.
+  ///   - clock: A clock providing the current time.
+  init(
+    eventLoop: EventLoop,
+    maxIdleTime: TimeAmount?,
+    maxAge: TimeAmount?,
+    maxGraceTime: TimeAmount?,
+    keepAliveTime: TimeAmount?,
+    keepAliveTimeout: TimeAmount?,
+    allowKeepAliveWithoutCalls: Bool,
+    minPingIntervalWithoutCalls: TimeAmount,
+    clock: Clock = .nio
+  ) {
+    self.eventLoop = eventLoop
+
+    self.maxIdleTimer = maxIdleTime.map { Timer(delay: $0) }
+    self.maxAgeTimer = maxAge.map { Timer(delay: $0) }
+    self.maxGraceTimer = maxGraceTime.map { Timer(delay: $0) }
+
+    self.keepAliveTimer = keepAliveTime.map { Timer(delay: $0) }
+    // Always create a keep alive timeout timer, it's only used if there is a keep alive timer.
+    self.keepAliveTimeoutTimer = Timer(delay: keepAliveTimeout ?? .seconds(20))
+
+    // Generate a random value to be used as keep alive ping data.
+    let pingData = UInt64.random(in: .min ... .max)
+    self.keepAlivePingData = HTTP2PingData(withInteger: pingData)
+
+    self.state = StateMachine(
+      allowKeepAliveWithoutCalls: allowKeepAliveWithoutCalls,
+      minPingReceiveIntervalWithoutCalls: minPingIntervalWithoutCalls,
+      goAwayPingData: HTTP2PingData(withInteger: ~pingData)
+    )
+
+    self.flushPending = false
+    self.inReadLoop = false
+    self.clock = clock
+    self.frameStats = FrameStats()
+  }
+
+  func handlerAdded(context: ChannelHandlerContext) {
+    assert(context.eventLoop === self.eventLoop)
+  }
+
+  func channelActive(context: ChannelHandlerContext) {
+    self.maxAgeTimer?.schedule(on: context.eventLoop) {
+      self.initiateGracefulShutdown(context: context)
+    }
+
+    self.maxIdleTimer?.schedule(on: context.eventLoop) {
+      self.initiateGracefulShutdown(context: context)
+    }
+
+    self.keepAliveTimer?.schedule(on: context.eventLoop) {
+      self.keepAliveTimerFired(context: context)
+    }
+
+    context.fireChannelActive()
+  }
+
+  func channelInactive(context: ChannelHandlerContext) {
+    self.maxIdleTimer?.cancel()
+    self.maxAgeTimer?.cancel()
+    self.maxGraceTimer?.cancel()
+    self.keepAliveTimer?.cancel()
+    self.keepAliveTimeoutTimer.cancel()
+    context.fireChannelInactive()
+  }
+
+  func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
+    switch event {
+    case let event as NIOHTTP2StreamCreatedEvent:
+      // The connection isn't idle if a stream is open.
+      self.maxIdleTimer?.cancel()
+      self.state.streamOpened(event.streamID)
+
+    case let event as StreamClosedEvent:
+      switch self.state.streamClosed(event.streamID) {
+      case .startIdleTimer:
+        self.maxIdleTimer?.schedule(on: context.eventLoop) {
+          self.initiateGracefulShutdown(context: context)
+        }
+
+      case .close:
+        context.close(mode: .all, promise: nil)
+
+      case .none:
+        ()
+      }
+
+    default:
+      ()
+    }
+
+    context.fireUserInboundEventTriggered(event)
+  }
+
+  func channelRead(context: ChannelHandlerContext, data: NIOAny) {
+    self.inReadLoop = true
+
+    // Any read data indicates that the connection is alive so cancel the keep-alive timers.
+    self.keepAliveTimer?.cancel()
+    self.keepAliveTimeoutTimer.cancel()
+
+    let frame = self.unwrapInboundIn(data)
+    switch frame.payload {
+    case .ping(let data, let ack):
+      if ack {
+        self.handlePingAck(context: context, data: data)
+      } else {
+        self.handlePing(context: context, data: data)
+      }
+
+    default:
+      ()  // Only interested in PING frames, ignore the rest.
+    }
+
+    context.fireChannelRead(data)
+  }
+
+  func channelReadComplete(context: ChannelHandlerContext) {
+    while self.flushPending {
+      self.flushPending = false
+      context.flush()
+    }
+
+    self.inReadLoop = false
+
+    // Done reading: schedule the keep-alive timer.
+    self.keepAliveTimer?.schedule(on: context.eventLoop) {
+      self.keepAliveTimerFired(context: context)
+    }
+
+    context.fireChannelReadComplete()
+  }
+
+  func flush(context: ChannelHandlerContext) {
+    self.maybeFlush(context: context)
+  }
+}
+
+extension ServerConnectionManagementHandler {
+  private func maybeFlush(context: ChannelHandlerContext) {
+    if self.inReadLoop {
+      self.flushPending = true
+    } else {
+      context.flush()
+    }
+  }
+
+  private func initiateGracefulShutdown(context: ChannelHandlerContext) {
+    context.eventLoop.assertInEventLoop()
+
+    // Cancel any timers if initiating shutdown.
+    self.maxIdleTimer?.cancel()
+    self.maxAgeTimer?.cancel()
+    self.keepAliveTimer?.cancel()
+    self.keepAliveTimeoutTimer.cancel()
+
+    switch self.state.startGracefulShutdown() {
+    case .sendGoAwayAndPing(let pingData):
+      // There's a time window between the server sending a GOAWAY frame and the client receiving
+      // it. During this time the client may open new streams as it doesn't yet know about the
+      // GOAWAY frame.
+      //
+      // The server therefore sends a GOAWAY with the last stream ID set to the maximum stream ID
+      // and follows it with a PING frame. When the server receives the ack for the PING frame it
+      // knows that the client has received the initial GOAWAY frame and that no more streams may
+      // be opened. The server can then send an additional GOAWAY frame with a more representative
+      // last stream ID.
+      let goAway = HTTP2Frame(
+        streamID: .rootStream,
+        payload: .goAway(
+          lastStreamID: .maxID,
+          errorCode: .noError,
+          opaqueData: nil
+        )
+      )
+
+      let ping = HTTP2Frame(streamID: .rootStream, payload: .ping(pingData, ack: false))
+
+      context.write(self.wrapOutboundOut(goAway), promise: nil)
+      context.write(self.wrapOutboundOut(ping), promise: nil)
+      self.maybeFlush(context: context)
+
+    case .none:
+      ()  // Already shutting down.
+    }
+  }
+
+  private func handlePing(context: ChannelHandlerContext, data: HTTP2PingData) {
+    switch self.state.receivedPing(atTime: self.clock.now(), data: data) {
+    case .enhanceYourCalmThenClose(let streamID):
+      let goAway = HTTP2Frame(
+        streamID: .rootStream,
+        payload: .goAway(
+          lastStreamID: streamID,
+          errorCode: .enhanceYourCalm,
+          opaqueData: context.channel.allocator.buffer(string: "too_many_pings")
+        )
+      )
+
+      context.write(self.wrapOutboundOut(goAway), promise: nil)
+      self.maybeFlush(context: context)
+      context.close(promise: nil)
+
+    case .sendAck:
+      let ping = HTTP2Frame(streamID: .rootStream, payload: .ping(data, ack: true))
+      context.write(self.wrapOutboundOut(ping), promise: nil)
+      self.maybeFlush(context: context)
+
+    case .none:
+      ()
+    }
+  }
+
+  private func handlePingAck(context: ChannelHandlerContext, data: HTTP2PingData) {
+    switch self.state.receivedPingAck(data: data) {
+    case .sendGoAway(let streamID, let close):
+      let goAway = HTTP2Frame(
+        streamID: .rootStream,
+        payload: .goAway(lastStreamID: streamID, errorCode: .noError, opaqueData: nil)
+      )
+
+      context.write(self.wrapOutboundOut(goAway), promise: nil)
+      self.maybeFlush(context: context)
+
+      if close {
+        context.close(promise: nil)
+      } else {
+        // RPCs may have a grace period for finishing once the second GOAWAY frame has finished.
+        // If this is set close the connection abruptly once the grace period passes.
+        self.maxGraceTimer?.schedule(on: context.eventLoop) {
+          context.close(promise: nil)
+        }
+      }
+
+    case .none:
+      ()
+    }
+  }
+
+  private func keepAliveTimerFired(context: ChannelHandlerContext) {
+    let ping = HTTP2Frame(streamID: .rootStream, payload: .ping(self.keepAlivePingData, ack: false))
+    context.write(self.wrapInboundOut(ping), promise: nil)
+    self.maybeFlush(context: context)
+
+    // Schedule a timeout on waiting for the response.
+    self.keepAliveTimeoutTimer.schedule(on: context.eventLoop) {
+      self.initiateGracefulShutdown(context: context)
+    }
+  }
+}

+ 3 - 3
Tests/GRPCHTTP2CoreTests/Server/Connection/ServerConnectionHandler+StateMachineTests.swift → Tests/GRPCHTTP2CoreTests/Server/Connection/ServerConnectionManagementHandler+StateMachineTests.swift

@@ -20,12 +20,12 @@ import XCTest
 
 @testable import GRPCHTTP2Core
 
-final class ServerConnectionHandlerStateMachineTests: XCTestCase {
+final class ServerConnectionManagementHandlerStateMachineTests: XCTestCase {
   private func makeStateMachine(
     allowKeepAliveWithoutCalls: Bool = false,
     minPingReceiveIntervalWithoutCalls: TimeAmount = .minutes(5),
     goAwayPingData: HTTP2PingData = HTTP2PingData(withInteger: 42)
-  ) -> ServerConnectionHandler.StateMachine {
+  ) -> ServerConnectionManagementHandler.StateMachine {
     return .init(
       allowKeepAliveWithoutCalls: allowKeepAliveWithoutCalls,
       minPingReceiveIntervalWithoutCalls: minPingReceiveIntervalWithoutCalls,
@@ -169,7 +169,7 @@ final class ServerConnectionHandlerStateMachineTests: XCTestCase {
   }
 
   func testPingStrikeUsingMinReceiveInterval(
-    state: inout ServerConnectionHandler.StateMachine,
+    state: inout ServerConnectionManagementHandler.StateMachine,
     interval: TimeAmount,
     expectedID id: HTTP2StreamID
   ) {

+ 428 - 0
Tests/GRPCHTTP2CoreTests/Server/Connection/ServerConnectionManagementHandlerTests.swift

@@ -0,0 +1,428 @@
+/*
+ * Copyright 2024, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import NIOCore
+import NIOEmbedded
+import NIOHTTP2
+import XCTest
+
+@testable import GRPCHTTP2Core
+
+final class ServerConnectionManagementHandlerTests: XCTestCase {
+  func testIdleTimeoutOnNewConnection() throws {
+    let connection = try Connection(maxIdleTime: .minutes(1))
+    try connection.activate()
+    // Hit the max idle time.
+    connection.advanceTime(by: .minutes(1))
+
+    // Follow the graceful shutdown flow.
+    try self.testGracefulShutdown(connection: connection, lastStreamID: 0)
+
+    // Closed because no streams were open.
+    try connection.waitUntilClosed()
+  }
+
+  func testIdleTimerIsCancelledWhenStreamIsOpened() throws {
+    let connection = try Connection(maxIdleTime: .minutes(1))
+    try connection.activate()
+
+    // Open a stream to cancel the idle timer and run through the max idle time.
+    connection.streamOpened(1)
+    connection.advanceTime(by: .minutes(1))
+
+    // No GOAWAY frame means the timer was cancelled.
+    XCTAssertNil(try connection.readFrame())
+  }
+
+  func testIdleTimerStartsWhenAllStreamsAreClosed() throws {
+    let connection = try Connection(maxIdleTime: .minutes(1))
+    try connection.activate()
+
+    // Open a stream to cancel the idle timer and run through the max idle time.
+    connection.streamOpened(1)
+    connection.advanceTime(by: .minutes(1))
+    XCTAssertNil(try connection.readFrame())
+
+    // Close the stream to start the timer again.
+    connection.streamClosed(1)
+    connection.advanceTime(by: .minutes(1))
+
+    // Follow the graceful shutdown flow.
+    try self.testGracefulShutdown(connection: connection, lastStreamID: 1)
+
+    // Closed because no streams were open.
+    try connection.waitUntilClosed()
+  }
+
+  func testMaxAge() throws {
+    let connection = try Connection(maxAge: .minutes(1))
+    try connection.activate()
+
+    // Open some streams.
+    connection.streamOpened(1)
+    connection.streamOpened(3)
+
+    // Run to the max age and follow the graceful shutdown flow.
+    connection.advanceTime(by: .minutes(1))
+    try self.testGracefulShutdown(connection: connection, lastStreamID: 3)
+
+    // Close the streams.
+    connection.streamClosed(1)
+    connection.streamClosed(3)
+
+    // Connection will be closed now.
+    try connection.waitUntilClosed()
+  }
+
+  func testGracefulShutdownRatchetsDownStreamID() throws {
+    // This test uses the idle timeout to trigger graceful shutdown. The mechanism is the same
+    // regardless of how it's triggered.
+    let connection = try Connection(maxIdleTime: .minutes(1))
+    try connection.activate()
+
+    // Trigger the shutdown, but open a stream during shutdown.
+    connection.advanceTime(by: .minutes(1))
+    try self.testGracefulShutdown(
+      connection: connection,
+      lastStreamID: 1,
+      streamToOpenBeforePingAck: 1
+    )
+
+    // Close the stream to trigger closing the connection.
+    connection.streamClosed(1)
+    try connection.waitUntilClosed()
+  }
+
+  func testGracefulShutdownGracePeriod() throws {
+    // This test uses the idle timeout to trigger graceful shutdown. The mechanism is the same
+    // regardless of how it's triggered.
+    let connection = try Connection(
+      maxIdleTime: .minutes(1),
+      maxGraceTime: .seconds(5)
+    )
+    try connection.activate()
+
+    // Trigger the shutdown, but open a stream during shutdown.
+    connection.advanceTime(by: .minutes(1))
+    try self.testGracefulShutdown(
+      connection: connection,
+      lastStreamID: 1,
+      streamToOpenBeforePingAck: 1
+    )
+
+    // Wait out the grace period without closing the stream.
+    connection.advanceTime(by: .seconds(5))
+    try connection.waitUntilClosed()
+  }
+
+  func testKeepAliveOnNewConnection() throws {
+    let connection = try Connection(
+      keepAliveTime: .minutes(5),
+      keepAliveTimeout: .seconds(5)
+    )
+    try connection.activate()
+
+    // Wait for the keep alive timer to fire which should cause the server to send a keep
+    // alive PING.
+    connection.advanceTime(by: .minutes(5))
+    let frame1 = try XCTUnwrap(connection.readFrame())
+    XCTAssertEqual(frame1.streamID, .rootStream)
+    try XCTAssertPing(frame1.payload) { data, ack in
+      XCTAssertFalse(ack)
+      // Data is opaque, send it back.
+      try connection.ping(data: data, ack: true)
+    }
+
+    // Run past the timeout, nothing should happen.
+    connection.advanceTime(by: .seconds(5))
+    XCTAssertNil(try connection.readFrame())
+  }
+
+  func testKeepAliveStartsAfterReadLoop() throws {
+    let connection = try Connection(
+      keepAliveTime: .minutes(5),
+      keepAliveTimeout: .seconds(5)
+    )
+    try connection.activate()
+
+    // Write a frame into the channel _without_ calling channel read complete. This will cancel
+    // the keep alive timer.
+    let settings = HTTP2Frame(streamID: .rootStream, payload: .settings(.settings([])))
+    connection.channel.pipeline.fireChannelRead(NIOAny(settings))
+
+    // Run out the keep alive timer, it shouldn't fire.
+    connection.advanceTime(by: .minutes(5))
+    XCTAssertNil(try connection.readFrame())
+
+    // Fire channel read complete to start the keep alive timer again.
+    connection.channel.pipeline.fireChannelReadComplete()
+
+    // Now expire the keep alive timer again, we should read out a PING frame.
+    connection.advanceTime(by: .minutes(5))
+    let frame1 = try XCTUnwrap(connection.readFrame())
+    XCTAssertEqual(frame1.streamID, .rootStream)
+    XCTAssertPing(frame1.payload) { data, ack in
+      XCTAssertFalse(ack)
+    }
+  }
+
+  func testKeepAliveOnNewConnectionWithoutResponse() throws {
+    let connection = try Connection(
+      keepAliveTime: .minutes(5),
+      keepAliveTimeout: .seconds(5)
+    )
+    try connection.activate()
+
+    // Wait for the keep alive timer to fire which should cause the server to send a keep
+    // alive PING.
+    connection.advanceTime(by: .minutes(5))
+    let frame1 = try XCTUnwrap(connection.readFrame())
+    XCTAssertEqual(frame1.streamID, .rootStream)
+    XCTAssertPing(frame1.payload) { data, ack in
+      XCTAssertFalse(ack)
+    }
+
+    // We didn't ack the PING, the connection should shutdown after the timeout.
+    connection.advanceTime(by: .seconds(5))
+    try self.testGracefulShutdown(connection: connection, lastStreamID: 0)
+
+    // Connection is closed now.
+    try connection.waitUntilClosed()
+  }
+
+  func testClientKeepAlivePolicing() throws {
+    let connection = try Connection(
+      allowKeepAliveWithoutCalls: true,
+      minPingIntervalWithoutCalls: .minutes(1)
+    )
+    try connection.activate()
+
+    // The first ping is valid, the second and third are strikes.
+    for _ in 1 ... 3 {
+      try connection.ping(data: HTTP2PingData(), ack: false)
+      let frame = try XCTUnwrap(connection.readFrame())
+      XCTAssertEqual(frame.streamID, .rootStream)
+      XCTAssertPing(frame.payload) { data, ack in
+        XCTAssertEqual(data, HTTP2PingData())
+        XCTAssertTrue(ack)
+      }
+    }
+
+    // The fourth ping is the third strike and triggers a GOAWAY.
+    try connection.ping(data: HTTP2PingData(), ack: false)
+    let frame = try XCTUnwrap(connection.readFrame())
+    XCTAssertEqual(frame.streamID, .rootStream)
+    XCTAssertGoAway(frame.payload) { streamID, error, data in
+      XCTAssertEqual(streamID, .rootStream)
+      XCTAssertEqual(error, .enhanceYourCalm)
+      XCTAssertEqual(data, ByteBuffer(string: "too_many_pings"))
+    }
+
+    // The server should close the connection.
+    try connection.waitUntilClosed()
+  }
+
+  func testClientKeepAliveWithPermissibleIntervals() throws {
+    let connection = try Connection(
+      allowKeepAliveWithoutCalls: true,
+      minPingIntervalWithoutCalls: .minutes(1),
+      manualClock: true
+    )
+    try connection.activate()
+
+    for _ in 1 ... 100 {
+      try connection.ping(data: HTTP2PingData(), ack: false)
+      let frame = try XCTUnwrap(connection.readFrame())
+      XCTAssertEqual(frame.streamID, .rootStream)
+      XCTAssertPing(frame.payload) { data, ack in
+        XCTAssertEqual(data, HTTP2PingData())
+        XCTAssertTrue(ack)
+      }
+
+      // Advance by the ping interval.
+      connection.advanceTime(by: .minutes(1))
+    }
+  }
+
+  func testClientKeepAliveResetState() throws {
+    let connection = try Connection(
+      allowKeepAliveWithoutCalls: true,
+      minPingIntervalWithoutCalls: .minutes(1)
+    )
+    try connection.activate()
+
+    func sendThreeKeepAlivePings() throws {
+      // The first ping is valid, the second and third are strikes.
+      for _ in 1 ... 3 {
+        try connection.ping(data: HTTP2PingData(), ack: false)
+        let frame = try XCTUnwrap(connection.readFrame())
+        XCTAssertEqual(frame.streamID, .rootStream)
+        XCTAssertPing(frame.payload) { data, ack in
+          XCTAssertEqual(data, HTTP2PingData())
+          XCTAssertTrue(ack)
+        }
+      }
+    }
+
+    try sendThreeKeepAlivePings()
+
+    // "send" a HEADERS frame and flush to reset keep alive state.
+    connection.syncView.wroteHeadersFrame()
+    connection.syncView.connectionWillFlush()
+
+    // As above, the first ping is valid, the next two are strikes.
+    try sendThreeKeepAlivePings()
+
+    // The next ping is the third strike and triggers a GOAWAY.
+    try connection.ping(data: HTTP2PingData(), ack: false)
+    let frame = try XCTUnwrap(connection.readFrame())
+    XCTAssertEqual(frame.streamID, .rootStream)
+    XCTAssertGoAway(frame.payload) { streamID, error, data in
+      XCTAssertEqual(streamID, .rootStream)
+      XCTAssertEqual(error, .enhanceYourCalm)
+      XCTAssertEqual(data, ByteBuffer(string: "too_many_pings"))
+    }
+
+    // The server should close the connection.
+    try connection.waitUntilClosed()
+  }
+}
+
+extension ServerConnectionManagementHandlerTests {
+  private func testGracefulShutdown(
+    connection: Connection,
+    lastStreamID: HTTP2StreamID,
+    streamToOpenBeforePingAck: HTTP2StreamID? = nil
+  ) throws {
+    let frame1 = try XCTUnwrap(connection.readFrame())
+    XCTAssertEqual(frame1.streamID, .rootStream)
+    XCTAssertGoAway(frame1.payload) { streamID, errorCode, _ in
+      XCTAssertEqual(streamID, .maxID)
+      XCTAssertEqual(errorCode, .noError)
+    }
+
+    // Followed by a PING
+    let frame2 = try XCTUnwrap(connection.readFrame())
+    XCTAssertEqual(frame2.streamID, .rootStream)
+    try XCTAssertPing(frame2.payload) { data, ack in
+      XCTAssertFalse(ack)
+
+      if let id = streamToOpenBeforePingAck {
+        connection.streamOpened(id)
+      }
+
+      // Send the PING ACK.
+      try connection.ping(data: data, ack: true)
+    }
+
+    // PING ACK triggers another GOAWAY.
+    let frame3 = try XCTUnwrap(connection.readFrame())
+    XCTAssertEqual(frame3.streamID, .rootStream)
+    XCTAssertGoAway(frame3.payload) { streamID, errorCode, _ in
+      XCTAssertEqual(streamID, lastStreamID)
+      XCTAssertEqual(errorCode, .noError)
+    }
+  }
+}
+
+extension ServerConnectionManagementHandlerTests {
+  struct Connection {
+    let channel: EmbeddedChannel
+    let syncView: ServerConnectionManagementHandler.SyncView
+
+    var loop: EmbeddedEventLoop {
+      self.channel.embeddedEventLoop
+    }
+
+    private let clock: ServerConnectionManagementHandler.Clock
+
+    init(
+      maxIdleTime: TimeAmount? = nil,
+      maxAge: TimeAmount? = nil,
+      maxGraceTime: TimeAmount? = nil,
+      keepAliveTime: TimeAmount? = nil,
+      keepAliveTimeout: TimeAmount? = nil,
+      allowKeepAliveWithoutCalls: Bool = false,
+      minPingIntervalWithoutCalls: TimeAmount = .minutes(5),
+      manualClock: Bool = false
+    ) throws {
+      if manualClock {
+        self.clock = .manual(ServerConnectionManagementHandler.Clock.Manual())
+      } else {
+        self.clock = .nio
+      }
+
+      let loop = EmbeddedEventLoop()
+      let handler = ServerConnectionManagementHandler(
+        eventLoop: loop,
+        maxIdleTime: maxIdleTime,
+        maxAge: maxAge,
+        maxGraceTime: maxGraceTime,
+        keepAliveTime: keepAliveTime,
+        keepAliveTimeout: keepAliveTimeout,
+        allowKeepAliveWithoutCalls: allowKeepAliveWithoutCalls,
+        minPingIntervalWithoutCalls: minPingIntervalWithoutCalls,
+        clock: self.clock
+      )
+
+      self.syncView = handler.syncView
+      self.channel = EmbeddedChannel(handler: handler, loop: loop)
+    }
+
+    func activate() throws {
+      try self.channel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 0)).wait()
+    }
+
+    func advanceTime(by delta: TimeAmount) {
+      switch self.clock {
+      case .nio:
+        ()
+      case .manual(let clock):
+        clock.advance(by: delta)
+      }
+
+      self.loop.advanceTime(by: delta)
+    }
+
+    func streamOpened(_ id: HTTP2StreamID) {
+      let event = NIOHTTP2StreamCreatedEvent(
+        streamID: id,
+        localInitialWindowSize: nil,
+        remoteInitialWindowSize: nil
+      )
+      self.channel.pipeline.fireUserInboundEventTriggered(event)
+    }
+
+    func streamClosed(_ id: HTTP2StreamID) {
+      let event = StreamClosedEvent(streamID: id, reason: nil)
+      self.channel.pipeline.fireUserInboundEventTriggered(event)
+    }
+
+    func ping(data: HTTP2PingData, ack: Bool) throws {
+      let frame = HTTP2Frame(streamID: .rootStream, payload: .ping(data, ack: ack))
+      try self.channel.writeInbound(frame)
+    }
+
+    func readFrame() throws -> HTTP2Frame? {
+      return try self.channel.readOutbound(as: HTTP2Frame.self)
+    }
+
+    func waitUntilClosed() throws {
+      self.channel.embeddedEventLoop.run()
+      try self.channel.closeFuture.wait()
+    }
+  }
+}

+ 43 - 0
Tests/GRPCHTTP2CoreTests/XCTest+FramePayload.swift

@@ -0,0 +1,43 @@
+/*
+ * Copyright 2024, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import NIOCore
+import NIOHTTP2
+import XCTest
+
+func XCTAssertGoAway(
+  _ payload: HTTP2Frame.FramePayload,
+  verify: (HTTP2StreamID, HTTP2ErrorCode, ByteBuffer?) throws -> Void = { _, _, _ in }
+) rethrows {
+  switch payload {
+  case .goAway(let lastStreamID, let errorCode, let opaqueData):
+    try verify(lastStreamID, errorCode, opaqueData)
+  default:
+    XCTFail("Expected '.goAway' got '\(payload)'")
+  }
+}
+
+func XCTAssertPing(
+  _ payload: HTTP2Frame.FramePayload,
+  verify: (HTTP2PingData, Bool) throws -> Void = { _, _ in }
+) rethrows {
+  switch payload {
+  case .ping(let data, ack: let ack):
+    try verify(data, ack)
+  default:
+    XCTFail("Expected '.ping' got '\(payload)'")
+  }
+}