Browse Source

Extend lifetime of client interceptor pipeline (#1265)

Motivation:

A client call (i.e. the object the user holds) may live longer than the
transport associated with it (roughly speaking, the http/2 stream channel). An
example of this is when interceptors are use to retry and RPC and
redirect responses back to the original call.

However, the interceptor pipeline is held by the transport and is
currently set to nil when the transport is removed from the channel.
This means events invoked from the call object (such as cancellation)
which go via the transport (holding the interceptor pipeline) are
incorrectly failed.

Modifications:

- Have the client interceptor pipeline break the ref cycle between the
  transport and itself when the interceptor pipeline closes rather than
  when the transport is closed
- Emit a cancellation status rater than error on cancellation
- Update the ordering of when close is called in the interceptor
  pipeline.
- Add and update tests

Result:

"sub"-RPCs may be cancelled.
George Barnett 4 years ago
parent
commit
03010c784c

+ 27 - 13
Sources/GRPC/Interceptor/ClientInterceptorPipeline.swift

@@ -79,16 +79,17 @@ internal final class ClientInterceptorPipeline<Request, Response> {
   internal let _errorDelegate: ClientErrorDelegate?
   internal let _errorDelegate: ClientErrorDelegate?
 
 
   @usableFromInline
   @usableFromInline
-  internal let _onError: (Error) -> Void
+  internal private(set) var _onError: ((Error) -> Void)?
 
 
   @usableFromInline
   @usableFromInline
-  internal let _onCancel: (EventLoopPromise<Void>?) -> Void
+  internal private(set) var _onCancel: ((EventLoopPromise<Void>?) -> Void)?
 
 
   @usableFromInline
   @usableFromInline
-  internal let _onRequestPart: (GRPCClientRequestPart<Request>, EventLoopPromise<Void>?) -> Void
+  internal private(set) var _onRequestPart:
+    ((GRPCClientRequestPart<Request>, EventLoopPromise<Void>?) -> Void)?
 
 
   @usableFromInline
   @usableFromInline
-  internal let _onResponsePart: (GRPCClientResponsePart<Response>) -> Void
+  internal private(set) var _onResponsePart: ((GRPCClientResponsePart<Response>) -> Void)?
 
 
   /// The index after the last user interceptor context index. (i.e. `_userContexts.endIndex`).
   /// The index after the last user interceptor context index. (i.e. `_userContexts.endIndex`).
   @usableFromInline
   @usableFromInline
@@ -217,9 +218,13 @@ internal final class ClientInterceptorPipeline<Request, Response> {
 
 
     case self._tailIndex:
     case self._tailIndex:
       if part.isEnd {
       if part.isEnd {
+        // Update our state before handling the response part.
+        self._isOpen = false
+        self._onResponsePart?(part)
         self.close()
         self.close()
+      } else {
+        self._onResponsePart?(part)
       }
       }
-      self._onResponsePart(part)
 
 
     default:
     default:
       self._userContexts[index].invokeReceive(part)
       self._userContexts[index].invokeReceive(part)
@@ -275,9 +280,8 @@ internal final class ClientInterceptorPipeline<Request, Response> {
   /// Handles a caught error which has traversed the interceptor pipeline.
   /// Handles a caught error which has traversed the interceptor pipeline.
   @usableFromInline
   @usableFromInline
   internal func _errorCaught(_ error: Error) {
   internal func _errorCaught(_ error: Error) {
-    // We're about to complete, close the pipeline.
-    self.close()
-
+    // We're about to call out to an error handler: update our state first.
+    self._isOpen = false
     var unwrappedError: Error
     var unwrappedError: Error
 
 
     // Unwrap the error, if possible.
     // Unwrap the error, if possible.
@@ -295,7 +299,10 @@ internal final class ClientInterceptorPipeline<Request, Response> {
     }
     }
 
 
     // Emit the unwrapped error.
     // Emit the unwrapped error.
-    self._onError(unwrappedError)
+    self._onError?(unwrappedError)
+
+    // Close the pipeline.
+    self.close()
   }
   }
 
 
   /// Writes a request message into the interceptor pipeline.
   /// Writes a request message into the interceptor pipeline.
@@ -351,7 +358,7 @@ internal final class ClientInterceptorPipeline<Request, Response> {
   ) {
   ) {
     switch index {
     switch index {
     case self._headIndex:
     case self._headIndex:
-      self._onRequestPart(part, promise)
+      self._onRequestPart?(part, promise)
 
 
     case self._tailIndex:
     case self._tailIndex:
       self._invokeSend(
       self._invokeSend(
@@ -407,7 +414,7 @@ internal final class ClientInterceptorPipeline<Request, Response> {
   ) {
   ) {
     switch index {
     switch index {
     case self._headIndex:
     case self._headIndex:
-      self._onCancel(promise)
+      self._onCancel?(promise)
 
 
     case self._tailIndex:
     case self._tailIndex:
       self._invokeCancel(
       self._invokeCancel(
@@ -425,7 +432,7 @@ internal final class ClientInterceptorPipeline<Request, Response> {
 
 
 extension ClientInterceptorPipeline {
 extension ClientInterceptorPipeline {
   /// Closes the pipeline. This should be called once, by the tail interceptor, to indicate that
   /// Closes the pipeline. This should be called once, by the tail interceptor, to indicate that
-  /// the RPC has completed.
+  /// the RPC has completed. If this is not called, we will leak.
   /// - Important: This *must* to be called from the `eventLoop`.
   /// - Important: This *must* to be called from the `eventLoop`.
   @inlinable
   @inlinable
   internal func close() {
   internal func close() {
@@ -437,7 +444,14 @@ extension ClientInterceptorPipeline {
     self._scheduledClose = nil
     self._scheduledClose = nil
 
 
     // Cancel the transport.
     // Cancel the transport.
-    self._onCancel(nil)
+    self._onCancel?(nil)
+
+    // `ClientTransport` holds a reference to us and references to itself via these callbacks. Break
+    // these references now by replacing the callbacks.
+    self._onError = nil
+    self._onCancel = nil
+    self._onRequestPart = nil
+    self._onResponsePart = nil
   }
   }
 
 
   /// Sets up a deadline for the pipeline.
   /// Sets up a deadline for the pipeline.

+ 5 - 5
Sources/GRPC/Interceptor/ClientTransport.swift

@@ -85,8 +85,8 @@ internal final class ClientTransport<Request, Response> {
   // trailers here and only forward them when we receive the status.
   // trailers here and only forward them when we receive the status.
   private var trailers: HPACKHeaders?
   private var trailers: HPACKHeaders?
 
 
-  /// The interceptor pipeline connected to this transport. This must be set to `nil` when removed
-  /// from the `ChannelPipeline` in order to break reference cycles.
+  /// The interceptor pipeline connected to this transport. The pipeline also holds references
+  /// to `self` which are dropped when the interceptor pipeline is closed.
   @usableFromInline
   @usableFromInline
   internal var _pipeline: ClientInterceptorPipeline<Request, Response>?
   internal var _pipeline: ClientInterceptorPipeline<Request, Response>?
 
 
@@ -118,6 +118,7 @@ internal final class ClientTransport<Request, Response> {
     self.logger = logger
     self.logger = logger
     self.serializer = serializer
     self.serializer = serializer
     self.deserializer = deserializer
     self.deserializer = deserializer
+    // The references to self held by the pipeline are dropped when it is closed.
     self._pipeline = ClientInterceptorPipeline(
     self._pipeline = ClientInterceptorPipeline(
       eventLoop: eventLoop,
       eventLoop: eventLoop,
       details: details,
       details: details,
@@ -241,7 +242,8 @@ extension ClientTransport {
 
 
     if self.state.cancel() {
     if self.state.cancel() {
       let error = GRPCError.RPCCancelledByClient()
       let error = GRPCError.RPCCancelledByClient()
-      self.forwardErrorToInterceptors(error)
+      let status = error.makeGRPCStatus()
+      self.forwardToInterceptors(.end(status, [:]))
       self.failBufferedWrites(with: error)
       self.failBufferedWrites(with: error)
       self.channel?.close(mode: .all, promise: nil)
       self.channel?.close(mode: .all, promise: nil)
       self.channelPromise?.fail(error)
       self.channelPromise?.fail(error)
@@ -363,11 +365,9 @@ extension ClientTransport {
   private func dropReferences() {
   private func dropReferences() {
     if self.callEventLoop.inEventLoop {
     if self.callEventLoop.inEventLoop {
       self.channel = nil
       self.channel = nil
-      self._pipeline = nil
     } else {
     } else {
       self.callEventLoop.execute {
       self.callEventLoop.execute {
         self.channel = nil
         self.channel = nil
-        self._pipeline = nil
       }
       }
     }
     }
   }
   }

+ 1 - 5
Tests/GRPCTests/ClientCallTests.swift

@@ -197,11 +197,7 @@ class ClientCallTests: GRPCTestCase {
     // Cancellation should succeed.
     // Cancellation should succeed.
     assertThat(try get.cancel().wait(), .doesNotThrow())
     assertThat(try get.cancel().wait(), .doesNotThrow())
 
 
-    // The status promise will fail.
-    assertThat(
-      try promise.futureResult.wait(),
-      .throws(.instanceOf(GRPCError.RPCCancelledByClient.self))
-    )
+    assertThat(try promise.futureResult.wait(), .hasCode(.cancelled))
 
 
     // Cancellation should now fail, we've already cancelled.
     // Cancellation should now fail, we've already cancelled.
     assertThat(try get.cancel().wait(), .throws(.instanceOf(GRPCError.AlreadyComplete.self)))
     assertThat(try get.cancel().wait(), .throws(.instanceOf(GRPCError.AlreadyComplete.self)))

+ 2 - 2
Tests/GRPCTests/ClientCancellingTests.swift

@@ -27,7 +27,7 @@ class ClientCancellingTests: EchoTestCaseBase {
     call.cancel(promise: nil)
     call.cancel(promise: nil)
 
 
     call.response.whenFailure { error in
     call.response.whenFailure { error in
-      XCTAssertTrue(error is GRPCError.RPCCancelledByClient)
+      XCTAssertEqual((error as? GRPCStatus)?.code, .cancelled)
       responseReceived.fulfill()
       responseReceived.fulfill()
     }
     }
 
 
@@ -47,7 +47,7 @@ class ClientCancellingTests: EchoTestCaseBase {
     call.cancel(promise: nil)
     call.cancel(promise: nil)
 
 
     call.response.whenFailure { error in
     call.response.whenFailure { error in
-      XCTAssertTrue(error is GRPCError.RPCCancelledByClient)
+      XCTAssertEqual((error as? GRPCStatus)?.code, .cancelled)
       responseReceived.fulfill()
       responseReceived.fulfill()
     }
     }
 
 

+ 87 - 0
Tests/GRPCTests/EchoHelpers/Interceptors/EchoInterceptorFactories.swift

@@ -0,0 +1,87 @@
+/*
+ * Copyright 2021, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import EchoModel
+import GRPC
+
+// MARK: - Client
+
+internal final class EchoClientInterceptors: Echo_EchoClientInterceptorFactoryProtocol {
+  internal typealias Factory = () -> ClientInterceptor<Echo_EchoRequest, Echo_EchoResponse>
+  private var factories: [Factory] = []
+
+  internal init(_ factories: Factory...) {
+    self.factories = factories
+  }
+
+  internal func register(_ factory: @escaping Factory) {
+    self.factories.append(factory)
+  }
+
+  private func makeInterceptors() -> [ClientInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
+    return self.factories.map { $0() }
+  }
+
+  func makeGetInterceptors() -> [ClientInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
+    return self.makeInterceptors()
+  }
+
+  func makeExpandInterceptors() -> [ClientInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
+    return self.makeInterceptors()
+  }
+
+  func makeCollectInterceptors() -> [ClientInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
+    return self.makeInterceptors()
+  }
+
+  func makeUpdateInterceptors() -> [ClientInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
+    return self.makeInterceptors()
+  }
+}
+
+// MARK: - Server
+
+internal final class EchoServerInterceptors: Echo_EchoServerInterceptorFactoryProtocol {
+  internal typealias Factory = () -> ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>
+  private var factories: [Factory] = []
+
+  internal init(_ factories: Factory...) {
+    self.factories = factories
+  }
+
+  internal func register(_ factory: @escaping Factory) {
+    self.factories.append(factory)
+  }
+
+  private func makeInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
+    return self.factories.map { $0() }
+  }
+
+  func makeGetInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
+    return self.makeInterceptors()
+  }
+
+  func makeExpandInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
+    return self.makeInterceptors()
+  }
+
+  func makeCollectInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
+    return self.makeInterceptors()
+  }
+
+  func makeUpdateInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
+    return self.makeInterceptors()
+  }
+}

+ 202 - 0
Tests/GRPCTests/InterceptedRPCCancellationTests.swift

@@ -0,0 +1,202 @@
+/*
+ * Copyright 2021, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import EchoImplementation
+import EchoModel
+@testable import GRPC
+import NIOCore
+import NIOPosix
+import protocol SwiftProtobuf.Message
+import XCTest
+
+final class InterceptedRPCCancellationTests: GRPCTestCase {
+  func testCancellationWithinInterceptedRPC() throws {
+    // This test validates that when using interceptors to replay an RPC that the lifecycle of
+    // the interceptor pipeline is correctly managed. That is, the transport maintains a reference
+    // to the pipeline for as long as the call is alive (rather than dropping the reference when
+    // the RPC ends).
+    let group = MultiThreadedEventLoopGroup(numberOfThreads: System.coreCount)
+    defer {
+      XCTAssertNoThrow(try group.syncShutdownGracefully())
+    }
+
+    // Interceptor checks that a "magic" header is present.
+    let serverInterceptors = EchoServerInterceptors(MagicRequiredServerInterceptor.init)
+    let server = try Server.insecure(group: group)
+      .withLogger(self.serverLogger)
+      .withServiceProviders([EchoProvider(interceptors: serverInterceptors)])
+      .bind(host: "127.0.0.1", port: 0)
+      .wait()
+    defer {
+      XCTAssertNoThrow(try server.close().wait())
+    }
+
+    let connection = ClientConnection.insecure(group: group)
+      .withBackgroundActivityLogger(self.clientLogger)
+      .connect(host: "127.0.0.1", port: server.channel.localAddress!.port!)
+    defer {
+      XCTAssertNoThrow(try connection.close().wait())
+    }
+
+    let clientInterceptors = EchoClientInterceptors()
+    // Retries an RPC with a "magic" header if it fails with the permission denied status code.
+    clientInterceptors.register {
+      MagicAddingClientInterceptor(channel: connection)
+    }
+
+    let echo = Echo_EchoClient(channel: connection, interceptors: clientInterceptors)
+
+    let receivedFirstResponse = connection.eventLoop.makePromise(of: Void.self)
+    let update = echo.update { _ in
+      receivedFirstResponse.succeed(())
+    }
+
+    XCTAssertNoThrow(try update.sendMessage(.with { $0.text = "ping" }).wait())
+    // Wait for the pong: it means the second RPC is up and running and the first should have
+    // completed.
+    XCTAssertNoThrow(try receivedFirstResponse.futureResult.wait())
+    XCTAssertNoThrow(try update.cancel().wait())
+
+    let status = try update.status.wait()
+    XCTAssertEqual(status.code, .cancelled)
+  }
+}
+
+final class MagicRequiredServerInterceptor<
+  Request: Message,
+  Response: Message
+>: ServerInterceptor<Request, Response> {
+  override func receive(
+    _ part: GRPCServerRequestPart<Request>,
+    context: ServerInterceptorContext<Request, Response>
+  ) {
+    switch part {
+    case let .metadata(metadata):
+      if metadata.contains(name: "magic") {
+        context.log.debug("metadata contains magic; accepting rpc")
+        context.receive(part)
+      } else {
+        context.log.debug("metadata does not contains magic; rejecting rpc")
+        let status = GRPCStatus(code: .permissionDenied, message: nil)
+        context.send(.end(status, [:]), promise: nil)
+      }
+    case .message, .end:
+      context.receive(part)
+    }
+  }
+}
+
+final class MagicAddingClientInterceptor<
+  Request: Message,
+  Response: Message
+>: ClientInterceptor<Request, Response> {
+  private let channel: GRPCChannel
+  private var requestParts = CircularBuffer<GRPCClientRequestPart<Request>>()
+  private var retry: Call<Request, Response>?
+
+  init(channel: GRPCChannel) {
+    self.channel = channel
+  }
+
+  override func cancel(
+    promise: EventLoopPromise<Void>?,
+    context: ClientInterceptorContext<Request, Response>
+  ) {
+    if let retry = self.retry {
+      context.log.debug("cancelling retry RPC")
+      retry.cancel(promise: promise)
+    } else {
+      context.cancel(promise: promise)
+    }
+  }
+
+  override func send(
+    _ part: GRPCClientRequestPart<Request>,
+    promise: EventLoopPromise<Void>?,
+    context: ClientInterceptorContext<Request, Response>
+  ) {
+    if let retry = self.retry {
+      context.log.debug("retrying part \(part)")
+      retry.send(part, promise: promise)
+    } else {
+      switch part {
+      case .metadata:
+        // Replace the metadata with the magic words.
+        self.requestParts.append(.metadata(["magic": "it's real!"]))
+      case .message, .end:
+        self.requestParts.append(part)
+      }
+      context.send(part, promise: promise)
+    }
+  }
+
+  override func receive(
+    _ part: GRPCClientResponsePart<Response>,
+    context: ClientInterceptorContext<Request, Response>
+  ) {
+    switch part {
+    case .metadata, .message:
+      XCTFail("Unexpected response part \(part)")
+      context.receive(part)
+
+    case let .end(status, _):
+      guard status.code == .permissionDenied else {
+        XCTFail("Unexpected status code \(status)")
+        context.receive(part)
+        return
+      }
+
+      XCTAssertNil(self.retry)
+
+      context.log.debug("initial rpc failed, retrying")
+
+      self.retry = self.channel.makeCall(
+        path: context.path,
+        type: context.type,
+        callOptions: CallOptions(logger: context.logger),
+        interceptors: []
+      )
+
+      self.retry!.invoke(onError: {
+        context.log.debug("intercepting error from retried rpc")
+        context.errorCaught($0)
+      }) { responsePart in
+        context.log.debug("intercepting response part from retried rpc")
+        context.receive(responsePart)
+      }
+
+      while let requestPart = self.requestParts.popFirst() {
+        context.log.debug("replaying \(requestPart) on new rpc")
+        self.retry!.send(requestPart, promise: nil)
+      }
+    }
+  }
+}
+
+// MARK: - GRPC Logger
+
+// Our tests also check the "Source" of a logger is "GRPC". That assertion fails when we log from
+// tests so we'll use our internal logger instead.
+extension ClientInterceptorContext {
+  var log: GRPCLogger {
+    return GRPCLogger(wrapping: self.logger)
+  }
+}
+
+extension ServerInterceptorContext {
+  var log: GRPCLogger {
+    return GRPCLogger(wrapping: self.logger)
+  }
+}