ServerCodeTranslator.swift 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511
  1. /*
  2. * Copyright 2023, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. /// Creates a representation for the server code that will be generated based on the ``CodeGenerationRequest`` object
  17. /// specifications, using types from ``StructuredSwiftRepresentation``.
  18. ///
  19. /// For example, in the case of a service called "Bar", in the "foo" namespace which has
  20. /// one method "baz", the ``ServerCodeTranslator`` will create
  21. /// a representation for the following generated code:
  22. ///
  23. /// ```swift
  24. /// public protocol foo_BarServiceStreamingProtocol: GRPCCore.RegistrableRPCService {
  25. /// func baz(
  26. /// request: ServerRequest.Stream<foo.Method.baz.Input>
  27. /// ) async throws -> ServerResponse.Stream<foo.Method.baz.Output>
  28. /// }
  29. /// // Generated conformance to `RegistrableRPCService`.
  30. /// extension foo.Bar.StreamingServiceProtocol {
  31. /// public func registerRPCs(with router: inout RPCRouter) {
  32. /// router.registerHandler(
  33. /// forMethod: foo.Method.baz.descriptor,
  34. /// deserializer: ProtobufDeserializer<foo.Method.baz.Input>(),
  35. /// serializer: ProtobufSerializer<foo.Method.baz.Output>(),
  36. /// handler: { request in try await self.baz(request: request) }
  37. /// )
  38. /// }
  39. /// }
  40. /// public protocol foo_BarServiceProtocol: foo.Bar.StreamingServiceProtocol {
  41. /// func baz(
  42. /// request: ServerRequest.Single<foo.Bar.Method.baz.Input>
  43. /// ) async throws -> ServerResponse.Single<foo.Bar.Method.baz.Output>
  44. /// }
  45. /// // Generated partial conformance to `foo_BarStreamingServiceProtocol`.
  46. /// extension foo.Bar.ServiceProtocol {
  47. /// public func baz(
  48. /// request: ServerRequest.Stream<foo.Bar.Method.baz.Input>
  49. /// ) async throws -> ServerResponse.Stream<foo.Bar.Method.baz.Output> {
  50. /// let response = try await self.baz(request: ServerRequest.Single(stream: request)
  51. /// return ServerResponse.Stream(single: response)
  52. /// }
  53. /// }
  54. ///```
  55. struct ServerCodeTranslator: SpecializedTranslator {
  56. var accessLevel: SourceGenerator.Configuration.AccessLevel
  57. init(accessLevel: SourceGenerator.Configuration.AccessLevel) {
  58. self.accessLevel = accessLevel
  59. }
  60. func translate(from codeGenerationRequest: CodeGenerationRequest) throws -> [CodeBlock] {
  61. var codeBlocks = [CodeBlock]()
  62. for service in codeGenerationRequest.services {
  63. // Create the streaming protocol that declares the service methods as bidirectional streaming.
  64. let streamingProtocol = CodeBlockItem.declaration(self.makeStreamingProtocol(for: service))
  65. codeBlocks.append(CodeBlock(item: streamingProtocol))
  66. // Create extension for implementing the 'registerRPCs' function which is a 'RegistrableRPCService' requirement.
  67. let conformanceToRPCServiceExtension = CodeBlockItem.declaration(
  68. self.makeConformanceToRPCServiceExtension(for: service, in: codeGenerationRequest)
  69. )
  70. codeBlocks.append(
  71. CodeBlock(
  72. comment: .doc("Conformance to `GRPCCore.RegistrableRPCService`."),
  73. item: conformanceToRPCServiceExtension
  74. )
  75. )
  76. // Create the service protocol that declares the service methods as they are described in the Source IDL (unary,
  77. // client/server streaming or bidirectional streaming).
  78. let serviceProtocol = CodeBlockItem.declaration(self.makeServiceProtocol(for: service))
  79. codeBlocks.append(CodeBlock(item: serviceProtocol))
  80. // Create extension for partial conformance to the streaming protocol.
  81. let extensionServiceProtocol = CodeBlockItem.declaration(
  82. self.makeExtensionServiceProtocol(for: service)
  83. )
  84. codeBlocks.append(
  85. CodeBlock(
  86. comment: .doc(
  87. "Partial conformance to `\(self.protocolName(service: service, streaming: true))`."
  88. ),
  89. item: extensionServiceProtocol
  90. )
  91. )
  92. }
  93. return codeBlocks
  94. }
  95. }
  96. extension ServerCodeTranslator {
  97. private func makeStreamingProtocol(
  98. for service: CodeGenerationRequest.ServiceDescriptor
  99. ) -> Declaration {
  100. let methods = service.methods.compactMap {
  101. Declaration.commentable(
  102. .preFormatted($0.documentation),
  103. .function(
  104. FunctionDescription(
  105. signature: self.makeStreamingMethodSignature(for: $0, in: service)
  106. )
  107. )
  108. )
  109. }
  110. let streamingProtocol = Declaration.protocol(
  111. .init(
  112. accessModifier: self.accessModifier,
  113. name: self.protocolName(service: service, streaming: true),
  114. conformances: ["GRPCCore.RegistrableRPCService"],
  115. members: methods
  116. )
  117. )
  118. return .commentable(
  119. .preFormatted(service.documentation),
  120. .guarded(self.availabilityGuard, streamingProtocol)
  121. )
  122. }
  123. private func makeStreamingMethodSignature(
  124. for method: CodeGenerationRequest.ServiceDescriptor.MethodDescriptor,
  125. in service: CodeGenerationRequest.ServiceDescriptor,
  126. accessModifier: AccessModifier? = nil
  127. ) -> FunctionSignatureDescription {
  128. return FunctionSignatureDescription(
  129. accessModifier: accessModifier,
  130. kind: .function(name: method.name.generatedLowerCase),
  131. parameters: [
  132. .init(
  133. label: "request",
  134. type: .generic(
  135. wrapper: .member(["ServerRequest", "Stream"]),
  136. wrapped: .member(
  137. self.methodInputOutputTypealias(for: method, service: service, type: .input)
  138. )
  139. )
  140. )
  141. ],
  142. keywords: [.async, .throws],
  143. returnType: .identifierType(
  144. .generic(
  145. wrapper: .member(["ServerResponse", "Stream"]),
  146. wrapped: .member(
  147. self.methodInputOutputTypealias(for: method, service: service, type: .output)
  148. )
  149. )
  150. )
  151. )
  152. }
  153. private func makeConformanceToRPCServiceExtension(
  154. for service: CodeGenerationRequest.ServiceDescriptor,
  155. in codeGenerationRequest: CodeGenerationRequest
  156. ) -> Declaration {
  157. let streamingProtocol = self.protocolNameTypealias(service: service, streaming: true)
  158. let registerRPCMethod = self.makeRegisterRPCsMethod(for: service, in: codeGenerationRequest)
  159. return .extension(
  160. onType: streamingProtocol,
  161. declarations: [registerRPCMethod]
  162. )
  163. }
  164. private func makeRegisterRPCsMethod(
  165. for service: CodeGenerationRequest.ServiceDescriptor,
  166. in codeGenerationRequest: CodeGenerationRequest
  167. ) -> Declaration {
  168. let registerRPCsSignature = FunctionSignatureDescription(
  169. accessModifier: self.accessModifier,
  170. kind: .function(name: "registerMethods"),
  171. parameters: [
  172. .init(
  173. label: "with",
  174. name: "router",
  175. type: .member(["GRPCCore", "RPCRouter"]),
  176. `inout`: true
  177. )
  178. ]
  179. )
  180. let registerRPCsBody = self.makeRegisterRPCsMethodBody(for: service, in: codeGenerationRequest)
  181. return .guarded(
  182. self.availabilityGuard,
  183. .function(signature: registerRPCsSignature, body: registerRPCsBody)
  184. )
  185. }
  186. private func makeRegisterRPCsMethodBody(
  187. for service: CodeGenerationRequest.ServiceDescriptor,
  188. in codeGenerationRequest: CodeGenerationRequest
  189. ) -> [CodeBlock] {
  190. let registerHandlerCalls = service.methods.compactMap {
  191. CodeBlock.expression(
  192. Expression.functionCall(
  193. calledExpression: .memberAccess(
  194. MemberAccessDescription(left: .identifierPattern("router"), right: "registerHandler")
  195. ),
  196. arguments: self.makeArgumentsForRegisterHandler(
  197. for: $0,
  198. in: service,
  199. from: codeGenerationRequest
  200. )
  201. )
  202. )
  203. }
  204. return registerHandlerCalls
  205. }
  206. private func makeArgumentsForRegisterHandler(
  207. for method: CodeGenerationRequest.ServiceDescriptor.MethodDescriptor,
  208. in service: CodeGenerationRequest.ServiceDescriptor,
  209. from codeGenerationRequest: CodeGenerationRequest
  210. ) -> [FunctionArgumentDescription] {
  211. var arguments = [FunctionArgumentDescription]()
  212. arguments.append(
  213. .init(
  214. label: "forMethod",
  215. expression: .identifierPattern(
  216. self.methodDescriptorPath(for: method, service: service)
  217. )
  218. )
  219. )
  220. arguments.append(
  221. .init(
  222. label: "deserializer",
  223. expression: .identifierPattern(
  224. codeGenerationRequest.lookupDeserializer(
  225. self.methodInputOutputTypealias(for: method, service: service, type: .input)
  226. )
  227. )
  228. )
  229. )
  230. arguments.append(
  231. .init(
  232. label: "serializer",
  233. expression:
  234. .identifierPattern(
  235. codeGenerationRequest.lookupSerializer(
  236. self.methodInputOutputTypealias(for: method, service: service, type: .output)
  237. )
  238. )
  239. )
  240. )
  241. let getFunctionCall = Expression.functionCall(
  242. calledExpression: .memberAccess(
  243. MemberAccessDescription(
  244. left: .identifierPattern("self"),
  245. right: method.name.generatedLowerCase
  246. )
  247. ),
  248. arguments: [
  249. FunctionArgumentDescription(label: "request", expression: .identifierPattern("request"))
  250. ]
  251. )
  252. let handlerClosureBody = Expression.unaryKeyword(
  253. kind: .try,
  254. expression: .unaryKeyword(kind: .await, expression: getFunctionCall)
  255. )
  256. arguments.append(
  257. .init(
  258. label: "handler",
  259. expression: .closureInvocation(
  260. .init(argumentNames: ["request"], body: [.expression(handlerClosureBody)])
  261. )
  262. )
  263. )
  264. return arguments
  265. }
  266. private func makeServiceProtocol(
  267. for service: CodeGenerationRequest.ServiceDescriptor
  268. ) -> Declaration {
  269. let methods = service.methods.compactMap {
  270. self.makeServiceProtocolMethod(for: $0, in: service)
  271. }
  272. let protocolName = self.protocolName(service: service, streaming: false)
  273. let streamingProtocol = self.protocolNameTypealias(service: service, streaming: true)
  274. return .commentable(
  275. .preFormatted(service.documentation),
  276. .protocol(
  277. ProtocolDescription(
  278. accessModifier: self.accessModifier,
  279. name: protocolName,
  280. conformances: [streamingProtocol],
  281. members: methods
  282. )
  283. )
  284. )
  285. }
  286. private func makeServiceProtocolMethod(
  287. for method: CodeGenerationRequest.ServiceDescriptor.MethodDescriptor,
  288. in service: CodeGenerationRequest.ServiceDescriptor,
  289. accessModifier: AccessModifier? = nil
  290. ) -> Declaration {
  291. let inputStreaming = method.isInputStreaming ? "Stream" : "Single"
  292. let outputStreaming = method.isOutputStreaming ? "Stream" : "Single"
  293. let inputTypealiasComponents = self.methodInputOutputTypealias(
  294. for: method,
  295. service: service,
  296. type: .input
  297. )
  298. let outputTypealiasComponents = self.methodInputOutputTypealias(
  299. for: method,
  300. service: service,
  301. type: .output
  302. )
  303. let functionSignature = FunctionSignatureDescription(
  304. accessModifier: accessModifier,
  305. kind: .function(name: method.name.generatedLowerCase),
  306. parameters: [
  307. .init(
  308. label: "request",
  309. type:
  310. .generic(
  311. wrapper: .member(["ServerRequest", inputStreaming]),
  312. wrapped: .member(inputTypealiasComponents)
  313. )
  314. )
  315. ],
  316. keywords: [.async, .throws],
  317. returnType: .identifierType(
  318. .generic(
  319. wrapper: .member(["ServerResponse", outputStreaming]),
  320. wrapped: .member(outputTypealiasComponents)
  321. )
  322. )
  323. )
  324. return .commentable(
  325. .preFormatted(method.documentation),
  326. .function(FunctionDescription(signature: functionSignature))
  327. )
  328. }
  329. private func makeExtensionServiceProtocol(
  330. for service: CodeGenerationRequest.ServiceDescriptor
  331. ) -> Declaration {
  332. let methods = service.methods.compactMap {
  333. self.makeServiceProtocolExtensionMethod(for: $0, in: service)
  334. }
  335. let protocolName = self.protocolNameTypealias(service: service, streaming: false)
  336. return .extension(
  337. onType: protocolName,
  338. declarations: methods
  339. )
  340. }
  341. private func makeServiceProtocolExtensionMethod(
  342. for method: CodeGenerationRequest.ServiceDescriptor.MethodDescriptor,
  343. in service: CodeGenerationRequest.ServiceDescriptor
  344. ) -> Declaration? {
  345. // The method has the same definition in StreamingServiceProtocol and ServiceProtocol.
  346. if method.isInputStreaming && method.isOutputStreaming {
  347. return nil
  348. }
  349. let response = CodeBlock(item: .declaration(self.makeResponse(for: method)))
  350. let returnStatement = CodeBlock(item: .expression(self.makeReturnStatement(for: method)))
  351. return .function(
  352. signature: self.makeStreamingMethodSignature(
  353. for: method,
  354. in: service,
  355. accessModifier: self.accessModifier
  356. ),
  357. body: [response, returnStatement]
  358. )
  359. }
  360. private func makeResponse(
  361. for method: CodeGenerationRequest.ServiceDescriptor.MethodDescriptor
  362. ) -> Declaration {
  363. let serverRequest: Expression
  364. if !method.isInputStreaming {
  365. // Transform the streaming request into a unary request.
  366. serverRequest = Expression.functionCall(
  367. calledExpression: .memberAccess(
  368. MemberAccessDescription(
  369. left: .identifierPattern("ServerRequest"),
  370. right: "Single"
  371. )
  372. ),
  373. arguments: [
  374. FunctionArgumentDescription(label: "stream", expression: .identifierPattern("request"))
  375. ]
  376. )
  377. } else {
  378. serverRequest = Expression.identifierPattern("request")
  379. }
  380. // Call to the corresponding ServiceProtocol method.
  381. let serviceProtocolMethod = Expression.functionCall(
  382. calledExpression: .memberAccess(
  383. MemberAccessDescription(
  384. left: .identifierPattern("self"),
  385. right: method.name.generatedLowerCase
  386. )
  387. ),
  388. arguments: [FunctionArgumentDescription(label: "request", expression: serverRequest)]
  389. )
  390. let responseValue = Expression.unaryKeyword(
  391. kind: .try,
  392. expression: .unaryKeyword(kind: .await, expression: serviceProtocolMethod)
  393. )
  394. return .variable(kind: .let, left: "response", right: responseValue)
  395. }
  396. private func makeReturnStatement(
  397. for method: CodeGenerationRequest.ServiceDescriptor.MethodDescriptor
  398. ) -> Expression {
  399. let returnValue: Expression
  400. // Transforming the unary response into a streaming one.
  401. if !method.isOutputStreaming {
  402. returnValue = .functionCall(
  403. calledExpression: .memberAccess(
  404. MemberAccessDescription(
  405. left: .identifierType(.member(["ServerResponse"])),
  406. right: "Stream"
  407. )
  408. ),
  409. arguments: [
  410. (FunctionArgumentDescription(label: "single", expression: .identifierPattern("response")))
  411. ]
  412. )
  413. } else {
  414. returnValue = .identifierPattern("response")
  415. }
  416. return .unaryKeyword(kind: .return, expression: returnValue)
  417. }
  418. fileprivate enum InputOutputType {
  419. case input
  420. case output
  421. }
  422. /// Generates the fully qualified name of the typealias for the input or output type of a method.
  423. private func methodInputOutputTypealias(
  424. for method: CodeGenerationRequest.ServiceDescriptor.MethodDescriptor,
  425. service: CodeGenerationRequest.ServiceDescriptor,
  426. type: InputOutputType
  427. ) -> String {
  428. var components: String =
  429. "\(service.namespacedTypealiasGeneratedName).Method.\(method.name.generatedUpperCase)"
  430. switch type {
  431. case .input:
  432. components.append(".Input")
  433. case .output:
  434. components.append(".Output")
  435. }
  436. return components
  437. }
  438. /// Generates the fully qualified name of a method descriptor.
  439. private func methodDescriptorPath(
  440. for method: CodeGenerationRequest.ServiceDescriptor.MethodDescriptor,
  441. service: CodeGenerationRequest.ServiceDescriptor
  442. ) -> String {
  443. return
  444. "\(service.namespacedTypealiasGeneratedName).Method.\(method.name.generatedUpperCase).descriptor"
  445. }
  446. /// Generates the fully qualified name of the type alias for a service protocol.
  447. internal func protocolNameTypealias(
  448. service: CodeGenerationRequest.ServiceDescriptor,
  449. streaming: Bool
  450. ) -> String {
  451. if streaming {
  452. return "\(service.namespacedTypealiasGeneratedName).StreamingServiceProtocol"
  453. }
  454. return "\(service.namespacedTypealiasGeneratedName).ServiceProtocol"
  455. }
  456. /// Generates the name of a service protocol.
  457. internal func protocolName(
  458. service: CodeGenerationRequest.ServiceDescriptor,
  459. streaming: Bool
  460. ) -> String {
  461. if streaming {
  462. return "\(service.namespacedGeneratedName)StreamingServiceProtocol"
  463. }
  464. return "\(service.namespacedGeneratedName)ServiceProtocol"
  465. }
  466. }