ReadWriteStates.swift 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. /*
  2. * Copyright 2019, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import NIOCore
  17. import SwiftProtobuf
  18. /// Number of messages expected on a stream.
  19. enum MessageArity {
  20. case one
  21. case many
  22. }
  23. /// Encapsulates the state required to create a new write state.
  24. struct PendingWriteState {
  25. /// The number of messages we expect to write to the stream.
  26. var arity: MessageArity
  27. /// The 'content-type' being written.
  28. var contentType: ContentType
  29. func makeWriteState(
  30. messageEncoding: ClientMessageEncoding,
  31. allocator: ByteBufferAllocator
  32. ) -> WriteState {
  33. let compression: CompressionAlgorithm?
  34. switch messageEncoding {
  35. case let .enabled(configuration):
  36. compression = configuration.outbound
  37. case .disabled:
  38. compression = nil
  39. }
  40. let writer = LengthPrefixedMessageWriter(compression: compression, allocator: allocator)
  41. return .writing(self.arity, self.contentType, writer)
  42. }
  43. }
  44. /// The write state of a stream.
  45. enum WriteState {
  46. /// Writing may be attempted using the given writer.
  47. case writing(MessageArity, ContentType, LengthPrefixedMessageWriter)
  48. /// Writing may not be attempted: either a write previously failed or it is not valid for any
  49. /// more messages to be written.
  50. case notWriting
  51. /// Writes a message into a buffer using the `writer`.
  52. ///
  53. /// - Parameter message: The `Message` to write.
  54. mutating func write(
  55. _ message: ByteBuffer,
  56. compressed: Bool
  57. ) -> Result<(ByteBuffer, ByteBuffer?), MessageWriteError> {
  58. switch self {
  59. case .notWriting:
  60. return .failure(.cardinalityViolation)
  61. case .writing(let writeArity, let contentType, var writer):
  62. self = .notWriting
  63. let buffers: (ByteBuffer, ByteBuffer?)
  64. do {
  65. buffers = try writer.write(buffer: message, compressed: compressed)
  66. } catch {
  67. self = .notWriting
  68. return .failure(.serializationFailed)
  69. }
  70. // If we only expect to write one message then we're no longer writable.
  71. if case .one = writeArity {
  72. self = .notWriting
  73. } else {
  74. self = .writing(writeArity, contentType, writer)
  75. }
  76. return .success(buffers)
  77. }
  78. }
  79. }
  80. enum MessageWriteError: Error {
  81. /// Too many messages were written.
  82. case cardinalityViolation
  83. /// Message serialization failed.
  84. case serializationFailed
  85. /// An invalid state was encountered. This is a serious implementation error.
  86. case invalidState
  87. }
  88. /// Encapsulates the state required to create a new read state.
  89. struct PendingReadState {
  90. /// The number of messages we expect to read from the stream.
  91. var arity: MessageArity
  92. /// The message encoding configuration, and whether it's enabled or not.
  93. var messageEncoding: ClientMessageEncoding
  94. func makeReadState(compression: CompressionAlgorithm? = nil) -> ReadState {
  95. let reader: LengthPrefixedMessageReader
  96. switch (self.messageEncoding, compression) {
  97. case let (.enabled(configuration), .some(compression)):
  98. reader = LengthPrefixedMessageReader(
  99. compression: compression,
  100. decompressionLimit: configuration.decompressionLimit
  101. )
  102. case (.enabled, .none),
  103. (.disabled, _):
  104. reader = LengthPrefixedMessageReader()
  105. }
  106. return .reading(self.arity, reader)
  107. }
  108. }
  109. /// The read state of a stream.
  110. enum ReadState {
  111. /// Reading may be attempted using the given reader.
  112. case reading(MessageArity, LengthPrefixedMessageReader)
  113. /// Reading may not be attempted: either a read previously failed or it is not valid for any
  114. /// more messages to be read.
  115. case notReading
  116. /// Consume the given `buffer` then attempt to read length-prefixed serialized messages.
  117. ///
  118. /// For an expected message count of `.one`, this function will produce **at most** 1 message. If
  119. /// a message has been produced then subsequent calls will result in an error.
  120. ///
  121. /// - Parameter buffer: The buffer to read from.
  122. mutating func readMessages(
  123. _ buffer: inout ByteBuffer,
  124. maxLength: Int
  125. ) -> Result<[ByteBuffer], MessageReadError> {
  126. switch self {
  127. case .notReading:
  128. return .failure(.cardinalityViolation)
  129. case .reading(let readArity, var reader):
  130. self = .notReading // Avoid CoWs
  131. reader.append(buffer: &buffer)
  132. var messages: [ByteBuffer] = []
  133. do {
  134. while let serializedBytes = try reader.nextMessage(maxLength: maxLength) {
  135. messages.append(serializedBytes)
  136. }
  137. } catch {
  138. self = .notReading
  139. if let grpcError = error as? GRPCError.WithContext {
  140. if let compressionLimit = grpcError.error as? GRPCError.DecompressionLimitExceeded {
  141. return .failure(.decompressionLimitExceeded(compressionLimit.compressedSize))
  142. } else if let lengthLimit = grpcError.error as? GRPCError.PayloadLengthLimitExceeded {
  143. return .failure(.lengthExceedsLimit(lengthLimit))
  144. }
  145. }
  146. return .failure(.deserializationFailed)
  147. }
  148. // We need to validate the number of messages we decoded. Zero is fine because the payload may
  149. // be split across frames.
  150. switch (readArity, messages.count) {
  151. // Always allowed:
  152. case (.one, 0),
  153. (.many, 0...):
  154. self = .reading(readArity, reader)
  155. return .success(messages)
  156. // Also allowed, assuming we have no leftover bytes:
  157. case (.one, 1):
  158. // We can't read more than one message on a unary stream.
  159. self = .notReading
  160. // We shouldn't have any bytes leftover after reading a single message and we also should not
  161. // have partially read a message.
  162. if reader.unprocessedBytes != 0 || reader.isReading {
  163. return .failure(.leftOverBytes)
  164. } else {
  165. return .success(messages)
  166. }
  167. // Anything else must be invalid.
  168. default:
  169. self = .notReading
  170. return .failure(.cardinalityViolation)
  171. }
  172. }
  173. }
  174. }
  175. enum MessageReadError: Error, Equatable {
  176. /// Too many messages were read.
  177. case cardinalityViolation
  178. /// Enough messages were read but bytes there are left-over bytes.
  179. case leftOverBytes
  180. /// Message deserialization failed.
  181. case deserializationFailed
  182. /// The limit for decompression was exceeded.
  183. case decompressionLimitExceeded(Int)
  184. /// The length of the message exceeded the permitted maximum length.
  185. case lengthExceedsLimit(GRPCError.PayloadLengthLimitExceeded)
  186. /// An invalid state was encountered. This is a serious implementation error.
  187. case invalidState
  188. }