commit
be5f14041f
@ -52,6 +52,7 @@ func NewGenericServiceTemplateBasedEncoder(templateDir string, service *descript
|
||||
if debug {
|
||||
log.Printf("new encoder: file=%q service=%q template-dir=%q", file.GetName(), service.GetName(), templateDir)
|
||||
}
|
||||
pgghelpers.InitPathMap(file)
|
||||
|
||||
return
|
||||
}
|
||||
@ -68,6 +69,7 @@ func NewGenericTemplateBasedEncoder(templateDir string, file *descriptor.FileDes
|
||||
if debug {
|
||||
log.Printf("new encoder: file=%q template-dir=%q", file.GetName(), templateDir)
|
||||
}
|
||||
pgghelpers.InitPathMap(file)
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -99,6 +99,7 @@ var ProtoHelpersFuncMap = template.FuncMap{
|
||||
"isFieldRepeated": isFieldRepeated,
|
||||
"haskellType": haskellType,
|
||||
"goType": goType,
|
||||
"goZeroValue": goZeroValue,
|
||||
"goTypeWithPackage": goTypeWithPackage,
|
||||
"jsType": jsType,
|
||||
"jsSuffixReserved": jsSuffixReservedKeyword,
|
||||
@ -110,6 +111,185 @@ var ProtoHelpersFuncMap = template.FuncMap{
|
||||
"urlHasVarsFromMessage": urlHasVarsFromMessage,
|
||||
"lowerGoNormalize": lowerGoNormalize,
|
||||
"goNormalize": goNormalize,
|
||||
"leadingComment": leadingComment,
|
||||
"trailingComment": trailingComment,
|
||||
"leadingDetachedComments": leadingDetachedComments,
|
||||
"stringFieldExtension": stringFieldExtension,
|
||||
"boolFieldExtension": boolFieldExtension,
|
||||
"isFieldMap": isFieldMap,
|
||||
"fieldMapKeyType": fieldMapKeyType,
|
||||
"fieldMapValueType": fieldMapValueType,
|
||||
"replaceDict": replaceDict,
|
||||
}
|
||||
|
||||
var pathMap map[interface{}]*descriptor.SourceCodeInfo_Location
|
||||
|
||||
func InitPathMap(file *descriptor.FileDescriptorProto) {
|
||||
pathMap = make(map[interface{}]*descriptor.SourceCodeInfo_Location)
|
||||
addToPathMap(file.GetSourceCodeInfo(), file, []int32{})
|
||||
}
|
||||
|
||||
func InitPathMaps(files []*descriptor.FileDescriptorProto) {
|
||||
pathMap = make(map[interface{}]*descriptor.SourceCodeInfo_Location)
|
||||
for _, file := range files {
|
||||
addToPathMap(file.GetSourceCodeInfo(), file, []int32{})
|
||||
}
|
||||
}
|
||||
|
||||
// addToPathMap traverses through the AST adding SourceCodeInfo_Location entries to the pathMap.
|
||||
// Since the AST is a tree, the recursion finishes once it has gone through all the nodes.
|
||||
func addToPathMap(info *descriptor.SourceCodeInfo, i interface{}, path []int32) {
|
||||
loc := findLoc(info, path)
|
||||
if loc != nil {
|
||||
pathMap[i] = loc
|
||||
}
|
||||
switch d := i.(type) {
|
||||
case *descriptor.FileDescriptorProto:
|
||||
for index, descriptor := range d.MessageType {
|
||||
addToPathMap(info, descriptor, newPath(path, 4, index))
|
||||
}
|
||||
for index, descriptor := range d.EnumType {
|
||||
addToPathMap(info, descriptor, newPath(path, 5, index))
|
||||
}
|
||||
for index, descriptor := range d.Service {
|
||||
addToPathMap(info, descriptor, newPath(path, 6, index))
|
||||
}
|
||||
case *descriptor.DescriptorProto:
|
||||
for index, descriptor := range d.Field {
|
||||
addToPathMap(info, descriptor, newPath(path, 2, index))
|
||||
}
|
||||
for index, descriptor := range d.NestedType {
|
||||
addToPathMap(info, descriptor, newPath(path, 3, index))
|
||||
}
|
||||
for index, descriptor := range d.EnumType {
|
||||
addToPathMap(info, descriptor, newPath(path, 4, index))
|
||||
}
|
||||
case *descriptor.EnumDescriptorProto:
|
||||
for index, descriptor := range d.Value {
|
||||
addToPathMap(info, descriptor, newPath(path, 2, index))
|
||||
}
|
||||
case *descriptor.ServiceDescriptorProto:
|
||||
for index, descriptor := range d.Method {
|
||||
addToPathMap(info, descriptor, newPath(path, 2, index))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func newPath(base []int32, field int32, index int) []int32 {
|
||||
p := append([]int32{}, base...)
|
||||
p = append(p, field, int32(index))
|
||||
return p
|
||||
}
|
||||
|
||||
func findLoc(info *descriptor.SourceCodeInfo, path []int32) *descriptor.SourceCodeInfo_Location {
|
||||
for _, loc := range info.GetLocation() {
|
||||
if samePath(loc.Path, path) {
|
||||
return loc
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func samePath(a, b []int32) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i, p := range a {
|
||||
if p != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func findSourceInfoLocation(i interface{}) *descriptor.SourceCodeInfo_Location {
|
||||
if pathMap == nil {
|
||||
return nil
|
||||
}
|
||||
return pathMap[i]
|
||||
}
|
||||
|
||||
func leadingComment(i interface{}) string {
|
||||
loc := pathMap[i]
|
||||
return loc.GetLeadingComments()
|
||||
}
|
||||
func trailingComment(i interface{}) string {
|
||||
loc := pathMap[i]
|
||||
return loc.GetTrailingComments()
|
||||
}
|
||||
func leadingDetachedComments(i interface{}) []string {
|
||||
loc := pathMap[i]
|
||||
return loc.GetLeadingDetachedComments()
|
||||
}
|
||||
|
||||
func stringFieldExtension(fieldID int32, f *descriptor.FieldDescriptorProto) string {
|
||||
if f == nil {
|
||||
return ""
|
||||
}
|
||||
if f.Options == nil {
|
||||
return ""
|
||||
}
|
||||
var extendedType *descriptor.FieldOptions
|
||||
var extensionType *string
|
||||
|
||||
eds := proto.RegisteredExtensions(f.Options)
|
||||
if eds[fieldID] == nil {
|
||||
ed := &proto.ExtensionDesc{
|
||||
ExtendedType: extendedType,
|
||||
ExtensionType: extensionType,
|
||||
Field: fieldID,
|
||||
Tag: fmt.Sprintf("bytes,%d", fieldID),
|
||||
}
|
||||
proto.RegisterExtension(ed)
|
||||
eds = proto.RegisteredExtensions(f.Options)
|
||||
}
|
||||
|
||||
ext, err := proto.GetExtension(f.Options, eds[fieldID])
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
str, ok := ext.(*string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
return *str
|
||||
}
|
||||
|
||||
func boolFieldExtension(fieldID int32, f *descriptor.FieldDescriptorProto) bool {
|
||||
if f == nil {
|
||||
return false
|
||||
}
|
||||
if f.Options == nil {
|
||||
return false
|
||||
}
|
||||
var extendedType *descriptor.FieldOptions
|
||||
var extensionType *bool
|
||||
|
||||
eds := proto.RegisteredExtensions(f.Options)
|
||||
if eds[fieldID] == nil {
|
||||
ed := &proto.ExtensionDesc{
|
||||
ExtendedType: extendedType,
|
||||
ExtensionType: extensionType,
|
||||
Field: fieldID,
|
||||
Tag: fmt.Sprintf("varint,%d", fieldID),
|
||||
}
|
||||
proto.RegisterExtension(ed)
|
||||
eds = proto.RegisteredExtensions(f.Options)
|
||||
}
|
||||
|
||||
ext, err := proto.GetExtension(f.Options, eds[fieldID])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
str, ok := ext.(*bool)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
return *str
|
||||
}
|
||||
|
||||
func init() {
|
||||
@ -180,6 +360,9 @@ func isFieldMessage(f *descriptor.FieldDescriptorProto) bool {
|
||||
}
|
||||
|
||||
func isFieldRepeated(f *descriptor.FieldDescriptorProto) bool {
|
||||
if f == nil {
|
||||
return false
|
||||
}
|
||||
if f.Type != nil && f.Label != nil && *f.Label == descriptor.FieldDescriptorProto_LABEL_REPEATED {
|
||||
return true
|
||||
}
|
||||
@ -187,6 +370,98 @@ func isFieldRepeated(f *descriptor.FieldDescriptorProto) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func isFieldMap(f *descriptor.FieldDescriptorProto, m *descriptor.DescriptorProto) bool {
|
||||
if f.TypeName == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
shortName := shortType(*f.TypeName)
|
||||
var nt *descriptor.DescriptorProto
|
||||
for _, t := range m.NestedType {
|
||||
if *t.Name == shortName {
|
||||
nt = t
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if nt == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, f := range nt.Field {
|
||||
switch *f.Name {
|
||||
case "key":
|
||||
if *f.Number != 1 {
|
||||
return false
|
||||
}
|
||||
case "value":
|
||||
if *f.Number != 2 {
|
||||
return false
|
||||
}
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func fieldMapKeyType(f *descriptor.FieldDescriptorProto, m *descriptor.DescriptorProto) *descriptor.FieldDescriptorProto {
|
||||
if f.TypeName == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
shortName := shortType(*f.TypeName)
|
||||
var nt *descriptor.DescriptorProto
|
||||
for _, t := range m.NestedType {
|
||||
if *t.Name == shortName {
|
||||
nt = t
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if nt == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, f := range nt.Field {
|
||||
if *f.Name == "key" {
|
||||
return f
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func fieldMapValueType(f *descriptor.FieldDescriptorProto, m *descriptor.DescriptorProto) *descriptor.FieldDescriptorProto {
|
||||
if f.TypeName == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
shortName := shortType(*f.TypeName)
|
||||
var nt *descriptor.DescriptorProto
|
||||
for _, t := range m.NestedType {
|
||||
if *t.Name == shortName {
|
||||
nt = t
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if nt == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, f := range nt.Field {
|
||||
if *f.Name == "value" {
|
||||
return f
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func goTypeWithPackage(f *descriptor.FieldDescriptorProto) string {
|
||||
pkg := ""
|
||||
if *f.Type == descriptor.FieldDescriptorProto_TYPE_MESSAGE || *f.Type == descriptor.FieldDescriptorProto_TYPE_ENUM {
|
||||
@ -319,6 +594,39 @@ func goType(pkg string, f *descriptor.FieldDescriptorProto) string {
|
||||
}
|
||||
}
|
||||
|
||||
func goZeroValue(f *descriptor.FieldDescriptorProto) string {
|
||||
const nilString = "nil"
|
||||
if *f.Label == descriptor.FieldDescriptorProto_LABEL_REPEATED {
|
||||
return nilString
|
||||
}
|
||||
switch *f.Type {
|
||||
case descriptor.FieldDescriptorProto_TYPE_DOUBLE:
|
||||
return "0.0"
|
||||
case descriptor.FieldDescriptorProto_TYPE_FLOAT:
|
||||
return "0.0"
|
||||
case descriptor.FieldDescriptorProto_TYPE_INT64:
|
||||
return "0"
|
||||
case descriptor.FieldDescriptorProto_TYPE_UINT64:
|
||||
return "0"
|
||||
case descriptor.FieldDescriptorProto_TYPE_INT32:
|
||||
return "0"
|
||||
case descriptor.FieldDescriptorProto_TYPE_UINT32:
|
||||
return "0"
|
||||
case descriptor.FieldDescriptorProto_TYPE_BOOL:
|
||||
return "false"
|
||||
case descriptor.FieldDescriptorProto_TYPE_STRING:
|
||||
return "\"\""
|
||||
case descriptor.FieldDescriptorProto_TYPE_MESSAGE:
|
||||
return nilString
|
||||
case descriptor.FieldDescriptorProto_TYPE_BYTES:
|
||||
return "0"
|
||||
case descriptor.FieldDescriptorProto_TYPE_ENUM:
|
||||
return nilString
|
||||
default:
|
||||
return nilString
|
||||
}
|
||||
}
|
||||
|
||||
func jsType(f *descriptor.FieldDescriptorProto) string {
|
||||
template := "%s"
|
||||
if isFieldRepeated(f) {
|
||||
@ -502,3 +810,14 @@ func formatID(base string, formatted string) string {
|
||||
}
|
||||
return formatted
|
||||
}
|
||||
|
||||
func replaceDict(src string, dict map[string]interface{}) string {
|
||||
for old, v := range dict {
|
||||
new, ok := v.(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
src = strings.Replace(src, old, new, -1)
|
||||
}
|
||||
return src
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user