ServerCodeTranslator.swift 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524
  1. /*
  2. * Copyright 2023, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. /// Creates a representation for the server code that will be generated based on the ``CodeGenerationRequest`` object
  17. /// specifications, using types from ``StructuredSwiftRepresentation``.
  18. ///
  19. /// For example, in the case of a service called "Bar", in the "foo" namespace which has
  20. /// one method "baz", the ``ServerCodeTranslator`` will create
  21. /// a representation for the following generated code:
  22. ///
  23. /// ```swift
  24. /// @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
  25. /// public protocol foo_BarServiceStreamingProtocol: GRPCCore.RegistrableRPCService {
  26. /// func baz(
  27. /// request: ServerRequest.Stream<foo.Method.baz.Input>
  28. /// ) async throws -> ServerResponse.Stream<foo.Method.baz.Output>
  29. /// }
  30. /// // Generated conformance to `RegistrableRPCService`.
  31. /// @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
  32. /// extension foo.Bar.StreamingServiceProtocol {
  33. /// public func registerRPCs(with router: inout RPCRouter) {
  34. /// router.registerHandler(
  35. /// forMethod: foo.Method.baz.descriptor,
  36. /// deserializer: ProtobufDeserializer<foo.Method.baz.Input>(),
  37. /// serializer: ProtobufSerializer<foo.Method.baz.Output>(),
  38. /// handler: { request in try await self.baz(request: request) }
  39. /// )
  40. /// }
  41. /// }
  42. /// @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
  43. /// public protocol foo_BarServiceProtocol: foo.Bar.StreamingServiceProtocol {
  44. /// func baz(
  45. /// request: ServerRequest.Single<foo.Bar.Method.baz.Input>
  46. /// ) async throws -> ServerResponse.Single<foo.Bar.Method.baz.Output>
  47. /// }
  48. /// // Generated partial conformance to `foo_BarStreamingServiceProtocol`.
  49. /// @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
  50. /// extension foo.Bar.ServiceProtocol {
  51. /// public func baz(
  52. /// request: ServerRequest.Stream<foo.Bar.Method.baz.Input>
  53. /// ) async throws -> ServerResponse.Stream<foo.Bar.Method.baz.Output> {
  54. /// let response = try await self.baz(request: ServerRequest.Single(stream: request)
  55. /// return ServerResponse.Stream(single: response)
  56. /// }
  57. /// }
  58. ///```
  59. struct ServerCodeTranslator: SpecializedTranslator {
  60. var accessLevel: SourceGenerator.Configuration.AccessLevel
  61. init(accessLevel: SourceGenerator.Configuration.AccessLevel) {
  62. self.accessLevel = accessLevel
  63. }
  64. func translate(from codeGenerationRequest: CodeGenerationRequest) throws -> [CodeBlock] {
  65. var codeBlocks = [CodeBlock]()
  66. for service in codeGenerationRequest.services {
  67. // Create the streaming protocol that declares the service methods as bidirectional streaming.
  68. let streamingProtocol = CodeBlockItem.declaration(self.makeStreamingProtocol(for: service))
  69. codeBlocks.append(CodeBlock(item: streamingProtocol))
  70. // Create extension for implementing the 'registerRPCs' function which is a 'RegistrableRPCService' requirement.
  71. let conformanceToRPCServiceExtension = CodeBlockItem.declaration(
  72. self.makeConformanceToRPCServiceExtension(for: service, in: codeGenerationRequest)
  73. )
  74. codeBlocks.append(
  75. CodeBlock(
  76. comment: .doc("Conformance to `GRPCCore.RegistrableRPCService`."),
  77. item: conformanceToRPCServiceExtension
  78. )
  79. )
  80. // Create the service protocol that declares the service methods as they are described in the Source IDL (unary,
  81. // client/server streaming or bidirectional streaming).
  82. let serviceProtocol = CodeBlockItem.declaration(self.makeServiceProtocol(for: service))
  83. codeBlocks.append(CodeBlock(item: serviceProtocol))
  84. // Create extension for partial conformance to the streaming protocol.
  85. let extensionServiceProtocol = CodeBlockItem.declaration(
  86. self.makeExtensionServiceProtocol(for: service)
  87. )
  88. codeBlocks.append(
  89. CodeBlock(
  90. comment: .doc(
  91. "Partial conformance to `\(self.protocolName(service: service, streaming: true))`."
  92. ),
  93. item: extensionServiceProtocol
  94. )
  95. )
  96. }
  97. return codeBlocks
  98. }
  99. }
  100. extension ServerCodeTranslator {
  101. private func makeStreamingProtocol(
  102. for service: CodeGenerationRequest.ServiceDescriptor
  103. ) -> Declaration {
  104. let methods = service.methods.compactMap {
  105. Declaration.commentable(
  106. .preFormatted($0.documentation),
  107. .function(
  108. FunctionDescription(
  109. signature: self.makeStreamingMethodSignature(for: $0, in: service)
  110. )
  111. )
  112. )
  113. }
  114. let streamingProtocol = Declaration.protocol(
  115. .init(
  116. accessModifier: self.accessModifier,
  117. name: self.protocolName(service: service, streaming: true),
  118. conformances: ["GRPCCore.RegistrableRPCService"],
  119. members: methods
  120. )
  121. )
  122. return .commentable(
  123. .preFormatted(service.documentation),
  124. .guarded(self.availabilityGuard, streamingProtocol)
  125. )
  126. }
  127. private func makeStreamingMethodSignature(
  128. for method: CodeGenerationRequest.ServiceDescriptor.MethodDescriptor,
  129. in service: CodeGenerationRequest.ServiceDescriptor,
  130. accessModifier: AccessModifier? = nil
  131. ) -> FunctionSignatureDescription {
  132. return FunctionSignatureDescription(
  133. accessModifier: accessModifier,
  134. kind: .function(name: method.name.generatedLowerCase),
  135. parameters: [
  136. .init(
  137. label: "request",
  138. type: .generic(
  139. wrapper: .member(["ServerRequest", "Stream"]),
  140. wrapped: .member(
  141. self.methodInputOutputTypealias(for: method, service: service, type: .input)
  142. )
  143. )
  144. )
  145. ],
  146. keywords: [.async, .throws],
  147. returnType: .identifierType(
  148. .generic(
  149. wrapper: .member(["ServerResponse", "Stream"]),
  150. wrapped: .member(
  151. self.methodInputOutputTypealias(for: method, service: service, type: .output)
  152. )
  153. )
  154. )
  155. )
  156. }
  157. private func makeConformanceToRPCServiceExtension(
  158. for service: CodeGenerationRequest.ServiceDescriptor,
  159. in codeGenerationRequest: CodeGenerationRequest
  160. ) -> Declaration {
  161. let streamingProtocol = self.protocolNameTypealias(service: service, streaming: true)
  162. let registerRPCMethod = self.makeRegisterRPCsMethod(for: service, in: codeGenerationRequest)
  163. return .guarded(
  164. self.availabilityGuard,
  165. .extension(
  166. onType: streamingProtocol,
  167. declarations: [registerRPCMethod]
  168. )
  169. )
  170. }
  171. private func makeRegisterRPCsMethod(
  172. for service: CodeGenerationRequest.ServiceDescriptor,
  173. in codeGenerationRequest: CodeGenerationRequest
  174. ) -> Declaration {
  175. let registerRPCsSignature = FunctionSignatureDescription(
  176. accessModifier: self.accessModifier,
  177. kind: .function(name: "registerMethods"),
  178. parameters: [
  179. .init(
  180. label: "with",
  181. name: "router",
  182. type: .member(["GRPCCore", "RPCRouter"]),
  183. `inout`: true
  184. )
  185. ]
  186. )
  187. let registerRPCsBody = self.makeRegisterRPCsMethodBody(for: service, in: codeGenerationRequest)
  188. return .guarded(
  189. self.availabilityGuard,
  190. .function(signature: registerRPCsSignature, body: registerRPCsBody)
  191. )
  192. }
  193. private func makeRegisterRPCsMethodBody(
  194. for service: CodeGenerationRequest.ServiceDescriptor,
  195. in codeGenerationRequest: CodeGenerationRequest
  196. ) -> [CodeBlock] {
  197. let registerHandlerCalls = service.methods.compactMap {
  198. CodeBlock.expression(
  199. Expression.functionCall(
  200. calledExpression: .memberAccess(
  201. MemberAccessDescription(left: .identifierPattern("router"), right: "registerHandler")
  202. ),
  203. arguments: self.makeArgumentsForRegisterHandler(
  204. for: $0,
  205. in: service,
  206. from: codeGenerationRequest
  207. )
  208. )
  209. )
  210. }
  211. return registerHandlerCalls
  212. }
  213. private func makeArgumentsForRegisterHandler(
  214. for method: CodeGenerationRequest.ServiceDescriptor.MethodDescriptor,
  215. in service: CodeGenerationRequest.ServiceDescriptor,
  216. from codeGenerationRequest: CodeGenerationRequest
  217. ) -> [FunctionArgumentDescription] {
  218. var arguments = [FunctionArgumentDescription]()
  219. arguments.append(
  220. .init(
  221. label: "forMethod",
  222. expression: .identifierPattern(
  223. self.methodDescriptorPath(for: method, service: service)
  224. )
  225. )
  226. )
  227. arguments.append(
  228. .init(
  229. label: "deserializer",
  230. expression: .identifierPattern(
  231. codeGenerationRequest.lookupDeserializer(
  232. self.methodInputOutputTypealias(for: method, service: service, type: .input)
  233. )
  234. )
  235. )
  236. )
  237. arguments.append(
  238. .init(
  239. label: "serializer",
  240. expression:
  241. .identifierPattern(
  242. codeGenerationRequest.lookupSerializer(
  243. self.methodInputOutputTypealias(for: method, service: service, type: .output)
  244. )
  245. )
  246. )
  247. )
  248. let getFunctionCall = Expression.functionCall(
  249. calledExpression: .memberAccess(
  250. MemberAccessDescription(
  251. left: .identifierPattern("self"),
  252. right: method.name.generatedLowerCase
  253. )
  254. ),
  255. arguments: [
  256. FunctionArgumentDescription(label: "request", expression: .identifierPattern("request"))
  257. ]
  258. )
  259. let handlerClosureBody = Expression.unaryKeyword(
  260. kind: .try,
  261. expression: .unaryKeyword(kind: .await, expression: getFunctionCall)
  262. )
  263. arguments.append(
  264. .init(
  265. label: "handler",
  266. expression: .closureInvocation(
  267. .init(argumentNames: ["request"], body: [.expression(handlerClosureBody)])
  268. )
  269. )
  270. )
  271. return arguments
  272. }
  273. private func makeServiceProtocol(
  274. for service: CodeGenerationRequest.ServiceDescriptor
  275. ) -> Declaration {
  276. let methods = service.methods.compactMap {
  277. self.makeServiceProtocolMethod(for: $0, in: service)
  278. }
  279. let protocolName = self.protocolName(service: service, streaming: false)
  280. let streamingProtocol = self.protocolNameTypealias(service: service, streaming: true)
  281. return .commentable(
  282. .preFormatted(service.documentation),
  283. .guarded(
  284. self.availabilityGuard,
  285. .protocol(
  286. ProtocolDescription(
  287. accessModifier: self.accessModifier,
  288. name: protocolName,
  289. conformances: [streamingProtocol],
  290. members: methods
  291. )
  292. )
  293. )
  294. )
  295. }
  296. private func makeServiceProtocolMethod(
  297. for method: CodeGenerationRequest.ServiceDescriptor.MethodDescriptor,
  298. in service: CodeGenerationRequest.ServiceDescriptor,
  299. accessModifier: AccessModifier? = nil
  300. ) -> Declaration {
  301. let inputStreaming = method.isInputStreaming ? "Stream" : "Single"
  302. let outputStreaming = method.isOutputStreaming ? "Stream" : "Single"
  303. let inputTypealiasComponents = self.methodInputOutputTypealias(
  304. for: method,
  305. service: service,
  306. type: .input
  307. )
  308. let outputTypealiasComponents = self.methodInputOutputTypealias(
  309. for: method,
  310. service: service,
  311. type: .output
  312. )
  313. let functionSignature = FunctionSignatureDescription(
  314. accessModifier: accessModifier,
  315. kind: .function(name: method.name.generatedLowerCase),
  316. parameters: [
  317. .init(
  318. label: "request",
  319. type:
  320. .generic(
  321. wrapper: .member(["ServerRequest", inputStreaming]),
  322. wrapped: .member(inputTypealiasComponents)
  323. )
  324. )
  325. ],
  326. keywords: [.async, .throws],
  327. returnType: .identifierType(
  328. .generic(
  329. wrapper: .member(["ServerResponse", outputStreaming]),
  330. wrapped: .member(outputTypealiasComponents)
  331. )
  332. )
  333. )
  334. return .commentable(
  335. .preFormatted(method.documentation),
  336. .function(FunctionDescription(signature: functionSignature))
  337. )
  338. }
  339. private func makeExtensionServiceProtocol(
  340. for service: CodeGenerationRequest.ServiceDescriptor
  341. ) -> Declaration {
  342. let methods = service.methods.compactMap {
  343. self.makeServiceProtocolExtensionMethod(for: $0, in: service)
  344. }
  345. let protocolName = self.protocolNameTypealias(service: service, streaming: false)
  346. return .guarded(
  347. self.availabilityGuard,
  348. .extension(
  349. onType: protocolName,
  350. declarations: methods
  351. )
  352. )
  353. }
  354. private func makeServiceProtocolExtensionMethod(
  355. for method: CodeGenerationRequest.ServiceDescriptor.MethodDescriptor,
  356. in service: CodeGenerationRequest.ServiceDescriptor
  357. ) -> Declaration? {
  358. // The method has the same definition in StreamingServiceProtocol and ServiceProtocol.
  359. if method.isInputStreaming && method.isOutputStreaming {
  360. return nil
  361. }
  362. let response = CodeBlock(item: .declaration(self.makeResponse(for: method)))
  363. let returnStatement = CodeBlock(item: .expression(self.makeReturnStatement(for: method)))
  364. return .function(
  365. signature: self.makeStreamingMethodSignature(
  366. for: method,
  367. in: service,
  368. accessModifier: self.accessModifier
  369. ),
  370. body: [response, returnStatement]
  371. )
  372. }
  373. private func makeResponse(
  374. for method: CodeGenerationRequest.ServiceDescriptor.MethodDescriptor
  375. ) -> Declaration {
  376. let serverRequest: Expression
  377. if !method.isInputStreaming {
  378. // Transform the streaming request into a unary request.
  379. serverRequest = Expression.functionCall(
  380. calledExpression: .memberAccess(
  381. MemberAccessDescription(
  382. left: .identifierPattern("ServerRequest"),
  383. right: "Single"
  384. )
  385. ),
  386. arguments: [
  387. FunctionArgumentDescription(label: "stream", expression: .identifierPattern("request"))
  388. ]
  389. )
  390. } else {
  391. serverRequest = Expression.identifierPattern("request")
  392. }
  393. // Call to the corresponding ServiceProtocol method.
  394. let serviceProtocolMethod = Expression.functionCall(
  395. calledExpression: .memberAccess(
  396. MemberAccessDescription(
  397. left: .identifierPattern("self"),
  398. right: method.name.generatedLowerCase
  399. )
  400. ),
  401. arguments: [FunctionArgumentDescription(label: "request", expression: serverRequest)]
  402. )
  403. let responseValue = Expression.unaryKeyword(
  404. kind: .try,
  405. expression: .unaryKeyword(kind: .await, expression: serviceProtocolMethod)
  406. )
  407. return .variable(kind: .let, left: "response", right: responseValue)
  408. }
  409. private func makeReturnStatement(
  410. for method: CodeGenerationRequest.ServiceDescriptor.MethodDescriptor
  411. ) -> Expression {
  412. let returnValue: Expression
  413. // Transforming the unary response into a streaming one.
  414. if !method.isOutputStreaming {
  415. returnValue = .functionCall(
  416. calledExpression: .memberAccess(
  417. MemberAccessDescription(
  418. left: .identifierType(.member(["ServerResponse"])),
  419. right: "Stream"
  420. )
  421. ),
  422. arguments: [
  423. (FunctionArgumentDescription(label: "single", expression: .identifierPattern("response")))
  424. ]
  425. )
  426. } else {
  427. returnValue = .identifierPattern("response")
  428. }
  429. return .unaryKeyword(kind: .return, expression: returnValue)
  430. }
  431. fileprivate enum InputOutputType {
  432. case input
  433. case output
  434. }
  435. /// Generates the fully qualified name of the typealias for the input or output type of a method.
  436. private func methodInputOutputTypealias(
  437. for method: CodeGenerationRequest.ServiceDescriptor.MethodDescriptor,
  438. service: CodeGenerationRequest.ServiceDescriptor,
  439. type: InputOutputType
  440. ) -> String {
  441. var components: String =
  442. "\(service.namespacedTypealiasGeneratedName).Method.\(method.name.generatedUpperCase)"
  443. switch type {
  444. case .input:
  445. components.append(".Input")
  446. case .output:
  447. components.append(".Output")
  448. }
  449. return components
  450. }
  451. /// Generates the fully qualified name of a method descriptor.
  452. private func methodDescriptorPath(
  453. for method: CodeGenerationRequest.ServiceDescriptor.MethodDescriptor,
  454. service: CodeGenerationRequest.ServiceDescriptor
  455. ) -> String {
  456. return
  457. "\(service.namespacedTypealiasGeneratedName).Method.\(method.name.generatedUpperCase).descriptor"
  458. }
  459. /// Generates the fully qualified name of the type alias for a service protocol.
  460. internal func protocolNameTypealias(
  461. service: CodeGenerationRequest.ServiceDescriptor,
  462. streaming: Bool
  463. ) -> String {
  464. if streaming {
  465. return "\(service.namespacedTypealiasGeneratedName).StreamingServiceProtocol"
  466. }
  467. return "\(service.namespacedTypealiasGeneratedName).ServiceProtocol"
  468. }
  469. /// Generates the name of a service protocol.
  470. internal func protocolName(
  471. service: CodeGenerationRequest.ServiceDescriptor,
  472. streaming: Bool
  473. ) -> String {
  474. if streaming {
  475. return "\(service.namespacedGeneratedName)StreamingServiceProtocol"
  476. }
  477. return "\(service.namespacedGeneratedName)ServiceProtocol"
  478. }
  479. }