ReadWriteStates.swift 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. /*
  2. * Copyright 2019, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import NIO
  17. import SwiftProtobuf
  18. /// Number of messages expected on a stream.
  19. enum MessageArity {
  20. case one
  21. case many
  22. }
  23. /// Encapsulates the state required to create a new write state.
  24. struct PendingWriteState {
  25. /// The number of messages we expect to write to the stream.
  26. var arity: MessageArity
  27. /// The 'content-type' being written.
  28. var contentType: ContentType
  29. func makeWriteState(messageEncoding: ClientMessageEncoding) -> WriteState {
  30. let compression: CompressionAlgorithm?
  31. switch messageEncoding {
  32. case .enabled(let configuration):
  33. compression = configuration.outbound
  34. case .disabled:
  35. compression = nil
  36. }
  37. let writer = LengthPrefixedMessageWriter(compression: compression)
  38. return .writing(self.arity, self.contentType, writer)
  39. }
  40. }
  41. /// The write state of a stream.
  42. enum WriteState {
  43. /// Writing may be attempted using the given writer.
  44. case writing(MessageArity, ContentType, LengthPrefixedMessageWriter)
  45. /// Writing may not be attempted: either a write previously failed or it is not valid for any
  46. /// more messages to be written.
  47. case notWriting
  48. /// Writes a message into a buffer using the `writer` and `allocator`.
  49. ///
  50. /// - Parameter message: The `Message` to write.
  51. /// - Parameter allocator: An allocator to provide a `ByteBuffer` into which the message will be
  52. /// written.
  53. mutating func write(
  54. _ message: GRPCPayload,
  55. compressed: Bool,
  56. allocator: ByteBufferAllocator
  57. ) -> Result<ByteBuffer, MessageWriteError> {
  58. switch self {
  59. case .notWriting:
  60. return .failure(.cardinalityViolation)
  61. case let .writing(writeArity, contentType, writer):
  62. // Zero is fine: the writer will allocate the correct amount of space.
  63. var buffer = allocator.buffer(capacity: 0)
  64. do {
  65. try writer.write(message, into: &buffer, compressed: compressed)
  66. } catch {
  67. self = .notWriting
  68. return .failure(.serializationFailed)
  69. }
  70. // If we only expect to write one message then we're no longer writable.
  71. if case .one = writeArity {
  72. self = .notWriting
  73. } else {
  74. self = .writing(writeArity, contentType, writer)
  75. }
  76. return .success(buffer)
  77. }
  78. }
  79. }
  80. enum MessageWriteError: Error {
  81. /// Too many messages were written.
  82. case cardinalityViolation
  83. /// Message serialization failed.
  84. case serializationFailed
  85. /// An invalid state was encountered. This is a serious implementation error.
  86. case invalidState
  87. }
  88. /// Encapsulates the state required to create a new read state.
  89. struct PendingReadState {
  90. /// The number of messages we expect to read from the stream.
  91. var arity: MessageArity
  92. /// The message encoding configuration, and whether it's enabled or not.
  93. var messageEncoding: ClientMessageEncoding
  94. func makeReadState(compression: CompressionAlgorithm? = nil) -> ReadState {
  95. let reader: LengthPrefixedMessageReader
  96. switch (self.messageEncoding, compression) {
  97. case (.enabled(let configuration), .some(let compression)):
  98. reader = LengthPrefixedMessageReader(
  99. compression: compression,
  100. decompressionLimit: configuration.decompressionLimit
  101. )
  102. case (.enabled, .none),
  103. (.disabled, _):
  104. reader = LengthPrefixedMessageReader()
  105. }
  106. return .reading(self.arity, reader)
  107. }
  108. }
  109. /// The read state of a stream.
  110. enum ReadState {
  111. /// Reading may be attempted using the given reader.
  112. case reading(MessageArity, LengthPrefixedMessageReader)
  113. /// Reading may not be attempted: either a read previously failed or it is not valid for any
  114. /// more messages to be read.
  115. case notReading
  116. /// Consume the given `buffer` then attempt to read and subsequently decode length-prefixed
  117. /// serialized messages.
  118. ///
  119. /// For an expected message count of `.one`, this function will produce **at most** 1 message. If
  120. /// a message has been produced then subsequent calls will result in an error.
  121. ///
  122. /// - Parameter buffer: The buffer to read from.
  123. mutating func readMessages<MessageType: GRPCPayload>(
  124. _ buffer: inout ByteBuffer,
  125. as: MessageType.Type = MessageType.self
  126. ) -> Result<[MessageType], MessageReadError> {
  127. switch self {
  128. case .notReading:
  129. return .failure(.cardinalityViolation)
  130. case .reading(let readArity, var reader):
  131. reader.append(buffer: &buffer)
  132. var messages: [MessageType] = []
  133. do {
  134. while var serializedBytes = try reader.nextMessage() {
  135. // Force unwrapping is okay here: we will always be able to read `readableBytes`.
  136. messages.append(try MessageType(serializedByteBuffer: &serializedBytes))
  137. }
  138. } catch {
  139. self = .notReading
  140. if let grpcError = error as? GRPCError.WithContext,
  141. let limitExceeded = grpcError.error as? GRPCError.DecompressionLimitExceeded {
  142. return .failure(.decompressionLimitExceeded(limitExceeded.compressedSize))
  143. } else {
  144. return .failure(.deserializationFailed)
  145. }
  146. }
  147. // We need to validate the number of messages we decoded. Zero is fine because the payload may
  148. // be split across frames.
  149. switch (readArity, messages.count) {
  150. // Always allowed:
  151. case (.one, 0),
  152. (.many, 0...):
  153. self = .reading(readArity, reader)
  154. return .success(messages)
  155. // Also allowed, assuming we have no leftover bytes:
  156. case (.one, 1):
  157. // We can't read more than one message on a unary stream.
  158. self = .notReading
  159. // We shouldn't have any bytes leftover after reading a single message and we also should not
  160. // have partially read a message.
  161. if reader.unprocessedBytes != 0 || reader.isReading {
  162. return .failure(.leftOverBytes)
  163. } else {
  164. return .success(messages)
  165. }
  166. // Anything else must be invalid.
  167. default:
  168. self = .notReading
  169. return .failure(.cardinalityViolation)
  170. }
  171. }
  172. }
  173. }
  174. enum MessageReadError: Error, Equatable {
  175. /// Too many messages were read.
  176. case cardinalityViolation
  177. /// Enough messages were read but bytes there are left-over bytes.
  178. case leftOverBytes
  179. /// Message deserialization failed.
  180. case deserializationFailed
  181. /// The limit for decompression was exceeded.
  182. case decompressionLimitExceeded(Int)
  183. /// An invalid state was encountered. This is a serious implementation error.
  184. case invalidState
  185. }