fix codegen for streams

Signed-off-by: Vasiliy Tolstov <v.tolstov@unistack.org>
This commit is contained in:
2025-01-09 18:00:13 +03:00
parent 4633928da5
commit 750c949723
5 changed files with 39 additions and 51 deletions

50
util.go
View File

@@ -51,14 +51,14 @@ func (g *Generator) generateServiceClient(gfile *protogen.GeneratedFile, file *p
gfile.P()
}
func (g *Generator) generateServiceClientMethods(gfile *protogen.GeneratedFile, service *protogen.Service, component string) {
func (g *Generator) generateServiceClientMethods(gfile *protogen.GeneratedFile, file *protogen.File, service *protogen.Service, component string) {
serviceName := service.GoName
for _, method := range service.Methods {
methodName := fmt.Sprintf("%s.%s", serviceName, method.GoName)
if component == "drpc" {
methodName = fmt.Sprintf("%s.%s", method.Parent.Desc.FullName(), method.Desc.Name())
}
g.generateClientFuncSignature(gfile, serviceName, method)
g.generateClientFuncSignature(gfile, serviceName, file, method)
if component == "http" && method.Desc.Options() != nil {
if proto.HasExtension(method.Desc.Options(), v2.E_Openapiv2Operation) {
@@ -262,7 +262,7 @@ func (g *Generator) generateServiceClientMethods(gfile *protogen.GeneratedFile,
gfile.P("}")
}
if method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
if method.Desc.IsStreamingClient() /*&& !method.Desc.IsStreamingServer()*/ {
gfile.P("func (s *", unexport(serviceName), "Client", method.GoName, ") CloseAndRecv() (*", gfile.QualifiedGoIdent(method.Output.GoIdent), ", error) {")
gfile.P("msg := &", gfile.QualifiedGoIdent(method.Output.GoIdent), "{}")
gfile.P("err := s.RecvMsg(msg)")
@@ -432,7 +432,7 @@ func (g *Generator) generateServiceServerMethods(gfile *protogen.GeneratedFile,
gfile.P("}")
}
if method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
if /*method.Desc.IsStreamingClient() && !*/ method.Desc.IsStreamingServer() {
gfile.P("func (s *", unexport(serviceName), method.GoName, "Stream) SendAndClose(msg *", gfile.QualifiedGoIdent(method.Output.GoIdent), ") error {")
gfile.P("err := s.SendMsg(msg)")
gfile.P("if err == nil {")
@@ -493,23 +493,24 @@ func (g *Generator) generateServiceRegister(gfile *protogen.GeneratedFile, file
gfile.P("func Register", serviceName, "Server(s ", microServerPackage.Ident("Server"), ", sh ", serviceName, "Server, opts ...", microServerPackage.Ident("HandlerOption"), ") error {")
}
gfile.P("type ", unexport(serviceName), " interface {")
var generate bool
for _, method := range service.Methods {
generateServerSignature(gfile, serviceName, method, true)
if endpoints, _ := generateEndpoints(method); endpoints != nil {
generate = true
}
}
gfile.P("}")
gfile.P("type ", serviceName, " struct {")
gfile.P(unexport(serviceName))
gfile.P("}")
gfile.P("h := &", unexport(serviceName), "Server{sh}")
gfile.P("var nopts []", microServerPackage.Ident("HandlerOption"))
if component == "http" {
// if g.standalone {
// gfile.P("nopts = append(nopts, ", microServerHttpPackage.Ident("HandlerEndpoints"), "(", file.GoImportPath.Ident(serviceName), "ServerEndpoints))")
// } else {
gfile.P("nopts = append(nopts, ", microServerHttpPackage.Ident("HandlerEndpoints"), "(", serviceName, "ServerEndpoints))")
// }
if component == "http" && generate {
gfile.P("opts = append(opts, ", microServerHttpPackage.Ident("HandlerEndpoints"), "(", serviceName, "ServerEndpoints))")
}
gfile.P("return s.Handle(s.NewHandler(&", serviceName, "{h}, append(nopts, opts...)...))")
gfile.P("return s.Handle(s.NewHandler(&", serviceName, "{h}, opts...))")
gfile.P("}")
}
@@ -557,7 +558,7 @@ func generateServerSignature(gfile *protogen.GeneratedFile, serviceName string,
gfile.P(args...)
}
func (g *Generator) generateClientFuncSignature(gfile *protogen.GeneratedFile, serviceName string, method *protogen.Method) {
func (g *Generator) generateClientFuncSignature(gfile *protogen.GeneratedFile, serviceName string, file *protogen.File, method *protogen.Method) {
args := append([]interface{}{},
"func (c *",
unexport(serviceName),
@@ -572,7 +573,8 @@ func (g *Generator) generateClientFuncSignature(gfile *protogen.GeneratedFile, s
if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
args = append(args, "*", gfile.QualifiedGoIdent(method.Output.GoIdent))
} else {
args = append(args, gfile.QualifiedGoIdent(protogen.GoIdent{GoName: serviceName + "_" + method.GoName + "Client", GoImportPath: method.Output.GoIdent.GoImportPath}))
// TODO
args = append(args, gfile.QualifiedGoIdent(protogen.GoIdent{GoName: serviceName + "_" + method.GoName + "Client", GoImportPath: file.GoImportPath}))
}
args = append(args, ", error) {")
gfile.P(args...)
@@ -627,7 +629,7 @@ func (g *Generator) generateServiceClientStreamInterface(gfile *protogen.Generat
gfile.P("Context() ", contextPackage.Ident("Context"))
gfile.P("SendMsg(msg interface{}) error")
gfile.P("RecvMsg(msg interface{}) error")
if method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
if method.Desc.IsStreamingClient() /*&& !method.Desc.IsStreamingServer()*/ {
gfile.P("CloseAndRecv() (*", gfile.QualifiedGoIdent(method.Output.GoIdent), ", error)")
gfile.P("CloseSend() error")
}
@@ -655,7 +657,7 @@ func (g *Generator) generateServiceServerStreamInterface(gfile *protogen.Generat
gfile.P("Context() ", contextPackage.Ident("Context"))
gfile.P("SendMsg(msg interface{}) error")
gfile.P("RecvMsg(msg interface{}) error")
if method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
if /*method.Desc.IsStreamingClient() && !*/ method.Desc.IsStreamingServer() {
gfile.P("SendAndClose(msg *", gfile.QualifiedGoIdent(method.Output.GoIdent), ") error")
// gfile.P("CloseSend() error")
}
@@ -888,14 +890,18 @@ func (g *Generator) generateServiceEndpoints(gfile *protogen.GeneratedFile, serv
if component != "http" {
return
}
serviceName := service.GoName
gfile.P("var (")
gfile.P(serviceName, "ServerEndpoints = []", microServerHttpPackage.Ident("EndpointMetadata"), "{")
var generate bool
serviceName := service.GoName
for _, method := range service.Methods {
if proto.HasExtension(method.Desc.Options(), api_options.E_Http) {
if endpoints, streaming := generateEndpoints(method); endpoints != nil {
if !generate {
gfile.P("var (")
gfile.P(serviceName, "ServerEndpoints = []", microServerHttpPackage.Ident("EndpointMetadata"), "{")
generate = true
}
for _, ep := range endpoints {
epath, emethod, ebody := getEndpoint(ep)
gfile.P("{")
@@ -910,8 +916,10 @@ func (g *Generator) generateServiceEndpoints(gfile *protogen.GeneratedFile, serv
}
}
gfile.P("}")
gfile.P(")")
if generate {
gfile.P("}")
gfile.P(")")
}
}
func (g *Generator) writeErrors(plugin *protogen.Plugin) error {