ReadWriteStates.swift 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. /*
  2. * Copyright 2019, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import NIO
  17. import SwiftProtobuf
  18. /// Number of messages expected on a stream.
  19. enum MessageArity {
  20. case one
  21. case many
  22. }
  23. /// Encapsulates the state required to create a new write state.
  24. struct PendingWriteState {
  25. /// The number of messages we expect to write to the stream.
  26. var arity: MessageArity
  27. /// The compression used when writing messages.
  28. var compression: CompressionMechanism
  29. /// The 'content-type' being written.
  30. var contentType: ContentType
  31. func makeWriteState() -> WriteState {
  32. return .writing(
  33. self.arity,
  34. self.contentType,
  35. LengthPrefixedMessageWriter(compression: self.compression)
  36. )
  37. }
  38. }
  39. /// The write state of a stream.
  40. enum WriteState {
  41. /// Writing may be attempted using the given writer.
  42. case writing(MessageArity, ContentType, LengthPrefixedMessageWriter)
  43. /// Writing may not be attempted: either a write previously failed or it is not valid for any
  44. /// more messages to be written.
  45. case notWriting
  46. /// Writes a message into a buffer using the `writer` and `allocator`.
  47. ///
  48. /// - Parameter message: The `Message` to write.
  49. /// - Parameter allocator: An allocator to provide a `ByteBuffer` into which the message will be
  50. /// written.
  51. mutating func write(
  52. _ message: Message,
  53. allocator: ByteBufferAllocator
  54. ) -> Result<ByteBuffer, MessageWriteError> {
  55. switch self {
  56. case .notWriting:
  57. return .failure(.cardinalityViolation)
  58. case let .writing(writeArity, contentType, writer):
  59. guard let data = try? message.serializedData() else {
  60. self = .notWriting
  61. return .failure(.serializationFailed)
  62. }
  63. // Zero is fine: the writer will allocate the correct amount of space.
  64. var buffer = allocator.buffer(capacity: 0)
  65. writer.write(data, into: &buffer)
  66. // If we only expect to write one message then we're no longer writable.
  67. if case .one = writeArity {
  68. self = .notWriting
  69. } else {
  70. self = .writing(writeArity, contentType, writer)
  71. }
  72. return .success(buffer)
  73. }
  74. }
  75. }
  76. enum MessageWriteError: Error {
  77. /// Too many messages were written.
  78. case cardinalityViolation
  79. /// Message serialization failed.
  80. case serializationFailed
  81. /// An invalid state was encountered. This is a serious implementation error.
  82. case invalidState
  83. }
  84. /// The read state of a stream.
  85. enum ReadState {
  86. /// Reading may be attempted using the given reader.
  87. case reading(MessageArity, LengthPrefixedMessageReader)
  88. /// Reading may not be attempted: either a read previously failed or it is not valid for any
  89. /// more messages to be read.
  90. case notReading
  91. /// Consume the given `buffer` then attempt to read and subsequently decode length-prefixed
  92. /// serialized messages.
  93. ///
  94. /// For an expected message count of `.one`, this function will produce **at most** 1 message. If
  95. /// a message has been produced then subsequent calls will result in an error.
  96. ///
  97. /// - Parameter buffer: The buffer to read from.
  98. mutating func readMessages<MessageType: Message>(
  99. _ buffer: inout ByteBuffer,
  100. as: MessageType.Type = MessageType.self
  101. ) -> Result<[MessageType], MessageReadError> {
  102. switch self {
  103. case .notReading:
  104. return .failure(.cardinalityViolation)
  105. case let .reading(readArity, reader):
  106. reader.append(buffer: &buffer)
  107. var messages: [MessageType] = []
  108. do {
  109. while var serializedBytes = try? reader.nextMessage() {
  110. // Force unwrapping is okay here: we will always be able to read `readableBytes`.
  111. let serializedData = serializedBytes.readData(length: serializedBytes.readableBytes)!
  112. messages.append(try MessageType(serializedData: serializedData))
  113. }
  114. } catch {
  115. self = .notReading
  116. return .failure(.deserializationFailed)
  117. }
  118. // We need to validate the number of messages we decoded. Zero is fine because the payload may
  119. // be split across frames.
  120. switch (readArity, messages.count) {
  121. // Always allowed:
  122. case (.one, 0),
  123. (.many, 0...):
  124. self = .reading(readArity, reader)
  125. return .success(messages)
  126. // Also allowed, assuming we have no leftover bytes:
  127. case (.one, 1):
  128. // We can't read more than one message on a unary stream.
  129. self = .notReading
  130. // We shouldn't have any bytes leftover after reading a single message and we also should not
  131. // have partially read a message.
  132. if reader.unprocessedBytes != 0 || reader.isReading {
  133. return .failure(.leftOverBytes)
  134. } else {
  135. return .success(messages)
  136. }
  137. // Anything else must be invalid.
  138. default:
  139. self = .notReading
  140. return .failure(.cardinalityViolation)
  141. }
  142. }
  143. }
  144. }
  145. enum MessageReadError: Error {
  146. /// Too many messages were read.
  147. case cardinalityViolation
  148. /// Enough messages were read but bytes there are left-over bytes.
  149. case leftOverBytes
  150. /// Message deserialization failed.
  151. case deserializationFailed
  152. /// An invalid state was encountered. This is a serious implementation error.
  153. case invalidState
  154. }