protoc-gen-go-micro/helpers.go

285 lines
7.4 KiB
Go
Raw Normal View History

package main
import (
"encoding/json"
"fmt"
"strings"
"text/template"
"github.com/Masterminds/sprig"
2016-12-20 13:31:46 +03:00
"github.com/huandu/xstrings"
"github.com/golang/protobuf/proto"
"github.com/golang/protobuf/protoc-gen-go/descriptor"
options "github.com/grpc-ecosystem/grpc-gateway/third_party/googleapis/google/api"
)
var ProtoHelpersFuncMap = template.FuncMap{
"string": func(i interface {
String() string
}) string {
return i.String()
},
"json": func(v interface{}) string {
a, _ := json.Marshal(v)
return string(a)
},
"prettyjson": func(v interface{}) string {
a, _ := json.MarshalIndent(v, "", " ")
return string(a)
},
"splitArray": func(sep string, s string) []string {
return strings.Split(s, sep)
},
"first": func(a []string) string {
return a[0]
},
"last": func(a []string) string {
return a[len(a)-1]
},
2016-12-19 17:43:38 +03:00
"upperFirst": func(s string) string {
return strings.ToUpper(s[:1]) + s[1:]
},
"lowerFirst": func(s string) string {
return strings.ToLower(s[:1]) + s[1:]
},
2016-12-20 13:31:46 +03:00
"camelCase": func(s string) string {
2017-01-04 01:41:58 +03:00
if len(s) > 1 {
return xstrings.ToCamelCase(s)
}
return strings.ToUpper(s[:1])
2016-12-20 13:31:46 +03:00
},
2017-01-11 18:33:03 +03:00
"lowerCamelCase": func(s string) string {
if len(s) > 1 {
s = xstrings.ToCamelCase(s)
}
2017-01-04 01:41:58 +03:00
2017-01-11 18:33:03 +03:00
return strings.ToLower(s[:1]) + s[1:]
},
2016-12-20 13:31:46 +03:00
"kebabCase": func(s string) string {
return strings.Replace(xstrings.ToSnakeCase(s), "_", "-", -1)
},
"snakeCase": xstrings.ToSnakeCase,
"getMessageType": getMessageType,
"isFieldMessage": isFieldMessage,
"isFieldRepeated": isFieldRepeated,
"goType": goType,
"jsType": jsType,
"namespacedFlowType": namespacedFlowType,
"httpVerb": httpVerb,
"httpPath": httpPath,
"shortType": shortType,
"urlHasVarsFromMessage": urlHasVarsFromMessage,
}
func init() {
for k, v := range sprig.TxtFuncMap() {
ProtoHelpersFuncMap[k] = v
}
}
func getMessageType(f *descriptor.FileDescriptorProto, name string) *descriptor.DescriptorProto {
// name is in the form .packageName.MessageTypeName.InnerMessageTypeName...
// e.g. .article.ProductTag
splits := strings.Split(name, ".")
target := splits[len(splits)-1]
for _, m := range f.MessageType {
if target == *m.Name {
return m
}
}
return nil
}
func isFieldMessage(f *descriptor.FieldDescriptorProto) bool {
2016-12-29 19:21:53 +03:00
if f.Type != nil && *f.Type == descriptor.FieldDescriptorProto_TYPE_MESSAGE {
return true
}
return false
}
func isFieldRepeated(f *descriptor.FieldDescriptorProto) bool {
if f.Type != nil && f.Label != nil && *f.Label == descriptor.FieldDescriptorProto_LABEL_REPEATED {
return true
}
return false
}
func goType(pkg string, f *descriptor.FieldDescriptorProto) string {
switch *f.Type {
case descriptor.FieldDescriptorProto_TYPE_DOUBLE:
if *f.Label == descriptor.FieldDescriptorProto_LABEL_REPEATED {
return "[]float64"
}
return "float64"
case descriptor.FieldDescriptorProto_TYPE_FLOAT:
if *f.Label == descriptor.FieldDescriptorProto_LABEL_REPEATED {
return "[]float32"
}
return "float32"
case descriptor.FieldDescriptorProto_TYPE_INT64:
if *f.Label == descriptor.FieldDescriptorProto_LABEL_REPEATED {
return "[]int64"
}
return "int64"
case descriptor.FieldDescriptorProto_TYPE_UINT64:
if *f.Label == descriptor.FieldDescriptorProto_LABEL_REPEATED {
return "[]uint64"
}
return "uint64"
case descriptor.FieldDescriptorProto_TYPE_INT32:
if *f.Label == descriptor.FieldDescriptorProto_LABEL_REPEATED {
return "[]uint32"
}
return "uint32"
case descriptor.FieldDescriptorProto_TYPE_BOOL:
if *f.Label == descriptor.FieldDescriptorProto_LABEL_REPEATED {
return "[]bool"
}
return "bool"
case descriptor.FieldDescriptorProto_TYPE_STRING:
if *f.Label == descriptor.FieldDescriptorProto_LABEL_REPEATED {
return "[]string"
}
return "string"
case descriptor.FieldDescriptorProto_TYPE_MESSAGE:
if *f.Label == descriptor.FieldDescriptorProto_LABEL_REPEATED {
return fmt.Sprintf("[]*%s.%s", pkg, shortType(*f.TypeName))
}
return fmt.Sprintf("*%s.%s", pkg, shortType(*f.TypeName))
case descriptor.FieldDescriptorProto_TYPE_BYTES:
if *f.Label == descriptor.FieldDescriptorProto_LABEL_REPEATED {
return "[]byte"
}
return "byte"
case descriptor.FieldDescriptorProto_TYPE_UINT32:
if *f.Label == descriptor.FieldDescriptorProto_LABEL_REPEATED {
return "[]uint32"
}
return "uint32"
case descriptor.FieldDescriptorProto_TYPE_ENUM:
return fmt.Sprintf("*%s.%s", pkg, shortType(*f.TypeName))
default:
return "interface{}"
}
}
2017-01-04 02:35:21 +03:00
func jsType(f *descriptor.FieldDescriptorProto) string {
2017-01-05 20:04:20 +03:00
template := "%s"
if isFieldRepeated(f) == true {
template = "Array<%s>"
}
2017-01-04 02:35:21 +03:00
switch *f.Type {
2017-01-10 13:40:48 +03:00
case descriptor.FieldDescriptorProto_TYPE_MESSAGE,
descriptor.FieldDescriptorProto_TYPE_ENUM:
2017-01-10 14:24:06 +03:00
return fmt.Sprintf(template, namespacedFlowType(*f.TypeName))
2017-01-05 20:04:20 +03:00
case descriptor.FieldDescriptorProto_TYPE_DOUBLE,
descriptor.FieldDescriptorProto_TYPE_FLOAT,
descriptor.FieldDescriptorProto_TYPE_INT64,
descriptor.FieldDescriptorProto_TYPE_UINT64,
descriptor.FieldDescriptorProto_TYPE_INT32,
descriptor.FieldDescriptorProto_TYPE_FIXED64,
descriptor.FieldDescriptorProto_TYPE_FIXED32,
descriptor.FieldDescriptorProto_TYPE_UINT32,
descriptor.FieldDescriptorProto_TYPE_SFIXED32,
descriptor.FieldDescriptorProto_TYPE_SFIXED64,
descriptor.FieldDescriptorProto_TYPE_SINT32,
descriptor.FieldDescriptorProto_TYPE_SINT64:
return fmt.Sprintf(template, "number")
2017-01-04 02:35:21 +03:00
case descriptor.FieldDescriptorProto_TYPE_BOOL:
2017-01-05 20:04:20 +03:00
return fmt.Sprintf(template, "boolean")
2017-01-05 20:32:54 +03:00
case descriptor.FieldDescriptorProto_TYPE_BYTES:
2017-01-11 16:52:00 +03:00
return fmt.Sprintf(template, "Uint8Array")
2017-01-04 02:35:21 +03:00
case descriptor.FieldDescriptorProto_TYPE_STRING:
2017-01-05 20:04:20 +03:00
return fmt.Sprintf(template, "string")
2017-01-04 02:35:21 +03:00
default:
2017-01-05 20:04:20 +03:00
return fmt.Sprintf(template, "any")
2017-01-04 02:35:21 +03:00
}
}
func shortType(s string) string {
t := strings.Split(s, ".")
return t[len(t)-1]
}
2017-01-10 14:24:06 +03:00
func namespacedFlowType(s string) string {
trimmed := strings.TrimLeft(s, ".")
splitted := strings.Split(trimmed, ".")
2017-02-11 19:18:26 +03:00
return strings.Join(splitted, "$")
2017-01-10 14:24:06 +03:00
}
func httpPath(m *descriptor.MethodDescriptorProto) string {
ext, err := proto.GetExtension(m.Options, options.E_Http)
if err != nil {
return err.Error()
}
opts, ok := ext.(*options.HttpRule)
if !ok {
return fmt.Sprintf("extension is %T; want an HttpRule", ext)
}
switch t := opts.Pattern.(type) {
default:
return ""
case *options.HttpRule_Get:
return t.Get
case *options.HttpRule_Post:
return t.Post
case *options.HttpRule_Put:
return t.Put
case *options.HttpRule_Delete:
return t.Delete
case *options.HttpRule_Patch:
return t.Patch
case *options.HttpRule_Custom:
return t.Custom.Path
}
}
func httpVerb(m *descriptor.MethodDescriptorProto) string {
ext, err := proto.GetExtension(m.Options, options.E_Http)
if err != nil {
return err.Error()
}
opts, ok := ext.(*options.HttpRule)
if !ok {
return fmt.Sprintf("extension is %T; want an HttpRule", ext)
}
switch t := opts.Pattern.(type) {
default:
return ""
case *options.HttpRule_Get:
return "GET"
case *options.HttpRule_Post:
return "POST"
case *options.HttpRule_Put:
return "PUT"
case *options.HttpRule_Delete:
return "DELETE"
case *options.HttpRule_Patch:
return "PATCH"
case *options.HttpRule_Custom:
return t.Custom.Kind
}
}
func urlHasVarsFromMessage(path string, d *descriptor.DescriptorProto) bool {
for _, field := range d.Field {
if !isFieldMessage(field) {
if strings.Contains(path, fmt.Sprintf("{%s}", *field.Name)) {
return true
}
}
}
return false
}