diff --git a/encoder.go b/encoder.go index fe33c9a..6d5e39a 100644 --- a/encoder.go +++ b/encoder.go @@ -2,6 +2,8 @@ package main import ( "bytes" + "fmt" + "io/ioutil" "log" "net/url" "os" @@ -12,11 +14,13 @@ import ( "github.com/golang/protobuf/protoc-gen-go/descriptor" plugin_go "github.com/golang/protobuf/protoc-gen-go/plugin" + "github.com/unistack-org/protoc-gen-micro/assets" pgghelpers "github.com/unistack-org/protoc-gen-micro/helpers" ) type GenericTemplateBasedEncoder struct { templateDir string + assetsDir string service *descriptor.ServiceDescriptorProto file *descriptor.FileDescriptorProto enum []*descriptor.EnumDescriptorProto @@ -77,6 +81,51 @@ func NewGenericTemplateBasedEncoder(templateDir string, file *descriptor.FileDes func (e *GenericTemplateBasedEncoder) templates() ([]string, error) { filenames := []string{} + if e.templateDir == "" { + dir, err := assets.Assets.Open("/") + if err != nil { + return nil, fmt.Errorf("failed to open assets dir") + } + + fi, err := dir.Readdir(-1) + if err != nil { + return nil, fmt.Errorf("failed to get assets files") + } + + if debug { + log.Printf("components to generate: %v", components) + } + + for _, f := range fi { + skip := true + for _, component := range components { + if component == "all" || strings.Contains(f.Name(), "_"+component+".pb.go") { + skip = false + } + } + if skip { + if debug { + log.Printf("skip template %s", f.Name()) + } + continue + } + + if f.IsDir() { + continue + } + if filepath.Ext(f.Name()) != ".tmpl" { + continue + } + if e.debug { + log.Printf("new template: %q", f.Name()) + } + + filenames = append(filenames, f.Name()) + } + + return filenames, nil + } + err := filepath.Walk(e.templateDir, func(path string, info os.FileInfo, err error) error { if err != nil { return err @@ -156,10 +205,27 @@ func (e *GenericTemplateBasedEncoder) genAst(templateFilename string) (*Ast, err } func (e *GenericTemplateBasedEncoder) buildContent(templateFilename string) (string, string, error) { - // initialize template engine - fullPath := filepath.Join(e.templateDir, templateFilename) - templateName := filepath.Base(fullPath) - tmpl, err := template.New(templateName).Funcs(pgghelpers.ProtoHelpersFuncMap).ParseFiles(fullPath) + var tmpl *template.Template + var err error + + if e.templateDir == "" { + fs, err := assets.Assets.Open("/" + templateFilename) + if err != nil { + return "", "", err + } + buf, err := ioutil.ReadAll(fs) + if err != nil { + return "", "", err + } + if err = fs.Close(); err == nil { + tmpl, err = template.New("/" + templateFilename).Funcs(pgghelpers.ProtoHelpersFuncMap).Parse(string(buf)) + } + } else { + // initialize template engine + fullPath := filepath.Join(e.templateDir, templateFilename) + templateName := filepath.Base(fullPath) + tmpl, err = template.New(templateName).Funcs(pgghelpers.ProtoHelpersFuncMap).ParseFiles(fullPath) + } if err != nil { return "", "", err } diff --git a/main.go b/main.go index 3e0d582..3b25de8 100644 --- a/main.go +++ b/main.go @@ -3,18 +3,15 @@ package main import ( "fmt" "go/format" - "io" "io/ioutil" "log" "net/url" "os" - "path/filepath" "strings" "github.com/golang/protobuf/protoc-gen-go/generator" plugin_go "github.com/golang/protobuf/protoc-gen-go/plugin" ggdescriptor "github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/descriptor" - "github.com/unistack-org/protoc-gen-micro/assets" pgghelpers "github.com/unistack-org/protoc-gen-micro/helpers" "google.golang.org/protobuf/proto" ) @@ -28,6 +25,17 @@ const ( boolFalse = "false" ) +var ( + templateDir = "" + templateRepo = "" + destinationDir = "." + debug = false + all = false + singlePackageMode = false + fileMode = false + components = []string{"micro", "grpc"} +) + func main() { g := generator.New() @@ -47,16 +55,6 @@ func main() { g.CommandLineParameters(g.Request.GetParameter()) // Parse parameters - var ( - templateDir = "" - templateRepo = "" - destinationDir = "." - debug = false - all = false - singlePackageMode = false - fileMode = false - components = []string{"micro", "grpc"} - ) if parameter := g.Request.GetParameter(); parameter != "" { for _, param := range strings.Split(parameter, ",") { parts := strings.Split(param, "=") @@ -135,7 +133,7 @@ func main() { } } - if templateDir == "" || templateRepo != "" { + if templateDir == "" && templateRepo != "" { if templateDir, err = ioutil.TempDir("", "gen-*"); err != nil { g.Error(err, "failed to create tmp dir") } @@ -149,63 +147,6 @@ func main() { if err = clone(templateRepo, templateDir); err != nil { g.Error(err, "failed to clone repo") } - } else { - dir, err := assets.Assets.Open("/") - if err != nil { - g.Error(err, "failed to open assets dir") - } - fi, err := dir.Readdir(-1) - if err != nil { - g.Error(err, "failed to get assets files") - } - - if debug { - log.Printf("components to generate: %v", components) - } - - for _, f := range fi { - skip := true - for _, component := range components { - if component == "all" || strings.Contains(f.Name(), "_"+component+".pb.go") { - skip = false - } - } - if skip { - if debug { - log.Printf("skip template %s", f.Name()) - } - continue - } - if debug { - log.Printf("copy template %s", f.Name()) - } - fpath := filepath.Join(templateDir, f.Name()) - if err = os.MkdirAll(filepath.Dir(fpath), os.FileMode(0755)); err != nil { - g.Error(err, "failed to create nested dir") - } - if f.IsDir() { - continue - } - fd, err := os.OpenFile(fpath, os.O_CREATE|os.O_TRUNC|os.O_RDWR, f.Mode()) - if err != nil { - g.Error(err, "failed to create template file") - } - fs, err := assets.Assets.Open(f.Name()) - if err != nil { - g.Error(err, "failed to open template file") - } - if _, err = io.Copy(fd, fs); err != nil { - fd.Close() - fs.Close() - g.Error(err, "failed to copy template file") - } - if err = fd.Close(); err != nil { - g.Error(err, "failed to flush template file") - } - if err = fs.Close(); err != nil { - g.Error(err, "failed to flush template file") - } - } } }