Bladeren bron

Fail calls on closed connections. (#451)

Motivation:

Calls on a connection which was already closed or closed during the call
would never fulfill the response/status/metadata futures.

Modifications:

- Observe failures on newly creating stream channels and errors received
  from NIOHTTP2
- Add tests
George Barnett 6 jaren geleden
bovenliggende
commit
4692c04a15

+ 2 - 2
Package.resolved

@@ -24,8 +24,8 @@
         "repositoryURL": "https://github.com/apple/swift-nio.git",
         "state": {
           "branch": null,
-          "revision": "c07fea1aa5fa8147a4f43929fff6d71ec17f01fb",
-          "version": "2.0.1"
+          "revision": "22d49070c556f9b9f741345938af9892473c9f13",
+          "version": "2.0.2"
         }
       },
       {

+ 28 - 10
Sources/SwiftGRPCNIO/ClientCalls/BaseClientCall.swift

@@ -89,6 +89,10 @@ open class BaseClientCall<RequestMessage: Message, ResponseMessage: Message> {
       statusPromise: connection.channel.eventLoop.makePromise(),
       responseObserver: responseObserver)
 
+    self.streamPromise.futureResult.whenFailure { error in
+      self.clientChannelHandler.observeError(error)
+    }
+
     self.createStreamChannel()
     self.setTimeout(callOptions.timeout)
   }
@@ -145,9 +149,7 @@ extension BaseClientCall {
   ///   - requestHead: The request head to send.
   ///   - promise: A promise to fulfill once the request head has been sent.
   internal func sendHead(_ requestHead: HTTPRequestHead, promise: EventLoopPromise<Void>?) {
-    self.subchannel.whenSuccess { channel in
-      channel.writeAndFlush(GRPCClientRequestPart<RequestMessage>.head(requestHead), promise: promise)
-    }
+    self.writeAndFlushOnStream(.head(requestHead), promise: promise)
   }
 
   /// Send the request head once `subchannel` becomes available.
@@ -169,9 +171,7 @@ extension BaseClientCall {
   ///   - message: The message to send.
   ///   - promise: A promise to fulfil when the message reaches the network.
   internal func _sendMessage(_ message: RequestMessage, promise: EventLoopPromise<Void>?) {
-    self.subchannel.whenSuccess { channel in
-      channel.writeAndFlush(GRPCClientRequestPart<RequestMessage>.message(message), promise: promise)
-    }
+    self.writeAndFlushOnStream(.message(message), promise: promise)
   }
 
   /// Send the given message once `subchannel` becomes available.
@@ -190,9 +190,7 @@ extension BaseClientCall {
   /// - Important: This should only ever be called once.
   /// - Parameter promise: A promise to succeed once then end has been sent.
   internal func _sendEnd(promise: EventLoopPromise<Void>?) {
-    self.subchannel.whenSuccess { channel in
-      channel.writeAndFlush(GRPCClientRequestPart<RequestMessage>.end, promise: promise)
-    }
+    self.writeAndFlushOnStream(.end, promise: promise)
   }
 
   /// Send `end` once `subchannel` becomes available.
@@ -206,6 +204,26 @@ extension BaseClientCall {
     return promise.futureResult
   }
 
+  /// Writes the given request on the future `Channel` for the HTTP/2 stream used to make this call.
+  ///
+  /// This method is intended to be used by the `sendX` methods in order to ensure that they fail
+  /// futures associated with this call should the write fail (e.g. due to a closed connection).
+  private func writeAndFlushOnStream(_ request: GRPCClientRequestPart<RequestMessage>, promise: EventLoopPromise<Void>?) {
+    // We need to use a promise here; if the write fails then it _must_ be observed by the handler
+    // to ensure that any futures given to the user are fulfilled.
+    let promise = promise ?? self.connection.channel.eventLoop.makePromise()
+
+    promise.futureResult.whenFailure { error in
+      self.clientChannelHandler.observeError(error)
+    }
+
+    self.subchannel.cascadeFailure(to: promise)
+
+    self.subchannel.whenSuccess { channel in
+      channel.writeAndFlush(NIOAny(request), promise: promise)
+    }
+  }
+
   /// Creates a client-side timeout for this call.
   ///
   /// - Important: This should only ever be called once.
@@ -213,7 +231,7 @@ extension BaseClientCall {
     if timeout == .infinite { return }
 
     self.connection.channel.eventLoop.scheduleTask(in: timeout.asNIOTimeAmount) { [weak self] in
-      self?.clientChannelHandler.observeError(.client(.deadlineExceeded(timeout)))
+      self?.clientChannelHandler.observeError(GRPCError.client(.deadlineExceeded(timeout)))
     }
   }
 

+ 3 - 1
Sources/SwiftGRPCNIO/ClientCalls/BidirectionalStreamingClientCall.swift

@@ -34,7 +34,9 @@ public class BidirectionalStreamingClientCall<RequestMessage: Message, ResponseM
     super.init(connection: connection, path: path, callOptions: callOptions, responseObserver: .callback(handler))
 
     let requestHead = self.makeRequestHead(path: path, host: connection.host, callOptions: callOptions)
-    self.messageQueue = self.messageQueue.flatMap { self.sendHead(requestHead) }
+    self.messageQueue = self.messageQueue.flatMap {
+      self.sendHead(requestHead)
+    }
   }
 
   public func sendMessage(_ message: RequestMessage) -> EventLoopFuture<Void> {

+ 3 - 1
Sources/SwiftGRPCNIO/ClientCalls/ClientStreamingClientCall.swift

@@ -43,7 +43,9 @@ public class ClientStreamingClientCall<RequestMessage: Message, ResponseMessage:
       responseObserver: .succeedPromise(responsePromise))
 
     let requestHead = self.makeRequestHead(path: path, host: connection.host, callOptions: callOptions)
-    self.messageQueue = self.messageQueue.flatMap { self.sendHead(requestHead) }
+    self.messageQueue = self.messageQueue.flatMap {
+      self.sendHead(requestHead)
+    }
   }
 
   public func sendMessage(_ message: RequestMessage) -> EventLoopFuture<Void> {

+ 5 - 3
Sources/SwiftGRPCNIO/ClientCalls/ServerStreamingClientCall.swift

@@ -28,8 +28,10 @@ public class ServerStreamingClientCall<RequestMessage: Message, ResponseMessage:
     super.init(connection: connection, path: path, callOptions: callOptions, responseObserver: .callback(handler))
 
     let requestHead = self.makeRequestHead(path: path, host: connection.host, callOptions: callOptions)
-    self.sendHead(requestHead)
-      .flatMap { self._sendMessage(request) }
-      .whenSuccess { self._sendEnd(promise: nil) }
+    self.sendHead(requestHead).flatMap {
+      self._sendMessage(request)
+    }.whenSuccess {
+      self._sendEnd(promise: nil)
+    }
   }
 }

+ 5 - 3
Sources/SwiftGRPCNIO/ClientCalls/UnaryClientCall.swift

@@ -38,8 +38,10 @@ public class UnaryClientCall<RequestMessage: Message, ResponseMessage: Message>:
       responseObserver: .succeedPromise(responsePromise))
 
     let requestHead = self.makeRequestHead(path: path, host: connection.host, callOptions: callOptions)
-    self.sendHead(requestHead)
-      .flatMap { self._sendMessage(request) }
-      .whenSuccess { self._sendEnd(promise: nil) }
+    self.sendHead(requestHead).flatMap {
+      self._sendMessage(request)
+    }.whenSuccess {
+      self._sendEnd(promise: nil)
+    }
   }
 }

+ 8 - 3
Sources/SwiftGRPCNIO/GRPCClientChannelHandler.swift

@@ -100,11 +100,16 @@ internal class GRPCClientChannelHandler<RequestMessage: Message, ResponseMessage
 
   /// Observe the given error.
   ///
-  /// Calls `observeStatus(status:)`. with `error.asGRPCStatus()`.
+  /// If the error conforms to `GRPCStatusTransformable` then `observeStatus(status:)` is called
+  /// with the transformed error, otherwise `GRPCStatus.processingError` is used.
   ///
   /// - Parameter error: the error to observe.
-  internal func observeError(_ error: GRPCError) {
-    self.observeStatus(error.asGRPCStatus())
+  internal func observeError(_ error: Error) {
+    if let transformable = error as? GRPCStatusTransformable {
+      self.observeStatus(transformable.asGRPCStatus())
+    } else {
+      self.observeStatus(.processingError)
+    }
   }
 }
 

+ 26 - 0
Sources/SwiftGRPCNIO/GRPCStatus.swift

@@ -1,5 +1,7 @@
 import Foundation
+import NIO
 import NIOHTTP1
+import NIOHTTP2
 
 /// Encapsulates the result of a gRPC call.
 public struct GRPCStatus: Error, Equatable {
@@ -33,3 +35,27 @@ extension GRPCStatus: GRPCStatusTransformable {
     return self
   }
 }
+
+extension NIOHTTP2Errors.StreamClosed: GRPCStatusTransformable {
+  public func asGRPCStatus() -> GRPCStatus {
+    return .init(code: .unavailable, message: self.localizedDescription)
+  }
+}
+
+extension NIOHTTP2Errors.IOOnClosedConnection: GRPCStatusTransformable {
+  public func asGRPCStatus() -> GRPCStatus {
+    return .init(code: .unavailable, message: "The connection is closed")
+  }
+}
+
+extension ChannelError: GRPCStatusTransformable {
+  public func asGRPCStatus() -> GRPCStatus {
+    switch self {
+    case .inputClosed, .outputClosed, .ioOnClosedChannel:
+      return .init(code: .unavailable, message: "The connection is closed")
+
+    default:
+      return .processingError
+    }
+  }
+}

+ 14 - 0
Tests/SwiftGRPCNIOTests/EventLoopFuture+Assertions.swift

@@ -66,3 +66,17 @@ extension EventLoopFuture {
     }
   }
 }
+
+extension EventLoopFuture {
+  // TODO: Replace with `always` once https://github.com/apple/swift-nio/pull/981 is released.
+  func peekError(callback: @escaping (Error) -> ()) -> EventLoopFuture<Value> {
+    self.whenFailure(callback)
+    return self
+  }
+
+  // TODO: Replace with `always` once https://github.com/apple/swift-nio/pull/981 is released.
+  func peek(callback: @escaping (Value) -> ()) -> EventLoopFuture<Value> {
+    self.whenSuccess(callback)
+    return self
+  }
+}

+ 33 - 1
Tests/SwiftGRPCNIOTests/NIOBasicEchoTestCase.swift

@@ -150,7 +150,10 @@ class NIOEchoTestCaseBase: XCTestCase {
   }
 
   override func tearDown() {
-    XCTAssertNoThrow(try self.client.connection.close().wait())
+    // Some tests close the channel, so would throw here if called twice.
+    if self.client.connection.channel.isActive {
+      XCTAssertNoThrow(try self.client.connection.close().wait())
+    }
     XCTAssertNoThrow(try self.clientEventLoopGroup.syncShutdownGracefully())
     self.client = nil
     self.clientEventLoopGroup = nil
@@ -163,3 +166,32 @@ class NIOEchoTestCaseBase: XCTestCase {
     super.tearDown()
   }
 }
+
+extension NIOEchoTestCaseBase {
+  func makeExpectation(description: String, expectedFulfillmentCount: Int = 1, assertForOverFulfill: Bool = true) -> XCTestExpectation {
+    let expectation = self.expectation(description: description)
+    expectation.expectedFulfillmentCount = expectedFulfillmentCount
+    expectation.assertForOverFulfill = assertForOverFulfill
+    return expectation
+  }
+
+  func makeStatusExpectation(expectedFulfillmentCount: Int = 1) -> XCTestExpectation {
+    return makeExpectation(description: "Expecting status received",
+                           expectedFulfillmentCount: expectedFulfillmentCount)
+  }
+
+  func makeResponseExpectation(expectedFulfillmentCount: Int = 1) -> XCTestExpectation {
+    return makeExpectation(description: "Expecting \(expectedFulfillmentCount) response(s)",
+      expectedFulfillmentCount: expectedFulfillmentCount)
+  }
+
+  func makeRequestExpectation(expectedFulfillmentCount: Int = 1) -> XCTestExpectation {
+    return makeExpectation(
+      description: "Expecting \(expectedFulfillmentCount) request(s) to have been sent",
+      expectedFulfillmentCount: expectedFulfillmentCount)
+  }
+
+  func makeInitialMetadataExpectation() -> XCTestExpectation {
+    return makeExpectation(description: "Expecting initial metadata")
+  }
+}

+ 167 - 0
Tests/SwiftGRPCNIOTests/NIOClientClosedChannelTests.swift

@@ -0,0 +1,167 @@
+/*
+ * Copyright 2019, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import Foundation
+import SwiftGRPCNIO
+import NIO
+import XCTest
+
+class NIOClientClosedChannelTests: NIOEchoTestCaseBase {
+  func testUnaryOnClosedConnection() throws {
+    let initialMetadataExpectation = self.makeInitialMetadataExpectation()
+    let responseExpectation = self.makeResponseExpectation()
+    let statusExpectation = self.makeStatusExpectation()
+
+    self.client.connection.close().map {
+      self.client.get(Echo_EchoRequest(text: "foo"))
+    }.whenSuccess { get in
+      get.initialMetadata.assertError(fulfill: initialMetadataExpectation)
+      get.response.assertError(fulfill: responseExpectation)
+      get.status.map { $0.code }.assertEqual(.unavailable, fulfill: statusExpectation)
+    }
+
+    self.wait(for: [initialMetadataExpectation, responseExpectation, statusExpectation],
+              timeout: self.defaultTestTimeout)
+  }
+
+  func testClientStreamingOnClosedConnection() throws {
+    let initialMetadataExpectation = self.makeInitialMetadataExpectation()
+    let responseExpectation = self.makeResponseExpectation()
+    let statusExpectation = self.makeStatusExpectation()
+
+    self.client.connection.close().map {
+      self.client.collect()
+    }.whenSuccess { collect in
+      collect.initialMetadata.assertError(fulfill: initialMetadataExpectation)
+      collect.response.assertError(fulfill: responseExpectation)
+      collect.status.map { $0.code }.assertEqual(.unavailable, fulfill: statusExpectation)
+    }
+
+    self.wait(for: [initialMetadataExpectation, responseExpectation, statusExpectation],
+              timeout: self.defaultTestTimeout)
+  }
+
+  func testClientStreamingWhenConnectionIsClosedBetweenMessages() throws {
+    let statusExpectation = self.makeStatusExpectation()
+    let responseExpectation = self.makeResponseExpectation()
+    let requestExpectation = self.makeRequestExpectation(expectedFulfillmentCount: 3)
+
+    let collect = self.client.collect()
+
+    collect.newMessageQueue().flatMap {
+      collect.sendMessage(Echo_EchoRequest(text: "foo"))
+    }.peek {
+      requestExpectation.fulfill()
+    }.flatMap {
+      collect.sendMessage(Echo_EchoRequest(text: "bar"))
+    }.peek {
+      requestExpectation.fulfill()
+    }.flatMap {
+      self.client.connection.close()
+    }.peekError { error in
+      XCTFail("Encountered error before or during closing the connection: \(error)")
+    }.flatMap {
+      collect.sendMessage(Echo_EchoRequest(text: "baz"))
+    }.assertError(fulfill: requestExpectation)
+
+    collect.response.assertError(fulfill: responseExpectation)
+    collect.status.map { $0.code }.assertEqual(.unavailable, fulfill: statusExpectation)
+
+    self.wait(for: [statusExpectation, responseExpectation, requestExpectation],
+              timeout: self.defaultTestTimeout)
+  }
+  
+  func testServerStreamingOnClosedConnection() throws {
+    let initialMetadataExpectation = self.makeInitialMetadataExpectation()
+    let statusExpectation = self.makeStatusExpectation()
+
+    self.client.connection.close().map {
+      self.client.expand(Echo_EchoRequest(text: "foo")) { response in
+        XCTFail("No response expected but got: \(response)")
+      }
+    }.whenSuccess { expand in
+      expand.initialMetadata.assertError(fulfill: initialMetadataExpectation)
+      expand.status.map { $0.code }.assertEqual(.unavailable, fulfill: statusExpectation)
+    }
+
+    self.wait(for: [initialMetadataExpectation, statusExpectation],
+              timeout: self.defaultTestTimeout)
+  }
+
+  func testBidirectionalStreamingOnClosedConnection() throws {
+    let initialMetadataExpectation = self.makeInitialMetadataExpectation()
+    let statusExpectation = self.makeStatusExpectation()
+
+    self.client.connection.close().map {
+      self.client.update { response in
+        XCTFail("No response expected but got: \(response)")
+      }
+    }.whenSuccess { update in
+      update.initialMetadata.assertError(fulfill: initialMetadataExpectation)
+      update.status.map { $0.code }.assertEqual(.unavailable, fulfill: statusExpectation)
+    }
+
+    self.wait(for: [initialMetadataExpectation, statusExpectation],
+              timeout: self.defaultTestTimeout)
+  }
+
+  func testBidirectionalStreamingWhenConnectionIsClosedBetweenMessages() throws {
+    let statusExpectation = self.makeStatusExpectation()
+    let requestExpectation = self.makeRequestExpectation(expectedFulfillmentCount: 3)
+
+    // We can't make any assertions about the number of responses we will receive before closing
+    // the connection; just ignore all responses.
+    let update = self.client.update() { _ in }
+
+    update.newMessageQueue().flatMap {
+      update.sendMessage(Echo_EchoRequest(text: "foo"))
+    }.peek {
+      requestExpectation.fulfill()
+    }.flatMap {
+      update.sendMessage(Echo_EchoRequest(text: "bar"))
+    }.peek {
+      requestExpectation.fulfill()
+    }.flatMap {
+      self.client.connection.close()
+    }.peekError { error in
+      XCTFail("Encountered error before or during closing the connection: \(error)")
+    }.flatMap {
+      update.sendMessage(Echo_EchoRequest(text: "baz"))
+    }.assertError(fulfill: requestExpectation)
+
+    update.status.map { $0.code }.assertEqual(.unavailable, fulfill: statusExpectation)
+
+    self.wait(for: [statusExpectation, requestExpectation], timeout: self.defaultTestTimeout)
+  }
+
+  func testBidirectionalStreamingWithNoPromiseWhenConnectionIsClosedBetweenMessages() throws {
+    let statusExpectation = self.makeStatusExpectation()
+
+    let update = self.client.update() { response in
+      XCTFail("No response expected but got: \(response)")
+    }
+
+    update.newMessageQueue().flatMap {
+      self.client.connection.close()
+    }.peekError { error in
+      XCTFail("Encountered error before or during closing the connection: \(error)")
+    }.whenSuccess {
+      update.sendMessage(Echo_EchoRequest(text: "foo"), promise: nil)
+    }
+
+    update.status.map { $0.code }.assertEqual(.unavailable, fulfill: statusExpectation)
+    self.wait(for: [statusExpectation], timeout: self.defaultTestTimeout)
+  }
+}

+ 0 - 19
Tests/SwiftGRPCNIOTests/NIOFunctionalTests.swift

@@ -37,25 +37,6 @@ class NIOFunctionalTestsInsecureTransport: NIOEchoTestCaseBase {
   }
 }
 
-extension NIOFunctionalTestsInsecureTransport {
-  func makeExpectation(description: String, expectedFulfillmentCount: Int = 1, assertForOverFulfill: Bool = true) -> XCTestExpectation {
-    let expectation = self.expectation(description: description)
-    expectation.expectedFulfillmentCount = expectedFulfillmentCount
-    expectation.assertForOverFulfill = assertForOverFulfill
-    return expectation
-  }
-
-  func makeStatusExpectation(expectedFulfillmentCount: Int = 1) -> XCTestExpectation {
-    return makeExpectation(description: "Expecting status received",
-                           expectedFulfillmentCount: expectedFulfillmentCount)
-  }
-
-  func makeResponseExpectation(expectedFulfillmentCount: Int = 1) -> XCTestExpectation {
-    return makeExpectation(description: "Expecting \(expectedFulfillmentCount) response(s)",
-                           expectedFulfillmentCount: expectedFulfillmentCount)
-  }
-}
-
 extension NIOFunctionalTestsInsecureTransport {
   func doTestUnary(request: Echo_EchoRequest, expect response: Echo_EchoResponse, file: StaticString = #file, line: UInt = #line) {
     let responseExpectation = self.makeResponseExpectation()

+ 16 - 0
Tests/SwiftGRPCNIOTests/XCTestManifests.swift

@@ -127,6 +127,21 @@ extension NIOClientCancellingTests {
     ]
 }
 
+extension NIOClientClosedChannelTests {
+    // DO NOT MODIFY: This is autogenerated, use:
+    //   `swift test --generate-linuxmain`
+    // to regenerate.
+    static let __allTests__NIOClientClosedChannelTests = [
+        ("testBidirectionalStreamingOnClosedConnection", testBidirectionalStreamingOnClosedConnection),
+        ("testBidirectionalStreamingWhenConnectionIsClosedBetweenMessages", testBidirectionalStreamingWhenConnectionIsClosedBetweenMessages),
+        ("testBidirectionalStreamingWithNoPromiseWhenConnectionIsClosedBetweenMessages", testBidirectionalStreamingWithNoPromiseWhenConnectionIsClosedBetweenMessages),
+        ("testClientStreamingOnClosedConnection", testClientStreamingOnClosedConnection),
+        ("testClientStreamingWhenConnectionIsClosedBetweenMessages", testClientStreamingWhenConnectionIsClosedBetweenMessages),
+        ("testServerStreamingOnClosedConnection", testServerStreamingOnClosedConnection),
+        ("testUnaryOnClosedConnection", testUnaryOnClosedConnection),
+    ]
+}
+
 extension NIOClientTLSFailureTests {
     // DO NOT MODIFY: This is autogenerated, use:
     //   `swift test --generate-linuxmain`
@@ -270,6 +285,7 @@ public func __allTests() -> [XCTestCaseEntry] {
         testCase(HTTP1ToRawGRPCServerCodecTests.__allTests__HTTP1ToRawGRPCServerCodecTests),
         testCase(LengthPrefixedMessageReaderTests.__allTests__LengthPrefixedMessageReaderTests),
         testCase(NIOClientCancellingTests.__allTests__NIOClientCancellingTests),
+        testCase(NIOClientClosedChannelTests.__allTests__NIOClientClosedChannelTests),
         testCase(NIOClientTLSFailureTests.__allTests__NIOClientTLSFailureTests),
         testCase(NIOClientTimeoutTests.__allTests__NIOClientTimeoutTests),
         testCase(NIOFunctionalTestsAnonymousClient.__allTests__NIOFunctionalTestsAnonymousClient),