Просмотр исходного кода

Add structured swift represntation for the generated server (#2122)

Motivation:

Follow-up from #2117, we also need the server code.

Modifications:

- Add structured swift for server code

Result:

Can build server code from structured swift
George Barnett 1 год назад
Родитель
Сommit
3cbc033f16

+ 425 - 0
Sources/GRPCCodeGen/Internal/StructuredSwift+Server.swift

@@ -0,0 +1,425 @@
+/*
+ * Copyright 2024, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+extension FunctionSignatureDescription {
+  /// ```
+  /// func <Method>(
+  ///   request: GRPCCore.ServerRequest<Input>,
+  ///   context: GRPCCore.ServerContext
+  /// ) async throws -> GRPCCore.ServerResponse<Output>
+  /// ```
+  static func serverMethod(
+    accessLevel: AccessModifier? = nil,
+    name: String,
+    input: String,
+    output: String,
+    streamingInput: Bool,
+    streamingOutput: Bool
+  ) -> Self {
+    return FunctionSignatureDescription(
+      accessModifier: accessLevel,
+      kind: .function(name: name),
+      parameters: [
+        ParameterDescription(
+          label: "request",
+          type: .serverRequest(forType: input, streaming: streamingInput)
+        ),
+        ParameterDescription(label: "context", type: .serverContext),
+      ],
+      keywords: [.async, .throws],
+      returnType: .identifierType(.serverResponse(forType: output, streaming: streamingOutput))
+    )
+  }
+}
+
+extension ProtocolDescription {
+  /// ```
+  /// protocol <Name>: GRPCCore.RegistrableRPCService {
+  ///   ...
+  /// }
+  /// ```
+  static func streamingService(
+    accessLevel: AccessModifier? = nil,
+    name: String,
+    methods: [MethodDescriptor]
+  ) -> Self {
+    return ProtocolDescription(
+      accessModifier: accessLevel,
+      name: name,
+      conformances: ["GRPCCore.RegistrableRPCService"],
+      members: methods.map { method in
+        .commentable(
+          .preFormatted(method.documentation),
+          .function(
+            signature: .serverMethod(
+              name: method.name.generatedLowerCase,
+              input: method.inputType,
+              output: method.outputType,
+              streamingInput: true,
+              streamingOutput: true
+            )
+          )
+        )
+      }
+    )
+  }
+}
+
+extension ExtensionDescription {
+  /// ```
+  /// extension <ExtensionName> {
+  ///   func registerMethods(with router: inout GRPCCore.RPCRouter) {
+  ///     // ...
+  ///   }
+  /// }
+  /// ```
+  static func registrableRPCServiceDefaultImplementation(
+    accessLevel: AccessModifier? = nil,
+    on extensionName: String,
+    serviceNamespace: String,
+    methods: [MethodDescriptor],
+    serializer: (String) -> String,
+    deserializer: (String) -> String
+  ) -> Self {
+    return ExtensionDescription(
+      onType: extensionName,
+      declarations: [
+        .function(
+          .registerMethods(
+            accessLevel: accessLevel,
+            serviceNamespace: serviceNamespace,
+            methods: methods,
+            serializer: serializer,
+            deserializer: deserializer
+          )
+        )
+      ]
+    )
+  }
+}
+
+extension ProtocolDescription {
+  /// ```
+  /// protocol <Name>: <StreamingProtocol> {
+  ///   ...
+  /// }
+  /// ```
+  static func service(
+    accessLevel: AccessModifier? = nil,
+    name: String,
+    streamingProtocol: String,
+    methods: [MethodDescriptor]
+  ) -> Self {
+    return ProtocolDescription(
+      accessModifier: accessLevel,
+      name: name,
+      conformances: [streamingProtocol],
+      members: methods.map { method in
+        .commentable(
+          .preFormatted(method.documentation),
+          .function(
+            signature: .serverMethod(
+              name: method.name.generatedLowerCase,
+              input: method.inputType,
+              output: method.outputType,
+              streamingInput: method.isInputStreaming,
+              streamingOutput: method.isOutputStreaming
+            )
+          )
+        )
+      }
+    )
+  }
+}
+
+extension FunctionCallDescription {
+  /// ```
+  /// self.<Name>(request: request, context: context)
+  /// ```
+  static func serverMethodCallOnSelf(
+    name: String,
+    requestArgument: Expression = .identifierPattern("request")
+  ) -> Self {
+    return FunctionCallDescription(
+      calledExpression: .memberAccess(
+        MemberAccessDescription(
+          left: .identifierPattern("self"),
+          right: name
+        )
+      ),
+      arguments: [
+        FunctionArgumentDescription(
+          label: "request",
+          expression: requestArgument
+        ),
+        FunctionArgumentDescription(
+          label: "context",
+          expression: .identifierPattern("context")
+        ),
+      ]
+    )
+  }
+}
+
+extension ClosureInvocationDescription {
+  /// ```
+  /// { router, context in
+  ///   try await self.<Method>(
+  ///     request: request,
+  ///     context: context
+  ///   )
+  /// }
+  /// ```
+  static func routerHandlerInvokingRPC(method: String) -> Self {
+    return ClosureInvocationDescription(
+      argumentNames: ["request", "context"],
+      body: [
+        .expression(
+          .unaryKeyword(
+            kind: .try,
+            expression: .unaryKeyword(
+              kind: .await,
+              expression: .functionCall(.serverMethodCallOnSelf(name: method))
+            )
+          )
+        )
+      ]
+    )
+  }
+}
+
+/// ```
+/// router.registerHandler(
+///   forMethod: ...,
+///   deserializer: ...
+///   serializer: ...
+///   handler: { request, context in
+///     // ...
+///   }
+/// )
+/// ```
+extension FunctionCallDescription {
+  static func registerWithRouter(
+    serviceNamespace: String,
+    methodNamespace: String,
+    methodName: String,
+    inputDeserializer: String,
+    outputSerializer: String
+  ) -> Self {
+    return FunctionCallDescription(
+      calledExpression: .memberAccess(
+        .init(left: .identifierPattern("router"), right: "registerHandler")
+      ),
+      arguments: [
+        FunctionArgumentDescription(
+          label: "forMethod",
+          expression: .identifierPattern("\(serviceNamespace).Method.\(methodNamespace).descriptor")
+        ),
+        FunctionArgumentDescription(
+          label: "deserializer",
+          expression: .identifierPattern(inputDeserializer)
+        ),
+        FunctionArgumentDescription(
+          label: "serializer",
+          expression: .identifierPattern(outputSerializer)
+        ),
+        FunctionArgumentDescription(
+          label: "handler",
+          expression: .closureInvocation(.routerHandlerInvokingRPC(method: methodName))
+        ),
+      ]
+    )
+  }
+}
+
+extension FunctionDescription {
+  /// ```
+  /// func registerMethods(with router: inout GRPCCore.RPCRouter) {
+  ///   // ...
+  /// }
+  /// ```
+  static func registerMethods(
+    accessLevel: AccessModifier? = nil,
+    serviceNamespace: String,
+    methods: [MethodDescriptor],
+    serializer: (String) -> String,
+    deserializer: (String) -> String
+  ) -> Self {
+    return FunctionDescription(
+      accessModifier: accessLevel,
+      kind: .function(name: "registerMethods"),
+      parameters: [
+        ParameterDescription(
+          label: "with",
+          name: "router",
+          type: .rpcRouter,
+          `inout`: true
+        )
+      ],
+      body: methods.map { method in
+        .functionCall(
+          .registerWithRouter(
+            serviceNamespace: serviceNamespace,
+            methodNamespace: method.name.generatedUpperCase,
+            methodName: method.name.generatedLowerCase,
+            inputDeserializer: deserializer(method.inputType),
+            outputSerializer: serializer(method.outputType)
+          )
+        )
+      }
+    )
+  }
+}
+
+extension FunctionDescription {
+  /// ```
+  /// func <Name>(
+  ///   request: GRPCCore.StreamingServerRequest<Input>
+  ///   context: GRPCCore.ServerContext
+  /// ) async throws -> GRPCCore.StreamingServerResponse<Output> {
+  ///   let response = try await self.<Name>(
+  ///     request: GRPCCore.ServerRequest(stream: request),
+  ///     context: context
+  ///   )
+  ///   return GRPCCore.StreamingServerResponse(single: response)
+  /// }
+  /// ```
+  static func serverStreamingMethodsCallingMethod(
+    accessLevel: AccessModifier? = nil,
+    name: String,
+    input: String,
+    output: String,
+    streamingInput: Bool,
+    streamingOutput: Bool
+  ) -> FunctionDescription {
+    let signature: FunctionSignatureDescription = .serverMethod(
+      accessLevel: accessLevel,
+      name: name,
+      input: input,
+      output: output,
+      // This method converts from the fully streamed version to the specified version.
+      streamingInput: true,
+      streamingOutput: true
+    )
+
+    // Call the underlying function.
+    let functionCall: Expression = .functionCall(
+      calledExpression: .memberAccess(
+        MemberAccessDescription(
+          left: .identifierPattern("self"),
+          right: name
+        )
+      ),
+      arguments: [
+        FunctionArgumentDescription(
+          label: "request",
+          expression: streamingInput
+            ? .identifierPattern("request")
+            : .functionCall(
+              calledExpression: .identifierType(.serverRequest(forType: nil, streaming: false)),
+              arguments: [
+                FunctionArgumentDescription(
+                  label: "stream",
+                  expression: .identifierPattern("request")
+                )
+              ]
+            )
+        ),
+        FunctionArgumentDescription(
+          label: "context",
+          expression: .identifierPattern("context")
+        ),
+      ]
+    )
+
+    // Call the function and assign to 'response'.
+    let response: Declaration = .variable(
+      kind: .let,
+      left: "response",
+      right: .unaryKeyword(
+        kind: .try,
+        expression: .unaryKeyword(
+          kind: .await,
+          expression: functionCall
+        )
+      )
+    )
+
+    // Build the return statement.
+    let returnExpression: Expression = .unaryKeyword(
+      kind: .return,
+      expression: streamingOutput
+        ? .identifierPattern("response")
+        : .functionCall(
+          calledExpression: .identifierType(.serverResponse(forType: nil, streaming: true)),
+          arguments: [
+            FunctionArgumentDescription(
+              label: "single",
+              expression: .identifierPattern("response")
+            )
+          ]
+        )
+    )
+
+    return Self(
+      signature: signature,
+      body: [.declaration(response), .expression(returnExpression)]
+    )
+  }
+}
+
+extension ExtensionDescription {
+  /// ```
+  /// extension <ExtensionName> {
+  ///   func <Name>(
+  ///     request: GRPCCore.StreamingServerRequest<Input>
+  ///     context: GRPCCore.ServerContext
+  ///   ) async throws -> GRPCCore.StreamingServerResponse<Output> {
+  ///     let response = try await self.<Name>(
+  ///       request: GRPCCore.ServerRequest(stream: request),
+  ///       context: context
+  ///     )
+  ///     return GRPCCore.StreamingServerResponse(single: response)
+  ///   }
+  ///   ...
+  /// }
+  /// ```
+  static func streamingServiceProtocolDefaultImplementation(
+    accessModifier: AccessModifier? = nil,
+    on extensionName: String,
+    methods: [MethodDescriptor]
+  ) -> Self {
+    return ExtensionDescription(
+      onType: extensionName,
+      declarations: methods.compactMap { method -> Declaration? in
+        // Bidirectional streaming methods don't need a default implementation as their signatures
+        // match across the two protocols.
+        if method.isInputStreaming, method.isOutputStreaming { return nil }
+
+        return .function(
+          .serverStreamingMethodsCallingMethod(
+            accessLevel: accessModifier,
+            name: method.name.generatedLowerCase,
+            input: method.inputType,
+            output: method.outputType,
+            streamingInput: method.isInputStreaming,
+            streamingOutput: method.isOutputStreaming
+          )
+        )
+      }
+    )
+  }
+}

+ 336 - 0
Tests/GRPCCodeGenTests/Internal/StructuredSwift+ServerTests.swift

@@ -0,0 +1,336 @@
+/*
+ * Copyright 2024, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import Testing
+
+@testable import GRPCCodeGen
+
+extension StructuedSwiftTests {
+  @Suite("Server")
+  struct Server {
+    @Test(
+      "func <Method>(request:context:) async throws -> ...",
+      arguments: AccessModifier.allCases,
+      RPCKind.allCases
+    )
+    func serverMethodSignature(access: AccessModifier, kind: RPCKind) {
+      let decl: FunctionSignatureDescription = .serverMethod(
+        accessLevel: access,
+        name: "foo",
+        input: "Input",
+        output: "Output",
+        streamingInput: kind.streamsInput,
+        streamingOutput: kind.streamsOutput
+      )
+
+      let expected: String
+
+      switch kind {
+      case .unary:
+        expected = """
+          \(access) func foo(
+            request: GRPCCore.ServerRequest<Input>,
+            context: GRPCCore.ServerContext
+          ) async throws -> GRPCCore.ServerResponse<Output>
+          """
+      case .clientStreaming:
+        expected = """
+          \(access) func foo(
+            request: GRPCCore.StreamingServerRequest<Input>,
+            context: GRPCCore.ServerContext
+          ) async throws -> GRPCCore.ServerResponse<Output>
+          """
+      case .serverStreaming:
+        expected = """
+          \(access) func foo(
+            request: GRPCCore.ServerRequest<Input>,
+            context: GRPCCore.ServerContext
+          ) async throws -> GRPCCore.StreamingServerResponse<Output>
+          """
+      case .bidirectionalStreaming:
+        expected = """
+          \(access) func foo(
+            request: GRPCCore.StreamingServerRequest<Input>,
+            context: GRPCCore.ServerContext
+          ) async throws -> GRPCCore.StreamingServerResponse<Output>
+          """
+      }
+
+      #expect(render(.function(signature: decl)) == expected)
+    }
+
+    @Test("protocol StreamingServiceProtocol { ... }", arguments: AccessModifier.allCases)
+    func serverStreamingServiceProtocol(access: AccessModifier) {
+      let decl: ProtocolDescription = .streamingService(
+        accessLevel: access,
+        name: "FooService",
+        methods: [
+          .init(
+            documentation: "/// Some docs",
+            name: .init(base: "Foo", generatedUpperCase: "Foo", generatedLowerCase: "foo"),
+            isInputStreaming: false,
+            isOutputStreaming: false,
+            inputType: "FooInput",
+            outputType: "FooOutput"
+          )
+        ]
+      )
+
+      let expected = """
+        \(access) protocol FooService: GRPCCore.RegistrableRPCService {
+          /// Some docs
+          func foo(
+            request: GRPCCore.StreamingServerRequest<FooInput>,
+            context: GRPCCore.ServerContext
+          ) async throws -> GRPCCore.StreamingServerResponse<FooOutput>
+        }
+        """
+
+      #expect(render(.protocol(decl)) == expected)
+    }
+
+    @Test("protocol ServiceProtocol { ... }", arguments: AccessModifier.allCases)
+    func serverServiceProtocol(access: AccessModifier) {
+      let decl: ProtocolDescription = .service(
+        accessLevel: access,
+        name: "FooService",
+        streamingProtocol: "FooService_StreamingServiceProtocol",
+        methods: [
+          .init(
+            documentation: "/// Some docs",
+            name: .init(base: "Foo", generatedUpperCase: "Foo", generatedLowerCase: "foo"),
+            isInputStreaming: false,
+            isOutputStreaming: false,
+            inputType: "FooInput",
+            outputType: "FooOutput"
+          )
+        ]
+      )
+
+      let expected = """
+        \(access) protocol FooService: FooService_StreamingServiceProtocol {
+          /// Some docs
+          func foo(
+            request: GRPCCore.ServerRequest<FooInput>,
+            context: GRPCCore.ServerContext
+          ) async throws -> GRPCCore.ServerResponse<FooOutput>
+        }
+        """
+
+      #expect(render(.protocol(decl)) == expected)
+    }
+
+    @Test("{ router, context in try await self.<Method>(...) }")
+    func routerHandlerInvokingRPC() {
+      let expression: ClosureInvocationDescription = .routerHandlerInvokingRPC(method: "foo")
+      let expected = """
+        { request, context in
+          try await self.foo(
+            request: request,
+            context: context
+          )
+        }
+        """
+      #expect(render(.closureInvocation(expression)) == expected)
+    }
+
+    @Test("router.registerHandler(...) { ... }")
+    func registerMethodsWithRouter() {
+      let expression: FunctionCallDescription = .registerWithRouter(
+        serviceNamespace: "FooService",
+        methodNamespace: "Bar",
+        methodName: "bar",
+        inputDeserializer: "Deserialize<BarInput>()",
+        outputSerializer: "Serialize<BarOutput>()"
+      )
+
+      let expected = """
+        router.registerHandler(
+          forMethod: FooService.Method.Bar.descriptor,
+          deserializer: Deserialize<BarInput>(),
+          serializer: Serialize<BarOutput>(),
+          handler: { request, context in
+            try await self.bar(
+              request: request,
+              context: context
+            )
+          }
+        )
+        """
+
+      #expect(render(.functionCall(expression)) == expected)
+    }
+
+    @Test("func registerMethods(router:)", arguments: AccessModifier.allCases)
+    func registerMethods(access: AccessModifier) {
+      let expression: FunctionDescription = .registerMethods(
+        accessLevel: access,
+        serviceNamespace: "FooService",
+        methods: [
+          .init(
+            documentation: "",
+            name: .init(base: "Bar", generatedUpperCase: "Bar", generatedLowerCase: "bar"),
+            isInputStreaming: false,
+            isOutputStreaming: false,
+            inputType: "BarInput",
+            outputType: "BarOutput"
+          )
+        ]
+      ) { type in
+        "Serialize<\(type)>()"
+      } deserializer: { type in
+        "Deserialize<\(type)>()"
+      }
+
+      let expected = """
+        \(access) func registerMethods(with router: inout GRPCCore.RPCRouter) {
+          router.registerHandler(
+            forMethod: FooService.Method.Bar.descriptor,
+            deserializer: Deserialize<BarInput>(),
+            serializer: Serialize<BarOutput>(),
+            handler: { request, context in
+              try await self.bar(
+                request: request,
+                context: context
+              )
+            }
+          )
+        }
+        """
+
+      #expect(render(.function(expression)) == expected)
+    }
+
+    @Test(
+      "func <Method>(request:context:) async throw { ... (convert to/from single) ... }",
+      arguments: AccessModifier.allCases,
+      RPCKind.allCases
+    )
+    func serverStreamingMethodsCallingMethod(access: AccessModifier, kind: RPCKind) {
+      let expression: FunctionDescription = .serverStreamingMethodsCallingMethod(
+        accessLevel: access,
+        name: "foo",
+        input: "Input",
+        output: "Output",
+        streamingInput: kind.streamsInput,
+        streamingOutput: kind.streamsOutput
+      )
+
+      let expected: String
+
+      switch kind {
+      case .unary:
+        expected = """
+          \(access) func foo(
+            request: GRPCCore.StreamingServerRequest<Input>,
+            context: GRPCCore.ServerContext
+          ) async throws -> GRPCCore.StreamingServerResponse<Output> {
+            let response = try await self.foo(
+              request: GRPCCore.ServerRequest(stream: request),
+              context: context
+            )
+            return GRPCCore.StreamingServerResponse(single: response)
+          }
+          """
+      case .serverStreaming:
+        expected = """
+          \(access) func foo(
+            request: GRPCCore.StreamingServerRequest<Input>,
+            context: GRPCCore.ServerContext
+          ) async throws -> GRPCCore.StreamingServerResponse<Output> {
+            let response = try await self.foo(
+              request: GRPCCore.ServerRequest(stream: request),
+              context: context
+            )
+            return response
+          }
+          """
+      case .clientStreaming:
+        expected = """
+          \(access) func foo(
+            request: GRPCCore.StreamingServerRequest<Input>,
+            context: GRPCCore.ServerContext
+          ) async throws -> GRPCCore.StreamingServerResponse<Output> {
+            let response = try await self.foo(
+              request: request,
+              context: context
+            )
+            return GRPCCore.StreamingServerResponse(single: response)
+          }
+          """
+      case .bidirectionalStreaming:
+        expected = """
+          \(access) func foo(
+            request: GRPCCore.StreamingServerRequest<Input>,
+            context: GRPCCore.ServerContext
+          ) async throws -> GRPCCore.StreamingServerResponse<Output> {
+            let response = try await self.foo(
+              request: request,
+              context: context
+            )
+            return response
+          }
+          """
+      }
+
+      #expect(render(.function(expression)) == expected)
+    }
+
+    @Test("extension FooService_ServiceProtocol { ... }", arguments: AccessModifier.allCases)
+    func streamingServiceProtocolDefaultImplementation(access: AccessModifier) {
+      let decl: ExtensionDescription = .streamingServiceProtocolDefaultImplementation(
+        accessModifier: access,
+        on: "Foo_ServiceProtocol",
+        methods: [
+          .init(
+            documentation: "",
+            name: .init(base: "Foo", generatedUpperCase: "Foo", generatedLowerCase: "foo"),
+            isInputStreaming: false,
+            isOutputStreaming: false,
+            inputType: "FooInput",
+            outputType: "FooOutput"
+          ),
+          // Will be ignored as a bidirectional streaming method.
+          .init(
+            documentation: "",
+            name: .init(base: "Bar", generatedUpperCase: "Bar", generatedLowerCase: "bar"),
+            isInputStreaming: true,
+            isOutputStreaming: true,
+            inputType: "BarInput",
+            outputType: "BarOutput"
+          ),
+        ]
+      )
+
+      let expected = """
+        extension Foo_ServiceProtocol {
+          \(access) func foo(
+            request: GRPCCore.StreamingServerRequest<FooInput>,
+            context: GRPCCore.ServerContext
+          ) async throws -> GRPCCore.StreamingServerResponse<FooOutput> {
+            let response = try await self.foo(
+              request: GRPCCore.ServerRequest(stream: request),
+              context: context
+            )
+            return GRPCCore.StreamingServerResponse(single: response)
+          }
+        }
+        """
+
+      #expect(render(.extension(decl)) == expected)
+    }
+  }
+}