Browse Source

Support new handlers in the server codec + state machine (#1101)

Motivation:

We have new ways of handling RPCs on the server, the server codec and
state machine should use them!

Modifications:

- In the state machine: try to handle an RPC using the new API, falling
  back to the old way
- In the codec: store a mode, i.e. how we're handling the RPC.
- Flushing has become the responsibility of the server codec (as opposed
  to the RPC-specific channel handler): now we only flush at the end of
  'channelReadComplete' (if a flush is pending), or on calls to
  'flush(context:)' if we aren't currently reading.

Result:

The server supports the new RPC handlers.
George Barnett 5 years ago
parent
commit
b4bebe5000

+ 204 - 68
Sources/GRPC/HTTP2ToRawGRPCServerCodec.swift

@@ -30,6 +30,23 @@ internal final class HTTP2ToRawGRPCServerCodec: ChannelDuplexHandler, GRPCServer
   private let errorDelegate: ServerErrorDelegate?
   private var context: ChannelHandlerContext!
 
+  /// The mode we're operating in.
+  private var mode: Mode = .notConfigured
+
+  /// Whether we are currently reading data from the `Channel`. Should be set to `false` once a
+  /// burst of reading has completed.
+  private var isReading = false
+
+  /// Indicates whether a flush event is pending. If a flush is received while `isReading` is `true`
+  /// then it is held until the read completes in order to elide unnecessary flushes.
+  private var flushPending = false
+
+  private enum Mode {
+    case notConfigured
+    case legacy
+    case handler(GRPCServerHandlerProtocol)
+  }
+
   init(
     servicesByName: [Substring: CallHandlerProvider],
     encoding: ServerMessageEncoding,
@@ -46,53 +63,36 @@ internal final class HTTP2ToRawGRPCServerCodec: ChannelDuplexHandler, GRPCServer
     )
   }
 
-  func handlerAdded(context: ChannelHandlerContext) {
+  internal func handlerAdded(context: ChannelHandlerContext) {
     self.context = context
   }
 
-  func handlerRemoved(context: ChannelHandlerContext) {
+  internal func handlerRemoved(context: ChannelHandlerContext) {
     self.context = nil
   }
 
-  /// Called when the pipeline has finished configuring.
-  private func configured() {
-    switch self.state.pipelineConfigured() {
-    case let .forwardHeaders(headers):
-      self.context.fireChannelRead(self.wrapInboundOut(.metadata(headers)))
-
-    case let .forwardHeadersAndRead(headers):
-      self.context.fireChannelRead(self.wrapInboundOut(.metadata(headers)))
-      self.tryReadingMessage()
+  internal func errorCaught(context: ChannelHandlerContext, error: Error) {
+    switch self.mode {
+    case .notConfigured:
+      context.close(mode: .all, promise: nil)
+    case .legacy:
+      context.fireErrorCaught(error)
+    case let .handler(hander):
+      hander.receiveError(error)
     }
   }
 
-  /// Try to read a request message from the buffer.
-  private func tryReadingMessage() {
-    let action = self.state.readNextRequest()
-    switch action {
-    case .none:
-      ()
-
-    case let .forwardMessage(buffer):
-      self.context.fireChannelRead(self.wrapInboundOut(.message(buffer)))
-
-    case let .forwardMessageAndEnd(buffer):
-      self.context.fireChannelRead(self.wrapInboundOut(.message(buffer)))
-      self.context.fireChannelRead(self.wrapInboundOut(.end))
-
-    case let .forwardMessageThenReadNextMessage(buffer):
-      self.context.fireChannelRead(self.wrapInboundOut(.message(buffer)))
-      self.tryReadingMessage()
-
-    case .forwardEnd:
-      self.context.fireChannelRead(self.wrapInboundOut(.end))
-
-    case let .errorCaught(error):
-      self.context.fireErrorCaught(error)
+  internal func channelInactive(context: ChannelHandlerContext) {
+    switch self.mode {
+    case .notConfigured, .legacy:
+      context.fireChannelInactive()
+    case let .handler(handler):
+      handler.finish()
     }
   }
 
-  func channelRead(context: ChannelHandlerContext, data: NIOAny) {
+  internal func channelRead(context: ChannelHandlerContext, data: NIOAny) {
+    self.isReading = true
     let payload = self.unwrapInboundIn(data)
 
     switch payload {
@@ -108,11 +108,16 @@ internal final class HTTP2ToRawGRPCServerCodec: ChannelDuplexHandler, GRPCServer
       )
 
       switch receiveHeaders {
-      case let .configurePipeline(handler):
+      case let .configureLegacy(handler):
+        self.mode = .legacy
         context.channel.pipeline.addHandler(handler).whenSuccess {
           self.configured()
         }
 
+      case let .configure(handler):
+        self.mode = .handler(handler)
+        self.configured()
+
       case let .rejectRPC(trailers):
         // We're not handling this request: write headers and end stream.
         let payload = HTTP2Frame.FramePayload.headers(.init(headers: trailers, endStream: true))
@@ -145,62 +150,179 @@ internal final class HTTP2ToRawGRPCServerCodec: ChannelDuplexHandler, GRPCServer
     }
   }
 
-  func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
+  internal func channelReadComplete(context: ChannelHandlerContext) {
+    self.isReading = false
+
+    if self.flushPending {
+      self.flushPending = false
+      context.flush()
+    }
+
+    context.fireChannelReadComplete()
+  }
+
+  internal func write(
+    context: ChannelHandlerContext,
+    data: NIOAny,
+    promise: EventLoopPromise<Void>?
+  ) {
     let responsePart = self.unwrapOutboundIn(data)
 
     switch responsePart {
     case let .metadata(headers):
-      switch self.state.send(headers: headers) {
-      case let .success(headers):
-        let payload = HTTP2Frame.FramePayload.headers(.init(headers: headers))
-        context.write(self.wrapOutboundOut(payload), promise: promise)
+      self.sendMetadata(headers, promise: promise)
 
-      case let .failure(error):
-        promise?.fail(error)
+    case let .message(buffer, metadata):
+      self.sendMessage(buffer, metadata: metadata, promise: promise)
+
+    case let .end(status, trailers):
+      self.sendEnd(status: status, trailers: trailers, promise: promise)
+    }
+  }
+
+  internal func flush(context: ChannelHandlerContext) {
+    if self.isReading {
+      // We're already reading; record the flush and emit it when the read completes.
+      self.flushPending = true
+    } else {
+      // Not reading: flush now.
+      context.flush()
+    }
+  }
+
+  /// Called when the pipeline has finished configuring.
+  private func configured() {
+    switch self.state.pipelineConfigured() {
+    case let .forwardHeaders(headers):
+      switch self.mode {
+      case .notConfigured:
+        preconditionFailure()
+      case .legacy:
+        self.context.fireChannelRead(self.wrapInboundOut(.metadata(headers)))
+      case let .handler(handler):
+        handler.receiveMetadata(headers)
       }
 
-    case let .message(buffer, metadata):
-      let writeBuffer = self.state.send(
-        buffer: buffer,
-        allocator: context.channel.allocator,
-        compress: metadata.compress
-      )
+    case let .forwardHeadersAndRead(headers):
+      switch self.mode {
+      case .notConfigured:
+        preconditionFailure()
+      case .legacy:
+        self.context.fireChannelRead(self.wrapInboundOut(.metadata(headers)))
+      case let .handler(handler):
+        handler.receiveMetadata(headers)
+      }
+      self.tryReadingMessage()
+    }
+  }
 
-      switch writeBuffer {
-      case let .success(buffer):
-        let payload = HTTP2Frame.FramePayload.data(.init(data: .byteBuffer(buffer)))
-        context.write(self.wrapOutboundOut(payload), promise: promise)
+  /// Try to read a request message from the buffer.
+  private func tryReadingMessage() {
+    let action = self.state.readNextRequest()
+    switch action {
+    case .none:
+      ()
 
-      case let .failure(error):
-        promise?.fail(error)
+    case let .forwardMessage(buffer):
+      switch self.mode {
+      case .notConfigured:
+        preconditionFailure()
+      case .legacy:
+        self.context.fireChannelRead(self.wrapInboundOut(.message(buffer)))
+      case let .handler(handler):
+        handler.receiveMessage(buffer)
       }
 
-    case let .end(status, trailers):
-      switch self.state.send(status: status, trailers: trailers) {
-      case let .success(trailers):
-        // Always end stream for status and trailers.
-        let payload = HTTP2Frame.FramePayload.headers(.init(headers: trailers, endStream: true))
-        context.write(self.wrapOutboundOut(payload), promise: promise)
+    case let .forwardMessageAndEnd(buffer):
+      switch self.mode {
+      case .notConfigured:
+        preconditionFailure()
+      case .legacy:
+        self.context.fireChannelRead(self.wrapInboundOut(.message(buffer)))
+        self.context.fireChannelRead(self.wrapInboundOut(.end))
+      case let .handler(handler):
+        handler.receiveMessage(buffer)
+        handler.receiveEnd()
+      }
 
-      case let .failure(error):
-        promise?.fail(error)
+    case let .forwardMessageThenReadNextMessage(buffer):
+      switch self.mode {
+      case .notConfigured:
+        preconditionFailure()
+      case .legacy:
+        self.context.fireChannelRead(self.wrapInboundOut(.message(buffer)))
+      case let .handler(handler):
+        handler.receiveMessage(buffer)
+      }
+      self.tryReadingMessage()
+
+    case .forwardEnd:
+      switch self.mode {
+      case .notConfigured:
+        preconditionFailure()
+      case .legacy:
+        self.context.fireChannelRead(self.wrapInboundOut(.end))
+      case let .handler(handler):
+        handler.receiveEnd()
+      }
+
+    case let .errorCaught(error):
+      switch self.mode {
+      case .notConfigured:
+        preconditionFailure()
+      case .legacy:
+        self.context.fireErrorCaught(error)
+      case let .handler(handler):
+        handler.receiveError(error)
       }
     }
   }
 
   internal func sendMetadata(
-    _ metadata: HPACKHeaders,
+    _ headers: HPACKHeaders,
     promise: EventLoopPromise<Void>?
   ) {
-    fatalError("TODO: not used yet")
+    switch self.state.send(headers: headers) {
+    case let .success(headers):
+      let payload = HTTP2Frame.FramePayload.headers(.init(headers: headers))
+      self.context.write(self.wrapOutboundOut(payload), promise: promise)
+
+      if self.isReading {
+        self.flushPending = true
+      } else {
+        self.context.flush()
+      }
+
+    case let .failure(error):
+      promise?.fail(error)
+    }
   }
 
   internal func sendMessage(
-    _ bytes: ByteBuffer,
+    _ buffer: ByteBuffer,
     metadata: MessageMetadata,
     promise: EventLoopPromise<Void>?
   ) {
-    fatalError("TODO: not used yet")
+    let writeBuffer = self.state.send(
+      buffer: buffer,
+      allocator: self.context.channel.allocator,
+      compress: metadata.compress
+    )
+
+    switch writeBuffer {
+    case let .success(buffer):
+      let payload = HTTP2Frame.FramePayload.data(.init(data: .byteBuffer(buffer)))
+      self.context.write(self.wrapOutboundOut(payload), promise: promise)
+
+      if self.isReading {
+        self.flushPending = true
+      } else {
+        self.context.flush()
+      }
+
+    case let .failure(error):
+      promise?.fail(error)
+    }
   }
 
   internal func sendEnd(
@@ -208,6 +330,20 @@ internal final class HTTP2ToRawGRPCServerCodec: ChannelDuplexHandler, GRPCServer
     trailers: HPACKHeaders,
     promise: EventLoopPromise<Void>?
   ) {
-    fatalError("TODO: not used yet")
+    switch self.state.send(status: status, trailers: trailers) {
+    case let .success(trailers):
+      // Always end stream for status and trailers.
+      let payload = HTTP2Frame.FramePayload.headers(.init(headers: trailers, endStream: true))
+      self.context.write(self.wrapOutboundOut(payload), promise: promise)
+
+      if self.isReading {
+        self.flushPending = true
+      } else {
+        self.context.flush()
+      }
+
+    case let .failure(error):
+      promise?.fail(error)
+    }
   }
 }

+ 27 - 18
Sources/GRPC/HTTP2ToRawGRPCStateMachine.swift

@@ -232,7 +232,9 @@ extension HTTP2ToRawGRPCStateMachine {
 
   enum ReceiveHeadersAction {
     /// Configure the pipeline with the given call handler.
-    case configurePipeline(GRPCCallHandler)
+    case configureLegacy(GRPCCallHandler)
+    /// Configure the RPC to use the given server handler.
+    case configure(GRPCServerHandlerProtocol)
     /// Reject the RPC by writing out the given headers and setting end-stream.
     case rejectRPC(HPACKHeaders)
   }
@@ -331,25 +333,32 @@ extension HTTP2ToRawGRPCStateMachine.RequestIdleResponseIdleState {
 
     // We have a matching service, hopefully we have a provider for the method too.
     let method = Substring(callPath.method)
-    guard let handler = service.handleMethod(method, callHandlerContext: context) else {
-      return self.methodNotImplemented(path, contentType: contentType)
-    }
 
-    // Finally, on to the next state!
-    let requestOpenResponseIdle = HTTP2ToRawGRPCStateMachine.RequestOpenResponseIdleState(
-      reader: reader,
-      writer: writer,
-      contentType: contentType,
-      acceptEncoding: acceptableRequestEncoding,
-      responseEncoding: responseEncoding,
-      normalizeHeaders: self.normalizeHeaders,
-      configurationState: .configuring(headers)
-    )
+    func nextState() -> HTTP2ToRawGRPCStateMachine.RequestOpenResponseIdleState {
+      return HTTP2ToRawGRPCStateMachine.RequestOpenResponseIdleState(
+        reader: reader,
+        writer: writer,
+        contentType: contentType,
+        acceptEncoding: acceptableRequestEncoding,
+        responseEncoding: responseEncoding,
+        normalizeHeaders: self.normalizeHeaders,
+        configurationState: .configuring(headers)
+      )
+    }
 
-    return .init(
-      state: .requestOpenResponseIdle(requestOpenResponseIdle),
-      action: .configurePipeline(handler)
-    )
+    if let handler = service.handle(method: method, context: context) {
+      return .init(
+        state: .requestOpenResponseIdle(nextState()),
+        action: .configure(handler)
+      )
+    } else if let handler = service.handleMethod(method, callHandlerContext: context) {
+      return .init(
+        state: .requestOpenResponseIdle(nextState()),
+        action: .configureLegacy(handler)
+      )
+    } else {
+      return self.methodNotImplemented(path, contentType: contentType)
+    }
   }
 
   /// The 'content-type' is not supported; close with status code 415.

+ 1 - 1
Tests/GRPCTests/XCTestHelpers.swift

@@ -481,7 +481,7 @@ struct Matcher<Value> {
   static func configure() -> Matcher<HTTP2ToRawGRPCStateMachine.ReceiveHeadersAction> {
     return .init { actual in
       switch actual {
-      case .configurePipeline:
+      case .configureLegacy, .configure:
         return .match
       default:
         return .noMatch(actual: "\(actual)", expected: "configurePipeline")