add new helpers

This commit is contained in:
Pat Moroney 2018-01-12 10:38:10 -07:00
parent 7e17e4319f
commit decb64ccd8
2 changed files with 319 additions and 0 deletions

View File

@ -52,6 +52,7 @@ func NewGenericServiceTemplateBasedEncoder(templateDir string, service *descript
if debug {
log.Printf("new encoder: file=%q service=%q template-dir=%q", file.GetName(), service.GetName(), templateDir)
}
pgghelpers.InitPathMap(file)
return
}
@ -68,6 +69,8 @@ func NewGenericTemplateBasedEncoder(templateDir string, file *descriptor.FileDes
if debug {
log.Printf("new encoder: file=%q template-dir=%q", file.GetName(), templateDir)
}
pgghelpers.InitPathMap(file)
pgghelpers.InitPathMap(file)
return
}

View File

@ -99,6 +99,7 @@ var ProtoHelpersFuncMap = template.FuncMap{
"isFieldRepeated": isFieldRepeated,
"haskellType": haskellType,
"goType": goType,
"goZeroValue": goZeroValue,
"goTypeWithPackage": goTypeWithPackage,
"jsType": jsType,
"jsSuffixReserved": jsSuffixReservedKeyword,
@ -110,6 +111,183 @@ var ProtoHelpersFuncMap = template.FuncMap{
"urlHasVarsFromMessage": urlHasVarsFromMessage,
"lowerGoNormalize": lowerGoNormalize,
"goNormalize": goNormalize,
"leadingComment": leadingComment,
"trailingComment": trailingComment,
"leadingDetachedComments": leadingDetachedComments,
"stringFieldExtension": stringFieldExtension,
"boolFieldExtension": boolFieldExtension,
"isFieldMap": isFieldMap,
"fieldMapKeyType": fieldMapKeyType,
"fieldMapValueType": fieldMapValueType,
"replaceDict": replaceDict,
}
var pathMap map[interface{}]*descriptor.SourceCodeInfo_Location
func InitPathMap(file *descriptor.FileDescriptorProto) {
pathMap = make(map[interface{}]*descriptor.SourceCodeInfo_Location)
addToPathMap(file.GetSourceCodeInfo(), file, []int32{})
}
func InitPathMaps(files []*descriptor.FileDescriptorProto) {
pathMap = make(map[interface{}]*descriptor.SourceCodeInfo_Location)
for _, file := range files {
addToPathMap(file.GetSourceCodeInfo(), file, []int32{})
}
}
func addToPathMap(info *descriptor.SourceCodeInfo, i interface{}, path []int32) {
loc := findLoc(info, path)
if loc != nil {
pathMap[i] = loc
}
switch d := i.(type) {
case *descriptor.FileDescriptorProto:
for index, descriptor := range d.MessageType {
addToPathMap(info, descriptor, newPath(path, 4, index))
}
for index, descriptor := range d.EnumType {
addToPathMap(info, descriptor, newPath(path, 5, index))
}
for index, descriptor := range d.Service {
addToPathMap(info, descriptor, newPath(path, 6, index))
}
case *descriptor.DescriptorProto:
for index, descriptor := range d.Field {
addToPathMap(info, descriptor, newPath(path, 2, index))
}
for index, descriptor := range d.NestedType {
addToPathMap(info, descriptor, newPath(path, 3, index))
}
for index, descriptor := range d.EnumType {
addToPathMap(info, descriptor, newPath(path, 4, index))
}
case *descriptor.EnumDescriptorProto:
for index, descriptor := range d.Value {
addToPathMap(info, descriptor, newPath(path, 2, index))
}
case *descriptor.ServiceDescriptorProto:
for index, descriptor := range d.Method {
addToPathMap(info, descriptor, newPath(path, 2, index))
}
}
}
func newPath(base []int32, field int32, index int) []int32 {
p := append([]int32{}, base...)
p = append(p, field, int32(index))
return p
}
func findLoc(info *descriptor.SourceCodeInfo, path []int32) *descriptor.SourceCodeInfo_Location {
for _, loc := range info.GetLocation() {
if samePath(loc.Path, path) {
return loc
}
}
return nil
}
func samePath(a, b []int32) bool {
if len(a) != len(b) {
return false
}
for i, p := range a {
if p != b[i] {
return false
}
}
return true
}
func findSourceInfoLocation(i interface{}) *descriptor.SourceCodeInfo_Location {
if pathMap == nil {
return nil
}
return pathMap[i]
}
func leadingComment(i interface{}) string {
loc := pathMap[i]
return loc.GetLeadingComments()
}
func trailingComment(i interface{}) string {
loc := pathMap[i]
return loc.GetTrailingComments()
}
func leadingDetachedComments(i interface{}) []string {
loc := pathMap[i]
return loc.GetLeadingDetachedComments()
}
func stringFieldExtension(fieldId int32, f *descriptor.FieldDescriptorProto) string {
if f == nil {
return ""
}
if f.Options == nil {
return ""
}
var extendedType *descriptor.FieldOptions
var extensionType *string
eds := proto.RegisteredExtensions(f.Options)
if eds[fieldId] == nil {
ed := &proto.ExtensionDesc{
ExtendedType: extendedType,
ExtensionType: extensionType,
Field: fieldId,
Tag: fmt.Sprintf("bytes,%d", fieldId),
}
proto.RegisterExtension(ed)
eds = proto.RegisteredExtensions(f.Options)
}
ext, err := proto.GetExtension(f.Options, eds[fieldId])
if err != nil {
return ""
}
str, ok := ext.(*string)
if !ok {
return ""
}
return *str
}
func boolFieldExtension(fieldId int32, f *descriptor.FieldDescriptorProto) bool {
if f == nil {
return false
}
if f.Options == nil {
return false
}
var extendedType *descriptor.FieldOptions
var extensionType *bool
eds := proto.RegisteredExtensions(f.Options)
if eds[fieldId] == nil {
ed := &proto.ExtensionDesc{
ExtendedType: extendedType,
ExtensionType: extensionType,
Field: fieldId,
Tag: fmt.Sprintf("varint,%d", fieldId),
}
proto.RegisterExtension(ed)
eds = proto.RegisteredExtensions(f.Options)
}
ext, err := proto.GetExtension(f.Options, eds[fieldId])
if err != nil {
return false
}
str, ok := ext.(*bool)
if !ok {
return false
}
return *str
}
func init() {
@ -180,6 +358,9 @@ func isFieldMessage(f *descriptor.FieldDescriptorProto) bool {
}
func isFieldRepeated(f *descriptor.FieldDescriptorProto) bool {
if f == nil {
return false
}
if f.Type != nil && f.Label != nil && *f.Label == descriptor.FieldDescriptorProto_LABEL_REPEATED {
return true
}
@ -187,6 +368,98 @@ func isFieldRepeated(f *descriptor.FieldDescriptorProto) bool {
return false
}
func isFieldMap(f *descriptor.FieldDescriptorProto, m *descriptor.DescriptorProto) bool {
if f.TypeName == nil {
return false
}
shortName := shortType(*f.TypeName)
var nt *descriptor.DescriptorProto
for _, t := range m.NestedType {
if *t.Name == shortName {
nt = t
break
}
}
if nt == nil {
return false
}
for _, f := range nt.Field {
switch *f.Name {
case "key":
if *f.Number != 1 {
return false
}
case "value":
if *f.Number != 2 {
return false
}
default:
return false
}
}
return true
}
func fieldMapKeyType(f *descriptor.FieldDescriptorProto, m *descriptor.DescriptorProto) *descriptor.FieldDescriptorProto {
if f.TypeName == nil {
return nil
}
shortName := shortType(*f.TypeName)
var nt *descriptor.DescriptorProto
for _, t := range m.NestedType {
if *t.Name == shortName {
nt = t
break
}
}
if nt == nil {
return nil
}
for _, f := range nt.Field {
if *f.Name == "key" {
return f
}
}
return nil
}
func fieldMapValueType(f *descriptor.FieldDescriptorProto, m *descriptor.DescriptorProto) *descriptor.FieldDescriptorProto {
if f.TypeName == nil {
return nil
}
shortName := shortType(*f.TypeName)
var nt *descriptor.DescriptorProto
for _, t := range m.NestedType {
if *t.Name == shortName {
nt = t
break
}
}
if nt == nil {
return nil
}
for _, f := range nt.Field {
if *f.Name == "value" {
return f
}
}
return nil
}
func goTypeWithPackage(f *descriptor.FieldDescriptorProto) string {
pkg := ""
if *f.Type == descriptor.FieldDescriptorProto_TYPE_MESSAGE || *f.Type == descriptor.FieldDescriptorProto_TYPE_ENUM {
@ -319,6 +592,38 @@ func goType(pkg string, f *descriptor.FieldDescriptorProto) string {
}
}
func goZeroValue(f *descriptor.FieldDescriptorProto) string {
if *f.Label == descriptor.FieldDescriptorProto_LABEL_REPEATED {
return "nil"
}
switch *f.Type {
case descriptor.FieldDescriptorProto_TYPE_DOUBLE:
return "0.0"
case descriptor.FieldDescriptorProto_TYPE_FLOAT:
return "0.0"
case descriptor.FieldDescriptorProto_TYPE_INT64:
return "0"
case descriptor.FieldDescriptorProto_TYPE_UINT64:
return "0"
case descriptor.FieldDescriptorProto_TYPE_INT32:
return "0"
case descriptor.FieldDescriptorProto_TYPE_UINT32:
return "0"
case descriptor.FieldDescriptorProto_TYPE_BOOL:
return "false"
case descriptor.FieldDescriptorProto_TYPE_STRING:
return "\"\""
case descriptor.FieldDescriptorProto_TYPE_MESSAGE:
return "nil"
case descriptor.FieldDescriptorProto_TYPE_BYTES:
return "0"
case descriptor.FieldDescriptorProto_TYPE_ENUM:
return "nil"
default:
return "nil"
}
}
func jsType(f *descriptor.FieldDescriptorProto) string {
template := "%s"
if isFieldRepeated(f) {
@ -502,3 +807,14 @@ func formatID(base string, formatted string) string {
}
return formatted
}
func replaceDict(src string, dict map[string]interface{}) string {
for old, v := range dict {
new, ok := v.(string)
if !ok {
continue
}
src = strings.Replace(src, old, new, -1)
}
return src
}