Browse Source

Fix keepalive logic (#50)

This PR fixes https://github.com/grpc/grpc-swift/issues/2095.

## Motivation

As per the gRPC specification, the server must keep track of pings from
each client, and if they go over a threshold, we must send a GOAWAY
frame and close the connection. We must reset the number of ping strikes
every time the server writes a headers or data frame. However, there is
a bug in the current keepalive implementation and we are not properly
keeping track of when header/data frames are written, so we never reset
the strikes, causing the server to always end up closing connections
when keepalive pings are enabled.

There was also a second bug where the GOAWAY frame wasn't actually sent
to the client because we were closing the connection straight away, and
the packet never made it out.

## Modifications

This PR fixes a couple of bugs:
- It keeps track of the appropriate FrameStats as described above
- It delays the channel close after sending the GOAWAY packet by a tick
to make sure it gets flushed and delivered to the client

## Results

Fewer bugs!
Gus Cairo 11 months ago
parent
commit
47e0be124e

+ 6 - 1
Sources/GRPCNIOTransportCore/Internal/NIOChannelPipeline+GRPC.swift

@@ -71,6 +71,10 @@ extension ChannelPipeline.SynchronousOperations {
     var http2HandlerStreamConfiguration = NIOHTTP2Handler.StreamConfiguration()
     http2HandlerStreamConfiguration.targetWindowSize = clampedTargetWindowSize
 
+    let boundConnectionManagementHandler = NIOLoopBound(
+      serverConnectionHandler.syncView,
+      eventLoop: self.eventLoop
+    )
     let streamMultiplexer = try self.configureAsyncHTTP2Pipeline(
       mode: .server,
       streamDelegate: serverConnectionHandler.http2StreamDelegate,
@@ -86,7 +90,8 @@ extension ChannelPipeline.SynchronousOperations {
           acceptedEncodings: compressionConfig.enabledAlgorithms,
           maxPayloadSize: rpcConfig.maxRequestPayloadSize,
           methodDescriptorPromise: methodDescriptorPromise,
-          eventLoop: streamChannel.eventLoop
+          eventLoop: streamChannel.eventLoop,
+          connectionManagementHandler: boundConnectionManagementHandler.value
         )
         try streamChannel.pipeline.syncOperations.addHandler(streamHandler)
 

+ 9 - 3
Sources/GRPCNIOTransportCore/Server/Connection/ServerConnectionManagementHandler.swift

@@ -121,9 +121,9 @@ package final class ServerConnectionManagementHandler: ChannelDuplexHandler {
   }
 
   /// Stats about recently written frames. Used to determine whether to reset keep-alive state.
-  private var frameStats: FrameStats
+  package var frameStats: FrameStats
 
-  struct FrameStats {
+  package struct FrameStats {
     private(set) var didWriteHeadersOrData = false
 
     /// Mark that a HEADERS frame has been written.
@@ -609,7 +609,13 @@ extension ServerConnectionManagementHandler {
 
       context.write(self.wrapOutboundOut(goAway), promise: nil)
       self.maybeFlush(context: context)
-      context.close(promise: nil)
+
+      // We must delay the channel close after sending the GOAWAY packet by a tick to make sure it
+      // gets flushed and delivered to the client before the connection is closed.
+      let loopBound = NIOLoopBound(context, eventLoop: context.eventLoop)
+      context.eventLoop.execute {
+        loopBound.value.close(promise: nil)
+      }
 
     case .sendAck:
       ()  // ACKs are sent by NIO's HTTP/2 handler, don't double ack.

+ 9 - 0
Sources/GRPCNIOTransportCore/Server/GRPCServerStreamHandler.swift

@@ -42,6 +42,8 @@ package final class GRPCServerStreamHandler: ChannelDuplexHandler, RemovableChan
 
   private var cancellationHandle: Optional<ServerContext.RPCCancellationHandle>
 
+  package let connectionManagementHandler: ServerConnectionManagementHandler.SyncView
+
   // Existential errors unconditionally allocate, avoid this per-use allocation by doing it
   // statically.
   private static let handlerRemovedBeforeDescriptorResolved: any Error = RPCError(
@@ -55,6 +57,7 @@ package final class GRPCServerStreamHandler: ChannelDuplexHandler, RemovableChan
     maxPayloadSize: Int,
     methodDescriptorPromise: EventLoopPromise<MethodDescriptor>,
     eventLoop: any EventLoop,
+    connectionManagementHandler: ServerConnectionManagementHandler.SyncView,
     cancellationHandler: ServerContext.RPCCancellationHandle? = nil,
     skipStateMachineAssertions: Bool = false
   ) {
@@ -66,6 +69,7 @@ package final class GRPCServerStreamHandler: ChannelDuplexHandler, RemovableChan
     self.methodDescriptorPromise = methodDescriptorPromise
     self.cancellationHandle = cancellationHandler
     self.eventLoop = eventLoop
+    self.connectionManagementHandler = connectionManagementHandler
   }
 
   package func setCancellationHandle(_ handle: ServerContext.RPCCancellationHandle) {
@@ -136,13 +140,16 @@ extension GRPCServerStreamHandler {
               switch self.stateMachine.nextInboundMessage() {
               case .receiveMessage(let message):
                 context.fireChannelRead(self.wrapInboundOut(.message(message)))
+
               case .awaitMoreMessages:
                 break loop
+
               case .noMoreMessages:
                 context.fireUserInboundEventTriggered(ChannelEvent.inputClosed)
                 break loop
               }
             }
+
           case .doNothing:
             ()
           }
@@ -261,6 +268,7 @@ extension GRPCServerStreamHandler {
         self.flushPending = true
         let headers = try self.stateMachine.send(metadata: metadata)
         context.write(self.wrapOutboundOut(.headers(.init(headers: headers))), promise: promise)
+        self.connectionManagementHandler.wroteHeadersFrame()
       } catch let invalidState {
         let error = RPCError(invalidState)
         promise?.fail(error)
@@ -270,6 +278,7 @@ extension GRPCServerStreamHandler {
     case .message(let message):
       do {
         try self.stateMachine.send(message: message, promise: promise)
+        self.connectionManagementHandler.wroteDataFrame()
       } catch let invalidState {
         let error = RPCError(invalidState)
         promise?.fail(error)

+ 13 - 1
Tests/GRPCNIOTransportCoreTests/Client/Connection/Utilities/ConnectionTest.swift

@@ -114,12 +114,24 @@ extension ConnectionTest {
             let h2 = NIOHTTP2Handler(mode: .server)
             let mux = HTTP2StreamMultiplexer(mode: .server, channel: channel) { stream in
               let sync = stream.pipeline.syncOperations
+              let connectionManagementHandler = ServerConnectionManagementHandler(
+                eventLoop: stream.eventLoop,
+                maxIdleTime: nil,
+                maxAge: nil,
+                maxGraceTime: nil,
+                keepaliveTime: nil,
+                keepaliveTimeout: nil,
+                allowKeepaliveWithoutCalls: false,
+                minPingIntervalWithoutCalls: .minutes(5),
+                requireALPN: false
+              )
               let handler = GRPCServerStreamHandler(
                 scheme: .http,
                 acceptedEncodings: .none,
                 maxPayloadSize: .max,
                 methodDescriptorPromise: channel.eventLoop.makePromise(of: MethodDescriptor.self),
-                eventLoop: stream.eventLoop
+                eventLoop: stream.eventLoop,
+                connectionManagementHandler: connectionManagementHandler.syncView
               )
 
               return stream.eventLoop.makeCompletedFuture {

+ 13 - 1
Tests/GRPCNIOTransportCoreTests/Client/Connection/Utilities/TestServer.swift

@@ -70,12 +70,24 @@ final class TestServer: Sendable {
         let sync = channel.pipeline.syncOperations
         let multiplexer = try sync.configureAsyncHTTP2Pipeline(mode: .server) { stream in
           stream.eventLoop.makeCompletedFuture {
+            let connectionManagementHandler = ServerConnectionManagementHandler(
+              eventLoop: stream.eventLoop,
+              maxIdleTime: nil,
+              maxAge: nil,
+              maxGraceTime: nil,
+              keepaliveTime: nil,
+              keepaliveTimeout: nil,
+              allowKeepaliveWithoutCalls: false,
+              minPingIntervalWithoutCalls: .minutes(5),
+              requireALPN: false
+            )
             let handler = GRPCServerStreamHandler(
               scheme: .http,
               acceptedEncodings: .all,
               maxPayloadSize: .max,
               methodDescriptorPromise: channel.eventLoop.makePromise(of: MethodDescriptor.self),
-              eventLoop: stream.eventLoop
+              eventLoop: stream.eventLoop,
+              connectionManagementHandler: connectionManagementHandler.syncView
             )
 
             try stream.pipeline.syncOperations.addHandlers(handler)

+ 86 - 6
Tests/GRPCNIOTransportCoreTests/Server/GRPCServerStreamHandlerTests.swift

@@ -33,12 +33,25 @@ final class GRPCServerStreamHandlerTests: XCTestCase {
     descriptorPromise: EventLoopPromise<MethodDescriptor>? = nil,
     disableAssertions: Bool = false
   ) -> GRPCServerStreamHandler {
+    let serverConnectionManagementHandler = ServerConnectionManagementHandler(
+      eventLoop: channel.eventLoop,
+      maxIdleTime: nil,
+      maxAge: nil,
+      maxGraceTime: nil,
+      keepaliveTime: nil,
+      keepaliveTimeout: nil,
+      allowKeepaliveWithoutCalls: false,
+      minPingIntervalWithoutCalls: .minutes(5),
+      requireALPN: false
+    )
+
     return GRPCServerStreamHandler(
       scheme: scheme,
       acceptedEncodings: acceptedEncodings,
       maxPayloadSize: maxPayloadSize,
       methodDescriptorPromise: descriptorPromise ?? channel.eventLoop.makePromise(),
       eventLoop: channel.eventLoop,
+      connectionManagementHandler: serverConnectionManagementHandler.syncView,
       skipStateMachineAssertions: disableAssertions
     )
   }
@@ -974,28 +987,50 @@ final class GRPCServerStreamHandlerTests: XCTestCase {
 }
 
 struct ServerStreamHandlerTests {
-  private func makeServerStreamHandler(
+  struct ConnectionAndStreamHandlers {
+    let streamHandler: GRPCServerStreamHandler
+    let connectionHandler: ServerConnectionManagementHandler
+  }
+
+  private func makeServerConnectionAndStreamHandlers(
     channel: any Channel,
     scheme: Scheme = .http,
     acceptedEncodings: CompressionAlgorithmSet = [],
     maxPayloadSize: Int = .max,
     descriptorPromise: EventLoopPromise<MethodDescriptor>? = nil,
     disableAssertions: Bool = false
-  ) -> GRPCServerStreamHandler {
-    return GRPCServerStreamHandler(
+  ) -> ConnectionAndStreamHandlers {
+    let connectionManagementHandler = ServerConnectionManagementHandler(
+      eventLoop: channel.eventLoop,
+      maxIdleTime: nil,
+      maxAge: nil,
+      maxGraceTime: nil,
+      keepaliveTime: nil,
+      keepaliveTimeout: nil,
+      allowKeepaliveWithoutCalls: false,
+      minPingIntervalWithoutCalls: .minutes(5),
+      requireALPN: false
+    )
+    let streamHandler = GRPCServerStreamHandler(
       scheme: scheme,
       acceptedEncodings: acceptedEncodings,
       maxPayloadSize: maxPayloadSize,
       methodDescriptorPromise: descriptorPromise ?? channel.eventLoop.makePromise(),
       eventLoop: channel.eventLoop,
+      connectionManagementHandler: connectionManagementHandler.syncView,
       skipStateMachineAssertions: disableAssertions
     )
+
+    return ConnectionAndStreamHandlers(
+      streamHandler: streamHandler,
+      connectionHandler: connectionManagementHandler
+    )
   }
 
   @Test("ChannelShouldQuiesceEvent is buffered and turns into RPC cancellation")
   func shouldQuiesceEventIsBufferedBeforeHandleIsSet() async throws {
     let channel = EmbeddedChannel()
-    let handler = self.makeServerStreamHandler(channel: channel)
+    let handler = self.makeServerConnectionAndStreamHandlers(channel: channel).streamHandler
     try channel.pipeline.syncOperations.addHandler(handler)
     channel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent())
 
@@ -1011,7 +1046,7 @@ struct ServerStreamHandlerTests {
   @Test("ChannelShouldQuiesceEvent turns into RPC cancellation")
   func shouldQuiesceEventTriggersCancellation() async throws {
     let channel = EmbeddedChannel()
-    let handler = self.makeServerStreamHandler(channel: channel)
+    let handler = self.makeServerConnectionAndStreamHandlers(channel: channel).streamHandler
     try channel.pipeline.syncOperations.addHandler(handler)
 
     await withServerContextRPCCancellationHandle { handle in
@@ -1028,7 +1063,7 @@ struct ServerStreamHandlerTests {
   @Test("RST_STREAM turns into RPC cancellation")
   func rstStreamTriggersCancellation() async throws {
     let channel = EmbeddedChannel()
-    let handler = self.makeServerStreamHandler(channel: channel)
+    let handler = self.makeServerConnectionAndStreamHandlers(channel: channel).streamHandler
     try channel.pipeline.syncOperations.addHandler(handler)
 
     await withServerContextRPCCancellationHandle { handle in
@@ -1045,6 +1080,51 @@ struct ServerStreamHandlerTests {
     _ = try? channel.finish()
   }
 
+  @Test("Connection FrameStats are updated when writing headers or data frames")
+  func connectionFrameStatsAreUpdatedAccordingly() async throws {
+    let channel = EmbeddedChannel()
+    let handlers = self.makeServerConnectionAndStreamHandlers(channel: channel)
+    try channel.pipeline.syncOperations.addHandler(handlers.streamHandler)
+
+    // We have written nothing yet, so expect FrameStats/didWriteHeadersOrData to be false
+    #expect(!handlers.connectionHandler.frameStats.didWriteHeadersOrData)
+
+    // FrameStats aren't affected by pings received
+    channel.pipeline.fireChannelRead(
+      NIOAny(HTTP2Frame.FramePayload.ping(.init(withInteger: 42), ack: false))
+    )
+    #expect(!handlers.connectionHandler.frameStats.didWriteHeadersOrData)
+
+    // Now write back headers and make sure FrameStats are updated accordingly:
+    // To do that, we first need to receive client's initial metadata...
+    let clientInitialMetadata: HPACKHeaders = [
+      GRPCHTTP2Keys.path.rawValue: "/SomeService/SomeMethod",
+      GRPCHTTP2Keys.scheme.rawValue: "http",
+      GRPCHTTP2Keys.method.rawValue: "POST",
+      GRPCHTTP2Keys.contentType.rawValue: "application/grpc",
+      GRPCHTTP2Keys.te.rawValue: "trailers",
+    ]
+    try channel.writeInbound(
+      HTTP2Frame.FramePayload.headers(.init(headers: clientInitialMetadata))
+    )
+
+    // Now we write back server's initial metadata...
+    let serverInitialMetadata = RPCResponsePart.metadata([:])
+    try channel.writeOutbound(serverInitialMetadata)
+
+    // And this should have updated the FrameStats
+    #expect(handlers.connectionHandler.frameStats.didWriteHeadersOrData)
+
+    // Manually reset the FrameStats to make sure that writing data also updates it correctly.
+    handlers.connectionHandler.frameStats.reset()
+    #expect(!handlers.connectionHandler.frameStats.didWriteHeadersOrData)
+    try channel.writeOutbound(RPCResponsePart.message([42]))
+    #expect(handlers.connectionHandler.frameStats.didWriteHeadersOrData)
+
+    // Clean up.
+    // Throwing is fine: the channel is closed abruptly, errors are expected.
+    _ = try? channel.finish()
+  }
 }
 
 extension EmbeddedChannel {