Fixes and improved test coverage
This commit is contained in:
parent
2729569f66
commit
95703e4565
@ -5,7 +5,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/micro/go-micro/v2/auth"
|
"github.com/micro/go-micro/v2/auth"
|
||||||
@ -23,9 +22,6 @@ type svc struct {
|
|||||||
auth pb.AuthService
|
auth pb.AuthService
|
||||||
rule pb.RulesService
|
rule pb.RulesService
|
||||||
jwt token.Provider
|
jwt token.Provider
|
||||||
|
|
||||||
rules []*pb.Rule
|
|
||||||
sync.Mutex
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *svc) String() string {
|
func (s *svc) String() string {
|
||||||
@ -53,8 +49,6 @@ func (s *svc) Init(opts ...auth.Option) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *svc) Options() auth.Options {
|
func (s *svc) Options() auth.Options {
|
||||||
s.Lock()
|
|
||||||
defer s.Unlock()
|
|
||||||
return s.options
|
return s.options
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -110,9 +104,6 @@ func (s *svc) Revoke(role string, res *auth.Resource) error {
|
|||||||
|
|
||||||
// Verify an account has access to a resource
|
// Verify an account has access to a resource
|
||||||
func (s *svc) Verify(acc *auth.Account, res *auth.Resource) error {
|
func (s *svc) Verify(acc *auth.Account, res *auth.Resource) error {
|
||||||
// load the rules if none are loaded
|
|
||||||
s.loadRulesIfEmpty()
|
|
||||||
|
|
||||||
// set the namespace on the resource
|
// set the namespace on the resource
|
||||||
if len(res.Namespace) == 0 {
|
if len(res.Namespace) == 0 {
|
||||||
res.Namespace = s.Options().Namespace
|
res.Namespace = s.Options().Namespace
|
||||||
@ -230,11 +221,14 @@ func accessForRule(rule *pb.Rule, acc *auth.Account, res *auth.Resource) pb.Acce
|
|||||||
// listRules gets all the rules from the store which match the filters.
|
// listRules gets all the rules from the store which match the filters.
|
||||||
// filters are namespace, type, name and then endpoint.
|
// filters are namespace, type, name and then endpoint.
|
||||||
func (s *svc) listRules(filters ...string) []*pb.Rule {
|
func (s *svc) listRules(filters ...string) []*pb.Rule {
|
||||||
s.Lock()
|
// load rules using the client cache
|
||||||
defer s.Unlock()
|
allRules, err := s.loadRules()
|
||||||
|
if err != nil {
|
||||||
|
return []*pb.Rule{}
|
||||||
|
}
|
||||||
|
|
||||||
var rules []*pb.Rule
|
var rules []*pb.Rule
|
||||||
for _, r := range s.rules {
|
for _, r := range allRules {
|
||||||
if len(filters) > 0 && r.Resource.Namespace != filters[0] {
|
if len(filters) > 0 && r.Resource.Namespace != filters[0] {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -260,27 +254,15 @@ func (s *svc) listRules(filters ...string) []*pb.Rule {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// loadRules retrieves the rules from the auth service
|
// loadRules retrieves the rules from the auth service
|
||||||
func (s *svc) loadRules() {
|
func (s *svc) loadRules() ([]*pb.Rule, error) {
|
||||||
rsp, err := s.rule.List(context.TODO(), &pb.ListRequest{}, client.WithCache(time.Minute))
|
rsp, err := s.rule.List(context.TODO(), &pb.ListRequest{}, client.WithCache(time.Minute))
|
||||||
s.Lock()
|
|
||||||
defer s.Unlock()
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Error listing rules: %v", err)
|
log.Debugf("Error listing rules: %v", err)
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
s.rules = rsp.Rules
|
return rsp.Rules, nil
|
||||||
}
|
|
||||||
|
|
||||||
func (s *svc) loadRulesIfEmpty() {
|
|
||||||
s.Lock()
|
|
||||||
rules := s.rules
|
|
||||||
s.Unlock()
|
|
||||||
|
|
||||||
if len(rules) == 0 {
|
|
||||||
s.loadRules()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func serializeToken(t *pb.Token) *auth.Token {
|
func serializeToken(t *pb.Token) *auth.Token {
|
||||||
|
@ -2,6 +2,7 @@ package wrapper
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -229,7 +230,7 @@ func AuthHandler(fn func() auth.Auth) server.HandlerWrapper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type cacheWrapper struct {
|
type cacheWrapper struct {
|
||||||
cache func() *client.Cache
|
cacheFn func() *client.Cache
|
||||||
client.Client
|
client.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -243,7 +244,7 @@ func (c *cacheWrapper) Call(ctx context.Context, req client.Request, rsp interfa
|
|||||||
}
|
}
|
||||||
|
|
||||||
// if the client doesn't have a cacbe setup don't continue
|
// if the client doesn't have a cacbe setup don't continue
|
||||||
cache := c.cache()
|
cache := c.cacheFn()
|
||||||
if cache == nil {
|
if cache == nil {
|
||||||
return c.Client.Call(ctx, req, rsp, opts...)
|
return c.Client.Call(ctx, req, rsp, opts...)
|
||||||
}
|
}
|
||||||
@ -253,9 +254,15 @@ func (c *cacheWrapper) Call(ctx context.Context, req client.Request, rsp interfa
|
|||||||
return c.Client.Call(ctx, req, rsp, opts...)
|
return c.Client.Call(ctx, req, rsp, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// check to see if there is a response
|
// if the response is nil don't call the cache since we can't assign the response
|
||||||
|
if rsp == nil {
|
||||||
|
return c.Client.Call(ctx, req, rsp, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// check to see if there is a response cached, if there is assign it
|
||||||
if r, ok := cache.Get(ctx, &req); ok {
|
if r, ok := cache.Get(ctx, &req); ok {
|
||||||
rsp = r
|
val := reflect.ValueOf(rsp).Elem()
|
||||||
|
val.Set(reflect.ValueOf(r).Elem())
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@ package wrapper
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -56,25 +57,33 @@ func TestWrapper(t *testing.T) {
|
|||||||
type testClient struct {
|
type testClient struct {
|
||||||
callCount int
|
callCount int
|
||||||
callRsp interface{}
|
callRsp interface{}
|
||||||
cache *client.Cache
|
|
||||||
client.Client
|
client.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *testClient) Call(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error {
|
func (c *testClient) Call(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error {
|
||||||
c.callCount++
|
c.callCount++
|
||||||
rsp = c.callRsp
|
|
||||||
|
if c.callRsp != nil {
|
||||||
|
val := reflect.ValueOf(rsp).Elem()
|
||||||
|
val.Set(reflect.ValueOf(c.callRsp).Elem())
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *testClient) Options() client.Options {
|
type testRsp struct {
|
||||||
return client.Options{Cache: c.cache}
|
value string
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCacheWrapper(t *testing.T) {
|
func TestCacheWrapper(t *testing.T) {
|
||||||
req := client.NewRequest("go.micro.service.foo", "Foo.Bar", nil)
|
req := client.NewRequest("go.micro.service.foo", "Foo.Bar", nil)
|
||||||
|
|
||||||
t.Run("NilCache", func(t *testing.T) {
|
t.Run("NilCache", func(t *testing.T) {
|
||||||
cli := new(testClient)
|
cli := new(testClient)
|
||||||
w := CacheClient(cli)
|
|
||||||
|
w := CacheClient(func() *client.Cache {
|
||||||
|
return nil
|
||||||
|
}, cli)
|
||||||
|
|
||||||
// perfroming two requests should increment the call count by two indicating the cache wasn't
|
// perfroming two requests should increment the call count by two indicating the cache wasn't
|
||||||
// used even though the WithCache option was passed.
|
// used even though the WithCache option was passed.
|
||||||
@ -88,7 +97,11 @@ func TestCacheWrapper(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("OptionNotSet", func(t *testing.T) {
|
t.Run("OptionNotSet", func(t *testing.T) {
|
||||||
cli := new(testClient)
|
cli := new(testClient)
|
||||||
w := CacheClient(cli)
|
cache := client.NewCache()
|
||||||
|
|
||||||
|
w := CacheClient(func() *client.Cache {
|
||||||
|
return cache
|
||||||
|
}, cli)
|
||||||
|
|
||||||
// perfroming two requests should increment the call count by two since we didn't pass the WithCache
|
// perfroming two requests should increment the call count by two since we didn't pass the WithCache
|
||||||
// option to Call.
|
// option to Call.
|
||||||
@ -101,13 +114,21 @@ func TestCacheWrapper(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("OptionSet", func(t *testing.T) {
|
t.Run("OptionSet", func(t *testing.T) {
|
||||||
cli := &testClient{callRsp: "foobar", cache: client.NewCache()}
|
val := "foo"
|
||||||
w := CacheClient(cli)
|
cli := &testClient{callRsp: &testRsp{value: val}}
|
||||||
|
cache := client.NewCache()
|
||||||
|
|
||||||
|
w := CacheClient(func() *client.Cache {
|
||||||
|
return cache
|
||||||
|
}, cli)
|
||||||
|
|
||||||
// perfroming two requests should increment the call count by once since the second request should
|
// perfroming two requests should increment the call count by once since the second request should
|
||||||
// have used the cache
|
// have used the cache. The correct value should be set on both responses and no errors should
|
||||||
err1 := w.Call(context.TODO(), req, nil, client.WithCache(time.Minute))
|
// be returned.
|
||||||
err2 := w.Call(context.TODO(), req, nil, client.WithCache(time.Minute))
|
rsp1 := &testRsp{}
|
||||||
|
rsp2 := &testRsp{}
|
||||||
|
err1 := w.Call(context.TODO(), req, rsp1, client.WithCache(time.Minute))
|
||||||
|
err2 := w.Call(context.TODO(), req, rsp2, client.WithCache(time.Minute))
|
||||||
|
|
||||||
if err1 != nil {
|
if err1 != nil {
|
||||||
t.Errorf("Expected nil error, got %v", err1)
|
t.Errorf("Expected nil error, got %v", err1)
|
||||||
@ -115,6 +136,14 @@ func TestCacheWrapper(t *testing.T) {
|
|||||||
if err2 != nil {
|
if err2 != nil {
|
||||||
t.Errorf("Expected nil error, got %v", err2)
|
t.Errorf("Expected nil error, got %v", err2)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if rsp1.value != val {
|
||||||
|
t.Errorf("Expected %v to be assigned to the value, got %v", val, rsp1.value)
|
||||||
|
}
|
||||||
|
if rsp2.value != val {
|
||||||
|
t.Errorf("Expected %v to be assigned to the value, got %v", val, rsp2.value)
|
||||||
|
}
|
||||||
|
|
||||||
if cli.callCount != 1 {
|
if cli.callCount != 1 {
|
||||||
t.Errorf("Expected the client to be called 1 time, was actually called %v time(s)", cli.callCount)
|
t.Errorf("Expected the client to be called 1 time, was actually called %v time(s)", cli.callCount)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user