Browse Source

Alllow calls to run on a specified event loop (#1156)

Motivation:

It can be helpful for callers to know which event loop their RPCs are
completed on. Currently the caller must hop to the desired event loop if
it differs from the event loop used by the underlying transport. This
can be tedious and error prone since it must be done in a number of
places.

Modifications:

- Separate the event loop an RPC runs on (i.e. the one used for
  interceptors and response futures) and the event loop the RPC is
  transported on (i.e. the one used by the underlying 'Channel').
- Add an 'EventLoopPreference' to 'CallOptions' allowing users to
  specify they are 'indifferent' (the existing and default behaviour,
  use an event loop provided by the framework) or use an 'exact' event
  loop.
- Add tests.

Result:

Callers can specify an event loop to run their RPCs on.
George Barnett 4 years ago
parent
commit
25f6cdaddf

+ 71 - 1
Sources/GRPC/CallOptions.swift

@@ -15,6 +15,7 @@
  */
 import struct Foundation.UUID
 import Logging
+import NIO
 import NIOHPACK
 import NIOHTTP1
 import NIOHTTP2
@@ -61,6 +62,16 @@ public struct CallOptions {
   /// messages associated with the call.
   public var requestIDHeader: String?
 
+  /// A preference for the `EventLoop` that the call is executed on.
+  ///
+  /// The `EventLoop` resulting from the preference will be used to create any `EventLoopFuture`s
+  /// associated with the call, such as the `response` for calls with a single response (i.e. unary
+  /// and client streaming). For calls which stream responses (server streaming and bidirectional
+  /// streaming) the response handler is executed on this event loop.
+  ///
+  /// Note that the underlying connection is not guaranteed to run on the same event loop.
+  public var eventLoopPreference: EventLoopPreference
+
   /// A logger used for the call. Defaults to a no-op logger.
   ///
   /// If a `requestIDProvider` exists then a request ID will automatically attached to the logger's
@@ -75,16 +86,41 @@ public struct CallOptions {
     requestIDHeader: String? = nil,
     cacheable: Bool = false,
     logger: Logger = Logger(label: "io.grpc", factory: { _ in SwiftLogNoOpLogHandler() })
+  ) {
+    self.init(
+      customMetadata: customMetadata,
+      timeLimit: timeLimit,
+      messageEncoding: messageEncoding,
+      requestIDProvider: requestIDProvider,
+      requestIDHeader: requestIDHeader,
+      eventLoopPreference: .indifferent,
+      cacheable: cacheable,
+      logger: logger
+    )
+  }
+
+  public init(
+    customMetadata: HPACKHeaders = HPACKHeaders(),
+    timeLimit: TimeLimit = .none,
+    messageEncoding: ClientMessageEncoding = .disabled,
+    requestIDProvider: RequestIDProvider = .autogenerated,
+    requestIDHeader: String? = nil,
+    eventLoopPreference: EventLoopPreference,
+    cacheable: Bool = false,
+    logger: Logger = Logger(label: "io.grpc", factory: { _ in SwiftLogNoOpLogHandler() })
   ) {
     self.customMetadata = customMetadata
     self.messageEncoding = messageEncoding
     self.requestIDProvider = requestIDProvider
     self.requestIDHeader = requestIDHeader
-    self.cacheable = false
+    self.cacheable = cacheable
     self.timeLimit = timeLimit
     self.logger = logger
+    self.eventLoopPreference = eventLoopPreference
   }
+}
 
+extension CallOptions {
   public struct RequestIDProvider {
     private enum RequestIDSource {
       case none
@@ -128,3 +164,37 @@ public struct CallOptions {
     }
   }
 }
+
+extension CallOptions {
+  public struct EventLoopPreference {
+    /// No preference. The framework will assign an `EventLoop`.
+    public static let indifferent = EventLoopPreference(.indifferent)
+
+    /// Use the provided `EventLoop` for the call.
+    public static func exact(_ eventLoop: EventLoop) -> EventLoopPreference {
+      return EventLoopPreference(.exact(eventLoop))
+    }
+
+    private enum Preference {
+      case indifferent
+      case exact(EventLoop)
+    }
+
+    private var preference: Preference
+
+    private init(_ preference: Preference) {
+      self.preference = preference
+    }
+  }
+}
+
+extension CallOptions.EventLoopPreference {
+  internal var exact: EventLoop? {
+    switch self.preference {
+    case let .exact(eventLoop):
+      return eventLoop
+    case .indifferent:
+      return nil
+    }
+  }
+}

+ 2 - 1
Sources/GRPC/ClientCalls/Call.swift

@@ -271,6 +271,7 @@ extension Call {
         to: self.path,
         for: self.type,
         withOptions: self.options,
+        onEventLoop: self.eventLoop,
         interceptedBy: self.interceptors,
         onError: onError,
         onResponsePart: onResponsePart
@@ -334,7 +335,7 @@ extension Call {
 
     case let (.none, .invoked(transport)):
       // Just ask the transport.
-      return transport.channel()
+      return transport.getChannel()
     }
   }
 }

+ 4 - 2
Sources/GRPC/ClientConnection.swift

@@ -152,11 +152,12 @@ extension ClientConnection: GRPCChannel {
     var options = callOptions
     self.populateLogger(in: &options)
     let multiplexer = self.getMultiplexer()
+    let eventLoop = callOptions.eventLoopPreference.exact ?? multiplexer.eventLoop
 
     return Call(
       path: path,
       type: type,
-      eventLoop: multiplexer.eventLoop,
+      eventLoop: eventLoop,
       options: options,
       interceptors: interceptors,
       transportFactory: .http2(
@@ -177,11 +178,12 @@ extension ClientConnection: GRPCChannel {
     var options = callOptions
     self.populateLogger(in: &options)
     let multiplexer = self.getMultiplexer()
+    let eventLoop = callOptions.eventLoopPreference.exact ?? multiplexer.eventLoop
 
     return Call(
       path: path,
       type: type,
-      eventLoop: multiplexer.eventLoop,
+      eventLoop: eventLoop,
       options: options,
       interceptors: interceptors,
       transportFactory: .http2(

+ 2 - 2
Sources/GRPC/FakeChannel.swift

@@ -113,7 +113,7 @@ public class FakeChannel: GRPCChannel {
       eventLoop: eventLoop,
       options: callOptions,
       interceptors: interceptors,
-      transportFactory: .fake(stream, on: eventLoop)
+      transportFactory: .fake(stream)
     )
   }
 
@@ -131,7 +131,7 @@ public class FakeChannel: GRPCChannel {
       eventLoop: eventLoop,
       options: callOptions,
       interceptors: interceptors,
-      transportFactory: .fake(stream, on: eventLoop)
+      transportFactory: .fake(stream)
     )
   }
 

+ 231 - 101
Sources/GRPC/Interceptor/ClientTransport.swift

@@ -33,12 +33,12 @@ import NIOHTTP2
 ///
 /// ### Thread Safety
 ///
-/// This class is not thread safe. All methods **must** be executed on the transport's `eventLoop`.
+/// This class is not thread safe. All methods **must** be executed on the transport's `callEventLoop`.
 @usableFromInline
 internal final class ClientTransport<Request, Response> {
-  /// The `EventLoop` this transport is running on.
+  /// The `EventLoop` the call is running on. State must be accessed from this event loop.
   @usableFromInline
-  internal let eventLoop: EventLoop
+  internal let callEventLoop: EventLoop
 
   /// The current state of the transport.
   private var state: ClientTransportState = .idle
@@ -92,8 +92,8 @@ internal final class ClientTransport<Request, Response> {
   @usableFromInline
   internal var _pipeline: ClientInterceptorPipeline<Request, Response>?
 
-  /// The 'ChannelHandlerContext'.
-  private var context: ChannelHandlerContext?
+  /// The `NIO.Channel` used by the transport, if it is available.
+  private var channel: Channel?
 
   /// Our current state as logging metadata.
   private var stateForLogging: Logger.MetadataValue {
@@ -114,7 +114,7 @@ internal final class ClientTransport<Request, Response> {
     onError: @escaping (Error) -> Void,
     onResponsePart: @escaping (GRPCClientResponsePart<Response>) -> Void
   ) {
-    self.eventLoop = eventLoop
+    self.callEventLoop = eventLoop
     self.callDetails = details
     self.serializer = serializer
     self.deserializer = deserializer
@@ -134,9 +134,9 @@ internal final class ClientTransport<Request, Response> {
 
   /// Configure the transport to communicate with the server.
   /// - Parameter configurator: A callback to invoke in order to configure this transport.
-  /// - Important: This *must* to be called from the `eventLoop`.
+  /// - Important: This *must* to be called from the `callEventLoop`.
   internal func configure(_ configurator: @escaping (ChannelHandler) -> EventLoopFuture<Void>) {
-    self.eventLoop.assertInEventLoop()
+    self.callEventLoop.assertInEventLoop()
     if self.state.configureTransport() {
       self.configure(using: configurator)
     }
@@ -146,10 +146,10 @@ internal final class ClientTransport<Request, Response> {
   /// - Parameters:
   ///   - part: The part to send.
   ///   - promise: A promise which will be completed when the request part has been handled.
-  /// - Important: This *must* to be called from the `eventLoop`.
+  /// - Important: This *must* to be called from the `callEventLoop`.
   @inlinable
   internal func send(_ part: GRPCClientRequestPart<Request>, promise: EventLoopPromise<Void>?) {
-    self.eventLoop.assertInEventLoop()
+    self.callEventLoop.assertInEventLoop()
     if let pipeline = self._pipeline {
       pipeline.send(part, promise: promise)
     } else {
@@ -161,7 +161,7 @@ internal final class ClientTransport<Request, Response> {
   /// - Parameter promise: A promise which will be completed when the cancellation attempt has
   ///   been handled.
   internal func cancel(promise: EventLoopPromise<Void>?) {
-    self.eventLoop.assertInEventLoop()
+    self.callEventLoop.assertInEventLoop()
     if let pipeline = self._pipeline {
       pipeline.cancel(promise: promise)
     } else {
@@ -170,21 +170,21 @@ internal final class ClientTransport<Request, Response> {
   }
 
   /// A request for the underlying `Channel`.
-  internal func channel() -> EventLoopFuture<Channel> {
-    self.eventLoop.assertInEventLoop()
+  internal func getChannel() -> EventLoopFuture<Channel> {
+    self.callEventLoop.assertInEventLoop()
 
     // Do we already have a promise?
     if let promise = self.channelPromise {
       return promise.futureResult
     } else {
       // Make and store the promise.
-      let promise = self.eventLoop.makePromise(of: Channel.self)
+      let promise = self.callEventLoop.makePromise(of: Channel.self)
       self.channelPromise = promise
 
       // Ask the state machine if we can have it.
       switch self.state.getChannel() {
       case .succeed:
-        if let channel = self.context?.channel {
+        if let channel = self.channel {
           promise.succeed(channel)
         }
 
@@ -207,17 +207,26 @@ extension ClientTransport {
   /// - Parameters:
   ///   - part: The request part to send.
   ///   - promise: A promise which will be completed when the part has been handled.
-  /// - Important: This *must* to be called from the `eventLoop`.
+  /// - Important: This *must* to be called from the `callEventLoop`.
   private func sendFromPipeline(
     _ part: GRPCClientRequestPart<Request>,
     promise: EventLoopPromise<Void>?
   ) {
-    self.eventLoop.assertInEventLoop()
+    self.callEventLoop.assertInEventLoop()
     switch self.state.send() {
     case .writeToBuffer:
       self.buffer(part, promise: promise)
+
     case .writeToChannel:
-      self.write(part, promise: promise, flush: self.shouldFlush(after: part))
+      // Banging the channel is okay here: we'll only be told to 'writeToChannel' if we're in the
+      // correct state, the requirements of that state are having an active `Channel`.
+      self.writeToChannel(
+        self.channel!,
+        part: part,
+        promise: promise,
+        flush: self.shouldFlush(after: part)
+      )
+
     case .alreadyComplete:
       promise?.fail(GRPCError.AlreadyComplete())
     }
@@ -225,15 +234,15 @@ extension ClientTransport {
 
   /// Attempt to cancel the RPC. Should only be called from the interceptor pipeline.
   /// - Parameter promise: A promise which will be completed when the cancellation has been handled.
-  /// - Important: This *must* to be called from the `eventLoop`.
+  /// - Important: This *must* to be called from the `callEventLoop`.
   private func cancelFromPipeline(promise: EventLoopPromise<Void>?) {
-    self.eventLoop.assertInEventLoop()
+    self.callEventLoop.assertInEventLoop()
 
     if self.state.cancel() {
       let error = GRPCError.RPCCancelledByClient()
       self.forwardErrorToInterceptors(error)
       self.failBufferedWrites(with: error)
-      self.context?.channel.close(mode: .all, promise: nil)
+      self.channel?.close(mode: .all, promise: nil)
       self.channelPromise?.fail(error)
       promise?.succeed(())
     } else {
@@ -252,56 +261,86 @@ extension ClientTransport: ChannelInboundHandler {
   typealias OutboundOut = _RawGRPCClientRequestPart
 
   @usableFromInline
-  func handlerAdded(context: ChannelHandlerContext) {
-    self.context = context
+  internal func handlerRemoved(context: ChannelHandlerContext) {
+    self.dropReferences()
   }
 
   @usableFromInline
-  internal func handlerRemoved(context: ChannelHandlerContext) {
-    self.eventLoop.assertInEventLoop()
-    self.context = nil
-    // Break the reference cycle.
-    self._pipeline = nil
+  internal func errorCaught(context: ChannelHandlerContext, error: Error) {
+    self.handleError(error)
   }
 
-  internal func channelError(_ error: Error) {
-    self.eventLoop.assertInEventLoop()
+  @usableFromInline
+  internal func channelActive(context: ChannelHandlerContext) {
+    self.transportActivated(channel: context.channel)
+  }
 
-    switch self.state.channelError() {
-    case .doNothing:
-      ()
-    case .propagateError:
-      self.forwardErrorToInterceptors(error)
-      self.failBufferedWrites(with: error)
+  @usableFromInline
+  internal func channelInactive(context: ChannelHandlerContext) {
+    self.transportDeactivated()
+  }
 
-    case .propagateErrorAndClose:
-      self.forwardErrorToInterceptors(error)
-      self.failBufferedWrites(with: error)
-      self.context?.close(mode: .all, promise: nil)
+  @usableFromInline
+  internal func channelRead(context: ChannelHandlerContext, data: NIOAny) {
+    switch self.unwrapInboundIn(data) {
+    case let .initialMetadata(headers):
+      self.receiveFromChannel(initialMetadata: headers)
+
+    case let .message(box):
+      self.receiveFromChannel(message: box.message)
+
+    case let .trailingMetadata(trailers):
+      self.receiveFromChannel(trailingMetadata: trailers)
+
+    case let .status(status):
+      self.receiveFromChannel(status: status)
     }
+
+    // (We're the end of the channel. No need to forward anything.)
   }
+}
 
-  @usableFromInline
-  internal func errorCaught(context: ChannelHandlerContext, error: Error) {
-    self.channelError(error)
+extension ClientTransport {
+  /// The `Channel` became active. Send out any buffered requests.
+  private func transportActivated(channel: Channel) {
+    if self.callEventLoop.inEventLoop {
+      self._transportActivated(channel: channel)
+    } else {
+      self.callEventLoop.execute {
+        self._transportActivated(channel: channel)
+      }
+    }
   }
 
-  @usableFromInline
-  internal func channelActive(context: ChannelHandlerContext) {
-    self.eventLoop.assertInEventLoop()
+  /// On-loop implementation of `transportActivated(channel:)`.
+  private func _transportActivated(channel: Channel) {
+    self.callEventLoop.assertInEventLoop()
     self.logger.debug("activated stream channel", source: "GRPC")
-    if self.state.channelActive() {
-      self.unbuffer(to: context.channel)
+
+    if self.state.activate() {
+      self.channel = channel
+      self.unbuffer()
     } else {
-      context.close(mode: .all, promise: nil)
+      channel.close(mode: .all, promise: nil)
     }
   }
 
-  @usableFromInline
-  internal func channelInactive(context: ChannelHandlerContext) {
-    self.eventLoop.assertInEventLoop()
+  /// The `Channel` became inactive. Fail any buffered writes and forward an error to the
+  /// interceptor pipeline if necessary.
+  private func transportDeactivated() {
+    if self.callEventLoop.inEventLoop {
+      self._transportDeactivated()
+    } else {
+      self.callEventLoop.execute {
+        self._transportDeactivated()
+      }
+    }
+  }
 
-    switch self.state.channelInactive() {
+  /// On-loop implementation of `transportDeactivated()`.
+  private func _transportDeactivated() {
+    self.callEventLoop.assertInEventLoop()
+    switch self.state.deactivate() {
     case .doNothing:
       ()
 
@@ -316,40 +355,125 @@ extension ClientTransport: ChannelInboundHandler {
     }
   }
 
-  @usableFromInline
-  internal func channelRead(context: ChannelHandlerContext, data: NIOAny) {
-    self.eventLoop.assertInEventLoop()
-    let part = self.unwrapInboundIn(data)
+  /// Drops any references to the `Channel` and interceptor pipeline.
+  private func dropReferences() {
+    if self.callEventLoop.inEventLoop {
+      self.channel = nil
+      self._pipeline = nil
+    } else {
+      self.callEventLoop.execute {
+        self.channel = nil
+        self._pipeline = nil
+      }
+    }
+  }
 
-    switch part {
-    case let .initialMetadata(headers):
-      if self.state.channelRead(isEnd: false) {
-        self.forwardToInterceptors(.metadata(headers))
+  /// Handles an error caught in the pipeline or from elsewhere. The error may be forwarded to the
+  /// interceptor pipeline and any buffered writes will be failed. Any underlying `Channel` will
+  /// also be closed.
+  internal func handleError(_ error: Error) {
+    if self.callEventLoop.inEventLoop {
+      self._handleError(error)
+    } else {
+      self.callEventLoop.execute {
+        self._handleError(error)
       }
+    }
+  }
 
-    case let .message(context):
-      do {
-        let message = try self.deserializer.deserialize(byteBuffer: context.message)
-        if self.state.channelRead(isEnd: false) {
-          self.forwardToInterceptors(.message(message))
-        }
-      } catch {
-        self.channelError(error)
+  /// On-loop implementation of `handleError(_:)`.
+  private func _handleError(_ error: Error) {
+    self.callEventLoop.assertInEventLoop()
+
+    switch self.state.handleError() {
+    case .doNothing:
+      ()
+
+    case .propagateError:
+      self.forwardErrorToInterceptors(error)
+      self.failBufferedWrites(with: error)
+
+    case .propagateErrorAndClose:
+      self.forwardErrorToInterceptors(error)
+      self.failBufferedWrites(with: error)
+      self.channel?.close(mode: .all, promise: nil)
+    }
+  }
+
+  /// Receive initial metadata from the `Channel`.
+  private func receiveFromChannel(initialMetadata headers: HPACKHeaders) {
+    if self.callEventLoop.inEventLoop {
+      self._receiveFromChannel(initialMetadata: headers)
+    } else {
+      self.callEventLoop.execute {
+        self._receiveFromChannel(initialMetadata: headers)
       }
+    }
+  }
 
-    case let .trailingMetadata(trailers):
-      // The `Channel` delivers trailers and `GRPCStatus` separately, we want to emit them together
-      // in the interceptor pipeline.
+  /// On-loop implementation of `receiveFromChannel(initialMetadata:)`.
+  private func _receiveFromChannel(initialMetadata headers: HPACKHeaders) {
+    self.callEventLoop.assertInEventLoop()
+    if self.state.channelRead(isEnd: false) {
+      self.forwardToInterceptors(.metadata(headers))
+    }
+  }
+
+  /// Receive response message bytes from the `Channel`.
+  private func receiveFromChannel(message buffer: ByteBuffer) {
+    if self.callEventLoop.inEventLoop {
+      self._receiveFromChannel(message: buffer)
+    } else {
+      self.callEventLoop.execute {
+        self._receiveFromChannel(message: buffer)
+      }
+    }
+  }
+
+  /// On-loop implementation of `receiveFromChannel(message:)`.
+  private func _receiveFromChannel(message buffer: ByteBuffer) {
+    self.callEventLoop.assertInEventLoop()
+    do {
+      let message = try self.deserializer.deserialize(byteBuffer: buffer)
+      if self.state.channelRead(isEnd: false) {
+        self.forwardToInterceptors(.message(message))
+      }
+    } catch {
+      self.handleError(error)
+    }
+  }
+
+  /// Receive trailing metadata from the `Channel`.
+  private func receiveFromChannel(trailingMetadata trailers: HPACKHeaders) {
+    // The `Channel` delivers trailers and `GRPCStatus` separately, we want to emit them together
+    // in the interceptor pipeline.
+    if self.callEventLoop.inEventLoop {
       self.trailers = trailers
+    } else {
+      self.callEventLoop.execute {
+        self.trailers = trailers
+      }
+    }
+  }
 
-    case let .status(status):
-      if self.state.channelRead(isEnd: true) {
-        self.forwardToInterceptors(.end(status, self.trailers ?? [:]))
-        self.trailers = nil
+  /// Receive the final status from the `Channel`.
+  private func receiveFromChannel(status: GRPCStatus) {
+    if self.callEventLoop.inEventLoop {
+      self._receiveFromChannel(status: status)
+    } else {
+      self.callEventLoop.execute {
+        self._receiveFromChannel(status: status)
       }
     }
+  }
 
-    // (We're the end of the channel. No need to forward anything.)
+  /// On-loop implementation of `receiveFromChannel(status:)`.
+  private func _receiveFromChannel(status: GRPCStatus) {
+    self.callEventLoop.assertInEventLoop()
+    if self.state.channelRead(isEnd: true) {
+      self.forwardToInterceptors(.end(status, self.trailers ?? [:]))
+      self.trailers = nil
+    }
   }
 }
 
@@ -533,7 +657,7 @@ extension ClientTransportState {
   }
 
   /// `channelActive` was invoked on the transport by the `Channel`.
-  mutating func channelActive() -> Bool {
+  mutating func activate() -> Bool {
     // The channel has become active: what now?
     switch self {
     case .idle:
@@ -565,7 +689,7 @@ extension ClientTransportState {
   }
 
   /// `channelInactive` was invoked on the transport by the `Channel`.
-  mutating func channelInactive() -> ChannelInactiveAction {
+  mutating func deactivate() -> ChannelInactiveAction {
     switch self {
     case .idle:
       // We can't become inactive before we've requested a `Channel`.
@@ -611,7 +735,7 @@ extension ClientTransportState {
     }
   }
 
-  enum ChannelErrorAction {
+  enum HandleErrorAction {
     /// Propagate the error to the interceptor pipeline and fail any buffered writes.
     case propagateError
     /// As above, but close the 'Channel' as well.
@@ -620,8 +744,8 @@ extension ClientTransportState {
     case doNothing
   }
 
-  /// We received an error from the `Channel`.
-  mutating func channelError() -> ChannelErrorAction {
+  /// An error was caught.
+  mutating func handleError() -> HandleErrorAction {
     switch self {
     case .idle:
       // The `Channel` can't error if it doesn't exist.
@@ -684,11 +808,13 @@ extension ClientTransport {
   /// Configures this transport with the `configurator`.
   private func configure(using configurator: (ChannelHandler) -> EventLoopFuture<Void>) {
     configurator(self).whenFailure { error in
+      // We might be on a different EL, but `handleError` will sort that out for us, so no need to
+      // hop.
       if error is GRPCStatus || error is GRPCStatusTransformable {
-        self.channelError(error)
+        self.handleError(error)
       } else {
         // Fallback to something which will mark the RPC as 'unavailable'.
-        self.channelError(ConnectionFailure(reason: error))
+        self.handleError(ConnectionFailure(reason: error))
       }
     }
   }
@@ -701,6 +827,7 @@ extension ClientTransport {
     _ part: GRPCClientRequestPart<Request>,
     promise: EventLoopPromise<Void>?
   ) {
+    self.callEventLoop.assertInEventLoop()
     self.logger.debug("buffering request part", metadata: [
       "request_part": "\(part.name)",
       "call_state": self.stateForLogging,
@@ -709,8 +836,13 @@ extension ClientTransport {
   }
 
   /// Writes any buffered request parts to the `Channel`.
-  /// - Parameter channel: The `Channel` to write any buffered request parts to.
-  private func unbuffer(to channel: Channel) {
+  private func unbuffer() {
+    self.callEventLoop.assertInEventLoop()
+
+    guard let channel = self.channel else {
+      return
+    }
+
     // Save any flushing until we're done writing.
     var shouldFlush = false
 
@@ -732,7 +864,7 @@ extension ClientTransport {
           shouldFlush = self.shouldFlush(after: write.request)
         }
 
-        self.write(write.request, promise: write.promise, flush: false)
+        self.writeToChannel(channel, part: write.request, promise: write.promise, flush: false)
       }
 
       // Okay, flush now.
@@ -775,52 +907,50 @@ extension ClientTransport {
 
   /// Write a request part to the `Channel`.
   /// - Parameters:
+  ///   - channel: The `Channel` to write `part` to.
   ///   - part: The request part to write.
-  ///   - channel: The `Channel` to write `part` in to.
   ///   - promise: A promise to complete once the write has been completed.
   ///   - flush: Whether to flush the `Channel` after writing.
-  private func write(
-    _ part: GRPCClientRequestPart<Request>,
+  private func writeToChannel(
+    _ channel: Channel,
+    part: GRPCClientRequestPart<Request>,
     promise: EventLoopPromise<Void>?,
     flush: Bool
   ) {
-    guard let context = self.context else {
-      promise?.fail(GRPCError.AlreadyComplete())
-      return
-    }
-
     switch part {
     case let .metadata(headers):
       let head = self.makeRequestHead(with: headers)
-      context.channel.write(self.wrapOutboundOut(.head(head)), promise: promise)
+      channel.write(self.wrapOutboundOut(.head(head)), promise: promise)
 
     case let .message(request, metadata):
       do {
-        let bytes = try self.serializer.serialize(request, allocator: context.channel.allocator)
+        let bytes = try self.serializer.serialize(request, allocator: channel.allocator)
         let message = _MessageContext<ByteBuffer>(bytes, compressed: metadata.compress)
-        context.channel.write(self.wrapOutboundOut(.message(message)), promise: promise)
+        channel.write(self.wrapOutboundOut(.message(message)), promise: promise)
       } catch {
-        self.channelError(error)
+        self.handleError(error)
       }
 
     case .end:
-      context.channel.write(self.wrapOutboundOut(.end), promise: promise)
+      channel.write(self.wrapOutboundOut(.end), promise: promise)
     }
 
     if flush {
-      context.channel.flush()
+      channel.flush()
     }
   }
 
   /// Forward the response part to the interceptor pipeline.
   /// - Parameter part: The response part to forward.
   private func forwardToInterceptors(_ part: GRPCClientResponsePart<Response>) {
+    self.callEventLoop.assertInEventLoop()
     self._pipeline?.receive(part)
   }
 
   /// Forward the error to the interceptor pipeline.
   /// - Parameter error: The error to forward.
   private func forwardErrorToInterceptors(_ error: Error) {
+    self.callEventLoop.assertInEventLoop()
     self._pipeline?.errorCaught(error)
   }
 }

+ 11 - 15
Sources/GRPC/Interceptor/ClientTransportFactory.swift

@@ -88,12 +88,10 @@ internal struct ClientTransportFactory<Request, Response> {
   /// - Parameter fakeResponse: The fake response stream.
   /// - Returns: A factory for making and configuring fake transport.
   internal static func fake<Request: SwiftProtobuf.Message, Response: SwiftProtobuf.Message>(
-    _ fakeResponse: _FakeResponseStream<Request, Response>?,
-    on eventLoop: EventLoop
+    _ fakeResponse: _FakeResponseStream<Request, Response>?
   ) -> ClientTransportFactory<Request, Response> {
     let factory = FakeClientTransportFactory(
       fakeResponse,
-      on: eventLoop,
       requestSerializer: ProtobufSerializer(),
       requestDeserializer: ProtobufDeserializer(),
       responseSerializer: ProtobufSerializer(),
@@ -106,12 +104,10 @@ internal struct ClientTransportFactory<Request, Response> {
   /// - Parameter fakeResponse: The fake response stream.
   /// - Returns: A factory for making and configuring fake transport.
   internal static func fake<Request: GRPCPayload, Response: GRPCPayload>(
-    _ fakeResponse: _FakeResponseStream<Request, Response>?,
-    on eventLoop: EventLoop
+    _ fakeResponse: _FakeResponseStream<Request, Response>?
   ) -> ClientTransportFactory<Request, Response> {
     let factory = FakeClientTransportFactory(
       fakeResponse,
-      on: eventLoop,
       requestSerializer: GRPCPayloadSerializer(),
       requestDeserializer: GRPCPayloadDeserializer(),
       responseSerializer: GRPCPayloadSerializer(),
@@ -133,6 +129,7 @@ internal struct ClientTransportFactory<Request, Response> {
     to path: String,
     for type: GRPCCallType,
     withOptions options: CallOptions,
+    onEventLoop eventLoop: EventLoop,
     interceptedBy interceptors: [ClientInterceptor<Request, Response>],
     onError: @escaping (Error) -> Void,
     onResponsePart: @escaping (GRPCClientResponsePart<Response>) -> Void
@@ -143,6 +140,7 @@ internal struct ClientTransportFactory<Request, Response> {
         to: path,
         for: type,
         withOptions: options,
+        onEventLoop: eventLoop,
         interceptedBy: interceptors,
         onError: onError,
         onResponsePart: onResponsePart
@@ -154,6 +152,7 @@ internal struct ClientTransportFactory<Request, Response> {
         to: path,
         for: type,
         withOptions: options,
+        onEventLoop: eventLoop,
         interceptedBy: interceptors,
         onError: onError,
         onResponsePart
@@ -203,13 +202,14 @@ private struct HTTP2ClientTransportFactory<Request, Response> {
     to path: String,
     for type: GRPCCallType,
     withOptions options: CallOptions,
+    onEventLoop eventLoop: EventLoop,
     interceptedBy interceptors: [ClientInterceptor<Request, Response>],
     onError: @escaping (Error) -> Void,
     onResponsePart: @escaping (GRPCClientResponsePart<Response>) -> Void
   ) -> ClientTransport<Request, Response> {
     return ClientTransport(
       details: self.makeCallDetails(type: type, path: path, options: options),
-      eventLoop: self.multiplexer.eventLoop,
+      eventLoop: eventLoop,
       interceptors: interceptors,
       serializer: self.serializer,
       deserializer: self.deserializer,
@@ -269,10 +269,6 @@ private struct FakeClientTransportFactory<Request, Response> {
   /// configure their client. The result will be a transport which immediately fails.
   private var fakeResponseStream: _FakeResponseStream<Request, Response>?
 
-  /// The `EventLoop` from the response stream, or an `EmbeddedEventLoop` should the response
-  /// stream be `nil`.
-  private var eventLoop: EventLoop
-
   /// The request serializer.
   private let requestSerializer: AnySerializer<Request>
 
@@ -289,7 +285,6 @@ private struct FakeClientTransportFactory<Request, Response> {
     ResponseDeserializer: MessageDeserializer
   >(
     _ fakeResponseStream: _FakeResponseStream<Request, Response>?,
-    on eventLoop: EventLoop,
     requestSerializer: RequestSerializer,
     requestDeserializer: RequestDeserializer,
     responseSerializer: ResponseSerializer,
@@ -300,7 +295,6 @@ private struct FakeClientTransportFactory<Request, Response> {
     ResponseDeserializer.Output == Response
   {
     self.fakeResponseStream = fakeResponseStream
-    self.eventLoop = eventLoop
     self.requestSerializer = AnySerializer(wrapping: requestSerializer)
     self.responseDeserializer = AnyDeserializer(wrapping: responseDeserializer)
     self.codec = GRPCClientReverseCodecHandler(
@@ -313,6 +307,7 @@ private struct FakeClientTransportFactory<Request, Response> {
     to path: String,
     for type: GRPCCallType,
     withOptions options: CallOptions,
+    onEventLoop eventLoop: EventLoop,
     interceptedBy interceptors: [ClientInterceptor<Request, Response>],
     onError: @escaping (Error) -> Void,
     _ onResponsePart: @escaping (GRPCClientResponsePart<Response>) -> Void
@@ -325,7 +320,7 @@ private struct FakeClientTransportFactory<Request, Response> {
         scheme: "http",
         options: options
       ),
-      eventLoop: self.eventLoop,
+      eventLoop: eventLoop,
       interceptors: interceptors,
       serializer: self.requestSerializer,
       deserializer: self.responseDeserializer,
@@ -347,7 +342,8 @@ private struct FakeClientTransportFactory<Request, Response> {
           }
         }
       } else {
-        return self.eventLoop.makeFailedFuture(GRPCStatus(code: .unavailable, message: nil))
+        return transport.callEventLoop
+          .makeFailedFuture(GRPCStatus(code: .unavailable, message: nil))
       }
     }
   }

+ 141 - 0
Tests/GRPCTests/ClientEventLoopPreferenceTests.swift

@@ -0,0 +1,141 @@
+/*
+ * Copyright 2021, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import EchoImplementation
+import EchoModel
+import GRPC
+import NIO
+import XCTest
+
+final class ClientEventLoopPreferenceTests: GRPCTestCase {
+  private var group: MultiThreadedEventLoopGroup!
+
+  private var serverLoop: EventLoop!
+  private var clientLoop: EventLoop!
+  private var clientCallbackLoop: EventLoop!
+
+  private var server: Server!
+  private var connection: ClientConnection!
+
+  private var echo: Echo_EchoClient {
+    let options = CallOptions(
+      eventLoopPreference: .exact(self.clientCallbackLoop),
+      logger: self.clientLogger
+    )
+
+    return Echo_EchoClient(channel: self.connection, defaultCallOptions: options)
+  }
+
+  override func setUp() {
+    super.setUp()
+
+    self.group = MultiThreadedEventLoopGroup(numberOfThreads: 3)
+    self.serverLoop = self.group.next()
+    self.clientLoop = self.group.next()
+    self.clientCallbackLoop = self.group.next()
+
+    XCTAssert(self.serverLoop !== self.clientLoop)
+    XCTAssert(self.serverLoop !== self.clientCallbackLoop)
+    XCTAssert(self.clientLoop !== self.clientCallbackLoop)
+
+    self.server = try! Server.insecure(group: self.serverLoop)
+      .withLogger(self.serverLogger)
+      .withServiceProviders([EchoProvider()])
+      .bind(host: "localhost", port: 0)
+      .wait()
+
+    self.connection = ClientConnection.insecure(group: self.clientLoop)
+      .withBackgroundActivityLogger(self.clientLogger)
+      .connect(host: "localhost", port: self.server.channel.localAddress!.port!)
+  }
+
+  override func tearDown() {
+    XCTAssertNoThrow(try self.connection.close().wait())
+    XCTAssertNoThrow(try self.server.close().wait())
+    XCTAssertNoThrow(try self.group.syncShutdownGracefully())
+
+    super.tearDown()
+  }
+
+  private func assertClientCallbackEventLoop(_ eventLoop: EventLoop, line: UInt = #line) {
+    XCTAssert(eventLoop === self.clientCallbackLoop, line: line)
+  }
+
+  func testUnaryWithDifferentEventLoop() throws {
+    let get = self.echo.get(.with { $0.text = "Hello!" })
+
+    self.assertClientCallbackEventLoop(get.eventLoop)
+    self.assertClientCallbackEventLoop(get.initialMetadata.eventLoop)
+    self.assertClientCallbackEventLoop(get.response.eventLoop)
+    self.assertClientCallbackEventLoop(get.trailingMetadata.eventLoop)
+    self.assertClientCallbackEventLoop(get.status.eventLoop)
+
+    assertThat(try get.response.wait(), .is(.with { $0.text = "Swift echo get: Hello!" }))
+    assertThat(try get.status.wait(), .hasCode(.ok))
+  }
+
+  func testClientStreamingWithDifferentEventLoop() throws {
+    let collect = self.echo.collect()
+
+    self.assertClientCallbackEventLoop(collect.eventLoop)
+    self.assertClientCallbackEventLoop(collect.initialMetadata.eventLoop)
+    self.assertClientCallbackEventLoop(collect.response.eventLoop)
+    self.assertClientCallbackEventLoop(collect.trailingMetadata.eventLoop)
+    self.assertClientCallbackEventLoop(collect.status.eventLoop)
+
+    XCTAssertNoThrow(try collect.sendMessage(.with { $0.text = "a" }).wait())
+    XCTAssertNoThrow(try collect.sendEnd().wait())
+
+    assertThat(try collect.response.wait(), .is(.with { $0.text = "Swift echo collect: a" }))
+    assertThat(try collect.status.wait(), .hasCode(.ok))
+  }
+
+  func testServerStreamingWithDifferentEventLoop() throws {
+    let response = self.clientCallbackLoop.makePromise(of: Void.self)
+
+    let expand = self.echo.expand(.with { $0.text = "a" }) { _ in
+      self.clientCallbackLoop.preconditionInEventLoop()
+      response.succeed(())
+    }
+
+    self.assertClientCallbackEventLoop(expand.eventLoop)
+    self.assertClientCallbackEventLoop(expand.initialMetadata.eventLoop)
+    self.assertClientCallbackEventLoop(expand.trailingMetadata.eventLoop)
+    self.assertClientCallbackEventLoop(expand.status.eventLoop)
+
+    XCTAssertNoThrow(try response.futureResult.wait())
+    assertThat(try expand.status.wait(), .hasCode(.ok))
+  }
+
+  func testBidirectionalStreamingWithDifferentEventLoop() throws {
+    let response = self.clientCallbackLoop.makePromise(of: Void.self)
+
+    let update = self.echo.update { _ in
+      self.clientCallbackLoop.preconditionInEventLoop()
+      response.succeed(())
+    }
+
+    self.assertClientCallbackEventLoop(update.eventLoop)
+    self.assertClientCallbackEventLoop(update.initialMetadata.eventLoop)
+    self.assertClientCallbackEventLoop(update.trailingMetadata.eventLoop)
+    self.assertClientCallbackEventLoop(update.status.eventLoop)
+
+    XCTAssertNoThrow(try update.sendMessage(.with { $0.text = "a" }).wait())
+    XCTAssertNoThrow(try update.sendEnd().wait())
+
+    XCTAssertNoThrow(try response.futureResult.wait())
+    assertThat(try update.status.wait(), .hasCode(.ok))
+  }
+}