diff --git a/main.go b/main.go index c816eae..bca7cf4 100644 --- a/main.go +++ b/main.go @@ -113,6 +113,11 @@ func (g *Generator) Generate(plugin *protogen.Plugin) error { } + if err = g.writeErrors(plugin); err != nil { + plugin.Error(err) + return err + } + if err = g.astGenerate(plugin); err != nil { plugin.Error(err) return err diff --git a/util.go b/util.go index dee640a..a2e912e 100644 --- a/util.go +++ b/util.go @@ -792,6 +792,35 @@ func (g *Generator) getGoIdentByXref(xref string) (protogen.GoIdent, error) { return protogen.GoIdent{}, fmt.Errorf("not found") } +func (g *Generator) getMessageByXref(xref string) (*protogen.Message, error) { + idx := strings.LastIndex(xref, ".") + pkg := xref[:idx] + msg := xref[idx+1:] + for _, file := range g.plugin.Files { + if strings.Compare(pkg, *(file.Proto.Package)) != 0 { + continue + } + if pmsg, err := getProtoMessage(file.Messages, msg); err == nil { + return pmsg, nil + } + } + return nil, fmt.Errorf("not found") +} + +func getProtoMessage(messages []*protogen.Message, msg string) (*protogen.Message, error) { + for _, message := range messages { + if strings.Compare(msg, message.GoIdent.GoName) == 0 { + return message, nil + } + if len(message.Messages) > 0 { + if pmsg, err := getProtoMessage(message.Messages, msg); err == nil { + return pmsg, nil + } + } + } + return nil, fmt.Errorf("not found") +} + func getGoIdentByMessage(messages []*protogen.Message, msg string) (protogen.GoIdent, error) { for _, message := range messages { if strings.Compare(msg, message.GoIdent.GoName) == 0 { @@ -891,3 +920,101 @@ func getServiceName(s *protogen.Service) string { } return s.GoName + "Service" } + +func (g *Generator) writeErrors(plugin *protogen.Plugin) error { + errorsMap := make(map[string]struct{}) + + for _, file := range plugin.Files { + for _, service := range file.Services { + for _, method := range service.Methods { + if method.Desc.Options() != nil { + if proto.HasExtension(method.Desc.Options(), v2.E_Openapiv2Operation) { + opts := proto.GetExtension(method.Desc.Options(), v2.E_Openapiv2Operation) + if opts != nil { + r := opts.(*v2.Operation) + if r.Responses == nil { + continue + } + + for _, rsp := range r.Responses.ResponseCode { + if schema := rsp.Value.GetJsonReference(); schema != nil { + xref := schema.XRef + if xref[0] == '.' { + xref = xref[1:] + } + errorsMap[xref] = struct{}{} + } + } + } + } + if proto.HasExtension(method.Desc.Options(), v3.E_Openapiv3Operation) { + opts := proto.GetExtension(method.Desc.Options(), v3.E_Openapiv3Operation) + if opts != nil { + r := opts.(*v3.Operation) + if r.Responses == nil { + continue + } + resps := r.Responses.ResponseOrReference + if r.Responses.GetDefault() != nil { + resps = append(resps, &v3.NamedResponseOrReference{Name: "default", Value: r.Responses.GetDefault()}) + } + for _, rsp := range resps { + if schema := rsp.Value.GetReference(); schema != nil { + xref := schema.XRef + if xref[0] == '.' { + xref = xref[1:] + } + errorsMap[xref] = struct{}{} + } + } + } + } + } + } + } + } + + var gfile *protogen.GeneratedFile + if len(errorsMap) > 0 { + gfile = plugin.NewGeneratedFile("micro_errors.pb.go", ".") + var packageName string + + for _, file := range plugin.Files { + if !file.Generate { + continue + } + if len(file.Services) == 0 { + continue + } + packageName = string(file.GoPackageName) + break + } + + gfile.P("// Code generated by protoc-gen-go-micro. DO NOT EDIT.") + gfile.P("// protoc-gen-go-micro version: " + versionComment) + gfile.P() + gfile.P("package ", packageName) + gfile.P() + + gfile.Import(protojsonPackage) + + gfile.P("var (") + gfile.P("marshaler = ", protojsonPackage.Ident("MarshalOptions"), "{}") + gfile.P(")") + } + + for xref := range errorsMap { + msg, err := g.getMessageByXref(xref) + if err != nil { + return err + } + + gfile.P(`func (m *`, msg.GoIdent.GoName, `) Error() string {`) + gfile.P(`buf, _ := marshaler.Marshal(m)`) + gfile.P("return string(buf)") + gfile.P(`}`) + // log.Printf("xref %#+v %v\n", msg.GoIdent.GoName, err) + } + + return nil +} diff --git a/variables.go b/variables.go index 2abaf10..f0eb021 100644 --- a/variables.go +++ b/variables.go @@ -20,6 +20,7 @@ var ( microErrorsPackage = protogen.GoImportPath("go.unistack.org/micro/v4/errors") microOptionsPackage = protogen.GoImportPath("go.unistack.org/micro/v4/options") grpcPackage = protogen.GoImportPath("google.golang.org/grpc") + protojsonPackage = protogen.GoImportPath("google.golang.org/protobuf/encoding/protojson") timePackage = protogen.GoImportPath("time") deprecationComment = "// Deprecated: Do not use." versionComment = "v4.0.2"