From ae63d44866e00eb06f9449fc7f79b6875716d27e Mon Sep 17 00:00:00 2001 From: Evstigneev Denis Date: Tue, 22 Apr 2025 09:41:22 +0300 Subject: [PATCH] move hooks --- hooks/metadata/metadata.go | 76 ++++++++++++ hooks/recovery/recovery.go | 94 +++++++++++++++ hooks/requestid/requestid.go | 139 +++++++++++++++++++++ hooks/requestid/requestid_test.go | 33 +++++ hooks/validator/validator.go | 194 ++++++++++++++++++++++++++++++ 5 files changed, 536 insertions(+) create mode 100644 hooks/metadata/metadata.go create mode 100644 hooks/recovery/recovery.go create mode 100644 hooks/requestid/requestid.go create mode 100644 hooks/requestid/requestid_test.go create mode 100644 hooks/validator/validator.go diff --git a/hooks/metadata/metadata.go b/hooks/metadata/metadata.go new file mode 100644 index 00000000..99b64c88 --- /dev/null +++ b/hooks/metadata/metadata.go @@ -0,0 +1,76 @@ +package metadata + +import ( + "context" + + "go.unistack.org/micro/v3/client" + "go.unistack.org/micro/v3/metadata" + "go.unistack.org/micro/v3/server" +) + +var DefaultMetadataKeys = []string{"x-request-id"} + +type hook struct { + keys []string +} + +func NewHook(keys ...string) *hook { + return &hook{keys: keys} +} + +func metadataCopy(ctx context.Context, keys []string) context.Context { + if keys == nil { + return ctx + } + if imd, iok := metadata.FromIncomingContext(ctx); iok && imd != nil { + omd, ook := metadata.FromOutgoingContext(ctx) + if !ook || omd == nil { + omd = metadata.New(len(keys)) + } + for _, k := range keys { + if v, ok := imd.Get(k); ok && v != "" { + omd.Set(k, v) + } + } + if !ook { + ctx = metadata.NewOutgoingContext(ctx, omd) + } + } + return ctx +} + +func (w *hook) ClientCall(next client.FuncCall) client.FuncCall { + return func(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error { + return next(metadataCopy(ctx, w.keys), req, rsp, opts...) + } +} + +func (w *hook) ClientStream(next client.FuncStream) client.FuncStream { + return func(ctx context.Context, req client.Request, opts ...client.CallOption) (client.Stream, error) { + return next(metadataCopy(ctx, w.keys), req, opts...) + } +} + +func (w *hook) ClientPublish(next client.FuncPublish) client.FuncPublish { + return func(ctx context.Context, msg client.Message, opts ...client.PublishOption) error { + return next(metadataCopy(ctx, w.keys), msg, opts...) + } +} + +func (w *hook) ClientBatchPublish(next client.FuncBatchPublish) client.FuncBatchPublish { + return func(ctx context.Context, msgs []client.Message, opts ...client.PublishOption) error { + return next(metadataCopy(ctx, w.keys), msgs, opts...) + } +} + +func (w *hook) ServerHandler(next server.FuncHandler) server.FuncHandler { + return func(ctx context.Context, req server.Request, rsp interface{}) error { + return next(metadataCopy(ctx, w.keys), req, rsp) + } +} + +func (w *hook) ServerSubscriber(next server.FuncSubHandler) server.FuncSubHandler { + return func(ctx context.Context, msg server.Message) error { + return next(metadataCopy(ctx, w.keys), msg) + } +} diff --git a/hooks/recovery/recovery.go b/hooks/recovery/recovery.go new file mode 100644 index 00000000..22b2cb71 --- /dev/null +++ b/hooks/recovery/recovery.go @@ -0,0 +1,94 @@ +package recovery + +import ( + "context" + "fmt" + + "go.unistack.org/micro/v3/errors" + "go.unistack.org/micro/v3/server" +) + +func NewOptions(opts ...Option) Options { + options := Options{ + ServerHandlerFn: DefaultServerHandlerFn, + ServerSubscriberFn: DefaultServerSubscriberFn, + } + for _, o := range opts { + o(&options) + } + return options +} + +type Options struct { + ServerHandlerFn func(context.Context, server.Request, interface{}, error) error + ServerSubscriberFn func(context.Context, server.Message, error) error +} + +type Option func(*Options) + +func ServerHandlerFunc(fn func(context.Context, server.Request, interface{}, error) error) Option { + return func(o *Options) { + o.ServerHandlerFn = fn + } +} + +func ServerSubscriberFunc(fn func(context.Context, server.Message, error) error) Option { + return func(o *Options) { + o.ServerSubscriberFn = fn + } +} + +var ( + DefaultServerHandlerFn = func(ctx context.Context, req server.Request, rsp interface{}, err error) error { + return errors.BadRequest("", "%v", err) + } + DefaultServerSubscriberFn = func(ctx context.Context, req server.Message, err error) error { + return errors.BadRequest("", "%v", err) + } +) + +var Hook = NewHook() + +type hook struct { + opts Options +} + +func NewHook(opts ...Option) *hook { + return &hook{opts: NewOptions(opts...)} +} + +func (w *hook) ServerHandler(next server.FuncHandler) server.FuncHandler { + return func(ctx context.Context, req server.Request, rsp interface{}) (err error) { + defer func() { + r := recover() + switch verr := r.(type) { + case nil: + return + case error: + err = w.opts.ServerHandlerFn(ctx, req, rsp, verr) + default: + err = w.opts.ServerHandlerFn(ctx, req, rsp, fmt.Errorf("%v", r)) + } + }() + err = next(ctx, req, rsp) + return err + } +} + +func (w *hook) ServerSubscriber(next server.FuncSubHandler) server.FuncSubHandler { + return func(ctx context.Context, msg server.Message) (err error) { + defer func() { + r := recover() + switch verr := r.(type) { + case nil: + return + case error: + err = w.opts.ServerSubscriberFn(ctx, msg, verr) + default: + err = w.opts.ServerSubscriberFn(ctx, msg, fmt.Errorf("%v", r)) + } + }() + err = next(ctx, msg) + return err + } +} diff --git a/hooks/requestid/requestid.go b/hooks/requestid/requestid.go new file mode 100644 index 00000000..702011f3 --- /dev/null +++ b/hooks/requestid/requestid.go @@ -0,0 +1,139 @@ +package requestid + +import ( + "context" + "net/textproto" + + "go.unistack.org/micro/v3/client" + "go.unistack.org/micro/v3/metadata" + "go.unistack.org/micro/v3/server" + "go.unistack.org/micro/v3/util/id" +) + +type XRequestIDKey struct{} + +// DefaultMetadataKey contains metadata key +var DefaultMetadataKey = textproto.CanonicalMIMEHeaderKey("x-request-id") + +// DefaultMetadataFunc wil be used if user not provide own func to fill metadata +var DefaultMetadataFunc = func(ctx context.Context) (context.Context, error) { + var xid string + + cid, cok := ctx.Value(XRequestIDKey{}).(string) + if cok && cid != "" { + xid = cid + } + + imd, iok := metadata.FromIncomingContext(ctx) + if !iok || imd == nil { + imd = metadata.New(1) + ctx = metadata.NewIncomingContext(ctx, imd) + } + + omd, ook := metadata.FromOutgoingContext(ctx) + if !ook || omd == nil { + omd = metadata.New(1) + ctx = metadata.NewOutgoingContext(ctx, omd) + } + + if xid == "" { + var id string + if id, iok = imd.Get(DefaultMetadataKey); iok && id != "" { + xid = id + } + if id, ook = omd.Get(DefaultMetadataKey); ook && id != "" { + xid = id + } + } + + if xid == "" { + var err error + xid, err = id.New() + if err != nil { + return ctx, err + } + } + + if !cok { + ctx = context.WithValue(ctx, XRequestIDKey{}, xid) + } + + if !iok { + imd.Set(DefaultMetadataKey, xid) + } + + if !ook { + omd.Set(DefaultMetadataKey, xid) + } + + return ctx, nil +} + +type hook struct{} + +func NewHook() *hook { + return &hook{} +} + +func (w *hook) ServerSubscriber(next server.FuncSubHandler) server.FuncSubHandler { + return func(ctx context.Context, msg server.Message) error { + var err error + if xid, ok := msg.Header()[DefaultMetadataKey]; ok { + ctx = context.WithValue(ctx, XRequestIDKey{}, xid) + } + if ctx, err = DefaultMetadataFunc(ctx); err != nil { + return err + } + return next(ctx, msg) + } +} + +func (w *hook) ServerHandler(next server.FuncHandler) server.FuncHandler { + return func(ctx context.Context, req server.Request, rsp interface{}) error { + var err error + if ctx, err = DefaultMetadataFunc(ctx); err != nil { + return err + } + return next(ctx, req, rsp) + } +} + +func (w *hook) ClientBatchPublish(next client.FuncBatchPublish) client.FuncBatchPublish { + return func(ctx context.Context, msgs []client.Message, opts ...client.PublishOption) error { + var err error + if ctx, err = DefaultMetadataFunc(ctx); err != nil { + return err + } + return next(ctx, msgs, opts...) + } +} + +func (w *hook) ClientPublish(next client.FuncPublish) client.FuncPublish { + return func(ctx context.Context, msg client.Message, opts ...client.PublishOption) error { + var err error + if ctx, err = DefaultMetadataFunc(ctx); err != nil { + return err + } + return next(ctx, msg, opts...) + } +} + +func (w *hook) ClientCall(next client.FuncCall) client.FuncCall { + return func(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error { + var err error + if ctx, err = DefaultMetadataFunc(ctx); err != nil { + return err + } + return next(ctx, req, rsp, opts...) + } +} + +func (w *hook) ClientStream(next client.FuncStream) client.FuncStream { + return func(ctx context.Context, req client.Request, opts ...client.CallOption) (client.Stream, error) { + var err error + if ctx, err = DefaultMetadataFunc(ctx); err != nil { + return nil, err + } + return next(ctx, req, opts...) + } +} diff --git a/hooks/requestid/requestid_test.go b/hooks/requestid/requestid_test.go new file mode 100644 index 00000000..ce7fad77 --- /dev/null +++ b/hooks/requestid/requestid_test.go @@ -0,0 +1,33 @@ +package requestid + +import ( + "context" + "testing" + + "go.unistack.org/micro/v3/metadata" +) + +func TestDefaultMetadataFunc(t *testing.T) { + ctx := context.TODO() + + nctx, err := DefaultMetadataFunc(ctx) + if err != nil { + t.Fatalf("%v", err) + } + + imd, ok := metadata.FromIncomingContext(nctx) + if !ok { + t.Fatalf("md missing in incoming context") + } + omd, ok := metadata.FromOutgoingContext(nctx) + if !ok { + t.Fatalf("md missing in outgoing context") + } + + _, iok := imd.Get(DefaultMetadataKey) + _, ook := omd.Get(DefaultMetadataKey) + + if !iok || !ook { + t.Fatalf("missing metadata key value") + } +} diff --git a/hooks/validator/validator.go b/hooks/validator/validator.go new file mode 100644 index 00000000..75381cd7 --- /dev/null +++ b/hooks/validator/validator.go @@ -0,0 +1,194 @@ +package validator + +import ( + "context" + + "go.unistack.org/micro/v3/client" + "go.unistack.org/micro/v3/errors" + "go.unistack.org/micro/v3/server" +) + +var ( + DefaultClientErrorFunc = func(req client.Request, rsp interface{}, err error) error { + if rsp != nil { + return errors.BadGateway(req.Service(), "%v", err) + } + return errors.BadRequest(req.Service(), "%v", err) + } + + DefaultServerErrorFunc = func(req server.Request, rsp interface{}, err error) error { + if rsp != nil { + return errors.BadGateway(req.Service(), "%v", err) + } + return errors.BadRequest(req.Service(), "%v", err) + } + + DefaultPublishErrorFunc = func(msg client.Message, err error) error { + return errors.BadRequest(msg.Topic(), "%v", err) + } + + DefaultSubscribeErrorFunc = func(msg server.Message, err error) error { + return errors.BadRequest(msg.Topic(), "%v", err) + } +) + +type ( + ClientErrorFunc func(client.Request, interface{}, error) error + ServerErrorFunc func(server.Request, interface{}, error) error + PublishErrorFunc func(client.Message, error) error + SubscribeErrorFunc func(server.Message, error) error +) + +// Options struct holds wrapper options +type Options struct { + ClientErrorFn ClientErrorFunc + ServerErrorFn ServerErrorFunc + PublishErrorFn PublishErrorFunc + SubscribeErrorFn SubscribeErrorFunc + ClientValidateResponse bool + ServerValidateResponse bool +} + +// Option func signature +type Option func(*Options) + +func ClientValidateResponse(b bool) Option { + return func(o *Options) { + o.ClientValidateResponse = b + } +} + +func ServerValidateResponse(b bool) Option { + return func(o *Options) { + o.ClientValidateResponse = b + } +} + +func ClientReqErrorFn(fn ClientErrorFunc) Option { + return func(o *Options) { + o.ClientErrorFn = fn + } +} + +func ServerErrorFn(fn ServerErrorFunc) Option { + return func(o *Options) { + o.ServerErrorFn = fn + } +} + +func PublishErrorFn(fn PublishErrorFunc) Option { + return func(o *Options) { + o.PublishErrorFn = fn + } +} + +func SubscribeErrorFn(fn SubscribeErrorFunc) Option { + return func(o *Options) { + o.SubscribeErrorFn = fn + } +} + +func NewOptions(opts ...Option) Options { + options := Options{ + ClientErrorFn: DefaultClientErrorFunc, + ServerErrorFn: DefaultServerErrorFunc, + PublishErrorFn: DefaultPublishErrorFunc, + SubscribeErrorFn: DefaultSubscribeErrorFunc, + } + for _, o := range opts { + o(&options) + } + return options +} + +func NewHook(opts ...Option) *hook { + return &hook{opts: NewOptions(opts...)} +} + +type validator interface { + Validate() error +} + +type hook struct { + opts Options +} + +func (w *hook) ClientCall(next client.FuncCall) client.FuncCall { + return func(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error { + if v, ok := req.Body().(validator); ok { + if err := v.Validate(); err != nil { + return w.opts.ClientErrorFn(req, nil, err) + } + } + err := next(ctx, req, rsp, opts...) + if v, ok := rsp.(validator); ok && w.opts.ClientValidateResponse { + if verr := v.Validate(); verr != nil { + return w.opts.ClientErrorFn(req, rsp, verr) + } + } + return err + } +} + +func (w *hook) ClientStream(next client.FuncStream) client.FuncStream { + return func(ctx context.Context, req client.Request, opts ...client.CallOption) (client.Stream, error) { + if v, ok := req.Body().(validator); ok { + if err := v.Validate(); err != nil { + return nil, w.opts.ClientErrorFn(req, nil, err) + } + } + return next(ctx, req, opts...) + } +} + +func (w *hook) ClientPublish(next client.FuncPublish) client.FuncPublish { + return func(ctx context.Context, msg client.Message, opts ...client.PublishOption) error { + if v, ok := msg.Payload().(validator); ok { + if err := v.Validate(); err != nil { + return w.opts.PublishErrorFn(msg, err) + } + } + return next(ctx, msg, opts...) + } +} + +func (w *hook) ClientBatchPublish(next client.FuncBatchPublish) client.FuncBatchPublish { + return func(ctx context.Context, msgs []client.Message, opts ...client.PublishOption) error { + for _, msg := range msgs { + if v, ok := msg.Payload().(validator); ok { + if err := v.Validate(); err != nil { + return w.opts.PublishErrorFn(msg, err) + } + } + } + return next(ctx, msgs, opts...) + } +} + +func (w *hook) ServerHandler(next server.FuncHandler) server.FuncHandler { + return func(ctx context.Context, req server.Request, rsp interface{}) error { + if v, ok := req.Body().(validator); ok { + if err := v.Validate(); err != nil { + return w.opts.ServerErrorFn(req, nil, err) + } + } + err := next(ctx, req, rsp) + if v, ok := rsp.(validator); ok && w.opts.ServerValidateResponse { + if verr := v.Validate(); verr != nil { + return w.opts.ServerErrorFn(req, rsp, verr) + } + } + return err + } +} + +func (w *hook) ServerSubscriber(next server.FuncSubHandler) server.FuncSubHandler { + return func(ctx context.Context, msg server.Message) error { + if v, ok := msg.Body().(validator); ok { + if err := v.Validate(); err != nil { + return w.opts.SubscribeErrorFn(msg, err) + } + } + return next(ctx, msg) + } +}