ReadWriteStates.swift 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. /*
  2. * Copyright 2019, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import NIO
  17. import SwiftProtobuf
  18. /// Number of messages expected on a stream.
  19. enum MessageArity {
  20. case one
  21. case many
  22. }
  23. /// Encapsulates the state required to create a new write state.
  24. struct PendingWriteState {
  25. /// The number of messages we expect to write to the stream.
  26. var arity: MessageArity
  27. /// The 'content-type' being written.
  28. var contentType: ContentType
  29. func makeWriteState(messageEncoding: ClientMessageEncoding) -> WriteState {
  30. let compression: CompressionAlgorithm?
  31. switch messageEncoding {
  32. case let .enabled(configuration):
  33. compression = configuration.outbound
  34. case .disabled:
  35. compression = nil
  36. }
  37. let writer = LengthPrefixedMessageWriter(compression: compression)
  38. return .writing(self.arity, self.contentType, writer)
  39. }
  40. }
  41. /// The write state of a stream.
  42. enum WriteState {
  43. /// Writing may be attempted using the given writer.
  44. case writing(MessageArity, ContentType, LengthPrefixedMessageWriter)
  45. /// Writing may not be attempted: either a write previously failed or it is not valid for any
  46. /// more messages to be written.
  47. case notWriting
  48. /// Writes a message into a buffer using the `writer` and `allocator`.
  49. ///
  50. /// - Parameter message: The `Message` to write.
  51. /// - Parameter allocator: An allocator to provide a `ByteBuffer` into which the message will be
  52. /// written.
  53. mutating func write(
  54. _ message: ByteBuffer,
  55. compressed: Bool,
  56. allocator: ByteBufferAllocator
  57. ) -> Result<ByteBuffer, MessageWriteError> {
  58. switch self {
  59. case .notWriting:
  60. return .failure(.cardinalityViolation)
  61. case let .writing(writeArity, contentType, writer):
  62. let buffer: ByteBuffer
  63. do {
  64. buffer = try writer.write(buffer: message, allocator: allocator, compressed: compressed)
  65. } catch {
  66. self = .notWriting
  67. return .failure(.serializationFailed)
  68. }
  69. // If we only expect to write one message then we're no longer writable.
  70. if case .one = writeArity {
  71. self = .notWriting
  72. } else {
  73. self = .writing(writeArity, contentType, writer)
  74. }
  75. return .success(buffer)
  76. }
  77. }
  78. }
  79. enum MessageWriteError: Error {
  80. /// Too many messages were written.
  81. case cardinalityViolation
  82. /// Message serialization failed.
  83. case serializationFailed
  84. /// An invalid state was encountered. This is a serious implementation error.
  85. case invalidState
  86. }
  87. /// Encapsulates the state required to create a new read state.
  88. struct PendingReadState {
  89. /// The number of messages we expect to read from the stream.
  90. var arity: MessageArity
  91. /// The message encoding configuration, and whether it's enabled or not.
  92. var messageEncoding: ClientMessageEncoding
  93. func makeReadState(compression: CompressionAlgorithm? = nil) -> ReadState {
  94. let reader: LengthPrefixedMessageReader
  95. switch (self.messageEncoding, compression) {
  96. case let (.enabled(configuration), .some(compression)):
  97. reader = LengthPrefixedMessageReader(
  98. compression: compression,
  99. decompressionLimit: configuration.decompressionLimit
  100. )
  101. case (.enabled, .none),
  102. (.disabled, _):
  103. reader = LengthPrefixedMessageReader()
  104. }
  105. return .reading(self.arity, reader)
  106. }
  107. }
  108. /// The read state of a stream.
  109. enum ReadState {
  110. /// Reading may be attempted using the given reader.
  111. case reading(MessageArity, LengthPrefixedMessageReader)
  112. /// Reading may not be attempted: either a read previously failed or it is not valid for any
  113. /// more messages to be read.
  114. case notReading
  115. /// Consume the given `buffer` then attempt to read length-prefixed serialized messages.
  116. ///
  117. /// For an expected message count of `.one`, this function will produce **at most** 1 message. If
  118. /// a message has been produced then subsequent calls will result in an error.
  119. ///
  120. /// - Parameter buffer: The buffer to read from.
  121. mutating func readMessages(_ buffer: inout ByteBuffer) -> Result<[ByteBuffer], MessageReadError> {
  122. switch self {
  123. case .notReading:
  124. return .failure(.cardinalityViolation)
  125. case .reading(let readArity, var reader):
  126. reader.append(buffer: &buffer)
  127. var messages: [ByteBuffer] = []
  128. do {
  129. while let serializedBytes = try reader.nextMessage() {
  130. messages.append(serializedBytes)
  131. }
  132. } catch {
  133. self = .notReading
  134. if let grpcError = error as? GRPCError.WithContext,
  135. let limitExceeded = grpcError.error as? GRPCError.DecompressionLimitExceeded {
  136. return .failure(.decompressionLimitExceeded(limitExceeded.compressedSize))
  137. } else {
  138. return .failure(.deserializationFailed)
  139. }
  140. }
  141. // We need to validate the number of messages we decoded. Zero is fine because the payload may
  142. // be split across frames.
  143. switch (readArity, messages.count) {
  144. // Always allowed:
  145. case (.one, 0),
  146. (.many, 0...):
  147. self = .reading(readArity, reader)
  148. return .success(messages)
  149. // Also allowed, assuming we have no leftover bytes:
  150. case (.one, 1):
  151. // We can't read more than one message on a unary stream.
  152. self = .notReading
  153. // We shouldn't have any bytes leftover after reading a single message and we also should not
  154. // have partially read a message.
  155. if reader.unprocessedBytes != 0 || reader.isReading {
  156. return .failure(.leftOverBytes)
  157. } else {
  158. return .success(messages)
  159. }
  160. // Anything else must be invalid.
  161. default:
  162. self = .notReading
  163. return .failure(.cardinalityViolation)
  164. }
  165. }
  166. }
  167. }
  168. enum MessageReadError: Error, Equatable {
  169. /// Too many messages were read.
  170. case cardinalityViolation
  171. /// Enough messages were read but bytes there are left-over bytes.
  172. case leftOverBytes
  173. /// Message deserialization failed.
  174. case deserializationFailed
  175. /// The limit for decompression was exceeded.
  176. case decompressionLimitExceeded(Int)
  177. /// An invalid state was encountered. This is a serious implementation error.
  178. case invalidState
  179. }