Browse Source

Code generation for async-await (#1259)

Motivation:

Manually constructing clients and servers is an error prone nightmare.
We should generate them instead!

Modifications:

- Add async-await code-generation for server and client.
- The client code generation is missing "simple-safe" wrappers for now,
  this can be added later.
- Naming represents the current state of the branch rather than anything
  final
- Add options for "ExperimentalAsyncClient" and
  "ExperimentalAsyncServer" -- these may be used in conjunction with the
  'regular' "Client" and "Server" options.

Result:

We can generate async-await style grpc clients and servers.
George Barnett 4 years ago
parent
commit
7e977965da

+ 162 - 0
Sources/protoc-gen-grpc-swift/Generator-Client+AsyncAwait.swift

@@ -0,0 +1,162 @@
+/*
+ * Copyright 2021, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import SwiftProtobuf
+import SwiftProtobufPluginLibrary
+
+// MARK: - Client protocol
+
+extension Generator {
+  internal func printAsyncServiceClientProtocol() {
+    let comments = self.service.protoSourceComments()
+    if !comments.isEmpty {
+      // Source comments already have the leading '///'
+      self.println(comments, newline: false)
+    }
+
+    self.printAvailabilityForAsyncAwait()
+    self.println("\(self.access) protocol \(self.asyncClientProtocolName): GRPCClient {")
+    self.withIndentation {
+      self.println("var serviceName: String { get }")
+      self.println("var interceptors: \(self.clientInterceptorProtocolName)? { get }")
+
+      for method in service.methods {
+        self.println()
+        self.method = method
+
+        let rpcType = streamingType(self.method)
+        let callType = Types.call(for: rpcType)
+
+        let arguments: [String]
+        switch rpcType {
+        case .unary, .serverStreaming:
+          arguments = [
+            "_ request: \(self.methodInputName)",
+            "callOptions: \(Types.clientCallOptions)?",
+          ]
+
+        case .clientStreaming, .bidirectionalStreaming:
+          arguments = [
+            "callOptions: \(Types.clientCallOptions)?",
+          ]
+        }
+
+        self.printFunction(
+          name: "make\(self.method.name)Call",
+          arguments: arguments,
+          returnType: "\(callType)<\(self.methodInputName), \(self.methodOutputName)>",
+          bodyBuilder: nil
+        )
+      }
+    }
+    self.println("}") // protocol
+  }
+}
+
+// MARK: - Client protocol default implementation: Calls
+
+extension Generator {
+  internal func printAsyncClientProtocolExtension() {
+    self.printAvailabilityForAsyncAwait()
+    self.withIndentation("extension \(self.asyncClientProtocolName)", braces: .curly) {
+      // Service name. TODO: use static metadata.
+      self.withIndentation("\(self.access) var serviceName: String", braces: .curly) {
+        self.println("return \"\(self.servicePath)\"")
+      }
+      self.println()
+
+      // Interceptor factory.
+      self.withIndentation(
+        "\(self.access) var interceptors: \(self.clientInterceptorProtocolName)?",
+        braces: .curly
+      ) {
+        self.println("return nil")
+      }
+
+      // 'Unsafe' calls.
+      for method in self.service.methods {
+        self.println()
+        self.method = method
+
+        let rpcType = streamingType(self.method)
+        let callType = Types.call(for: rpcType)
+        let callTypeWithoutPrefix = Types.call(for: rpcType, withGRPCPrefix: false)
+
+        switch rpcType {
+        case .unary, .serverStreaming:
+          self.printFunction(
+            name: "make\(self.method.name)Call",
+            arguments: [
+              "_ request: \(self.methodInputName)",
+              "callOptions: \(Types.clientCallOptions)? = nil",
+            ],
+            returnType: "\(callType)<\(self.methodInputName), \(self.methodOutputName)>",
+            access: self.access
+          ) {
+            self.withIndentation("return self.make\(callTypeWithoutPrefix)", braces: .round) {
+              self.println("path: \(self.methodPath),")
+              self.println("request: request,")
+              self.println("callOptions: callOptions ?? self.defaultCallOptions")
+            }
+          }
+
+        case .clientStreaming, .bidirectionalStreaming:
+          self.printFunction(
+            name: "make\(self.method.name)Call",
+            arguments: ["callOptions: \(Types.clientCallOptions)? = nil"],
+            returnType: "\(callType)<\(self.methodInputName), \(self.methodOutputName)>",
+            access: self.access
+          ) {
+            self.withIndentation("return self.make\(callTypeWithoutPrefix)", braces: .round) {
+              self.println("path: \(self.methodPath),")
+              self.println("callOptions: callOptions ?? self.defaultCallOptions")
+            }
+          }
+        }
+      }
+    }
+  }
+}
+
+// MARK: - Client protocol implementation
+
+extension Generator {
+  internal func printAsyncServiceClientImplementation() {
+    self.printAvailabilityForAsyncAwait()
+    self.withIndentation(
+      "\(self.access) struct \(self.asyncClientClassName): \(self.asyncClientProtocolName)",
+      braces: .curly
+    ) {
+      self.println("\(self.access) var channel: GRPCChannel")
+      self.println("\(self.access) var defaultCallOptions: CallOptions")
+      self.println("\(self.access) var interceptors: \(self.clientInterceptorProtocolName)?")
+      self.println()
+
+      self.println("\(self.access) init(")
+      self.withIndentation {
+        self.println("channel: GRPCChannel,")
+        self.println("defaultCallOptions: CallOptions = CallOptions(),")
+        self.println("interceptors: \(self.clientInterceptorProtocolName)? = nil")
+      }
+      self.println(") {")
+      self.withIndentation {
+        self.println("self.channel = channel")
+        self.println("self.defaultCallOptions = defaultCallOptions")
+        self.println("self.interceptors = interceptors")
+      }
+      self.println("}")
+    }
+  }
+}

+ 37 - 5
Sources/protoc-gen-grpc-swift/Generator-Client.swift

@@ -30,6 +30,18 @@ extension Generator {
       self.printServiceClientImplementation()
       self.printServiceClientImplementation()
     }
     }
 
 
+    if self.options.generateAsyncClient {
+      self.println()
+      self.printIfCompilerGuardForAsyncAwait()
+      self.printAsyncServiceClientProtocol()
+      self.println()
+      self.printAsyncClientProtocolExtension()
+      self.println()
+      self.printAsyncServiceClientImplementation()
+      self.println()
+      self.printEndCompilerGuardForAsyncAwait()
+    }
+
     if self.options.generateTestClient {
     if self.options.generateTestClient {
       self.println()
       self.println()
       self.printTestClient()
       self.printTestClient()
@@ -41,19 +53,39 @@ extension Generator {
     arguments: [String],
     arguments: [String],
     returnType: String?,
     returnType: String?,
     access: String? = nil,
     access: String? = nil,
+    sendable: Bool = false,
+    async: Bool = false,
+    throws: Bool = false,
+    genericWhereClause: String? = nil,
     bodyBuilder: (() -> Void)?
     bodyBuilder: (() -> Void)?
   ) {
   ) {
     // Add a space after access, if it exists.
     // Add a space after access, if it exists.
-    let accessOrEmpty = access.map { $0 + " " } ?? ""
-    let `return` = returnType.map { "-> " + $0 } ?? ""
+    let functionHead = (access.map { $0 + " " } ?? "") + (sendable ? "@Sendable " : "")
+    let `return` = returnType.map { " -> " + $0 } ?? ""
+    let genericWhere = genericWhereClause.map { " " + $0 } ?? ""
+
+    let asyncThrows: String
+    switch (async, `throws`) {
+    case (true, true):
+      asyncThrows = " async throws"
+    case (true, false):
+      asyncThrows = " async"
+    case (false, true):
+      asyncThrows = " throws"
+    case (false, false):
+      asyncThrows = ""
+    }
 
 
     let hasBody = bodyBuilder != nil
     let hasBody = bodyBuilder != nil
 
 
     if arguments.isEmpty {
     if arguments.isEmpty {
       // Don't bother splitting across multiple lines if there are no arguments.
       // Don't bother splitting across multiple lines if there are no arguments.
-      self.println("\(accessOrEmpty)func \(name)() \(`return`)", newline: !hasBody)
+      self.println(
+        "\(functionHead)func \(name)()\(asyncThrows)\(`return`)\(genericWhere)",
+        newline: !hasBody
+      )
     } else {
     } else {
-      self.println("\(accessOrEmpty)func \(name)(")
+      self.println("\(functionHead)func \(name)(")
       self.withIndentation {
       self.withIndentation {
         // Add a comma after each argument except the last.
         // Add a comma after each argument except the last.
         arguments.forEach(beforeLast: {
         arguments.forEach(beforeLast: {
@@ -62,7 +94,7 @@ extension Generator {
           self.println($0)
           self.println($0)
         })
         })
       }
       }
-      self.println(") \(`return`)", newline: !hasBody)
+      self.println(")\(asyncThrows)\(`return`)\(genericWhere)", newline: !hasBody)
     }
     }
 
 
     if let bodyBuilder = bodyBuilder {
     if let bodyBuilder = bodyBuilder {

+ 12 - 0
Sources/protoc-gen-grpc-swift/Generator-Names.swift

@@ -74,10 +74,18 @@ extension Generator {
     return nameForPackageService(file, service) + "Provider"
     return nameForPackageService(file, service) + "Provider"
   }
   }
 
 
+  internal var asyncProviderName: String {
+    return nameForPackageService(file, service) + "AsyncProvider"
+  }
+
   internal var clientClassName: String {
   internal var clientClassName: String {
     return nameForPackageService(file, service) + "Client"
     return nameForPackageService(file, service) + "Client"
   }
   }
 
 
+  internal var asyncClientClassName: String {
+    return nameForPackageService(file, service) + "AsyncClient"
+  }
+
   internal var testClientClassName: String {
   internal var testClientClassName: String {
     return nameForPackageService(self.file, self.service) + "TestClient"
     return nameForPackageService(self.file, self.service) + "TestClient"
   }
   }
@@ -86,6 +94,10 @@ extension Generator {
     return nameForPackageService(file, service) + "ClientProtocol"
     return nameForPackageService(file, service) + "ClientProtocol"
   }
   }
 
 
+  internal var asyncClientProtocolName: String {
+    return nameForPackageService(file, service) + "AsyncClientProtocol"
+  }
+
   internal var clientInterceptorProtocolName: String {
   internal var clientInterceptorProtocolName: String {
     return nameForPackageService(file, service) + "ClientInterceptorFactoryProtocol"
     return nameForPackageService(file, service) + "ClientInterceptorFactoryProtocol"
   }
   }

+ 170 - 0
Sources/protoc-gen-grpc-swift/Generator-Server+AsyncAwait.swift

@@ -0,0 +1,170 @@
+/*
+ * Copyright 2021, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import SwiftProtobuf
+import SwiftProtobufPluginLibrary
+
+// MARK: - Protocol
+
+extension Generator {
+  internal func printServerProtocolAsyncAwait() {
+    let sourceComments = self.service.protoSourceComments()
+    if !sourceComments.isEmpty {
+      // Source comments already have the leading '///'
+      self.println(sourceComments, newline: false)
+      self.println("///")
+    }
+    self.println("/// To implement a server, implement an object which conforms to this protocol.")
+    self.printAvailabilityForAsyncAwait()
+    self.withIndentation(
+      "\(self.access) protocol \(self.asyncProviderName): CallHandlerProvider",
+      braces: .curly
+    ) {
+      self.println("var interceptors: \(self.serverInterceptorProtocolName)? { get }")
+
+      for method in service.methods {
+        self.method = method
+        self.println()
+        self.printRPCProtocolRequirement()
+      }
+    }
+  }
+
+  fileprivate func printRPCProtocolRequirement() {
+    // Print any comments; skip the newline as source comments include them already.
+    self.println(self.method.protoSourceComments(), newline: false)
+
+    let arguments: [String]
+    let returnType: String?
+
+    switch streamingType(self.method) {
+    case .unary:
+      arguments = [
+        "request: \(self.methodInputName)",
+        "context: \(Types.serverContext)",
+      ]
+      returnType = self.methodOutputName
+
+    case .clientStreaming:
+      arguments = [
+        "requests: \(Types.requestStream(of: self.methodInputName))",
+        "context: \(Types.serverContext)",
+      ]
+      returnType = self.methodOutputName
+
+    case .serverStreaming:
+      arguments = [
+        "request: \(self.methodInputName)",
+        "responseStream: \(Types.responseStreamWriter(of: self.methodOutputName))",
+        "context: \(Types.serverContext)",
+      ]
+      returnType = nil
+
+    case .bidirectionalStreaming:
+      arguments = [
+        "requests: \(Types.requestStream(of: self.methodInputName))",
+        "responseStream: \(Types.responseStreamWriter(of: self.methodOutputName))",
+        "context: \(Types.serverContext)",
+      ]
+      returnType = nil
+    }
+
+    self.printFunction(
+      name: self.methodFunctionName,
+      arguments: arguments,
+      returnType: returnType,
+      sendable: true,
+      async: true,
+      throws: true,
+      bodyBuilder: nil
+    )
+  }
+}
+
+// MARK: - Protocol Extension; RPC handling
+
+extension Generator {
+  internal func printServerProtocolExtensionAsyncAwait() {
+    // Default extension to provide the service name and routing for methods.
+    self.printAvailabilityForAsyncAwait()
+    self.withIndentation("extension \(self.asyncProviderName)", braces: .curly) {
+      self.withIndentation("\(self.access) var serviceName: Substring", braces: .curly) {
+        self.println("return \"\(self.servicePath)\"")
+      }
+
+      self.println()
+
+      // Default nil interceptor factory.
+      self.withIndentation(
+        "\(self.access) var interceptors: \(self.serverInterceptorProtocolName)?",
+        braces: .curly
+      ) {
+        self.println("return nil")
+      }
+
+      self.println()
+
+      self.printFunction(
+        name: "handle",
+        arguments: [
+          "method name: Substring",
+          "context: CallHandlerContext",
+        ],
+        returnType: "GRPCServerHandlerProtocol?",
+        access: self.access
+      ) {
+        self.println("switch name {")
+        for method in self.service.methods {
+          self.method = method
+
+          let requestType = self.methodInputName
+          let responseType = self.methodOutputName
+          let interceptorFactory = self.methodInterceptorFactoryName
+          let functionName = self.methodFunctionName
+
+          self.withIndentation("case \"\(self.method.name)\":", braces: .none) {
+            self.withIndentation("return \(Types.serverHandler)", braces: .round) {
+              self.println("context: context,")
+              self.println("requestDeserializer: \(Types.deserializer(for: requestType))(),")
+              self.println("responseSerializer: \(Types.serializer(for: responseType))(),")
+              self.println("interceptors: self.interceptors?.\(interceptorFactory)() ?? [],")
+              switch streamingType(self.method) {
+              case .unary:
+                self.println("wrapping: self.\(functionName)(request:context:)")
+
+              case .clientStreaming:
+                self.println("wrapping: self.\(functionName)(requests:context:)")
+
+              case .serverStreaming:
+                self.println("wrapping: self.\(functionName)(request:responseStream:context:)")
+
+              case .bidirectionalStreaming:
+                self.println("wrapping: self.\(functionName)(requests:responseStream:context:)")
+              }
+            }
+          }
+        }
+
+        // Default case.
+        self.println("default:")
+        self.withIndentation {
+          self.println("return nil")
+        }
+
+        self.println("}") // switch
+      }
+    }
+  }
+}

+ 25 - 5
Sources/protoc-gen-grpc-swift/Generator-Server.swift

@@ -19,11 +19,31 @@ import SwiftProtobufPluginLibrary
 
 
 extension Generator {
 extension Generator {
   internal func printServer() {
   internal func printServer() {
-    self.printServerProtocol()
-    self.println()
-    self.printServerProtocolExtension()
-    self.println()
-    self.printServerInterceptorFactoryProtocol()
+    if self.options.generateServer {
+      self.printServerProtocol()
+      self.println()
+      self.printServerProtocolExtension()
+      self.println()
+      self.printServerInterceptorFactoryProtocol()
+      self.println()
+    }
+
+    if self.options.generateAsyncServer {
+      self.printIfCompilerGuardForAsyncAwait()
+      self.println()
+      self.printServerProtocolAsyncAwait()
+      self.println()
+      self.printServerProtocolExtensionAsyncAwait()
+      self.println()
+      self.printEndCompilerGuardForAsyncAwait()
+      self.println()
+    }
+
+    // If we generate only the async server we need to print the interceptor factory protocol (as
+    // it is used by both).
+    if self.options.generateAsyncServer, !self.options.generateServer {
+      self.printServerInterceptorFactoryProtocol()
+    }
   }
   }
 
 
   private func printServerProtocol() {
   private func printServerProtocol() {

+ 59 - 1
Sources/protoc-gen-grpc-swift/Generator.swift

@@ -61,6 +61,52 @@ class Generator {
     self.outdent()
     self.outdent()
   }
   }
 
 
+  internal enum Braces {
+    case none
+    case curly
+    case round
+
+    var open: String {
+      switch self {
+      case .none:
+        return ""
+      case .curly:
+        return "{"
+      case .round:
+        return "("
+      }
+    }
+
+    var close: String {
+      switch self {
+      case .none:
+        return ""
+      case .curly:
+        return "}"
+      case .round:
+        return ")"
+      }
+    }
+  }
+
+  internal func withIndentation(
+    _ header: String,
+    braces: Braces,
+    _ body: () -> Void
+  ) {
+    let spaceBeforeOpeningBrace: Bool
+    switch braces {
+    case .curly:
+      spaceBeforeOpeningBrace = true
+    case .round, .none:
+      spaceBeforeOpeningBrace = false
+    }
+
+    self.println(header + "\(spaceBeforeOpeningBrace ? " " : "")" + "\(braces.open)")
+    self.withIndentation(body: body)
+    self.println(braces.close)
+  }
+
   private func printMain() {
   private func printMain() {
     self.printer.print("""
     self.printer.print("""
     //
     //
@@ -111,11 +157,23 @@ class Generator {
     }
     }
     self.println()
     self.println()
 
 
-    if self.options.generateServer {
+    if self.options.generateServer || self.options.generateAsyncServer {
       for service in self.file.services {
       for service in self.file.services {
         self.service = service
         self.service = service
         printServer()
         printServer()
       }
       }
     }
     }
   }
   }
+
+  func printAvailabilityForAsyncAwait() {
+    self.println("@available(macOS 12, iOS 15, tvOS 15, watchOS 8, *)")
+  }
+
+  func printIfCompilerGuardForAsyncAwait() {
+    self.println("#if compiler(>=5.5)")
+  }
+
+  func printEndCompilerGuardForAsyncAwait() {
+    self.println("#endif // compiler(>=5.5)")
+  }
 }
 }

+ 68 - 0
Sources/protoc-gen-grpc-swift/Types.swift

@@ -0,0 +1,68 @@
+/*
+ * Copyright 2021, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+enum Types {
+  static let serverContext = "GRPCAsyncServerCallContext"
+  static let serverHandler = "GRPCAsyncServerHandler"
+
+  static let clientCallOptions = "CallOptions"
+
+  private static let unaryCall = "AsyncUnaryCall"
+  private static let clientStreamingCall = "AsyncClientStreamingCall"
+  private static let serverStreamingCall = "AsyncServerStreamingCall"
+  private static let bidirectionalStreamingCall = "AsyncBidirectionalStreamingCall"
+
+  static func requestStream(of type: String) -> String {
+    return "GRPCAsyncRequestStream<\(type)>"
+  }
+
+  static func responseStream(of type: String) -> String {
+    return "GRPCAsyncResponseStream<\(type)>"
+  }
+
+  static func responseStreamWriter(of type: String) -> String {
+    return "GRPCAsyncResponseStreamWriter<\(type)>"
+  }
+
+  static func serializer(for type: String) -> String {
+    return "ProtobufSerializer<\(type)>"
+  }
+
+  static func deserializer(for type: String) -> String {
+    return "ProtobufDeserializer<\(type)>"
+  }
+
+  static func call(for streamingType: StreamingType, withGRPCPrefix: Bool = true) -> String {
+    let typeName: String
+
+    switch streamingType {
+    case .unary:
+      typeName = Types.unaryCall
+    case .clientStreaming:
+      typeName = Types.clientStreamingCall
+    case .serverStreaming:
+      typeName = Types.serverStreamingCall
+    case .bidirectionalStreaming:
+      typeName = Types.bidirectionalStreamingCall
+    }
+
+    if withGRPCPrefix {
+      return "GRPC" + typeName
+    } else {
+      return typeName
+    }
+  }
+}

+ 19 - 0
Sources/protoc-gen-grpc-swift/options.swift

@@ -52,9 +52,14 @@ final class GeneratorOptions {
   }
   }
 
 
   private(set) var visibility = Visibility.internal
   private(set) var visibility = Visibility.internal
+
   private(set) var generateServer = true
   private(set) var generateServer = true
+  private(set) var generateAsyncServer = false
+
   private(set) var generateClient = true
   private(set) var generateClient = true
+  private(set) var generateAsyncClient = false
   private(set) var generateTestClient = false
   private(set) var generateTestClient = false
+
   private(set) var keepMethodCasing = false
   private(set) var keepMethodCasing = false
   private(set) var protoToModuleMappings = ProtoFileToModuleMappings()
   private(set) var protoToModuleMappings = ProtoFileToModuleMappings()
   private(set) var fileNaming = FileNaming.FullPath
   private(set) var fileNaming = FileNaming.FullPath
@@ -79,6 +84,13 @@ final class GeneratorOptions {
           throw GenerationError.invalidParameterValue(name: pair.key, value: pair.value)
           throw GenerationError.invalidParameterValue(name: pair.key, value: pair.value)
         }
         }
 
 
+      case "ExperimentalAsyncServer":
+        if let value = Bool(pair.value) {
+          self.generateAsyncServer = value
+        } else {
+          throw GenerationError.invalidParameterValue(name: pair.key, value: pair.value)
+        }
+
       case "Client":
       case "Client":
         if let value = Bool(pair.value) {
         if let value = Bool(pair.value) {
           self.generateClient = value
           self.generateClient = value
@@ -86,6 +98,13 @@ final class GeneratorOptions {
           throw GenerationError.invalidParameterValue(name: pair.key, value: pair.value)
           throw GenerationError.invalidParameterValue(name: pair.key, value: pair.value)
         }
         }
 
 
+      case "ExperimentalAsyncClient":
+        if let value = Bool(pair.value) {
+          self.generateAsyncClient = value
+        } else {
+          throw GenerationError.invalidParameterValue(name: pair.key, value: pair.value)
+        }
+
       case "TestClient":
       case "TestClient":
         if let value = Bool(pair.value) {
         if let value = Bool(pair.value) {
           self.generateTestClient = value
           self.generateTestClient = value