Browse Source

Conform client/server connection handlers to h2 stream delegate (#1875)

Motivation:

The natural NIO API to use when dealing with HTTP/2 multiplexing doesn't
emit user inbound events when streams are created and closed. Instead
this is signalled via the `NIOHTTP2StreamDelegate`.

Modifications:

- Conform the `ClientConnectionHandler` and the
  `ServerConnectionManagementHandler` to `NIOHTTP2StreamDelegate`.
- Update tests

Result:

`ClientConnectionHandler` and `ServerConnectionManagementHandler` can be
used as stream delegates with NIO HTTP/2 is configured for async
multiplexing.
George Barnett 1 year ago
parent
commit
8b6a8f4098

+ 46 - 23
Sources/GRPCHTTP2Core/Client/Connection/ClientConnectionHandler.swift

@@ -86,6 +86,9 @@ final class ClientConnectionHandler: ChannelInboundHandler, ChannelOutboundHandl
   /// Resets once `channelReadComplete` returns.
   private var inReadLoop: Bool
 
+  /// The context of the channel this handler is in.
+  private var context: ChannelHandlerContext?
+
   /// Creates a new handler which manages the lifecycle of a connection.
   ///
   /// - Parameters:
@@ -118,6 +121,11 @@ final class ClientConnectionHandler: ChannelInboundHandler, ChannelOutboundHandl
 
   func handlerAdded(context: ChannelHandlerContext) {
     assert(context.eventLoop === self.eventLoop)
+    self.context = context
+  }
+
+  func handlerRemoved(context: ChannelHandlerContext) {
+    self.context = nil
   }
 
   func channelActive(context: ChannelHandlerContext) {
@@ -144,31 +152,10 @@ final class ClientConnectionHandler: ChannelInboundHandler, ChannelOutboundHandl
   func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
     switch event {
     case let event as NIOHTTP2StreamCreatedEvent:
-      // Stream created, so the connection isn't idle.
-      self.maxIdleTimer?.cancel()
-      self.state.streamOpened(event.streamID)
+      self.streamCreated(event.streamID, channel: context.channel)
 
     case let event as StreamClosedEvent:
-      switch self.state.streamClosed(event.streamID) {
-      case .startIdleTimer(let cancelKeepalive):
-        // All streams are closed, restart the idle timer, and stop the keep-alive timer (it may
-        // not stop if keep-alive is allowed when there are no active calls).
-        self.maxIdleTimer?.schedule(on: context.eventLoop) {
-          self.maxIdleTimerFired(context: context)
-        }
-
-        if cancelKeepalive {
-          self.keepaliveTimer?.cancel()
-        }
-
-      case .close:
-        // Connection was closing but waiting for all streams to close. They must all be closed
-        // now so close the connection.
-        context.close(promise: nil)
-
-      case .none:
-        ()
-      }
+      self.streamClosed(event.streamID, channel: context.channel)
 
     default:
       ()
@@ -263,6 +250,42 @@ final class ClientConnectionHandler: ChannelInboundHandler, ChannelOutboundHandl
   }
 }
 
+extension ClientConnectionHandler: NIOHTTP2StreamDelegate {
+  func streamCreated(_ id: HTTP2StreamID, channel: any Channel) {
+    self.eventLoop.assertInEventLoop()
+
+    // Stream created, so the connection isn't idle.
+    self.maxIdleTimer?.cancel()
+    self.state.streamOpened(id)
+  }
+
+  func streamClosed(_ id: HTTP2StreamID, channel: any Channel) {
+    guard let context = self.context else { return }
+    self.eventLoop.assertInEventLoop()
+
+    switch self.state.streamClosed(id) {
+    case .startIdleTimer(let cancelKeepalive):
+      // All streams are closed, restart the idle timer, and stop the keep-alive timer (it may
+      // not stop if keep-alive is allowed when there are no active calls).
+      self.maxIdleTimer?.schedule(on: context.eventLoop) {
+        self.maxIdleTimerFired(context: context)
+      }
+
+      if cancelKeepalive {
+        self.keepaliveTimer?.cancel()
+      }
+
+    case .close:
+      // Connection was closing but waiting for all streams to close. They must all be closed
+      // now so close the connection.
+      context.close(promise: nil)
+
+    case .none:
+      ()
+    }
+  }
+}
+
 extension ClientConnectionHandler {
   private func maybeFlush(context: ChannelHandlerContext) {
     if self.inReadLoop {

+ 35 - 15
Sources/GRPCHTTP2Core/Server/Connection/ServerConnectionManagementHandler.swift

@@ -75,6 +75,9 @@ final class ServerConnectionManagementHandler: ChannelDuplexHandler {
   /// Resets once `channelReadComplete` returns.
   private var inReadLoop: Bool
 
+  /// The context of the channel this handler is in.
+  private var context: ChannelHandlerContext?
+
   /// The current state of the connection.
   private var state: StateMachine
 
@@ -236,6 +239,11 @@ final class ServerConnectionManagementHandler: ChannelDuplexHandler {
 
   func handlerAdded(context: ChannelHandlerContext) {
     assert(context.eventLoop === self.eventLoop)
+    self.context = context
+  }
+
+  func handlerRemoved(context: ChannelHandlerContext) {
+    self.context = nil
   }
 
   func channelActive(context: ChannelHandlerContext) {
@@ -266,23 +274,10 @@ final class ServerConnectionManagementHandler: ChannelDuplexHandler {
   func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
     switch event {
     case let event as NIOHTTP2StreamCreatedEvent:
-      // The connection isn't idle if a stream is open.
-      self.maxIdleTimer?.cancel()
-      self.state.streamOpened(event.streamID)
+      self.streamCreated(event.streamID, channel: context.channel)
 
     case let event as StreamClosedEvent:
-      switch self.state.streamClosed(event.streamID) {
-      case .startIdleTimer:
-        self.maxIdleTimer?.schedule(on: context.eventLoop) {
-          self.initiateGracefulShutdown(context: context)
-        }
-
-      case .close:
-        context.close(mode: .all, promise: nil)
-
-      case .none:
-        ()
-      }
+      self.streamClosed(event.streamID, channel: context.channel)
 
     default:
       ()
@@ -335,6 +330,31 @@ final class ServerConnectionManagementHandler: ChannelDuplexHandler {
   }
 }
 
+extension ServerConnectionManagementHandler: NIOHTTP2StreamDelegate {
+  func streamCreated(_ id: HTTP2StreamID, channel: any Channel) {
+    // The connection isn't idle if a stream is open.
+    self.maxIdleTimer?.cancel()
+    self.state.streamOpened(id)
+  }
+
+  func streamClosed(_ id: HTTP2StreamID, channel: any Channel) {
+    guard let context = self.context else { return }
+
+    switch self.state.streamClosed(id) {
+    case .startIdleTimer:
+      self.maxIdleTimer?.schedule(on: context.eventLoop) {
+        self.initiateGracefulShutdown(context: context)
+      }
+
+    case .close:
+      context.close(mode: .all, promise: nil)
+
+    case .none:
+      ()
+    }
+  }
+}
+
 extension ServerConnectionManagementHandler {
   private func maybeFlush(context: ChannelHandlerContext) {
     if self.inReadLoop {

+ 4 - 8
Tests/GRPCHTTP2CoreTests/Client/Connection/ClientConnectionHandlerTests.swift

@@ -226,6 +226,7 @@ final class ClientConnectionHandlerTests: XCTestCase {
 extension ClientConnectionHandlerTests {
   struct Connection {
     let channel: EmbeddedChannel
+    let streamDelegate: any NIOHTTP2StreamDelegate
     var loop: EmbeddedEventLoop {
       self.channel.embeddedEventLoop
     }
@@ -245,6 +246,7 @@ extension ClientConnectionHandlerTests {
         keepaliveWithoutCalls: allowKeepaliveWithoutCalls
       )
 
+      self.streamDelegate = handler
       self.channel = EmbeddedChannel(handler: handler, loop: loop)
     }
 
@@ -253,17 +255,11 @@ extension ClientConnectionHandlerTests {
     }
 
     func streamOpened(_ id: HTTP2StreamID) {
-      let event = NIOHTTP2StreamCreatedEvent(
-        streamID: id,
-        localInitialWindowSize: nil,
-        remoteInitialWindowSize: nil
-      )
-      self.channel.pipeline.fireUserInboundEventTriggered(event)
+      self.streamDelegate.streamCreated(id, channel: self.channel)
     }
 
     func streamClosed(_ id: HTTP2StreamID) {
-      let event = StreamClosedEvent(streamID: id, reason: nil)
-      self.channel.pipeline.fireUserInboundEventTriggered(event)
+      self.streamDelegate.streamClosed(id, channel: self.channel)
     }
 
     func goAway(

+ 4 - 8
Tests/GRPCHTTP2CoreTests/Server/Connection/ServerConnectionManagementHandlerTests.swift

@@ -341,6 +341,7 @@ extension ServerConnectionManagementHandlerTests {
 extension ServerConnectionManagementHandlerTests {
   struct Connection {
     let channel: EmbeddedChannel
+    let streamDelegate: any NIOHTTP2StreamDelegate
     let syncView: ServerConnectionManagementHandler.SyncView
 
     var loop: EmbeddedEventLoop {
@@ -378,6 +379,7 @@ extension ServerConnectionManagementHandlerTests {
         clock: self.clock
       )
 
+      self.streamDelegate = handler
       self.syncView = handler.syncView
       self.channel = EmbeddedChannel(handler: handler, loop: loop)
     }
@@ -398,17 +400,11 @@ extension ServerConnectionManagementHandlerTests {
     }
 
     func streamOpened(_ id: HTTP2StreamID) {
-      let event = NIOHTTP2StreamCreatedEvent(
-        streamID: id,
-        localInitialWindowSize: nil,
-        remoteInitialWindowSize: nil
-      )
-      self.channel.pipeline.fireUserInboundEventTriggered(event)
+      self.streamDelegate.streamCreated(id, channel: self.channel)
     }
 
     func streamClosed(_ id: HTTP2StreamID) {
-      let event = StreamClosedEvent(streamID: id, reason: nil)
-      self.channel.pipeline.fireUserInboundEventTriggered(event)
+      self.streamDelegate.streamClosed(id, channel: self.channel)
     }
 
     func ping(data: HTTP2PingData, ack: Bool) throws {