diff --git a/ast.go b/ast.go index 966e4a9..0286a24 100644 --- a/ast.go +++ b/ast.go @@ -15,47 +15,58 @@ import ( "google.golang.org/protobuf/proto" ) -var ( - astFields = make(map[string]map[string]map[string]*structtag.Tags) // map proto file with proto message ast struct -) +var astFields = make(map[string]map[string]map[string]*structtag.Tags) // map proto file with proto message ast struct + +func (g *Generator) astFill(file *protogen.File, message *protogen.Message) error { + for _, field := range message.Fields { + if field.Desc.Options() == nil { + continue + } + if !proto.HasExtension(field.Desc.Options(), tag_options.E_Tags) { + continue + } + + opts := proto.GetExtension(field.Desc.Options(), tag_options.E_Tags) + if opts != nil { + fpath := filepath.Join(g.tagPath, file.GeneratedFilenamePrefix+".pb.go") + mp, ok := astFields[fpath] + if !ok { + mp = make(map[string]map[string]*structtag.Tags) + } + nmp, ok := mp[message.GoIdent.GoName] + if !ok { + nmp = make(map[string]*structtag.Tags) + } + tags, err := structtag.Parse(opts.(string)) + if err != nil { + return err + } + nmp[field.GoName] = tags + mp[message.GoIdent.GoName] = nmp + astFields[fpath] = mp + } + } + for _, nmessage := range message.Messages { + if err := g.astFill(file, nmessage); err != nil { + return err + } + } + + return nil +} func (g *Generator) astGenerate(plugin *protogen.Plugin) error { if g.tagPath == "" { return nil } + for _, file := range plugin.Files { if !file.Generate { continue } - for _, message := range file.Messages { - for _, field := range message.Fields { - if field.Desc.Options() == nil { - continue - } - if !proto.HasExtension(field.Desc.Options(), tag_options.E_Tags) { - continue - } - - opts := proto.GetExtension(field.Desc.Options(), tag_options.E_Tags) - if opts != nil { - fpath := filepath.Join(g.tagPath, file.GeneratedFilenamePrefix+".pb.go") - mp, ok := astFields[fpath] - if !ok { - mp = make(map[string]map[string]*structtag.Tags) - } - nmp, ok := mp[message.GoIdent.GoName] - if !ok { - nmp = make(map[string]*structtag.Tags) - } - tags, err := structtag.Parse(opts.(string)) - if err != nil { - return err - } - nmp[field.GoName] = tags - mp[message.GoIdent.GoName] = nmp - astFields[fpath] = mp - } + if err := g.astFill(file, message); err != nil { + return err } } }