ControlService.swift 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. /*
  2. * Copyright 2024, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import Foundation
  17. import GRPCCore
  18. struct ControlService: RegistrableRPCService {
  19. func registerMethods(with router: inout RPCRouter) {
  20. router.registerHandler(
  21. forMethod: MethodDescriptor(fullyQualifiedService: "Control", method: "Unary"),
  22. deserializer: JSONDeserializer<ControlInput>(),
  23. serializer: JSONSerializer<ControlOutput>(),
  24. handler: { request, context in
  25. return try await self.handle(request: request)
  26. }
  27. )
  28. router.registerHandler(
  29. forMethod: MethodDescriptor(fullyQualifiedService: "Control", method: "ServerStream"),
  30. deserializer: JSONDeserializer<ControlInput>(),
  31. serializer: JSONSerializer<ControlOutput>(),
  32. handler: { request, context in
  33. return try await self.handle(request: request)
  34. }
  35. )
  36. router.registerHandler(
  37. forMethod: MethodDescriptor(fullyQualifiedService: "Control", method: "ClientStream"),
  38. deserializer: JSONDeserializer<ControlInput>(),
  39. serializer: JSONSerializer<ControlOutput>(),
  40. handler: { request, context in
  41. return try await self.handle(request: request)
  42. }
  43. )
  44. router.registerHandler(
  45. forMethod: MethodDescriptor(fullyQualifiedService: "Control", method: "BidiStream"),
  46. deserializer: JSONDeserializer<ControlInput>(),
  47. serializer: JSONSerializer<ControlOutput>(),
  48. handler: { request, context in
  49. return try await self.handle(request: request)
  50. }
  51. )
  52. router.registerHandler(
  53. forMethod: MethodDescriptor(fullyQualifiedService: "Control", method: "WaitForCancellation"),
  54. deserializer: JSONDeserializer<CancellationKind>(),
  55. serializer: JSONSerializer<Empty>(),
  56. handler: { request, context in
  57. return try await self.waitForCancellation(
  58. request: ServerRequest(stream: request),
  59. context: context
  60. )
  61. }
  62. )
  63. router.registerHandler(
  64. forMethod: MethodDescriptor(fullyQualifiedService: "Control", method: "PeerInfo"),
  65. deserializer: JSONDeserializer<String>(),
  66. serializer: JSONSerializer<PeerInfoResponse>()
  67. ) { request, context in
  68. return StreamingServerResponse { response in
  69. let peerInfo = PeerInfoResponse(
  70. client: PeerInfoResponse.PeerInfo(
  71. local: clientLocalPeerInfo(request: request),
  72. remote: clientRemotePeerInfo(request: request)
  73. ),
  74. server: PeerInfoResponse.PeerInfo(
  75. local: serverLocalPeerInfo(context: context),
  76. remote: serverRemotePeerInfo(context: context)
  77. )
  78. )
  79. try await response.write(peerInfo)
  80. return [:]
  81. }
  82. }
  83. }
  84. }
  85. extension ControlService {
  86. private func waitForCancellation(
  87. request: ServerRequest<CancellationKind>,
  88. context: ServerContext
  89. ) async throws -> StreamingServerResponse<Empty> {
  90. switch request.message {
  91. case .awaitCancelled:
  92. return StreamingServerResponse { _ in
  93. try await context.cancellation.cancelled
  94. return [:]
  95. }
  96. case .withCancellationHandler:
  97. let signal = AsyncStream.makeStream(of: Void.self)
  98. return StreamingServerResponse { _ in
  99. await withRPCCancellationHandler {
  100. for await _ in signal.stream {}
  101. return [:]
  102. } onCancelRPC: {
  103. signal.continuation.finish()
  104. }
  105. }
  106. }
  107. }
  108. private func serverRemotePeerInfo(context: ServerContext) -> String {
  109. context.remotePeer
  110. }
  111. private func serverLocalPeerInfo(context: ServerContext) -> String {
  112. context.localPeer
  113. }
  114. private func clientRemotePeerInfo<T>(request: StreamingServerRequest<T>) -> String {
  115. request.metadata[stringValues: "remotePeer"].first(where: { _ in true })!
  116. }
  117. private func clientLocalPeerInfo<T>(request: StreamingServerRequest<T>) -> String {
  118. request.metadata[stringValues: "localPeer"].first(where: { _ in true })!
  119. }
  120. private func handle(
  121. request: StreamingServerRequest<ControlInput>
  122. ) async throws -> StreamingServerResponse<ControlOutput> {
  123. var iterator = request.messages.makeAsyncIterator()
  124. guard let message = try await iterator.next() else {
  125. // Empty input stream, empty output stream.
  126. return StreamingServerResponse { _ in [:] }
  127. }
  128. // Check if the request is for a trailers-only response.
  129. if let status = message.status, message.isTrailersOnly {
  130. var trailers = message.echoMetadataInTrailers ? request.metadata.echo() : [:]
  131. for (key, value) in message.trailingMetadataToAdd {
  132. trailers.addString(value, forKey: key)
  133. }
  134. let code = Status.Code(rawValue: status.code.rawValue).flatMap { RPCError.Code($0) }
  135. if let code = code {
  136. throw RPCError(code: code, message: status.message, metadata: trailers)
  137. } else {
  138. // Invalid code, the request is invalid, so throw an appropriate error.
  139. throw RPCError(
  140. code: .invalidArgument,
  141. message: "Trailers only response must use a non-OK status code"
  142. )
  143. }
  144. }
  145. // Not a trailers-only response. Should the metadata be echo'd back?
  146. var metadata = message.echoMetadataInHeaders ? request.metadata.echo() : [:]
  147. for (key, value) in message.initialMetadataToAdd {
  148. metadata.addString(value, forKey: key)
  149. }
  150. // The iterator needs to be transferred into the response. This is okay: we won't touch the
  151. // iterator again from the current concurrency domain.
  152. let transfer = UnsafeTransfer(iterator)
  153. return StreamingServerResponse(metadata: metadata) { writer in
  154. // Finish dealing with the first message.
  155. switch try await self.processMessage(message, metadata: request.metadata, writer: writer) {
  156. case .return(let metadata):
  157. return metadata
  158. case .continue:
  159. ()
  160. }
  161. var iterator = transfer.wrappedValue
  162. // Process the rest of the messages.
  163. while let message = try await iterator.next() {
  164. switch try await self.processMessage(message, metadata: request.metadata, writer: writer) {
  165. case .return(let metadata):
  166. return metadata
  167. case .continue:
  168. ()
  169. }
  170. }
  171. // Input stream finished without explicitly setting a status; finish the RPC cleanly.
  172. return [:]
  173. }
  174. }
  175. private enum NextProcessingStep {
  176. case `return`(Metadata)
  177. case `continue`
  178. }
  179. private func processMessage(
  180. _ input: ControlInput,
  181. metadata: Metadata,
  182. writer: RPCWriter<ControlOutput>
  183. ) async throws -> NextProcessingStep {
  184. // If messages were requested, build a response and send them back.
  185. if input.numberOfMessages > 0 {
  186. let output = ControlOutput(
  187. payload: Data(
  188. repeating: input.payloadParameters.content,
  189. count: input.payloadParameters.size
  190. )
  191. )
  192. for _ in 0 ..< input.numberOfMessages {
  193. try await writer.write(output)
  194. }
  195. }
  196. // Check whether the RPC should be finished (i.e. the input `hasStatus`).
  197. guard let status = input.status else {
  198. if input.echoMetadataInTrailers || !input.trailingMetadataToAdd.isEmpty {
  199. // There was no status in the input, but echo metadata in trailers was set. This is an
  200. // implicit 'ok' status.
  201. var trailers = input.echoMetadataInTrailers ? metadata.echo() : [:]
  202. for (key, value) in input.trailingMetadataToAdd {
  203. trailers.addString(value, forKey: key)
  204. }
  205. return .return(trailers)
  206. } else {
  207. // No status, and not echoing back metadata. Continue consuming the input stream.
  208. return .continue
  209. }
  210. }
  211. // Build the trailers.
  212. var trailers = input.echoMetadataInTrailers ? metadata.echo() : [:]
  213. for (key, value) in input.trailingMetadataToAdd {
  214. trailers.addString(value, forKey: key)
  215. }
  216. if status.code == .ok {
  217. return .return(trailers)
  218. }
  219. // Non-OK status code, throw an error.
  220. let code = RPCError.Code(status.code)
  221. if let code = code {
  222. // Valid error code, throw it.
  223. throw RPCError(code: code, message: status.message, metadata: trailers)
  224. } else {
  225. // Invalid error code, throw an appropriate error.
  226. throw RPCError(
  227. code: .invalidArgument,
  228. message: "Invalid error code '\(status.code)'"
  229. )
  230. }
  231. }
  232. }
  233. extension ControlService {
  234. struct PeerInfoResponse: Codable {
  235. struct PeerInfo: Codable {
  236. var local: String
  237. var remote: String
  238. }
  239. var client: PeerInfo
  240. var server: PeerInfo
  241. }
  242. }
  243. extension Metadata {
  244. fileprivate func echo() -> Self {
  245. var copy = Metadata()
  246. copy.reserveCapacity(self.count)
  247. for (key, value) in self {
  248. // Header field names mustn't contain ":".
  249. let key = "echo-" + key.replacingOccurrences(of: ":", with: "")
  250. switch value {
  251. case .string(let stringValue):
  252. copy.addString(stringValue, forKey: key)
  253. case .binary(let binaryValue):
  254. copy.addBinary(binaryValue, forKey: key)
  255. }
  256. }
  257. return copy
  258. }
  259. }
  260. private struct UnsafeTransfer<Wrapped> {
  261. var wrappedValue: Wrapped
  262. init(_ wrappedValue: Wrapped) {
  263. self.wrappedValue = wrappedValue
  264. }
  265. }
  266. extension UnsafeTransfer: @unchecked Sendable {}
  267. struct PeerInfoClientInterceptor: ClientInterceptor {
  268. func intercept<Input, Output>(
  269. request: StreamingClientRequest<Input>,
  270. context: ClientContext,
  271. next: (
  272. StreamingClientRequest<Input>,
  273. ClientContext
  274. ) async throws -> StreamingClientResponse<Output>
  275. ) async throws -> StreamingClientResponse<Output> where Input: Sendable, Output: Sendable {
  276. var request = request
  277. request.metadata.addString(context.localPeer, forKey: "localPeer")
  278. request.metadata.addString(context.remotePeer, forKey: "remotePeer")
  279. return try await next(request, context)
  280. }
  281. }