LengthPrefixedMessageReader.swift 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. /*
  2. * Copyright 2019, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import Foundation
  17. import NIO
  18. import NIOHTTP1
  19. /// This class reads and decodes length-prefixed gRPC messages.
  20. ///
  21. /// Messages are expected to be in the following format:
  22. /// - compression flag: 0/1 as a 1-byte unsigned integer,
  23. /// - message length: length of the message as a 4-byte unsigned integer,
  24. /// - message: `message_length` bytes.
  25. ///
  26. /// Messages may span multiple `ByteBuffer`s, and `ByteBuffer`s may contain multiple
  27. /// length-prefixed messages.
  28. ///
  29. /// - SeeAlso:
  30. /// [gRPC Protocol](https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md)
  31. public class LengthPrefixedMessageReader {
  32. public typealias Mode = GRPCError.Origin
  33. private let mode: Mode
  34. private var buffer: ByteBuffer!
  35. private var state: State = .expectingCompressedFlag
  36. private enum State {
  37. case expectingCompressedFlag
  38. case expectingMessageLength
  39. case receivedMessageLength(Int)
  40. case willBuffer(requiredBytes: Int)
  41. case isBuffering(requiredBytes: Int)
  42. }
  43. public init(mode: Mode) {
  44. self.mode = mode
  45. }
  46. /// Consumes all readable bytes from given buffer and returns all messages which could be read.
  47. ///
  48. /// - SeeAlso: `read(messageBuffer:compression:)`
  49. public func consume(messageBuffer: inout ByteBuffer, compression: CompressionMechanism) throws -> [ByteBuffer] {
  50. var messages: [ByteBuffer] = []
  51. while messageBuffer.readableBytes > 0 {
  52. if let message = try self.read(messageBuffer: &messageBuffer, compression: compression) {
  53. messages.append(message)
  54. }
  55. }
  56. return messages
  57. }
  58. /// Reads bytes from the given buffer until it is exhausted or a message has been read.
  59. ///
  60. /// Length prefixed messages may be split across multiple input buffers in any of the
  61. /// following places:
  62. /// 1. after the compression flag,
  63. /// 2. after the message length field,
  64. /// 3. at any point within the message.
  65. ///
  66. /// It is possible for the message length field to be split across multiple `ByteBuffer`s,
  67. /// this is unlikely to happen in practice.
  68. ///
  69. /// - Note:
  70. /// This method relies on state; if a message is _not_ returned then the next time this
  71. /// method is called it expects to read the bytes which follow the most recently read bytes.
  72. ///
  73. /// - Parameters:
  74. /// - messageBuffer: buffer to read from.
  75. /// - compression: compression mechanism to decode message with.
  76. /// - Returns: A buffer containing a message if one has been read, or `nil` if not enough
  77. /// bytes have been consumed to return a message.
  78. /// - Throws: Throws an error if the compression algorithm is not supported.
  79. public func read(messageBuffer: inout ByteBuffer, compression: CompressionMechanism) throws -> ByteBuffer? {
  80. while true {
  81. switch state {
  82. case .expectingCompressedFlag:
  83. guard let compressionFlag: Int8 = messageBuffer.readInteger() else { return nil }
  84. try handleCompressionFlag(enabled: compressionFlag != 0, mechanism: compression)
  85. self.state = .expectingMessageLength
  86. case .expectingMessageLength:
  87. //! FIXME: Support the message length being split across multiple byte buffers.
  88. guard let messageLength: UInt32 = messageBuffer.readInteger() else { return nil }
  89. self.state = .receivedMessageLength(numericCast(messageLength))
  90. case .receivedMessageLength(let messageLength):
  91. // If this holds true, we can skip buffering and return a slice.
  92. guard messageLength <= messageBuffer.readableBytes else {
  93. self.state = .willBuffer(requiredBytes: messageLength)
  94. continue
  95. }
  96. self.state = .expectingCompressedFlag
  97. // We know messageBuffer.readableBytes >= messageLength, so it's okay to force unwrap here.
  98. return messageBuffer.readSlice(length: messageLength)!
  99. case .willBuffer(let requiredBytes):
  100. messageBuffer.reserveCapacity(requiredBytes)
  101. self.buffer = messageBuffer
  102. let readableBytes = messageBuffer.readableBytes
  103. // Move the reader index to avoid reading the bytes again.
  104. messageBuffer.moveReaderIndex(forwardBy: readableBytes)
  105. self.state = .isBuffering(requiredBytes: requiredBytes - readableBytes)
  106. return nil
  107. case .isBuffering(let requiredBytes):
  108. guard requiredBytes <= messageBuffer.readableBytes else {
  109. self.state = .isBuffering(requiredBytes: requiredBytes - self.buffer.write(buffer: &messageBuffer))
  110. return nil
  111. }
  112. // We know messageBuffer.readableBytes >= requiredBytes, so it's okay to force unwrap here.
  113. var slice = messageBuffer.readSlice(length: requiredBytes)!
  114. self.buffer.write(buffer: &slice)
  115. self.state = .expectingCompressedFlag
  116. defer { self.buffer = nil }
  117. return buffer
  118. }
  119. }
  120. }
  121. private func handleCompressionFlag(enabled flagEnabled: Bool, mechanism: CompressionMechanism) throws {
  122. guard flagEnabled == mechanism.requiresFlag else {
  123. throw GRPCError.common(.unexpectedCompression, origin: mode)
  124. }
  125. guard mechanism.supported else {
  126. throw GRPCError.common(.unsupportedCompressionMechanism(mechanism.rawValue), origin: mode)
  127. }
  128. }
  129. }