diff --git a/gobreaker.go b/gobreaker.go index 516d45b..50b7c88 100644 --- a/gobreaker.go +++ b/gobreaker.go @@ -1,28 +1,92 @@ package gobreaker import ( - "github.com/micro/go-micro/client" - "github.com/sony/gobreaker" - "context" + "sync" + + "github.com/micro/go-micro/client" + "github.com/micro/go-micro/errors" + "github.com/sony/gobreaker" +) + +type BreakerMethod int + +const ( + BreakService BreakerMethod = iota + BreakServiceEndpoint ) type clientWrapper struct { - cb *gobreaker.CircuitBreaker + bs gobreaker.Settings + bm BreakerMethod + cbs map[string]*gobreaker.TwoStepCircuitBreaker + mu sync.Mutex client.Client } func (c *clientWrapper) Call(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error { - _, err := c.cb.Execute(func() (interface{}, error) { - cerr := c.Client.Call(ctx, req, rsp, opts...) - return nil, cerr - }) + var svc string + + switch c.bm { + case BreakService: + svc = req.Service() + case BreakServiceEndpoint: + svc = req.Service() + "." + req.Endpoint() + } + + c.mu.Lock() + cb, ok := c.cbs[svc] + if !ok { + cb = gobreaker.NewTwoStepCircuitBreaker(c.bs) + c.cbs[svc] = cb + } + c.mu.Unlock() + + cbAllow, err := cb.Allow() + if err != nil { + return errors.New(req.Service(), err.Error(), 502) + } + + if err = c.Client.Call(ctx, req, rsp, opts...); err == nil { + cbAllow(true) + return nil + } + + switch err.(type) { + case *errors.Error: + break + default: + err = errors.New(req.Service(), err.Error(), 503) + } + + if err.(*errors.Error).Code >= 500 { + cbAllow(false) + } else { + cbAllow(true) + } + return err } -// NewClientWrapper takes a *gobreaker.CircuitBreaker and returns a client Wrapper. -func NewClientWrapper(cb *gobreaker.CircuitBreaker) client.Wrapper { +// NewClientWrapper returns a client Wrapper. +func NewClientWrapper() client.Wrapper { return func(c client.Client) client.Client { - return &clientWrapper{cb, c} + w := &clientWrapper{} + w.bs = gobreaker.Settings{} + w.cbs = make(map[string]*gobreaker.TwoStepCircuitBreaker) + w.Client = c + return w + } +} + +// NewCustomClientWrapper takes a gobreaker.Settings and BreakerMethod. Returns a client Wrapper. +func NewCustomClientWrapper(bs gobreaker.Settings, bm BreakerMethod) client.Wrapper { + return func(c client.Client) client.Client { + w := &clientWrapper{} + w.bm = bm + w.bs = bs + w.cbs = make(map[string]*gobreaker.TwoStepCircuitBreaker) + w.Client = c + return w } } diff --git a/gobreaker_test.go b/gobreaker_test.go index 824c816..e026dfb 100644 --- a/gobreaker_test.go +++ b/gobreaker_test.go @@ -1,14 +1,14 @@ package gobreaker import ( + "context" "testing" "github.com/micro/go-micro/client" + "github.com/micro/go-micro/errors" "github.com/micro/go-micro/registry/memory" "github.com/micro/go-micro/selector" "github.com/sony/gobreaker" - - "context" ) func TestBreaker(t *testing.T) { @@ -20,8 +20,43 @@ func TestBreaker(t *testing.T) { // set the selector client.Selector(s), // add the breaker wrapper - client.Wrap(NewClientWrapper( - gobreaker.NewCircuitBreaker(gobreaker.Settings{}), + client.Wrap(NewClientWrapper()), + ) + + req := c.NewRequest("test.service", "Test.Method", map[string]string{ + "foo": "bar", + }, client.WithContentType("application/json")) + + var rsp map[string]interface{} + + // Force to point of trip + for i := 0; i < 6; i++ { + c.Call(context.TODO(), req, rsp) + } + + err := c.Call(context.TODO(), req, rsp) + if err == nil { + t.Error("Expecting tripped breaker, got nil error") + } + + merr := err.(*errors.Error) + if merr.Code != 502 { + t.Errorf("Expecting tripped breaker, got %v", err) + } +} + +func TestCustomBreaker(t *testing.T) { + // setup + r := memory.NewRegistry() + s := selector.NewSelector(selector.Registry(r)) + + c := client.NewClient( + // set the selector + client.Selector(s), + // add the breaker wrapper + client.Wrap(NewCustomClientWrapper( + gobreaker.Settings{}, + BreakService, )), ) @@ -41,7 +76,8 @@ func TestBreaker(t *testing.T) { t.Error("Expecting tripped breaker, got nil error") } - if err.Error() != "circuit breaker is open" { + merr := err.(*errors.Error) + if merr.Code != 502 { t.Errorf("Expecting tripped breaker, got %v", err) } }