ControlService.swift 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. /*
  2. * Copyright 2024, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import Foundation
  17. import GRPCCore
  18. @available(gRPCSwiftNIOTransport 2.0, *)
  19. struct ControlService: RegistrableRPCService {
  20. func registerMethods<Transport: ServerTransport>(with router: inout RPCRouter<Transport>) {
  21. router.registerHandler(
  22. forMethod: MethodDescriptor(fullyQualifiedService: "Control", method: "Unary"),
  23. deserializer: JSONDeserializer<ControlInput>(),
  24. serializer: JSONSerializer<ControlOutput>(),
  25. handler: { request, context in
  26. return try await self.handle(request: request)
  27. }
  28. )
  29. router.registerHandler(
  30. forMethod: MethodDescriptor(fullyQualifiedService: "Control", method: "ServerStream"),
  31. deserializer: JSONDeserializer<ControlInput>(),
  32. serializer: JSONSerializer<ControlOutput>(),
  33. handler: { request, context in
  34. return try await self.handle(request: request)
  35. }
  36. )
  37. router.registerHandler(
  38. forMethod: MethodDescriptor(fullyQualifiedService: "Control", method: "ClientStream"),
  39. deserializer: JSONDeserializer<ControlInput>(),
  40. serializer: JSONSerializer<ControlOutput>(),
  41. handler: { request, context in
  42. return try await self.handle(request: request)
  43. }
  44. )
  45. router.registerHandler(
  46. forMethod: MethodDescriptor(fullyQualifiedService: "Control", method: "BidiStream"),
  47. deserializer: JSONDeserializer<ControlInput>(),
  48. serializer: JSONSerializer<ControlOutput>(),
  49. handler: { request, context in
  50. return try await self.handle(request: request)
  51. }
  52. )
  53. router.registerHandler(
  54. forMethod: MethodDescriptor(fullyQualifiedService: "Control", method: "WaitForCancellation"),
  55. deserializer: JSONDeserializer<CancellationKind>(),
  56. serializer: JSONSerializer<Empty>(),
  57. handler: { request, context in
  58. return try await self.waitForCancellation(
  59. request: ServerRequest(stream: request),
  60. context: context
  61. )
  62. }
  63. )
  64. router.registerHandler(
  65. forMethod: MethodDescriptor(fullyQualifiedService: "Control", method: "PeerInfo"),
  66. deserializer: JSONDeserializer<String>(),
  67. serializer: JSONSerializer<PeerInfoResponse>()
  68. ) { request, context in
  69. return StreamingServerResponse { response in
  70. let peerInfo = PeerInfoResponse(
  71. client: PeerInfoResponse.PeerInfo(
  72. local: clientLocalPeerInfo(request: request),
  73. remote: clientRemotePeerInfo(request: request)
  74. ),
  75. server: PeerInfoResponse.PeerInfo(
  76. local: serverLocalPeerInfo(context: context),
  77. remote: serverRemotePeerInfo(context: context)
  78. )
  79. )
  80. try await response.write(peerInfo)
  81. return [:]
  82. }
  83. }
  84. }
  85. }
  86. @available(gRPCSwiftNIOTransport 2.0, *)
  87. extension ControlService {
  88. private func waitForCancellation(
  89. request: ServerRequest<CancellationKind>,
  90. context: ServerContext
  91. ) async throws -> StreamingServerResponse<Empty> {
  92. switch request.message {
  93. case .awaitCancelled:
  94. return StreamingServerResponse { _ in
  95. try await context.cancellation.cancelled
  96. return [:]
  97. }
  98. case .withCancellationHandler:
  99. let signal = AsyncStream.makeStream(of: Void.self)
  100. return StreamingServerResponse { _ in
  101. await withRPCCancellationHandler {
  102. for await _ in signal.stream {}
  103. return [:]
  104. } onCancelRPC: {
  105. signal.continuation.finish()
  106. }
  107. }
  108. }
  109. }
  110. private func serverRemotePeerInfo(context: ServerContext) -> String {
  111. context.remotePeer
  112. }
  113. private func serverLocalPeerInfo(context: ServerContext) -> String {
  114. context.localPeer
  115. }
  116. private func clientRemotePeerInfo<T>(request: StreamingServerRequest<T>) -> String {
  117. request.metadata[stringValues: "remotePeer"].first(where: { _ in true }) ?? "<missing>"
  118. }
  119. private func clientLocalPeerInfo<T>(request: StreamingServerRequest<T>) -> String {
  120. request.metadata[stringValues: "localPeer"].first(where: { _ in true }) ?? "<missing>"
  121. }
  122. private func handle(
  123. request: StreamingServerRequest<ControlInput>
  124. ) async throws -> StreamingServerResponse<ControlOutput> {
  125. var iterator = request.messages.makeAsyncIterator()
  126. guard let message = try await iterator.next() else {
  127. // Empty input stream, empty output stream.
  128. return StreamingServerResponse { _ in [:] }
  129. }
  130. // Check if the request is for a trailers-only response.
  131. if let status = message.status, message.isTrailersOnly {
  132. var trailers = message.echoMetadataInTrailers ? request.metadata.echo() : [:]
  133. for (key, value) in message.trailingMetadataToAdd {
  134. trailers.addString(value, forKey: key)
  135. }
  136. let code = Status.Code(rawValue: status.code.rawValue).flatMap { RPCError.Code($0) }
  137. if let code = code {
  138. throw RPCError(code: code, message: status.message, metadata: trailers)
  139. } else {
  140. // Invalid code, the request is invalid, so throw an appropriate error.
  141. throw RPCError(
  142. code: .invalidArgument,
  143. message: "Trailers only response must use a non-OK status code"
  144. )
  145. }
  146. }
  147. // Not a trailers-only response. Should the metadata be echo'd back?
  148. var metadata = message.echoMetadataInHeaders ? request.metadata.echo() : [:]
  149. for (key, value) in message.initialMetadataToAdd {
  150. metadata.addString(value, forKey: key)
  151. }
  152. // The iterator needs to be transferred into the response. This is okay: we won't touch the
  153. // iterator again from the current concurrency domain.
  154. let transfer = UnsafeTransfer(iterator)
  155. return StreamingServerResponse(metadata: metadata) { writer in
  156. // Finish dealing with the first message.
  157. switch try await self.processMessage(message, metadata: request.metadata, writer: writer) {
  158. case .return(let metadata):
  159. return metadata
  160. case .continue:
  161. ()
  162. }
  163. var iterator = transfer.wrappedValue
  164. // Process the rest of the messages.
  165. while let message = try await iterator.next() {
  166. switch try await self.processMessage(message, metadata: request.metadata, writer: writer) {
  167. case .return(let metadata):
  168. return metadata
  169. case .continue:
  170. ()
  171. }
  172. }
  173. // Input stream finished without explicitly setting a status; finish the RPC cleanly.
  174. return [:]
  175. }
  176. }
  177. private enum NextProcessingStep {
  178. case `return`(Metadata)
  179. case `continue`
  180. }
  181. private func processMessage(
  182. _ input: ControlInput,
  183. metadata: Metadata,
  184. writer: RPCWriter<ControlOutput>
  185. ) async throws -> NextProcessingStep {
  186. // If messages were requested, build a response and send them back.
  187. if input.numberOfMessages > 0 {
  188. let output = ControlOutput(
  189. payload: Data(
  190. repeating: input.payloadParameters.content,
  191. count: input.payloadParameters.size
  192. )
  193. )
  194. for _ in 0 ..< input.numberOfMessages {
  195. try await writer.write(output)
  196. }
  197. }
  198. // Check whether the RPC should be finished (i.e. the input `hasStatus`).
  199. guard let status = input.status else {
  200. if input.echoMetadataInTrailers || !input.trailingMetadataToAdd.isEmpty {
  201. // There was no status in the input, but echo metadata in trailers was set. This is an
  202. // implicit 'ok' status.
  203. var trailers = input.echoMetadataInTrailers ? metadata.echo() : [:]
  204. for (key, value) in input.trailingMetadataToAdd {
  205. trailers.addString(value, forKey: key)
  206. }
  207. return .return(trailers)
  208. } else {
  209. // No status, and not echoing back metadata. Continue consuming the input stream.
  210. return .continue
  211. }
  212. }
  213. // Build the trailers.
  214. var trailers = input.echoMetadataInTrailers ? metadata.echo() : [:]
  215. for (key, value) in input.trailingMetadataToAdd {
  216. trailers.addString(value, forKey: key)
  217. }
  218. if status.code == .ok {
  219. return .return(trailers)
  220. }
  221. // Non-OK status code, throw an error.
  222. let code = RPCError.Code(status.code)
  223. if let code = code {
  224. // Valid error code, throw it.
  225. throw RPCError(code: code, message: status.message, metadata: trailers)
  226. } else {
  227. // Invalid error code, throw an appropriate error.
  228. throw RPCError(
  229. code: .invalidArgument,
  230. message: "Invalid error code '\(status.code)'"
  231. )
  232. }
  233. }
  234. }
  235. @available(gRPCSwiftNIOTransport 2.0, *)
  236. extension ControlService {
  237. struct PeerInfoResponse: Codable {
  238. struct PeerInfo: Codable {
  239. var local: String
  240. var remote: String
  241. }
  242. var client: PeerInfo
  243. var server: PeerInfo
  244. }
  245. }
  246. @available(gRPCSwiftNIOTransport 2.0, *)
  247. extension Metadata {
  248. fileprivate func echo() -> Self {
  249. var copy = Metadata()
  250. copy.reserveCapacity(self.count)
  251. for (key, value) in self {
  252. // Header field names mustn't contain ":".
  253. let key = "echo-" + key.replacingOccurrences(of: ":", with: "")
  254. switch value {
  255. case .string(let stringValue):
  256. copy.addString(stringValue, forKey: key)
  257. case .binary(let binaryValue):
  258. copy.addBinary(binaryValue, forKey: key)
  259. }
  260. }
  261. return copy
  262. }
  263. }
  264. private struct UnsafeTransfer<Wrapped> {
  265. var wrappedValue: Wrapped
  266. init(_ wrappedValue: Wrapped) {
  267. self.wrappedValue = wrappedValue
  268. }
  269. }
  270. extension UnsafeTransfer: @unchecked Sendable {}
  271. @available(gRPCSwiftNIOTransport 2.0, *)
  272. struct PeerInfoClientInterceptor: ClientInterceptor {
  273. func intercept<Input, Output>(
  274. request: StreamingClientRequest<Input>,
  275. context: ClientContext,
  276. next: (
  277. StreamingClientRequest<Input>,
  278. ClientContext
  279. ) async throws -> StreamingClientResponse<Output>
  280. ) async throws -> StreamingClientResponse<Output> where Input: Sendable, Output: Sendable {
  281. var request = request
  282. request.metadata.addString(context.localPeer, forKey: "localPeer")
  283. request.metadata.addString(context.remotePeer, forKey: "remotePeer")
  284. return try await next(request, context)
  285. }
  286. }