fix tag for nested messages

Signed-off-by: Vasiliy Tolstov <v.tolstov@unistack.org>
This commit is contained in:
Василий Толстов 2021-08-31 19:40:20 +03:00
parent dd62c380f2
commit ee4d83458f

73
ast.go
View File

@ -15,47 +15,58 @@ import (
"google.golang.org/protobuf/proto"
)
var (
astFields = make(map[string]map[string]map[string]*structtag.Tags) // map proto file with proto message ast struct
)
var astFields = make(map[string]map[string]map[string]*structtag.Tags) // map proto file with proto message ast struct
func (g *Generator) astFill(file *protogen.File, message *protogen.Message) error {
for _, field := range message.Fields {
if field.Desc.Options() == nil {
continue
}
if !proto.HasExtension(field.Desc.Options(), tag_options.E_Tags) {
continue
}
opts := proto.GetExtension(field.Desc.Options(), tag_options.E_Tags)
if opts != nil {
fpath := filepath.Join(g.tagPath, file.GeneratedFilenamePrefix+".pb.go")
mp, ok := astFields[fpath]
if !ok {
mp = make(map[string]map[string]*structtag.Tags)
}
nmp, ok := mp[message.GoIdent.GoName]
if !ok {
nmp = make(map[string]*structtag.Tags)
}
tags, err := structtag.Parse(opts.(string))
if err != nil {
return err
}
nmp[field.GoName] = tags
mp[message.GoIdent.GoName] = nmp
astFields[fpath] = mp
}
}
for _, nmessage := range message.Messages {
if err := g.astFill(file, nmessage); err != nil {
return err
}
}
return nil
}
func (g *Generator) astGenerate(plugin *protogen.Plugin) error {
if g.tagPath == "" {
return nil
}
for _, file := range plugin.Files {
if !file.Generate {
continue
}
for _, message := range file.Messages {
for _, field := range message.Fields {
if field.Desc.Options() == nil {
continue
}
if !proto.HasExtension(field.Desc.Options(), tag_options.E_Tags) {
continue
}
opts := proto.GetExtension(field.Desc.Options(), tag_options.E_Tags)
if opts != nil {
fpath := filepath.Join(g.tagPath, file.GeneratedFilenamePrefix+".pb.go")
mp, ok := astFields[fpath]
if !ok {
mp = make(map[string]map[string]*structtag.Tags)
}
nmp, ok := mp[message.GoIdent.GoName]
if !ok {
nmp = make(map[string]*structtag.Tags)
}
tags, err := structtag.Parse(opts.(string))
if err != nil {
return err
}
nmp[field.GoName] = tags
mp[message.GoIdent.GoName] = nmp
astFields[fpath] = mp
}
if err := g.astFill(file, message); err != nil {
return err
}
}
}