ClientCodeTranslator.swift 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456
  1. /*
  2. * Copyright 2023, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. /// Creates a representation for the client code that will be generated based on the ``CodeGenerationRequest`` object
  17. /// specifications, using types from ``StructuredSwiftRepresentation``.
  18. ///
  19. /// For example, in the case of a service called "Bar", in the "foo" namespace which has
  20. /// one method "baz", the ``ClientCodeTranslator`` will create
  21. /// a representation for the following generated code:
  22. ///
  23. /// ```swift
  24. /// public protocol Foo_BarClientProtocol: Sendable {
  25. /// func baz<R: Sendable>(
  26. /// request: ClientRequest.Single<foo.Bar.Method.baz.Input>,
  27. /// serializer: some MessageSerializer<foo.Bar.Method.baz.Input>,
  28. /// deserializer: some MessageDeserializer<foo.Bar.Method.baz.Output>,
  29. /// _ body: @Sendable @escaping (ClientResponse.Single<foo.Bar.Method.Baz.Output>) async throws -> R
  30. /// ) async throws -> ServerResponse.Stream<foo.Bar.Method.Baz.Output>
  31. /// }
  32. /// extension Foo.Bar.ClientProtocol {
  33. /// public func get<R: Sendable>(
  34. /// request: ClientRequest.Single<Foo.Bar.Method.Baz.Input>,
  35. /// _ body: @Sendable @escaping (ClientResponse.Single<Foo.Bar.Method.Baz.Output>) async throws -> R
  36. /// ) async rethrows -> R {
  37. /// try await self.baz(
  38. /// request: request,
  39. /// serializer: ProtobufSerializer<Foo.Bar.Method.Baz.Input>(),
  40. /// deserializer: ProtobufDeserializer<Foo.Bar.Method.Baz.Output>(),
  41. /// body
  42. /// )
  43. /// }
  44. /// public struct foo_BarClient: foo.Bar.ClientProtocol {
  45. /// private let client: GRPCCore.GRPCClient
  46. /// public init(client: GRPCCore.GRPCClient) {
  47. /// self.client = client
  48. /// }
  49. /// public func methodA<R: Sendable>(
  50. /// request: ClientRequest.Stream<namespaceA.ServiceA.Method.methodA.Input>,
  51. /// serializer: some MessageSerializer<namespaceA.ServiceA.Method.methodA.Input>,
  52. /// deserializer: some MessageDeserializer<namespaceA.ServiceA.Method.methodA.Output>,
  53. /// _ body: @Sendable @escaping (ClientResponse.Single<namespaceA.ServiceA.Method.methodA.Output>) async throws -> R
  54. /// ) async rethrows -> R {
  55. /// try await self.client.clientStreaming(
  56. /// request: request,
  57. /// descriptor: NamespaceA.ServiceA.Method.MethodA.descriptor,
  58. /// serializer: serializer,
  59. /// deserializer: deserializer,
  60. /// handler: body
  61. /// )
  62. /// }
  63. /// }
  64. ///```
  65. struct ClientCodeTranslator: SpecializedTranslator {
  66. var accessLevel: SourceGenerator.Configuration.AccessLevel
  67. init(accessLevel: SourceGenerator.Configuration.AccessLevel) {
  68. self.accessLevel = accessLevel
  69. }
  70. func translate(from codeGenerationRequest: CodeGenerationRequest) throws -> [CodeBlock] {
  71. var codeBlocks = [CodeBlock]()
  72. for service in codeGenerationRequest.services {
  73. codeBlocks.append(
  74. .declaration(
  75. .commentable(
  76. .preFormatted(service.documentation),
  77. self.makeClientProtocol(for: service, in: codeGenerationRequest)
  78. )
  79. )
  80. )
  81. codeBlocks.append(
  82. .declaration(self.makeExtensionProtocol(for: service, in: codeGenerationRequest))
  83. )
  84. codeBlocks.append(
  85. .declaration(
  86. .commentable(
  87. .preFormatted(service.documentation),
  88. self.makeClientStruct(for: service, in: codeGenerationRequest)
  89. )
  90. )
  91. )
  92. }
  93. return codeBlocks
  94. }
  95. }
  96. extension ClientCodeTranslator {
  97. private func makeClientProtocol(
  98. for service: CodeGenerationRequest.ServiceDescriptor,
  99. in codeGenerationRequest: CodeGenerationRequest
  100. ) -> Declaration {
  101. let methods = service.methods.map {
  102. self.makeClientProtocolMethod(
  103. for: $0,
  104. in: service,
  105. from: codeGenerationRequest,
  106. generateSerializerDeserializer: false
  107. )
  108. }
  109. let clientProtocol = Declaration.protocol(
  110. ProtocolDescription(
  111. accessModifier: self.accessModifier,
  112. name: "\(service.namespacedGeneratedName)ClientProtocol",
  113. conformances: ["Sendable"],
  114. members: methods
  115. )
  116. )
  117. return clientProtocol
  118. }
  119. private func makeExtensionProtocol(
  120. for service: CodeGenerationRequest.ServiceDescriptor,
  121. in codeGenerationRequest: CodeGenerationRequest
  122. ) -> Declaration {
  123. let methods = service.methods.map {
  124. self.makeClientProtocolMethod(
  125. for: $0,
  126. in: service,
  127. from: codeGenerationRequest,
  128. generateSerializerDeserializer: true,
  129. accessModifier: self.accessModifier
  130. )
  131. }
  132. let clientProtocolExtension = Declaration.extension(
  133. ExtensionDescription(
  134. onType: "\(service.namespacedTypealiasGeneratedName).ClientProtocol",
  135. declarations: methods
  136. )
  137. )
  138. return clientProtocolExtension
  139. }
  140. private func makeClientProtocolMethod(
  141. for method: CodeGenerationRequest.ServiceDescriptor.MethodDescriptor,
  142. in service: CodeGenerationRequest.ServiceDescriptor,
  143. from codeGenerationRequest: CodeGenerationRequest,
  144. generateSerializerDeserializer: Bool,
  145. accessModifier: AccessModifier? = nil
  146. ) -> Declaration {
  147. let methodParameters = self.makeParameters(
  148. for: method,
  149. in: service,
  150. from: codeGenerationRequest,
  151. generateSerializerDeserializer: generateSerializerDeserializer
  152. )
  153. let functionSignature = FunctionSignatureDescription(
  154. accessModifier: accessModifier,
  155. kind: .function(
  156. name: method.name.generatedLowerCase,
  157. isStatic: false
  158. ),
  159. generics: [.member("R")],
  160. parameters: methodParameters,
  161. keywords: [.async, .throws],
  162. returnType: .identifierType(.member("R")),
  163. whereClause: WhereClause(requirements: [.conformance("R", "Sendable")])
  164. )
  165. if generateSerializerDeserializer {
  166. let body = self.makeSerializerDeserializerCall(
  167. for: method,
  168. in: service,
  169. from: codeGenerationRequest
  170. )
  171. return .function(signature: functionSignature, body: body)
  172. } else {
  173. return .commentable(
  174. .preFormatted(method.documentation),
  175. .function(signature: functionSignature)
  176. )
  177. }
  178. }
  179. private func makeSerializerDeserializerCall(
  180. for method: CodeGenerationRequest.ServiceDescriptor.MethodDescriptor,
  181. in service: CodeGenerationRequest.ServiceDescriptor,
  182. from codeGenerationRequest: CodeGenerationRequest
  183. ) -> [CodeBlock] {
  184. let functionCall = Expression.functionCall(
  185. calledExpression: .memberAccess(
  186. MemberAccessDescription(
  187. left: .identifierPattern("self"),
  188. right: method.name.generatedLowerCase
  189. )
  190. ),
  191. arguments: [
  192. FunctionArgumentDescription(label: "request", expression: .identifierPattern("request")),
  193. FunctionArgumentDescription(
  194. label: "serializer",
  195. expression: .identifierPattern(
  196. codeGenerationRequest.lookupSerializer(
  197. self.methodInputOutputTypealias(for: method, service: service, type: .input)
  198. )
  199. )
  200. ),
  201. FunctionArgumentDescription(
  202. label: "deserializer",
  203. expression: .identifierPattern(
  204. codeGenerationRequest.lookupDeserializer(
  205. self.methodInputOutputTypealias(for: method, service: service, type: .output)
  206. )
  207. )
  208. ),
  209. FunctionArgumentDescription(expression: .identifierPattern("body")),
  210. ]
  211. )
  212. let awaitFunctionCall = Expression.unaryKeyword(kind: .await, expression: functionCall)
  213. let tryAwaitFunctionCall = Expression.unaryKeyword(kind: .try, expression: awaitFunctionCall)
  214. return [CodeBlock(item: .expression(tryAwaitFunctionCall))]
  215. }
  216. private func makeParameters(
  217. for method: CodeGenerationRequest.ServiceDescriptor.MethodDescriptor,
  218. in service: CodeGenerationRequest.ServiceDescriptor,
  219. from codeGenerationRequest: CodeGenerationRequest,
  220. generateSerializerDeserializer: Bool
  221. ) -> [ParameterDescription] {
  222. var parameters = [ParameterDescription]()
  223. parameters.append(self.clientRequestParameter(for: method, in: service))
  224. if !generateSerializerDeserializer {
  225. parameters.append(self.serializerParameter(for: method, in: service))
  226. parameters.append(self.deserializerParameter(for: method, in: service))
  227. }
  228. parameters.append(self.bodyParameter(for: method, in: service))
  229. return parameters
  230. }
  231. private func clientRequestParameter(
  232. for method: CodeGenerationRequest.ServiceDescriptor.MethodDescriptor,
  233. in service: CodeGenerationRequest.ServiceDescriptor
  234. ) -> ParameterDescription {
  235. let requestType = method.isInputStreaming ? "Stream" : "Single"
  236. let clientRequestType = ExistingTypeDescription.member(["ClientRequest", requestType])
  237. return ParameterDescription(
  238. label: "request",
  239. type: .generic(
  240. wrapper: clientRequestType,
  241. wrapped: .member(
  242. self.methodInputOutputTypealias(for: method, service: service, type: .input)
  243. )
  244. )
  245. )
  246. }
  247. private func serializerParameter(
  248. for method: CodeGenerationRequest.ServiceDescriptor.MethodDescriptor,
  249. in service: CodeGenerationRequest.ServiceDescriptor
  250. ) -> ParameterDescription {
  251. return ParameterDescription(
  252. label: "serializer",
  253. type: ExistingTypeDescription.some(
  254. .generic(
  255. wrapper: .member("MessageSerializer"),
  256. wrapped: .member(
  257. self.methodInputOutputTypealias(for: method, service: service, type: .input)
  258. )
  259. )
  260. )
  261. )
  262. }
  263. private func deserializerParameter(
  264. for method: CodeGenerationRequest.ServiceDescriptor.MethodDescriptor,
  265. in service: CodeGenerationRequest.ServiceDescriptor
  266. ) -> ParameterDescription {
  267. return ParameterDescription(
  268. label: "deserializer",
  269. type: ExistingTypeDescription.some(
  270. .generic(
  271. wrapper: .member("MessageDeserializer"),
  272. wrapped: .member(
  273. self.methodInputOutputTypealias(for: method, service: service, type: .output)
  274. )
  275. )
  276. )
  277. )
  278. }
  279. private func bodyParameter(
  280. for method: CodeGenerationRequest.ServiceDescriptor.MethodDescriptor,
  281. in service: CodeGenerationRequest.ServiceDescriptor
  282. ) -> ParameterDescription {
  283. let clientStreaming = method.isOutputStreaming ? "Stream" : "Single"
  284. let closureParameterType = ExistingTypeDescription.generic(
  285. wrapper: .member(["ClientResponse", clientStreaming]),
  286. wrapped: .member(
  287. self.methodInputOutputTypealias(for: method, service: service, type: .output)
  288. )
  289. )
  290. let bodyClosure = ClosureSignatureDescription(
  291. parameters: [.init(type: closureParameterType)],
  292. keywords: [.async, .throws],
  293. returnType: .identifierType(.member("R")),
  294. sendable: true,
  295. escaping: true
  296. )
  297. return ParameterDescription(name: "body", type: .closure(bodyClosure))
  298. }
  299. private func makeClientStruct(
  300. for service: CodeGenerationRequest.ServiceDescriptor,
  301. in codeGenerationRequest: CodeGenerationRequest
  302. ) -> Declaration {
  303. let clientProperty = Declaration.variable(
  304. accessModifier: .private,
  305. kind: .let,
  306. left: "client",
  307. type: .member(["GRPCCore", "GRPCClient"])
  308. )
  309. let initializer = self.makeClientVariable()
  310. let methods = service.methods.map {
  311. Declaration.commentable(
  312. .preFormatted($0.documentation),
  313. self.makeClientMethod(for: $0, in: service, from: codeGenerationRequest)
  314. )
  315. }
  316. return .guarded(
  317. self.availabilityGuard,
  318. .struct(
  319. StructDescription(
  320. accessModifier: self.accessModifier,
  321. name: "\(service.namespacedGeneratedName)Client",
  322. conformances: ["\(service.namespacedTypealiasGeneratedName).ClientProtocol"],
  323. members: [clientProperty, initializer] + methods
  324. )
  325. )
  326. )
  327. }
  328. private func makeClientVariable() -> Declaration {
  329. let initializerBody = Expression.assignment(
  330. left: .memberAccess(
  331. MemberAccessDescription(left: .identifierPattern("self"), right: "client")
  332. ),
  333. right: .identifierPattern("client")
  334. )
  335. return .function(
  336. signature: .init(
  337. accessModifier: self.accessModifier,
  338. kind: .initializer,
  339. parameters: [.init(label: "client", type: .member(["GRPCCore", "GRPCClient"]))]
  340. ),
  341. body: [CodeBlock(item: .expression(initializerBody))]
  342. )
  343. }
  344. private func clientMethod(
  345. isInputStreaming: Bool,
  346. isOutputStreaming: Bool
  347. ) -> String {
  348. switch (isInputStreaming, isOutputStreaming) {
  349. case (true, true):
  350. return "bidirectionalStreaming"
  351. case (true, false):
  352. return "clientStreaming"
  353. case (false, true):
  354. return "serverStreaming"
  355. case (false, false):
  356. return "unary"
  357. }
  358. }
  359. private func makeClientMethod(
  360. for method: CodeGenerationRequest.ServiceDescriptor.MethodDescriptor,
  361. in service: CodeGenerationRequest.ServiceDescriptor,
  362. from codeGenerationRequest: CodeGenerationRequest
  363. ) -> Declaration {
  364. let parameters = self.makeParameters(
  365. for: method,
  366. in: service,
  367. from: codeGenerationRequest,
  368. generateSerializerDeserializer: false
  369. )
  370. let grpcMethodName = self.clientMethod(
  371. isInputStreaming: method.isInputStreaming,
  372. isOutputStreaming: method.isOutputStreaming
  373. )
  374. let functionCall = Expression.functionCall(
  375. calledExpression: .memberAccess(
  376. MemberAccessDescription(left: .identifierPattern("self.client"), right: "\(grpcMethodName)")
  377. ),
  378. arguments: [
  379. .init(label: "request", expression: .identifierPattern("request")),
  380. .init(
  381. label: "descriptor",
  382. expression: .identifierPattern(
  383. "\(service.namespacedTypealiasGeneratedName).Method.\(method.name.generatedUpperCase).descriptor"
  384. )
  385. ),
  386. .init(label: "serializer", expression: .identifierPattern("serializer")),
  387. .init(label: "deserializer", expression: .identifierPattern("deserializer")),
  388. .init(label: "handler", expression: .identifierPattern("body")),
  389. ]
  390. )
  391. let body = UnaryKeywordDescription(
  392. kind: .try,
  393. expression: .unaryKeyword(kind: .await, expression: functionCall)
  394. )
  395. return .function(
  396. accessModifier: self.accessModifier,
  397. kind: .function(
  398. name: "\(method.name.generatedLowerCase)",
  399. isStatic: false
  400. ),
  401. generics: [.member("R")],
  402. parameters: parameters,
  403. keywords: [.async, .throws],
  404. returnType: .identifierType(.member("R")),
  405. whereClause: WhereClause(requirements: [.conformance("R", "Sendable")]),
  406. body: [.expression(.unaryKeyword(body))]
  407. )
  408. }
  409. fileprivate enum InputOutputType {
  410. case input
  411. case output
  412. }
  413. /// Generates the fully qualified name of the typealias for the input or output type of a method.
  414. private func methodInputOutputTypealias(
  415. for method: CodeGenerationRequest.ServiceDescriptor.MethodDescriptor,
  416. service: CodeGenerationRequest.ServiceDescriptor,
  417. type: InputOutputType
  418. ) -> String {
  419. var components: String =
  420. "\(service.namespacedTypealiasGeneratedName).Method.\(method.name.generatedUpperCase)"
  421. switch type {
  422. case .input:
  423. components.append(".Input")
  424. case .output:
  425. components.append(".Output")
  426. }
  427. return components
  428. }
  429. }