diff --git a/encoder.go b/encoder.go index c8b1b84..7f6a149 100644 --- a/encoder.go +++ b/encoder.go @@ -1,32 +1,73 @@ package main import ( + "path/filepath" + "github.com/golang/protobuf/protoc-gen-go/descriptor" "github.com/golang/protobuf/protoc-gen-go/plugin" + "github.com/kr/fs" ) type GenericTemplateBasedEncoder struct { - service *descriptor.ServiceDescriptorProto - file *descriptor.FileDescriptorProto + templateDir string + service *descriptor.ServiceDescriptorProto + file *descriptor.FileDescriptorProto } -func NewGenericTemplateBasedEncoder(service *descriptor.ServiceDescriptorProto, file *descriptor.FileDescriptorProto) (e *GenericTemplateBasedEncoder) { +func NewGenericTemplateBasedEncoder(templateDir string, service *descriptor.ServiceDescriptorProto, file *descriptor.FileDescriptorProto) (e *GenericTemplateBasedEncoder) { e = &GenericTemplateBasedEncoder{ - service: service, - file: file, + service: service, + file: file, + templateDir: templateDir, } return } -func (e *GenericTemplateBasedEncoder) Files() []*plugin_go.CodeGeneratorResponse_File { - //log.Printf("file: %v\n", e.file) - //log.Printf("service: %v\n", e.service) - var content string = "hello world" - var fileName string = "test.txt" - return []*plugin_go.CodeGeneratorResponse_File{ - &plugin_go.CodeGeneratorResponse_File{ - Content: &content, - Name: &fileName, - }, +func (e *GenericTemplateBasedEncoder) templates() ([]string, error) { + filenames := []string{} + + walker := fs.Walk(e.templateDir) + for walker.Step() { + if err := walker.Err(); err != nil { + return nil, err + } + + if walker.Stat().IsDir() { + continue + } + + if filepath.Ext(walker.Path()) != ".tmpl" { + continue + } + + rel, err := filepath.Rel(e.templateDir, walker.Path()) + if err != nil { + return nil, err + } + + filenames = append(filenames, rel) } + + return filenames, nil +} + +func (e *GenericTemplateBasedEncoder) Files() []*plugin_go.CodeGeneratorResponse_File { + files := []*plugin_go.CodeGeneratorResponse_File{} + + templates, err := e.templates() + if err != nil { + panic(err) + } + + for _, templateFilename := range templates { + filename := templateFilename[0 : len(templateFilename)-len(".tmpl")] + + content := "hello world" + files = append(files, &plugin_go.CodeGeneratorResponse_File{ + Content: &content, + Name: &filename, + }) + } + + return files } diff --git a/main.go b/main.go index e4d54f5..59e60d2 100644 --- a/main.go +++ b/main.go @@ -29,7 +29,7 @@ func main() { // Generate the clients for _, file := range g.Request.GetProtoFile() { for _, service := range file.GetService() { - encoder := NewGenericTemplateBasedEncoder(service, file) + encoder := NewGenericTemplateBasedEncoder("templates", service, file) g.Response.File = append(g.Response.File, encoder.Files()...) } }