Browse Source

Use NIO scheduleCallback API for connection handler timers (#28)

## Motivation

NIO 2.75.0 added new APIs for scheduling callbacks, which can be used to
implement timers with fewer allocations than the previous APIs[^0].

## Modifications

Replace the timers used in the client and server connection handlers
with a new implementation that uses these new NIO APIs.

## Result

This should have no functional impact. However it reduces the
allocations in the connection handler. How many allocations it reduces
will depend on the exact scenario, but running the `echo` server under
Instruments and sending 1000 requests suggests around 10%.

[^0]: https://github.com/apple/swift-nio/releases/tag/2.75.0

---------

Co-authored-by: George Barnett <gbarnett@apple.com>
Si Beaumont 1 year ago
parent
commit
565aae3d57

+ 1 - 1
Package.swift

@@ -39,7 +39,7 @@ let dependencies: [Package.Dependency] = [
   ),
   .package(
     url: "https://github.com/apple/swift-nio.git",
-    from: "2.65.0"
+    from: "2.75.0"
   ),
   .package(
     url: "https://github.com/apple/swift-nio-http2.git",

+ 74 - 58
Sources/GRPCNIOTransportCore/Client/Connection/ClientConnectionHandler.swift

@@ -63,17 +63,14 @@ package final class ClientConnectionHandler: ChannelInboundHandler, ChannelOutbo
   /// The `EventLoop` of the `Channel` this handler exists in.
   private let eventLoop: any EventLoop
 
-  /// The maximum amount of time the connection may be idle for. If the connection remains idle
-  /// (i.e. has no open streams) for this period of time then the connection will be gracefully
-  /// closed.
-  private var maxIdleTimer: Timer?
+  /// The timer used to gracefully close idle connections.
+  private var maxIdleTimerHandler: Timer<MaxIdleTimerHandlerView>?
 
-  /// The amount of time to wait before sending a keep alive ping.
-  private var keepaliveTimer: Timer?
+  /// The timer used to send keep-alive pings.
+  private var keepaliveTimerHandler: Timer<KeepaliveTimerHandlerView>?
 
-  /// The amount of time the client has to reply after sending a keep alive ping. Only used if
-  /// `keepaliveTimer` is set.
-  private var keepaliveTimeoutTimer: Timer
+  /// The timer used to detect keep alive timeouts, if keep-alive pings are enabled.
+  private var keepaliveTimeoutHandler: Timer<KeepaliveTimeoutHandlerView>?
 
   /// Opaque data sent in keep alive pings.
   private let keepalivePingData: HTTP2PingData
@@ -110,14 +107,34 @@ package final class ClientConnectionHandler: ChannelInboundHandler, ChannelOutbo
     keepaliveWithoutCalls: Bool
   ) {
     self.eventLoop = eventLoop
-    self.maxIdleTimer = maxIdleTime.map { Timer(delay: $0) }
-    self.keepaliveTimer = keepaliveTime.map { Timer(delay: $0, repeat: true) }
-    self.keepaliveTimeoutTimer = Timer(delay: keepaliveTimeout ?? .seconds(20))
     self.keepalivePingData = HTTP2PingData(withInteger: .random(in: .min ... .max))
     self.state = StateMachine(allowKeepaliveWithoutCalls: keepaliveWithoutCalls)
 
     self.flushPending = false
     self.inReadLoop = false
+    if let maxIdleTime {
+      self.maxIdleTimerHandler = Timer(
+        eventLoop: eventLoop,
+        duration: maxIdleTime,
+        repeating: false,
+        handler: MaxIdleTimerHandlerView(self)
+      )
+    }
+    if let keepaliveTime {
+      let keepaliveTimeout = keepaliveTimeout ?? .seconds(20)
+      self.keepaliveTimerHandler = Timer(
+        eventLoop: eventLoop,
+        duration: keepaliveTime,
+        repeating: true,
+        handler: KeepaliveTimerHandlerView(self)
+      )
+      self.keepaliveTimeoutHandler = Timer(
+        eventLoop: eventLoop,
+        duration: keepaliveTimeout,
+        repeating: false,
+        handler: KeepaliveTimeoutHandlerView(self)
+      )
+    }
   }
 
   package func handlerAdded(context: ChannelHandlerContext) {
@@ -142,8 +159,8 @@ package final class ClientConnectionHandler: ChannelInboundHandler, ChannelOutbo
       promise.succeed()
     }
 
-    self.keepaliveTimer?.cancel()
-    self.keepaliveTimeoutTimer.cancel()
+    self.keepaliveTimerHandler?.cancel()
+    self.keepaliveTimeoutHandler?.cancel()
     context.fireChannelInactive()
   }
 
@@ -222,11 +239,8 @@ package final class ClientConnectionHandler: ChannelInboundHandler, ChannelOutbo
       // Pings are ack'd by the HTTP/2 handler so we only pay attention to acks here, and in
       // particular only those carrying the keep-alive data.
       if ack, data == self.keepalivePingData {
-        let loopBound = LoopBoundView(handler: self, context: context)
-        self.keepaliveTimeoutTimer.cancel()
-        self.keepaliveTimer?.schedule(on: context.eventLoop) {
-          loopBound.keepaliveTimerFired()
-        }
+        self.keepaliveTimeoutHandler?.cancel()
+        self.keepaliveTimerHandler?.start()
       }
 
     case .settings(.settings(_)):
@@ -236,15 +250,8 @@ package final class ClientConnectionHandler: ChannelInboundHandler, ChannelOutbo
       // becoming active is insufficient as, for example, a TLS handshake may fail after
       // establishing the TCP connection, or the server isn't configured for gRPC (or HTTP/2).
       if isInitialSettings {
-        let loopBound = LoopBoundView(handler: self, context: context)
-        self.keepaliveTimer?.schedule(on: context.eventLoop) {
-          loopBound.keepaliveTimerFired()
-        }
-
-        self.maxIdleTimer?.schedule(on: context.eventLoop) {
-          loopBound.maxIdleTimerFired()
-        }
-
+        self.keepaliveTimerHandler?.start()
+        self.maxIdleTimerHandler?.start()
         context.fireChannelRead(self.wrapInboundOut(.ready))
       }
 
@@ -290,29 +297,44 @@ package final class ClientConnectionHandler: ChannelInboundHandler, ChannelOutbo
   }
 }
 
+// Timer handler views.
 extension ClientConnectionHandler {
-  struct LoopBoundView: @unchecked Sendable {
+  struct MaxIdleTimerHandlerView: @unchecked Sendable, NIOScheduledCallbackHandler {
     private let handler: ClientConnectionHandler
-    private let context: ChannelHandlerContext
 
-    init(handler: ClientConnectionHandler, context: ChannelHandlerContext) {
+    init(_ handler: ClientConnectionHandler) {
       self.handler = handler
-      self.context = context
     }
 
-    func keepaliveTimerFired() {
-      self.context.eventLoop.assertInEventLoop()
-      self.handler.keepaliveTimerFired(context: self.context)
+    func handleScheduledCallback(eventLoop: some EventLoop) {
+      self.handler.eventLoop.assertInEventLoop()
+      self.handler.maxIdleTimerFired()
     }
+  }
+
+  struct KeepaliveTimerHandlerView: @unchecked Sendable, NIOScheduledCallbackHandler {
+    private let handler: ClientConnectionHandler
 
-    func keepaliveTimeoutExpired() {
-      self.context.eventLoop.assertInEventLoop()
-      self.handler.keepaliveTimeoutExpired(context: self.context)
+    init(_ handler: ClientConnectionHandler) {
+      self.handler = handler
     }
 
-    func maxIdleTimerFired() {
-      self.context.eventLoop.assertInEventLoop()
-      self.handler.maxIdleTimerFired(context: self.context)
+    func handleScheduledCallback(eventLoop: some EventLoop) {
+      self.handler.eventLoop.assertInEventLoop()
+      self.handler.keepaliveTimerFired()
+    }
+  }
+
+  struct KeepaliveTimeoutHandlerView: @unchecked Sendable, NIOScheduledCallbackHandler {
+    private let handler: ClientConnectionHandler
+
+    init(_ handler: ClientConnectionHandler) {
+      self.handler = handler
+    }
+
+    func handleScheduledCallback(eventLoop: some EventLoop) {
+      self.handler.eventLoop.assertInEventLoop()
+      self.handler.keepaliveTimeoutExpired()
     }
   }
 }
@@ -356,7 +378,7 @@ extension ClientConnectionHandler {
     self.eventLoop.assertInEventLoop()
 
     // Stream created, so the connection isn't idle.
-    self.maxIdleTimer?.cancel()
+    self.maxIdleTimerHandler?.cancel()
     self.state.streamOpened(id)
   }
 
@@ -368,13 +390,10 @@ extension ClientConnectionHandler {
     case .startIdleTimer(let cancelKeepalive):
       // All streams are closed, restart the idle timer, and stop the keep-alive timer (it may
       // not stop if keep-alive is allowed when there are no active calls).
-      let loopBound = LoopBoundView(handler: self, context: context)
-      self.maxIdleTimer?.schedule(on: context.eventLoop) {
-        loopBound.maxIdleTimerFired()
-      }
+      self.maxIdleTimerHandler?.start()
 
       if cancelKeepalive {
-        self.keepaliveTimer?.cancel()
+        self.keepaliveTimerHandler?.cancel()
       }
 
     case .close:
@@ -397,34 +416,31 @@ extension ClientConnectionHandler {
     }
   }
 
-  private func keepaliveTimerFired(context: ChannelHandlerContext) {
-    guard self.state.sendKeepalivePing() else { return }
+  private func keepaliveTimerFired() {
+    guard self.state.sendKeepalivePing(), let context = self.context else { return }
 
     // Cancel the keep alive timer when the client sends a ping. The timer is resumed when the ping
     // is acknowledged.
-    self.keepaliveTimer?.cancel()
+    self.keepaliveTimerHandler?.cancel()
 
     let ping = HTTP2Frame(streamID: .rootStream, payload: .ping(self.keepalivePingData, ack: false))
     context.write(self.wrapOutboundOut(ping), promise: nil)
     self.maybeFlush(context: context)
 
     // Schedule a timeout on waiting for the response.
-    let loopBound = LoopBoundView(handler: self, context: context)
-    self.keepaliveTimeoutTimer.schedule(on: context.eventLoop) {
-      loopBound.keepaliveTimeoutExpired()
-    }
+    self.keepaliveTimeoutHandler?.start()
   }
 
-  private func keepaliveTimeoutExpired(context: ChannelHandlerContext) {
-    guard self.state.beginClosing() else { return }
+  private func keepaliveTimeoutExpired() {
+    guard self.state.beginClosing(), let context = self.context else { return }
 
     context.fireChannelRead(self.wrapInboundOut(.closing(.keepaliveExpired)))
     self.writeAndFlushGoAway(context: context, message: "keepalive_expired")
     context.close(promise: nil)
   }
 
-  private func maxIdleTimerFired(context: ChannelHandlerContext) {
-    guard self.state.beginClosing() else { return }
+  private func maxIdleTimerFired() {
+    guard self.state.beginClosing(), let context = self.context else { return }
 
     context.fireChannelRead(self.wrapInboundOut(.closing(.idle)))
     self.writeAndFlushGoAway(context: context, message: "idle")

+ 41 - 42
Sources/GRPCNIOTransportCore/Internal/Timer.swift

@@ -16,55 +16,54 @@
 
 package import NIOCore
 
-package struct Timer {
-  /// The delay to wait before running the task.
-  private let delay: TimeAmount
-  /// The task to run, if scheduled.
-  private var task: Kind?
-  /// Whether the task to schedule is repeated.
-  private let `repeat`: Bool
+/// A timer backed by `NIOScheduledCallback`.
+package final class Timer<Handler: NIOScheduledCallbackHandler> where Handler: Sendable {
+  /// The event loop on which to run this timer.
+  private let eventLoop: any EventLoop
 
-  private enum Kind {
-    case once(Scheduled<Void>)
-    case repeated(RepeatedTask)
+  /// The duration of the timer.
+  private let duration: TimeAmount
 
-    func cancel() {
-      switch self {
-      case .once(let task):
-        task.cancel()
-      case .repeated(let task):
-        task.cancel()
-      }
-    }
-  }
+  /// Whether this timer should repeat.
+  private let repeating: Bool
+
+  /// The handler to call when the timer fires.
+  private let handler: Handler
+
+  /// The currently scheduled callback if the timer is running.
+  private var scheduledCallback: NIOScheduledCallback?
 
-  package init(delay: TimeAmount, repeat: Bool = false) {
-    self.delay = delay
-    self.task = nil
-    self.repeat = `repeat`
+  package init(eventLoop: any EventLoop, duration: TimeAmount, repeating: Bool, handler: Handler) {
+    self.eventLoop = eventLoop
+    self.duration = duration
+    self.repeating = repeating
+    self.handler = handler
+    self.scheduledCallback = nil
   }
 
-  /// Schedule a task on the given `EventLoop`.
-  package mutating func schedule(
-    on eventLoop: any EventLoop,
-    work: @escaping @Sendable () throws -> Void
-  ) {
-    self.task?.cancel()
+  /// Cancel the timer, if it is running.
+  package func cancel() {
+    self.eventLoop.assertInEventLoop()
+    guard let scheduledCallback = self.scheduledCallback else { return }
+    scheduledCallback.cancel()
+  }
 
-    if self.repeat {
-      let task = eventLoop.scheduleRepeatedTask(initialDelay: self.delay, delay: self.delay) { _ in
-        try work()
-      }
-      self.task = .repeated(task)
-    } else {
-      let task = eventLoop.scheduleTask(in: self.delay, work)
-      self.task = .once(task)
-    }
+  /// Start or restart the timer.
+  package func start() {
+    self.eventLoop.assertInEventLoop()
+    self.scheduledCallback?.cancel()
+    // Only throws if the event loop is shutting down, so we'll just swallow the error here.
+    self.scheduledCallback = try? self.eventLoop.scheduleCallback(in: self.duration, handler: self)
   }
+}
 
-  /// Cancels the task, if one was scheduled.
-  package mutating func cancel() {
-    self.task?.cancel()
-    self.task = nil
+extension Timer: NIOScheduledCallbackHandler, @unchecked Sendable where Handler: Sendable {
+  /// For repeated timer support, the timer itself proxies the callback and restarts the timer.
+  ///
+  /// - NOTE: Users should not call this function directly.
+  package func handleScheduledCallback(eventLoop: some EventLoop) {
+    self.eventLoop.assertInEventLoop()
+    self.handler.handleScheduledCallback(eventLoop: eventLoop)
+    if self.repeating { self.start() }
   }
 }

+ 131 - 79
Sources/GRPCNIOTransportCore/Server/Connection/ServerConnectionManagementHandler.swift

@@ -48,25 +48,21 @@ package final class ServerConnectionManagementHandler: ChannelDuplexHandler {
   /// The `EventLoop` of the `Channel` this handler exists in.
   private let eventLoop: any EventLoop
 
-  /// The maximum amount of time a connection may be idle for. If the connection remains idle
-  /// (i.e. has no open streams) for this period of time then the connection will be gracefully
-  /// closed.
-  private var maxIdleTimer: Timer?
+  /// The timer used to gracefully close idle connections.
+  private var maxIdleTimerHandler: Timer<MaxIdleTimerHandlerView>?
 
-  /// The maximum age of a connection. If the connection remains open after this amount of time
-  /// then it will be gracefully closed.
-  private var maxAgeTimer: Timer?
+  /// The timer used to gracefully close old connections.
+  private var maxAgeTimerHandler: Timer<MaxAgeTimerHandlerView>?
 
-  /// The maximum amount of time a connection may spend closing gracefully, after which it is
-  /// closed abruptly. The timer starts after the second GOAWAY frame has been sent.
-  private var maxGraceTimer: Timer?
+  /// The timer used to forcefully close a connection during a graceful close.
+  /// The timer starts after the second GOAWAY frame has been sent.
+  private var maxGraceTimerHandler: Timer<MaxGraceTimerHandlerView>?
 
-  /// The amount of time to wait before sending a keep alive ping.
-  private var keepaliveTimer: Timer?
+  /// The timer used to send keep-alive pings.
+  private var keepaliveTimerHandler: Timer<KeepaliveTimerHandlerView>?
 
-  /// The amount of time the client has to reply after sending a keep alive ping. Only used if
-  /// `keepaliveTimer` is set.
-  private var keepaliveTimeoutTimer: Timer
+  /// The timer used to detect keep alive timeouts, if keep-alive pings are enabled.
+  private var keepaliveTimeoutHandler: Timer<KeepaliveTimeoutHandlerView>?
 
   /// Opaque data sent in keep alive pings.
   private let keepalivePingData: HTTP2PingData
@@ -222,14 +218,6 @@ package final class ServerConnectionManagementHandler: ChannelDuplexHandler {
   ) {
     self.eventLoop = eventLoop
 
-    self.maxIdleTimer = maxIdleTime.map { Timer(delay: $0) }
-    self.maxAgeTimer = maxAge.map { Timer(delay: $0) }
-    self.maxGraceTimer = maxGraceTime.map { Timer(delay: $0) }
-
-    self.keepaliveTimer = keepaliveTime.map { Timer(delay: $0) }
-    // Always create a keep alive timeout timer, it's only used if there is a keep alive timer.
-    self.keepaliveTimeoutTimer = Timer(delay: keepaliveTimeout ?? .seconds(20))
-
     // Generate a random value to be used as keep alive ping data.
     let pingData = UInt64.random(in: .min ... .max)
     self.keepalivePingData = HTTP2PingData(withInteger: pingData)
@@ -246,6 +234,47 @@ package final class ServerConnectionManagementHandler: ChannelDuplexHandler {
     self.frameStats = FrameStats()
 
     self.requireALPN = requireALPN
+
+    if let maxIdleTime {
+      self.maxIdleTimerHandler = Timer(
+        eventLoop: eventLoop,
+        duration: maxIdleTime,
+        repeating: false,
+        handler: MaxIdleTimerHandlerView(self)
+      )
+    }
+    if let maxAge {
+      self.maxAgeTimerHandler = Timer(
+        eventLoop: eventLoop,
+        duration: maxAge,
+        repeating: false,
+        handler: MaxAgeTimerHandlerView(self)
+      )
+    }
+    if let maxGraceTime {
+      self.maxGraceTimerHandler = Timer(
+        eventLoop: eventLoop,
+        duration: maxGraceTime,
+        repeating: false,
+        handler: MaxGraceTimerHandlerView(self)
+      )
+    }
+    if let keepaliveTime {
+      let keepaliveTimeout = keepaliveTimeout ?? .seconds(20)
+      // NOTE: The use of a non-repeating timer is deliberate for the server, and is different from the client.
+      self.keepaliveTimerHandler = Timer(
+        eventLoop: eventLoop,
+        duration: keepaliveTime,
+        repeating: false,
+        handler: KeepaliveTimerHandlerView(self)
+      )
+      self.keepaliveTimeoutHandler = Timer(
+        eventLoop: eventLoop,
+        duration: keepaliveTimeout,
+        repeating: false,
+        handler: KeepaliveTimeoutHandlerView(self)
+      )
+    }
   }
 
   package func handlerAdded(context: ChannelHandlerContext) {
@@ -258,29 +287,18 @@ package final class ServerConnectionManagementHandler: ChannelDuplexHandler {
   }
 
   package func channelActive(context: ChannelHandlerContext) {
-    let view = LoopBoundView(handler: self, context: context)
-
-    self.maxAgeTimer?.schedule(on: context.eventLoop) {
-      view.initiateGracefulShutdown()
-    }
-
-    self.maxIdleTimer?.schedule(on: context.eventLoop) {
-      view.initiateGracefulShutdown()
-    }
-
-    self.keepaliveTimer?.schedule(on: context.eventLoop) {
-      view.keepaliveTimerFired()
-    }
-
+    self.maxAgeTimerHandler?.start()
+    self.maxIdleTimerHandler?.start()
+    self.keepaliveTimerHandler?.start()
     context.fireChannelActive()
   }
 
   package func channelInactive(context: ChannelHandlerContext) {
-    self.maxIdleTimer?.cancel()
-    self.maxAgeTimer?.cancel()
-    self.maxGraceTimer?.cancel()
-    self.keepaliveTimer?.cancel()
-    self.keepaliveTimeoutTimer.cancel()
+    self.maxIdleTimerHandler?.cancel()
+    self.maxAgeTimerHandler?.cancel()
+    self.maxGraceTimerHandler?.cancel()
+    self.keepaliveTimerHandler?.cancel()
+    self.keepaliveTimeoutHandler?.cancel()
     context.fireChannelInactive()
   }
 
@@ -293,7 +311,7 @@ package final class ServerConnectionManagementHandler: ChannelDuplexHandler {
       self._streamClosed(event.streamID, channel: context.channel)
 
     case is ChannelShouldQuiesceEvent:
-      self.initiateGracefulShutdown(context: context)
+      self.initiateGracefulShutdown()
 
     case TLSUserEvent.handshakeCompleted(let negotiatedProtocol):
       if negotiatedProtocol == nil, self.requireALPN {
@@ -349,8 +367,8 @@ package final class ServerConnectionManagementHandler: ChannelDuplexHandler {
     self.inReadLoop = true
 
     // Any read data indicates that the connection is alive so cancel the keep-alive timers.
-    self.keepaliveTimer?.cancel()
-    self.keepaliveTimeoutTimer.cancel()
+    self.keepaliveTimerHandler?.cancel()
+    self.keepaliveTimeoutHandler?.cancel()
 
     let frame = self.unwrapInboundIn(data)
     switch frame.payload {
@@ -377,10 +395,7 @@ package final class ServerConnectionManagementHandler: ChannelDuplexHandler {
     self.inReadLoop = false
 
     // Done reading: schedule the keep-alive timer.
-    let view = LoopBoundView(handler: self, context: context)
-    self.keepaliveTimer?.schedule(on: context.eventLoop) {
-      view.keepaliveTimerFired()
-    }
+    self.keepaliveTimerHandler?.start()
 
     context.fireChannelReadComplete()
   }
@@ -390,26 +405,71 @@ package final class ServerConnectionManagementHandler: ChannelDuplexHandler {
   }
 }
 
+// Timer handler views.
 extension ServerConnectionManagementHandler {
-  struct LoopBoundView: @unchecked Sendable {
+  struct MaxIdleTimerHandlerView: @unchecked Sendable, NIOScheduledCallbackHandler {
     private let handler: ServerConnectionManagementHandler
-    private let context: ChannelHandlerContext
 
-    init(handler: ServerConnectionManagementHandler, context: ChannelHandlerContext) {
+    init(_ handler: ServerConnectionManagementHandler) {
       self.handler = handler
-      self.context = context
     }
 
-    func initiateGracefulShutdown() {
-      self.context.eventLoop.assertInEventLoop()
-      self.handler.initiateGracefulShutdown(context: self.context)
+    func handleScheduledCallback(eventLoop: some EventLoop) {
+      self.handler.eventLoop.assertInEventLoop()
+      self.handler.initiateGracefulShutdown()
     }
+  }
+
+  struct MaxAgeTimerHandlerView: @unchecked Sendable, NIOScheduledCallbackHandler {
+    private let handler: ServerConnectionManagementHandler
 
-    func keepaliveTimerFired() {
-      self.context.eventLoop.assertInEventLoop()
-      self.handler.keepaliveTimerFired(context: self.context)
+    init(_ handler: ServerConnectionManagementHandler) {
+      self.handler = handler
     }
 
+    func handleScheduledCallback(eventLoop: some EventLoop) {
+      self.handler.eventLoop.assertInEventLoop()
+      self.handler.initiateGracefulShutdown()
+    }
+  }
+
+  struct MaxGraceTimerHandlerView: @unchecked Sendable, NIOScheduledCallbackHandler {
+    private let handler: ServerConnectionManagementHandler
+
+    init(_ handler: ServerConnectionManagementHandler) {
+      self.handler = handler
+    }
+
+    func handleScheduledCallback(eventLoop: some EventLoop) {
+      self.handler.eventLoop.assertInEventLoop()
+      self.handler.context?.close(promise: nil)
+    }
+  }
+
+  struct KeepaliveTimerHandlerView: @unchecked Sendable, NIOScheduledCallbackHandler {
+    private let handler: ServerConnectionManagementHandler
+
+    init(_ handler: ServerConnectionManagementHandler) {
+      self.handler = handler
+    }
+
+    func handleScheduledCallback(eventLoop: some EventLoop) {
+      self.handler.eventLoop.assertInEventLoop()
+      self.handler.keepaliveTimerFired()
+    }
+  }
+
+  struct KeepaliveTimeoutHandlerView: @unchecked Sendable, NIOScheduledCallbackHandler {
+    private let handler: ServerConnectionManagementHandler
+
+    init(_ handler: ServerConnectionManagementHandler) {
+      self.handler = handler
+    }
+
+    func handleScheduledCallback(eventLoop: some EventLoop) {
+      self.handler.eventLoop.assertInEventLoop()
+      self.handler.initiateGracefulShutdown()
+    }
   }
 }
 
@@ -450,7 +510,7 @@ extension ServerConnectionManagementHandler {
 
   private func _streamCreated(_ id: HTTP2StreamID, channel: any Channel) {
     // The connection isn't idle if a stream is open.
-    self.maxIdleTimer?.cancel()
+    self.maxIdleTimerHandler?.cancel()
     self.state.streamOpened(id)
   }
 
@@ -459,11 +519,7 @@ extension ServerConnectionManagementHandler {
 
     switch self.state.streamClosed(id) {
     case .startIdleTimer:
-      let loopBound = LoopBoundView(handler: self, context: context)
-      self.maxIdleTimer?.schedule(on: context.eventLoop) {
-        loopBound.initiateGracefulShutdown()
-      }
-
+      self.maxIdleTimerHandler?.start()
     case .close:
       context.close(mode: .all, promise: nil)
 
@@ -482,14 +538,15 @@ extension ServerConnectionManagementHandler {
     }
   }
 
-  private func initiateGracefulShutdown(context: ChannelHandlerContext) {
+  private func initiateGracefulShutdown() {
+    guard let context = self.context else { return }
     context.eventLoop.assertInEventLoop()
 
     // Cancel any timers if initiating shutdown.
-    self.maxIdleTimer?.cancel()
-    self.maxAgeTimer?.cancel()
-    self.keepaliveTimer?.cancel()
-    self.keepaliveTimeoutTimer.cancel()
+    self.maxIdleTimerHandler?.cancel()
+    self.maxAgeTimerHandler?.cancel()
+    self.keepaliveTimerHandler?.cancel()
+    self.keepaliveTimeoutHandler?.cancel()
 
     switch self.state.startGracefulShutdown() {
     case .sendGoAwayAndPing(let pingData):
@@ -562,10 +619,7 @@ extension ServerConnectionManagementHandler {
       } else {
         // RPCs may have a grace period for finishing once the second GOAWAY frame has finished.
         // If this is set close the connection abruptly once the grace period passes.
-        let loopBound = NIOLoopBound(context, eventLoop: context.eventLoop)
-        self.maxGraceTimer?.schedule(on: context.eventLoop) {
-          loopBound.value.close(promise: nil)
-        }
+        self.maxGraceTimerHandler?.start()
       }
 
     case .none:
@@ -573,15 +627,13 @@ extension ServerConnectionManagementHandler {
     }
   }
 
-  private func keepaliveTimerFired(context: ChannelHandlerContext) {
+  private func keepaliveTimerFired() {
+    guard let context = self.context else { return }
     let ping = HTTP2Frame(streamID: .rootStream, payload: .ping(self.keepalivePingData, ack: false))
     context.write(self.wrapInboundOut(ping), promise: nil)
     self.maybeFlush(context: context)
 
     // Schedule a timeout on waiting for the response.
-    let loopBound = LoopBoundView(handler: self, context: context)
-    self.keepaliveTimeoutTimer.schedule(on: context.eventLoop) {
-      loopBound.initiateGracefulShutdown()
-    }
+    self.keepaliveTimeoutHandler?.start()
   }
 }

+ 63 - 45
Tests/GRPCNIOTransportCoreTests/Internal/TimerTests.swift

@@ -16,82 +16,100 @@
 
 import GRPCCore
 import GRPCNIOTransportCore
+import NIOCore
 import NIOEmbedded
 import Synchronization
 import XCTest
 
 internal final class TimerTests: XCTestCase {
-  func testScheduleOneOffTimer() {
+  fileprivate struct CounterTimerHandler: NIOScheduledCallbackHandler {
+    let counter = AtomicCounter(0)
+
+    func handleScheduledCallback(eventLoop: some EventLoop) {
+      counter.increment()
+    }
+  }
+
+  func testOneOffTimer() {
     let loop = EmbeddedEventLoop()
     defer { try! loop.close() }
 
-    let value = Atomic(0)
-    var timer = Timer(delay: .seconds(1), repeat: false)
-    timer.schedule(on: loop) {
-      let (old, _) = value.add(1, ordering: .releasing)
-      XCTAssertEqual(old, 0)
-    }
+    let handler = CounterTimerHandler()
+    let timer = Timer(eventLoop: loop, duration: .seconds(1), repeating: false, handler: handler)
+    timer.start()
 
+    // Timer hasn't fired because we haven't reached the required duration.
     loop.advanceTime(by: .milliseconds(999))
-    XCTAssertEqual(value.load(ordering: .acquiring), 0)
+    XCTAssertEqual(handler.counter.value, 0)
+
+    // Timer has fired once.
     loop.advanceTime(by: .milliseconds(1))
-    XCTAssertEqual(value.load(ordering: .acquiring), 1)
+    XCTAssertEqual(handler.counter.value, 1)
 
-    // Run again to make sure the task wasn't repeated.
+    // Timer does not repeat.
     loop.advanceTime(by: .seconds(1))
-    XCTAssertEqual(value.load(ordering: .acquiring), 1)
-  }
+    XCTAssertEqual(handler.counter.value, 1)
 
-  func testCancelOneOffTimer() {
-    let loop = EmbeddedEventLoop()
-    defer { try! loop.close() }
-
-    var timer = Timer(delay: .seconds(1), repeat: false)
-    timer.schedule(on: loop) {
-      XCTFail("Timer wasn't cancelled")
-    }
+    // Timer can be restarted and then fires again after the duration.
+    timer.start()
+    loop.advanceTime(by: .seconds(1))
+    XCTAssertEqual(handler.counter.value, 2)
 
+    // Timer can be cancelled before the duration and then does not fire.
+    timer.start()
     loop.advanceTime(by: .milliseconds(999))
     timer.cancel()
     loop.advanceTime(by: .milliseconds(1))
+    XCTAssertEqual(handler.counter.value, 2)
+
+    // Timer can be restarted after being cancelled.
+    timer.start()
+    loop.advanceTime(by: .seconds(1))
+    XCTAssertEqual(handler.counter.value, 3)
   }
 
-  func testScheduleRepeatedTimer() throws {
+  func testRepeatedTimer() {
     let loop = EmbeddedEventLoop()
     defer { try! loop.close() }
 
-    let counter = AtomicCounter()
-    var timer = Timer(delay: .seconds(1), repeat: true)
-    timer.schedule(on: loop) {
-      counter.increment()
-    }
+    let handler = CounterTimerHandler()
+    let timer = Timer(eventLoop: loop, duration: .seconds(1), repeating: true, handler: handler)
+    timer.start()
 
+    // Timer hasn't fired because we haven't reached the required duration.
     loop.advanceTime(by: .milliseconds(999))
-    XCTAssertEqual(counter.value, 0)
+    XCTAssertEqual(handler.counter.value, 0)
+
+    // Timer has fired once.
+    loop.advanceTime(by: .milliseconds(1))
+    XCTAssertEqual(handler.counter.value, 1)
+
+    // Timer hasn't fired again because we haven't reached the required duration again.
+    loop.advanceTime(by: .milliseconds(999))
+    XCTAssertEqual(handler.counter.value, 1)
+
+    // Timer has fired again.
     loop.advanceTime(by: .milliseconds(1))
-    XCTAssertEqual(counter.value, 1)
+    XCTAssertEqual(handler.counter.value, 2)
 
+    // Timer continues to fire on each second.
     loop.advanceTime(by: .seconds(1))
-    XCTAssertEqual(counter.value, 2)
+    XCTAssertEqual(handler.counter.value, 3)
     loop.advanceTime(by: .seconds(1))
-    XCTAssertEqual(counter.value, 3)
-
-    timer.cancel()
+    XCTAssertEqual(handler.counter.value, 4)
     loop.advanceTime(by: .seconds(1))
-    XCTAssertEqual(counter.value, 3)
-  }
+    XCTAssertEqual(handler.counter.value, 5)
+    loop.advanceTime(by: .seconds(5))
+    XCTAssertEqual(handler.counter.value, 10)
 
-  func testCancelRepeatedTimer() {
-    let loop = EmbeddedEventLoop()
-    defer { try! loop.close() }
-
-    var timer = Timer(delay: .seconds(1), repeat: true)
-    timer.schedule(on: loop) {
-      XCTFail("Timer wasn't cancelled")
-    }
-
-    loop.advanceTime(by: .milliseconds(999))
+    // Timer does not fire again, after being cancelled.
     timer.cancel()
-    loop.advanceTime(by: .milliseconds(1))
+    loop.advanceTime(by: .seconds(5))
+    XCTAssertEqual(handler.counter.value, 10)
+
+    // Timer can be restarted after being cancelled and continues to fire once per second.
+    timer.start()
+    loop.advanceTime(by: .seconds(5))
+    XCTAssertEqual(handler.counter.value, 15)
   }
 }