Browse Source

Add a 'closeFuture' to the server context (#1147)

Motivation:

Freeing up user-allocated resources for RPCs is difficult to do
correctly: there isn't a single obvious way to do it for all types, and
in some cases it's impossible since the handler doesn't always know when
the RPC has been terminated. That makes it nigh on impossible for users
to correctly manage resources.

Modifications:

- Add a 'closeFuture' to the server context allowing users to register a
  callbacks to be executed when the RPC completes.
- In existing context implementations this is just the stream channels
  'closeFuture'
- Tests for each RPC style: when the RPC completes successfully, when
  the RPC handler future fails, and when the connection is killed
  mid-rpc

Results:

- It's easier for users to tear down resources when the RPC has
  completed
- Resolves #1145
George Barnett 4 years ago
parent
commit
caefcd5de2

+ 1 - 0
Sources/GRPC/CallHandlers/BidirectionalStreamingServerHandler.swift

@@ -167,6 +167,7 @@ public final class BidirectionalStreamingServerHandler<
         logger: self.context.logger,
         userInfoRef: self.userInfoRef,
         compressionIsEnabled: self.context.encoding.isEnabled,
+        closeFuture: self.context.closeFuture,
         sendResponse: self.interceptResponse(_:metadata:promise:)
       )
 

+ 2 - 1
Sources/GRPC/CallHandlers/ClientStreamingServerHandler.swift

@@ -166,7 +166,8 @@ public final class ClientStreamingServerHandler<
         eventLoop: self.context.eventLoop,
         headers: headers,
         logger: self.context.logger,
-        userInfoRef: self.userInfoRef
+        userInfoRef: self.userInfoRef,
+        closeFuture: self.context.closeFuture
       )
 
       // Move to the next state.

+ 1 - 0
Sources/GRPC/CallHandlers/ServerStreamingServerHandler.swift

@@ -164,6 +164,7 @@ public final class ServerStreamingServerHandler<
         logger: self.context.logger,
         userInfoRef: self.userInfoRef,
         compressionIsEnabled: self.context.encoding.isEnabled,
+        closeFuture: self.context.closeFuture,
         sendResponse: self.interceptResponse(_:metadata:promise:)
       )
 

+ 2 - 1
Sources/GRPC/CallHandlers/UnaryServerHandler.swift

@@ -160,7 +160,8 @@ public final class UnaryServerHandler<
         eventLoop: self.context.eventLoop,
         headers: headers,
         logger: self.context.logger,
-        userInfoRef: self.userInfoRef
+        userInfoRef: self.userInfoRef,
+        closeFuture: self.context.closeFuture
       )
 
       // Move to the next state.

+ 2 - 0
Sources/GRPC/GRPCServerRequestRoutingHandler.swift

@@ -57,6 +57,8 @@ public struct CallHandlerContext {
   internal var responseWriter: GRPCServerResponseWriter
   @usableFromInline
   internal var allocator: ByteBufferAllocator
+  @usableFromInline
+  internal var closeFuture: EventLoopFuture<Void>
 }
 
 /// A call URI split into components.

+ 2 - 1
Sources/GRPC/HTTP2ToRawGRPCServerCodec.swift

@@ -99,7 +99,8 @@ internal final class HTTP2ToRawGRPCServerCodec: ChannelInboundHandler, GRPCServe
         remoteAddress: context.channel.remoteAddress,
         logger: self.logger,
         allocator: context.channel.allocator,
-        responseWriter: self
+        responseWriter: self,
+        closeFuture: context.channel.closeFuture
       )
 
       switch receiveHeaders {

+ 12 - 6
Sources/GRPC/HTTP2ToRawGRPCStateMachine.swift

@@ -278,7 +278,8 @@ extension HTTP2ToRawGRPCStateMachine.RequestIdleResponseIdleState {
     remoteAddress: SocketAddress?,
     logger: Logger,
     allocator: ByteBufferAllocator,
-    responseWriter: GRPCServerResponseWriter
+    responseWriter: GRPCServerResponseWriter,
+    closeFuture: EventLoopFuture<Void>
   ) -> HTTP2ToRawGRPCStateMachine.StateAndReceiveHeadersAction {
     // Extract and validate the content type. If it's nil we need to close.
     guard let contentType = self.extractContentType(from: headers) else {
@@ -326,7 +327,8 @@ extension HTTP2ToRawGRPCStateMachine.RequestIdleResponseIdleState {
       path: path,
       remoteAddress: remoteAddress,
       responseWriter: responseWriter,
-      allocator: allocator
+      allocator: allocator,
+      closeFuture: closeFuture
     )
 
     // We have a matching service, hopefully we have a provider for the method too.
@@ -834,7 +836,8 @@ extension HTTP2ToRawGRPCStateMachine {
     remoteAddress: SocketAddress?,
     logger: Logger,
     allocator: ByteBufferAllocator,
-    responseWriter: GRPCServerResponseWriter
+    responseWriter: GRPCServerResponseWriter,
+    closeFuture: EventLoopFuture<Void>
   ) -> ReceiveHeadersAction {
     return self.withStateAvoidingCoWs { state in
       state.receive(
@@ -844,7 +847,8 @@ extension HTTP2ToRawGRPCStateMachine {
         remoteAddress: remoteAddress,
         logger: logger,
         allocator: allocator,
-        responseWriter: responseWriter
+        responseWriter: responseWriter,
+        closeFuture: closeFuture
       )
     }
   }
@@ -934,7 +938,8 @@ extension HTTP2ToRawGRPCStateMachine.State {
     remoteAddress: SocketAddress?,
     logger: Logger,
     allocator: ByteBufferAllocator,
-    responseWriter: GRPCServerResponseWriter
+    responseWriter: GRPCServerResponseWriter,
+    closeFuture: EventLoopFuture<Void>
   ) -> HTTP2ToRawGRPCStateMachine.ReceiveHeadersAction {
     switch self {
     // This is the only state in which we can receive headers. Everything else is invalid.
@@ -946,7 +951,8 @@ extension HTTP2ToRawGRPCStateMachine.State {
         remoteAddress: remoteAddress,
         logger: logger,
         allocator: allocator,
-        responseWriter: responseWriter
+        responseWriter: responseWriter,
+        closeFuture: closeFuture
       )
       self = stateAndAction.state
       return stateAndAction.action

+ 49 - 2
Sources/GRPC/ServerCallContexts/ServerCallContext.swift

@@ -38,6 +38,24 @@ public protocol ServerCallContext: AnyObject {
   /// this value to take effect compression must have been enabled on the server and a compression
   /// algorithm must have been negotiated with the client.
   var compressionEnabled: Bool { get set }
+
+  /// A future which completes when the call closes. This may be used to register callbacks which
+  /// free up resources used by the RPC.
+  var closeFuture: EventLoopFuture<Void> { get }
+}
+
+extension ServerCallContext {
+  // Default implementation to avoid breaking API.
+  public var closeFuture: EventLoopFuture<Void> {
+    return self.eventLoop.makeFailedFuture(GRPCStatus.closeFutureNotImplemented)
+  }
+}
+
+extension GRPCStatus {
+  internal static let closeFutureNotImplemented = GRPCStatus(
+    code: .unimplemented,
+    message: "This context type has not implemented support for a 'closeFuture'"
+  )
 }
 
 /// Base class providing data provided to the framework user for all server calls.
@@ -111,13 +129,40 @@ open class ServerCallContextBase: ServerCallContext {
 
   private var _trailers: HPACKHeaders = [:]
 
+  /// A future which completes when the call closes. This may be used to register callbacks which
+  /// free up resources used by the RPC.
+  public let closeFuture: EventLoopFuture<Void>
+
+  @available(*, deprecated, renamed: "init(eventLoop:headers:logger:userInfo:closeFuture:)")
   public convenience init(
     eventLoop: EventLoop,
     headers: HPACKHeaders,
     logger: Logger,
     userInfo: UserInfo = UserInfo()
   ) {
-    self.init(eventLoop: eventLoop, headers: headers, logger: logger, userInfoRef: .init(userInfo))
+    self.init(
+      eventLoop: eventLoop,
+      headers: headers,
+      logger: logger,
+      userInfoRef: .init(userInfo),
+      closeFuture: eventLoop.makeFailedFuture(GRPCStatus.closeFutureNotImplemented)
+    )
+  }
+
+  public convenience init(
+    eventLoop: EventLoop,
+    headers: HPACKHeaders,
+    logger: Logger,
+    userInfo: UserInfo = UserInfo(),
+    closeFuture: EventLoopFuture<Void>
+  ) {
+    self.init(
+      eventLoop: eventLoop,
+      headers: headers,
+      logger: logger,
+      userInfoRef: .init(userInfo),
+      closeFuture: closeFuture
+    )
   }
 
   @inlinable
@@ -125,11 +170,13 @@ open class ServerCallContextBase: ServerCallContext {
     eventLoop: EventLoop,
     headers: HPACKHeaders,
     logger: Logger,
-    userInfoRef: Ref<UserInfo>
+    userInfoRef: Ref<UserInfo>,
+    closeFuture: EventLoopFuture<Void>
   ) {
     self.eventLoop = eventLoop
     self.headers = headers
     self.userInfoRef = userInfoRef
     self.logger = logger
+    self.closeFuture = closeFuture
   }
 }

+ 41 - 4
Sources/GRPC/ServerCallContexts/StreamingResponseCallContext.swift

@@ -31,13 +31,36 @@ open class StreamingResponseCallContext<ResponsePayload>: ServerCallContextBase
   /// handler.
   public let statusPromise: EventLoopPromise<GRPCStatus>
 
+  @available(*, deprecated, renamed: "init(eventLoop:headers:logger:userInfo:closeFuture:)")
   public convenience init(
     eventLoop: EventLoop,
     headers: HPACKHeaders,
     logger: Logger,
     userInfo: UserInfo = UserInfo()
   ) {
-    self.init(eventLoop: eventLoop, headers: headers, logger: logger, userInfoRef: .init(userInfo))
+    self.init(
+      eventLoop: eventLoop,
+      headers: headers,
+      logger: logger,
+      userInfoRef: .init(userInfo),
+      closeFuture: eventLoop.makeFailedFuture(GRPCStatus.closeFutureNotImplemented)
+    )
+  }
+
+  public convenience init(
+    eventLoop: EventLoop,
+    headers: HPACKHeaders,
+    logger: Logger,
+    userInfo: UserInfo = UserInfo(),
+    closeFuture: EventLoopFuture<Void>
+  ) {
+    self.init(
+      eventLoop: eventLoop,
+      headers: headers,
+      logger: logger,
+      userInfoRef: .init(userInfo),
+      closeFuture: closeFuture
+    )
   }
 
   @inlinable
@@ -45,10 +68,17 @@ open class StreamingResponseCallContext<ResponsePayload>: ServerCallContextBase
     eventLoop: EventLoop,
     headers: HPACKHeaders,
     logger: Logger,
-    userInfoRef: Ref<UserInfo>
+    userInfoRef: Ref<UserInfo>,
+    closeFuture: EventLoopFuture<Void>
   ) {
     self.statusPromise = eventLoop.makePromise()
-    super.init(eventLoop: eventLoop, headers: headers, logger: logger, userInfoRef: userInfoRef)
+    super.init(
+      eventLoop: eventLoop,
+      headers: headers,
+      logger: logger,
+      userInfoRef: userInfoRef,
+      closeFuture: closeFuture
+    )
   }
 
   /// Send a response to the client.
@@ -131,11 +161,18 @@ internal final class _StreamingResponseCallContext<Request, Response>:
     logger: Logger,
     userInfoRef: Ref<UserInfo>,
     compressionIsEnabled: Bool,
+    closeFuture: EventLoopFuture<Void>,
     sendResponse: @escaping (Response, MessageMetadata, EventLoopPromise<Void>?) -> Void
   ) {
     self._sendResponse = sendResponse
     self._compressionEnabledOnServer = compressionIsEnabled
-    super.init(eventLoop: eventLoop, headers: headers, logger: logger, userInfoRef: userInfoRef)
+    super.init(
+      eventLoop: eventLoop,
+      headers: headers,
+      logger: logger,
+      userInfoRef: userInfoRef,
+      closeFuture: closeFuture
+    )
   }
 
   @inlinable

+ 33 - 3
Sources/GRPC/ServerCallContexts/UnaryResponseCallContext.swift

@@ -51,13 +51,36 @@ open class UnaryResponseCallContext<Response>: ServerCallContextBase, StatusOnly
 
   private var _responseStatus: GRPCStatus = .ok
 
+  @available(*, deprecated, renamed: "init(eventLoop:headers:logger:userInfo:closeFuture:)")
   public convenience init(
     eventLoop: EventLoop,
     headers: HPACKHeaders,
     logger: Logger,
     userInfo: UserInfo = UserInfo()
   ) {
-    self.init(eventLoop: eventLoop, headers: headers, logger: logger, userInfoRef: .init(userInfo))
+    self.init(
+      eventLoop: eventLoop,
+      headers: headers,
+      logger: logger,
+      userInfoRef: .init(userInfo),
+      closeFuture: eventLoop.makeFailedFuture(GRPCStatus.closeFutureNotImplemented)
+    )
+  }
+
+  public convenience init(
+    eventLoop: EventLoop,
+    headers: HPACKHeaders,
+    logger: Logger,
+    userInfo: UserInfo = UserInfo(),
+    closeFuture: EventLoopFuture<Void>
+  ) {
+    self.init(
+      eventLoop: eventLoop,
+      headers: headers,
+      logger: logger,
+      userInfoRef: .init(userInfo),
+      closeFuture: closeFuture
+    )
   }
 
   @inlinable
@@ -65,10 +88,17 @@ open class UnaryResponseCallContext<Response>: ServerCallContextBase, StatusOnly
     eventLoop: EventLoop,
     headers: HPACKHeaders,
     logger: Logger,
-    userInfoRef: Ref<UserInfo>
+    userInfoRef: Ref<UserInfo>,
+    closeFuture: EventLoopFuture<Void>
   ) {
     self.responsePromise = eventLoop.makePromise()
-    super.init(eventLoop: eventLoop, headers: headers, logger: logger, userInfoRef: userInfoRef)
+    super.init(
+      eventLoop: eventLoop,
+      headers: headers,
+      logger: logger,
+      userInfoRef: userInfoRef,
+      closeFuture: closeFuture
+    )
   }
 }
 

+ 82 - 0
Tests/GRPCTests/EchoHelpers/Interceptors/DelegatingClientInterceptor.swift

@@ -0,0 +1,82 @@
+/*
+ * Copyright 2021, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import EchoModel
+import GRPC
+import NIO
+import SwiftProtobuf
+
+/// A client interceptor which delegates the implementation of `send` and `receive` to callbacks.
+final class DelegatingClientInterceptor<
+  Request: Message,
+  Response: Message
+>: ClientInterceptor<Request, Response> {
+  typealias RequestPart = GRPCClientRequestPart<Request>
+  typealias ResponsePart = GRPCClientResponsePart<Response>
+  typealias Context = ClientInterceptorContext<Request, Response>
+  typealias OnSend = (RequestPart, EventLoopPromise<Void>?, Context) -> Void
+  typealias OnReceive = (ResponsePart, Context) -> Void
+
+  private let onSend: OnSend
+  private let onReceive: OnReceive
+
+  init(
+    onSend: @escaping OnSend = { part, promise, context in context.send(part, promise: promise) },
+    onReceive: @escaping OnReceive = { part, context in context.receive(part) }
+  ) {
+    self.onSend = onSend
+    self.onReceive = onReceive
+  }
+
+  override func send(
+    _ part: GRPCClientRequestPart<Request>,
+    promise: EventLoopPromise<Void>?,
+    context: ClientInterceptorContext<Request, Response>
+  ) {
+    self.onSend(part, promise, context)
+  }
+
+  override func receive(
+    _ part: GRPCClientResponsePart<Response>,
+    context: ClientInterceptorContext<Request, Response>
+  ) {
+    self.onReceive(part, context)
+  }
+}
+
+class DelegatingEchoClientInterceptorFactory: Echo_EchoClientInterceptorFactoryProtocol {
+  typealias OnSend = DelegatingClientInterceptor<Echo_EchoRequest, Echo_EchoResponse>.OnSend
+  let interceptor: DelegatingClientInterceptor<Echo_EchoRequest, Echo_EchoResponse>
+
+  init(onSend: @escaping OnSend) {
+    self.interceptor = DelegatingClientInterceptor(onSend: onSend)
+  }
+
+  func makeGetInterceptors() -> [ClientInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
+    return [self.interceptor]
+  }
+
+  func makeExpandInterceptors() -> [ClientInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
+    return [self.interceptor]
+  }
+
+  func makeCollectInterceptors() -> [ClientInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
+    return [self.interceptor]
+  }
+
+  func makeUpdateInterceptors() -> [ClientInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
+    return [self.interceptor]
+  }
+}

+ 67 - 0
Tests/GRPCTests/EchoHelpers/Providers/DelegatingOnCloseEchoProvider.swift

@@ -0,0 +1,67 @@
+/*
+ * Copyright 2021, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import EchoModel
+import GRPC
+import NIO
+
+/// An `Echo_EchoProvider` which sets `onClose` for each RPC and then calls a delegate to provide
+/// the RPC implementation.
+class OnCloseEchoProvider: Echo_EchoProvider {
+  let interceptors: Echo_EchoServerInterceptorFactoryProtocol?
+
+  let onClose: (Result<Void, Error>) -> Void
+  let delegate: Echo_EchoProvider
+
+  init(
+    delegate: Echo_EchoProvider,
+    interceptors: Echo_EchoServerInterceptorFactoryProtocol? = nil,
+    onClose: @escaping (Result<Void, Error>) -> Void
+  ) {
+    self.delegate = delegate
+    self.onClose = onClose
+    self.interceptors = interceptors
+  }
+
+  func get(
+    request: Echo_EchoRequest,
+    context: StatusOnlyCallContext
+  ) -> EventLoopFuture<Echo_EchoResponse> {
+    context.closeFuture.whenComplete(self.onClose)
+    return self.delegate.get(request: request, context: context)
+  }
+
+  func expand(
+    request: Echo_EchoRequest,
+    context: StreamingResponseCallContext<Echo_EchoResponse>
+  ) -> EventLoopFuture<GRPCStatus> {
+    context.closeFuture.whenComplete(self.onClose)
+    return self.delegate.expand(request: request, context: context)
+  }
+
+  func collect(
+    context: UnaryResponseCallContext<Echo_EchoResponse>
+  ) -> EventLoopFuture<(StreamEvent<Echo_EchoRequest>) -> Void> {
+    context.closeFuture.whenComplete(self.onClose)
+    return self.delegate.collect(context: context)
+  }
+
+  func update(
+    context: StreamingResponseCallContext<Echo_EchoResponse>
+  ) -> EventLoopFuture<(StreamEvent<Echo_EchoRequest>) -> Void> {
+    context.closeFuture.whenComplete(self.onClose)
+    return self.delegate.update(context: context)
+  }
+}

+ 53 - 0
Tests/GRPCTests/EchoHelpers/Providers/FailingEchoProvider.swift

@@ -0,0 +1,53 @@
+/*
+ * Copyright 2021, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import EchoModel
+import GRPC
+import NIO
+
+/// An `Echo_EchoProvider` which always returns failed future for each RPC.
+class FailingEchoProvider: Echo_EchoProvider {
+  let interceptors: Echo_EchoServerInterceptorFactoryProtocol?
+
+  init(interceptors: Echo_EchoServerInterceptorFactoryProtocol? = nil) {
+    self.interceptors = interceptors
+  }
+
+  func get(
+    request: Echo_EchoRequest,
+    context: StatusOnlyCallContext
+  ) -> EventLoopFuture<Echo_EchoResponse> {
+    return context.eventLoop.makeFailedFuture(GRPCStatus.processingError)
+  }
+
+  func expand(
+    request: Echo_EchoRequest,
+    context: StreamingResponseCallContext<Echo_EchoResponse>
+  ) -> EventLoopFuture<GRPCStatus> {
+    return context.eventLoop.makeFailedFuture(GRPCStatus.processingError)
+  }
+
+  func collect(
+    context: UnaryResponseCallContext<Echo_EchoResponse>
+  ) -> EventLoopFuture<(StreamEvent<Echo_EchoRequest>) -> Void> {
+    return context.eventLoop.makeFailedFuture(GRPCStatus.processingError)
+  }
+
+  func update(
+    context: StreamingResponseCallContext<Echo_EchoResponse>
+  ) -> EventLoopFuture<(StreamEvent<Echo_EchoRequest>) -> Void> {
+    return context.eventLoop.makeFailedFuture(GRPCStatus.processingError)
+  }
+}

+ 62 - 0
Tests/GRPCTests/EchoHelpers/Providers/NeverResolvingEchoProvider.swift

@@ -0,0 +1,62 @@
+/*
+ * Copyright 2021, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import EchoModel
+import GRPC
+import NIO
+
+/// An `Echo_EchoProvider` which returns a failed future for each RPC which resolves in the distant
+/// future.
+class NeverResolvingEchoProvider: Echo_EchoProvider {
+  let interceptors: Echo_EchoServerInterceptorFactoryProtocol?
+
+  init(interceptors: Echo_EchoServerInterceptorFactoryProtocol? = nil) {
+    self.interceptors = interceptors
+  }
+
+  func get(
+    request: Echo_EchoRequest,
+    context: StatusOnlyCallContext
+  ) -> EventLoopFuture<Echo_EchoResponse> {
+    return context.eventLoop.scheduleTask(deadline: .distantFuture) {
+      throw GRPCStatus.processingError
+    }.futureResult
+  }
+
+  func expand(
+    request: Echo_EchoRequest,
+    context: StreamingResponseCallContext<Echo_EchoResponse>
+  ) -> EventLoopFuture<GRPCStatus> {
+    return context.eventLoop.scheduleTask(deadline: .distantFuture) {
+      throw GRPCStatus.processingError
+    }.futureResult
+  }
+
+  func collect(
+    context: UnaryResponseCallContext<Echo_EchoResponse>
+  ) -> EventLoopFuture<(StreamEvent<Echo_EchoRequest>) -> Void> {
+    return context.eventLoop.scheduleTask(deadline: .distantFuture) {
+      throw GRPCStatus.processingError
+    }.futureResult
+  }
+
+  func update(
+    context: StreamingResponseCallContext<Echo_EchoResponse>
+  ) -> EventLoopFuture<(StreamEvent<Echo_EchoRequest>) -> Void> {
+    return context.eventLoop.scheduleTask(deadline: .distantFuture) {
+      throw GRPCStatus.processingError
+    }.futureResult
+  }
+}

+ 24 - 12
Tests/GRPCTests/HTTP2ToRawGRPCStateMachineTests.swift

@@ -102,7 +102,8 @@ class HTTP2ToRawGRPCStateMachineTests: GRPCTestCase {
       remoteAddress: nil,
       logger: self.logger,
       allocator: ByteBufferAllocator(),
-      responseWriter: NoOpResponseWriter()
+      responseWriter: NoOpResponseWriter(),
+      closeFuture: self.eventLoop.makeSucceededVoidFuture()
     )
 
     assertThat(receiveHeadersAction, .is(.configure()))
@@ -174,7 +175,8 @@ class HTTP2ToRawGRPCStateMachineTests: GRPCTestCase {
       remoteAddress: nil,
       logger: self.logger,
       allocator: ByteBufferAllocator(),
-      responseWriter: NoOpResponseWriter()
+      responseWriter: NoOpResponseWriter(),
+      closeFuture: self.eventLoop.makeSucceededVoidFuture()
     )
     assertThat(action, .is(.configure()))
   }
@@ -188,7 +190,8 @@ class HTTP2ToRawGRPCStateMachineTests: GRPCTestCase {
       remoteAddress: nil,
       logger: self.logger,
       allocator: ByteBufferAllocator(),
-      responseWriter: NoOpResponseWriter()
+      responseWriter: NoOpResponseWriter(),
+      closeFuture: self.eventLoop.makeSucceededVoidFuture()
     )
     assertThat(action, .is(.rejectRPC(.contains(":status", ["415"]))))
   }
@@ -202,7 +205,8 @@ class HTTP2ToRawGRPCStateMachineTests: GRPCTestCase {
       remoteAddress: nil,
       logger: self.logger,
       allocator: ByteBufferAllocator(),
-      responseWriter: NoOpResponseWriter()
+      responseWriter: NoOpResponseWriter(),
+      closeFuture: self.eventLoop.makeSucceededVoidFuture()
     )
     assertThat(action, .is(.rejectRPC(.trailersOnly(code: .unimplemented))))
   }
@@ -216,7 +220,8 @@ class HTTP2ToRawGRPCStateMachineTests: GRPCTestCase {
       remoteAddress: nil,
       logger: self.logger,
       allocator: ByteBufferAllocator(),
-      responseWriter: NoOpResponseWriter()
+      responseWriter: NoOpResponseWriter(),
+      closeFuture: self.eventLoop.makeSucceededVoidFuture()
     )
     assertThat(action, .is(.rejectRPC(.trailersOnly(code: .unimplemented))))
   }
@@ -230,7 +235,8 @@ class HTTP2ToRawGRPCStateMachineTests: GRPCTestCase {
       remoteAddress: nil,
       logger: self.logger,
       allocator: ByteBufferAllocator(),
-      responseWriter: NoOpResponseWriter()
+      responseWriter: NoOpResponseWriter(),
+      closeFuture: self.eventLoop.makeSucceededVoidFuture()
     )
     assertThat(action, .is(.rejectRPC(.trailersOnly(code: .unimplemented))))
   }
@@ -244,7 +250,8 @@ class HTTP2ToRawGRPCStateMachineTests: GRPCTestCase {
       remoteAddress: nil,
       logger: self.logger,
       allocator: ByteBufferAllocator(),
-      responseWriter: NoOpResponseWriter()
+      responseWriter: NoOpResponseWriter(),
+      closeFuture: self.eventLoop.makeSucceededVoidFuture()
     )
     assertThat(action, .is(.rejectRPC(.trailersOnly(code: .unimplemented))))
   }
@@ -259,7 +266,8 @@ class HTTP2ToRawGRPCStateMachineTests: GRPCTestCase {
       remoteAddress: nil,
       logger: self.logger,
       allocator: ByteBufferAllocator(),
-      responseWriter: NoOpResponseWriter()
+      responseWriter: NoOpResponseWriter(),
+      closeFuture: self.eventLoop.makeSucceededVoidFuture()
     )
     assertThat(action, .is(.rejectRPC(.trailersOnly(code: .invalidArgument))))
   }
@@ -274,7 +282,8 @@ class HTTP2ToRawGRPCStateMachineTests: GRPCTestCase {
       remoteAddress: nil,
       logger: self.logger,
       allocator: ByteBufferAllocator(),
-      responseWriter: NoOpResponseWriter()
+      responseWriter: NoOpResponseWriter(),
+      closeFuture: self.eventLoop.makeSucceededVoidFuture()
     )
 
     assertThat(action, .is(.rejectRPC(.trailersOnly(code: .unimplemented))))
@@ -295,7 +304,8 @@ class HTTP2ToRawGRPCStateMachineTests: GRPCTestCase {
       remoteAddress: nil,
       logger: self.logger,
       allocator: ByteBufferAllocator(),
-      responseWriter: NoOpResponseWriter()
+      responseWriter: NoOpResponseWriter(),
+      closeFuture: self.eventLoop.makeSucceededVoidFuture()
     )
 
     // This is expected: however, we also expect 'grpc-accept-encoding' to be in the response
@@ -319,7 +329,8 @@ class HTTP2ToRawGRPCStateMachineTests: GRPCTestCase {
       remoteAddress: nil,
       logger: self.logger,
       allocator: ByteBufferAllocator(),
-      responseWriter: NoOpResponseWriter()
+      responseWriter: NoOpResponseWriter(),
+      closeFuture: self.eventLoop.makeSucceededVoidFuture()
     )
 
     assertThat(action, .is(.configure()))
@@ -335,7 +346,8 @@ class HTTP2ToRawGRPCStateMachineTests: GRPCTestCase {
       remoteAddress: nil,
       logger: self.logger,
       allocator: ByteBufferAllocator(),
-      responseWriter: NoOpResponseWriter()
+      responseWriter: NoOpResponseWriter(),
+      closeFuture: self.eventLoop.makeSucceededVoidFuture()
     )
 
     // This is expected, but we need to check the value of 'grpc-encoding' in the response headers.

+ 2 - 1
Tests/GRPCTests/ServerInterceptorTests.swift

@@ -57,7 +57,8 @@ class ServerInterceptorTests: GRPCTestCase {
       eventLoop: self.eventLoop,
       path: path,
       responseWriter: self.recorder,
-      allocator: ByteBufferAllocator()
+      allocator: ByteBufferAllocator(),
+      closeFuture: self.eventLoop.makeSucceededVoidFuture()
     )
   }
 

+ 243 - 0
Tests/GRPCTests/ServerOnCloseTests.swift

@@ -0,0 +1,243 @@
+/*
+ * Copyright 2021, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import EchoImplementation
+import EchoModel
+import GRPC
+import NIO
+import NIOConcurrencyHelpers
+import XCTest
+
+final class ServerOnCloseTests: GRPCTestCase {
+  private let group = MultiThreadedEventLoopGroup(numberOfThreads: System.coreCount)
+  private var server: Server!
+  private var client: ClientConnection!
+  private var echo: Echo_EchoClient!
+
+  private var eventLoop: EventLoop {
+    return self.group.next()
+  }
+
+  override func tearDown() {
+    // Some tests shut down the client/server so we tolerate errors here.
+    try? self.client.close().wait()
+    try? self.server.close().wait()
+    XCTAssertNoThrow(try self.group.syncShutdownGracefully())
+  }
+
+  private func setUp(provider: Echo_EchoProvider) throws {
+    self.server = try Server.insecure(group: self.group)
+      .withLogger(self.serverLogger)
+      .withServiceProviders([provider])
+      .bind(host: "localhost", port: 0)
+      .wait()
+
+    print(self.server.channel.localAddress!.port!)
+
+    self.client = ClientConnection.insecure(group: self.group)
+      .withBackgroundActivityLogger(self.clientLogger)
+      .connect(host: "localhost", port: self.server.channel.localAddress!.port!)
+
+    self.echo = Echo_EchoClient(
+      channel: self.client,
+      defaultCallOptions: CallOptions(logger: self.clientLogger)
+    )
+  }
+
+  private func startServer(
+    echoDelegate: Echo_EchoProvider,
+    onClose: @escaping (Result<Void, Error>) -> Void
+  ) {
+    let provider = OnCloseEchoProvider(delegate: echoDelegate, onClose: onClose)
+    XCTAssertNoThrow(try self.setUp(provider: provider))
+  }
+
+  private func doTestUnary(
+    echoProvider: Echo_EchoProvider,
+    completesWithStatus code: GRPCStatus.Code
+  ) {
+    let promise = self.eventLoop.makePromise(of: Void.self)
+    self.startServer(echoDelegate: echoProvider) { result in
+      promise.completeWith(result)
+    }
+
+    let get = self.echo.get(.with { $0.text = "" })
+    assertThat(try get.status.wait(), .hasCode(code))
+    XCTAssertNoThrow(try promise.futureResult.wait())
+  }
+
+  func testUnaryOnCloseHappyPath() throws {
+    self.doTestUnary(echoProvider: EchoProvider(), completesWithStatus: .ok)
+  }
+
+  func testUnaryOnCloseAfterUserFunctionFails() throws {
+    self.doTestUnary(echoProvider: FailingEchoProvider(), completesWithStatus: .internalError)
+  }
+
+  func testUnaryOnCloseAfterClientKilled() throws {
+    let promise = self.eventLoop.makePromise(of: Void.self)
+    self.startServer(echoDelegate: NeverResolvingEchoProvider()) { result in
+      promise.completeWith(result)
+    }
+
+    // We want to wait until the client has sent the request parts before closing. We'll grab the
+    // promise for sending end.
+    let endSent = self.client.eventLoop.makePromise(of: Void.self)
+    self.echo.interceptors = DelegatingEchoClientInterceptorFactory { part, promise, context in
+      switch part {
+      case .metadata, .message:
+        context.send(part, promise: promise)
+      case .end:
+        endSent.futureResult.cascade(to: promise)
+        context.send(part, promise: endSent)
+      }
+    }
+
+    _ = self.echo.get(.with { $0.text = "" })
+    // Make sure end has been sent before closing the connection.
+    XCTAssertNoThrow(try endSent.futureResult.wait())
+    XCTAssertNoThrow(try self.client.close().wait())
+    XCTAssertNoThrow(try promise.futureResult.wait())
+  }
+
+  private func doTestClientStreaming(
+    echoProvider: Echo_EchoProvider,
+    completesWithStatus code: GRPCStatus.Code
+  ) {
+    let promise = self.eventLoop.makePromise(of: Void.self)
+    self.startServer(echoDelegate: echoProvider) { result in
+      promise.completeWith(result)
+    }
+
+    let collect = self.echo.collect()
+    // We don't know if we'll send successfully or not.
+    try? collect.sendEnd().wait()
+    assertThat(try collect.status.wait(), .hasCode(code))
+    XCTAssertNoThrow(try promise.futureResult.wait())
+  }
+
+  func testClientStreamingOnCloseHappyPath() throws {
+    self.doTestClientStreaming(echoProvider: EchoProvider(), completesWithStatus: .ok)
+  }
+
+  func testClientStreamingOnCloseAfterUserFunctionFails() throws {
+    self.doTestClientStreaming(
+      echoProvider: FailingEchoProvider(),
+      completesWithStatus: .internalError
+    )
+  }
+
+  func testClientStreamingOnCloseAfterClientKilled() throws {
+    let promise = self.eventLoop.makePromise(of: Void.self)
+    self.startServer(echoDelegate: NeverResolvingEchoProvider()) { error in
+      promise.completeWith(error)
+    }
+
+    let collect = self.echo.collect()
+    XCTAssertNoThrow(try collect.sendMessage(.with { $0.text = "" }).wait())
+    XCTAssertNoThrow(try self.client.close().wait())
+    XCTAssertNoThrow(try promise.futureResult.wait())
+  }
+
+  private func doTestServerStreaming(
+    echoProvider: Echo_EchoProvider,
+    completesWithStatus code: GRPCStatus.Code
+  ) {
+    let promise = self.eventLoop.makePromise(of: Void.self)
+    self.startServer(echoDelegate: echoProvider) { result in
+      promise.completeWith(result)
+    }
+
+    let expand = self.echo.expand(.with { $0.text = "1 2 3" }) { _ in /* ignore responses */ }
+    assertThat(try expand.status.wait(), .hasCode(code))
+    XCTAssertNoThrow(try promise.futureResult.wait())
+  }
+
+  func testServerStreamingOnCloseHappyPath() throws {
+    self.doTestServerStreaming(echoProvider: EchoProvider(), completesWithStatus: .ok)
+  }
+
+  func testServerStreamingOnCloseAfterUserFunctionFails() throws {
+    self.doTestServerStreaming(
+      echoProvider: FailingEchoProvider(),
+      completesWithStatus: .internalError
+    )
+  }
+
+  func testServerStreamingOnCloseAfterClientKilled() throws {
+    let promise = self.eventLoop.makePromise(of: Void.self)
+    self.startServer(echoDelegate: NeverResolvingEchoProvider()) { result in
+      promise.completeWith(result)
+    }
+
+    // We want to wait until the client has sent the request parts before closing. We'll grab the
+    // promise for sending end.
+    let endSent = self.client.eventLoop.makePromise(of: Void.self)
+    self.echo.interceptors = DelegatingEchoClientInterceptorFactory { part, promise, context in
+      switch part {
+      case .metadata, .message:
+        context.send(part, promise: promise)
+      case .end:
+        endSent.futureResult.cascade(to: promise)
+        context.send(part, promise: endSent)
+      }
+    }
+
+    _ = self.echo.expand(.with { $0.text = "1 2 3" }) { _ in /* ignore responses */ }
+    // Make sure end has been sent before closing the connection.
+    XCTAssertNoThrow(try endSent.futureResult.wait())
+    XCTAssertNoThrow(try self.client.close().wait())
+    XCTAssertNoThrow(try promise.futureResult.wait())
+  }
+
+  private func doTestBidirectionalStreaming(
+    echoProvider: Echo_EchoProvider,
+    completesWithStatus code: GRPCStatus.Code
+  ) {
+    let promise = self.eventLoop.makePromise(of: Void.self)
+    self.startServer(echoDelegate: echoProvider) { result in
+      promise.completeWith(result)
+    }
+
+    let update = self.echo.update { _ in /* ignored */ }
+    // We don't know if we'll send successfully or not.
+    try? update.sendEnd().wait()
+    assertThat(try update.status.wait(), .hasCode(code))
+    XCTAssertNoThrow(try promise.futureResult.wait())
+  }
+
+  func testBidirectionalStreamingOnCloseHappyPath() throws {
+    self.doTestBidirectionalStreaming(echoProvider: EchoProvider(), completesWithStatus: .ok)
+  }
+
+  func testBidirectionalStreamingOnCloseAfterUserFunctionFails() throws {
+    self.doTestBidirectionalStreaming(
+      echoProvider: FailingEchoProvider(),
+      completesWithStatus: .internalError
+    )
+  }
+
+  func testBidirectionalStreamingOnCloseAfterClientKilled() throws {
+    let promise = self.eventLoop.makePromise(of: Void.self)
+    self.startServer(echoDelegate: NeverResolvingEchoProvider()) { result in
+      promise.completeWith(result)
+    }
+
+    let update = self.echo.update { _ in /* ignored */ }
+    XCTAssertNoThrow(try update.sendMessage(.with { $0.text = "" }).wait())
+    XCTAssertNoThrow(try self.client.close().wait())
+    XCTAssertNoThrow(try promise.futureResult.wait())
+  }
+}

+ 2 - 1
Tests/GRPCTests/UnaryServerHandlerTests.swift

@@ -68,7 +68,8 @@ extension ServerHandlerTestCase {
       path: "/ignored",
       remoteAddress: nil,
       responseWriter: self.recorder,
-      allocator: self.allocator
+      allocator: self.allocator,
+      closeFuture: self.eventLoop.makeSucceededVoidFuture()
     )
   }
 }