support message reference

Signed-off-by: Vasiliy Tolstov <v.tolstov@unistack.org>
This commit is contained in:
Василий Толстов 2022-03-19 15:13:34 +03:00
parent 8844f01558
commit 378c00c16d
5 changed files with 80 additions and 38 deletions

View File

@ -40,13 +40,13 @@ func (g *Generator) httpGenerate(component string, plugin *protogen.Plugin, genC
for _, service := range file.Services {
if genClient {
generateServiceClient(gfile, service)
generateServiceClientMethods(gfile, service, true)
g.generateServiceClient(gfile, service)
g.generateServiceClientMethods(gfile, service, true)
}
if genServer {
generateServiceServer(gfile, service)
generateServiceServerMethods(gfile, service)
generateServiceRegister(gfile, service)
g.generateServiceServerMethods(gfile, service)
g.generateServiceRegister(gfile, service)
}
}
}

View File

@ -45,11 +45,13 @@ type Generator struct {
fieldaligment bool
tagPath string
openapiFile string
plugin *protogen.Plugin
}
func (g *Generator) Generate(plugin *protogen.Plugin) error {
var err error
g.plugin = plugin
g.standalone = *flagStandalone
g.debug = *flagDebug
g.components = *flagComponents

View File

@ -35,14 +35,14 @@ func (g *Generator) microGenerate(component string, plugin *protogen.Plugin, gen
}
// generate services
for _, service := range file.Services {
generateServiceEndpoints(gfile, service)
g.generateServiceEndpoints(gfile, service)
if genClient {
generateServiceClientInterface(gfile, service)
generateServiceClientStreamInterface(gfile, service)
g.generateServiceClientInterface(gfile, service)
g.generateServiceClientStreamInterface(gfile, service)
}
if genServer {
generateServiceServerInterface(gfile, service)
generateServiceServerStreamInterface(gfile, service)
g.generateServiceServerInterface(gfile, service)
g.generateServiceServerStreamInterface(gfile, service)
}
}

8
rpc.go
View File

@ -37,13 +37,13 @@ func (g *Generator) rpcGenerate(component string, plugin *protogen.Plugin, genCl
}
for _, service := range file.Services {
if genClient {
generateServiceClient(gfile, service)
generateServiceClientMethods(gfile, service, false)
g.generateServiceClient(gfile, service)
g.generateServiceClientMethods(gfile, service, false)
}
if genServer {
generateServiceServer(gfile, service)
generateServiceServerMethods(gfile, service)
generateServiceRegister(gfile, service)
g.generateServiceServerMethods(gfile, service)
g.generateServiceRegister(gfile, service)
}
}
}

90
util.go
View File

@ -2,6 +2,7 @@ package main
import (
"fmt"
"log"
"net/http"
"strings"
@ -28,7 +29,7 @@ func unexport(s string) string {
return strings.ToLower(s[:1]) + s[1:]
}
func generateServiceClient(gfile *protogen.GeneratedFile, service *protogen.Service) {
func (g *Generator) generateServiceClient(gfile *protogen.GeneratedFile, service *protogen.Service) {
serviceName := service.GoName
// if rule, ok := getMicroApiService(service); ok {
// gfile.P("// client wrappers ", strings.Join(rule.ClientWrappers, ", "))
@ -44,11 +45,11 @@ func generateServiceClient(gfile *protogen.GeneratedFile, service *protogen.Serv
gfile.P()
}
func generateServiceClientMethods(gfile *protogen.GeneratedFile, service *protogen.Service, http bool) {
func (g *Generator) generateServiceClientMethods(gfile *protogen.GeneratedFile, service *protogen.Service, http bool) {
serviceName := service.GoName
for _, method := range service.Methods {
methodName := fmt.Sprintf("%s.%s", serviceName, method.GoName)
generateClientFuncSignature(gfile, serviceName, method)
g.generateClientFuncSignature(gfile, serviceName, method)
if http && method.Desc.Options() != nil {
if proto.HasExtension(method.Desc.Options(), v2.E_Openapiv2Operation) {
@ -61,20 +62,25 @@ func generateServiceClientMethods(gfile *protogen.GeneratedFile, service *protog
gfile.P("errmap := make(map[string]interface{}, ", len(r.Responses.ResponseCode), ")")
for _, rsp := range r.Responses.ResponseCode {
if schema := rsp.Value.GetJsonReference(); schema != nil {
ref := schema.XRef
if strings.HasPrefix(ref, "."+string(service.Desc.ParentFile().Package())+".") {
ref = strings.TrimPrefix(ref, "."+string(service.Desc.ParentFile().Package())+".")
xref := schema.XRef
if strings.HasPrefix(xref, "."+string(service.Desc.ParentFile().Package())+".") {
xref = strings.TrimPrefix(xref, "."+string(service.Desc.ParentFile().Package())+".")
}
if ref[0] == '.' {
ref = ref[1:]
if xref[0] == '.' {
xref = xref[1:]
}
switch ref {
switch xref {
case "micro.codec.Frame":
gfile.P(`errmap["`, rsp.Name, `"] = &`, microCodecPackage.Ident("Frame"), "{}")
case "micro.errors.Error":
gfile.P(`errmap["`, rsp.Name, `"] = &`, microErrorsPackage.Ident("Error"), "{}")
default:
gfile.P(`errmap["`, rsp.Name, `"] = &`, ref, "{}")
ident, err := g.getGoIdentByXref(strings.TrimPrefix(schema.XRef, "."))
if err != nil {
log.Printf("cant find message by ref %s\n", schema.XRef)
continue
}
gfile.P(`errmap["`, rsp.Name, `"] = &`, gfile.QualifiedGoIdent(ident), "{}")
}
}
}
@ -97,20 +103,25 @@ func generateServiceClientMethods(gfile *protogen.GeneratedFile, service *protog
gfile.P("errmap := make(map[string]interface{}, ", len(resps), ")")
for _, rsp := range resps {
if schema := rsp.Value.GetReference(); schema != nil {
ref := schema.XRef
if strings.HasPrefix(ref, "."+string(service.Desc.ParentFile().Package())+".") {
ref = strings.TrimPrefix(ref, "."+string(service.Desc.ParentFile().Package())+".")
xref := schema.XRef
if strings.HasPrefix(xref, "."+string(service.Desc.ParentFile().Package())+".") {
xref = strings.TrimPrefix(xref, "."+string(service.Desc.ParentFile().Package())+".")
}
if ref[0] == '.' {
ref = ref[1:]
if xref[0] == '.' {
xref = xref[1:]
}
switch ref {
switch xref {
case "micro.codec.Frame":
gfile.P(`errmap["`, rsp.Name, `"] = &`, microCodecPackage.Ident("Frame"), "{}")
case "micro.errors.Error":
gfile.P(`errmap["`, rsp.Name, `"] = &`, microErrorsPackage.Ident("Error"), "{}")
default:
gfile.P(`errmap["`, rsp.Name, `"] = &`, ref, "{}")
ident, err := g.getGoIdentByXref(strings.TrimPrefix(schema.XRef, "."))
if err != nil {
log.Printf("cant find message by ref %s\n", schema.XRef)
continue
}
gfile.P(`errmap["`, rsp.Name, `"] = &`, gfile.QualifiedGoIdent(ident), "{}")
}
}
}
@ -300,7 +311,7 @@ func generateServiceServer(gfile *protogen.GeneratedFile, service *protogen.Serv
gfile.P()
}
func generateServiceServerMethods(gfile *protogen.GeneratedFile, service *protogen.Service) {
func (g *Generator) generateServiceServerMethods(gfile *protogen.GeneratedFile, service *protogen.Service) {
serviceName := service.GoName
for _, method := range service.Methods {
generateServerFuncSignature(gfile, serviceName, method, true)
@ -444,7 +455,7 @@ func generateServiceServerMethods(gfile *protogen.GeneratedFile, service *protog
}
}
func generateServiceRegister(gfile *protogen.GeneratedFile, service *protogen.Service) {
func (g *Generator) generateServiceRegister(gfile *protogen.GeneratedFile, service *protogen.Service) {
serviceName := service.GoName
gfile.P("func Register", serviceName, "Server(s ", microServerPackage.Ident("Server"), ", sh ", serviceName, "Server, opts ...", microServerPackage.Ident("HandlerOption"), ") error {")
gfile.P("type ", unexport(serviceName), " interface {")
@ -508,7 +519,7 @@ func generateServerSignature(gfile *protogen.GeneratedFile, serviceName string,
gfile.P(args...)
}
func generateClientFuncSignature(gfile *protogen.GeneratedFile, serviceName string, method *protogen.Method) {
func (g *Generator) generateClientFuncSignature(gfile *protogen.GeneratedFile, serviceName string, method *protogen.Method) {
args := append([]interface{}{},
"func (c *",
unexport(serviceName),
@ -547,7 +558,7 @@ func generateClientSignature(gfile *protogen.GeneratedFile, serviceName string,
gfile.P(args...)
}
func generateServiceClientInterface(gfile *protogen.GeneratedFile, service *protogen.Service) {
func (g *Generator) generateServiceClientInterface(gfile *protogen.GeneratedFile, service *protogen.Service) {
serviceName := service.GoName
gfile.P("type ", serviceName, "Client interface {")
for _, method := range service.Methods {
@ -557,7 +568,7 @@ func generateServiceClientInterface(gfile *protogen.GeneratedFile, service *prot
gfile.P()
}
func generateServiceServerInterface(gfile *protogen.GeneratedFile, service *protogen.Service) {
func (g *Generator) generateServiceServerInterface(gfile *protogen.GeneratedFile, service *protogen.Service) {
serviceName := service.GoName
gfile.P("type ", serviceName, "Server interface {")
for _, method := range service.Methods {
@ -567,7 +578,7 @@ func generateServiceServerInterface(gfile *protogen.GeneratedFile, service *prot
gfile.P()
}
func generateServiceClientStreamInterface(gfile *protogen.GeneratedFile, service *protogen.Service) {
func (g *Generator) generateServiceClientStreamInterface(gfile *protogen.GeneratedFile, service *protogen.Service) {
serviceName := service.GoName
for _, method := range service.Methods {
if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
@ -594,7 +605,7 @@ func generateServiceClientStreamInterface(gfile *protogen.GeneratedFile, service
}
}
func generateServiceServerStreamInterface(gfile *protogen.GeneratedFile, service *protogen.Service) {
func (g *Generator) generateServiceServerStreamInterface(gfile *protogen.GeneratedFile, service *protogen.Service) {
serviceName := service.GoName
for _, method := range service.Methods {
if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
@ -621,7 +632,7 @@ func generateServiceServerStreamInterface(gfile *protogen.GeneratedFile, service
}
}
func generateServiceEndpoints(gfile *protogen.GeneratedFile, service *protogen.Service) {
func (g *Generator) generateServiceEndpoints(gfile *protogen.GeneratedFile, service *protogen.Service) {
serviceName := service.GoName
gfile.P("var (")
gfile.P(serviceName, "Name", "=", `"`, serviceName, `"`)
@ -755,3 +766,32 @@ func generateEndpoint(gfile *protogen.GeneratedFile, serviceName string, methodN
}
gfile.P(`Handler: "rpc",`)
}
func (g *Generator) getGoIdentByXref(xref string) (protogen.GoIdent, error) {
idx := strings.LastIndex(xref, ".")
pkg := xref[:idx]
msg := xref[idx+1:]
for _, file := range g.plugin.Files {
if strings.Compare(pkg, *(file.Proto.Package)) != 0 {
continue
}
if ident, err := getGoIdentByMessage(file.Messages, msg); err == nil {
return ident, nil
}
}
return protogen.GoIdent{}, fmt.Errorf("not found")
}
func getGoIdentByMessage(messages []*protogen.Message, msg string) (protogen.GoIdent, error) {
for _, message := range messages {
if strings.Compare(msg, message.GoIdent.GoName) == 0 {
return message.GoIdent, nil
}
if len(message.Messages) > 0 {
if ident, err := getGoIdentByMessage(message.Messages, msg); err == nil {
return ident, nil
}
}
}
return protogen.GoIdent{}, fmt.Errorf("not found")
}