Generator-Client.swift 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623
  1. /*
  2. * Copyright 2018, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import Foundation
  17. import SwiftProtobuf
  18. import SwiftProtobufPluginLibrary
  19. extension Generator {
  20. internal func printClient() {
  21. if self.options.generateClient {
  22. self.println()
  23. self.printServiceClientProtocol()
  24. self.println()
  25. self.printClientProtocolExtension()
  26. self.println()
  27. self.printServiceClientInterceptorFactoryProtocol()
  28. self.println()
  29. self.printServiceClientInterceptorFactoryProtocolExtension()
  30. self.println()
  31. self.printServiceClientImplementation()
  32. }
  33. if self.options.generateTestClient {
  34. self.println()
  35. self.printTestClient()
  36. }
  37. }
  38. private func printFunction(
  39. name: String,
  40. arguments: [String],
  41. returnType: String?,
  42. access: String? = nil,
  43. bodyBuilder: (() -> Void)?
  44. ) {
  45. // Add a space after access, if it exists.
  46. let accessOrEmpty = access.map { $0 + " " } ?? ""
  47. let `return` = returnType.map { "-> " + $0 } ?? ""
  48. let hasBody = bodyBuilder != nil
  49. if arguments.isEmpty {
  50. // Don't bother splitting across multiple lines if there are no arguments.
  51. self.println("\(accessOrEmpty)func \(name)() \(`return`)", newline: !hasBody)
  52. } else {
  53. self.println("\(accessOrEmpty)func \(name)(")
  54. self.withIndentation {
  55. // Add a comma after each argument except the last.
  56. arguments.forEach(beforeLast: {
  57. self.println($0 + ",")
  58. }, onLast: {
  59. self.println($0)
  60. })
  61. }
  62. self.println(") \(`return`)", newline: !hasBody)
  63. }
  64. if let bodyBuilder = bodyBuilder {
  65. self.println(" {")
  66. self.withIndentation {
  67. bodyBuilder()
  68. }
  69. self.println("}")
  70. }
  71. }
  72. private func printServiceClientProtocol() {
  73. self.println(
  74. "/// Usage: instantiate \(self.clientClassName), then call methods of this protocol to make API calls."
  75. )
  76. self.println("\(self.access) protocol \(self.clientProtocolName): GRPCClient {")
  77. self.withIndentation {
  78. self.println("var interceptors: \(self.clientInterceptorProtocolName)? { get }")
  79. for method in service.methods {
  80. self.println()
  81. self.method = method
  82. self.printFunction(
  83. name: self.methodFunctionName,
  84. arguments: self.methodArgumentsWithoutDefaults,
  85. returnType: self.methodReturnType,
  86. bodyBuilder: nil
  87. )
  88. }
  89. }
  90. println("}")
  91. }
  92. private func printClientProtocolExtension() {
  93. self.println("extension \(self.clientProtocolName) {")
  94. // Default method implementations.
  95. self.withIndentation {
  96. self.printMethods()
  97. }
  98. self.println("}")
  99. }
  100. private func printServiceClientInterceptorFactoryProtocol() {
  101. self.println("\(self.access) protocol \(self.clientInterceptorProtocolName) {")
  102. self.withIndentation {
  103. // Generic interceptor.
  104. self.println("/// Makes an array of generic interceptors. The per-method interceptor")
  105. self.println("/// factories default to calling this function and it therefore provides a")
  106. self.println("/// convenient way of setting interceptors for all methods on a client.")
  107. self.println("/// - Returns: An array of interceptors generic over `Request` and `Response`.")
  108. self.println("/// Defaults to an empty array.")
  109. self.printFunction(
  110. name: "makeInterceptors<Request: SwiftProtobuf.Message, Response: SwiftProtobuf.Message>",
  111. arguments: [],
  112. returnType: "[ClientInterceptor<Request, Response>]",
  113. bodyBuilder: nil
  114. )
  115. // Method specific interceptors.
  116. for method in service.methods {
  117. self.println()
  118. self.method = method
  119. self.println(
  120. "/// - Returns: Interceptors to use when invoking '\(self.methodFunctionName)'."
  121. )
  122. self.println("/// Defaults to calling `self.makeInterceptors()`.")
  123. // Skip the access, we're defining a protocol.
  124. self.printMethodInterceptorFactory(access: nil)
  125. }
  126. }
  127. self.println("}")
  128. }
  129. private func printMethodInterceptorFactory(
  130. access: String?,
  131. bodyBuilder: (() -> Void)? = nil
  132. ) {
  133. self.printFunction(
  134. name: self.methodInterceptorFactoryName,
  135. arguments: [],
  136. returnType: "[ClientInterceptor<\(self.methodInputName), \(self.methodOutputName)>]",
  137. access: access,
  138. bodyBuilder: bodyBuilder
  139. )
  140. }
  141. private func printServiceClientInterceptorFactoryProtocolExtension() {
  142. self.println("extension \(self.clientInterceptorProtocolName) {")
  143. self.withIndentation {
  144. // Default interceptor factory.
  145. self.printFunction(
  146. name: "makeInterceptors<Request: SwiftProtobuf.Message, Response: SwiftProtobuf.Message>",
  147. arguments: [],
  148. returnType: "[ClientInterceptor<Request, Response>]",
  149. access: self.access
  150. ) {
  151. self.println("return []")
  152. }
  153. for method in self.service.methods {
  154. self.println()
  155. self.method = method
  156. self.printMethodInterceptorFactory(access: self.access) {
  157. self.println("return self.makeInterceptors()")
  158. }
  159. }
  160. }
  161. self.println("}")
  162. }
  163. private func printServiceClientImplementation() {
  164. println("\(access) final class \(clientClassName): \(clientProtocolName) {")
  165. self.withIndentation {
  166. println("\(access) let channel: GRPCChannel")
  167. println("\(access) var defaultCallOptions: CallOptions")
  168. println("\(access) var interceptors: \(clientInterceptorProtocolName)?")
  169. println()
  170. println("/// Creates a client for the \(servicePath) service.")
  171. println("///")
  172. self.printParameters()
  173. println("/// - channel: `GRPCChannel` to the service host.")
  174. println(
  175. "/// - defaultCallOptions: Options to use for each service call if the user doesn't provide them."
  176. )
  177. println("/// - interceptors: A factory providing interceptors for each RPC.")
  178. println("\(access) init(")
  179. self.withIndentation {
  180. println("channel: GRPCChannel,")
  181. println("defaultCallOptions: CallOptions = CallOptions(),")
  182. println("interceptors: \(clientInterceptorProtocolName)? = nil")
  183. }
  184. self.println(") {")
  185. self.withIndentation {
  186. println("self.channel = channel")
  187. println("self.defaultCallOptions = defaultCallOptions")
  188. println("self.interceptors = interceptors")
  189. }
  190. self.println("}")
  191. }
  192. println("}")
  193. }
  194. private func printMethods() {
  195. for method in self.service.methods {
  196. self.println()
  197. self.method = method
  198. switch self.streamType {
  199. case .unary:
  200. self.printUnaryCall()
  201. case .serverStreaming:
  202. self.printServerStreamingCall()
  203. case .clientStreaming:
  204. self.printClientStreamingCall()
  205. case .bidirectionalStreaming:
  206. self.printBidirectionalStreamingCall()
  207. }
  208. }
  209. }
  210. private func printUnaryCall() {
  211. self.println(self.method.documentation(streamingType: self.streamType), newline: false)
  212. self.println("///")
  213. self.printParameters()
  214. self.printRequestParameter()
  215. self.printCallOptionsParameter()
  216. self.println("/// - Returns: A `UnaryCall` with futures for the metadata, status and response.")
  217. self.printFunction(
  218. name: self.methodFunctionName,
  219. arguments: self.methodArguments,
  220. returnType: self.methodReturnType,
  221. access: self.access
  222. ) {
  223. self.println("return self.makeUnaryCall(")
  224. self.withIndentation {
  225. self.println("path: \(self.methodPath),")
  226. self.println("request: request,")
  227. self.println("callOptions: callOptions ?? self.defaultCallOptions,")
  228. self.println(
  229. "interceptors: self.interceptors?.\(self.methodInterceptorFactoryName)() ?? []"
  230. )
  231. }
  232. self.println(")")
  233. }
  234. }
  235. private func printServerStreamingCall() {
  236. self.println(self.method.documentation(streamingType: self.streamType), newline: false)
  237. self.println("///")
  238. self.printParameters()
  239. self.printRequestParameter()
  240. self.printCallOptionsParameter()
  241. self.printHandlerParameter()
  242. self.println("/// - Returns: A `ServerStreamingCall` with futures for the metadata and status.")
  243. self.printFunction(
  244. name: self.methodFunctionName,
  245. arguments: self.methodArguments,
  246. returnType: self.methodReturnType,
  247. access: self.access
  248. ) {
  249. self.println("return self.makeServerStreamingCall(")
  250. self.withIndentation {
  251. self.println("path: \(self.methodPath),")
  252. self.println("request: request,")
  253. self.println("callOptions: callOptions ?? self.defaultCallOptions,")
  254. self.println(
  255. "interceptors: self.interceptors?.\(self.methodInterceptorFactoryName)() ?? [],"
  256. )
  257. self.println("handler: handler")
  258. }
  259. self.println(")")
  260. }
  261. }
  262. private func printClientStreamingCall() {
  263. self.println(self.method.documentation(streamingType: self.streamType), newline: false)
  264. self.println("///")
  265. self.printClientStreamingDetails()
  266. self.println("///")
  267. self.printParameters()
  268. self.printCallOptionsParameter()
  269. self
  270. .println(
  271. "/// - Returns: A `ClientStreamingCall` with futures for the metadata, status and response."
  272. )
  273. self.printFunction(
  274. name: self.methodFunctionName,
  275. arguments: self.methodArguments,
  276. returnType: self.methodReturnType,
  277. access: self.access
  278. ) {
  279. self.println("return self.makeClientStreamingCall(")
  280. self.withIndentation {
  281. self.println("path: \(self.methodPath),")
  282. self.println("callOptions: callOptions ?? self.defaultCallOptions,")
  283. self.println(
  284. "interceptors: self.interceptors?.\(self.methodInterceptorFactoryName)() ?? []"
  285. )
  286. }
  287. self.println(")")
  288. }
  289. }
  290. private func printBidirectionalStreamingCall() {
  291. self.println(self.method.documentation(streamingType: self.streamType), newline: false)
  292. self.println("///")
  293. self.printClientStreamingDetails()
  294. self.println("///")
  295. self.printParameters()
  296. self.printCallOptionsParameter()
  297. self.printHandlerParameter()
  298. self.println("/// - Returns: A `ClientStreamingCall` with futures for the metadata and status.")
  299. self.printFunction(
  300. name: self.methodFunctionName,
  301. arguments: self.methodArguments,
  302. returnType: self.methodReturnType,
  303. access: self.access
  304. ) {
  305. self.println("return self.makeBidirectionalStreamingCall(")
  306. self.withIndentation {
  307. self.println("path: \(self.methodPath),")
  308. self.println("callOptions: callOptions ?? self.defaultCallOptions,")
  309. self.println(
  310. "interceptors: self.interceptors?.\(self.methodInterceptorFactoryName)() ?? [],"
  311. )
  312. self.println("handler: handler")
  313. }
  314. self.println(")")
  315. }
  316. }
  317. private func printClientStreamingDetails() {
  318. println("/// Callers should use the `send` method on the returned object to send messages")
  319. println(
  320. "/// to the server. The caller should send an `.end` after the final message has been sent."
  321. )
  322. }
  323. private func printParameters() {
  324. println("/// - Parameters:")
  325. }
  326. private func printRequestParameter() {
  327. println("/// - request: Request to send to \(method.name).")
  328. }
  329. private func printCallOptionsParameter() {
  330. println("/// - callOptions: Call options.")
  331. }
  332. private func printHandlerParameter() {
  333. println("/// - handler: A closure called when each response is received from the server.")
  334. }
  335. }
  336. extension Generator {
  337. fileprivate func printFakeResponseStreams() {
  338. for method in self.service.methods {
  339. self.println()
  340. self.method = method
  341. switch self.streamType {
  342. case .unary, .clientStreaming:
  343. self.printUnaryResponse()
  344. case .serverStreaming, .bidirectionalStreaming:
  345. self.printStreamingResponse()
  346. }
  347. }
  348. }
  349. fileprivate func printUnaryResponse() {
  350. self.printResponseStream(isUnary: true)
  351. self.println()
  352. self.printEnqueueUnaryResponse(isUnary: true)
  353. self.println()
  354. self.printHasResponseStreamEnqueued()
  355. }
  356. fileprivate func printStreamingResponse() {
  357. self.printResponseStream(isUnary: false)
  358. self.println()
  359. self.printEnqueueUnaryResponse(isUnary: false)
  360. self.println()
  361. self.printHasResponseStreamEnqueued()
  362. }
  363. private func printEnqueueUnaryResponse(isUnary: Bool) {
  364. let name: String
  365. let responseArg: String
  366. let responseArgAndType: String
  367. if isUnary {
  368. name = "enqueue\(self.method.name)Response"
  369. responseArg = "response"
  370. responseArgAndType = "_ \(responseArg): \(self.methodOutputName)"
  371. } else {
  372. name = "enqueue\(self.method.name)Responses"
  373. responseArg = "responses"
  374. responseArgAndType = "_ \(responseArg): [\(self.methodOutputName)]"
  375. }
  376. self.printFunction(
  377. name: name,
  378. arguments: [
  379. responseArgAndType,
  380. "_ requestHandler: @escaping (FakeRequestPart<\(self.methodInputName)>) -> () = { _ in }",
  381. ],
  382. returnType: nil,
  383. access: self.access
  384. ) {
  385. self.println("let stream = self.make\(self.method.name)ResponseStream(requestHandler)")
  386. if isUnary {
  387. self.println("// This is the only operation on the stream; try! is fine.")
  388. self.println("try! stream.sendMessage(\(responseArg))")
  389. } else {
  390. self.println("// These are the only operation on the stream; try! is fine.")
  391. self.println("\(responseArg).forEach { try! stream.sendMessage($0) }")
  392. self.println("try! stream.sendEnd()")
  393. }
  394. }
  395. }
  396. private func printResponseStream(isUnary: Bool) {
  397. let type = isUnary ? "FakeUnaryResponse" : "FakeStreamingResponse"
  398. let factory = isUnary ? "makeFakeUnaryResponse" : "makeFakeStreamingResponse"
  399. self
  400. .println(
  401. "/// Make a \(isUnary ? "unary" : "streaming") response for the \(self.method.name) RPC. This must be called"
  402. )
  403. self.println("/// before calling '\(self.methodFunctionName)'. See also '\(type)'.")
  404. self.println("///")
  405. self.println("/// - Parameter requestHandler: a handler for request parts sent by the RPC.")
  406. self.printFunction(
  407. name: "make\(self.method.name)ResponseStream",
  408. arguments: [
  409. "_ requestHandler: @escaping (FakeRequestPart<\(self.methodInputName)>) -> () = { _ in }",
  410. ],
  411. returnType: "\(type)<\(self.methodInputName), \(self.methodOutputName)>",
  412. access: self.access
  413. ) {
  414. self
  415. .println(
  416. "return self.fakeChannel.\(factory)(path: \(self.methodPath), requestHandler: requestHandler)"
  417. )
  418. }
  419. }
  420. private func printHasResponseStreamEnqueued() {
  421. self
  422. .println("/// Returns true if there are response streams enqueued for '\(self.method.name)'")
  423. self.println("\(self.access) var has\(self.method.name)ResponsesRemaining: Bool {")
  424. self.withIndentation {
  425. self.println("return self.fakeChannel.hasFakeResponseEnqueued(forPath: \(self.methodPath))")
  426. }
  427. self.println("}")
  428. }
  429. fileprivate func printTestClient() {
  430. self
  431. .println(
  432. "\(self.access) final class \(self.testClientClassName): \(self.clientProtocolName) {"
  433. )
  434. self.withIndentation {
  435. self.println("private let fakeChannel: FakeChannel")
  436. self.println("\(self.access) var defaultCallOptions: CallOptions")
  437. self.println("\(self.access) var interceptors: \(self.clientInterceptorProtocolName)?")
  438. self.println()
  439. self.println("\(self.access) var channel: GRPCChannel {")
  440. self.withIndentation {
  441. self.println("return self.fakeChannel")
  442. }
  443. self.println("}")
  444. self.println()
  445. self.println("\(self.access) init(")
  446. self.withIndentation {
  447. self.println("fakeChannel: FakeChannel = FakeChannel(),")
  448. self.println("defaultCallOptions callOptions: CallOptions = CallOptions(),")
  449. self.println("interceptors: \(clientInterceptorProtocolName)? = nil")
  450. }
  451. self.println(") {")
  452. self.withIndentation {
  453. self.println("self.fakeChannel = fakeChannel")
  454. self.println("self.defaultCallOptions = callOptions")
  455. self.println("self.interceptors = interceptors")
  456. }
  457. self.println("}")
  458. self.printFakeResponseStreams()
  459. }
  460. self.println("}") // end class
  461. }
  462. }
  463. private extension Generator {
  464. var streamType: StreamingType {
  465. return streamingType(self.method)
  466. }
  467. }
  468. extension Generator {
  469. fileprivate var methodArguments: [String] {
  470. switch self.streamType {
  471. case .unary:
  472. return [
  473. "_ request: \(self.methodInputName)",
  474. "callOptions: CallOptions? = nil",
  475. ]
  476. case .serverStreaming:
  477. return [
  478. "_ request: \(self.methodInputName)",
  479. "callOptions: CallOptions? = nil",
  480. "handler: @escaping (\(methodOutputName)) -> Void",
  481. ]
  482. case .clientStreaming:
  483. return ["callOptions: CallOptions? = nil"]
  484. case .bidirectionalStreaming:
  485. return [
  486. "callOptions: CallOptions? = nil",
  487. "handler: @escaping (\(methodOutputName)) -> Void",
  488. ]
  489. }
  490. }
  491. fileprivate var methodArgumentsWithoutDefaults: [String] {
  492. return self.methodArguments.map { arg in
  493. // Remove default arg from call options.
  494. if arg == "callOptions: CallOptions? = nil" {
  495. return "callOptions: CallOptions?"
  496. } else {
  497. return arg
  498. }
  499. }
  500. }
  501. fileprivate var methodArgumentsWithoutCallOptions: [String] {
  502. return self.methodArguments.filter {
  503. !$0.hasPrefix("callOptions: ")
  504. }
  505. }
  506. fileprivate var methodReturnType: String {
  507. switch self.streamType {
  508. case .unary:
  509. return "UnaryCall<\(self.methodInputName), \(self.methodOutputName)>"
  510. case .serverStreaming:
  511. return "ServerStreamingCall<\(self.methodInputName), \(self.methodOutputName)>"
  512. case .clientStreaming:
  513. return "ClientStreamingCall<\(self.methodInputName), \(self.methodOutputName)>"
  514. case .bidirectionalStreaming:
  515. return "BidirectionalStreamingCall<\(self.methodInputName), \(self.methodOutputName)>"
  516. }
  517. }
  518. }
  519. private extension StreamingType {
  520. var name: String {
  521. switch self {
  522. case .unary:
  523. return "Unary"
  524. case .clientStreaming:
  525. return "Client streaming"
  526. case .serverStreaming:
  527. return "Server streaming"
  528. case .bidirectionalStreaming:
  529. return "Bidirectional streaming"
  530. }
  531. }
  532. }
  533. extension MethodDescriptor {
  534. var documentation: String? {
  535. let comments = self.protoSourceComments(commentPrefix: "")
  536. return comments.isEmpty ? nil : comments
  537. }
  538. fileprivate func documentation(streamingType: StreamingType) -> String {
  539. let sourceComments = self.protoSourceComments()
  540. if sourceComments.isEmpty {
  541. return "/// \(streamingType.name) call to \(self.name)\n" // comments end with "\n" already.
  542. } else {
  543. return sourceComments // already prefixed with "///"
  544. }
  545. }
  546. }
  547. extension Array {
  548. /// Like `forEach` except that the `body` closure operates on all elements except for the last,
  549. /// and the `last` closure only operates on the last element.
  550. fileprivate func forEach(beforeLast body: (Element) -> Void, onLast last: (Element) -> Void) {
  551. for element in self.dropLast() {
  552. body(element)
  553. }
  554. if let lastElement = self.last {
  555. last(lastElement)
  556. }
  557. }
  558. }