ClientCodeTranslator.swift 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452
  1. /*
  2. * Copyright 2023, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. /// Creates a representation for the client code that will be generated based on the ``CodeGenerationRequest`` object
  17. /// specifications, using types from ``StructuredSwiftRepresentation``.
  18. ///
  19. /// For example, in the case of a service called "Bar", in the "foo" namespace which has
  20. /// one method "baz" with input type "Input" and output type "Output", the ``ClientCodeTranslator`` will create
  21. /// a representation for the following generated code:
  22. ///
  23. /// ```swift
  24. /// @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
  25. /// public protocol Foo_BarClientProtocol: Sendable {
  26. /// func baz<R>(
  27. /// request: ClientRequest.Single<Foo_Bar_Input>,
  28. /// serializer: some MessageSerializer<Foo_Bar_Input>,
  29. /// deserializer: some MessageDeserializer<Foo_Bar_Output>,
  30. /// options: CallOptions = .defaults,
  31. /// _ body: @Sendable @escaping (ClientResponse.Single<Foo_Bar_Output>) async throws -> R
  32. /// ) async throws -> R where R: Sendable
  33. /// }
  34. /// @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
  35. /// extension Foo_Bar.ClientProtocol {
  36. /// public func baz<R>(
  37. /// request: ClientRequest.Single<Foo_Bar_Input>,
  38. /// options: CallOptions = .defaults,
  39. /// _ body: @Sendable @escaping (ClientResponse.Single<Foo_Bar_Output>) async throws -> R
  40. /// ) async throws -> R where R: Sendable {
  41. /// try await self.baz(
  42. /// request: request,
  43. /// serializer: ProtobufSerializer<Foo_Bar_Input>(),
  44. /// deserializer: ProtobufDeserializer<Foo_Bar_Output>(),
  45. /// options: options,
  46. /// body
  47. /// )
  48. /// }
  49. /// @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
  50. /// public struct Foo_BarClient: Foo_Bar.ClientProtocol {
  51. /// private let client: GRPCCore.GRPCClient
  52. /// public init(client: GRPCCore.GRPCClient) {
  53. /// self.client = client
  54. /// }
  55. /// public func methodA<R>(
  56. /// request: ClientRequest.Stream<Foo_Bar_Input>,
  57. /// serializer: some MessageSerializer<Foo_Bar_Input>,
  58. /// deserializer: some MessageDeserializer<Foo_Bar_Output>,
  59. /// options: CallOptions = .defaults,
  60. /// _ body: @Sendable @escaping (ClientResponse.Single<Foo_Bar_Output>) async throws -> R
  61. /// ) async throws -> R where R: Sendable {
  62. /// try await self.client.unary(
  63. /// request: request,
  64. /// descriptor: NamespaceA.ServiceA.Method.MethodA.descriptor,
  65. /// serializer: serializer,
  66. /// deserializer: deserializer,
  67. /// options: options,
  68. /// handler: body
  69. /// )
  70. /// }
  71. /// }
  72. ///```
  73. struct ClientCodeTranslator: SpecializedTranslator {
  74. var accessLevel: SourceGenerator.Configuration.AccessLevel
  75. init(accessLevel: SourceGenerator.Configuration.AccessLevel) {
  76. self.accessLevel = accessLevel
  77. }
  78. func translate(from codeGenerationRequest: CodeGenerationRequest) throws -> [CodeBlock] {
  79. var codeBlocks = [CodeBlock]()
  80. for service in codeGenerationRequest.services {
  81. codeBlocks.append(
  82. .declaration(
  83. .commentable(
  84. .preFormatted(service.documentation),
  85. self.makeClientProtocol(for: service, in: codeGenerationRequest)
  86. )
  87. )
  88. )
  89. codeBlocks.append(
  90. .declaration(self.makeExtensionProtocol(for: service, in: codeGenerationRequest))
  91. )
  92. codeBlocks.append(
  93. .declaration(
  94. .commentable(
  95. .preFormatted(service.documentation),
  96. self.makeClientStruct(for: service, in: codeGenerationRequest)
  97. )
  98. )
  99. )
  100. }
  101. return codeBlocks
  102. }
  103. }
  104. extension ClientCodeTranslator {
  105. private func makeClientProtocol(
  106. for service: CodeGenerationRequest.ServiceDescriptor,
  107. in codeGenerationRequest: CodeGenerationRequest
  108. ) -> Declaration {
  109. let methods = service.methods.map {
  110. self.makeClientProtocolMethod(
  111. for: $0,
  112. in: service,
  113. from: codeGenerationRequest,
  114. includeBody: false,
  115. includeDefaultCallOptions: false
  116. )
  117. }
  118. let clientProtocol = Declaration.protocol(
  119. ProtocolDescription(
  120. accessModifier: self.accessModifier,
  121. name: "\(service.namespacedGeneratedName)ClientProtocol",
  122. conformances: ["Sendable"],
  123. members: methods
  124. )
  125. )
  126. return .guarded(self.availabilityGuard, clientProtocol)
  127. }
  128. private func makeExtensionProtocol(
  129. for service: CodeGenerationRequest.ServiceDescriptor,
  130. in codeGenerationRequest: CodeGenerationRequest
  131. ) -> Declaration {
  132. let methods = service.methods.map {
  133. self.makeClientProtocolMethod(
  134. for: $0,
  135. in: service,
  136. from: codeGenerationRequest,
  137. includeBody: true,
  138. accessModifier: self.accessModifier,
  139. includeDefaultCallOptions: true
  140. )
  141. }
  142. let clientProtocolExtension = Declaration.extension(
  143. ExtensionDescription(
  144. onType: "\(service.namespacedGeneratedName).ClientProtocol",
  145. declarations: methods
  146. )
  147. )
  148. return .guarded(
  149. self.availabilityGuard,
  150. clientProtocolExtension
  151. )
  152. }
  153. private func makeClientProtocolMethod(
  154. for method: CodeGenerationRequest.ServiceDescriptor.MethodDescriptor,
  155. in service: CodeGenerationRequest.ServiceDescriptor,
  156. from codeGenerationRequest: CodeGenerationRequest,
  157. includeBody: Bool,
  158. accessModifier: AccessModifier? = nil,
  159. includeDefaultCallOptions: Bool
  160. ) -> Declaration {
  161. let isProtocolExtension = includeBody
  162. let methodParameters = self.makeParameters(
  163. for: method,
  164. in: service,
  165. from: codeGenerationRequest,
  166. // The serializer/deserializer for the protocol extension method will be auto-generated.
  167. includeSerializationParameters: !isProtocolExtension,
  168. includeDefaultCallOptions: includeDefaultCallOptions
  169. )
  170. let functionSignature = FunctionSignatureDescription(
  171. accessModifier: accessModifier,
  172. kind: .function(
  173. name: method.name.generatedLowerCase,
  174. isStatic: false
  175. ),
  176. generics: [.member("R")],
  177. parameters: methodParameters,
  178. keywords: [.async, .throws],
  179. returnType: .identifierType(.member("R")),
  180. whereClause: WhereClause(requirements: [.conformance("R", "Sendable")])
  181. )
  182. if includeBody {
  183. let body = self.makeClientProtocolMethodCall(
  184. for: method,
  185. in: service,
  186. from: codeGenerationRequest
  187. )
  188. return .function(signature: functionSignature, body: body)
  189. } else {
  190. return .commentable(
  191. .preFormatted(method.documentation),
  192. .function(signature: functionSignature)
  193. )
  194. }
  195. }
  196. private func makeClientProtocolMethodCall(
  197. for method: CodeGenerationRequest.ServiceDescriptor.MethodDescriptor,
  198. in service: CodeGenerationRequest.ServiceDescriptor,
  199. from codeGenerationRequest: CodeGenerationRequest
  200. ) -> [CodeBlock] {
  201. let functionCall = Expression.functionCall(
  202. calledExpression: .memberAccess(
  203. MemberAccessDescription(
  204. left: .identifierPattern("self"),
  205. right: method.name.generatedLowerCase
  206. )
  207. ),
  208. arguments: [
  209. FunctionArgumentDescription(label: "request", expression: .identifierPattern("request")),
  210. FunctionArgumentDescription(
  211. label: "serializer",
  212. expression: .identifierPattern(codeGenerationRequest.lookupSerializer(method.inputType))
  213. ),
  214. FunctionArgumentDescription(
  215. label: "deserializer",
  216. expression: .identifierPattern(
  217. codeGenerationRequest.lookupDeserializer(method.outputType)
  218. )
  219. ),
  220. FunctionArgumentDescription(label: "options", expression: .identifierPattern("options")),
  221. FunctionArgumentDescription(expression: .identifierPattern("body")),
  222. ]
  223. )
  224. let awaitFunctionCall = Expression.unaryKeyword(kind: .await, expression: functionCall)
  225. let tryAwaitFunctionCall = Expression.unaryKeyword(kind: .try, expression: awaitFunctionCall)
  226. return [CodeBlock(item: .expression(tryAwaitFunctionCall))]
  227. }
  228. private func makeParameters(
  229. for method: CodeGenerationRequest.ServiceDescriptor.MethodDescriptor,
  230. in service: CodeGenerationRequest.ServiceDescriptor,
  231. from codeGenerationRequest: CodeGenerationRequest,
  232. includeSerializationParameters: Bool,
  233. includeDefaultCallOptions: Bool
  234. ) -> [ParameterDescription] {
  235. var parameters = [ParameterDescription]()
  236. parameters.append(self.clientRequestParameter(for: method, in: service))
  237. if includeSerializationParameters {
  238. parameters.append(self.serializerParameter(for: method, in: service))
  239. parameters.append(self.deserializerParameter(for: method, in: service))
  240. }
  241. parameters.append(
  242. ParameterDescription(
  243. label: "options",
  244. type: .member("CallOptions"),
  245. defaultValue: includeDefaultCallOptions
  246. ? .memberAccess(MemberAccessDescription(right: "defaults")) : nil
  247. )
  248. )
  249. parameters.append(self.bodyParameter(for: method, in: service))
  250. return parameters
  251. }
  252. private func clientRequestParameter(
  253. for method: CodeGenerationRequest.ServiceDescriptor.MethodDescriptor,
  254. in service: CodeGenerationRequest.ServiceDescriptor
  255. ) -> ParameterDescription {
  256. let requestType = method.isInputStreaming ? "Stream" : "Single"
  257. let clientRequestType = ExistingTypeDescription.member(["ClientRequest", requestType])
  258. return ParameterDescription(
  259. label: "request",
  260. type: .generic(
  261. wrapper: clientRequestType,
  262. wrapped: .member(method.inputType)
  263. )
  264. )
  265. }
  266. private func serializerParameter(
  267. for method: CodeGenerationRequest.ServiceDescriptor.MethodDescriptor,
  268. in service: CodeGenerationRequest.ServiceDescriptor
  269. ) -> ParameterDescription {
  270. return ParameterDescription(
  271. label: "serializer",
  272. type: ExistingTypeDescription.some(
  273. .generic(
  274. wrapper: .member("MessageSerializer"),
  275. wrapped: .member(method.inputType)
  276. )
  277. )
  278. )
  279. }
  280. private func deserializerParameter(
  281. for method: CodeGenerationRequest.ServiceDescriptor.MethodDescriptor,
  282. in service: CodeGenerationRequest.ServiceDescriptor
  283. ) -> ParameterDescription {
  284. return ParameterDescription(
  285. label: "deserializer",
  286. type: ExistingTypeDescription.some(
  287. .generic(
  288. wrapper: .member("MessageDeserializer"),
  289. wrapped: .member(method.outputType)
  290. )
  291. )
  292. )
  293. }
  294. private func bodyParameter(
  295. for method: CodeGenerationRequest.ServiceDescriptor.MethodDescriptor,
  296. in service: CodeGenerationRequest.ServiceDescriptor
  297. ) -> ParameterDescription {
  298. let clientStreaming = method.isOutputStreaming ? "Stream" : "Single"
  299. let closureParameterType = ExistingTypeDescription.generic(
  300. wrapper: .member(["ClientResponse", clientStreaming]),
  301. wrapped: .member(method.outputType)
  302. )
  303. let bodyClosure = ClosureSignatureDescription(
  304. parameters: [.init(type: closureParameterType)],
  305. keywords: [.async, .throws],
  306. returnType: .identifierType(.member("R")),
  307. sendable: true,
  308. escaping: true
  309. )
  310. return ParameterDescription(name: "body", type: .closure(bodyClosure))
  311. }
  312. private func makeClientStruct(
  313. for service: CodeGenerationRequest.ServiceDescriptor,
  314. in codeGenerationRequest: CodeGenerationRequest
  315. ) -> Declaration {
  316. let clientProperty = Declaration.variable(
  317. accessModifier: .private,
  318. kind: .let,
  319. left: "client",
  320. type: .member(["GRPCCore", "GRPCClient"])
  321. )
  322. let initializer = self.makeClientVariable()
  323. let methods = service.methods.map {
  324. Declaration.commentable(
  325. .preFormatted($0.documentation),
  326. self.makeClientMethod(for: $0, in: service, from: codeGenerationRequest)
  327. )
  328. }
  329. return .guarded(
  330. self.availabilityGuard,
  331. .struct(
  332. StructDescription(
  333. accessModifier: self.accessModifier,
  334. name: "\(service.namespacedGeneratedName)Client",
  335. conformances: ["\(service.namespacedGeneratedName).ClientProtocol"],
  336. members: [clientProperty, initializer] + methods
  337. )
  338. )
  339. )
  340. }
  341. private func makeClientVariable() -> Declaration {
  342. let initializerBody = Expression.assignment(
  343. left: .memberAccess(
  344. MemberAccessDescription(left: .identifierPattern("self"), right: "client")
  345. ),
  346. right: .identifierPattern("client")
  347. )
  348. return .function(
  349. signature: .init(
  350. accessModifier: self.accessModifier,
  351. kind: .initializer,
  352. parameters: [.init(label: "client", type: .member(["GRPCCore", "GRPCClient"]))]
  353. ),
  354. body: [CodeBlock(item: .expression(initializerBody))]
  355. )
  356. }
  357. private func clientMethod(
  358. isInputStreaming: Bool,
  359. isOutputStreaming: Bool
  360. ) -> String {
  361. switch (isInputStreaming, isOutputStreaming) {
  362. case (true, true):
  363. return "bidirectionalStreaming"
  364. case (true, false):
  365. return "clientStreaming"
  366. case (false, true):
  367. return "serverStreaming"
  368. case (false, false):
  369. return "unary"
  370. }
  371. }
  372. private func makeClientMethod(
  373. for method: CodeGenerationRequest.ServiceDescriptor.MethodDescriptor,
  374. in service: CodeGenerationRequest.ServiceDescriptor,
  375. from codeGenerationRequest: CodeGenerationRequest
  376. ) -> Declaration {
  377. let parameters = self.makeParameters(
  378. for: method,
  379. in: service,
  380. from: codeGenerationRequest,
  381. includeSerializationParameters: true,
  382. includeDefaultCallOptions: true
  383. )
  384. let grpcMethodName = self.clientMethod(
  385. isInputStreaming: method.isInputStreaming,
  386. isOutputStreaming: method.isOutputStreaming
  387. )
  388. let functionCall = Expression.functionCall(
  389. calledExpression: .memberAccess(
  390. MemberAccessDescription(left: .identifierPattern("self.client"), right: "\(grpcMethodName)")
  391. ),
  392. arguments: [
  393. .init(label: "request", expression: .identifierPattern("request")),
  394. .init(
  395. label: "descriptor",
  396. expression: .identifierPattern(
  397. "\(service.namespacedGeneratedName).Method.\(method.name.generatedUpperCase).descriptor"
  398. )
  399. ),
  400. .init(label: "serializer", expression: .identifierPattern("serializer")),
  401. .init(label: "deserializer", expression: .identifierPattern("deserializer")),
  402. .init(label: "options", expression: .identifierPattern("options")),
  403. .init(label: "handler", expression: .identifierPattern("body")),
  404. ]
  405. )
  406. let body = UnaryKeywordDescription(
  407. kind: .try,
  408. expression: .unaryKeyword(kind: .await, expression: functionCall)
  409. )
  410. return .function(
  411. accessModifier: self.accessModifier,
  412. kind: .function(
  413. name: "\(method.name.generatedLowerCase)",
  414. isStatic: false
  415. ),
  416. generics: [.member("R")],
  417. parameters: parameters,
  418. keywords: [.async, .throws],
  419. returnType: .identifierType(.member("R")),
  420. whereClause: WhereClause(requirements: [.conformance("R", "Sendable")]),
  421. body: [.expression(.unaryKeyword(body))]
  422. )
  423. }
  424. fileprivate enum InputOutputType {
  425. case input
  426. case output
  427. }
  428. }