diff --git a/helpers/helpers.go b/helpers/helpers.go index 20a1cc3..3c408fa 100644 --- a/helpers/helpers.go +++ b/helpers/helpers.go @@ -146,10 +146,13 @@ var ProtoHelpersFuncMap = template.FuncMap{ "leadingComment": leadingComment, "trailingComment": trailingComment, "leadingDetachedComments": leadingDetachedComments, + "stringMessageExtension": stringMessageExtension, "stringFieldExtension": stringFieldExtension, "int64FieldExtension": int64FieldExtension, + "int64MessageExtension": int64MessageExtension, "stringMethodOptionsExtension": stringMethodOptionsExtension, "boolMethodOptionsExtension": boolMethodOptionsExtension, + "boolMessageExtension": boolMessageExtension, "boolFieldExtension": boolFieldExtension, "isFieldMap": isFieldMap, "fieldMapKeyType": fieldMapKeyType, @@ -413,6 +416,76 @@ func int64FieldExtension(fieldID int32, f *descriptor.FieldDescriptorProto) int6 return *i } +func int64MessageExtension(fieldID int32, f *descriptor.DescriptorProto) int64 { + if f == nil { + return 0 + } + if f.Options == nil { + return 0 + } + var extendedType *descriptor.MessageOptions + var extensionType *string + + eds := proto.RegisteredExtensions(f.Options) + if eds[fieldID] == nil { + ed := &proto.ExtensionDesc{ + ExtendedType: extendedType, + ExtensionType: extensionType, + Field: fieldID, + Tag: fmt.Sprintf("bytes,%d", fieldID), + } + proto.RegisterExtension(ed) + eds = proto.RegisteredExtensions(f.Options) + } + + ext, err := proto.GetExtension(f.Options, eds[fieldID]) + if err != nil { + return 0 + } + + i, ok := ext.(*int64) + if !ok { + return 0 + } + + return *i +} + +func stringMessageExtension(fieldID int32, f *descriptor.DescriptorProto) string { + if f == nil { + return "" + } + if f.Options == nil { + return "" + } + var extendedType *descriptor.MessageOptions + var extensionType *string + + eds := proto.RegisteredExtensions(f.Options) + if eds[fieldID] == nil { + ed := &proto.ExtensionDesc{ + ExtendedType: extendedType, + ExtensionType: extensionType, + Field: fieldID, + Tag: fmt.Sprintf("bytes,%d", fieldID), + } + proto.RegisterExtension(ed) + eds = proto.RegisteredExtensions(f.Options) + } + + ext, err := proto.GetExtension(f.Options, eds[fieldID]) + if err != nil { + return "" + } + + str, ok := ext.(*string) + if !ok { + return "" + } + + return *str +} + func boolMethodOptionsExtension(fieldID int32, f *descriptor.MethodDescriptorProto) bool { if f == nil { return false @@ -483,6 +556,41 @@ func boolFieldExtension(fieldID int32, f *descriptor.FieldDescriptorProto) bool return *b } +func boolMessageExtension(fieldID int32, f *descriptor.DescriptorProto) bool { + if f == nil { + return false + } + if f.Options == nil { + return false + } + var extendedType *descriptor.MessageOptions + var extensionType *bool + + eds := proto.RegisteredExtensions(f.Options) + if eds[fieldID] == nil { + ed := &proto.ExtensionDesc{ + ExtendedType: extendedType, + ExtensionType: extensionType, + Field: fieldID, + Tag: fmt.Sprintf("varint,%d", fieldID), + } + proto.RegisterExtension(ed) + eds = proto.RegisteredExtensions(f.Options) + } + + ext, err := proto.GetExtension(f.Options, eds[fieldID]) + if err != nil { + return false + } + + b, ok := ext.(*bool) + if !ok { + return false + } + + return *b +} + func init() { for k, v := range sprig.TxtFuncMap() { ProtoHelpersFuncMap[k] = v