Moved to google.golang.org/genproto/googleapis/api/annotations

Fixes #52
This commit is contained in:
Valerio Gheri
2017-03-31 18:01:58 +02:00
parent 024c5a4e4e
commit c40779224f
2037 changed files with 831329 additions and 1854 deletions

View File

@@ -0,0 +1,299 @@
package descriptor
import (
"fmt"
"path"
"path/filepath"
"strings"
"github.com/golang/glog"
descriptor "github.com/golang/protobuf/protoc-gen-go/descriptor"
plugin "github.com/golang/protobuf/protoc-gen-go/plugin"
)
// Registry is a registry of information extracted from plugin.CodeGeneratorRequest.
type Registry struct {
// msgs is a mapping from fully-qualified message name to descriptor
msgs map[string]*Message
// enums is a mapping from fully-qualified enum name to descriptor
enums map[string]*Enum
// files is a mapping from file path to descriptor
files map[string]*File
// prefix is a prefix to be inserted to golang package paths generated from proto package names.
prefix string
// pkgMap is a user-specified mapping from file path to proto package.
pkgMap map[string]string
// pkgAliases is a mapping from package aliases to package paths in go which are already taken.
pkgAliases map[string]string
// allowDeleteBody permits http delete methods to have a body
allowDeleteBody bool
}
// NewRegistry returns a new Registry.
func NewRegistry() *Registry {
return &Registry{
msgs: make(map[string]*Message),
enums: make(map[string]*Enum),
files: make(map[string]*File),
pkgMap: make(map[string]string),
pkgAliases: make(map[string]string),
}
}
// Load loads definitions of services, methods, messages, enumerations and fields from "req".
func (r *Registry) Load(req *plugin.CodeGeneratorRequest) error {
for _, file := range req.GetProtoFile() {
r.loadFile(file)
}
var targetPkg string
for _, name := range req.FileToGenerate {
target := r.files[name]
if target == nil {
return fmt.Errorf("no such file: %s", name)
}
name := packageIdentityName(target.FileDescriptorProto)
if targetPkg == "" {
targetPkg = name
} else {
if targetPkg != name {
return fmt.Errorf("inconsistent package names: %s %s", targetPkg, name)
}
}
if err := r.loadServices(target); err != nil {
return err
}
}
return nil
}
// loadFile loads messages, enumerations and fields from "file".
// It does not loads services and methods in "file". You need to call
// loadServices after loadFiles is called for all files to load services and methods.
func (r *Registry) loadFile(file *descriptor.FileDescriptorProto) {
pkg := GoPackage{
Path: r.goPackagePath(file),
Name: defaultGoPackageName(file),
}
if err := r.ReserveGoPackageAlias(pkg.Name, pkg.Path); err != nil {
for i := 0; ; i++ {
alias := fmt.Sprintf("%s_%d", pkg.Name, i)
if err := r.ReserveGoPackageAlias(alias, pkg.Path); err == nil {
pkg.Alias = alias
break
}
}
}
f := &File{
FileDescriptorProto: file,
GoPkg: pkg,
}
r.files[file.GetName()] = f
r.registerMsg(f, nil, file.GetMessageType())
r.registerEnum(f, nil, file.GetEnumType())
}
func (r *Registry) registerMsg(file *File, outerPath []string, msgs []*descriptor.DescriptorProto) {
for i, md := range msgs {
m := &Message{
File: file,
Outers: outerPath,
DescriptorProto: md,
Index: i,
}
for _, fd := range md.GetField() {
m.Fields = append(m.Fields, &Field{
Message: m,
FieldDescriptorProto: fd,
})
}
file.Messages = append(file.Messages, m)
r.msgs[m.FQMN()] = m
glog.V(1).Infof("register name: %s", m.FQMN())
var outers []string
outers = append(outers, outerPath...)
outers = append(outers, m.GetName())
r.registerMsg(file, outers, m.GetNestedType())
r.registerEnum(file, outers, m.GetEnumType())
}
}
func (r *Registry) registerEnum(file *File, outerPath []string, enums []*descriptor.EnumDescriptorProto) {
for i, ed := range enums {
e := &Enum{
File: file,
Outers: outerPath,
EnumDescriptorProto: ed,
Index: i,
}
file.Enums = append(file.Enums, e)
r.enums[e.FQEN()] = e
glog.V(1).Infof("register enum name: %s", e.FQEN())
}
}
// LookupMsg looks up a message type by "name".
// It tries to resolve "name" from "location" if "name" is a relative message name.
func (r *Registry) LookupMsg(location, name string) (*Message, error) {
glog.V(1).Infof("lookup %s from %s", name, location)
if strings.HasPrefix(name, ".") {
m, ok := r.msgs[name]
if !ok {
return nil, fmt.Errorf("no message found: %s", name)
}
return m, nil
}
if !strings.HasPrefix(location, ".") {
location = fmt.Sprintf(".%s", location)
}
components := strings.Split(location, ".")
for len(components) > 0 {
fqmn := strings.Join(append(components, name), ".")
if m, ok := r.msgs[fqmn]; ok {
return m, nil
}
components = components[:len(components)-1]
}
return nil, fmt.Errorf("no message found: %s", name)
}
// LookupEnum looks up a enum type by "name".
// It tries to resolve "name" from "location" if "name" is a relative enum name.
func (r *Registry) LookupEnum(location, name string) (*Enum, error) {
glog.V(1).Infof("lookup enum %s from %s", name, location)
if strings.HasPrefix(name, ".") {
e, ok := r.enums[name]
if !ok {
return nil, fmt.Errorf("no enum found: %s", name)
}
return e, nil
}
if !strings.HasPrefix(location, ".") {
location = fmt.Sprintf(".%s", location)
}
components := strings.Split(location, ".")
for len(components) > 0 {
fqen := strings.Join(append(components, name), ".")
if e, ok := r.enums[fqen]; ok {
return e, nil
}
components = components[:len(components)-1]
}
return nil, fmt.Errorf("no enum found: %s", name)
}
// LookupFile looks up a file by name.
func (r *Registry) LookupFile(name string) (*File, error) {
f, ok := r.files[name]
if !ok {
return nil, fmt.Errorf("no such file given: %s", name)
}
return f, nil
}
// AddPkgMap adds a mapping from a .proto file to proto package name.
func (r *Registry) AddPkgMap(file, protoPkg string) {
r.pkgMap[file] = protoPkg
}
// SetPrefix registeres the perfix to be added to go package paths generated from proto package names.
func (r *Registry) SetPrefix(prefix string) {
r.prefix = prefix
}
// ReserveGoPackageAlias reserves the unique alias of go package.
// If succeeded, the alias will be never used for other packages in generated go files.
// If failed, the alias is already taken by another package, so you need to use another
// alias for the package in your go files.
func (r *Registry) ReserveGoPackageAlias(alias, pkgpath string) error {
if taken, ok := r.pkgAliases[alias]; ok {
if taken == pkgpath {
return nil
}
return fmt.Errorf("package name %s is already taken. Use another alias", alias)
}
r.pkgAliases[alias] = pkgpath
return nil
}
// goPackagePath returns the go package path which go files generated from "f" should have.
// It respects the mapping registered by AddPkgMap if exists. Or use go_package as import path
// if it includes a slash, Otherwide, it generates a path from the file name of "f".
func (r *Registry) goPackagePath(f *descriptor.FileDescriptorProto) string {
name := f.GetName()
if pkg, ok := r.pkgMap[name]; ok {
return path.Join(r.prefix, pkg)
}
gopkg := f.Options.GetGoPackage()
idx := strings.LastIndex(gopkg, "/")
if idx >= 0 {
return gopkg
}
return path.Join(r.prefix, path.Dir(name))
}
// GetAllFQMNs returns a list of all FQMNs
func (r *Registry) GetAllFQMNs() []string {
var keys []string
for k := range r.msgs {
keys = append(keys, k)
}
return keys
}
// GetAllFQENs returns a list of all FQENs
func (r *Registry) GetAllFQENs() []string {
var keys []string
for k := range r.enums {
keys = append(keys, k)
}
return keys
}
// SetAllowDeleteBody controls whether http delete methods may have a
// body or fail loading if encountered.
func (r *Registry) SetAllowDeleteBody(allow bool) {
r.allowDeleteBody = allow
}
// defaultGoPackageName returns the default go package name to be used for go files generated from "f".
// You might need to use an unique alias for the package when you import it. Use ReserveGoPackageAlias to get a unique alias.
func defaultGoPackageName(f *descriptor.FileDescriptorProto) string {
name := packageIdentityName(f)
return strings.Replace(name, ".", "_", -1)
}
// packageIdentityName returns the identity of packages.
// protoc-gen-grpc-gateway rejects CodeGenerationRequests which contains more than one packages
// as protoc-gen-go does.
func packageIdentityName(f *descriptor.FileDescriptorProto) string {
if f.Options != nil && f.Options.GoPackage != nil {
gopkg := f.Options.GetGoPackage()
idx := strings.LastIndex(gopkg, "/")
if idx < 0 {
return gopkg
}
return gopkg[idx+1:]
}
if f.Package == nil {
base := filepath.Base(f.GetName())
ext := filepath.Ext(base)
return strings.TrimSuffix(base, ext)
}
return f.GetPackage()
}

View File

@@ -0,0 +1,533 @@
package descriptor
import (
"testing"
"github.com/golang/protobuf/proto"
descriptor "github.com/golang/protobuf/protoc-gen-go/descriptor"
plugin "github.com/golang/protobuf/protoc-gen-go/plugin"
)
func loadFile(t *testing.T, reg *Registry, src string) *descriptor.FileDescriptorProto {
var file descriptor.FileDescriptorProto
if err := proto.UnmarshalText(src, &file); err != nil {
t.Fatalf("proto.UnmarshalText(%s, &file) failed with %v; want success", src, err)
}
reg.loadFile(&file)
return &file
}
func load(t *testing.T, reg *Registry, src string) error {
var req plugin.CodeGeneratorRequest
if err := proto.UnmarshalText(src, &req); err != nil {
t.Fatalf("proto.UnmarshalText(%s, &file) failed with %v; want success", src, err)
}
return reg.Load(&req)
}
func TestLoadFile(t *testing.T) {
reg := NewRegistry()
fd := loadFile(t, reg, `
name: 'example.proto'
package: 'example'
message_type <
name: 'ExampleMessage'
field <
name: 'str'
label: LABEL_OPTIONAL
type: TYPE_STRING
number: 1
>
>
`)
file := reg.files["example.proto"]
if file == nil {
t.Errorf("reg.files[%q] = nil; want non-nil", "example.proto")
return
}
wantPkg := GoPackage{Path: ".", Name: "example"}
if got, want := file.GoPkg, wantPkg; got != want {
t.Errorf("file.GoPkg = %#v; want %#v", got, want)
}
msg, err := reg.LookupMsg("", ".example.ExampleMessage")
if err != nil {
t.Errorf("reg.LookupMsg(%q, %q)) failed with %v; want success", "", ".example.ExampleMessage", err)
return
}
if got, want := msg.DescriptorProto, fd.MessageType[0]; got != want {
t.Errorf("reg.lookupMsg(%q, %q).DescriptorProto = %#v; want %#v", "", ".example.ExampleMessage", got, want)
}
if got, want := msg.File, file; got != want {
t.Errorf("msg.File = %v; want %v", got, want)
}
if got := msg.Outers; got != nil {
t.Errorf("msg.Outers = %v; want %v", got, nil)
}
if got, want := len(msg.Fields), 1; got != want {
t.Errorf("len(msg.Fields) = %d; want %d", got, want)
} else if got, want := msg.Fields[0].FieldDescriptorProto, fd.MessageType[0].Field[0]; got != want {
t.Errorf("msg.Fields[0].FieldDescriptorProto = %v; want %v", got, want)
} else if got, want := msg.Fields[0].Message, msg; got != want {
t.Errorf("msg.Fields[0].Message = %v; want %v", got, want)
}
if got, want := len(file.Messages), 1; got != want {
t.Errorf("file.Meeesages = %#v; want %#v", file.Messages, []*Message{msg})
}
if got, want := file.Messages[0], msg; got != want {
t.Errorf("file.Meeesages[0] = %v; want %v", got, want)
}
}
func TestLoadFileNestedPackage(t *testing.T) {
reg := NewRegistry()
loadFile(t, reg, `
name: 'example.proto'
package: 'example.nested.nested2'
`)
file := reg.files["example.proto"]
if file == nil {
t.Errorf("reg.files[%q] = nil; want non-nil", "example.proto")
return
}
wantPkg := GoPackage{Path: ".", Name: "example_nested_nested2"}
if got, want := file.GoPkg, wantPkg; got != want {
t.Errorf("file.GoPkg = %#v; want %#v", got, want)
}
}
func TestLoadFileWithDir(t *testing.T) {
reg := NewRegistry()
loadFile(t, reg, `
name: 'path/to/example.proto'
package: 'example'
`)
file := reg.files["path/to/example.proto"]
if file == nil {
t.Errorf("reg.files[%q] = nil; want non-nil", "example.proto")
return
}
wantPkg := GoPackage{Path: "path/to", Name: "example"}
if got, want := file.GoPkg, wantPkg; got != want {
t.Errorf("file.GoPkg = %#v; want %#v", got, want)
}
}
func TestLoadFileWithoutPackage(t *testing.T) {
reg := NewRegistry()
loadFile(t, reg, `
name: 'path/to/example_file.proto'
`)
file := reg.files["path/to/example_file.proto"]
if file == nil {
t.Errorf("reg.files[%q] = nil; want non-nil", "example.proto")
return
}
wantPkg := GoPackage{Path: "path/to", Name: "example_file"}
if got, want := file.GoPkg, wantPkg; got != want {
t.Errorf("file.GoPkg = %#v; want %#v", got, want)
}
}
func TestLoadFileWithMapping(t *testing.T) {
reg := NewRegistry()
reg.AddPkgMap("path/to/example.proto", "example.com/proj/example/proto")
loadFile(t, reg, `
name: 'path/to/example.proto'
package: 'example'
`)
file := reg.files["path/to/example.proto"]
if file == nil {
t.Errorf("reg.files[%q] = nil; want non-nil", "example.proto")
return
}
wantPkg := GoPackage{Path: "example.com/proj/example/proto", Name: "example"}
if got, want := file.GoPkg, wantPkg; got != want {
t.Errorf("file.GoPkg = %#v; want %#v", got, want)
}
}
func TestLoadFileWithPackageNameCollision(t *testing.T) {
reg := NewRegistry()
loadFile(t, reg, `
name: 'path/to/another.proto'
package: 'example'
`)
loadFile(t, reg, `
name: 'path/to/example.proto'
package: 'example'
`)
if err := reg.ReserveGoPackageAlias("ioutil", "io/ioutil"); err != nil {
t.Fatalf("reg.ReserveGoPackageAlias(%q) failed with %v; want success", "ioutil", err)
}
loadFile(t, reg, `
name: 'path/to/ioutil.proto'
package: 'ioutil'
`)
file := reg.files["path/to/another.proto"]
if file == nil {
t.Errorf("reg.files[%q] = nil; want non-nil", "path/to/another.proto")
return
}
wantPkg := GoPackage{Path: "path/to", Name: "example"}
if got, want := file.GoPkg, wantPkg; got != want {
t.Errorf("file.GoPkg = %#v; want %#v", got, want)
}
file = reg.files["path/to/example.proto"]
if file == nil {
t.Errorf("reg.files[%q] = nil; want non-nil", "path/to/example.proto")
return
}
wantPkg = GoPackage{Path: "path/to", Name: "example", Alias: ""}
if got, want := file.GoPkg, wantPkg; got != want {
t.Errorf("file.GoPkg = %#v; want %#v", got, want)
}
file = reg.files["path/to/ioutil.proto"]
if file == nil {
t.Errorf("reg.files[%q] = nil; want non-nil", "path/to/ioutil.proto")
return
}
wantPkg = GoPackage{Path: "path/to", Name: "ioutil", Alias: "ioutil_0"}
if got, want := file.GoPkg, wantPkg; got != want {
t.Errorf("file.GoPkg = %#v; want %#v", got, want)
}
}
func TestLoadFileWithIdenticalGoPkg(t *testing.T) {
reg := NewRegistry()
reg.AddPkgMap("path/to/another.proto", "example.com/example")
reg.AddPkgMap("path/to/example.proto", "example.com/example")
loadFile(t, reg, `
name: 'path/to/another.proto'
package: 'example'
`)
loadFile(t, reg, `
name: 'path/to/example.proto'
package: 'example'
`)
file := reg.files["path/to/example.proto"]
if file == nil {
t.Errorf("reg.files[%q] = nil; want non-nil", "example.proto")
return
}
wantPkg := GoPackage{Path: "example.com/example", Name: "example"}
if got, want := file.GoPkg, wantPkg; got != want {
t.Errorf("file.GoPkg = %#v; want %#v", got, want)
}
file = reg.files["path/to/another.proto"]
if file == nil {
t.Errorf("reg.files[%q] = nil; want non-nil", "example.proto")
return
}
wantPkg = GoPackage{Path: "example.com/example", Name: "example"}
if got, want := file.GoPkg, wantPkg; got != want {
t.Errorf("file.GoPkg = %#v; want %#v", got, want)
}
}
func TestLoadFileWithPrefix(t *testing.T) {
reg := NewRegistry()
reg.SetPrefix("third_party")
loadFile(t, reg, `
name: 'path/to/example.proto'
package: 'example'
`)
file := reg.files["path/to/example.proto"]
if file == nil {
t.Errorf("reg.files[%q] = nil; want non-nil", "example.proto")
return
}
wantPkg := GoPackage{Path: "third_party/path/to", Name: "example"}
if got, want := file.GoPkg, wantPkg; got != want {
t.Errorf("file.GoPkg = %#v; want %#v", got, want)
}
}
func TestLookupMsgWithoutPackage(t *testing.T) {
reg := NewRegistry()
fd := loadFile(t, reg, `
name: 'example.proto'
message_type <
name: 'ExampleMessage'
field <
name: 'str'
label: LABEL_OPTIONAL
type: TYPE_STRING
number: 1
>
>
`)
msg, err := reg.LookupMsg("", ".ExampleMessage")
if err != nil {
t.Errorf("reg.LookupMsg(%q, %q)) failed with %v; want success", "", ".ExampleMessage", err)
return
}
if got, want := msg.DescriptorProto, fd.MessageType[0]; got != want {
t.Errorf("reg.lookupMsg(%q, %q).DescriptorProto = %#v; want %#v", "", ".ExampleMessage", got, want)
}
}
func TestLookupMsgWithNestedPackage(t *testing.T) {
reg := NewRegistry()
fd := loadFile(t, reg, `
name: 'example.proto'
package: 'nested.nested2.mypackage'
message_type <
name: 'ExampleMessage'
field <
name: 'str'
label: LABEL_OPTIONAL
type: TYPE_STRING
number: 1
>
>
`)
for _, name := range []string{
"nested.nested2.mypackage.ExampleMessage",
"nested2.mypackage.ExampleMessage",
"mypackage.ExampleMessage",
"ExampleMessage",
} {
msg, err := reg.LookupMsg("nested.nested2.mypackage", name)
if err != nil {
t.Errorf("reg.LookupMsg(%q, %q)) failed with %v; want success", ".nested.nested2.mypackage", name, err)
return
}
if got, want := msg.DescriptorProto, fd.MessageType[0]; got != want {
t.Errorf("reg.lookupMsg(%q, %q).DescriptorProto = %#v; want %#v", ".nested.nested2.mypackage", name, got, want)
}
}
for _, loc := range []string{
".nested.nested2.mypackage",
"nested.nested2.mypackage",
".nested.nested2",
"nested.nested2",
".nested",
"nested",
".",
"",
"somewhere.else",
} {
name := "nested.nested2.mypackage.ExampleMessage"
msg, err := reg.LookupMsg(loc, name)
if err != nil {
t.Errorf("reg.LookupMsg(%q, %q)) failed with %v; want success", loc, name, err)
return
}
if got, want := msg.DescriptorProto, fd.MessageType[0]; got != want {
t.Errorf("reg.lookupMsg(%q, %q).DescriptorProto = %#v; want %#v", loc, name, got, want)
}
}
for _, loc := range []string{
".nested.nested2.mypackage",
"nested.nested2.mypackage",
".nested.nested2",
"nested.nested2",
".nested",
"nested",
} {
name := "nested2.mypackage.ExampleMessage"
msg, err := reg.LookupMsg(loc, name)
if err != nil {
t.Errorf("reg.LookupMsg(%q, %q)) failed with %v; want success", loc, name, err)
return
}
if got, want := msg.DescriptorProto, fd.MessageType[0]; got != want {
t.Errorf("reg.lookupMsg(%q, %q).DescriptorProto = %#v; want %#v", loc, name, got, want)
}
}
}
func TestLoadWithInconsistentTargetPackage(t *testing.T) {
for _, spec := range []struct {
req string
consistent bool
}{
// root package, no explicit go package
{
req: `
file_to_generate: 'a.proto'
file_to_generate: 'b.proto'
proto_file <
name: 'a.proto'
message_type < name: 'A' >
service <
name: "AService"
method <
name: "Meth"
input_type: "A"
output_type: "A"
options <
[google.api.http] < post: "/v1/a" body: "*" >
>
>
>
>
proto_file <
name: 'b.proto'
message_type < name: 'B' >
service <
name: "BService"
method <
name: "Meth"
input_type: "B"
output_type: "B"
options <
[google.api.http] < post: "/v1/b" body: "*" >
>
>
>
>
`,
consistent: false,
},
// named package, no explicit go package
{
req: `
file_to_generate: 'a.proto'
file_to_generate: 'b.proto'
proto_file <
name: 'a.proto'
package: 'example.foo'
message_type < name: 'A' >
service <
name: "AService"
method <
name: "Meth"
input_type: "A"
output_type: "A"
options <
[google.api.http] < post: "/v1/a" body: "*" >
>
>
>
>
proto_file <
name: 'b.proto'
package: 'example.foo'
message_type < name: 'B' >
service <
name: "BService"
method <
name: "Meth"
input_type: "B"
output_type: "B"
options <
[google.api.http] < post: "/v1/b" body: "*" >
>
>
>
>
`,
consistent: true,
},
// root package, explicit go package
{
req: `
file_to_generate: 'a.proto'
file_to_generate: 'b.proto'
proto_file <
name: 'a.proto'
options < go_package: 'foo' >
message_type < name: 'A' >
service <
name: "AService"
method <
name: "Meth"
input_type: "A"
output_type: "A"
options <
[google.api.http] < post: "/v1/a" body: "*" >
>
>
>
>
proto_file <
name: 'b.proto'
options < go_package: 'foo' >
message_type < name: 'B' >
service <
name: "BService"
method <
name: "Meth"
input_type: "B"
output_type: "B"
options <
[google.api.http] < post: "/v1/b" body: "*" >
>
>
>
>
`,
consistent: true,
},
// named package, explicit go package
{
req: `
file_to_generate: 'a.proto'
file_to_generate: 'b.proto'
proto_file <
name: 'a.proto'
package: 'example.foo'
options < go_package: 'foo' >
message_type < name: 'A' >
service <
name: "AService"
method <
name: "Meth"
input_type: "A"
output_type: "A"
options <
[google.api.http] < post: "/v1/a" body: "*" >
>
>
>
>
proto_file <
name: 'b.proto'
package: 'example.foo'
options < go_package: 'foo' >
message_type < name: 'B' >
service <
name: "BService"
method <
name: "Meth"
input_type: "B"
output_type: "B"
options <
[google.api.http] < post: "/v1/b" body: "*" >
>
>
>
>
`,
consistent: true,
},
} {
reg := NewRegistry()
err := load(t, reg, spec.req)
if got, want := err == nil, spec.consistent; got != want {
if want {
t.Errorf("reg.Load(%s) failed with %v; want success", spec.req, err)
continue
}
t.Errorf("reg.Load(%s) succeeded; want an package inconsistency error", spec.req)
}
}
}

View File

@@ -0,0 +1,266 @@
package descriptor
import (
"fmt"
"strings"
"github.com/golang/glog"
"github.com/golang/protobuf/proto"
descriptor "github.com/golang/protobuf/protoc-gen-go/descriptor"
"github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/httprule"
options "google.golang.org/genproto/googleapis/api/annotations"
)
// loadServices registers services and their methods from "targetFile" to "r".
// It must be called after loadFile is called for all files so that loadServices
// can resolve names of message types and their fields.
func (r *Registry) loadServices(file *File) error {
glog.V(1).Infof("Loading services from %s", file.GetName())
var svcs []*Service
for _, sd := range file.GetService() {
glog.V(2).Infof("Registering %s", sd.GetName())
svc := &Service{
File: file,
ServiceDescriptorProto: sd,
}
for _, md := range sd.GetMethod() {
glog.V(2).Infof("Processing %s.%s", sd.GetName(), md.GetName())
opts, err := extractAPIOptions(md)
if err != nil {
glog.Errorf("Failed to extract ApiMethodOptions from %s.%s: %v", svc.GetName(), md.GetName(), err)
return err
}
if opts == nil {
glog.V(1).Infof("Found non-target method: %s.%s", svc.GetName(), md.GetName())
}
meth, err := r.newMethod(svc, md, opts)
if err != nil {
return err
}
svc.Methods = append(svc.Methods, meth)
}
if len(svc.Methods) == 0 {
continue
}
glog.V(2).Infof("Registered %s with %d method(s)", svc.GetName(), len(svc.Methods))
svcs = append(svcs, svc)
}
file.Services = svcs
return nil
}
func (r *Registry) newMethod(svc *Service, md *descriptor.MethodDescriptorProto, opts *options.HttpRule) (*Method, error) {
requestType, err := r.LookupMsg(svc.File.GetPackage(), md.GetInputType())
if err != nil {
return nil, err
}
responseType, err := r.LookupMsg(svc.File.GetPackage(), md.GetOutputType())
if err != nil {
return nil, err
}
meth := &Method{
Service: svc,
MethodDescriptorProto: md,
RequestType: requestType,
ResponseType: responseType,
}
newBinding := func(opts *options.HttpRule, idx int) (*Binding, error) {
var (
httpMethod string
pathTemplate string
)
switch {
case opts.GetGet() != "":
httpMethod = "GET"
pathTemplate = opts.GetGet()
if opts.Body != "" {
return nil, fmt.Errorf("needs request body even though http method is GET: %s", md.GetName())
}
case opts.GetPut() != "":
httpMethod = "PUT"
pathTemplate = opts.GetPut()
case opts.GetPost() != "":
httpMethod = "POST"
pathTemplate = opts.GetPost()
case opts.GetDelete() != "":
httpMethod = "DELETE"
pathTemplate = opts.GetDelete()
if opts.Body != "" && !r.allowDeleteBody {
return nil, fmt.Errorf("needs request body even though http method is DELETE: %s", md.GetName())
}
case opts.GetPatch() != "":
httpMethod = "PATCH"
pathTemplate = opts.GetPatch()
case opts.GetCustom() != nil:
custom := opts.GetCustom()
httpMethod = custom.Kind
pathTemplate = custom.Path
default:
glog.V(1).Infof("No pattern specified in google.api.HttpRule: %s", md.GetName())
return nil, nil
}
parsed, err := httprule.Parse(pathTemplate)
if err != nil {
return nil, err
}
tmpl := parsed.Compile()
if md.GetClientStreaming() && len(tmpl.Fields) > 0 {
return nil, fmt.Errorf("cannot use path parameter in client streaming")
}
b := &Binding{
Method: meth,
Index: idx,
PathTmpl: tmpl,
HTTPMethod: httpMethod,
}
for _, f := range tmpl.Fields {
param, err := r.newParam(meth, f)
if err != nil {
return nil, err
}
b.PathParams = append(b.PathParams, param)
}
// TODO(yugui) Handle query params
b.Body, err = r.newBody(meth, opts.Body)
if err != nil {
return nil, err
}
return b, nil
}
b, err := newBinding(opts, 0)
if err != nil {
return nil, err
}
if b != nil {
meth.Bindings = append(meth.Bindings, b)
}
for i, additional := range opts.GetAdditionalBindings() {
if len(additional.AdditionalBindings) > 0 {
return nil, fmt.Errorf("additional_binding in additional_binding not allowed: %s.%s", svc.GetName(), meth.GetName())
}
b, err := newBinding(additional, i+1)
if err != nil {
return nil, err
}
meth.Bindings = append(meth.Bindings, b)
}
return meth, nil
}
func extractAPIOptions(meth *descriptor.MethodDescriptorProto) (*options.HttpRule, error) {
if meth.Options == nil {
return nil, nil
}
if !proto.HasExtension(meth.Options, options.E_Http) {
return nil, nil
}
ext, err := proto.GetExtension(meth.Options, options.E_Http)
if err != nil {
return nil, err
}
opts, ok := ext.(*options.HttpRule)
if !ok {
return nil, fmt.Errorf("extension is %T; want an HttpRule", ext)
}
return opts, nil
}
func (r *Registry) newParam(meth *Method, path string) (Parameter, error) {
msg := meth.RequestType
fields, err := r.resolveFiledPath(msg, path)
if err != nil {
return Parameter{}, err
}
l := len(fields)
if l == 0 {
return Parameter{}, fmt.Errorf("invalid field access list for %s", path)
}
target := fields[l-1].Target
switch target.GetType() {
case descriptor.FieldDescriptorProto_TYPE_MESSAGE, descriptor.FieldDescriptorProto_TYPE_GROUP:
return Parameter{}, fmt.Errorf("aggregate type %s in parameter of %s.%s: %s", target.Type, meth.Service.GetName(), meth.GetName(), path)
}
return Parameter{
FieldPath: FieldPath(fields),
Method: meth,
Target: fields[l-1].Target,
}, nil
}
func (r *Registry) newBody(meth *Method, path string) (*Body, error) {
msg := meth.RequestType
switch path {
case "":
return nil, nil
case "*":
return &Body{FieldPath: nil}, nil
}
fields, err := r.resolveFiledPath(msg, path)
if err != nil {
return nil, err
}
return &Body{FieldPath: FieldPath(fields)}, nil
}
// lookupField looks up a field named "name" within "msg".
// It returns nil if no such field found.
func lookupField(msg *Message, name string) *Field {
for _, f := range msg.Fields {
if f.GetName() == name {
return f
}
}
return nil
}
// resolveFieldPath resolves "path" into a list of fieldDescriptor, starting from "msg".
func (r *Registry) resolveFiledPath(msg *Message, path string) ([]FieldPathComponent, error) {
if path == "" {
return nil, nil
}
root := msg
var result []FieldPathComponent
for i, c := range strings.Split(path, ".") {
if i > 0 {
f := result[i-1].Target
switch f.GetType() {
case descriptor.FieldDescriptorProto_TYPE_MESSAGE, descriptor.FieldDescriptorProto_TYPE_GROUP:
var err error
msg, err = r.LookupMsg(msg.FQMN(), f.GetTypeName())
if err != nil {
return nil, err
}
default:
return nil, fmt.Errorf("not an aggregate type: %s in %s", f.GetName(), path)
}
}
glog.V(2).Infof("Lookup %s in %s", c, msg.FQMN())
f := lookupField(msg, c)
if f == nil {
return nil, fmt.Errorf("no field %q found in %s", path, root.GetName())
}
if f.GetLabel() == descriptor.FieldDescriptorProto_LABEL_REPEATED {
return nil, fmt.Errorf("repeated field not allowed in field path: %s in %s", f.GetName(), path)
}
result = append(result, FieldPathComponent{Name: c, Target: f})
}
return result, nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,322 @@
package descriptor
import (
"fmt"
"strings"
descriptor "github.com/golang/protobuf/protoc-gen-go/descriptor"
gogen "github.com/golang/protobuf/protoc-gen-go/generator"
"github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/httprule"
)
// GoPackage represents a golang package
type GoPackage struct {
// Path is the package path to the package.
Path string
// Name is the package name of the package
Name string
// Alias is an alias of the package unique within the current invokation of grpc-gateway generator.
Alias string
}
// Standard returns whether the import is a golang standard package.
func (p GoPackage) Standard() bool {
return !strings.Contains(p.Path, ".")
}
// String returns a string representation of this package in the form of import line in golang.
func (p GoPackage) String() string {
if p.Alias == "" {
return fmt.Sprintf("%q", p.Path)
}
return fmt.Sprintf("%s %q", p.Alias, p.Path)
}
// File wraps descriptor.FileDescriptorProto for richer features.
type File struct {
*descriptor.FileDescriptorProto
// GoPkg is the go package of the go file generated from this file..
GoPkg GoPackage
// Messages is the list of messages defined in this file.
Messages []*Message
// Enums is the list of enums defined in this file.
Enums []*Enum
// Services is the list of services defined in this file.
Services []*Service
}
// proto2 determines if the syntax of the file is proto2.
func (f *File) proto2() bool {
return f.Syntax == nil || f.GetSyntax() == "proto2"
}
// Message describes a protocol buffer message types
type Message struct {
// File is the file where the message is defined
File *File
// Outers is a list of outer messages if this message is a nested type.
Outers []string
*descriptor.DescriptorProto
Fields []*Field
// Index is proto path index of this message in File.
Index int
}
// FQMN returns a fully qualified message name of this message.
func (m *Message) FQMN() string {
components := []string{""}
if m.File.Package != nil {
components = append(components, m.File.GetPackage())
}
components = append(components, m.Outers...)
components = append(components, m.GetName())
return strings.Join(components, ".")
}
// GoType returns a go type name for the message type.
// It prefixes the type name with the package alias if
// its belonging package is not "currentPackage".
func (m *Message) GoType(currentPackage string) string {
var components []string
components = append(components, m.Outers...)
components = append(components, m.GetName())
name := strings.Join(components, "_")
if m.File.GoPkg.Path == currentPackage {
return name
}
pkg := m.File.GoPkg.Name
if alias := m.File.GoPkg.Alias; alias != "" {
pkg = alias
}
return fmt.Sprintf("%s.%s", pkg, name)
}
// Enum describes a protocol buffer enum types
type Enum struct {
// File is the file where the enum is defined
File *File
// Outers is a list of outer messages if this enum is a nested type.
Outers []string
*descriptor.EnumDescriptorProto
Index int
}
// FQEN returns a fully qualified enum name of this enum.
func (e *Enum) FQEN() string {
components := []string{""}
if e.File.Package != nil {
components = append(components, e.File.GetPackage())
}
components = append(components, e.Outers...)
components = append(components, e.GetName())
return strings.Join(components, ".")
}
// Service wraps descriptor.ServiceDescriptorProto for richer features.
type Service struct {
// File is the file where this service is defined.
File *File
*descriptor.ServiceDescriptorProto
// Methods is the list of methods defined in this service.
Methods []*Method
}
// Method wraps descriptor.MethodDescriptorProto for richer features.
type Method struct {
// Service is the service which this method belongs to.
Service *Service
*descriptor.MethodDescriptorProto
// RequestType is the message type of requests to this method.
RequestType *Message
// ResponseType is the message type of responses from this method.
ResponseType *Message
Bindings []*Binding
}
// Binding describes how an HTTP endpoint is bound to a gRPC method.
type Binding struct {
// Method is the method which the endpoint is bound to.
Method *Method
// Index is a zero-origin index of the binding in the target method
Index int
// PathTmpl is path template where this method is mapped to.
PathTmpl httprule.Template
// HTTPMethod is the HTTP method which this method is mapped to.
HTTPMethod string
// PathParams is the list of parameters provided in HTTP request paths.
PathParams []Parameter
// Body describes parameters provided in HTTP request body.
Body *Body
}
// ExplicitParams returns a list of explicitly bound parameters of "b",
// i.e. a union of field path for body and field paths for path parameters.
func (b *Binding) ExplicitParams() []string {
var result []string
if b.Body != nil {
result = append(result, b.Body.FieldPath.String())
}
for _, p := range b.PathParams {
result = append(result, p.FieldPath.String())
}
return result
}
// Field wraps descriptor.FieldDescriptorProto for richer features.
type Field struct {
// Message is the message type which this field belongs to.
Message *Message
// FieldMessage is the message type of the field.
FieldMessage *Message
*descriptor.FieldDescriptorProto
}
// Parameter is a parameter provided in http requests
type Parameter struct {
// FieldPath is a path to a proto field which this parameter is mapped to.
FieldPath
// Target is the proto field which this parameter is mapped to.
Target *Field
// Method is the method which this parameter is used for.
Method *Method
}
// ConvertFuncExpr returns a go expression of a converter function.
// The converter function converts a string into a value for the parameter.
func (p Parameter) ConvertFuncExpr() (string, error) {
tbl := proto3ConvertFuncs
if p.Target.Message.File.proto2() {
tbl = proto2ConvertFuncs
}
typ := p.Target.GetType()
conv, ok := tbl[typ]
if !ok {
return "", fmt.Errorf("unsupported field type %s of parameter %s in %s.%s", typ, p.FieldPath, p.Method.Service.GetName(), p.Method.GetName())
}
return conv, nil
}
// Body describes a http requtest body to be sent to the method.
type Body struct {
// FieldPath is a path to a proto field which the request body is mapped to.
// The request body is mapped to the request type itself if FieldPath is empty.
FieldPath FieldPath
}
// RHS returns a right-hand-side expression in go to be used to initialize method request object.
// It starts with "msgExpr", which is the go expression of the method request object.
func (b Body) RHS(msgExpr string) string {
return b.FieldPath.RHS(msgExpr)
}
// FieldPath is a path to a field from a request message.
type FieldPath []FieldPathComponent
// String returns a string representation of the field path.
func (p FieldPath) String() string {
var components []string
for _, c := range p {
components = append(components, c.Name)
}
return strings.Join(components, ".")
}
// IsNestedProto3 indicates whether the FieldPath is a nested Proto3 path.
func (p FieldPath) IsNestedProto3() bool {
if len(p) > 1 && !p[0].Target.Message.File.proto2() {
return true
}
return false
}
// RHS is a right-hand-side expression in go to be used to assign a value to the target field.
// It starts with "msgExpr", which is the go expression of the method request object.
func (p FieldPath) RHS(msgExpr string) string {
l := len(p)
if l == 0 {
return msgExpr
}
components := []string{msgExpr}
for i, c := range p {
if i == l-1 {
components = append(components, c.RHS())
continue
}
components = append(components, c.LHS())
}
return strings.Join(components, ".")
}
// FieldPathComponent is a path component in FieldPath
type FieldPathComponent struct {
// Name is a name of the proto field which this component corresponds to.
// TODO(yugui) is this necessary?
Name string
// Target is the proto field which this component corresponds to.
Target *Field
}
// RHS returns a right-hand-side expression in go for this field.
func (c FieldPathComponent) RHS() string {
return gogen.CamelCase(c.Name)
}
// LHS returns a left-hand-side expression in go for this field.
func (c FieldPathComponent) LHS() string {
if c.Target.Message.File.proto2() {
return fmt.Sprintf("Get%s()", gogen.CamelCase(c.Name))
}
return gogen.CamelCase(c.Name)
}
var (
proto3ConvertFuncs = map[descriptor.FieldDescriptorProto_Type]string{
descriptor.FieldDescriptorProto_TYPE_DOUBLE: "runtime.Float64",
descriptor.FieldDescriptorProto_TYPE_FLOAT: "runtime.Float32",
descriptor.FieldDescriptorProto_TYPE_INT64: "runtime.Int64",
descriptor.FieldDescriptorProto_TYPE_UINT64: "runtime.Uint64",
descriptor.FieldDescriptorProto_TYPE_INT32: "runtime.Int32",
descriptor.FieldDescriptorProto_TYPE_FIXED64: "runtime.Uint64",
descriptor.FieldDescriptorProto_TYPE_FIXED32: "runtime.Uint32",
descriptor.FieldDescriptorProto_TYPE_BOOL: "runtime.Bool",
descriptor.FieldDescriptorProto_TYPE_STRING: "runtime.String",
// FieldDescriptorProto_TYPE_GROUP
// FieldDescriptorProto_TYPE_MESSAGE
// FieldDescriptorProto_TYPE_BYTES
// TODO(yugui) Handle bytes
descriptor.FieldDescriptorProto_TYPE_UINT32: "runtime.Uint32",
// FieldDescriptorProto_TYPE_ENUM
// TODO(yugui) Handle Enum
descriptor.FieldDescriptorProto_TYPE_SFIXED32: "runtime.Int32",
descriptor.FieldDescriptorProto_TYPE_SFIXED64: "runtime.Int64",
descriptor.FieldDescriptorProto_TYPE_SINT32: "runtime.Int32",
descriptor.FieldDescriptorProto_TYPE_SINT64: "runtime.Int64",
}
proto2ConvertFuncs = map[descriptor.FieldDescriptorProto_Type]string{
descriptor.FieldDescriptorProto_TYPE_DOUBLE: "runtime.Float64P",
descriptor.FieldDescriptorProto_TYPE_FLOAT: "runtime.Float32P",
descriptor.FieldDescriptorProto_TYPE_INT64: "runtime.Int64P",
descriptor.FieldDescriptorProto_TYPE_UINT64: "runtime.Uint64P",
descriptor.FieldDescriptorProto_TYPE_INT32: "runtime.Int32P",
descriptor.FieldDescriptorProto_TYPE_FIXED64: "runtime.Uint64P",
descriptor.FieldDescriptorProto_TYPE_FIXED32: "runtime.Uint32P",
descriptor.FieldDescriptorProto_TYPE_BOOL: "runtime.BoolP",
descriptor.FieldDescriptorProto_TYPE_STRING: "runtime.StringP",
// FieldDescriptorProto_TYPE_GROUP
// FieldDescriptorProto_TYPE_MESSAGE
// FieldDescriptorProto_TYPE_BYTES
// TODO(yugui) Handle bytes
descriptor.FieldDescriptorProto_TYPE_UINT32: "runtime.Uint32P",
// FieldDescriptorProto_TYPE_ENUM
// TODO(yugui) Handle Enum
descriptor.FieldDescriptorProto_TYPE_SFIXED32: "runtime.Int32P",
descriptor.FieldDescriptorProto_TYPE_SFIXED64: "runtime.Int64P",
descriptor.FieldDescriptorProto_TYPE_SINT32: "runtime.Int32P",
descriptor.FieldDescriptorProto_TYPE_SINT64: "runtime.Int64P",
}
)

View File

@@ -0,0 +1,206 @@
package descriptor
import (
"testing"
"github.com/golang/protobuf/proto"
descriptor "github.com/golang/protobuf/protoc-gen-go/descriptor"
)
func TestGoPackageStandard(t *testing.T) {
for _, spec := range []struct {
pkg GoPackage
want bool
}{
{
pkg: GoPackage{Path: "fmt", Name: "fmt"},
want: true,
},
{
pkg: GoPackage{Path: "encoding/json", Name: "json"},
want: true,
},
{
pkg: GoPackage{Path: "github.com/golang/protobuf/jsonpb", Name: "jsonpb"},
want: false,
},
{
pkg: GoPackage{Path: "golang.org/x/net/context", Name: "context"},
want: false,
},
{
pkg: GoPackage{Path: "github.com/grpc-ecosystem/grpc-gateway", Name: "main"},
want: false,
},
{
pkg: GoPackage{Path: "github.com/google/googleapis/google/api/http.pb", Name: "http_pb", Alias: "htpb"},
want: false,
},
} {
if got, want := spec.pkg.Standard(), spec.want; got != want {
t.Errorf("%#v.Standard() = %v; want %v", spec.pkg, got, want)
}
}
}
func TestGoPackageString(t *testing.T) {
for _, spec := range []struct {
pkg GoPackage
want string
}{
{
pkg: GoPackage{Path: "fmt", Name: "fmt"},
want: `"fmt"`,
},
{
pkg: GoPackage{Path: "encoding/json", Name: "json"},
want: `"encoding/json"`,
},
{
pkg: GoPackage{Path: "github.com/golang/protobuf/jsonpb", Name: "jsonpb"},
want: `"github.com/golang/protobuf/jsonpb"`,
},
{
pkg: GoPackage{Path: "golang.org/x/net/context", Name: "context"},
want: `"golang.org/x/net/context"`,
},
{
pkg: GoPackage{Path: "github.com/grpc-ecosystem/grpc-gateway", Name: "main"},
want: `"github.com/grpc-ecosystem/grpc-gateway"`,
},
{
pkg: GoPackage{Path: "github.com/google/googleapis/google/api/http.pb", Name: "http_pb", Alias: "htpb"},
want: `htpb "github.com/google/googleapis/google/api/http.pb"`,
},
} {
if got, want := spec.pkg.String(), spec.want; got != want {
t.Errorf("%#v.String() = %q; want %q", spec.pkg, got, want)
}
}
}
func TestFieldPath(t *testing.T) {
var fds []*descriptor.FileDescriptorProto
for _, src := range []string{
`
name: 'example.proto'
package: 'example'
message_type <
name: 'Nest'
field <
name: 'nest2_field'
label: LABEL_OPTIONAL
type: TYPE_MESSAGE
type_name: 'Nest2'
number: 1
>
field <
name: 'terminal_field'
label: LABEL_OPTIONAL
type: TYPE_STRING
number: 2
>
>
syntax: "proto3"
`, `
name: 'another.proto'
package: 'example'
message_type <
name: 'Nest2'
field <
name: 'nest_field'
label: LABEL_OPTIONAL
type: TYPE_MESSAGE
type_name: 'Nest'
number: 1
>
field <
name: 'terminal_field'
label: LABEL_OPTIONAL
type: TYPE_STRING
number: 2
>
>
syntax: "proto2"
`,
} {
var fd descriptor.FileDescriptorProto
if err := proto.UnmarshalText(src, &fd); err != nil {
t.Fatalf("proto.UnmarshalText(%s, &fd) failed with %v; want success", src, err)
}
fds = append(fds, &fd)
}
nest := &Message{
DescriptorProto: fds[0].MessageType[0],
Fields: []*Field{
{FieldDescriptorProto: fds[0].MessageType[0].Field[0]},
{FieldDescriptorProto: fds[0].MessageType[0].Field[1]},
},
}
nest2 := &Message{
DescriptorProto: fds[1].MessageType[0],
Fields: []*Field{
{FieldDescriptorProto: fds[1].MessageType[0].Field[0]},
{FieldDescriptorProto: fds[1].MessageType[0].Field[1]},
},
}
file1 := &File{
FileDescriptorProto: fds[0],
GoPkg: GoPackage{Path: "example", Name: "example"},
Messages: []*Message{nest},
}
file2 := &File{
FileDescriptorProto: fds[1],
GoPkg: GoPackage{Path: "example", Name: "example"},
Messages: []*Message{nest2},
}
crossLinkFixture(file1)
crossLinkFixture(file2)
c1 := FieldPathComponent{
Name: "nest_field",
Target: nest2.Fields[0],
}
if got, want := c1.LHS(), "GetNestField()"; got != want {
t.Errorf("c1.LHS() = %q; want %q", got, want)
}
if got, want := c1.RHS(), "NestField"; got != want {
t.Errorf("c1.RHS() = %q; want %q", got, want)
}
c2 := FieldPathComponent{
Name: "nest2_field",
Target: nest.Fields[0],
}
if got, want := c2.LHS(), "Nest2Field"; got != want {
t.Errorf("c2.LHS() = %q; want %q", got, want)
}
if got, want := c2.LHS(), "Nest2Field"; got != want {
t.Errorf("c2.LHS() = %q; want %q", got, want)
}
fp := FieldPath{
c1, c2, c1, FieldPathComponent{
Name: "terminal_field",
Target: nest.Fields[1],
},
}
if got, want := fp.RHS("resp"), "resp.GetNestField().Nest2Field.GetNestField().TerminalField"; got != want {
t.Errorf("fp.RHS(%q) = %q; want %q", "resp", got, want)
}
fp2 := FieldPath{
c2, c1, c2, FieldPathComponent{
Name: "terminal_field",
Target: nest2.Fields[1],
},
}
if got, want := fp2.RHS("resp"), "resp.Nest2Field.GetNestField().Nest2Field.TerminalField"; got != want {
t.Errorf("fp2.RHS(%q) = %q; want %q", "resp", got, want)
}
var fpEmpty FieldPath
if got, want := fpEmpty.RHS("resp"), "resp"; got != want {
t.Errorf("fpEmpty.RHS(%q) = %q; want %q", "resp", got, want)
}
}

View File

@@ -0,0 +1,13 @@
// Package generator provides an abstract interface to code generators.
package generator
import (
plugin "github.com/golang/protobuf/protoc-gen-go/plugin"
"github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/descriptor"
)
// Generator is an abstraction of code generators.
type Generator interface {
// Generate generates output files from input .proto files.
Generate(targets []*descriptor.File) ([]*plugin.CodeGeneratorResponse_File, error)
}

View File

@@ -0,0 +1,2 @@
// Package gengateway provides a code generator for grpc gateway files.
package gengateway

View File

@@ -0,0 +1,111 @@
package gengateway
import (
"errors"
"fmt"
"go/format"
"path"
"path/filepath"
"strings"
"github.com/golang/glog"
"github.com/golang/protobuf/proto"
plugin "github.com/golang/protobuf/protoc-gen-go/plugin"
"github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/descriptor"
gen "github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/generator"
options "google.golang.org/genproto/googleapis/api/annotations"
)
var (
errNoTargetService = errors.New("no target service defined in the file")
)
type generator struct {
reg *descriptor.Registry
baseImports []descriptor.GoPackage
useRequestContext bool
}
// New returns a new generator which generates grpc gateway files.
func New(reg *descriptor.Registry, useRequestContext bool) gen.Generator {
var imports []descriptor.GoPackage
for _, pkgpath := range []string{
"io",
"net/http",
"github.com/grpc-ecosystem/grpc-gateway/runtime",
"github.com/grpc-ecosystem/grpc-gateway/utilities",
"github.com/golang/protobuf/proto",
"golang.org/x/net/context",
"google.golang.org/grpc",
"google.golang.org/grpc/codes",
"google.golang.org/grpc/grpclog",
} {
pkg := descriptor.GoPackage{
Path: pkgpath,
Name: path.Base(pkgpath),
}
if err := reg.ReserveGoPackageAlias(pkg.Name, pkg.Path); err != nil {
for i := 0; ; i++ {
alias := fmt.Sprintf("%s_%d", pkg.Name, i)
if err := reg.ReserveGoPackageAlias(alias, pkg.Path); err != nil {
continue
}
pkg.Alias = alias
break
}
}
imports = append(imports, pkg)
}
return &generator{reg: reg, baseImports: imports, useRequestContext: useRequestContext}
}
func (g *generator) Generate(targets []*descriptor.File) ([]*plugin.CodeGeneratorResponse_File, error) {
var files []*plugin.CodeGeneratorResponse_File
for _, file := range targets {
glog.V(1).Infof("Processing %s", file.GetName())
code, err := g.generate(file)
if err == errNoTargetService {
glog.V(1).Infof("%s: %v", file.GetName(), err)
continue
}
if err != nil {
return nil, err
}
formatted, err := format.Source([]byte(code))
if err != nil {
glog.Errorf("%v: %s", err, code)
return nil, err
}
name := file.GetName()
ext := filepath.Ext(name)
base := strings.TrimSuffix(name, ext)
output := fmt.Sprintf("%s.pb.gw.go", base)
files = append(files, &plugin.CodeGeneratorResponse_File{
Name: proto.String(output),
Content: proto.String(string(formatted)),
})
glog.V(1).Infof("Will emit %s", output)
}
return files, nil
}
func (g *generator) generate(file *descriptor.File) (string, error) {
pkgSeen := make(map[string]bool)
var imports []descriptor.GoPackage
for _, pkg := range g.baseImports {
pkgSeen[pkg.Path] = true
imports = append(imports, pkg)
}
for _, svc := range file.Services {
for _, m := range svc.Methods {
pkg := m.RequestType.File.GoPkg
if m.Options == nil || !proto.HasExtension(m.Options, options.E_Http) ||
pkg == file.GoPkg || pkgSeen[pkg.Path] {
continue
}
pkgSeen[pkg.Path] = true
imports = append(imports, pkg)
}
}
return applyTemplate(param{File: file, Imports: imports, UseRequestContext: g.useRequestContext})
}

View File

@@ -0,0 +1,88 @@
package gengateway
import (
"strings"
"testing"
"github.com/golang/protobuf/proto"
protodescriptor "github.com/golang/protobuf/protoc-gen-go/descriptor"
"github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/descriptor"
)
func TestGenerateServiceWithoutBindings(t *testing.T) {
msgdesc := &protodescriptor.DescriptorProto{
Name: proto.String("ExampleMessage"),
}
msg := &descriptor.Message{
DescriptorProto: msgdesc,
}
msg1 := &descriptor.Message{
DescriptorProto: msgdesc,
File: &descriptor.File{
GoPkg: descriptor.GoPackage{
Path: "github.com/golang/protobuf/ptypes/empty",
Name: "empty",
},
},
}
meth := &protodescriptor.MethodDescriptorProto{
Name: proto.String("Example"),
InputType: proto.String("ExampleMessage"),
OutputType: proto.String("ExampleMessage"),
}
meth1 := &protodescriptor.MethodDescriptorProto{
Name: proto.String("ExampleWithoutBindings"),
InputType: proto.String("empty.Empty"),
OutputType: proto.String("empty.Empty"),
}
svc := &protodescriptor.ServiceDescriptorProto{
Name: proto.String("ExampleService"),
Method: []*protodescriptor.MethodDescriptorProto{meth, meth1},
}
file := descriptor.File{
FileDescriptorProto: &protodescriptor.FileDescriptorProto{
Name: proto.String("example.proto"),
Package: proto.String("example"),
Dependency: []string{"a.example/b/c.proto", "a.example/d/e.proto"},
MessageType: []*protodescriptor.DescriptorProto{msgdesc},
Service: []*protodescriptor.ServiceDescriptorProto{svc},
},
GoPkg: descriptor.GoPackage{
Path: "example.com/path/to/example/example.pb",
Name: "example_pb",
},
Messages: []*descriptor.Message{msg},
Services: []*descriptor.Service{
{
ServiceDescriptorProto: svc,
Methods: []*descriptor.Method{
{
MethodDescriptorProto: meth,
RequestType: msg,
ResponseType: msg,
Bindings: []*descriptor.Binding{
{
HTTPMethod: "GET",
Body: &descriptor.Body{FieldPath: nil},
},
},
},
{
MethodDescriptorProto: meth1,
RequestType: msg1,
ResponseType: msg1,
},
},
},
},
}
g := &generator{}
got, err := g.generate(crossLinkFixture(&file))
if err != nil {
t.Errorf("generate(%#v) failed with %v; want success", file, err)
return
}
if notwanted := `"github.com/golang/protobuf/ptypes/empty"`; strings.Contains(got, notwanted) {
t.Errorf("generate(%#v) = %s; does not want to contain %s", file, got, notwanted)
}
}

View File

@@ -0,0 +1,398 @@
package gengateway
import (
"bytes"
"fmt"
"strings"
"text/template"
"github.com/golang/glog"
"github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/descriptor"
"github.com/grpc-ecosystem/grpc-gateway/utilities"
)
type param struct {
*descriptor.File
Imports []descriptor.GoPackage
UseRequestContext bool
}
type binding struct {
*descriptor.Binding
}
// HasQueryParam determines if the binding needs parameters in query string.
//
// It sometimes returns true even though actually the binding does not need.
// But it is not serious because it just results in a small amount of extra codes generated.
func (b binding) HasQueryParam() bool {
if b.Body != nil && len(b.Body.FieldPath) == 0 {
return false
}
fields := make(map[string]bool)
for _, f := range b.Method.RequestType.Fields {
fields[f.GetName()] = true
}
if b.Body != nil {
delete(fields, b.Body.FieldPath.String())
}
for _, p := range b.PathParams {
delete(fields, p.FieldPath.String())
}
return len(fields) > 0
}
func (b binding) QueryParamFilter() queryParamFilter {
var seqs [][]string
if b.Body != nil {
seqs = append(seqs, strings.Split(b.Body.FieldPath.String(), "."))
}
for _, p := range b.PathParams {
seqs = append(seqs, strings.Split(p.FieldPath.String(), "."))
}
return queryParamFilter{utilities.NewDoubleArray(seqs)}
}
// queryParamFilter is a wrapper of utilities.DoubleArray which provides String() to output DoubleArray.Encoding in a stable and predictable format.
type queryParamFilter struct {
*utilities.DoubleArray
}
func (f queryParamFilter) String() string {
encodings := make([]string, len(f.Encoding))
for str, enc := range f.Encoding {
encodings[enc] = fmt.Sprintf("%q: %d", str, enc)
}
e := strings.Join(encodings, ", ")
return fmt.Sprintf("&utilities.DoubleArray{Encoding: map[string]int{%s}, Base: %#v, Check: %#v}", e, f.Base, f.Check)
}
type trailerParams struct {
Services []*descriptor.Service
UseRequestContext bool
}
func applyTemplate(p param) (string, error) {
w := bytes.NewBuffer(nil)
if err := headerTemplate.Execute(w, p); err != nil {
return "", err
}
var targetServices []*descriptor.Service
for _, svc := range p.Services {
var methodWithBindingsSeen bool
for _, meth := range svc.Methods {
glog.V(2).Infof("Processing %s.%s", svc.GetName(), meth.GetName())
methName := strings.Title(*meth.Name)
meth.Name = &methName
for _, b := range meth.Bindings {
methodWithBindingsSeen = true
if err := handlerTemplate.Execute(w, binding{Binding: b}); err != nil {
return "", err
}
}
}
if methodWithBindingsSeen {
targetServices = append(targetServices, svc)
}
}
if len(targetServices) == 0 {
return "", errNoTargetService
}
tp := trailerParams{
Services: targetServices,
UseRequestContext: p.UseRequestContext,
}
if err := trailerTemplate.Execute(w, tp); err != nil {
return "", err
}
return w.String(), nil
}
var (
headerTemplate = template.Must(template.New("header").Parse(`
// Code generated by protoc-gen-grpc-gateway
// source: {{.GetName}}
// DO NOT EDIT!
/*
Package {{.GoPkg.Name}} is a reverse proxy.
It translates gRPC into RESTful JSON APIs.
*/
package {{.GoPkg.Name}}
import (
{{range $i := .Imports}}{{if $i.Standard}}{{$i | printf "%s\n"}}{{end}}{{end}}
{{range $i := .Imports}}{{if not $i.Standard}}{{$i | printf "%s\n"}}{{end}}{{end}}
)
var _ codes.Code
var _ io.Reader
var _ = runtime.String
var _ = utilities.NewDoubleArray
`))
handlerTemplate = template.Must(template.New("handler").Parse(`
{{if and .Method.GetClientStreaming .Method.GetServerStreaming}}
{{template "bidi-streaming-request-func" .}}
{{else if .Method.GetClientStreaming}}
{{template "client-streaming-request-func" .}}
{{else}}
{{template "client-rpc-request-func" .}}
{{end}}
`))
_ = template.Must(handlerTemplate.New("request-func-signature").Parse(strings.Replace(`
{{if .Method.GetServerStreaming}}
func request_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}(ctx context.Context, marshaler runtime.Marshaler, client {{.Method.Service.GetName}}Client, req *http.Request, pathParams map[string]string) ({{.Method.Service.GetName}}_{{.Method.GetName}}Client, runtime.ServerMetadata, error)
{{else}}
func request_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}(ctx context.Context, marshaler runtime.Marshaler, client {{.Method.Service.GetName}}Client, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error)
{{end}}`, "\n", "", -1)))
_ = template.Must(handlerTemplate.New("client-streaming-request-func").Parse(`
{{template "request-func-signature" .}} {
var metadata runtime.ServerMetadata
stream, err := client.{{.Method.GetName}}(ctx)
if err != nil {
grpclog.Printf("Failed to start streaming: %v", err)
return nil, metadata, err
}
dec := marshaler.NewDecoder(req.Body)
for {
var protoReq {{.Method.RequestType.GoType .Method.Service.File.GoPkg.Path}}
err = dec.Decode(&protoReq)
if err == io.EOF {
break
}
if err != nil {
grpclog.Printf("Failed to decode request: %v", err)
return nil, metadata, grpc.Errorf(codes.InvalidArgument, "%v", err)
}
if err = stream.Send(&protoReq); err != nil {
grpclog.Printf("Failed to send request: %v", err)
return nil, metadata, err
}
}
if err := stream.CloseSend(); err != nil {
grpclog.Printf("Failed to terminate client stream: %v", err)
return nil, metadata, err
}
header, err := stream.Header()
if err != nil {
grpclog.Printf("Failed to get header from client: %v", err)
return nil, metadata, err
}
metadata.HeaderMD = header
{{if .Method.GetServerStreaming}}
return stream, metadata, nil
{{else}}
msg, err := stream.CloseAndRecv()
metadata.TrailerMD = stream.Trailer()
return msg, metadata, err
{{end}}
}
`))
_ = template.Must(handlerTemplate.New("client-rpc-request-func").Parse(`
{{if .HasQueryParam}}
var (
filter_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}} = {{.QueryParamFilter}}
)
{{end}}
{{template "request-func-signature" .}} {
var protoReq {{.Method.RequestType.GoType .Method.Service.File.GoPkg.Path}}
var metadata runtime.ServerMetadata
{{if .Body}}
if err := marshaler.NewDecoder(req.Body).Decode(&{{.Body.RHS "protoReq"}}); err != nil {
return nil, metadata, grpc.Errorf(codes.InvalidArgument, "%v", err)
}
{{end}}
{{if .PathParams}}
var (
val string
ok bool
err error
_ = err
)
{{range $param := .PathParams}}
val, ok = pathParams[{{$param | printf "%q"}}]
if !ok {
return nil, metadata, grpc.Errorf(codes.InvalidArgument, "missing parameter %s", {{$param | printf "%q"}})
}
{{if $param.IsNestedProto3 }}
err = runtime.PopulateFieldFromPath(&protoReq, {{$param | printf "%q"}}, val)
{{else}}
{{$param.RHS "protoReq"}}, err = {{$param.ConvertFuncExpr}}(val)
{{end}}
if err != nil {
return nil, metadata, err
}
{{end}}
{{end}}
{{if .HasQueryParam}}
if err := runtime.PopulateQueryParameters(&protoReq, req.URL.Query(), filter_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}); err != nil {
return nil, metadata, grpc.Errorf(codes.InvalidArgument, "%v", err)
}
{{end}}
{{if .Method.GetServerStreaming}}
stream, err := client.{{.Method.GetName}}(ctx, &protoReq)
if err != nil {
return nil, metadata, err
}
header, err := stream.Header()
if err != nil {
return nil, metadata, err
}
metadata.HeaderMD = header
return stream, metadata, nil
{{else}}
msg, err := client.{{.Method.GetName}}(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD))
return msg, metadata, err
{{end}}
}`))
_ = template.Must(handlerTemplate.New("bidi-streaming-request-func").Parse(`
{{template "request-func-signature" .}} {
var metadata runtime.ServerMetadata
stream, err := client.{{.Method.GetName}}(ctx)
if err != nil {
grpclog.Printf("Failed to start streaming: %v", err)
return nil, metadata, err
}
dec := marshaler.NewDecoder(req.Body)
handleSend := func() error {
var protoReq {{.Method.RequestType.GoType .Method.Service.File.GoPkg.Path}}
err = dec.Decode(&protoReq)
if err == io.EOF {
return err
}
if err != nil {
grpclog.Printf("Failed to decode request: %v", err)
return err
}
if err = stream.Send(&protoReq); err != nil {
grpclog.Printf("Failed to send request: %v", err)
return err
}
return nil
}
if err := handleSend(); err != nil {
if cerr := stream.CloseSend(); cerr != nil {
grpclog.Printf("Failed to terminate client stream: %v", cerr)
}
if err == io.EOF {
return stream, metadata, nil
}
return nil, metadata, err
}
go func() {
for {
if err := handleSend(); err != nil {
break
}
}
if err := stream.CloseSend(); err != nil {
grpclog.Printf("Failed to terminate client stream: %v", err)
}
}()
header, err := stream.Header()
if err != nil {
grpclog.Printf("Failed to get header from client: %v", err)
return nil, metadata, err
}
metadata.HeaderMD = header
return stream, metadata, nil
}
`))
trailerTemplate = template.Must(template.New("trailer").Parse(`
{{$UseRequestContext := .UseRequestContext}}
{{range $svc := .Services}}
// Register{{$svc.GetName}}HandlerFromEndpoint is same as Register{{$svc.GetName}}Handler but
// automatically dials to "endpoint" and closes the connection when "ctx" gets done.
func Register{{$svc.GetName}}HandlerFromEndpoint(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) (err error) {
conn, err := grpc.Dial(endpoint, opts...)
if err != nil {
return err
}
defer func() {
if err != nil {
if cerr := conn.Close(); cerr != nil {
grpclog.Printf("Failed to close conn to %s: %v", endpoint, cerr)
}
return
}
go func() {
<-ctx.Done()
if cerr := conn.Close(); cerr != nil {
grpclog.Printf("Failed to close conn to %s: %v", endpoint, cerr)
}
}()
}()
return Register{{$svc.GetName}}Handler(ctx, mux, conn)
}
// Register{{$svc.GetName}}Handler registers the http handlers for service {{$svc.GetName}} to "mux".
// The handlers forward requests to the grpc endpoint over "conn".
func Register{{$svc.GetName}}Handler(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error {
client := New{{$svc.GetName}}Client(conn)
{{range $m := $svc.Methods}}
{{range $b := $m.Bindings}}
mux.Handle({{$b.HTTPMethod | printf "%q"}}, pattern_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
{{- if $UseRequestContext }}
ctx, cancel := context.WithCancel(req.Context())
{{- else -}}
ctx, cancel := context.WithCancel(ctx)
{{- end }}
defer cancel()
if cn, ok := w.(http.CloseNotifier); ok {
go func(done <-chan struct{}, closed <-chan bool) {
select {
case <-done:
case <-closed:
cancel()
}
}(ctx.Done(), cn.CloseNotify())
}
inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req)
rctx, err := runtime.AnnotateContext(ctx, req)
if err != nil {
runtime.HTTPError(ctx, outboundMarshaler, w, req, err)
}
resp, md, err := request_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(rctx, inboundMarshaler, client, req, pathParams)
ctx = runtime.NewServerMetadataContext(ctx, md)
if err != nil {
runtime.HTTPError(ctx, outboundMarshaler, w, req, err)
return
}
{{if $m.GetServerStreaming}}
forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(ctx, outboundMarshaler, w, req, func() (proto.Message, error) { return resp.Recv() }, mux.GetForwardResponseOptions()...)
{{else}}
forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(ctx, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...)
{{end}}
})
{{end}}
{{end}}
return nil
}
var (
{{range $m := $svc.Methods}}
{{range $b := $m.Bindings}}
pattern_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}} = runtime.MustPattern(runtime.NewPattern({{$b.PathTmpl.Version}}, {{$b.PathTmpl.OpCodes | printf "%#v"}}, {{$b.PathTmpl.Pool | printf "%#v"}}, {{$b.PathTmpl.Verb | printf "%q"}}))
{{end}}
{{end}}
)
var (
{{range $m := $svc.Methods}}
{{range $b := $m.Bindings}}
forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}} = {{if $m.GetServerStreaming}}runtime.ForwardResponseStream{{else}}runtime.ForwardResponseMessage{{end}}
{{end}}
{{end}}
)
{{end}}`))
)

View File

@@ -0,0 +1,404 @@
package gengateway
import (
"strings"
"testing"
"github.com/golang/protobuf/proto"
protodescriptor "github.com/golang/protobuf/protoc-gen-go/descriptor"
"github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/descriptor"
"github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/httprule"
)
func crossLinkFixture(f *descriptor.File) *descriptor.File {
for _, m := range f.Messages {
m.File = f
}
for _, svc := range f.Services {
svc.File = f
for _, m := range svc.Methods {
m.Service = svc
for _, b := range m.Bindings {
b.Method = m
for _, param := range b.PathParams {
param.Method = m
}
}
}
}
return f
}
func TestApplyTemplateHeader(t *testing.T) {
msgdesc := &protodescriptor.DescriptorProto{
Name: proto.String("ExampleMessage"),
}
meth := &protodescriptor.MethodDescriptorProto{
Name: proto.String("Example"),
InputType: proto.String("ExampleMessage"),
OutputType: proto.String("ExampleMessage"),
}
svc := &protodescriptor.ServiceDescriptorProto{
Name: proto.String("ExampleService"),
Method: []*protodescriptor.MethodDescriptorProto{meth},
}
msg := &descriptor.Message{
DescriptorProto: msgdesc,
}
file := descriptor.File{
FileDescriptorProto: &protodescriptor.FileDescriptorProto{
Name: proto.String("example.proto"),
Package: proto.String("example"),
Dependency: []string{"a.example/b/c.proto", "a.example/d/e.proto"},
MessageType: []*protodescriptor.DescriptorProto{msgdesc},
Service: []*protodescriptor.ServiceDescriptorProto{svc},
},
GoPkg: descriptor.GoPackage{
Path: "example.com/path/to/example/example.pb",
Name: "example_pb",
},
Messages: []*descriptor.Message{msg},
Services: []*descriptor.Service{
{
ServiceDescriptorProto: svc,
Methods: []*descriptor.Method{
{
MethodDescriptorProto: meth,
RequestType: msg,
ResponseType: msg,
Bindings: []*descriptor.Binding{
{
HTTPMethod: "GET",
Body: &descriptor.Body{FieldPath: nil},
},
},
},
},
},
},
}
got, err := applyTemplate(param{File: crossLinkFixture(&file)})
if err != nil {
t.Errorf("applyTemplate(%#v) failed with %v; want success", file, err)
return
}
if want := "package example_pb\n"; !strings.Contains(got, want) {
t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
}
}
func TestApplyTemplateRequestWithoutClientStreaming(t *testing.T) {
msgdesc := &protodescriptor.DescriptorProto{
Name: proto.String("ExampleMessage"),
Field: []*protodescriptor.FieldDescriptorProto{
{
Name: proto.String("nested"),
Label: protodescriptor.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
Type: protodescriptor.FieldDescriptorProto_TYPE_MESSAGE.Enum(),
TypeName: proto.String("NestedMessage"),
Number: proto.Int32(1),
},
},
}
nesteddesc := &protodescriptor.DescriptorProto{
Name: proto.String("NestedMessage"),
Field: []*protodescriptor.FieldDescriptorProto{
{
Name: proto.String("int32"),
Label: protodescriptor.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
Type: protodescriptor.FieldDescriptorProto_TYPE_INT32.Enum(),
Number: proto.Int32(1),
},
{
Name: proto.String("bool"),
Label: protodescriptor.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
Type: protodescriptor.FieldDescriptorProto_TYPE_BOOL.Enum(),
Number: proto.Int32(2),
},
},
}
meth := &protodescriptor.MethodDescriptorProto{
Name: proto.String("Echo"),
InputType: proto.String("ExampleMessage"),
OutputType: proto.String("ExampleMessage"),
ClientStreaming: proto.Bool(false),
}
svc := &protodescriptor.ServiceDescriptorProto{
Name: proto.String("ExampleService"),
Method: []*protodescriptor.MethodDescriptorProto{meth},
}
for _, spec := range []struct {
serverStreaming bool
sigWant string
}{
{
serverStreaming: false,
sigWant: `func request_ExampleService_Echo_0(ctx context.Context, marshaler runtime.Marshaler, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) {`,
},
{
serverStreaming: true,
sigWant: `func request_ExampleService_Echo_0(ctx context.Context, marshaler runtime.Marshaler, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (ExampleService_EchoClient, runtime.ServerMetadata, error) {`,
},
} {
meth.ServerStreaming = proto.Bool(spec.serverStreaming)
msg := &descriptor.Message{
DescriptorProto: msgdesc,
}
nested := &descriptor.Message{
DescriptorProto: nesteddesc,
}
nestedField := &descriptor.Field{
Message: msg,
FieldDescriptorProto: msg.GetField()[0],
}
intField := &descriptor.Field{
Message: nested,
FieldDescriptorProto: nested.GetField()[0],
}
boolField := &descriptor.Field{
Message: nested,
FieldDescriptorProto: nested.GetField()[1],
}
file := descriptor.File{
FileDescriptorProto: &protodescriptor.FileDescriptorProto{
Name: proto.String("example.proto"),
Package: proto.String("example"),
MessageType: []*protodescriptor.DescriptorProto{msgdesc, nesteddesc},
Service: []*protodescriptor.ServiceDescriptorProto{svc},
},
GoPkg: descriptor.GoPackage{
Path: "example.com/path/to/example/example.pb",
Name: "example_pb",
},
Messages: []*descriptor.Message{msg, nested},
Services: []*descriptor.Service{
{
ServiceDescriptorProto: svc,
Methods: []*descriptor.Method{
{
MethodDescriptorProto: meth,
RequestType: msg,
ResponseType: msg,
Bindings: []*descriptor.Binding{
{
HTTPMethod: "POST",
PathTmpl: httprule.Template{
Version: 1,
OpCodes: []int{0, 0},
},
PathParams: []descriptor.Parameter{
{
FieldPath: descriptor.FieldPath([]descriptor.FieldPathComponent{
{
Name: "nested",
Target: nestedField,
},
{
Name: "int32",
Target: intField,
},
}),
Target: intField,
},
},
Body: &descriptor.Body{
FieldPath: descriptor.FieldPath([]descriptor.FieldPathComponent{
{
Name: "nested",
Target: nestedField,
},
{
Name: "bool",
Target: boolField,
},
}),
},
},
},
},
},
},
},
}
got, err := applyTemplate(param{File: crossLinkFixture(&file)})
if err != nil {
t.Errorf("applyTemplate(%#v) failed with %v; want success", file, err)
return
}
if want := spec.sigWant; !strings.Contains(got, want) {
t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
}
if want := `marshaler.NewDecoder(req.Body).Decode(&protoReq.GetNested().Bool)`; !strings.Contains(got, want) {
t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
}
if want := `val, ok = pathParams["nested.int32"]`; !strings.Contains(got, want) {
t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
}
if want := `protoReq.GetNested().Int32, err = runtime.Int32P(val)`; !strings.Contains(got, want) {
t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
}
if want := `func RegisterExampleServiceHandler(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error {`; !strings.Contains(got, want) {
t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
}
if want := `pattern_ExampleService_Echo_0 = runtime.MustPattern(runtime.NewPattern(1, []int{0, 0}, []string(nil), ""))`; !strings.Contains(got, want) {
t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
}
}
}
func TestApplyTemplateRequestWithClientStreaming(t *testing.T) {
msgdesc := &protodescriptor.DescriptorProto{
Name: proto.String("ExampleMessage"),
Field: []*protodescriptor.FieldDescriptorProto{
{
Name: proto.String("nested"),
Label: protodescriptor.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
Type: protodescriptor.FieldDescriptorProto_TYPE_MESSAGE.Enum(),
TypeName: proto.String("NestedMessage"),
Number: proto.Int32(1),
},
},
}
nesteddesc := &protodescriptor.DescriptorProto{
Name: proto.String("NestedMessage"),
Field: []*protodescriptor.FieldDescriptorProto{
{
Name: proto.String("int32"),
Label: protodescriptor.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
Type: protodescriptor.FieldDescriptorProto_TYPE_INT32.Enum(),
Number: proto.Int32(1),
},
{
Name: proto.String("bool"),
Label: protodescriptor.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
Type: protodescriptor.FieldDescriptorProto_TYPE_BOOL.Enum(),
Number: proto.Int32(2),
},
},
}
meth := &protodescriptor.MethodDescriptorProto{
Name: proto.String("Echo"),
InputType: proto.String("ExampleMessage"),
OutputType: proto.String("ExampleMessage"),
ClientStreaming: proto.Bool(true),
}
svc := &protodescriptor.ServiceDescriptorProto{
Name: proto.String("ExampleService"),
Method: []*protodescriptor.MethodDescriptorProto{meth},
}
for _, spec := range []struct {
serverStreaming bool
sigWant string
}{
{
serverStreaming: false,
sigWant: `func request_ExampleService_Echo_0(ctx context.Context, marshaler runtime.Marshaler, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) {`,
},
{
serverStreaming: true,
sigWant: `func request_ExampleService_Echo_0(ctx context.Context, marshaler runtime.Marshaler, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (ExampleService_EchoClient, runtime.ServerMetadata, error) {`,
},
} {
meth.ServerStreaming = proto.Bool(spec.serverStreaming)
msg := &descriptor.Message{
DescriptorProto: msgdesc,
}
nested := &descriptor.Message{
DescriptorProto: nesteddesc,
}
nestedField := &descriptor.Field{
Message: msg,
FieldDescriptorProto: msg.GetField()[0],
}
intField := &descriptor.Field{
Message: nested,
FieldDescriptorProto: nested.GetField()[0],
}
boolField := &descriptor.Field{
Message: nested,
FieldDescriptorProto: nested.GetField()[1],
}
file := descriptor.File{
FileDescriptorProto: &protodescriptor.FileDescriptorProto{
Name: proto.String("example.proto"),
Package: proto.String("example"),
MessageType: []*protodescriptor.DescriptorProto{msgdesc, nesteddesc},
Service: []*protodescriptor.ServiceDescriptorProto{svc},
},
GoPkg: descriptor.GoPackage{
Path: "example.com/path/to/example/example.pb",
Name: "example_pb",
},
Messages: []*descriptor.Message{msg, nested},
Services: []*descriptor.Service{
{
ServiceDescriptorProto: svc,
Methods: []*descriptor.Method{
{
MethodDescriptorProto: meth,
RequestType: msg,
ResponseType: msg,
Bindings: []*descriptor.Binding{
{
HTTPMethod: "POST",
PathTmpl: httprule.Template{
Version: 1,
OpCodes: []int{0, 0},
},
PathParams: []descriptor.Parameter{
{
FieldPath: descriptor.FieldPath([]descriptor.FieldPathComponent{
{
Name: "nested",
Target: nestedField,
},
{
Name: "int32",
Target: intField,
},
}),
Target: intField,
},
},
Body: &descriptor.Body{
FieldPath: descriptor.FieldPath([]descriptor.FieldPathComponent{
{
Name: "nested",
Target: nestedField,
},
{
Name: "bool",
Target: boolField,
},
}),
},
},
},
},
},
},
},
}
got, err := applyTemplate(param{File: crossLinkFixture(&file)})
if err != nil {
t.Errorf("applyTemplate(%#v) failed with %v; want success", file, err)
return
}
if want := spec.sigWant; !strings.Contains(got, want) {
t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
}
if want := `marshaler.NewDecoder(req.Body)`; !strings.Contains(got, want) {
t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
}
if want := `func RegisterExampleServiceHandler(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error {`; !strings.Contains(got, want) {
t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
}
if want := `pattern_ExampleService_Echo_0 = runtime.MustPattern(runtime.NewPattern(1, []int{0, 0}, []string(nil), ""))`; !strings.Contains(got, want) {
t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
}
}
}

View File

@@ -0,0 +1,117 @@
package httprule
import (
"github.com/grpc-ecosystem/grpc-gateway/utilities"
)
const (
opcodeVersion = 1
)
// Template is a compiled representation of path templates.
type Template struct {
// Version is the version number of the format.
Version int
// OpCodes is a sequence of operations.
OpCodes []int
// Pool is a constant pool
Pool []string
// Verb is a VERB part in the template.
Verb string
// Fields is a list of field paths bound in this template.
Fields []string
// Original template (example: /v1/a_bit_of_everything)
Template string
}
// Compiler compiles utilities representation of path templates into marshallable operations.
// They can be unmarshalled by runtime.NewPattern.
type Compiler interface {
Compile() Template
}
type op struct {
// code is the opcode of the operation
code utilities.OpCode
// str is a string operand of the code.
// num is ignored if str is not empty.
str string
// num is a numeric operand of the code.
num int
}
func (w wildcard) compile() []op {
return []op{
{code: utilities.OpPush},
}
}
func (w deepWildcard) compile() []op {
return []op{
{code: utilities.OpPushM},
}
}
func (l literal) compile() []op {
return []op{
{
code: utilities.OpLitPush,
str: string(l),
},
}
}
func (v variable) compile() []op {
var ops []op
for _, s := range v.segments {
ops = append(ops, s.compile()...)
}
ops = append(ops, op{
code: utilities.OpConcatN,
num: len(v.segments),
}, op{
code: utilities.OpCapture,
str: v.path,
})
return ops
}
func (t template) Compile() Template {
var rawOps []op
for _, s := range t.segments {
rawOps = append(rawOps, s.compile()...)
}
var (
ops []int
pool []string
fields []string
)
consts := make(map[string]int)
for _, op := range rawOps {
ops = append(ops, int(op.code))
if op.str == "" {
ops = append(ops, op.num)
} else {
if _, ok := consts[op.str]; !ok {
consts[op.str] = len(pool)
pool = append(pool, op.str)
}
ops = append(ops, consts[op.str])
}
if op.code == utilities.OpCapture {
fields = append(fields, op.str)
}
}
return Template{
Version: opcodeVersion,
OpCodes: ops,
Pool: pool,
Verb: t.verb,
Fields: fields,
Template: t.template,
}
}

View File

@@ -0,0 +1,122 @@
package httprule
import (
"reflect"
"testing"
"github.com/grpc-ecosystem/grpc-gateway/utilities"
)
const (
operandFiller = 0
)
func TestCompile(t *testing.T) {
for _, spec := range []struct {
segs []segment
verb string
ops []int
pool []string
fields []string
}{
{},
{
segs: []segment{
wildcard{},
},
ops: []int{int(utilities.OpPush), operandFiller},
},
{
segs: []segment{
deepWildcard{},
},
ops: []int{int(utilities.OpPushM), operandFiller},
},
{
segs: []segment{
literal("v1"),
},
ops: []int{int(utilities.OpLitPush), 0},
pool: []string{"v1"},
},
{
segs: []segment{
literal("v1"),
},
verb: "LOCK",
ops: []int{int(utilities.OpLitPush), 0},
pool: []string{"v1"},
},
{
segs: []segment{
variable{
path: "name.nested",
segments: []segment{
wildcard{},
},
},
},
ops: []int{
int(utilities.OpPush), operandFiller,
int(utilities.OpConcatN), 1,
int(utilities.OpCapture), 0,
},
pool: []string{"name.nested"},
fields: []string{"name.nested"},
},
{
segs: []segment{
literal("obj"),
variable{
path: "name.nested",
segments: []segment{
literal("a"),
wildcard{},
literal("b"),
},
},
variable{
path: "obj",
segments: []segment{
deepWildcard{},
},
},
},
ops: []int{
int(utilities.OpLitPush), 0,
int(utilities.OpLitPush), 1,
int(utilities.OpPush), operandFiller,
int(utilities.OpLitPush), 2,
int(utilities.OpConcatN), 3,
int(utilities.OpCapture), 3,
int(utilities.OpPushM), operandFiller,
int(utilities.OpConcatN), 1,
int(utilities.OpCapture), 0,
},
pool: []string{"obj", "a", "b", "name.nested"},
fields: []string{"name.nested", "obj"},
},
} {
tmpl := template{
segments: spec.segs,
verb: spec.verb,
}
compiled := tmpl.Compile()
if got, want := compiled.Version, opcodeVersion; got != want {
t.Errorf("tmpl.Compile().Version = %d; want %d; segs=%#v, verb=%q", got, want, spec.segs, spec.verb)
}
if got, want := compiled.OpCodes, spec.ops; !reflect.DeepEqual(got, want) {
t.Errorf("tmpl.Compile().OpCodes = %v; want %v; segs=%#v, verb=%q", got, want, spec.segs, spec.verb)
}
if got, want := compiled.Pool, spec.pool; !reflect.DeepEqual(got, want) {
t.Errorf("tmpl.Compile().Pool = %q; want %q; segs=%#v, verb=%q", got, want, spec.segs, spec.verb)
}
if got, want := compiled.Verb, spec.verb; got != want {
t.Errorf("tmpl.Compile().Verb = %q; want %q; segs=%#v, verb=%q", got, want, spec.segs, spec.verb)
}
if got, want := compiled.Fields, spec.fields; !reflect.DeepEqual(got, want) {
t.Errorf("tmpl.Compile().Fields = %q; want %q; segs=%#v, verb=%q", got, want, spec.segs, spec.verb)
}
}
}

View File

@@ -0,0 +1,351 @@
package httprule
import (
"fmt"
"strings"
"github.com/golang/glog"
)
// InvalidTemplateError indicates that the path template is not valid.
type InvalidTemplateError struct {
tmpl string
msg string
}
func (e InvalidTemplateError) Error() string {
return fmt.Sprintf("%s: %s", e.msg, e.tmpl)
}
// Parse parses the string representation of path template
func Parse(tmpl string) (Compiler, error) {
if !strings.HasPrefix(tmpl, "/") {
return template{}, InvalidTemplateError{tmpl: tmpl, msg: "no leading /"}
}
tokens, verb := tokenize(tmpl[1:])
p := parser{tokens: tokens}
segs, err := p.topLevelSegments()
if err != nil {
return template{}, InvalidTemplateError{tmpl: tmpl, msg: err.Error()}
}
return template{
segments: segs,
verb: verb,
template: tmpl,
}, nil
}
func tokenize(path string) (tokens []string, verb string) {
if path == "" {
return []string{eof}, ""
}
const (
init = iota
field
nested
)
var (
st = init
)
for path != "" {
var idx int
switch st {
case init:
idx = strings.IndexAny(path, "/{")
case field:
idx = strings.IndexAny(path, ".=}")
case nested:
idx = strings.IndexAny(path, "/}")
}
if idx < 0 {
tokens = append(tokens, path)
break
}
switch r := path[idx]; r {
case '/', '.':
case '{':
st = field
case '=':
st = nested
case '}':
st = init
}
if idx == 0 {
tokens = append(tokens, path[idx:idx+1])
} else {
tokens = append(tokens, path[:idx], path[idx:idx+1])
}
path = path[idx+1:]
}
l := len(tokens)
t := tokens[l-1]
if idx := strings.LastIndex(t, ":"); idx == 0 {
tokens, verb = tokens[:l-1], t[1:]
} else if idx > 0 {
tokens[l-1], verb = t[:idx], t[idx+1:]
}
tokens = append(tokens, eof)
return tokens, verb
}
// parser is a parser of the template syntax defined in github.com/googleapis/googleapis/google/api/http.proto.
type parser struct {
tokens []string
accepted []string
}
// topLevelSegments is the target of this parser.
func (p *parser) topLevelSegments() ([]segment, error) {
glog.V(1).Infof("Parsing %q", p.tokens)
segs, err := p.segments()
if err != nil {
return nil, err
}
glog.V(2).Infof("accept segments: %q; %q", p.accepted, p.tokens)
if _, err := p.accept(typeEOF); err != nil {
return nil, fmt.Errorf("unexpected token %q after segments %q", p.tokens[0], strings.Join(p.accepted, ""))
}
glog.V(2).Infof("accept eof: %q; %q", p.accepted, p.tokens)
return segs, nil
}
func (p *parser) segments() ([]segment, error) {
s, err := p.segment()
if err != nil {
return nil, err
}
glog.V(2).Infof("accept segment: %q; %q", p.accepted, p.tokens)
segs := []segment{s}
for {
if _, err := p.accept("/"); err != nil {
return segs, nil
}
s, err := p.segment()
if err != nil {
return segs, err
}
segs = append(segs, s)
glog.V(2).Infof("accept segment: %q; %q", p.accepted, p.tokens)
}
}
func (p *parser) segment() (segment, error) {
if _, err := p.accept("*"); err == nil {
return wildcard{}, nil
}
if _, err := p.accept("**"); err == nil {
return deepWildcard{}, nil
}
if l, err := p.literal(); err == nil {
return l, nil
}
v, err := p.variable()
if err != nil {
return nil, fmt.Errorf("segment neither wildcards, literal or variable: %v", err)
}
return v, err
}
func (p *parser) literal() (segment, error) {
lit, err := p.accept(typeLiteral)
if err != nil {
return nil, err
}
return literal(lit), nil
}
func (p *parser) variable() (segment, error) {
if _, err := p.accept("{"); err != nil {
return nil, err
}
path, err := p.fieldPath()
if err != nil {
return nil, err
}
var segs []segment
if _, err := p.accept("="); err == nil {
segs, err = p.segments()
if err != nil {
return nil, fmt.Errorf("invalid segment in variable %q: %v", path, err)
}
} else {
segs = []segment{wildcard{}}
}
if _, err := p.accept("}"); err != nil {
return nil, fmt.Errorf("unterminated variable segment: %s", path)
}
return variable{
path: path,
segments: segs,
}, nil
}
func (p *parser) fieldPath() (string, error) {
c, err := p.accept(typeIdent)
if err != nil {
return "", err
}
components := []string{c}
for {
if _, err = p.accept("."); err != nil {
return strings.Join(components, "."), nil
}
c, err := p.accept(typeIdent)
if err != nil {
return "", fmt.Errorf("invalid field path component: %v", err)
}
components = append(components, c)
}
}
// A termType is a type of terminal symbols.
type termType string
// These constants define some of valid values of termType.
// They improve readability of parse functions.
//
// You can also use "/", "*", "**", "." or "=" as valid values.
const (
typeIdent = termType("ident")
typeLiteral = termType("literal")
typeEOF = termType("$")
)
const (
// eof is the terminal symbol which always appears at the end of token sequence.
eof = "\u0000"
)
// accept tries to accept a token in "p".
// This function consumes a token and returns it if it matches to the specified "term".
// If it doesn't match, the function does not consume any tokens and return an error.
func (p *parser) accept(term termType) (string, error) {
t := p.tokens[0]
switch term {
case "/", "*", "**", ".", "=", "{", "}":
if t != string(term) {
return "", fmt.Errorf("expected %q but got %q", term, t)
}
case typeEOF:
if t != eof {
return "", fmt.Errorf("expected EOF but got %q", t)
}
case typeIdent:
if err := expectIdent(t); err != nil {
return "", err
}
case typeLiteral:
if err := expectPChars(t); err != nil {
return "", err
}
default:
return "", fmt.Errorf("unknown termType %q", term)
}
p.tokens = p.tokens[1:]
p.accepted = append(p.accepted, t)
return t, nil
}
// expectPChars determines if "t" consists of only pchars defined in RFC3986.
//
// https://www.ietf.org/rfc/rfc3986.txt, P.49
// pchar = unreserved / pct-encoded / sub-delims / ":" / "@"
// unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~"
// sub-delims = "!" / "$" / "&" / "'" / "(" / ")"
// / "*" / "+" / "," / ";" / "="
// pct-encoded = "%" HEXDIG HEXDIG
func expectPChars(t string) error {
const (
init = iota
pct1
pct2
)
st := init
for _, r := range t {
if st != init {
if !isHexDigit(r) {
return fmt.Errorf("invalid hexdigit: %c(%U)", r, r)
}
switch st {
case pct1:
st = pct2
case pct2:
st = init
}
continue
}
// unreserved
switch {
case 'A' <= r && r <= 'Z':
continue
case 'a' <= r && r <= 'z':
continue
case '0' <= r && r <= '9':
continue
}
switch r {
case '-', '.', '_', '~':
// unreserved
case '!', '$', '&', '\'', '(', ')', '*', '+', ',', ';', '=':
// sub-delims
case ':', '@':
// rest of pchar
case '%':
// pct-encoded
st = pct1
default:
return fmt.Errorf("invalid character in path segment: %q(%U)", r, r)
}
}
if st != init {
return fmt.Errorf("invalid percent-encoding in %q", t)
}
return nil
}
// expectIdent determines if "ident" is a valid identifier in .proto schema ([[:alpha:]_][[:alphanum:]_]*).
func expectIdent(ident string) error {
if ident == "" {
return fmt.Errorf("empty identifier")
}
for pos, r := range ident {
switch {
case '0' <= r && r <= '9':
if pos == 0 {
return fmt.Errorf("identifier starting with digit: %s", ident)
}
continue
case 'A' <= r && r <= 'Z':
continue
case 'a' <= r && r <= 'z':
continue
case r == '_':
continue
default:
return fmt.Errorf("invalid character %q(%U) in identifier: %s", r, r, ident)
}
}
return nil
}
func isHexDigit(r rune) bool {
switch {
case '0' <= r && r <= '9':
return true
case 'A' <= r && r <= 'F':
return true
case 'a' <= r && r <= 'f':
return true
}
return false
}

View File

@@ -0,0 +1,313 @@
package httprule
import (
"flag"
"fmt"
"reflect"
"testing"
"github.com/golang/glog"
)
func TestTokenize(t *testing.T) {
for _, spec := range []struct {
src string
tokens []string
}{
{
src: "",
tokens: []string{eof},
},
{
src: "v1",
tokens: []string{"v1", eof},
},
{
src: "v1/b",
tokens: []string{"v1", "/", "b", eof},
},
{
src: "v1/endpoint/*",
tokens: []string{"v1", "/", "endpoint", "/", "*", eof},
},
{
src: "v1/endpoint/**",
tokens: []string{"v1", "/", "endpoint", "/", "**", eof},
},
{
src: "v1/b/{bucket_name=*}",
tokens: []string{
"v1", "/",
"b", "/",
"{", "bucket_name", "=", "*", "}",
eof,
},
},
{
src: "v1/b/{bucket_name=buckets/*}",
tokens: []string{
"v1", "/",
"b", "/",
"{", "bucket_name", "=", "buckets", "/", "*", "}",
eof,
},
},
{
src: "v1/b/{bucket_name=buckets/*}/o",
tokens: []string{
"v1", "/",
"b", "/",
"{", "bucket_name", "=", "buckets", "/", "*", "}", "/",
"o",
eof,
},
},
{
src: "v1/b/{bucket_name=buckets/*}/o/{name}",
tokens: []string{
"v1", "/",
"b", "/",
"{", "bucket_name", "=", "buckets", "/", "*", "}", "/",
"o", "/", "{", "name", "}",
eof,
},
},
{
src: "v1/a=b&c=d;e=f:g/endpoint.rdf",
tokens: []string{
"v1", "/",
"a=b&c=d;e=f:g", "/",
"endpoint.rdf",
eof,
},
},
} {
tokens, verb := tokenize(spec.src)
if got, want := tokens, spec.tokens; !reflect.DeepEqual(got, want) {
t.Errorf("tokenize(%q) = %q, _; want %q, _", spec.src, got, want)
}
if got, want := verb, ""; got != want {
t.Errorf("tokenize(%q) = _, %q; want _, %q", spec.src, got, want)
}
src := fmt.Sprintf("%s:%s", spec.src, "LOCK")
tokens, verb = tokenize(src)
if got, want := tokens, spec.tokens; !reflect.DeepEqual(got, want) {
t.Errorf("tokenize(%q) = %q, _; want %q, _", src, got, want)
}
if got, want := verb, "LOCK"; got != want {
t.Errorf("tokenize(%q) = _, %q; want _, %q", src, got, want)
}
}
}
func TestParseSegments(t *testing.T) {
flag.Set("v", "3")
for _, spec := range []struct {
tokens []string
want []segment
}{
{
tokens: []string{"v1", eof},
want: []segment{
literal("v1"),
},
},
{
tokens: []string{"-._~!$&'()*+,;=:@", eof},
want: []segment{
literal("-._~!$&'()*+,;=:@"),
},
},
{
tokens: []string{"%e7%ac%ac%e4%b8%80%e7%89%88", eof},
want: []segment{
literal("%e7%ac%ac%e4%b8%80%e7%89%88"),
},
},
{
tokens: []string{"v1", "/", "*", eof},
want: []segment{
literal("v1"),
wildcard{},
},
},
{
tokens: []string{"v1", "/", "**", eof},
want: []segment{
literal("v1"),
deepWildcard{},
},
},
{
tokens: []string{"{", "name", "}", eof},
want: []segment{
variable{
path: "name",
segments: []segment{
wildcard{},
},
},
},
},
{
tokens: []string{"{", "name", "=", "*", "}", eof},
want: []segment{
variable{
path: "name",
segments: []segment{
wildcard{},
},
},
},
},
{
tokens: []string{"{", "field", ".", "nested", ".", "nested2", "=", "*", "}", eof},
want: []segment{
variable{
path: "field.nested.nested2",
segments: []segment{
wildcard{},
},
},
},
},
{
tokens: []string{"{", "name", "=", "a", "/", "b", "/", "*", "}", eof},
want: []segment{
variable{
path: "name",
segments: []segment{
literal("a"),
literal("b"),
wildcard{},
},
},
},
},
{
tokens: []string{
"v1", "/",
"{",
"name", ".", "nested", ".", "nested2",
"=",
"a", "/", "b", "/", "*",
"}", "/",
"o", "/",
"{",
"another_name",
"=",
"a", "/", "b", "/", "*", "/", "c",
"}", "/",
"**",
eof},
want: []segment{
literal("v1"),
variable{
path: "name.nested.nested2",
segments: []segment{
literal("a"),
literal("b"),
wildcard{},
},
},
literal("o"),
variable{
path: "another_name",
segments: []segment{
literal("a"),
literal("b"),
wildcard{},
literal("c"),
},
},
deepWildcard{},
},
},
} {
p := parser{tokens: spec.tokens}
segs, err := p.topLevelSegments()
if err != nil {
t.Errorf("parser{%q}.segments() failed with %v; want success", spec.tokens, err)
continue
}
if got, want := segs, spec.want; !reflect.DeepEqual(got, want) {
t.Errorf("parser{%q}.segments() = %#v; want %#v", spec.tokens, got, want)
}
if got := p.tokens; len(got) > 0 {
t.Errorf("p.tokens = %q; want []; spec.tokens=%q", got, spec.tokens)
}
}
}
func TestParseSegmentsWithErrors(t *testing.T) {
flag.Set("v", "3")
for _, spec := range []struct {
tokens []string
}{
{
// double slash
tokens: []string{"/", eof},
},
{
// invalid literal
tokens: []string{"a?b", eof},
},
{
// invalid percent-encoding
tokens: []string{"%", eof},
},
{
// invalid percent-encoding
tokens: []string{"%2", eof},
},
{
// invalid percent-encoding
tokens: []string{"a%2z", eof},
},
{
// empty segments
tokens: []string{eof},
},
{
// unterminated variable
tokens: []string{"{", "name", eof},
},
{
// unterminated variable
tokens: []string{"{", "name", "=", eof},
},
{
// unterminated variable
tokens: []string{"{", "name", "=", "*", eof},
},
{
// empty component in field path
tokens: []string{"{", "name", ".", "}", eof},
},
{
// empty component in field path
tokens: []string{"{", "name", ".", ".", "nested", "}", eof},
},
{
// invalid character in identifier
tokens: []string{"{", "field-name", "}", eof},
},
{
// no slash between segments
tokens: []string{"v1", "endpoint", eof},
},
{
// no slash between segments
tokens: []string{"v1", "{", "name", "}", eof},
},
} {
p := parser{tokens: spec.tokens}
segs, err := p.topLevelSegments()
if err == nil {
t.Errorf("parser{%q}.segments() succeeded; want InvalidTemplateError; accepted %#v", spec.tokens, segs)
continue
}
glog.V(1).Info(err)
}
}

View File

@@ -0,0 +1,60 @@
package httprule
import (
"fmt"
"strings"
)
type template struct {
segments []segment
verb string
template string
}
type segment interface {
fmt.Stringer
compile() (ops []op)
}
type wildcard struct{}
type deepWildcard struct{}
type literal string
type variable struct {
path string
segments []segment
}
func (wildcard) String() string {
return "*"
}
func (deepWildcard) String() string {
return "**"
}
func (l literal) String() string {
return string(l)
}
func (v variable) String() string {
var segs []string
for _, s := range v.segments {
segs = append(segs, s.String())
}
return fmt.Sprintf("{%s=%s}", v.path, strings.Join(segs, "/"))
}
func (t template) String() string {
var segs []string
for _, s := range t.segments {
segs = append(segs, s.String())
}
str := strings.Join(segs, "/")
if t.verb != "" {
str = fmt.Sprintf("%s:%s", str, t.verb)
}
return "/" + str
}

View File

@@ -0,0 +1,91 @@
package httprule
import (
"fmt"
"testing"
)
func TestTemplateStringer(t *testing.T) {
for _, spec := range []struct {
segs []segment
want string
}{
{
segs: []segment{
literal("v1"),
},
want: "/v1",
},
{
segs: []segment{
wildcard{},
},
want: "/*",
},
{
segs: []segment{
deepWildcard{},
},
want: "/**",
},
{
segs: []segment{
variable{
path: "name",
segments: []segment{
literal("a"),
},
},
},
want: "/{name=a}",
},
{
segs: []segment{
variable{
path: "name",
segments: []segment{
literal("a"),
wildcard{},
literal("b"),
},
},
},
want: "/{name=a/*/b}",
},
{
segs: []segment{
literal("v1"),
variable{
path: "name",
segments: []segment{
literal("a"),
wildcard{},
literal("b"),
},
},
literal("c"),
variable{
path: "field.nested",
segments: []segment{
wildcard{},
literal("d"),
},
},
wildcard{},
literal("e"),
deepWildcard{},
},
want: "/v1/{name=a/*/b}/c/{field.nested=*/d}/*/e/**",
},
} {
tmpl := template{segments: spec.segs}
if got, want := tmpl.String(), spec.want; got != want {
t.Errorf("%#v.String() = %q; want %q", tmpl, got, want)
}
tmpl.verb = "LOCK"
if got, want := tmpl.String(), fmt.Sprintf("%s:LOCK", spec.want); got != want {
t.Errorf("%#v.String() = %q; want %q", tmpl, got, want)
}
}
}

View File

@@ -0,0 +1,119 @@
// Command protoc-gen-grpc-gateway is a plugin for Google protocol buffer
// compiler to generate a reverse-proxy, which converts incoming RESTful
// HTTP/1 requests gRPC invocation.
// You rarely need to run this program directly. Instead, put this program
// into your $PATH with a name "protoc-gen-grpc-gateway" and run
// protoc --grpc-gateway_out=output_directory path/to/input.proto
//
// See README.md for more details.
package main
import (
"flag"
"io"
"io/ioutil"
"os"
"strings"
"github.com/golang/glog"
"github.com/golang/protobuf/proto"
plugin "github.com/golang/protobuf/protoc-gen-go/plugin"
"github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/descriptor"
"github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/gengateway"
)
var (
importPrefix = flag.String("import_prefix", "", "prefix to be added to go package paths for imported proto files")
useRequestContext = flag.Bool("request_context", false, "determine whether to use http.Request's context or not")
)
func parseReq(r io.Reader) (*plugin.CodeGeneratorRequest, error) {
glog.V(1).Info("Parsing code generator request")
input, err := ioutil.ReadAll(r)
if err != nil {
glog.Errorf("Failed to read code generator request: %v", err)
return nil, err
}
req := new(plugin.CodeGeneratorRequest)
if err = proto.Unmarshal(input, req); err != nil {
glog.Errorf("Failed to unmarshal code generator request: %v", err)
return nil, err
}
glog.V(1).Info("Parsed code generator request")
return req, nil
}
func main() {
flag.Parse()
defer glog.Flush()
reg := descriptor.NewRegistry()
glog.V(1).Info("Processing code generator request")
req, err := parseReq(os.Stdin)
if err != nil {
glog.Fatal(err)
}
if req.Parameter != nil {
for _, p := range strings.Split(req.GetParameter(), ",") {
spec := strings.SplitN(p, "=", 2)
if len(spec) == 1 {
if err := flag.CommandLine.Set(spec[0], ""); err != nil {
glog.Fatalf("Cannot set flag %s", p)
}
continue
}
name, value := spec[0], spec[1]
if strings.HasPrefix(name, "M") {
reg.AddPkgMap(name[1:], value)
continue
}
if err := flag.CommandLine.Set(name, value); err != nil {
glog.Fatalf("Cannot set flag %s", p)
}
}
}
g := gengateway.New(reg, *useRequestContext)
reg.SetPrefix(*importPrefix)
if err := reg.Load(req); err != nil {
emitError(err)
return
}
var targets []*descriptor.File
for _, target := range req.FileToGenerate {
f, err := reg.LookupFile(target)
if err != nil {
glog.Fatal(err)
}
targets = append(targets, f)
}
out, err := g.Generate(targets)
glog.V(1).Info("Processed code generator request")
if err != nil {
emitError(err)
return
}
emitFiles(out)
}
func emitFiles(out []*plugin.CodeGeneratorResponse_File) {
emitResp(&plugin.CodeGeneratorResponse{File: out})
}
func emitError(err error) {
emitResp(&plugin.CodeGeneratorResponse{Error: proto.String(err.Error())})
}
func emitResp(resp *plugin.CodeGeneratorResponse) {
buf, err := proto.Marshal(resp)
if err != nil {
glog.Fatal(err)
}
if _, err := os.Stdout.Write(buf); err != nil {
glog.Fatal(err)
}
}