Fixes and improved test coverage

This commit is contained in:
Ben Toogood 2020-05-24 20:26:37 +01:00
parent 2729569f66
commit 95703e4565
3 changed files with 61 additions and 43 deletions

View File

@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"sort" "sort"
"strings" "strings"
"sync"
"time" "time"
"github.com/micro/go-micro/v2/auth" "github.com/micro/go-micro/v2/auth"
@ -23,9 +22,6 @@ type svc struct {
auth pb.AuthService auth pb.AuthService
rule pb.RulesService rule pb.RulesService
jwt token.Provider jwt token.Provider
rules []*pb.Rule
sync.Mutex
} }
func (s *svc) String() string { func (s *svc) String() string {
@ -53,8 +49,6 @@ func (s *svc) Init(opts ...auth.Option) {
} }
func (s *svc) Options() auth.Options { func (s *svc) Options() auth.Options {
s.Lock()
defer s.Unlock()
return s.options return s.options
} }
@ -110,9 +104,6 @@ func (s *svc) Revoke(role string, res *auth.Resource) error {
// Verify an account has access to a resource // Verify an account has access to a resource
func (s *svc) Verify(acc *auth.Account, res *auth.Resource) error { func (s *svc) Verify(acc *auth.Account, res *auth.Resource) error {
// load the rules if none are loaded
s.loadRulesIfEmpty()
// set the namespace on the resource // set the namespace on the resource
if len(res.Namespace) == 0 { if len(res.Namespace) == 0 {
res.Namespace = s.Options().Namespace res.Namespace = s.Options().Namespace
@ -230,11 +221,14 @@ func accessForRule(rule *pb.Rule, acc *auth.Account, res *auth.Resource) pb.Acce
// listRules gets all the rules from the store which match the filters. // listRules gets all the rules from the store which match the filters.
// filters are namespace, type, name and then endpoint. // filters are namespace, type, name and then endpoint.
func (s *svc) listRules(filters ...string) []*pb.Rule { func (s *svc) listRules(filters ...string) []*pb.Rule {
s.Lock() // load rules using the client cache
defer s.Unlock() allRules, err := s.loadRules()
if err != nil {
return []*pb.Rule{}
}
var rules []*pb.Rule var rules []*pb.Rule
for _, r := range s.rules { for _, r := range allRules {
if len(filters) > 0 && r.Resource.Namespace != filters[0] { if len(filters) > 0 && r.Resource.Namespace != filters[0] {
continue continue
} }
@ -260,27 +254,15 @@ func (s *svc) listRules(filters ...string) []*pb.Rule {
} }
// loadRules retrieves the rules from the auth service // loadRules retrieves the rules from the auth service
func (s *svc) loadRules() { func (s *svc) loadRules() ([]*pb.Rule, error) {
rsp, err := s.rule.List(context.TODO(), &pb.ListRequest{}, client.WithCache(time.Minute)) rsp, err := s.rule.List(context.TODO(), &pb.ListRequest{}, client.WithCache(time.Minute))
s.Lock()
defer s.Unlock()
if err != nil { if err != nil {
log.Errorf("Error listing rules: %v", err) log.Debugf("Error listing rules: %v", err)
return return nil, err
} }
s.rules = rsp.Rules return rsp.Rules, nil
}
func (s *svc) loadRulesIfEmpty() {
s.Lock()
rules := s.rules
s.Unlock()
if len(rules) == 0 {
s.loadRules()
}
} }
func serializeToken(t *pb.Token) *auth.Token { func serializeToken(t *pb.Token) *auth.Token {

View File

@ -2,6 +2,7 @@ package wrapper
import ( import (
"context" "context"
"reflect"
"strings" "strings"
"time" "time"
@ -229,7 +230,7 @@ func AuthHandler(fn func() auth.Auth) server.HandlerWrapper {
} }
type cacheWrapper struct { type cacheWrapper struct {
cache func() *client.Cache cacheFn func() *client.Cache
client.Client client.Client
} }
@ -243,7 +244,7 @@ func (c *cacheWrapper) Call(ctx context.Context, req client.Request, rsp interfa
} }
// if the client doesn't have a cacbe setup don't continue // if the client doesn't have a cacbe setup don't continue
cache := c.cache() cache := c.cacheFn()
if cache == nil { if cache == nil {
return c.Client.Call(ctx, req, rsp, opts...) return c.Client.Call(ctx, req, rsp, opts...)
} }
@ -253,9 +254,15 @@ func (c *cacheWrapper) Call(ctx context.Context, req client.Request, rsp interfa
return c.Client.Call(ctx, req, rsp, opts...) return c.Client.Call(ctx, req, rsp, opts...)
} }
// check to see if there is a response // if the response is nil don't call the cache since we can't assign the response
if rsp == nil {
return c.Client.Call(ctx, req, rsp, opts...)
}
// check to see if there is a response cached, if there is assign it
if r, ok := cache.Get(ctx, &req); ok { if r, ok := cache.Get(ctx, &req); ok {
rsp = r val := reflect.ValueOf(rsp).Elem()
val.Set(reflect.ValueOf(r).Elem())
return nil return nil
} }

View File

@ -2,6 +2,7 @@ package wrapper
import ( import (
"context" "context"
"reflect"
"testing" "testing"
"time" "time"
@ -56,25 +57,33 @@ func TestWrapper(t *testing.T) {
type testClient struct { type testClient struct {
callCount int callCount int
callRsp interface{} callRsp interface{}
cache *client.Cache
client.Client client.Client
} }
func (c *testClient) Call(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error { func (c *testClient) Call(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error {
c.callCount++ c.callCount++
rsp = c.callRsp
if c.callRsp != nil {
val := reflect.ValueOf(rsp).Elem()
val.Set(reflect.ValueOf(c.callRsp).Elem())
}
return nil return nil
} }
func (c *testClient) Options() client.Options { type testRsp struct {
return client.Options{Cache: c.cache} value string
} }
func TestCacheWrapper(t *testing.T) { func TestCacheWrapper(t *testing.T) {
req := client.NewRequest("go.micro.service.foo", "Foo.Bar", nil) req := client.NewRequest("go.micro.service.foo", "Foo.Bar", nil)
t.Run("NilCache", func(t *testing.T) { t.Run("NilCache", func(t *testing.T) {
cli := new(testClient) cli := new(testClient)
w := CacheClient(cli)
w := CacheClient(func() *client.Cache {
return nil
}, cli)
// perfroming two requests should increment the call count by two indicating the cache wasn't // perfroming two requests should increment the call count by two indicating the cache wasn't
// used even though the WithCache option was passed. // used even though the WithCache option was passed.
@ -88,7 +97,11 @@ func TestCacheWrapper(t *testing.T) {
t.Run("OptionNotSet", func(t *testing.T) { t.Run("OptionNotSet", func(t *testing.T) {
cli := new(testClient) cli := new(testClient)
w := CacheClient(cli) cache := client.NewCache()
w := CacheClient(func() *client.Cache {
return cache
}, cli)
// perfroming two requests should increment the call count by two since we didn't pass the WithCache // perfroming two requests should increment the call count by two since we didn't pass the WithCache
// option to Call. // option to Call.
@ -101,13 +114,21 @@ func TestCacheWrapper(t *testing.T) {
}) })
t.Run("OptionSet", func(t *testing.T) { t.Run("OptionSet", func(t *testing.T) {
cli := &testClient{callRsp: "foobar", cache: client.NewCache()} val := "foo"
w := CacheClient(cli) cli := &testClient{callRsp: &testRsp{value: val}}
cache := client.NewCache()
w := CacheClient(func() *client.Cache {
return cache
}, cli)
// perfroming two requests should increment the call count by once since the second request should // perfroming two requests should increment the call count by once since the second request should
// have used the cache // have used the cache. The correct value should be set on both responses and no errors should
err1 := w.Call(context.TODO(), req, nil, client.WithCache(time.Minute)) // be returned.
err2 := w.Call(context.TODO(), req, nil, client.WithCache(time.Minute)) rsp1 := &testRsp{}
rsp2 := &testRsp{}
err1 := w.Call(context.TODO(), req, rsp1, client.WithCache(time.Minute))
err2 := w.Call(context.TODO(), req, rsp2, client.WithCache(time.Minute))
if err1 != nil { if err1 != nil {
t.Errorf("Expected nil error, got %v", err1) t.Errorf("Expected nil error, got %v", err1)
@ -115,6 +136,14 @@ func TestCacheWrapper(t *testing.T) {
if err2 != nil { if err2 != nil {
t.Errorf("Expected nil error, got %v", err2) t.Errorf("Expected nil error, got %v", err2)
} }
if rsp1.value != val {
t.Errorf("Expected %v to be assigned to the value, got %v", val, rsp1.value)
}
if rsp2.value != val {
t.Errorf("Expected %v to be assigned to the value, got %v", val, rsp2.value)
}
if cli.callCount != 1 { if cli.callCount != 1 {
t.Errorf("Expected the client to be called 1 time, was actually called %v time(s)", cli.callCount) t.Errorf("Expected the client to be called 1 time, was actually called %v time(s)", cli.callCount)
} }