BaseCallHandler.swift 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import Foundation
  2. import SwiftProtobuf
  3. import NIO
  4. import NIOHTTP1
  5. /// Provides a means for decoding incoming gRPC messages into protobuf objects.
  6. ///
  7. /// Calls through to `processMessage` for individual messages it receives, which needs to be implemented by subclasses.
  8. public class BaseCallHandler<RequestMessage: Message, ResponseMessage: Message>: GRPCCallHandler {
  9. public func makeGRPCServerCodec() -> ChannelHandler { return GRPCServerCodec<RequestMessage, ResponseMessage>() }
  10. /// Called whenever a message has been received.
  11. ///
  12. /// Overridden by subclasses.
  13. public func processMessage(_ message: RequestMessage) throws {
  14. fatalError("needs to be overridden")
  15. }
  16. /// Called when the client has half-closed the stream, indicating that they won't send any further data.
  17. ///
  18. /// Overridden by subclasses if the "end-of-stream" event is relevant.
  19. public func endOfStreamReceived() { }
  20. /// Whether this handler can still write messages to the client.
  21. private var serverCanWrite = true
  22. /// Called for each error recieved in `errorCaught(ctx:error:)`.
  23. private weak var errorDelegate: ServerErrorDelegate?
  24. public init(errorDelegate: ServerErrorDelegate?) {
  25. self.errorDelegate = errorDelegate
  26. }
  27. }
  28. extension BaseCallHandler: ChannelInboundHandler {
  29. public typealias InboundIn = GRPCServerRequestPart<RequestMessage>
  30. /// Passes errors to the user-provided `errorHandler`. After an error has been received an
  31. /// appropriate status is written. Errors which don't conform to `GRPCStatusTransformable`
  32. /// return a status with code `.internalError`.
  33. public func errorCaught(ctx: ChannelHandlerContext, error: Error) {
  34. errorDelegate?.observe(error)
  35. let transformed = errorDelegate?.transform(error) ?? error
  36. let status = (transformed as? GRPCStatusTransformable)?.asGRPCStatus() ?? GRPCStatus.processingError
  37. self.write(ctx: ctx, data: NIOAny(GRPCServerResponsePart<ResponseMessage>.status(status)), promise: nil)
  38. }
  39. public func channelRead(ctx: ChannelHandlerContext, data: NIOAny) {
  40. switch self.unwrapInboundIn(data) {
  41. case .head(let requestHead):
  42. // Head should have been handled by `GRPCChannelHandler`.
  43. self.errorCaught(ctx: ctx, error: GRPCServerError.invalidState("unexpected request head received \(requestHead)"))
  44. case .message(let message):
  45. do {
  46. try processMessage(message)
  47. } catch {
  48. self.errorCaught(ctx: ctx, error: error)
  49. }
  50. case .end:
  51. endOfStreamReceived()
  52. }
  53. }
  54. }
  55. extension BaseCallHandler: ChannelOutboundHandler {
  56. public typealias OutboundIn = GRPCServerResponsePart<ResponseMessage>
  57. public typealias OutboundOut = GRPCServerResponsePart<ResponseMessage>
  58. public func write(ctx: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
  59. guard serverCanWrite else {
  60. promise?.fail(error: GRPCServerError.serverNotWritable)
  61. return
  62. }
  63. // We can only write one status; make sure we don't write again.
  64. if case .status = unwrapOutboundIn(data) {
  65. serverCanWrite = false
  66. ctx.writeAndFlush(data, promise: promise)
  67. } else {
  68. ctx.write(data, promise: promise)
  69. }
  70. }
  71. }