瀏覽代碼

Use EmbeddedChannel for client timeout tests (#524)

Motivation:

The ClientTimeoutTests were often flakey as they relied on waiting
for specified times instead of reacting to signals.

Modifications:

Provide an internal init for ClientConnection which uses an
EmbeddedChannel. Rewrite the ClientTimeoutTests to make use of
this change.

Result:

ClientTimeoutTests are more reliable.
George Barnett 6 年之前
父節點
當前提交
7416161090

+ 3 - 6
Sources/GRPC/ClientCalls/BaseClientCall.swift

@@ -168,21 +168,18 @@ extension BaseClientCall: ClientCall {
 /// - Parameter requestID: The request ID used for this call. If `callOptions` specifies a
 ///   non-nil `reqeuestIDHeader` then this request ID will be added to the headers with the
 ///   specified header name.
-internal func makeRequestHead(path: String, host: String?, callOptions: CallOptions, requestID: String) -> HTTPRequestHead {
+internal func makeRequestHead(path: String, host: String, callOptions: CallOptions, requestID: String) -> HTTPRequestHead {
   var headers: HTTPHeaders = [
     "content-type": "application/grpc",
     // Used to detect incompatible proxies, as per https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests
     "te": "trailers",
     //! FIXME: Add a more specific user-agent, see: https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#user-agents
     "user-agent": "grpc-swift-nio",
+    // We're dealing with HTTP/1; the NIO HTTP2ToHTTP1Codec replaces "host" with ":authority".
+    "host": host,
     GRPCHeaderName.acceptEncoding: CompressionMechanism.acceptEncodingHeader,
   ]
 
-  if let host = host {
-    // We're dealing with HTTP/1; the NIO HTTP2ToHTTP1Codec replaces "host" with ":authority".
-    headers.add(name: "host", value: host)
-  }
-
   if callOptions.timeout != .infinite {
     headers.add(name: GRPCHeaderName.timeout, value: String(describing: callOptions.timeout))
   }

+ 89 - 15
Sources/GRPC/ClientConnection.swift

@@ -106,6 +106,57 @@ public class ClientConnection {
     self.didSetChannel(to: channel)
   }
 
+  // This is only internal to expose it for testing.
+  /// Create a `ClientConnection` for testing using the given `EmbeddedChannel`.
+  ///
+  /// - Parameter channel: The embedded channel to create this connection on.
+  /// - Parameter configuration: How this connection should be configured. The `eventLoopGroup`
+  ///     on the configuration will _not_ be used by the call. As such the `eventLoop` of
+  ///     the given `channel` may be used in the configuration to avoid managing an additional
+  ///     event loop group.
+  ///
+  /// - Important:
+  ///   The connectivity state will not be updated using this connection and should not be
+  ///   relied on.
+  ///
+  /// - Precondition:
+  ///   The provided connection target in the `configuration` _must_ be a `SocketAddress`.
+  internal init(channel: EmbeddedChannel, configuration: Configuration) {
+    // We need a .socketAddress to connect to.
+    let socketAddress: SocketAddress
+    switch configuration.target {
+    case .socketAddress(let address):
+      socketAddress = address
+    default:
+      preconditionFailure("target must be SocketAddress when using EmbeddedChannel")
+    }
+
+    self.uuid = UUID()
+    var logger = Logger(subsystem: .clientChannel)
+    logger[metadataKey: MetadataKey.connectionID] = "\(self.uuid)"
+    self.logger = logger
+
+    self.configuration = configuration
+    self.connectivity = ConnectivityStateMonitor(delegate: configuration.connectivityStateDelegate)
+
+    // Configure the channel with the correct handlers and connect to our target.
+    let configuredChannel = ClientConnection.initializeChannel(
+      channel,
+      tls: configuration.tls,
+      errorDelegate: configuration.errorDelegate
+    ).flatMap {
+      channel.connect(to: socketAddress)
+    }.map { _ in
+      return channel as Channel
+    }
+
+    self.multiplexer = configuredChannel.flatMap {
+      $0.pipeline.handler(type: HTTP2StreamMultiplexer.self)
+    }
+
+    self.channel = configuredChannel
+  }
+
   /// The `EventLoop` this connection is using.
   public var eventLoop: EventLoop {
     return self.channel.eventLoop
@@ -286,17 +337,12 @@ extension ClientConnection {
       .channelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
       .channelOption(ChannelOptions.socket(IPPROTO_TCP, TCP_NODELAY), value: 1)
       .channelInitializer { channel in
-        let tlsConfigured = configuration.tls.map { tlsConfiguration in
-          channel.configureTLS(tlsConfiguration, errorDelegate: configuration.errorDelegate)
-        }
-
-        return (tlsConfigured ?? channel.eventLoop.makeSucceededFuture(())).flatMap {
-          channel.configureHTTP2Pipeline(mode: .client)
-        }.flatMap { _ in
-          let errorHandler = DelegatingErrorHandler(delegate: configuration.errorDelegate)
-          return channel.pipeline.addHandler(errorHandler)
-        }
-    }
+        initializeChannel(
+          channel,
+          tls: configuration.tls,
+          errorDelegate: configuration.errorDelegate
+        )
+      }
 
     if let timeout = timeout {
       logger.info("setting connect timeout to \(timeout) seconds")
@@ -305,6 +351,28 @@ extension ClientConnection {
       return bootstrap
     }
   }
+
+  /// Initialize the channel with the given TLS configuration and error delegate.
+  ///
+  /// - Parameter channel: The channel to initialize.
+  /// - Parameter tls: The optional TLS configuration for the channel.
+  /// - Parameter errorDelegate: Optional client error delegate.
+  private class func initializeChannel(
+    _ channel: Channel,
+    tls: Configuration.TLS?,
+    errorDelegate: ClientErrorDelegate?
+  ) -> EventLoopFuture<Void> {
+    let tlsConfigured = tls.map { tlsConfiguration in
+      channel.configureTLS(tlsConfiguration, errorDelegate: errorDelegate)
+    }
+
+    return (tlsConfigured ?? channel.eventLoop.makeSucceededFuture(())).flatMap {
+      channel.configureHTTP2Pipeline(mode: .client)
+    }.flatMap { _ in
+      let errorHandler = DelegatingErrorHandler(delegate: errorDelegate)
+      return channel.pipeline.addHandler(errorHandler)
+    }
+  }
 }
 
 // MARK: - Configuration structures
@@ -318,11 +386,17 @@ public enum ConnectionTarget {
   /// A NIO socket address.
   case socketAddress(SocketAddress)
 
-  var host: String? {
-    guard case .hostAndPort(let host, _) = self else {
-      return nil
+  var host: String {
+    switch self {
+    case .hostAndPort(let host, _):
+      return host
+    case .socketAddress(.v4(let address)):
+      return address.host
+    case .socketAddress(.v6(let address)):
+      return address.host
+    case .unixDomainSocket, .socketAddress(.unixDomainSocket):
+      return "localhost"
     }
-    return host
   }
 }
 

+ 92 - 60
Tests/GRPCTests/ClientTimeoutTests.swift

@@ -14,109 +14,141 @@
  * limitations under the License.
  */
 import Foundation
-import GRPC
+@testable import GRPC
 import NIO
 import XCTest
 
-class ClientTimeoutTests: EchoTestCaseBase {
-  let optionsWithShortTimeout = CallOptions(timeout: try! GRPCTimeout.milliseconds(10))
-  let moreThanShortTimeout: TimeInterval = 0.020
+class ClientTimeoutTests: GRPCTestCase {
+  var channel: EmbeddedChannel!
+  var client: Echo_EchoServiceClient!
 
-  private func expectDeadlineExceeded(forStatus status: EventLoopFuture<GRPCStatus>,
-                                      file: StaticString = #file, line: UInt = #line) {
-    let statusExpectation = self.expectation(description: "status received")
+  let callOptions = CallOptions(timeout: try! .milliseconds(100))
+  var timeout: GRPCTimeout {
+    return self.callOptions.timeout
+  }
 
-    status.whenSuccess { status in
-      XCTAssertEqual(status.code, .deadlineExceeded, file: file, line: line)
-      statusExpectation.fulfill()
-    }
+  // Note: this is not related to the call timeout since we're using an EmbeddedChannel. We require
+  // this in case the timeout doesn't work.
+  let testTimeout: TimeInterval = 0.1
 
-    status.whenFailure { error in
-      XCTFail("unexpectedly received error for status: \(error)", file: file, line: line)
-    }
+  override func setUp() {
+    super.setUp()
+    let channel = EmbeddedChannel()
+
+    let configuration = ClientConnection.Configuration(
+      target: .socketAddress(try! .init(unixDomainSocketPath: "/foo")),
+      eventLoopGroup: channel.eventLoop
+    )
+
+    let connection = ClientConnection(channel: channel, configuration: configuration)
+    let client = Echo_EchoServiceClient(connection: connection, defaultCallOptions: self.callOptions)
+
+    self.channel = channel
+    self.client = client
   }
 
-  private func expectDeadlineExceeded(forResponse response: EventLoopFuture<Echo_EchoResponse>,
-                                      file: StaticString = #file, line: UInt = #line) {
-    let responseExpectation = self.expectation(description: "response received")
+  override func tearDown() {
+    XCTAssertNoThrow(try self.channel.finish())
+  }
 
-    response.whenFailure { error in
-      XCTAssertEqual((error as? GRPCStatus)?.code, .deadlineExceeded, file: file, line: line)
-      responseExpectation.fulfill()
+  func assertDeadlineExceeded(_ response: EventLoopFuture<Echo_EchoResponse>, expectation: XCTestExpectation) {
+    response.whenComplete { result in
+      switch result {
+      case .success(let response):
+        XCTFail("unexpected response: \(response)")
+      case .failure(let error):
+        if let status = error as? GRPCStatus {
+          XCTAssertEqual(status.code, .deadlineExceeded)
+        } else {
+          XCTFail("unexpected error: \(error)")
+        }
+      }
+      expectation.fulfill()
     }
+  }
 
-    response.whenSuccess { response in
-      XCTFail("response received after deadline", file: file, line: line)
+  func assertDeadlineExceeded(_ status: EventLoopFuture<GRPCStatus>, expectation: XCTestExpectation) {
+    status.whenComplete { result in
+      switch result {
+      case .success(let status):
+        XCTAssertEqual(status.code, .deadlineExceeded)
+      case .failure(let error):
+        XCTFail("unexpected error: \(error)")
+      }
+      expectation.fulfill()
     }
   }
-}
 
-extension ClientTimeoutTests {
-  func testUnaryTimeoutAfterSending() {
-    // The request gets fired on call creation, so we need a very short timeout.
-    let callOptions = CallOptions(timeout: try! .microseconds(100))
-    let call = client.get(Echo_EchoRequest(text: "foo"), callOptions: callOptions)
+  func testUnaryTimeoutAfterSending() throws {
+    let statusExpectation = self.expectation(description: "status fulfilled")
 
-    self.expectDeadlineExceeded(forStatus: call.status)
-    self.expectDeadlineExceeded(forResponse: call.response)
+    let call = self.client.get(Echo_EchoRequest(text: "foo"))
+    channel.embeddedEventLoop.advanceTime(by: self.timeout.asNIOTimeAmount)
 
-    waitForExpectations(timeout: defaultTestTimeout)
+    self.assertDeadlineExceeded(call.status, expectation: statusExpectation)
+    self.wait(for: [statusExpectation], timeout: self.testTimeout)
   }
 
-  func testServerStreamingTimeoutAfterSending() {
-    // The request gets fired on call creation, so we need a very short timeout.
-    let callOptions = CallOptions(timeout: try! .microseconds(100))
-    let call = client.expand(Echo_EchoRequest(text: "foo bar baz"), callOptions: callOptions) { _ in }
+  func testServerStreamingTimeoutAfterSending() throws {
+    let statusExpectation = self.expectation(description: "status fulfilled")
 
-    self.expectDeadlineExceeded(forStatus: call.status)
+    let call = client.expand(Echo_EchoRequest(text: "foo bar baz")) { _ in }
+    channel.embeddedEventLoop.advanceTime(by: self.timeout.asNIOTimeAmount)
 
-    waitForExpectations(timeout: defaultTestTimeout)
+    self.assertDeadlineExceeded(call.status, expectation: statusExpectation)
+    self.wait(for: [statusExpectation], timeout: self.testTimeout)
   }
 
-  func testClientStreamingTimeoutBeforeSending() {
-    let call = client.collect(callOptions: optionsWithShortTimeout)
+  func testClientStreamingTimeoutBeforeSending() throws {
+    let responseExpectation = self.expectation(description: "response fulfilled")
+    let statusExpectation = self.expectation(description: "status fulfilled")
 
-    self.expectDeadlineExceeded(forStatus: call.status)
-    self.expectDeadlineExceeded(forResponse: call.response)
+    let call = client.collect()
+    channel.embeddedEventLoop.advanceTime(by: self.timeout.asNIOTimeAmount)
 
-    waitForExpectations(timeout: defaultTestTimeout)
+    self.assertDeadlineExceeded(call.response, expectation: responseExpectation)
+    self.assertDeadlineExceeded(call.status, expectation: statusExpectation)
+    self.wait(for: [responseExpectation, statusExpectation], timeout: self.testTimeout)
   }
 
-  func testClientStreamingTimeoutAfterSending() {
-    let call = client.collect(callOptions: optionsWithShortTimeout)
+  func testClientStreamingTimeoutAfterSending() throws {
+    let responseExpectation = self.expectation(description: "response fulfilled")
+    let statusExpectation = self.expectation(description: "status fulfilled")
 
-    self.expectDeadlineExceeded(forStatus: call.status)
-    self.expectDeadlineExceeded(forResponse: call.response)
+    let call = client.collect()
 
-    call.sendMessage(Echo_EchoRequest(text: "foo"), promise: nil)
+    self.assertDeadlineExceeded(call.response, expectation: responseExpectation)
+    self.assertDeadlineExceeded(call.status, expectation: statusExpectation)
 
-    // Timeout before sending `.end`
-    Thread.sleep(forTimeInterval: moreThanShortTimeout)
+    call.sendMessage(Echo_EchoRequest(text: "foo"), promise: nil)
     call.sendEnd(promise: nil)
+    channel.embeddedEventLoop.advanceTime(by: self.timeout.asNIOTimeAmount)
 
-    waitForExpectations(timeout: defaultTestTimeout)
+    self.wait(for: [responseExpectation, statusExpectation], timeout: 1.0)
   }
 
   func testBidirectionalStreamingTimeoutBeforeSending() {
-    let call = client.update(callOptions: optionsWithShortTimeout)  { _ in }
+    let statusExpectation = self.expectation(description: "status fulfilled")
 
-    self.expectDeadlineExceeded(forStatus: call.status)
+    let call = client.update { _ in }
 
-    Thread.sleep(forTimeInterval: moreThanShortTimeout)
-    waitForExpectations(timeout: defaultTestTimeout)
+    channel.embeddedEventLoop.advanceTime(by: self.timeout.asNIOTimeAmount)
+
+    self.assertDeadlineExceeded(call.status, expectation: statusExpectation)
+    self.wait(for: [statusExpectation], timeout: self.testTimeout)
   }
 
   func testBidirectionalStreamingTimeoutAfterSending() {
-    let call = client.update(callOptions: optionsWithShortTimeout) { _ in }
+    let statusExpectation = self.expectation(description: "status fulfilled")
 
-    self.expectDeadlineExceeded(forStatus: call.status)
+    let call = client.update { _ in }
 
-    call.sendMessage(Echo_EchoRequest(text: "foo"), promise: nil)
+    self.assertDeadlineExceeded(call.status, expectation: statusExpectation)
 
-    // Timeout before sending `.end`
-    Thread.sleep(forTimeInterval: moreThanShortTimeout)
+    call.sendMessage(Echo_EchoRequest(text: "foo"), promise: nil)
     call.sendEnd(promise: nil)
+    channel.embeddedEventLoop.advanceTime(by: self.timeout.asNIOTimeAmount)
 
-    waitForExpectations(timeout: defaultTestTimeout)
+    self.wait(for: [statusExpectation], timeout: self.testTimeout)
   }
 }