From bfbb2bcbc1d0e499494c92020fa03ee0dff74646 Mon Sep 17 00:00:00 2001 From: Vasiliy Tolstov Date: Sat, 3 Dec 2022 02:17:05 +0300 Subject: [PATCH] enable rsp validation Signed-off-by: Vasiliy Tolstov --- validator.go | 40 +++++++++++++++++++++++++++++----------- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/validator.go b/validator.go index 17d0f0e..017c01d 100644 --- a/validator.go +++ b/validator.go @@ -29,28 +29,40 @@ func NewClientCallWrapper() client.CallWrapper { return func(fn client.CallFunc) client.CallFunc { return func(ctx context.Context, addr string, req client.Request, rsp interface{}, opts client.CallOptions) error { if v, ok := req.Body().(validator); ok { - if err := v.Validate(); err != nil { - return errors.BadRequest(req.Service(), "%v", err) + if verr := v.Validate(); verr != nil { + return errors.BadRequest(req.Service(), "%v", verr) } } - return fn(ctx, addr, req, rsp, opts) + err := fn(ctx, addr, req, rsp, opts) + if v, ok := rsp.(validator); ok { + if verr := v.Validate(); verr != nil { + return errors.BadGateway(req.Service(), "%v", verr) + } + } + return err } } } func (w *wrapper) Call(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error { if v, ok := req.Body().(validator); ok { - if err := v.Validate(); err != nil { - return errors.BadRequest(req.Service(), "%v", err) + if verr := v.Validate(); verr != nil { + return errors.BadRequest(req.Service(), "%v", verr) } } - return w.Client.Call(ctx, req, rsp, opts...) + err := w.Client.Call(ctx, req, rsp, opts...) + if v, ok := rsp.(validator); ok { + if verr := v.Validate(); verr != nil { + return errors.BadGateway(req.Service(), "%v", verr) + } + } + return err } func (w *wrapper) Stream(ctx context.Context, req client.Request, opts ...client.CallOption) (client.Stream, error) { if v, ok := req.Body().(validator); ok { - if err := v.Validate(); err != nil { - return nil, errors.BadRequest(req.Service(), "%v", err) + if verr := v.Validate(); verr != nil { + return nil, errors.BadRequest(req.Service(), "%v", verr) } } return w.Client.Stream(ctx, req, opts...) @@ -69,11 +81,17 @@ func NewServerHandlerWrapper() server.HandlerWrapper { return func(fn server.HandlerFunc) server.HandlerFunc { return func(ctx context.Context, req server.Request, rsp interface{}) error { if v, ok := req.Body().(validator); ok { - if err := v.Validate(); err != nil { - return errors.BadRequest(req.Service(), "%v", err) + if verr := v.Validate(); verr != nil { + return errors.BadRequest(req.Service(), "%v", verr) } } - return fn(ctx, req, rsp) + err := fn(ctx, req, rsp) + if v, ok := rsp.(validator); ok { + if verr := v.Validate(); verr != nil { + return errors.BadGateway(req.Service(), "%v", verr) + } + } + return err } } }