瀏覽代碼

Move server connection management tests to swift-testing (#14)

Motivation:

I'd like to add more tests to the server connection management handler.
Ideally these would be written using swift-testing.

Modifications:

- Migrate server connection management handler tests
- Use `package` access to avoid `@testable` import

Result:

Fewer XCTest tests
George Barnett 1 年之前
父節點
當前提交
bed03b0d3a

+ 31 - 31
Sources/GRPCNIOTransportCore/Server/Connection/ServerConnectionManagementHandler.swift

@@ -14,10 +14,10 @@
  * limitations under the License.
  */
 
-internal import GRPCCore
-internal import NIOCore
-internal import NIOHTTP2
-internal import NIOTLS
+private import GRPCCore
+package import NIOCore
+package import NIOHTTP2
+private import NIOTLS
 
 /// A `ChannelHandler` which manages the lifecycle of a gRPC connection over HTTP/2.
 ///
@@ -39,11 +39,11 @@ internal import NIOTLS
 /// Some of the behaviours are described in:
 /// - [gRFC A8](https://github.com/grpc/proposal/blob/master/A8-client-side-keepalive.md), and
 /// - [gRFC A9](https://github.com/grpc/proposal/blob/master/A9-server-side-conn-mgt.md).
-final class ServerConnectionManagementHandler: ChannelDuplexHandler {
-  typealias InboundIn = HTTP2Frame
-  typealias InboundOut = HTTP2Frame
-  typealias OutboundIn = HTTP2Frame
-  typealias OutboundOut = HTTP2Frame
+package final class ServerConnectionManagementHandler: ChannelDuplexHandler {
+  package typealias InboundIn = HTTP2Frame
+  package typealias InboundOut = HTTP2Frame
+  package typealias OutboundIn = HTTP2Frame
+  package typealias OutboundOut = HTTP2Frame
 
   /// The `EventLoop` of the `Channel` this handler exists in.
   private let eventLoop: any EventLoop
@@ -98,7 +98,7 @@ final class ServerConnectionManagementHandler: ChannelDuplexHandler {
   /// While NIO's `EmbeddedEventLoop` provides control over its view of time (and therefore any
   /// events scheduled on it) it doesn't offer a way to get the current time. This is usually done
   /// via `NIODeadline`.
-  enum Clock {
+  package enum Clock {
     case nio
     case manual(Manual)
 
@@ -111,14 +111,14 @@ final class ServerConnectionManagementHandler: ChannelDuplexHandler {
       }
     }
 
-    final class Manual {
+    package final class Manual {
       private(set) var time: NIODeadline
 
-      init() {
+      package init() {
         self.time = .uptimeNanoseconds(0)
       }
 
-      func advance(by amount: TimeAmount) {
+      package func advance(by amount: TimeAmount) {
         self.time = self.time + amount
       }
     }
@@ -147,7 +147,7 @@ final class ServerConnectionManagementHandler: ChannelDuplexHandler {
   }
 
   /// A synchronous view over this handler.
-  var syncView: SyncView {
+  package var syncView: SyncView {
     return SyncView(self)
   }
 
@@ -155,7 +155,7 @@ final class ServerConnectionManagementHandler: ChannelDuplexHandler {
   ///
   /// Methods on this view *must* be called from the same `EventLoop` as the `Channel` in which
   /// this handler exists.
-  struct SyncView {
+  package struct SyncView {
     private let handler: ServerConnectionManagementHandler
 
     fileprivate init(_ handler: ServerConnectionManagementHandler) {
@@ -163,7 +163,7 @@ final class ServerConnectionManagementHandler: ChannelDuplexHandler {
     }
 
     /// Notify the handler that the connection has received a flush event.
-    func connectionWillFlush() {
+    package func connectionWillFlush() {
       // The handler can't rely on `flush(context:)` due to its expected position in the pipeline.
       // It's expected to be placed after the HTTP/2 handler (i.e. closer to the application) as
       // it needs to receive HTTP/2 frames. However, flushes from stream channels aren't sent down
@@ -178,13 +178,13 @@ final class ServerConnectionManagementHandler: ChannelDuplexHandler {
     }
 
     /// Notify the handler that a HEADERS frame was written in the last write loop.
-    func wroteHeadersFrame() {
+    package func wroteHeadersFrame() {
       self.handler.eventLoop.assertInEventLoop()
       self.handler.frameStats.wroteHeaders()
     }
 
     /// Notify the handler that a DATA frame was written in the last write loop.
-    func wroteDataFrame() {
+    package func wroteDataFrame() {
       self.handler.eventLoop.assertInEventLoop()
       self.handler.frameStats.wroteData()
     }
@@ -208,7 +208,7 @@ final class ServerConnectionManagementHandler: ChannelDuplexHandler {
   ///       keep-alive pings. Pings more frequent than this interval count as 'strikes' and the
   ///       connection is closed if there are too many strikes.
   ///   - clock: A clock providing the current time.
-  init(
+  package init(
     eventLoop: any EventLoop,
     maxIdleTime: TimeAmount?,
     maxAge: TimeAmount?,
@@ -248,16 +248,16 @@ final class ServerConnectionManagementHandler: ChannelDuplexHandler {
     self.requireALPN = requireALPN
   }
 
-  func handlerAdded(context: ChannelHandlerContext) {
+  package func handlerAdded(context: ChannelHandlerContext) {
     assert(context.eventLoop === self.eventLoop)
     self.context = context
   }
 
-  func handlerRemoved(context: ChannelHandlerContext) {
+  package func handlerRemoved(context: ChannelHandlerContext) {
     self.context = nil
   }
 
-  func channelActive(context: ChannelHandlerContext) {
+  package func channelActive(context: ChannelHandlerContext) {
     let view = LoopBoundView(handler: self, context: context)
 
     self.maxAgeTimer?.schedule(on: context.eventLoop) {
@@ -275,7 +275,7 @@ final class ServerConnectionManagementHandler: ChannelDuplexHandler {
     context.fireChannelActive()
   }
 
-  func channelInactive(context: ChannelHandlerContext) {
+  package func channelInactive(context: ChannelHandlerContext) {
     self.maxIdleTimer?.cancel()
     self.maxAgeTimer?.cancel()
     self.maxGraceTimer?.cancel()
@@ -284,7 +284,7 @@ final class ServerConnectionManagementHandler: ChannelDuplexHandler {
     context.fireChannelInactive()
   }
 
-  func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
+  package func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
     switch event {
     case let event as NIOHTTP2StreamCreatedEvent:
       self._streamCreated(event.streamID, channel: context.channel)
@@ -314,7 +314,7 @@ final class ServerConnectionManagementHandler: ChannelDuplexHandler {
     context.fireUserInboundEventTriggered(event)
   }
 
-  func channelRead(context: ChannelHandlerContext, data: NIOAny) {
+  package func channelRead(context: ChannelHandlerContext, data: NIOAny) {
     self.inReadLoop = true
 
     // Any read data indicates that the connection is alive so cancel the keep-alive timers.
@@ -337,7 +337,7 @@ final class ServerConnectionManagementHandler: ChannelDuplexHandler {
     context.fireChannelRead(data)
   }
 
-  func channelReadComplete(context: ChannelHandlerContext) {
+  package func channelReadComplete(context: ChannelHandlerContext) {
     while self.flushPending {
       self.flushPending = false
       context.flush()
@@ -354,7 +354,7 @@ final class ServerConnectionManagementHandler: ChannelDuplexHandler {
     context.fireChannelReadComplete()
   }
 
-  func flush(context: ChannelHandlerContext) {
+  package func flush(context: ChannelHandlerContext) {
     self.maybeFlush(context: context)
   }
 }
@@ -383,7 +383,7 @@ extension ServerConnectionManagementHandler {
 }
 
 extension ServerConnectionManagementHandler {
-  struct HTTP2StreamDelegate: @unchecked Sendable, NIOHTTP2StreamDelegate {
+  package struct HTTP2StreamDelegate: @unchecked Sendable, NIOHTTP2StreamDelegate {
     // @unchecked is okay: the only methods do the appropriate event-loop dance.
 
     private let handler: ServerConnectionManagementHandler
@@ -392,7 +392,7 @@ extension ServerConnectionManagementHandler {
       self.handler = handler
     }
 
-    func streamCreated(_ id: HTTP2StreamID, channel: any Channel) {
+    package func streamCreated(_ id: HTTP2StreamID, channel: any Channel) {
       if self.handler.eventLoop.inEventLoop {
         self.handler._streamCreated(id, channel: channel)
       } else {
@@ -402,7 +402,7 @@ extension ServerConnectionManagementHandler {
       }
     }
 
-    func streamClosed(_ id: HTTP2StreamID, channel: any Channel) {
+    package func streamClosed(_ id: HTTP2StreamID, channel: any Channel) {
       if self.handler.eventLoop.inEventLoop {
         self.handler._streamClosed(id, channel: channel)
       } else {
@@ -413,7 +413,7 @@ extension ServerConnectionManagementHandler {
     }
   }
 
-  var http2StreamDelegate: HTTP2StreamDelegate {
+  package var http2StreamDelegate: HTTP2StreamDelegate {
     return HTTP2StreamDelegate(self)
   }
 

+ 82 - 68
Tests/GRPCNIOTransportCoreTests/Server/Connection/ServerConnectionManagementHandlerTests.swift

@@ -14,15 +14,15 @@
  * limitations under the License.
  */
 
+import GRPCNIOTransportCore
 import NIOCore
 import NIOEmbedded
 import NIOHTTP2
-import XCTest
+import Testing
 
-@testable import GRPCNIOTransportCore
-
-final class ServerConnectionManagementHandlerTests: XCTestCase {
-  func testIdleTimeoutOnNewConnection() throws {
+struct ServerConnectionManagementHandlerTests {
+  @Test("Idle timeout on new connection")
+  func idleTimeoutOnNewConnection() throws {
     let connection = try Connection(maxIdleTime: .minutes(1))
     try connection.activate()
     // Hit the max idle time.
@@ -35,7 +35,8 @@ final class ServerConnectionManagementHandlerTests: XCTestCase {
     try connection.waitUntilClosed()
   }
 
-  func testIdleTimerIsCancelledWhenStreamIsOpened() throws {
+  @Test("Idle timeout is cancelled when stream is opened")
+  func idleTimerIsCancelledWhenStreamIsOpened() throws {
     let connection = try Connection(maxIdleTime: .minutes(1))
     try connection.activate()
 
@@ -44,17 +45,18 @@ final class ServerConnectionManagementHandlerTests: XCTestCase {
     connection.advanceTime(by: .minutes(1))
 
     // No GOAWAY frame means the timer was cancelled.
-    XCTAssertNil(try connection.readFrame())
+    #expect(try connection.readFrame() == nil)
   }
 
-  func testIdleTimerStartsWhenAllStreamsAreClosed() throws {
+  @Test("Idle timer starts when all streams are closed")
+  func idleTimerStartsWhenAllStreamsAreClosed() throws {
     let connection = try Connection(maxIdleTime: .minutes(1))
     try connection.activate()
 
     // Open a stream to cancel the idle timer and run through the max idle time.
     connection.streamOpened(1)
     connection.advanceTime(by: .minutes(1))
-    XCTAssertNil(try connection.readFrame())
+    #expect(try connection.readFrame() == nil)
 
     // Close the stream to start the timer again.
     connection.streamClosed(1)
@@ -67,7 +69,8 @@ final class ServerConnectionManagementHandlerTests: XCTestCase {
     try connection.waitUntilClosed()
   }
 
-  func testMaxAge() throws {
+  @Test("Connection shutdown after max age is reached")
+  func maxAge() throws {
     let connection = try Connection(maxAge: .minutes(1))
     try connection.activate()
 
@@ -87,7 +90,8 @@ final class ServerConnectionManagementHandlerTests: XCTestCase {
     try connection.waitUntilClosed()
   }
 
-  func testGracefulShutdownRatchetsDownStreamID() throws {
+  @Test("Graceful shutdown ratchets down last stream ID")
+  func gracefulShutdownRatchetsDownStreamID() throws {
     // This test uses the idle timeout to trigger graceful shutdown. The mechanism is the same
     // regardless of how it's triggered.
     let connection = try Connection(maxIdleTime: .minutes(1))
@@ -106,7 +110,8 @@ final class ServerConnectionManagementHandlerTests: XCTestCase {
     try connection.waitUntilClosed()
   }
 
-  func testGracefulShutdownGracePeriod() throws {
+  @Test("Graceful shutdown promoted to close after grace period")
+  func gracefulShutdownGracePeriod() throws {
     // This test uses the idle timeout to trigger graceful shutdown. The mechanism is the same
     // regardless of how it's triggered.
     let connection = try Connection(
@@ -128,7 +133,8 @@ final class ServerConnectionManagementHandlerTests: XCTestCase {
     try connection.waitUntilClosed()
   }
 
-  func testKeepaliveOnNewConnection() throws {
+  @Test("Keepalive works on new connection")
+  func keepaliveOnNewConnection() throws {
     let connection = try Connection(
       keepaliveTime: .minutes(5),
       keepaliveTimeout: .seconds(5)
@@ -138,20 +144,20 @@ final class ServerConnectionManagementHandlerTests: XCTestCase {
     // Wait for the keep alive timer to fire which should cause the server to send a keep
     // alive PING.
     connection.advanceTime(by: .minutes(5))
-    let frame1 = try XCTUnwrap(connection.readFrame())
-    XCTAssertEqual(frame1.streamID, .rootStream)
-    try XCTAssertPing(frame1.payload) { data, ack in
-      XCTAssertFalse(ack)
-      // Data is opaque, send it back.
-      try connection.ping(data: data, ack: true)
-    }
+    let frame1 = try #require(try connection.readFrame())
+    #expect(frame1.streamID == .rootStream)
+    let (data, ack) = try #require(frame1.payload.ping)
+    #expect(!ack)
+    // Data is opaque, send it back.
+    try connection.ping(data: data, ack: true)
 
     // Run past the timeout, nothing should happen.
     connection.advanceTime(by: .seconds(5))
-    XCTAssertNil(try connection.readFrame())
+    #expect(try connection.readFrame() == nil)
   }
 
-  func testKeepaliveStartsAfterReadLoop() throws {
+  @Test("Keepalive starts after read loop")
+  func keepaliveStartsAfterReadLoop() throws {
     let connection = try Connection(
       keepaliveTime: .minutes(5),
       keepaliveTimeout: .seconds(5)
@@ -165,21 +171,21 @@ final class ServerConnectionManagementHandlerTests: XCTestCase {
 
     // Run out the keep alive timer, it shouldn't fire.
     connection.advanceTime(by: .minutes(5))
-    XCTAssertNil(try connection.readFrame())
+    #expect(try connection.readFrame() == nil)
 
     // Fire channel read complete to start the keep alive timer again.
     connection.channel.pipeline.fireChannelReadComplete()
 
     // Now expire the keep alive timer again, we should read out a PING frame.
     connection.advanceTime(by: .minutes(5))
-    let frame1 = try XCTUnwrap(connection.readFrame())
-    XCTAssertEqual(frame1.streamID, .rootStream)
-    XCTAssertPing(frame1.payload) { data, ack in
-      XCTAssertFalse(ack)
-    }
+    let frame1 = try #require(try connection.readFrame())
+    #expect(frame1.streamID == .rootStream)
+    let (_, ack) = try #require(frame1.payload.ping)
+    #expect(!ack)
   }
 
-  func testKeepaliveOnNewConnectionWithoutResponse() throws {
+  @Test("Keepalive works on new connection without response")
+  func keepaliveOnNewConnectionWithoutResponse() throws {
     let connection = try Connection(
       keepaliveTime: .minutes(5),
       keepaliveTimeout: .seconds(5)
@@ -189,11 +195,10 @@ final class ServerConnectionManagementHandlerTests: XCTestCase {
     // Wait for the keep alive timer to fire which should cause the server to send a keep
     // alive PING.
     connection.advanceTime(by: .minutes(5))
-    let frame1 = try XCTUnwrap(connection.readFrame())
-    XCTAssertEqual(frame1.streamID, .rootStream)
-    XCTAssertPing(frame1.payload) { data, ack in
-      XCTAssertFalse(ack)
-    }
+    let frame1 = try #require(try connection.readFrame())
+    #expect(frame1.streamID == .rootStream)
+    let (_, ack) = try #require(frame1.payload.ping)
+    #expect(!ack)
 
     // We didn't ack the PING, the connection should shutdown after the timeout.
     connection.advanceTime(by: .seconds(5))
@@ -203,7 +208,8 @@ final class ServerConnectionManagementHandlerTests: XCTestCase {
     try connection.waitUntilClosed()
   }
 
-  func testClientKeepalivePolicing() throws {
+  @Test("Keepalive sent by client is policed")
+  func clientKeepalivePolicing() throws {
     let connection = try Connection(
       allowKeepaliveWithoutCalls: true,
       minPingIntervalWithoutCalls: .minutes(1)
@@ -213,24 +219,25 @@ final class ServerConnectionManagementHandlerTests: XCTestCase {
     // The first ping is valid, the second and third are strikes.
     for _ in 1 ... 3 {
       try connection.ping(data: HTTP2PingData(), ack: false)
-      XCTAssertNil(try connection.readFrame())
+      #expect(try connection.readFrame() == nil)
     }
 
     // The fourth ping is the third strike and triggers a GOAWAY.
     try connection.ping(data: HTTP2PingData(), ack: false)
-    let frame = try XCTUnwrap(connection.readFrame())
-    XCTAssertEqual(frame.streamID, .rootStream)
-    XCTAssertGoAway(frame.payload) { streamID, error, data in
-      XCTAssertEqual(streamID, .rootStream)
-      XCTAssertEqual(error, .enhanceYourCalm)
-      XCTAssertEqual(data, ByteBuffer(string: "too_many_pings"))
-    }
+    let frame = try #require(try connection.readFrame())
+    #expect(frame.streamID == .rootStream)
+    let (streamID, error, data) = try #require(frame.payload.goAway)
+
+    #expect(streamID == .rootStream)
+    #expect(error == .enhanceYourCalm)
+    #expect(data == ByteBuffer(string: "too_many_pings"))
 
     // The server should close the connection.
     try connection.waitUntilClosed()
   }
 
-  func testClientKeepaliveWithPermissibleIntervals() throws {
+  @Test("Client keepalive works with permissible intervals")
+  func clientKeepaliveWithPermissibleIntervals() throws {
     let connection = try Connection(
       allowKeepaliveWithoutCalls: true,
       minPingIntervalWithoutCalls: .minutes(1),
@@ -240,14 +247,15 @@ final class ServerConnectionManagementHandlerTests: XCTestCase {
 
     for _ in 1 ... 100 {
       try connection.ping(data: HTTP2PingData(), ack: false)
-      XCTAssertNil(try connection.readFrame())
+      #expect(try connection.readFrame() == nil)
 
       // Advance by the ping interval.
       connection.advanceTime(by: .minutes(1))
     }
   }
 
-  func testClientKeepaliveResetState() throws {
+  @Test("Client keepalive works after reset state")
+  func clientKeepaliveResetState() throws {
     let connection = try Connection(
       allowKeepaliveWithoutCalls: true,
       minPingIntervalWithoutCalls: .minutes(1)
@@ -258,7 +266,7 @@ final class ServerConnectionManagementHandlerTests: XCTestCase {
       // The first ping is valid, the second and third are strikes.
       for _ in 1 ... 3 {
         try connection.ping(data: HTTP2PingData(), ack: false)
-        XCTAssertNil(try connection.readFrame())
+        #expect(try connection.readFrame() == nil)
       }
     }
 
@@ -273,13 +281,13 @@ final class ServerConnectionManagementHandlerTests: XCTestCase {
 
     // The next ping is the third strike and triggers a GOAWAY.
     try connection.ping(data: HTTP2PingData(), ack: false)
-    let frame = try XCTUnwrap(connection.readFrame())
-    XCTAssertEqual(frame.streamID, .rootStream)
-    XCTAssertGoAway(frame.payload) { streamID, error, data in
-      XCTAssertEqual(streamID, .rootStream)
-      XCTAssertEqual(error, .enhanceYourCalm)
-      XCTAssertEqual(data, ByteBuffer(string: "too_many_pings"))
-    }
+    let frame = try #require(try connection.readFrame())
+    #expect(frame.streamID == .rootStream)
+    let (streamID, error, data) = try #require(frame.payload.goAway)
+
+    #expect(streamID == .rootStream)
+    #expect(error == .enhanceYourCalm)
+    #expect(data == ByteBuffer(string: "too_many_pings"))
 
     // The server should close the connection.
     try connection.waitUntilClosed()
@@ -292,18 +300,22 @@ extension ServerConnectionManagementHandlerTests {
     lastStreamID: HTTP2StreamID,
     streamToOpenBeforePingAck: HTTP2StreamID? = nil
   ) throws {
-    let frame1 = try XCTUnwrap(connection.readFrame())
-    XCTAssertEqual(frame1.streamID, .rootStream)
-    XCTAssertGoAway(frame1.payload) { streamID, errorCode, _ in
-      XCTAssertEqual(streamID, .maxID)
-      XCTAssertEqual(errorCode, .noError)
+    do {
+      let frame = try #require(try connection.readFrame())
+      #expect(frame.streamID == .rootStream)
+
+      let (streamID, errorCode, _) = try #require(frame.payload.goAway)
+      #expect(streamID == .maxID)
+      #expect(errorCode == .noError)
     }
 
     // Followed by a PING
-    let frame2 = try XCTUnwrap(connection.readFrame())
-    XCTAssertEqual(frame2.streamID, .rootStream)
-    try XCTAssertPing(frame2.payload) { data, ack in
-      XCTAssertFalse(ack)
+    do {
+      let frame = try #require(try connection.readFrame())
+      #expect(frame.streamID == .rootStream)
+
+      let (data, ack) = try #require(frame.payload.ping)
+      #expect(!ack)
 
       if let id = streamToOpenBeforePingAck {
         connection.streamOpened(id)
@@ -314,11 +326,13 @@ extension ServerConnectionManagementHandlerTests {
     }
 
     // PING ACK triggers another GOAWAY.
-    let frame3 = try XCTUnwrap(connection.readFrame())
-    XCTAssertEqual(frame3.streamID, .rootStream)
-    XCTAssertGoAway(frame3.payload) { streamID, errorCode, _ in
-      XCTAssertEqual(streamID, lastStreamID)
-      XCTAssertEqual(errorCode, .noError)
+    do {
+      let frame = try #require(try connection.readFrame())
+      #expect(frame.streamID == .rootStream)
+
+      let (streamID, errorCode, _) = try #require(frame.payload.goAway)
+      #expect(streamID == lastStreamID)
+      #expect(errorCode == .noError)
     }
   }
 }

+ 20 - 0
Tests/GRPCNIOTransportCoreTests/XCTest+FramePayload.swift

@@ -41,3 +41,23 @@ func XCTAssertPing(
     XCTFail("Expected '.ping' got '\(payload)'")
   }
 }
+
+extension HTTP2Frame.FramePayload {
+  var goAway: (lastStreamID: HTTP2StreamID, errorCode: HTTP2ErrorCode, opaqueData: ByteBuffer?)? {
+    switch self {
+    case .goAway(let lastStreamID, let errorCode, let opaqueData):
+      return (lastStreamID, errorCode, opaqueData)
+    default:
+      return nil
+    }
+  }
+
+  var ping: (data: HTTP2PingData, ack: Bool)? {
+    switch self {
+    case .ping(let data, ack: let ack):
+      return (data, ack)
+    default:
+      return nil
+    }
+  }
+}