Sfoglia il codice sorgente

Merge pull request #1428 from glbrntt/gb-merge-async

Merge the async branch into main
George Barnett 3 anni fa
parent
commit
3792d550bd

+ 243 - 0
Sources/protoc-gen-grpc-swift/Generator-Client+AsyncAwait.swift

@@ -0,0 +1,243 @@
+/*
+ * Copyright 2021, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import SwiftProtobuf
+import SwiftProtobufPluginLibrary
+
+// MARK: - Client protocol
+
+extension Generator {
+  internal func printAsyncServiceClientProtocol() {
+    let comments = self.service.protoSourceComments()
+    if !comments.isEmpty {
+      // Source comments already have the leading '///'
+      self.println(comments, newline: false)
+    }
+
+    self.printAvailabilityForAsyncAwait()
+    self.println("\(self.access) protocol \(self.asyncClientProtocolName): GRPCClient {")
+    self.withIndentation {
+      self.println("static var serviceDescriptor: GRPCServiceDescriptor { get }")
+      self.println("var interceptors: \(self.clientInterceptorProtocolName)? { get }")
+
+      for method in service.methods {
+        self.println()
+        self.method = method
+
+        let rpcType = streamingType(self.method)
+        let callType = Types.call(for: rpcType)
+
+        let arguments: [String]
+        switch rpcType {
+        case .unary, .serverStreaming:
+          arguments = [
+            "_ request: \(self.methodInputName)",
+            "callOptions: \(Types.clientCallOptions)?",
+          ]
+
+        case .clientStreaming, .bidirectionalStreaming:
+          arguments = [
+            "callOptions: \(Types.clientCallOptions)?",
+          ]
+        }
+
+        self.printFunction(
+          name: self.methodMakeFunctionCallName,
+          arguments: arguments,
+          returnType: "\(callType)<\(self.methodInputName), \(self.methodOutputName)>",
+          bodyBuilder: nil
+        )
+      }
+    }
+    self.println("}") // protocol
+  }
+}
+
+// MARK: - Client protocol default implementation: Calls
+
+extension Generator {
+  internal func printAsyncClientProtocolExtension() {
+    self.printAvailabilityForAsyncAwait()
+    self.withIndentation("extension \(self.asyncClientProtocolName)", braces: .curly) {
+      // Service descriptor.
+      self.withIndentation(
+        "\(self.access) static var serviceDescriptor: GRPCServiceDescriptor",
+        braces: .curly
+      ) {
+        self.println("return \(self.serviceClientMetadata).serviceDescriptor")
+      }
+
+      self.println()
+
+      // Interceptor factory.
+      self.withIndentation(
+        "\(self.access) var interceptors: \(self.clientInterceptorProtocolName)?",
+        braces: .curly
+      ) {
+        self.println("return nil")
+      }
+
+      // 'Unsafe' calls.
+      for method in self.service.methods {
+        self.println()
+        self.method = method
+
+        let rpcType = streamingType(self.method)
+        let callType = Types.call(for: rpcType)
+        let callTypeWithoutPrefix = Types.call(for: rpcType, withGRPCPrefix: false)
+
+        switch rpcType {
+        case .unary, .serverStreaming:
+          self.printFunction(
+            name: self.methodMakeFunctionCallName,
+            arguments: [
+              "_ request: \(self.methodInputName)",
+              "callOptions: \(Types.clientCallOptions)? = nil",
+            ],
+            returnType: "\(callType)<\(self.methodInputName), \(self.methodOutputName)>",
+            access: self.access
+          ) {
+            self.withIndentation("return self.make\(callTypeWithoutPrefix)", braces: .round) {
+              self.println("path: \(self.methodPathUsingClientMetadata),")
+              self.println("request: request,")
+              self.println("callOptions: callOptions ?? self.defaultCallOptions,")
+              self.println(
+                "interceptors: self.interceptors?.\(self.methodInterceptorFactoryName)() ?? []"
+              )
+            }
+          }
+
+        case .clientStreaming, .bidirectionalStreaming:
+          self.printFunction(
+            name: self.methodMakeFunctionCallName,
+            arguments: ["callOptions: \(Types.clientCallOptions)? = nil"],
+            returnType: "\(callType)<\(self.methodInputName), \(self.methodOutputName)>",
+            access: self.access
+          ) {
+            self.withIndentation("return self.make\(callTypeWithoutPrefix)", braces: .round) {
+              self.println("path: \(self.methodPathUsingClientMetadata),")
+              self.println("callOptions: callOptions ?? self.defaultCallOptions,")
+              self.println(
+                "interceptors: self.interceptors?.\(self.methodInterceptorFactoryName)() ?? []"
+              )
+            }
+          }
+        }
+      }
+    }
+  }
+}
+
+// MARK: - Client protocol extension: "Simple, but safe" call wrappers.
+
+extension Generator {
+  internal func printAsyncClientProtocolSafeWrappersExtension() {
+    self.printAvailabilityForAsyncAwait()
+    self.withIndentation("extension \(self.asyncClientProtocolName)", braces: .curly) {
+      for (i, method) in self.service.methods.enumerated() {
+        self.method = method
+
+        let rpcType = streamingType(self.method)
+        let callTypeWithoutPrefix = Types.call(for: rpcType, withGRPCPrefix: false)
+
+        let streamsResponses = [.serverStreaming, .bidirectionalStreaming].contains(rpcType)
+        let streamsRequests = [.clientStreaming, .bidirectionalStreaming].contains(rpcType)
+
+        // (protocol, requires sendable)
+        let sequenceProtocols: [(String, Bool)?] = streamsRequests
+          ? [("Sequence", false), ("AsyncSequence", true)]
+          : [nil]
+
+        for (j, sequenceProtocol) in sequenceProtocols.enumerated() {
+          // Print a new line if this is not the first function in the extension.
+          if i > 0 || j > 0 {
+            self.println()
+          }
+          let functionName = streamsRequests
+            ? "\(self.methodFunctionName)<RequestStream>"
+            : self.methodFunctionName
+          let requestParamName = streamsRequests ? "requests" : "request"
+          let requestParamType = streamsRequests ? "RequestStream" : self.methodInputName
+          let returnType = streamsResponses
+            ? Types.responseStream(of: self.methodOutputName)
+            : self.methodOutputName
+          let maybeWhereClause = sequenceProtocol.map { protocolName, mustBeSendable -> String in
+            let constraints = [
+              "RequestStream: \(protocolName)" + (mustBeSendable ? " & Sendable" : ""),
+              "RequestStream.Element == \(self.methodInputName)",
+            ]
+
+            return "where " + constraints.joined(separator: ", ")
+          }
+          self.printFunction(
+            name: functionName,
+            arguments: [
+              "_ \(requestParamName): \(requestParamType)",
+              "callOptions: \(Types.clientCallOptions)? = nil",
+            ],
+            returnType: returnType,
+            access: self.access,
+            async: !streamsResponses,
+            throws: !streamsResponses,
+            genericWhereClause: maybeWhereClause
+          ) {
+            self.withIndentation(
+              "return\(!streamsResponses ? " try await" : "") self.perform\(callTypeWithoutPrefix)",
+              braces: .round
+            ) {
+              self.println("path: \(self.methodPathUsingClientMetadata),")
+              self.println("\(requestParamName): \(requestParamName),")
+              self.println("callOptions: callOptions ?? self.defaultCallOptions,")
+              self.println(
+                "interceptors: self.interceptors?.\(self.methodInterceptorFactoryName)() ?? []"
+              )
+            }
+          }
+        }
+      }
+    }
+  }
+}
+
+// MARK: - Client protocol implementation
+
+extension Generator {
+  internal func printAsyncServiceClientImplementation() {
+    self.printAvailabilityForAsyncAwait()
+    self.withIndentation(
+      "\(self.access) struct \(self.asyncClientStructName): \(self.asyncClientProtocolName)",
+      braces: .curly
+    ) {
+      self.println("\(self.access) var channel: GRPCChannel")
+      self.println("\(self.access) var defaultCallOptions: CallOptions")
+      self.println("\(self.access) var interceptors: \(self.clientInterceptorProtocolName)?")
+      self.println()
+
+      self.println("\(self.access) init(")
+      self.withIndentation {
+        self.println("channel: GRPCChannel,")
+        self.println("defaultCallOptions: CallOptions = CallOptions(),")
+        self.println("interceptors: \(self.clientInterceptorProtocolName)? = nil")
+      }
+      self.println(") {")
+      self.withIndentation {
+        self.println("self.channel = channel")
+        self.println("self.defaultCallOptions = defaultCallOptions")
+        self.println("self.interceptors = interceptors")
+      }
+      self.println("}")
+    }
+  }
+}

+ 119 - 21
Sources/protoc-gen-grpc-swift/Generator-Client.swift

@@ -25,9 +25,25 @@ extension Generator {
       self.println()
       self.printClientProtocolExtension()
       self.println()
+      self.printClassBackedServiceClientImplementation()
+      self.println()
+      self.printStructBackedServiceClientImplementation()
+      self.println()
+      self.printIfCompilerGuardForAsyncAwait()
+      self.printAsyncServiceClientProtocol()
+      self.println()
+      self.printAsyncClientProtocolExtension()
+      self.println()
+      self.printAsyncClientProtocolSafeWrappersExtension()
+      self.println()
+      self.printAsyncServiceClientImplementation()
+      self.println()
+      self.printEndCompilerGuardForAsyncAwait()
+      self.println()
+      // Both implementations share definitions for interceptors and metadata.
       self.printServiceClientInterceptorFactoryProtocol()
       self.println()
-      self.printServiceClientImplementation()
+      self.printClientMetadata()
     }
 
     if self.options.generateTestClient {
@@ -41,19 +57,39 @@ extension Generator {
     arguments: [String],
     returnType: String?,
     access: String? = nil,
+    sendable: Bool = false,
+    async: Bool = false,
+    throws: Bool = false,
+    genericWhereClause: String? = nil,
     bodyBuilder: (() -> Void)?
   ) {
     // Add a space after access, if it exists.
-    let accessOrEmpty = access.map { $0 + " " } ?? ""
-    let `return` = returnType.map { "-> " + $0 } ?? ""
+    let functionHead = (access.map { $0 + " " } ?? "") + (sendable ? "@Sendable " : "")
+    let `return` = returnType.map { " -> " + $0 } ?? ""
+    let genericWhere = genericWhereClause.map { " " + $0 } ?? ""
+
+    let asyncThrows: String
+    switch (async, `throws`) {
+    case (true, true):
+      asyncThrows = " async throws"
+    case (true, false):
+      asyncThrows = " async"
+    case (false, true):
+      asyncThrows = " throws"
+    case (false, false):
+      asyncThrows = ""
+    }
 
     let hasBody = bodyBuilder != nil
 
     if arguments.isEmpty {
       // Don't bother splitting across multiple lines if there are no arguments.
-      self.println("\(accessOrEmpty)func \(name)() \(`return`)", newline: !hasBody)
+      self.println(
+        "\(functionHead)func \(name)()\(asyncThrows)\(`return`)\(genericWhere)",
+        newline: !hasBody
+      )
     } else {
-      self.println("\(accessOrEmpty)func \(name)(")
+      self.println("\(functionHead)func \(name)(")
       self.withIndentation {
         // Add a comma after each argument except the last.
         arguments.forEach(beforeLast: {
@@ -62,7 +98,7 @@ extension Generator {
           self.println($0)
         })
       }
-      self.println(") \(`return`)", newline: !hasBody)
+      self.println(")\(asyncThrows)\(`return`)\(genericWhere)", newline: !hasBody)
     }
 
     if let bodyBuilder = bodyBuilder {
@@ -123,7 +159,7 @@ extension Generator {
   }
 
   private func printServiceClientInterceptorFactoryProtocol() {
-    self.println("\(self.access) protocol \(self.clientInterceptorProtocolName) {")
+    self.println("\(self.access) protocol \(self.clientInterceptorProtocolName): GRPCSendable {")
     self.withIndentation {
       // Method specific interceptors.
       for method in service.methods {
@@ -152,10 +188,62 @@ extension Generator {
     )
   }
 
-  private func printServiceClientImplementation() {
+  private func printClassBackedServiceClientImplementation() {
+    self.printIfCompilerGuardForAsyncAwait()
+    self.println("@available(*, deprecated)")
+    self.println("extension \(clientClassName): @unchecked Sendable {}")
+    self.printEndCompilerGuardForAsyncAwait()
+    self.println()
+    self.println("@available(*, deprecated, renamed: \"\(clientStructName)\")")
     println("\(access) final class \(clientClassName): \(clientProtocolName) {")
     self.withIndentation {
+      println("private let lock = Lock()")
+      println("private var _defaultCallOptions: CallOptions")
+      println("private var _interceptors: \(clientInterceptorProtocolName)?")
+
       println("\(access) let channel: GRPCChannel")
+      println("\(access) var defaultCallOptions: CallOptions {")
+      self.withIndentation {
+        println("get { self.lock.withLock { return self._defaultCallOptions } }")
+        println("set { self.lock.withLockVoid { self._defaultCallOptions = newValue } }")
+      }
+      self.println("}")
+      println("\(access) var interceptors: \(clientInterceptorProtocolName)? {")
+      self.withIndentation {
+        println("get { self.lock.withLock { return self._interceptors } }")
+        println("set { self.lock.withLockVoid { self._interceptors = newValue } }")
+      }
+      println("}")
+      println()
+      println("/// Creates a client for the \(servicePath) service.")
+      println("///")
+      self.printParameters()
+      println("///   - channel: `GRPCChannel` to the service host.")
+      println(
+        "///   - defaultCallOptions: Options to use for each service call if the user doesn't provide them."
+      )
+      println("///   - interceptors: A factory providing interceptors for each RPC.")
+      println("\(access) init(")
+      self.withIndentation {
+        println("channel: GRPCChannel,")
+        println("defaultCallOptions: CallOptions = CallOptions(),")
+        println("interceptors: \(clientInterceptorProtocolName)? = nil")
+      }
+      self.println(") {")
+      self.withIndentation {
+        println("self.channel = channel")
+        println("self._defaultCallOptions = defaultCallOptions")
+        println("self._interceptors = interceptors")
+      }
+      self.println("}")
+    }
+    println("}")
+  }
+
+  private func printStructBackedServiceClientImplementation() {
+    println("\(access) struct \(clientStructName): \(clientProtocolName) {")
+    self.withIndentation {
+      println("\(access) var channel: GRPCChannel")
       println("\(access) var defaultCallOptions: CallOptions")
       println("\(access) var interceptors: \(clientInterceptorProtocolName)?")
       println()
@@ -220,7 +308,7 @@ extension Generator {
     ) {
       self.println("return self.makeUnaryCall(")
       self.withIndentation {
-        self.println("path: \(self.methodPath),")
+        self.println("path: \(self.methodPathUsingClientMetadata),")
         self.println("request: request,")
         self.println("callOptions: callOptions ?? self.defaultCallOptions,")
         self.println(
@@ -247,7 +335,7 @@ extension Generator {
     ) {
       self.println("return self.makeServerStreamingCall(")
       self.withIndentation {
-        self.println("path: \(self.methodPath),")
+        self.println("path: \(self.methodPathUsingClientMetadata),")
         self.println("request: request,")
         self.println("callOptions: callOptions ?? self.defaultCallOptions,")
         self.println(
@@ -278,7 +366,7 @@ extension Generator {
     ) {
       self.println("return self.makeClientStreamingCall(")
       self.withIndentation {
-        self.println("path: \(self.methodPath),")
+        self.println("path: \(self.methodPathUsingClientMetadata),")
         self.println("callOptions: callOptions ?? self.defaultCallOptions,")
         self.println(
           "interceptors: self.interceptors?.\(self.methodInterceptorFactoryName)() ?? []"
@@ -305,7 +393,7 @@ extension Generator {
     ) {
       self.println("return self.makeBidirectionalStreamingCall(")
       self.withIndentation {
-        self.println("path: \(self.methodPath),")
+        self.println("path: \(self.methodPathUsingClientMetadata),")
         self.println("callOptions: callOptions ?? self.defaultCallOptions,")
         self.println(
           "interceptors: self.interceptors?.\(self.methodInterceptorFactoryName)() ?? [],"
@@ -428,7 +516,7 @@ extension Generator {
     ) {
       self
         .println(
-          "return self.fakeChannel.\(factory)(path: \(self.methodPath), requestHandler: requestHandler)"
+          "return self.fakeChannel.\(factory)(path: \(self.methodPathUsingClientMetadata), requestHandler: requestHandler)"
         )
     }
   }
@@ -438,21 +526,31 @@ extension Generator {
       .println("/// Returns true if there are response streams enqueued for '\(self.method.name)'")
     self.println("\(self.access) var has\(self.method.name)ResponsesRemaining: Bool {")
     self.withIndentation {
-      self.println("return self.fakeChannel.hasFakeResponseEnqueued(forPath: \(self.methodPath))")
+      self.println(
+        "return self.fakeChannel.hasFakeResponseEnqueued(forPath: \(self.methodPathUsingClientMetadata))"
+      )
     }
     self.println("}")
   }
 
   private func printTestClient() {
-    self
-      .println(
-        "\(self.access) final class \(self.testClientClassName): \(self.clientProtocolName) {"
-      )
+    self.printIfCompilerGuardForAsyncAwait()
+    self.println("@available(swift, deprecated: 5.6)")
+    self.println("extension \(self.testClientClassName): @unchecked Sendable {}")
+    self.printEndCompilerGuardForAsyncAwait()
+    self.println()
+    self.println(
+      "@available(swift, deprecated: 5.6, message: \"Test clients are not Sendable "
+        + "but the 'GRPCClient' API requires clients to be Sendable. Using a localhost client and "
+        + "server is the recommended alternative.\")"
+    )
+    self.println(
+      "\(self.access) final class \(self.testClientClassName): \(self.clientProtocolName) {"
+    )
     self.withIndentation {
       self.println("private let fakeChannel: FakeChannel")
-      self.println("\(self.access) var defaultCallOptions: CallOptions")
-      self.println("\(self.access) var interceptors: \(self.clientInterceptorProtocolName)?")
-
+      self.println("\(access) var defaultCallOptions: CallOptions")
+      self.println("\(access) var interceptors: \(clientInterceptorProtocolName)?")
       self.println()
       self.println("\(self.access) var channel: GRPCChannel {")
       self.withIndentation {

+ 82 - 0
Sources/protoc-gen-grpc-swift/Generator-Metadata.swift

@@ -0,0 +1,82 @@
+/*
+ * Copyright 2021, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import SwiftProtobuf
+import SwiftProtobufPluginLibrary
+
+extension Generator {
+  internal func printServerMetadata() {
+    self.printMetadata(server: true)
+  }
+
+  internal func printClientMetadata() {
+    self.printMetadata(server: false)
+  }
+
+  private func printMetadata(server: Bool) {
+    let enumName = server ? self.serviceServerMetadata : self.serviceClientMetadata
+
+    self.withIndentation("\(self.access) enum \(enumName)", braces: .curly) {
+      self.println("\(self.access) static let serviceDescriptor = GRPCServiceDescriptor(")
+      self.withIndentation {
+        self.println("name: \(quoted(self.service.name)),")
+        self.println("fullName: \(quoted(self.servicePath)),")
+        self.println("methods: [")
+        for method in self.service.methods {
+          self.method = method
+          self.withIndentation {
+            self.println("\(enumName).Methods.\(self.methodFunctionName),")
+          }
+        }
+        self.println("]")
+      }
+      self.println(")")
+      self.println()
+
+      self.withIndentation("\(self.access) enum Methods", braces: .curly) {
+        for (offset, method) in self.service.methods.enumerated() {
+          self.method = method
+          self.println(
+            "\(self.access) static let \(self.methodFunctionName) = GRPCMethodDescriptor("
+          )
+          self.withIndentation {
+            self.println("name: \(quoted(self.method.name)),")
+            self.println("path: \(quoted(self.methodPath)),")
+            self.println("type: \(streamingType(self.method).asGRPCCallTypeCase)")
+          }
+          self.println(")")
+
+          if (offset + 1) < self.service.methods.count {
+            self.println()
+          }
+        }
+      }
+    }
+  }
+}
+
+extension Generator {
+  internal var serviceServerMetadata: String {
+    return nameForPackageService(self.file, self.service) + "ServerMetadata"
+  }
+
+  internal var serviceClientMetadata: String {
+    return nameForPackageService(self.file, self.service) + "ClientMetadata"
+  }
+
+  internal var methodPathUsingClientMetadata: String {
+    return "\(self.serviceClientMetadata).Methods.\(self.methodFunctionName).path"
+  }
+}

+ 38 - 1
Sources/protoc-gen-grpc-swift/Generator-Names.swift

@@ -78,10 +78,22 @@ extension Generator {
     return nameForPackageService(file, service) + "Provider"
   }
 
+  internal var asyncProviderName: String {
+    return nameForPackageService(file, service) + "AsyncProvider"
+  }
+
   internal var clientClassName: String {
     return nameForPackageService(file, service) + "Client"
   }
 
+  internal var clientStructName: String {
+    return nameForPackageService(file, service) + "NIOClient"
+  }
+
+  internal var asyncClientStructName: String {
+    return nameForPackageService(file, service) + "AsyncClient"
+  }
+
   internal var testClientClassName: String {
     return nameForPackageService(self.file, self.service) + "TestClient"
   }
@@ -90,6 +102,10 @@ extension Generator {
     return nameForPackageService(file, service) + "ClientProtocol"
   }
 
+  internal var asyncClientProtocolName: String {
+    return nameForPackageService(file, service) + "AsyncClientProtocol"
+  }
+
   internal var clientInterceptorProtocolName: String {
     return nameForPackageService(file, service) + "ClientInterceptorFactoryProtocol"
   }
@@ -111,6 +127,19 @@ extension Generator {
     return self.sanitize(fieldName: name)
   }
 
+  internal var methodMakeFunctionCallName: String {
+    let name: String
+
+    if self.options.keepMethodCasing {
+      name = self.method.name
+    } else {
+      name = NamingUtils.toUpperCamelCase(self.method.name)
+    }
+
+    let fnName = "make\(name)Call"
+    return self.sanitize(fieldName: fnName)
+  }
+
   internal func sanitize(fieldName string: String) -> String {
     if quotableFieldNames.contains(string) {
       return "`\(string)`"
@@ -139,6 +168,14 @@ extension Generator {
   }
 
   internal var methodPath: String {
-    return "\"/" + self.servicePath + "/" + method.name + "\""
+    return "/" + self.fullMethodName
   }
+
+  internal var fullMethodName: String {
+    return self.servicePath + "/" + self.method.name
+  }
+}
+
+internal func quoted(_ str: String) -> String {
+  return "\"" + str + "\""
 }

+ 184 - 0
Sources/protoc-gen-grpc-swift/Generator-Server+AsyncAwait.swift

@@ -0,0 +1,184 @@
+/*
+ * Copyright 2021, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import SwiftProtobuf
+import SwiftProtobufPluginLibrary
+
+// MARK: - Protocol
+
+extension Generator {
+  internal func printServerProtocolAsyncAwait() {
+    let sourceComments = self.service.protoSourceComments()
+    if !sourceComments.isEmpty {
+      // Source comments already have the leading '///'
+      self.println(sourceComments, newline: false)
+      self.println("///")
+    }
+    self.println("/// To implement a server, implement an object which conforms to this protocol.")
+    self.printAvailabilityForAsyncAwait()
+    self.withIndentation(
+      "\(self.access) protocol \(self.asyncProviderName): CallHandlerProvider",
+      braces: .curly
+    ) {
+      self.println("static var serviceDescriptor: GRPCServiceDescriptor { get }")
+      self.println("var interceptors: \(self.serverInterceptorProtocolName)? { get }")
+
+      for method in service.methods {
+        self.method = method
+        self.println()
+        self.printRPCProtocolRequirement()
+      }
+    }
+  }
+
+  private func printRPCProtocolRequirement() {
+    // Print any comments; skip the newline as source comments include them already.
+    self.println(self.method.protoSourceComments(), newline: false)
+
+    let arguments: [String]
+    let returnType: String?
+
+    switch streamingType(self.method) {
+    case .unary:
+      arguments = [
+        "request: \(self.methodInputName)",
+        "context: \(Types.serverContext)",
+      ]
+      returnType = self.methodOutputName
+
+    case .clientStreaming:
+      arguments = [
+        "requestStream: \(Types.requestStream(of: self.methodInputName))",
+        "context: \(Types.serverContext)",
+      ]
+      returnType = self.methodOutputName
+
+    case .serverStreaming:
+      arguments = [
+        "request: \(self.methodInputName)",
+        "responseStream: \(Types.responseStreamWriter(of: self.methodOutputName))",
+        "context: \(Types.serverContext)",
+      ]
+      returnType = nil
+
+    case .bidirectionalStreaming:
+      arguments = [
+        "requestStream: \(Types.requestStream(of: self.methodInputName))",
+        "responseStream: \(Types.responseStreamWriter(of: self.methodOutputName))",
+        "context: \(Types.serverContext)",
+      ]
+      returnType = nil
+    }
+
+    self.printFunction(
+      name: self.methodFunctionName,
+      arguments: arguments,
+      returnType: returnType,
+      sendable: true,
+      async: true,
+      throws: true,
+      bodyBuilder: nil
+    )
+  }
+}
+
+// MARK: - Protocol Extension; RPC handling
+
+extension Generator {
+  internal func printServerProtocolExtensionAsyncAwait() {
+    // Default extension to provide the service name and routing for methods.
+    self.printAvailabilityForAsyncAwait()
+    self.withIndentation("extension \(self.asyncProviderName)", braces: .curly) {
+      self.withIndentation(
+        "\(self.access) static var serviceDescriptor: GRPCServiceDescriptor",
+        braces: .curly
+      ) {
+        self.println("return \(self.serviceServerMetadata).serviceDescriptor")
+      }
+
+      self.println()
+
+      // This fulfils a requirement from 'CallHandlerProvider'
+      self.withIndentation("\(self.access) var serviceName: Substring", braces: .curly) {
+        /// This API returns a Substring (hence the '[...]')
+        self.println("return \(self.serviceServerMetadata).serviceDescriptor.fullName[...]")
+      }
+
+      self.println()
+
+      // Default nil interceptor factory.
+      self.withIndentation(
+        "\(self.access) var interceptors: \(self.serverInterceptorProtocolName)?",
+        braces: .curly
+      ) {
+        self.println("return nil")
+      }
+
+      self.println()
+
+      self.printFunction(
+        name: "handle",
+        arguments: [
+          "method name: Substring",
+          "context: CallHandlerContext",
+        ],
+        returnType: "GRPCServerHandlerProtocol?",
+        access: self.access
+      ) {
+        self.println("switch name {")
+        for method in self.service.methods {
+          self.method = method
+
+          let requestType = self.methodInputName
+          let responseType = self.methodOutputName
+          let interceptorFactory = self.methodInterceptorFactoryName
+          let functionName = self.methodFunctionName
+
+          self.withIndentation("case \"\(self.method.name)\":", braces: .none) {
+            self.withIndentation("return \(Types.serverHandler)", braces: .round) {
+              self.println("context: context,")
+              self.println("requestDeserializer: \(Types.deserializer(for: requestType))(),")
+              self.println("responseSerializer: \(Types.serializer(for: responseType))(),")
+              self.println("interceptors: self.interceptors?.\(interceptorFactory)() ?? [],")
+              switch streamingType(self.method) {
+              case .unary:
+                self.println("wrapping: self.\(functionName)(request:context:)")
+
+              case .clientStreaming:
+                self.println("wrapping: self.\(functionName)(requestStream:context:)")
+
+              case .serverStreaming:
+                self.println("wrapping: self.\(functionName)(request:responseStream:context:)")
+
+              case .bidirectionalStreaming:
+                self.println(
+                  "wrapping: self.\(functionName)(requestStream:responseStream:context:)"
+                )
+              }
+            }
+          }
+        }
+
+        // Default case.
+        self.println("default:")
+        self.withIndentation {
+          self.println("return nil")
+        }
+
+        self.println("}") // switch
+      }
+    }
+  }
+}

+ 22 - 6
Sources/protoc-gen-grpc-swift/Generator-Server.swift

@@ -19,11 +19,24 @@ import SwiftProtobufPluginLibrary
 
 extension Generator {
   internal func printServer() {
-    self.printServerProtocol()
-    self.println()
-    self.printServerProtocolExtension()
-    self.println()
-    self.printServerInterceptorFactoryProtocol()
+    if self.options.generateServer {
+      self.printServerProtocol()
+      self.println()
+      self.printServerProtocolExtension()
+      self.println()
+      self.printIfCompilerGuardForAsyncAwait()
+      self.println()
+      self.printServerProtocolAsyncAwait()
+      self.println()
+      self.printServerProtocolExtensionAsyncAwait()
+      self.println()
+      self.printEndCompilerGuardForAsyncAwait()
+      self.println()
+      // Both implementations share definitions for interceptors and metadata.
+      self.printServerInterceptorFactoryProtocol()
+      self.println()
+      self.printServerMetadata()
+    }
   }
 
   private func printServerProtocol() {
@@ -71,7 +84,10 @@ extension Generator {
   private func printServerProtocolExtension() {
     self.println("extension \(self.providerName) {")
     self.withIndentation {
-      self.println("\(self.access) var serviceName: Substring { return \"\(self.servicePath)\" }")
+      self.withIndentation("\(self.access) var serviceName: Substring", braces: .curly) {
+        /// This API returns a Substring (hence the '[...]')
+        self.println("return \(self.serviceServerMetadata).serviceDescriptor.fullName[...]")
+      }
       self.println()
       self.println(
         "/// Determines, calls and returns the appropriate request handler, depending on the request's method."

+ 59 - 0
Sources/protoc-gen-grpc-swift/Generator.swift

@@ -61,6 +61,52 @@ class Generator {
     self.outdent()
   }
 
+  internal enum Braces {
+    case none
+    case curly
+    case round
+
+    var open: String {
+      switch self {
+      case .none:
+        return ""
+      case .curly:
+        return "{"
+      case .round:
+        return "("
+      }
+    }
+
+    var close: String {
+      switch self {
+      case .none:
+        return ""
+      case .curly:
+        return "}"
+      case .round:
+        return ")"
+      }
+    }
+  }
+
+  internal func withIndentation(
+    _ header: String,
+    braces: Braces,
+    _ body: () -> Void
+  ) {
+    let spaceBeforeOpeningBrace: Bool
+    switch braces {
+    case .curly:
+      spaceBeforeOpeningBrace = true
+    case .round, .none:
+      spaceBeforeOpeningBrace = false
+    }
+
+    self.println(header + "\(spaceBeforeOpeningBrace ? " " : "")" + "\(braces.open)")
+    self.withIndentation(body: body)
+    self.println(braces.close)
+  }
+
   private func printMain() {
     self.printer.print("""
     //
@@ -90,6 +136,7 @@ class Generator {
     let moduleNames = [
       self.options.gRPCModuleName,
       "NIO",
+      "NIOConcurrencyHelpers",
       self.options.swiftProtobufModuleName,
     ]
 
@@ -118,4 +165,16 @@ class Generator {
       }
     }
   }
+
+  func printAvailabilityForAsyncAwait() {
+    self.println("@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)")
+  }
+
+  func printIfCompilerGuardForAsyncAwait() {
+    self.println("#if compiler(>=5.6)")
+  }
+
+  func printEndCompilerGuardForAsyncAwait() {
+    self.println("#endif // compiler(>=5.6)")
+  }
 }

+ 15 - 0
Sources/protoc-gen-grpc-swift/StreamingType.swift

@@ -22,6 +22,21 @@ internal enum StreamingType {
   case bidirectionalStreaming
 }
 
+extension StreamingType {
+  internal var asGRPCCallTypeCase: String {
+    switch self {
+    case .unary:
+      return "GRPCCallType.unary"
+    case .clientStreaming:
+      return "GRPCCallType.clientStreaming"
+    case .serverStreaming:
+      return "GRPCCallType.serverStreaming"
+    case .bidirectionalStreaming:
+      return "GRPCCallType.bidirectionalStreaming"
+    }
+  }
+}
+
 internal func streamingType(_ method: MethodDescriptor) -> StreamingType {
   if method.proto.clientStreaming {
     if method.proto.serverStreaming {

+ 68 - 0
Sources/protoc-gen-grpc-swift/Types.swift

@@ -0,0 +1,68 @@
+/*
+ * Copyright 2021, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+enum Types {
+  static let serverContext = "GRPCAsyncServerCallContext"
+  static let serverHandler = "GRPCAsyncServerHandler"
+
+  static let clientCallOptions = "CallOptions"
+
+  private static let unaryCall = "AsyncUnaryCall"
+  private static let clientStreamingCall = "AsyncClientStreamingCall"
+  private static let serverStreamingCall = "AsyncServerStreamingCall"
+  private static let bidirectionalStreamingCall = "AsyncBidirectionalStreamingCall"
+
+  static func requestStream(of type: String) -> String {
+    return "GRPCAsyncRequestStream<\(type)>"
+  }
+
+  static func responseStream(of type: String) -> String {
+    return "GRPCAsyncResponseStream<\(type)>"
+  }
+
+  static func responseStreamWriter(of type: String) -> String {
+    return "GRPCAsyncResponseStreamWriter<\(type)>"
+  }
+
+  static func serializer(for type: String) -> String {
+    return "ProtobufSerializer<\(type)>"
+  }
+
+  static func deserializer(for type: String) -> String {
+    return "ProtobufDeserializer<\(type)>"
+  }
+
+  static func call(for streamingType: StreamingType, withGRPCPrefix: Bool = true) -> String {
+    let typeName: String
+
+    switch streamingType {
+    case .unary:
+      typeName = Types.unaryCall
+    case .clientStreaming:
+      typeName = Types.clientStreamingCall
+    case .serverStreaming:
+      typeName = Types.serverStreamingCall
+    case .bidirectionalStreaming:
+      typeName = Types.bidirectionalStreamingCall
+    }
+
+    if withGRPCPrefix {
+      return "GRPC" + typeName
+    } else {
+      return typeName
+    }
+  }
+}

+ 7 - 4
Sources/protoc-gen-grpc-swift/main.swift

@@ -80,12 +80,11 @@ func outputFileName(
   }
 }
 
-var generatedFiles: [String: Int] = [:]
-
 func uniqueOutputFileName(
   component: String,
   fileDescriptor: FileDescriptor,
-  fileNamingOption: FileNaming
+  fileNamingOption: FileNaming,
+  generatedFiles: inout [String: Int]
 ) -> String {
   let defaultName = outputFileName(
     component: component,
@@ -121,6 +120,9 @@ func main() throws {
   // Build the SwiftProtobufPluginLibrary model of the plugin input
   let descriptorSet = DescriptorSet(protos: request.protoFile)
 
+  // A count of generated files by desired name (actual name may differ to avoid collisions).
+  var generatedFiles: [String: Int] = [:]
+
   // Only generate output for services.
   for name in request.fileToGenerate {
     let fileDescriptor = descriptorSet.lookupFileDescriptor(protoName: name)
@@ -128,7 +130,8 @@ func main() throws {
       let grpcFileName = uniqueOutputFileName(
         component: "grpc",
         fileDescriptor: fileDescriptor,
-        fileNamingOption: options.fileNaming
+        fileNamingOption: options.fileNaming,
+        generatedFiles: &generatedFiles
       )
       let grpcGenerator = Generator(fileDescriptor, options: options)
       var grpcFile = Google_Protobuf_Compiler_CodeGeneratorResponse.File()

+ 3 - 0
Sources/protoc-gen-grpc-swift/options.swift

@@ -52,9 +52,12 @@ final class GeneratorOptions {
   }
 
   private(set) var visibility = Visibility.internal
+
   private(set) var generateServer = true
+
   private(set) var generateClient = true
   private(set) var generateTestClient = false
+
   private(set) var keepMethodCasing = false
   private(set) var protoToModuleMappings = ProtoFileToModuleMappings()
   private(set) var fileNaming = FileNaming.FullPath