ReadWriteStates.swift 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. /*
  2. * Copyright 2019, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import NIOCore
  17. import SwiftProtobuf
  18. /// Number of messages expected on a stream.
  19. enum MessageArity {
  20. case one
  21. case many
  22. }
  23. /// Encapsulates the state required to create a new write state.
  24. struct PendingWriteState {
  25. /// The number of messages we expect to write to the stream.
  26. var arity: MessageArity
  27. /// The 'content-type' being written.
  28. var contentType: ContentType
  29. func makeWriteState(
  30. messageEncoding: ClientMessageEncoding,
  31. allocator: ByteBufferAllocator
  32. ) -> WriteState {
  33. let compression: CompressionAlgorithm?
  34. switch messageEncoding {
  35. case let .enabled(configuration):
  36. compression = configuration.outbound
  37. case .disabled:
  38. compression = nil
  39. }
  40. let writer = CoalescingLengthPrefixedMessageWriter(
  41. compression: compression,
  42. allocator: allocator
  43. )
  44. return .init(arity: self.arity, contentType: self.contentType, writer: writer)
  45. }
  46. }
  47. /// The write state of a stream.
  48. struct WriteState {
  49. private var arity: MessageArity
  50. private var contentType: ContentType
  51. private var writer: CoalescingLengthPrefixedMessageWriter
  52. private var canWrite: Bool
  53. init(
  54. arity: MessageArity,
  55. contentType: ContentType,
  56. writer: CoalescingLengthPrefixedMessageWriter
  57. ) {
  58. self.arity = arity
  59. self.contentType = contentType
  60. self.writer = writer
  61. self.canWrite = true
  62. }
  63. /// Writes a message into a buffer using the `writer`.
  64. ///
  65. /// - Parameter message: The `Message` to write.
  66. mutating func write(
  67. _ message: ByteBuffer,
  68. compressed: Bool,
  69. promise: EventLoopPromise<Void>?
  70. ) -> Result<Void, MessageWriteError> {
  71. guard self.canWrite else {
  72. return .failure(.cardinalityViolation)
  73. }
  74. self.writer.append(buffer: message, compress: compressed, promise: promise)
  75. switch self.arity {
  76. case .one:
  77. self.canWrite = false
  78. case .many:
  79. ()
  80. }
  81. return .success(())
  82. }
  83. mutating func next() -> (Result<ByteBuffer, MessageWriteError>, EventLoopPromise<Void>?)? {
  84. if let next = self.writer.next() {
  85. return (next.0.mapError { _ in .serializationFailed }, next.1)
  86. } else {
  87. return nil
  88. }
  89. }
  90. }
  91. enum MessageWriteError: Error {
  92. /// Too many messages were written.
  93. case cardinalityViolation
  94. /// Message serialization failed.
  95. case serializationFailed
  96. /// An invalid state was encountered. This is a serious implementation error.
  97. case invalidState
  98. }
  99. /// Encapsulates the state required to create a new read state.
  100. struct PendingReadState {
  101. /// The number of messages we expect to read from the stream.
  102. var arity: MessageArity
  103. /// The message encoding configuration, and whether it's enabled or not.
  104. var messageEncoding: ClientMessageEncoding
  105. func makeReadState(compression: CompressionAlgorithm? = nil) -> ReadState {
  106. let reader: LengthPrefixedMessageReader
  107. switch (self.messageEncoding, compression) {
  108. case let (.enabled(configuration), .some(compression)):
  109. reader = LengthPrefixedMessageReader(
  110. compression: compression,
  111. decompressionLimit: configuration.decompressionLimit
  112. )
  113. case (.enabled, .none),
  114. (.disabled, _):
  115. reader = LengthPrefixedMessageReader()
  116. }
  117. return .reading(self.arity, reader)
  118. }
  119. }
  120. /// The read state of a stream.
  121. enum ReadState {
  122. /// Reading may be attempted using the given reader.
  123. case reading(MessageArity, LengthPrefixedMessageReader)
  124. /// Reading may not be attempted: either a read previously failed or it is not valid for any
  125. /// more messages to be read.
  126. case notReading
  127. /// Consume the given `buffer` then attempt to read length-prefixed serialized messages.
  128. ///
  129. /// For an expected message count of `.one`, this function will produce **at most** 1 message. If
  130. /// a message has been produced then subsequent calls will result in an error.
  131. ///
  132. /// - Parameter buffer: The buffer to read from.
  133. mutating func readMessages(
  134. _ buffer: inout ByteBuffer,
  135. maxLength: Int
  136. ) -> Result<[ByteBuffer], MessageReadError> {
  137. switch self {
  138. case .notReading:
  139. return .failure(.cardinalityViolation)
  140. case .reading(let readArity, var reader):
  141. self = .notReading // Avoid CoWs
  142. reader.append(buffer: &buffer)
  143. var messages: [ByteBuffer] = []
  144. do {
  145. while let serializedBytes = try reader.nextMessage(maxLength: maxLength) {
  146. messages.append(serializedBytes)
  147. }
  148. } catch {
  149. self = .notReading
  150. if let grpcError = error as? GRPCError.WithContext {
  151. if let compressionLimit = grpcError.error as? GRPCError.DecompressionLimitExceeded {
  152. return .failure(.decompressionLimitExceeded(compressionLimit.compressedSize))
  153. } else if let lengthLimit = grpcError.error as? GRPCError.PayloadLengthLimitExceeded {
  154. return .failure(.lengthExceedsLimit(lengthLimit))
  155. }
  156. }
  157. return .failure(.deserializationFailed)
  158. }
  159. // We need to validate the number of messages we decoded. Zero is fine because the payload may
  160. // be split across frames.
  161. switch (readArity, messages.count) {
  162. // Always allowed:
  163. case (.one, 0),
  164. (.many, 0...):
  165. self = .reading(readArity, reader)
  166. return .success(messages)
  167. // Also allowed, assuming we have no leftover bytes:
  168. case (.one, 1):
  169. // We can't read more than one message on a unary stream.
  170. self = .notReading
  171. // We shouldn't have any bytes leftover after reading a single message and we also should not
  172. // have partially read a message.
  173. if reader.unprocessedBytes != 0 || reader.isReading {
  174. return .failure(.leftOverBytes)
  175. } else {
  176. return .success(messages)
  177. }
  178. // Anything else must be invalid.
  179. default:
  180. self = .notReading
  181. return .failure(.cardinalityViolation)
  182. }
  183. }
  184. }
  185. }
  186. enum MessageReadError: Error, Equatable {
  187. /// Too many messages were read.
  188. case cardinalityViolation
  189. /// Enough messages were read but bytes there are left-over bytes.
  190. case leftOverBytes
  191. /// Message deserialization failed.
  192. case deserializationFailed
  193. /// The limit for decompression was exceeded.
  194. case decompressionLimitExceeded(Int)
  195. /// The length of the message exceeded the permitted maximum length.
  196. case lengthExceedsLimit(GRPCError.PayloadLengthLimitExceeded)
  197. /// An invalid state was encountered. This is a serious implementation error.
  198. case invalidState
  199. }