diff --git a/main.go b/main.go index 7e14fc8..c816eae 100644 --- a/main.go +++ b/main.go @@ -63,8 +63,8 @@ func (g *Generator) Generate(plugin *protogen.Plugin) error { g.reflection = *flagReflection plugin.SupportedFeatures = uint64(pluginpb.CodeGeneratorResponse_FEATURE_PROTO3_OPTIONAL) - genClient := true - genServer := true + var genClient bool + var genServer bool var genNone bool if strings.Contains(g.components, "server") { diff --git a/util.go b/util.go index 090b8c0..dee640a 100644 --- a/util.go +++ b/util.go @@ -32,7 +32,7 @@ func unexport(s string) string { } func (g *Generator) generateServiceClient(gfile *protogen.GeneratedFile, file *protogen.File, service *protogen.Service) { - serviceName := service.GoName + serviceName := getServiceName(service) // if rule, ok := getMicroApiService(service); ok { // gfile.P("// client wrappers ", strings.Join(rule.ClientWrappers, ", ")) // } @@ -52,7 +52,7 @@ func (g *Generator) generateServiceClient(gfile *protogen.GeneratedFile, file *p } func (g *Generator) generateServiceClientMethods(gfile *protogen.GeneratedFile, service *protogen.Service, component string) { - serviceName := service.GoName + serviceName := getServiceName(service) for _, method := range service.Methods { methodName := fmt.Sprintf("%s.%s", serviceName, method.GoName) if component == "drpc" { @@ -324,7 +324,7 @@ func (g *Generator) generateServiceClientMethods(gfile *protogen.GeneratedFile, } func (g *Generator) generateServiceServer(gfile *protogen.GeneratedFile, file *protogen.File, service *protogen.Service) { - serviceName := service.GoName + serviceName := getServiceName(service) gfile.P("type ", unexport(serviceName), "Server struct {") if g.standalone { gfile.P(file.GoImportPath.Ident(serviceName), "Server") @@ -336,7 +336,7 @@ func (g *Generator) generateServiceServer(gfile *protogen.GeneratedFile, file *p } func (g *Generator) generateServiceServerMethods(gfile *protogen.GeneratedFile, service *protogen.Service) { - serviceName := service.GoName + serviceName := getServiceName(service) for _, method := range service.Methods { generateServerFuncSignature(gfile, serviceName, method, true) if rule, ok := getMicroApiMethod(method); ok { @@ -486,7 +486,7 @@ func (g *Generator) generateServiceServerMethods(gfile *protogen.GeneratedFile, } func (g *Generator) generateServiceRegister(gfile *protogen.GeneratedFile, file *protogen.File, service *protogen.Service, component string) { - serviceName := service.GoName + serviceName := getServiceName(service) if g.standalone { gfile.P("func Register", serviceName, "Server(s ", microServerPackage.Ident("Server"), ", sh ", file.GoImportPath.Ident(serviceName), "Server, opts ...", microOptionsPackage.Ident("Option"), ") error {") } else { @@ -597,7 +597,7 @@ func generateClientSignature(gfile *protogen.GeneratedFile, serviceName string, } func (g *Generator) generateServiceClientInterface(gfile *protogen.GeneratedFile, service *protogen.Service) { - serviceName := service.GoName + serviceName := getServiceName(service) gfile.P("type ", serviceName, "Client interface {") for _, method := range service.Methods { generateClientSignature(gfile, serviceName, method) @@ -607,7 +607,7 @@ func (g *Generator) generateServiceClientInterface(gfile *protogen.GeneratedFile } func (g *Generator) generateServiceServerInterface(gfile *protogen.GeneratedFile, service *protogen.Service) { - serviceName := service.GoName + serviceName := getServiceName(service) gfile.P("type ", serviceName, "Server interface {") for _, method := range service.Methods { generateServerSignature(gfile, serviceName, method, false) @@ -617,7 +617,7 @@ func (g *Generator) generateServiceServerInterface(gfile *protogen.GeneratedFile } func (g *Generator) generateServiceClientStreamInterface(gfile *protogen.GeneratedFile, service *protogen.Service) { - serviceName := service.GoName + serviceName := getServiceName(service) for _, method := range service.Methods { if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() { continue @@ -645,7 +645,7 @@ func (g *Generator) generateServiceClientStreamInterface(gfile *protogen.Generat } func (g *Generator) generateServiceServerStreamInterface(gfile *protogen.GeneratedFile, service *protogen.Service) { - serviceName := service.GoName + serviceName := getServiceName(service) for _, method := range service.Methods { if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() { continue @@ -807,7 +807,7 @@ func getGoIdentByMessage(messages []*protogen.Message, msg string) (protogen.GoI } func (g *Generator) generateServiceDesc(gfile *protogen.GeneratedFile, file *protogen.File, service *protogen.Service) { - serviceName := service.GoName + serviceName := getServiceName(service) gfile.P("// ", serviceName, "_ServiceDesc", " is the ", grpcPackage.Ident("ServiceDesc"), " for ", serviceName, " service.") gfile.P("// It's only intended for direct use with ", grpcPackage.Ident("RegisterService"), ",") @@ -849,7 +849,7 @@ func (g *Generator) generateServiceDesc(gfile *protogen.GeneratedFile, file *pro } func (g *Generator) generateServiceName(gfile *protogen.GeneratedFile, service *protogen.Service) { - serviceName := service.GoName + serviceName := getServiceName(service) gfile.P("var (") gfile.P(serviceName, "Name", "=", `"`, serviceName, `"`) gfile.P(")") @@ -859,7 +859,7 @@ func (g *Generator) generateServiceEndpoints(gfile *protogen.GeneratedFile, serv if component != "http" { return } - serviceName := service.GoName + serviceName := getServiceName(service) gfile.P("var (") gfile.P(serviceName, "ServerEndpoints = []", microServerHttpPackage.Ident("EndpointMetadata"), "{") @@ -884,3 +884,10 @@ func (g *Generator) generateServiceEndpoints(gfile *protogen.GeneratedFile, serv gfile.P("}") gfile.P(")") } + +func getServiceName(s *protogen.Service) string { + if strings.HasSuffix(s.GoName, "Service") { + return s.GoName + } + return s.GoName + "Service" +}