Browse Source

Expose closeFuture in server interceptor context (#1553)

Motivation:

Server calls expose a `closeFuture` where users can register callbacks
to tear things down when the RPC ends. Interceptors don't have this
capability and must rely on observing an `.end`.

Modifications:

Expose the `closeFuture` from `ServerCallContext` to
the `ServerInterceptorContext`.

Result:

- Users can be notified in interceptors when the call ends.
- Resolves #1552
George Barnett 2 years ago
parent
commit
bcca31f3cf

+ 1 - 0
Sources/GRPC/AsyncAwaitSupport/GRPCAsyncServerHandler.swift

@@ -283,6 +283,7 @@ internal final class AsyncServerHandler<
       callType: callType,
       remoteAddress: context.remoteAddress,
       userInfoRef: self.userInfoRef,
+      closeFuture: context.closeFuture,
       interceptors: interceptors,
       onRequestPart: self.receiveInterceptedPart(_:),
       onResponsePart: self.sendInterceptedPart(_:promise:)

+ 1 - 0
Sources/GRPC/CallHandlers/BidirectionalStreamingServerHandler.swift

@@ -92,6 +92,7 @@ public final class BidirectionalStreamingServerHandler<
       callType: .bidirectionalStreaming,
       remoteAddress: context.remoteAddress,
       userInfoRef: userInfoRef,
+      closeFuture: context.closeFuture,
       interceptors: interceptors,
       onRequestPart: self.receiveInterceptedPart(_:),
       onResponsePart: self.sendInterceptedPart(_:promise:)

+ 1 - 0
Sources/GRPC/CallHandlers/ClientStreamingServerHandler.swift

@@ -93,6 +93,7 @@ public final class ClientStreamingServerHandler<
       callType: .clientStreaming,
       remoteAddress: context.remoteAddress,
       userInfoRef: userInfoRef,
+      closeFuture: context.closeFuture,
       interceptors: interceptors,
       onRequestPart: self.receiveInterceptedPart(_:),
       onResponsePart: self.sendInterceptedPart(_:promise:)

+ 1 - 0
Sources/GRPC/CallHandlers/ServerStreamingServerHandler.swift

@@ -89,6 +89,7 @@ public final class ServerStreamingServerHandler<
       callType: .serverStreaming,
       remoteAddress: context.remoteAddress,
       userInfoRef: userInfoRef,
+      closeFuture: context.closeFuture,
       interceptors: interceptors,
       onRequestPart: self.receiveInterceptedPart(_:),
       onResponsePart: self.sendInterceptedPart(_:promise:)

+ 1 - 0
Sources/GRPC/CallHandlers/UnaryServerHandler.swift

@@ -87,6 +87,7 @@ public final class UnaryServerHandler<
       callType: .unary,
       remoteAddress: context.remoteAddress,
       userInfoRef: userInfoRef,
+      closeFuture: context.closeFuture,
       interceptors: interceptors,
       onRequestPart: self.receiveInterceptedPart(_:),
       onResponsePart: self.sendInterceptedPart(_:promise:)

+ 6 - 0
Sources/GRPC/Interceptor/ServerInterceptorContext.swift

@@ -54,6 +54,12 @@ public struct ServerInterceptorContext<Request, Response> {
     return self._pipeline.remoteAddress
   }
 
+  /// A future which completes when the call closes. This may be used to register callbacks which
+  /// free up resources used by the interceptor.
+  public var closeFuture: EventLoopFuture<Void> {
+    return self._pipeline.closeFuture
+  }
+
   /// A 'UserInfo' dictionary.
   ///
   /// - Important: While `UserInfo` has value-semantics, this property retrieves from, and sets a

+ 7 - 0
Sources/GRPC/Interceptor/ServerInterceptorPipeline.swift

@@ -42,6 +42,11 @@ internal final class ServerInterceptorPipeline<Request, Response> {
   @usableFromInline
   internal let userInfoRef: Ref<UserInfo>
 
+  /// A future which completes when the call closes. This may be used to register callbacks which
+  /// free up resources used by the interceptor.
+  @usableFromInline
+  internal let closeFuture: EventLoopFuture<Void>
+
   /// Called when a response part has traversed the interceptor pipeline.
   @usableFromInline
   internal let _onResponsePart: (GRPCServerResponsePart<Response>, EventLoopPromise<Void>?) -> Void
@@ -99,6 +104,7 @@ internal final class ServerInterceptorPipeline<Request, Response> {
     callType: GRPCCallType,
     remoteAddress: SocketAddress?,
     userInfoRef: Ref<UserInfo>,
+    closeFuture: EventLoopFuture<Void>,
     interceptors: [ServerInterceptor<Request, Response>],
     onRequestPart: @escaping (GRPCServerRequestPart<Request>) -> Void,
     onResponsePart: @escaping (GRPCServerResponsePart<Response>, EventLoopPromise<Void>?) -> Void
@@ -109,6 +115,7 @@ internal final class ServerInterceptorPipeline<Request, Response> {
     self.type = callType
     self.remoteAddress = remoteAddress
     self.userInfoRef = userInfoRef
+    self.closeFuture = closeFuture
 
     self._onResponsePart = onResponsePart
     self._onRequestPart = onRequestPart

+ 59 - 1
Tests/GRPCTests/InterceptorsTests.swift

@@ -13,6 +13,7 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
+import Atomics
 import EchoImplementation
 import EchoModel
 import GRPC
@@ -28,6 +29,7 @@ class InterceptorsTests: GRPCTestCase {
   private var server: Server!
   private var connection: ClientConnection!
   private var echo: Echo_EchoNIOClient!
+  private let onCloseCounter = ManagedAtomic<Int>(0)
 
   override func setUp() {
     super.setUp()
@@ -35,7 +37,7 @@ class InterceptorsTests: GRPCTestCase {
 
     self.server = try! Server.insecure(group: self.group)
       .withServiceProviders([
-        EchoProvider(),
+        EchoProvider(interceptors: CountOnCloseInterceptors(counter: self.onCloseCounter)),
         HelloWorldProvider(interceptors: HelloWorldServerInterceptorFactory()),
       ])
       .withLogger(self.serverLogger)
@@ -64,6 +66,8 @@ class InterceptorsTests: GRPCTestCase {
     let get = self.echo.get(.with { $0.text = "hello" })
     assertThat(try get.response.wait(), .is(.with { $0.text = "hello :teg ohce tfiwS" }))
     assertThat(try get.status.wait(), .hasCode(.ok))
+
+    XCTAssertEqual(self.onCloseCounter.load(ordering: .sequentiallyConsistent), 1)
   }
 
   func testCollect() {
@@ -73,6 +77,8 @@ class InterceptorsTests: GRPCTestCase {
     collect.sendEnd(promise: nil)
     assertThat(try collect.response.wait(), .is(.with { $0.text = "3 4 1 2 :tcelloc ohce tfiwS" }))
     assertThat(try collect.status.wait(), .hasCode(.ok))
+
+    XCTAssertEqual(self.onCloseCounter.load(ordering: .sequentiallyConsistent), 1)
   }
 
   func testExpand() {
@@ -81,6 +87,8 @@ class InterceptorsTests: GRPCTestCase {
       assertThat(response, .is(.with { $0.text = "hello :)0( dnapxe ohce tfiwS" }))
     }
     assertThat(try expand.status.wait(), .hasCode(.ok))
+
+    XCTAssertEqual(self.onCloseCounter.load(ordering: .sequentiallyConsistent), 1)
   }
 
   func testUpdate() {
@@ -91,6 +99,8 @@ class InterceptorsTests: GRPCTestCase {
     update.sendMessage(.with { $0.text = "hello" }, promise: nil)
     update.sendEnd(promise: nil)
     assertThat(try update.status.wait(), .hasCode(.ok))
+
+    XCTAssertEqual(self.onCloseCounter.load(ordering: .sequentiallyConsistent), 1)
   }
 
   func testSayHello() {
@@ -360,6 +370,54 @@ final class ReversingInterceptors: Echo_EchoClientInterceptorFactoryProtocol {
   }
 }
 
+final class CountOnCloseInterceptors: Echo_EchoServerInterceptorFactoryProtocol {
+  // This interceptor is stateless, let's just share it.
+  private let interceptors: [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>]
+
+  init(counter: ManagedAtomic<Int>) {
+    self.interceptors = [CountOnCloseServerInterceptor(counter: counter)]
+  }
+
+  func makeGetInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
+    return self.interceptors
+  }
+
+  func makeExpandInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
+    return self.interceptors
+  }
+
+  func makeCollectInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
+    return self.interceptors
+  }
+
+  func makeUpdateInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
+    return self.interceptors
+  }
+}
+
+final class CountOnCloseServerInterceptor: ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse> {
+  private let counter: ManagedAtomic<Int>
+
+  init(counter: ManagedAtomic<Int>) {
+    self.counter = counter
+  }
+
+  override func receive(
+    _ part: GRPCServerRequestPart<Echo_EchoRequest>,
+    context: ServerInterceptorContext<Echo_EchoRequest, Echo_EchoResponse>
+  ) {
+    switch part {
+    case .metadata:
+      context.closeFuture.whenComplete { _ in
+        self.counter.wrappingIncrement(ordering: .sequentiallyConsistent)
+      }
+    default:
+      ()
+    }
+    context.receive(part)
+  }
+}
+
 private enum MagicKey: UserInfo.Key {
   typealias Value = String
 }

+ 1 - 0
Tests/GRPCTests/ServerInterceptorPipelineTests.swift

@@ -43,6 +43,7 @@ class ServerInterceptorPipelineTests: GRPCTestCase {
       callType: callType,
       remoteAddress: nil,
       userInfoRef: Ref(UserInfo()),
+      closeFuture: self.embeddedEventLoop.makeSucceededVoidFuture(),
       interceptors: interceptors,
       onRequestPart: onRequestPart,
       onResponsePart: onResponsePart