Browse Source

Fixed .proto file imports handling in ProtobufCodeGenParser (#1804)

Motivation:

The .proto file imports from a .proto file were added as file imports in the generated code.
The correct behavior is to add dependencies on the modules containing generated code and types
associated with the .proto files, if they are generated in a different module from the one
of the .proto file we are parsing.

Modifications:

- Created a method that adds only the necessary imports associated with a .proto file
into the CodeGenerationRequest object
- updated the test for the parser

Result:

Imports will be generated correctly for '.proto' files.
Stefana-Ioana Dranca 1 year ago
parent
commit
130605f286

+ 30 - 6
Sources/GRPCProtobufCodeGen/ProtobufCodeGenParser.swift

@@ -24,9 +24,17 @@ import struct GRPCCodeGen.CodeGenerationRequest
 internal struct ProtobufCodeGenParser {
 internal struct ProtobufCodeGenParser {
   let input: FileDescriptor
   let input: FileDescriptor
   let namer: SwiftProtobufNamer
   let namer: SwiftProtobufNamer
+  let extraModuleImports: [String]
+  let protoToModuleMappings: ProtoFileToModuleMappings
 
 
-  internal init(input: FileDescriptor, protoFileModuleMappings: ProtoFileToModuleMappings) {
+  internal init(
+    input: FileDescriptor,
+    protoFileModuleMappings: ProtoFileToModuleMappings,
+    extraModuleImports: [String]
+  ) {
     self.input = input
     self.input = input
+    self.extraModuleImports = extraModuleImports
+    self.protoToModuleMappings = protoFileModuleMappings
     self.namer = SwiftProtobufNamer(
     self.namer = SwiftProtobufNamer(
       currentFile: input,
       currentFile: input,
       protoFileToModuleMappings: protoFileModuleMappings
       protoFileToModuleMappings: protoFileModuleMappings
@@ -50,10 +58,6 @@ internal struct ProtobufCodeGenParser {
       //   https://github.com/grpc/grpc-swift
       //   https://github.com/grpc/grpc-swift
 
 
       """
       """
-    var dependencies = self.input.dependencies.map {
-      CodeGenerationRequest.Dependency(module: $0.name)
-    }
-    dependencies.append(CodeGenerationRequest.Dependency(module: "GRPCProtobuf"))
     let lookupSerializer: (String) -> String = { messageType in
     let lookupSerializer: (String) -> String = { messageType in
       "ProtobufSerializer<\(messageType)>()"
       "ProtobufSerializer<\(messageType)>()"
     }
     }
@@ -71,7 +75,7 @@ internal struct ProtobufCodeGenParser {
     return CodeGenerationRequest(
     return CodeGenerationRequest(
       fileName: self.input.name,
       fileName: self.input.name,
       leadingTrivia: header + leadingTrivia,
       leadingTrivia: header + leadingTrivia,
-      dependencies: dependencies,
+      dependencies: self.codeDependencies,
       services: services,
       services: services,
       lookupSerializer: lookupSerializer,
       lookupSerializer: lookupSerializer,
       lookupDeserializer: lookupDeserializer
       lookupDeserializer: lookupDeserializer
@@ -79,6 +83,26 @@ internal struct ProtobufCodeGenParser {
   }
   }
 }
 }
 
 
+extension ProtobufCodeGenParser {
+  fileprivate var codeDependencies: [CodeGenerationRequest.Dependency] {
+    var codeDependencies: [CodeGenerationRequest.Dependency] = [.init(module: "GRPCProtobuf")]
+    // Adding as dependencies the modules containing generated code or types for
+    // '.proto' files imported in the '.proto' file we are parsing.
+    codeDependencies.append(
+      contentsOf: (self.protoToModuleMappings.neededModules(forFile: self.input) ?? []).map {
+        CodeGenerationRequest.Dependency(module: $0)
+      }
+    )
+    // Adding extra imports passed in as an option to the plugin.
+    codeDependencies.append(
+      contentsOf: self.extraModuleImports.sorted().map {
+        CodeGenerationRequest.Dependency(module: $0)
+      }
+    )
+    return codeDependencies
+  }
+}
+
 extension CodeGenerationRequest.ServiceDescriptor {
 extension CodeGenerationRequest.ServiceDescriptor {
   fileprivate init(
   fileprivate init(
     descriptor: ServiceDescriptor,
     descriptor: ServiceDescriptor,

+ 4 - 2
Sources/GRPCProtobufCodeGen/ProtobufCodeGenerator.swift

@@ -28,11 +28,13 @@ public struct ProtobufCodeGenerator {
 
 
   public func generateCode(
   public func generateCode(
     from fileDescriptor: FileDescriptor,
     from fileDescriptor: FileDescriptor,
-    protoFileModuleMappings: ProtoFileToModuleMappings
+    protoFileModuleMappings: ProtoFileToModuleMappings,
+    extraModuleImports: [String]
   ) throws -> String {
   ) throws -> String {
     let parser = ProtobufCodeGenParser(
     let parser = ProtobufCodeGenParser(
       input: fileDescriptor,
       input: fileDescriptor,
-      protoFileModuleMappings: protoFileModuleMappings
+      protoFileModuleMappings: protoFileModuleMappings,
+      extraModuleImports: extraModuleImports
     )
     )
     let sourceGenerator = SourceGenerator(configuration: self.configuration)
     let sourceGenerator = SourceGenerator(configuration: self.configuration)
 
 

+ 2 - 1
Sources/protoc-gen-grpc-swift/main.swift

@@ -174,7 +174,8 @@ func main(args: [String]) throws {
           )
           )
           grpcFile.content = try grpcGenerator.generateCode(
           grpcFile.content = try grpcGenerator.generateCode(
             from: fileDescriptor,
             from: fileDescriptor,
-            protoFileModuleMappings: options.protoToModuleMappings
+            protoFileModuleMappings: options.protoToModuleMappings,
+            extraModuleImports: options.extraModuleImports
           )
           )
         } else {
         } else {
           let grpcGenerator = Generator(fileDescriptor, options: options)
           let grpcGenerator = Generator(fileDescriptor, options: options)

+ 37 - 7
Tests/GRPCProtobufCodeGenTests/ProtobufCodeGenParserTests.swift

@@ -23,7 +23,20 @@ import XCTest
 
 
 final class ProtobufCodeGenParserTests: XCTestCase {
 final class ProtobufCodeGenParserTests: XCTestCase {
   func testParser() throws {
   func testParser() throws {
-    let descriptorSet = DescriptorSet(protos: [Google_Protobuf_FileDescriptorProto.helloWorld])
+    let descriptorSet = DescriptorSet(
+      protos: [
+        Google_Protobuf_FileDescriptorProto(
+          name: "same-module.proto",
+          package: "same-package"
+        ),
+        Google_Protobuf_FileDescriptorProto(
+          name: "different-module.proto",
+          package: "different-package"
+        ),
+        Google_Protobuf_FileDescriptorProto.helloWorld,
+      ]
+    )
+
     guard let fileDescriptor = descriptorSet.fileDescriptor(named: "helloworld.proto") else {
     guard let fileDescriptor = descriptorSet.fileDescriptor(named: "helloworld.proto") else {
       return XCTFail(
       return XCTFail(
         """
         """
@@ -31,9 +44,18 @@ final class ProtobufCodeGenParserTests: XCTestCase {
         """
         """
       )
       )
     }
     }
+    let moduleMappings = SwiftProtobuf_GenSwift_ModuleMappings.with {
+      $0.mapping = [
+        SwiftProtobuf_GenSwift_ModuleMappings.Entry.with {
+          $0.protoFilePath = ["different-module.proto"]
+          $0.moduleName = "DifferentModule"
+        }
+      ]
+    }
     let parsedCodeGenRequest = try ProtobufCodeGenParser(
     let parsedCodeGenRequest = try ProtobufCodeGenParser(
       input: fileDescriptor,
       input: fileDescriptor,
-      protoFileModuleMappings: ProtoFileToModuleMappings()
+      protoFileModuleMappings: ProtoFileToModuleMappings(moduleMappingsProto: moduleMappings),
+      extraModuleImports: ["ExtraModule"]
     ).parse()
     ).parse()
     XCTAssertEqual(parsedCodeGenRequest.fileName, "helloworld.proto")
     XCTAssertEqual(parsedCodeGenRequest.fileName, "helloworld.proto")
     XCTAssertEqual(
     XCTAssertEqual(
@@ -65,6 +87,11 @@ final class ProtobufCodeGenParserTests: XCTestCase {
       """
       """
     )
     )
 
 
+    XCTAssertEqual(parsedCodeGenRequest.dependencies.count, 3)
+    let expectedDependencyNames = ["GRPCProtobuf", "DifferentModule", "ExtraModule"]
+    let parsedDependencyNames = parsedCodeGenRequest.dependencies.map { $0.module }
+    XCTAssertEqual(parsedDependencyNames, expectedDependencyNames)
+
     XCTAssertEqual(parsedCodeGenRequest.services.count, 1)
     XCTAssertEqual(parsedCodeGenRequest.services.count, 1)
 
 
     let expectedMethod = CodeGenerationRequest.ServiceDescriptor.MethodDescriptor(
     let expectedMethod = CodeGenerationRequest.ServiceDescriptor.MethodDescriptor(
@@ -108,11 +135,6 @@ final class ProtobufCodeGenParserTests: XCTestCase {
       parsedCodeGenRequest.lookupDeserializer("HelloRequest"),
       parsedCodeGenRequest.lookupDeserializer("HelloRequest"),
       "ProtobufDeserializer<HelloRequest>()"
       "ProtobufDeserializer<HelloRequest>()"
     )
     )
-    XCTAssertEqual(parsedCodeGenRequest.dependencies.count, 1)
-    XCTAssertEqual(
-      parsedCodeGenRequest.dependencies[0],
-      CodeGenerationRequest.Dependency(module: "GRPCProtobuf")
-    )
   }
   }
 }
 }
 
 
@@ -158,6 +180,8 @@ extension Google_Protobuf_FileDescriptorProto {
     return Google_Protobuf_FileDescriptorProto.with {
     return Google_Protobuf_FileDescriptorProto.with {
       $0.name = "helloworld.proto"
       $0.name = "helloworld.proto"
       $0.package = "helloworld"
       $0.package = "helloworld"
+      $0.dependency = ["same-module.proto", "different-module.proto"]
+      $0.publicDependency = [1, 2]
       $0.messageType = [requestType, responseType]
       $0.messageType = [requestType, responseType]
       $0.service = [service]
       $0.service = [service]
       $0.sourceCodeInfo = Google_Protobuf_SourceCodeInfo.with {
       $0.sourceCodeInfo = Google_Protobuf_SourceCodeInfo.with {
@@ -199,4 +223,10 @@ extension Google_Protobuf_FileDescriptorProto {
       $0.syntax = "proto3"
       $0.syntax = "proto3"
     }
     }
   }
   }
+
+  internal init(name: String, package: String) {
+    self.init()
+    self.name = name
+    self.package = package
+  }
 }
 }

+ 26 - 2
Tests/GRPCProtobufCodeGenTests/ProtobufCodeGeneratorTests.swift

@@ -55,6 +55,8 @@ final class ProtobufCodeGeneratorTests: XCTestCase {
 
 
         import GRPCCore
         import GRPCCore
         import GRPCProtobuf
         import GRPCProtobuf
+        import DifferentModule
+        import ExtraModule
 
 
         internal enum Helloworld {
         internal enum Helloworld {
             internal enum Greeter {
             internal enum Greeter {
@@ -164,6 +166,8 @@ final class ProtobufCodeGeneratorTests: XCTestCase {
 
 
         import GRPCCore
         import GRPCCore
         import GRPCProtobuf
         import GRPCProtobuf
+        import DifferentModule
+        import ExtraModule
 
 
         public enum Helloworld {
         public enum Helloworld {
           public enum Greeter {
           public enum Greeter {
@@ -258,6 +262,8 @@ final class ProtobufCodeGeneratorTests: XCTestCase {
 
 
         import GRPCCore
         import GRPCCore
         import GRPCProtobuf
         import GRPCProtobuf
+        import DifferentModule
+        import ExtraModule
 
 
         package enum Helloworld {
         package enum Helloworld {
           package enum Greeter {
           package enum Greeter {
@@ -393,7 +399,15 @@ final class ProtobufCodeGeneratorTests: XCTestCase {
       server: server,
       server: server,
       indentation: indentation
       indentation: indentation
     )
     )
-    let descriptorSet = DescriptorSet(protos: [Google_Protobuf_FileDescriptorProto.helloWorld])
+    let descriptorSet = DescriptorSet(
+      protos: [
+        Google_Protobuf_FileDescriptorProto(name: "same-module.proto", package: "same-package"),
+        Google_Protobuf_FileDescriptorProto(
+          name: "different-module.proto",
+          package: "different-package"
+        ),
+        Google_Protobuf_FileDescriptorProto.helloWorld,
+      ])
     guard let fileDescriptor = descriptorSet.fileDescriptor(named: "helloworld.proto") else {
     guard let fileDescriptor = descriptorSet.fileDescriptor(named: "helloworld.proto") else {
       return XCTFail(
       return XCTFail(
         """
         """
@@ -401,11 +415,21 @@ final class ProtobufCodeGeneratorTests: XCTestCase {
         """
         """
       )
       )
     }
     }
+
+    let moduleMappings = SwiftProtobuf_GenSwift_ModuleMappings.with {
+      $0.mapping = [
+        SwiftProtobuf_GenSwift_ModuleMappings.Entry.with {
+          $0.protoFilePath = ["different-module.proto"]
+          $0.moduleName = "DifferentModule"
+        }
+      ]
+    }
     let generator = ProtobufCodeGenerator(configuration: configs)
     let generator = ProtobufCodeGenerator(configuration: configs)
     try XCTAssertEqualWithDiff(
     try XCTAssertEqualWithDiff(
       try generator.generateCode(
       try generator.generateCode(
         from: fileDescriptor,
         from: fileDescriptor,
-        protoFileModuleMappings: ProtoFileToModuleMappings()
+        protoFileModuleMappings: ProtoFileToModuleMappings(moduleMappingsProto: moduleMappings),
+        extraModuleImports: ["ExtraModule"]
       ),
       ),
       expectedCode
       expectedCode
     )
     )