HTTP1ToRawGRPCServerCodec.swift 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. import Foundation
  2. import NIO
  3. import NIOHTTP1
  4. import NIOFoundationCompat
  5. /// Incoming gRPC package with an unknown message type (represented by a byte buffer).
  6. public enum RawGRPCServerRequestPart {
  7. case head(HTTPRequestHead)
  8. case message(ByteBuffer)
  9. case end
  10. }
  11. /// Outgoing gRPC package with an unknown message type (represented by a byte buffer).
  12. public enum RawGRPCServerResponsePart {
  13. case headers(HTTPHeaders)
  14. case message(ByteBuffer)
  15. case status(GRPCStatus)
  16. }
  17. /// A simple channel handler that translates HTTP1 data types into gRPC packets, and vice versa.
  18. ///
  19. /// This codec allows us to use the "raw" gRPC protocol on a low level, with further handlers operationg the protocol
  20. /// on a "higher" level.
  21. ///
  22. /// We use HTTP1 (instead of HTTP2) primitives, as these are easier to work with than raw HTTP2
  23. /// primitives while providing all the functionality we need. In addition, this should make implementing gRPC-over-HTTP1
  24. /// (sometimes also called pPRC) easier in the future.
  25. ///
  26. /// The translation from HTTP2 to HTTP1 is done by `HTTP2ToHTTP1ServerCodec`.
  27. public final class HTTP1ToRawGRPCServerCodec {
  28. // 1-byte for compression flag, 4-bytes for message length.
  29. private let protobufMetadataSize = 5
  30. private var contentType: ContentType?
  31. // The following buffers use force unwrapping explicitly. With optionals, developers
  32. // are encouraged to unwrap them using guard-else statements. These don't work cleanly
  33. // with structs, since the guard-else would create a new copy of the struct, which
  34. // would then have to be re-assigned into the class variable for the changes to take effect.
  35. // By force unwrapping, we avoid those reassignments, and the code is a bit cleaner.
  36. // Buffer to store binary encoded protos as they're being received if the proto is split across
  37. // multiple buffers.
  38. private var binaryRequestBuffer: NIO.ByteBuffer!
  39. // Buffers to store text encoded protos. Only used when content-type is application/grpc-web-text.
  40. // TODO(kaipi): Extract all gRPC Web processing logic into an independent handler only added on
  41. // the HTTP1.1 pipeline, as it's starting to get in the way of readability.
  42. private var requestTextBuffer: NIO.ByteBuffer!
  43. private var responseTextBuffer: NIO.ByteBuffer!
  44. var inboundState = InboundState.expectingHeaders
  45. var outboundState = OutboundState.expectingHeaders
  46. }
  47. extension HTTP1ToRawGRPCServerCodec {
  48. /// Expected content types for incoming requests.
  49. private enum ContentType: String {
  50. /// Binary encoded gRPC request.
  51. case binary = "application/grpc"
  52. /// Base64 encoded gRPC-Web request.
  53. case text = "application/grpc-web-text"
  54. /// Binary encoded gRPC-Web request.
  55. case web = "application/grpc-web"
  56. }
  57. enum InboundState {
  58. case expectingHeaders
  59. case expectingBody(Body)
  60. // ignore any additional messages; e.g. we've seen .end or we've sent an error and are waiting for the stream to close.
  61. case ignore
  62. enum Body {
  63. case expectingCompressedFlag
  64. case expectingMessageLength
  65. case expectingMoreMessageBytes(UInt32)
  66. }
  67. }
  68. enum OutboundState {
  69. case expectingHeaders
  70. case expectingBodyOrStatus
  71. case ignore
  72. }
  73. }
  74. extension HTTP1ToRawGRPCServerCodec: ChannelInboundHandler {
  75. public typealias InboundIn = HTTPServerRequestPart
  76. public typealias InboundOut = RawGRPCServerRequestPart
  77. public func channelRead(ctx: ChannelHandlerContext, data: NIOAny) {
  78. if case .ignore = inboundState { return }
  79. do {
  80. switch self.unwrapInboundIn(data) {
  81. case .head(let requestHead):
  82. inboundState = try processHead(ctx: ctx, requestHead: requestHead)
  83. case .body(var body):
  84. inboundState = try processBody(ctx: ctx, body: &body)
  85. case .end(let trailers):
  86. inboundState = try processEnd(ctx: ctx, trailers: trailers)
  87. }
  88. } catch {
  89. ctx.fireErrorCaught(error)
  90. inboundState = .ignore
  91. }
  92. }
  93. func processHead(ctx: ChannelHandlerContext, requestHead: HTTPRequestHead) throws -> InboundState {
  94. guard case .expectingHeaders = inboundState else {
  95. throw GRPCServerError.invalidState("expecteded state .expectingHeaders, got \(inboundState)")
  96. }
  97. if let contentTypeHeader = requestHead.headers["content-type"].first {
  98. contentType = ContentType(rawValue: contentTypeHeader)
  99. } else {
  100. // If the Content-Type is not present, assume the request is binary encoded gRPC.
  101. contentType = .binary
  102. }
  103. if contentType == .text {
  104. requestTextBuffer = ctx.channel.allocator.buffer(capacity: 0)
  105. }
  106. ctx.fireChannelRead(self.wrapInboundOut(.head(requestHead)))
  107. return .expectingBody(.expectingCompressedFlag)
  108. }
  109. func processBody(ctx: ChannelHandlerContext, body: inout ByteBuffer) throws -> InboundState {
  110. guard case .expectingBody(let bodyState) = inboundState else {
  111. throw GRPCServerError.invalidState("expecteded state .expectingBody(_), got \(inboundState)")
  112. }
  113. // If the contentType is text, then decode the incoming bytes as base64 encoded, and append
  114. // it to the binary buffer. If the request is chunked, this section will process the text
  115. // in the biggest chunk that is multiple of 4, leaving the unread bytes in the textBuffer
  116. // where it will expect a new incoming chunk.
  117. if contentType == .text {
  118. precondition(requestTextBuffer != nil)
  119. requestTextBuffer.write(buffer: &body)
  120. // Read in chunks of 4 bytes as base64 encoded strings will always be multiples of 4.
  121. let readyBytes = requestTextBuffer.readableBytes - (requestTextBuffer.readableBytes % 4)
  122. guard let base64Encoded = requestTextBuffer.readString(length: readyBytes),
  123. let decodedData = Data(base64Encoded: base64Encoded) else {
  124. throw GRPCServerError.base64DecodeError
  125. }
  126. body.write(bytes: decodedData)
  127. }
  128. return .expectingBody(try processBodyState(ctx: ctx, initialState: bodyState, messageBuffer: &body))
  129. }
  130. func processBodyState(ctx: ChannelHandlerContext, initialState: InboundState.Body, messageBuffer: inout ByteBuffer) throws -> InboundState.Body {
  131. var bodyState = initialState
  132. // Iterate over all available incoming data, trying to read length-delimited messages.
  133. // Each message has the following format:
  134. // - 1 byte "compressed" flag (currently always zero, as we do not support for compression)
  135. // - 4 byte signed-integer payload length (N)
  136. // - N bytes payload (normally a valid wire-format protocol buffer)
  137. while true {
  138. switch bodyState {
  139. case .expectingCompressedFlag:
  140. guard let compressedFlag: Int8 = messageBuffer.readInteger() else { return .expectingCompressedFlag }
  141. // TODO: Add support for compression.
  142. guard compressedFlag == 0 else { throw GRPCServerError.unexpectedCompression }
  143. bodyState = .expectingMessageLength
  144. case .expectingMessageLength:
  145. guard let messageLength: UInt32 = messageBuffer.readInteger() else { return .expectingMessageLength }
  146. bodyState = .expectingMoreMessageBytes(messageLength)
  147. case .expectingMoreMessageBytes(let bytesOutstanding):
  148. // We need to account for messages being spread across multiple `ByteBuffer`s so buffer them
  149. // into `buffer`. Note: when messages are contained within a single `ByteBuffer` we're just
  150. // taking a slice so don't incur any extra writes.
  151. guard messageBuffer.readableBytes >= bytesOutstanding else {
  152. let remainingBytes = bytesOutstanding - numericCast(messageBuffer.readableBytes)
  153. if self.binaryRequestBuffer != nil {
  154. self.binaryRequestBuffer.write(buffer: &messageBuffer)
  155. } else {
  156. messageBuffer.reserveCapacity(numericCast(bytesOutstanding))
  157. self.binaryRequestBuffer = messageBuffer
  158. }
  159. return .expectingMoreMessageBytes(remainingBytes)
  160. }
  161. // We know buffer.readableBytes >= messageLength, so it's okay to force unwrap here.
  162. var slice = messageBuffer.readSlice(length: numericCast(bytesOutstanding))!
  163. if self.binaryRequestBuffer != nil {
  164. self.binaryRequestBuffer.write(buffer: &slice)
  165. ctx.fireChannelRead(self.wrapInboundOut(.message(self.binaryRequestBuffer)))
  166. } else {
  167. ctx.fireChannelRead(self.wrapInboundOut(.message(slice)))
  168. }
  169. self.binaryRequestBuffer = nil
  170. bodyState = .expectingCompressedFlag
  171. }
  172. }
  173. }
  174. private func processEnd(ctx: ChannelHandlerContext, trailers: HTTPHeaders?) throws -> InboundState {
  175. if let trailers = trailers {
  176. throw GRPCServerError.invalidState("unexpected trailers received \(trailers)")
  177. }
  178. ctx.fireChannelRead(self.wrapInboundOut(.end))
  179. return .ignore
  180. }
  181. }
  182. extension HTTP1ToRawGRPCServerCodec: ChannelOutboundHandler {
  183. public typealias OutboundIn = RawGRPCServerResponsePart
  184. public typealias OutboundOut = HTTPServerResponsePart
  185. public func write(ctx: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
  186. if case .ignore = outboundState { return }
  187. switch self.unwrapOutboundIn(data) {
  188. case .headers(var headers):
  189. guard case .expectingHeaders = outboundState else { return }
  190. var version = HTTPVersion(major: 2, minor: 0)
  191. if let contentType = contentType {
  192. headers.add(name: "content-type", value: contentType.rawValue)
  193. if contentType != .binary {
  194. version = .init(major: 1, minor: 1)
  195. }
  196. }
  197. if contentType == .text {
  198. responseTextBuffer = ctx.channel.allocator.buffer(capacity: 0)
  199. }
  200. ctx.write(self.wrapOutboundOut(.head(HTTPResponseHead(version: version, status: .ok, headers: headers))), promise: promise)
  201. outboundState = .expectingBodyOrStatus
  202. case .message(var messageBytes):
  203. guard case .expectingBodyOrStatus = outboundState else { return }
  204. // Write out a length-delimited message payload. See `processBodyState` for the corresponding format.
  205. var responseBuffer = ctx.channel.allocator.buffer(capacity: messageBytes.readableBytes + protobufMetadataSize)
  206. responseBuffer.write(integer: Int8(0)) // Compression flag: no compression
  207. responseBuffer.write(integer: UInt32(messageBytes.readableBytes))
  208. responseBuffer.write(buffer: &messageBytes)
  209. if contentType == .text {
  210. precondition(responseTextBuffer != nil)
  211. // Store the response into an independent buffer. We can't return the message directly as
  212. // it needs to be aggregated with all the responses plus the trailers, in order to have
  213. // the base64 response properly encoded in a single byte stream.
  214. responseTextBuffer!.write(buffer: &responseBuffer)
  215. // Since we stored the written data, mark the write promise as successful so that the
  216. // ServerStreaming provider continues sending the data.
  217. promise?.succeed(result: Void())
  218. } else {
  219. ctx.write(self.wrapOutboundOut(.body(.byteBuffer(responseBuffer))), promise: promise)
  220. }
  221. outboundState = .expectingBodyOrStatus
  222. case .status(let status):
  223. // If we error before sending the initial headers (e.g. unimplemented method) then we won't have sent the request head.
  224. // NIOHTTP2 doesn't support sending a single frame as a "Trailers-Only" response so we still need to loop back and
  225. // send the request head first.
  226. if case .expectingHeaders = outboundState {
  227. self.write(ctx: ctx, data: NIOAny(RawGRPCServerResponsePart.headers(HTTPHeaders())), promise: nil)
  228. }
  229. var trailers = status.trailingMetadata
  230. trailers.add(name: "grpc-status", value: String(describing: status.code.rawValue))
  231. trailers.add(name: "grpc-message", value: status.message)
  232. if contentType == .text {
  233. precondition(responseTextBuffer != nil)
  234. // Encode the trailers into the response byte stream as a length delimited message, as per
  235. // https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-WEB.md
  236. let textTrailers = trailers.map { name, value in "\(name): \(value)" }.joined(separator: "\r\n")
  237. responseTextBuffer.write(integer: UInt8(0x80))
  238. responseTextBuffer.write(integer: UInt32(textTrailers.utf8.count))
  239. responseTextBuffer.write(string: textTrailers)
  240. // TODO: Binary responses that are non multiples of 3 will end = or == when encoded in
  241. // base64. Investigate whether this might have any effect on the transport mechanism and
  242. // client decoding. Initial results say that they are inocuous, but we might have to keep
  243. // an eye on this in case something trips up.
  244. if let binaryData = responseTextBuffer.readData(length: responseTextBuffer.readableBytes) {
  245. let encodedData = binaryData.base64EncodedString()
  246. responseTextBuffer.clear()
  247. responseTextBuffer.reserveCapacity(encodedData.utf8.count)
  248. responseTextBuffer.write(string: encodedData)
  249. }
  250. // After collecting all response for gRPC Web connections, send one final aggregated
  251. // response.
  252. ctx.write(self.wrapOutboundOut(.body(.byteBuffer(responseTextBuffer))), promise: promise)
  253. ctx.write(self.wrapOutboundOut(.end(nil)), promise: promise)
  254. } else {
  255. ctx.write(self.wrapOutboundOut(.end(trailers)), promise: promise)
  256. }
  257. outboundState = .ignore
  258. inboundState = .ignore
  259. }
  260. }
  261. }