diff --git a/encoder.go b/encoder.go index 53a0779..f5d5fe7 100644 --- a/encoder.go +++ b/encoder.go @@ -52,6 +52,7 @@ func NewGenericServiceTemplateBasedEncoder(templateDir string, service *descript if debug { log.Printf("new encoder: file=%q service=%q template-dir=%q", file.GetName(), service.GetName(), templateDir) } + pgghelpers.InitPathMap(file) return } @@ -68,6 +69,7 @@ func NewGenericTemplateBasedEncoder(templateDir string, file *descriptor.FileDes if debug { log.Printf("new encoder: file=%q template-dir=%q", file.GetName(), templateDir) } + pgghelpers.InitPathMap(file) return } diff --git a/helpers/helpers.go b/helpers/helpers.go index d7bcccf..da92080 100644 --- a/helpers/helpers.go +++ b/helpers/helpers.go @@ -99,6 +99,7 @@ var ProtoHelpersFuncMap = template.FuncMap{ "isFieldRepeated": isFieldRepeated, "haskellType": haskellType, "goType": goType, + "goZeroValue": goZeroValue, "goTypeWithPackage": goTypeWithPackage, "jsType": jsType, "jsSuffixReserved": jsSuffixReservedKeyword, @@ -110,6 +111,185 @@ var ProtoHelpersFuncMap = template.FuncMap{ "urlHasVarsFromMessage": urlHasVarsFromMessage, "lowerGoNormalize": lowerGoNormalize, "goNormalize": goNormalize, + "leadingComment": leadingComment, + "trailingComment": trailingComment, + "leadingDetachedComments": leadingDetachedComments, + "stringFieldExtension": stringFieldExtension, + "boolFieldExtension": boolFieldExtension, + "isFieldMap": isFieldMap, + "fieldMapKeyType": fieldMapKeyType, + "fieldMapValueType": fieldMapValueType, + "replaceDict": replaceDict, +} + +var pathMap map[interface{}]*descriptor.SourceCodeInfo_Location + +func InitPathMap(file *descriptor.FileDescriptorProto) { + pathMap = make(map[interface{}]*descriptor.SourceCodeInfo_Location) + addToPathMap(file.GetSourceCodeInfo(), file, []int32{}) +} + +func InitPathMaps(files []*descriptor.FileDescriptorProto) { + pathMap = make(map[interface{}]*descriptor.SourceCodeInfo_Location) + for _, file := range files { + addToPathMap(file.GetSourceCodeInfo(), file, []int32{}) + } +} + +// addToPathMap traverses through the AST adding SourceCodeInfo_Location entries to the pathMap. +// Since the AST is a tree, the recursion finishes once it has gone through all the nodes. +func addToPathMap(info *descriptor.SourceCodeInfo, i interface{}, path []int32) { + loc := findLoc(info, path) + if loc != nil { + pathMap[i] = loc + } + switch d := i.(type) { + case *descriptor.FileDescriptorProto: + for index, descriptor := range d.MessageType { + addToPathMap(info, descriptor, newPath(path, 4, index)) + } + for index, descriptor := range d.EnumType { + addToPathMap(info, descriptor, newPath(path, 5, index)) + } + for index, descriptor := range d.Service { + addToPathMap(info, descriptor, newPath(path, 6, index)) + } + case *descriptor.DescriptorProto: + for index, descriptor := range d.Field { + addToPathMap(info, descriptor, newPath(path, 2, index)) + } + for index, descriptor := range d.NestedType { + addToPathMap(info, descriptor, newPath(path, 3, index)) + } + for index, descriptor := range d.EnumType { + addToPathMap(info, descriptor, newPath(path, 4, index)) + } + case *descriptor.EnumDescriptorProto: + for index, descriptor := range d.Value { + addToPathMap(info, descriptor, newPath(path, 2, index)) + } + case *descriptor.ServiceDescriptorProto: + for index, descriptor := range d.Method { + addToPathMap(info, descriptor, newPath(path, 2, index)) + } + } +} + +func newPath(base []int32, field int32, index int) []int32 { + p := append([]int32{}, base...) + p = append(p, field, int32(index)) + return p +} + +func findLoc(info *descriptor.SourceCodeInfo, path []int32) *descriptor.SourceCodeInfo_Location { + for _, loc := range info.GetLocation() { + if samePath(loc.Path, path) { + return loc + } + } + return nil +} + +func samePath(a, b []int32) bool { + if len(a) != len(b) { + return false + } + for i, p := range a { + if p != b[i] { + return false + } + } + return true +} + +func findSourceInfoLocation(i interface{}) *descriptor.SourceCodeInfo_Location { + if pathMap == nil { + return nil + } + return pathMap[i] +} + +func leadingComment(i interface{}) string { + loc := pathMap[i] + return loc.GetLeadingComments() +} +func trailingComment(i interface{}) string { + loc := pathMap[i] + return loc.GetTrailingComments() +} +func leadingDetachedComments(i interface{}) []string { + loc := pathMap[i] + return loc.GetLeadingDetachedComments() +} + +func stringFieldExtension(fieldID int32, f *descriptor.FieldDescriptorProto) string { + if f == nil { + return "" + } + if f.Options == nil { + return "" + } + var extendedType *descriptor.FieldOptions + var extensionType *string + + eds := proto.RegisteredExtensions(f.Options) + if eds[fieldID] == nil { + ed := &proto.ExtensionDesc{ + ExtendedType: extendedType, + ExtensionType: extensionType, + Field: fieldID, + Tag: fmt.Sprintf("bytes,%d", fieldID), + } + proto.RegisterExtension(ed) + eds = proto.RegisteredExtensions(f.Options) + } + + ext, err := proto.GetExtension(f.Options, eds[fieldID]) + if err != nil { + return "" + } + + str, ok := ext.(*string) + if !ok { + return "" + } + + return *str +} + +func boolFieldExtension(fieldID int32, f *descriptor.FieldDescriptorProto) bool { + if f == nil { + return false + } + if f.Options == nil { + return false + } + var extendedType *descriptor.FieldOptions + var extensionType *bool + + eds := proto.RegisteredExtensions(f.Options) + if eds[fieldID] == nil { + ed := &proto.ExtensionDesc{ + ExtendedType: extendedType, + ExtensionType: extensionType, + Field: fieldID, + Tag: fmt.Sprintf("varint,%d", fieldID), + } + proto.RegisterExtension(ed) + eds = proto.RegisteredExtensions(f.Options) + } + + ext, err := proto.GetExtension(f.Options, eds[fieldID]) + if err != nil { + return false + } + + str, ok := ext.(*bool) + if !ok { + return false + } + + return *str } func init() { @@ -180,6 +360,9 @@ func isFieldMessage(f *descriptor.FieldDescriptorProto) bool { } func isFieldRepeated(f *descriptor.FieldDescriptorProto) bool { + if f == nil { + return false + } if f.Type != nil && f.Label != nil && *f.Label == descriptor.FieldDescriptorProto_LABEL_REPEATED { return true } @@ -187,6 +370,98 @@ func isFieldRepeated(f *descriptor.FieldDescriptorProto) bool { return false } +func isFieldMap(f *descriptor.FieldDescriptorProto, m *descriptor.DescriptorProto) bool { + if f.TypeName == nil { + return false + } + + shortName := shortType(*f.TypeName) + var nt *descriptor.DescriptorProto + for _, t := range m.NestedType { + if *t.Name == shortName { + nt = t + break + } + } + + if nt == nil { + return false + } + + for _, f := range nt.Field { + switch *f.Name { + case "key": + if *f.Number != 1 { + return false + } + case "value": + if *f.Number != 2 { + return false + } + default: + return false + } + } + + return true +} + +func fieldMapKeyType(f *descriptor.FieldDescriptorProto, m *descriptor.DescriptorProto) *descriptor.FieldDescriptorProto { + if f.TypeName == nil { + return nil + } + + shortName := shortType(*f.TypeName) + var nt *descriptor.DescriptorProto + for _, t := range m.NestedType { + if *t.Name == shortName { + nt = t + break + } + } + + if nt == nil { + return nil + } + + for _, f := range nt.Field { + if *f.Name == "key" { + return f + } + } + + return nil + +} + +func fieldMapValueType(f *descriptor.FieldDescriptorProto, m *descriptor.DescriptorProto) *descriptor.FieldDescriptorProto { + if f.TypeName == nil { + return nil + } + + shortName := shortType(*f.TypeName) + var nt *descriptor.DescriptorProto + for _, t := range m.NestedType { + if *t.Name == shortName { + nt = t + break + } + } + + if nt == nil { + return nil + } + + for _, f := range nt.Field { + if *f.Name == "value" { + return f + } + } + + return nil + +} + func goTypeWithPackage(f *descriptor.FieldDescriptorProto) string { pkg := "" if *f.Type == descriptor.FieldDescriptorProto_TYPE_MESSAGE || *f.Type == descriptor.FieldDescriptorProto_TYPE_ENUM { @@ -319,6 +594,39 @@ func goType(pkg string, f *descriptor.FieldDescriptorProto) string { } } +func goZeroValue(f *descriptor.FieldDescriptorProto) string { + const nilString = "nil" + if *f.Label == descriptor.FieldDescriptorProto_LABEL_REPEATED { + return nilString + } + switch *f.Type { + case descriptor.FieldDescriptorProto_TYPE_DOUBLE: + return "0.0" + case descriptor.FieldDescriptorProto_TYPE_FLOAT: + return "0.0" + case descriptor.FieldDescriptorProto_TYPE_INT64: + return "0" + case descriptor.FieldDescriptorProto_TYPE_UINT64: + return "0" + case descriptor.FieldDescriptorProto_TYPE_INT32: + return "0" + case descriptor.FieldDescriptorProto_TYPE_UINT32: + return "0" + case descriptor.FieldDescriptorProto_TYPE_BOOL: + return "false" + case descriptor.FieldDescriptorProto_TYPE_STRING: + return "\"\"" + case descriptor.FieldDescriptorProto_TYPE_MESSAGE: + return nilString + case descriptor.FieldDescriptorProto_TYPE_BYTES: + return "0" + case descriptor.FieldDescriptorProto_TYPE_ENUM: + return nilString + default: + return nilString + } +} + func jsType(f *descriptor.FieldDescriptorProto) string { template := "%s" if isFieldRepeated(f) { @@ -502,3 +810,14 @@ func formatID(base string, formatted string) string { } return formatted } + +func replaceDict(src string, dict map[string]interface{}) string { + for old, v := range dict { + new, ok := v.(string) + if !ok { + continue + } + src = strings.Replace(src, old, new, -1) + } + return src +}