diff --git a/main.go b/main.go index 9f1a6e6..87d74e9 100644 --- a/main.go +++ b/main.go @@ -8,6 +8,7 @@ import ( "github.com/golang/protobuf/proto" "github.com/golang/protobuf/protoc-gen-go/generator" + "github.com/golang/protobuf/protoc-gen-go/plugin" ) func main() { @@ -71,17 +72,32 @@ func main() { } } + tmplMap := make(map[string]*plugin_go.CodeGeneratorResponse_File) + concatOrAppend := func(file *plugin_go.CodeGeneratorResponse_File) { + if val, ok := tmplMap[*file.Name]; ok { + *val.Content += *file.Content + } else { + tmplMap[*file.Name] = file + g.Response.File = append(g.Response.File, file) + } + } + // Generate the encoders for _, file := range g.Request.GetProtoFile() { if all { encoder := NewGenericTemplateBasedEncoder(templateDir, file, debug, destinationDir) - g.Response.File = append(g.Response.File, encoder.Files()...) + for _, tmpl := range encoder.Files() { + concatOrAppend(tmpl) + } + continue } for _, service := range file.GetService() { encoder := NewGenericServiceTemplateBasedEncoder(templateDir, service, file, debug, destinationDir) - g.Response.File = append(g.Response.File, encoder.Files()...) + for _, tmpl := range encoder.Files() { + concatOrAppend(tmpl) + } } }