InProcessClientTransport.swift 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. /*
  2. * Copyright 2023, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import GRPCCore
  17. /// An in-process implementation of a ``ClientTransport``.
  18. ///
  19. /// This is useful when you're interested in testing your application without any actual networking layers
  20. /// involved, as the client and server will communicate directly with each other via in-process streams.
  21. ///
  22. /// To use this client, you'll have to provide an ``InProcessServerTransport`` upon creation, as well
  23. /// as a ``ClientRPCExecutionConfigurationCollection``, containing a set of
  24. /// ``ClientRPCExecutionConfiguration``s which are specific, per-method configurations for your
  25. /// transport.
  26. ///
  27. /// Once you have a client, you must keep a long-running task executing ``connect(lazily:)``, which
  28. /// will return only once all streams have been finished and ``close()`` has been called on this client; or
  29. /// when the containing task is cancelled.
  30. ///
  31. /// To execute requests using this client, use ``withStream(descriptor:_:)``. If this function is
  32. /// called before ``connect(lazily:)`` is called, then any streams will remain pending and the call will
  33. /// block until ``connect(lazily:)`` is called or the task is cancelled.
  34. ///
  35. /// - SeeAlso: ``ClientTransport``
  36. @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
  37. public struct InProcessClientTransport: ClientTransport {
  38. private enum State: Sendable {
  39. struct UnconnectedState {
  40. var serverTransport: InProcessServerTransport
  41. var pendingStreams: [AsyncStream<Void>.Continuation]
  42. init(serverTransport: InProcessServerTransport) {
  43. self.serverTransport = serverTransport
  44. self.pendingStreams = []
  45. }
  46. }
  47. struct ConnectedState {
  48. var serverTransport: InProcessServerTransport
  49. var nextStreamID: Int
  50. var openStreams:
  51. [Int: (
  52. RPCStream<Inbound, Outbound>,
  53. RPCStream<RPCAsyncSequence<RPCRequestPart>, RPCWriter<RPCResponsePart>.Closable>
  54. )]
  55. var signalEndContinuation: AsyncStream<Void>.Continuation
  56. init(
  57. fromUnconnected state: UnconnectedState,
  58. signalEndContinuation: AsyncStream<Void>.Continuation
  59. ) {
  60. self.serverTransport = state.serverTransport
  61. self.nextStreamID = 0
  62. self.openStreams = [:]
  63. self.signalEndContinuation = signalEndContinuation
  64. }
  65. }
  66. struct ClosedState {
  67. var openStreams:
  68. [Int: (
  69. RPCStream<Inbound, Outbound>,
  70. RPCStream<RPCAsyncSequence<RPCRequestPart>, RPCWriter<RPCResponsePart>.Closable>
  71. )]
  72. var signalEndContinuation: AsyncStream<Void>.Continuation?
  73. init() {
  74. self.openStreams = [:]
  75. self.signalEndContinuation = nil
  76. }
  77. init(fromConnected state: ConnectedState) {
  78. self.openStreams = state.openStreams
  79. self.signalEndContinuation = state.signalEndContinuation
  80. }
  81. }
  82. case unconnected(UnconnectedState)
  83. case connected(ConnectedState)
  84. case closed(ClosedState)
  85. }
  86. public typealias Inbound = RPCAsyncSequence<RPCResponsePart>
  87. public typealias Outbound = RPCWriter<RPCRequestPart>.Closable
  88. public let retryThrottle: RetryThrottle
  89. private let methodConfiguration: MethodConfigurations
  90. private let state: _LockedValueBox<State>
  91. public init(
  92. server: InProcessServerTransport,
  93. methodConfiguration: MethodConfigurations = MethodConfigurations()
  94. ) {
  95. self.retryThrottle = RetryThrottle(maximumTokens: 10, tokenRatio: 0.1)
  96. self.methodConfiguration = methodConfiguration
  97. self.state = _LockedValueBox(.unconnected(.init(serverTransport: server)))
  98. }
  99. /// Establish and maintain a connection to the remote destination.
  100. ///
  101. /// Maintains a long-lived connection, or set of connections, to a remote destination.
  102. /// Connections may be added or removed over time as required by the implementation and the
  103. /// demand for streams by the client.
  104. ///
  105. /// Implementations of this function will typically create a long-lived task group which
  106. /// maintains connections. The function exits when all open streams have been closed and new connections
  107. /// are no longer required by the caller who signals this by calling ``close()``, or by cancelling the
  108. /// task this function runs in.
  109. ///
  110. /// - Parameter lazily: This parameter is ignored in this implementation.
  111. public func connect(lazily: Bool) async throws {
  112. let (stream, continuation) = AsyncStream<Void>.makeStream()
  113. try self.state.withLockedValue { state in
  114. switch state {
  115. case .unconnected(let unconnectedState):
  116. state = .connected(
  117. .init(
  118. fromUnconnected: unconnectedState,
  119. signalEndContinuation: continuation
  120. )
  121. )
  122. for pendingStream in unconnectedState.pendingStreams {
  123. pendingStream.finish()
  124. }
  125. case .connected:
  126. throw RPCError(
  127. code: .failedPrecondition,
  128. message: "Already connected to server."
  129. )
  130. case .closed:
  131. throw RPCError(
  132. code: .failedPrecondition,
  133. message: "Can't connect to server, transport is closed."
  134. )
  135. }
  136. }
  137. for await _ in stream {
  138. // This for-await loop will exit (and thus `connect(lazily:)` will return)
  139. // only when the task is cancelled, or when the stream's continuation is
  140. // finished - whichever happens first.
  141. // The continuation will be finished when `close()` is called and there
  142. // are no more open streams.
  143. }
  144. // If at this point there are any open streams, it's because Cancellation
  145. // occurred and all open streams must now be closed.
  146. let openStreams = self.state.withLockedValue { state in
  147. switch state {
  148. case .unconnected:
  149. // We have transitioned to connected, and we can't transition back.
  150. fatalError("Invalid state")
  151. case .connected(let connectedState):
  152. state = .closed(.init())
  153. return connectedState.openStreams.values
  154. case .closed(let closedState):
  155. return closedState.openStreams.values
  156. }
  157. }
  158. for (clientStream, serverStream) in openStreams {
  159. clientStream.outbound.finish(throwing: CancellationError())
  160. serverStream.outbound.finish(throwing: CancellationError())
  161. }
  162. }
  163. /// Signal to the transport that no new streams may be created.
  164. ///
  165. /// Existing streams may run to completion naturally but calling ``withStream(descriptor:_:)``
  166. /// will result in an ``RPCError`` with code ``RPCError/Code/failedPrecondition`` being thrown.
  167. ///
  168. /// If you want to forcefully cancel all active streams then cancel the task running ``connect(lazily:)``.
  169. public func close() {
  170. let maybeContinuation: AsyncStream<Void>.Continuation? = self.state.withLockedValue { state in
  171. switch state {
  172. case .unconnected:
  173. state = .closed(.init())
  174. return nil
  175. case .connected(let connectedState):
  176. if connectedState.openStreams.count == 0 {
  177. state = .closed(.init())
  178. return connectedState.signalEndContinuation
  179. } else {
  180. state = .closed(.init(fromConnected: connectedState))
  181. return nil
  182. }
  183. case .closed:
  184. return nil
  185. }
  186. }
  187. maybeContinuation?.finish()
  188. }
  189. /// Opens a stream using the transport, and uses it as input into a user-provided closure.
  190. ///
  191. /// - Important: The opened stream is closed after the closure is finished.
  192. ///
  193. /// This transport implementation throws ``RPCError/Code/failedPrecondition`` if the transport
  194. /// is closing or has been closed.
  195. ///
  196. /// This implementation will queue any streams (and thus block this call) if this function is called before
  197. /// ``connect(lazily:)``, until a connection is established - at which point all streams will be
  198. /// created.
  199. ///
  200. /// - Parameters:
  201. /// - descriptor: A description of the method to open a stream for.
  202. /// - closure: A closure that takes the opened stream as parameter.
  203. /// - Returns: Whatever value was returned from `closure`.
  204. public func withStream<T>(
  205. descriptor: MethodDescriptor,
  206. _ closure: (RPCStream<Inbound, Outbound>) async throws -> T
  207. ) async throws -> T {
  208. let request = RPCAsyncSequence<RPCRequestPart>._makeBackpressuredStream(watermarks: (16, 32))
  209. let response = RPCAsyncSequence<RPCResponsePart>._makeBackpressuredStream(watermarks: (16, 32))
  210. let clientStream = RPCStream(
  211. descriptor: descriptor,
  212. inbound: response.stream,
  213. outbound: request.writer
  214. )
  215. let serverStream = RPCStream(
  216. descriptor: descriptor,
  217. inbound: request.stream,
  218. outbound: response.writer
  219. )
  220. let waitForConnectionStream: AsyncStream<Void>? = self.state.withLockedValue { state in
  221. if case .unconnected(var unconnectedState) = state {
  222. let (stream, continuation) = AsyncStream<Void>.makeStream()
  223. unconnectedState.pendingStreams.append(continuation)
  224. state = .unconnected(unconnectedState)
  225. return stream
  226. }
  227. return nil
  228. }
  229. if let waitForConnectionStream {
  230. for await _ in waitForConnectionStream {
  231. // This loop will exit either when the task is cancelled or when the
  232. // client connects and this stream can be opened.
  233. }
  234. try Task.checkCancellation()
  235. }
  236. let streamID = try self.state.withLockedValue { state in
  237. switch state {
  238. case .unconnected:
  239. // The state cannot be unconnected because if it was, then the above
  240. // for-await loop on `pendingStream` would have not returned.
  241. // The only other option is for the task to have been cancelled,
  242. // and that's why we check for cancellation right after the loop.
  243. fatalError("Invalid state.")
  244. case .connected(var connectedState):
  245. let streamID = connectedState.nextStreamID
  246. do {
  247. try connectedState.serverTransport.acceptStream(serverStream)
  248. connectedState.openStreams[streamID] = (clientStream, serverStream)
  249. connectedState.nextStreamID += 1
  250. state = .connected(connectedState)
  251. } catch let acceptStreamError as RPCError {
  252. serverStream.outbound.finish(throwing: acceptStreamError)
  253. clientStream.outbound.finish(throwing: acceptStreamError)
  254. throw acceptStreamError
  255. } catch {
  256. serverStream.outbound.finish(throwing: error)
  257. clientStream.outbound.finish(throwing: error)
  258. throw RPCError(code: .unknown, message: "Unknown error: \(error).")
  259. }
  260. return streamID
  261. case .closed:
  262. let error = RPCError(
  263. code: .failedPrecondition,
  264. message: "The client transport is closed."
  265. )
  266. serverStream.outbound.finish(throwing: error)
  267. clientStream.outbound.finish(throwing: error)
  268. throw error
  269. }
  270. }
  271. defer {
  272. clientStream.outbound.finish()
  273. let maybeEndContinuation = self.state.withLockedValue { state in
  274. switch state {
  275. case .unconnected:
  276. // The state cannot be unconnected at this point, because if we made
  277. // it this far, it's because the transport was connected.
  278. // Once connected, it's impossible to transition back to unconnected,
  279. // so this is an invalid state.
  280. fatalError("Invalid state")
  281. case .connected(var connectedState):
  282. connectedState.openStreams.removeValue(forKey: streamID)
  283. state = .connected(connectedState)
  284. case .closed(var closedState):
  285. closedState.openStreams.removeValue(forKey: streamID)
  286. state = .closed(closedState)
  287. if closedState.openStreams.isEmpty {
  288. // This was the last open stream: signal the closure of the client.
  289. return closedState.signalEndContinuation
  290. }
  291. }
  292. return nil
  293. }
  294. maybeEndContinuation?.finish()
  295. }
  296. return try await closure(clientStream)
  297. }
  298. /// Returns the execution configuration for a given method.
  299. ///
  300. /// - Parameter descriptor: The method to lookup configuration for.
  301. /// - Returns: Execution configuration for the method, if it exists.
  302. public func executionConfiguration(
  303. forMethod descriptor: MethodDescriptor
  304. ) -> MethodConfiguration? {
  305. self.methodConfiguration[descriptor]
  306. }
  307. }