diff --git a/auth/auth.go b/auth/auth.go index e34051fa..d18b0e38 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -90,8 +90,6 @@ type Token struct { const ( // DefaultNamespace used for auth DefaultNamespace = "go.micro" - // NamespaceKey is the key used when storing the namespace in metadata - NamespaceKey = "Micro-Namespace" // MetadataKey is the key used when storing the account in metadata MetadataKey = "auth-account" // TokenCookieName is the name of the cookie which stores the auth token diff --git a/auth/options.go b/auth/options.go index d3c0f4ea..3cd1bc69 100644 --- a/auth/options.go +++ b/auth/options.go @@ -7,6 +7,19 @@ import ( "github.com/micro/go-micro/v2/store" ) +func NewOptions(opts ...Option) Options { + var options Options + for _, o := range opts { + o(&options) + } + + if len(options.Namespace) == 0 { + options.Namespace = DefaultNamespace + } + + return options +} + type Options struct { // Namespace the service belongs to Namespace string diff --git a/auth/service/service.go b/auth/service/service.go index cb1740c7..73e3f80a 100644 --- a/auth/service/service.go +++ b/auth/service/service.go @@ -18,9 +18,7 @@ import ( // NewAuth returns a new instance of the Auth service func NewAuth(opts ...auth.Option) auth.Auth { - svc := new(svc) - svc.Init(opts...) - return svc + return &svc{options: auth.NewOptions(opts...)} } // svc is the service implementation of the Auth interface @@ -55,8 +53,9 @@ func (s *svc) Init(opts ...auth.Option) { } // load rules periodically from the auth service - ruleTimer := time.NewTicker(time.Second * 30) go func() { + ruleTimer := time.NewTicker(time.Second * 30) + // load rules immediately on startup s.loadRules() @@ -74,10 +73,11 @@ func (s *svc) Init(opts ...auth.Option) { // we have client credentials and must load a new token // periodically if len(s.options.ID) > 0 || len(s.options.Secret) > 0 { - tokenTimer := time.NewTicker(time.Minute) + // get a token immediately + s.refreshToken() go func() { - s.refreshToken() + tokenTimer := time.NewTicker(time.Minute) for { <-tokenTimer.C @@ -180,11 +180,15 @@ func (s *svc) Verify(acc *auth.Account, res *auth.Resource) error { } } - // set a default account id to log + // set a default account id / namespace to log logID := acc.ID if len(logID) == 0 { logID = "[no account]" } + logNamespace := acc.Namespace + if len(logNamespace) == 0 { + logNamespace = "[no namespace]" + } for _, q := range queries { for _, rule := range s.listRules(q...) { @@ -192,17 +196,17 @@ func (s *svc) Verify(acc *auth.Account, res *auth.Resource) error { case pb.Access_UNKNOWN: continue // rule did not specify access, check the next rule case pb.Access_GRANTED: - log.Infof("%v:%v granted access to %v:%v:%v:%v by rule %v", acc.Namespace, logID, res.Namespace, res.Type, res.Name, res.Endpoint, rule.Id) + log.Tracef("%v:%v granted access to %v:%v:%v:%v by rule %v", logNamespace, logID, res.Namespace, res.Type, res.Name, res.Endpoint, rule.Id) return nil // rule grants the account access to the resource case pb.Access_DENIED: - log.Infof("%v:%v denied access to %v:%v:%v:%v by rule %v", acc.Namespace, logID, res.Namespace, res.Type, res.Name, res.Endpoint, rule.Id) + log.Tracef("%v:%v denied access to %v:%v:%v:%v by rule %v", logNamespace, logID, res.Namespace, res.Type, res.Name, res.Endpoint, rule.Id) return auth.ErrForbidden // rule denies access to the resource } } } // no rules were found for the resource, default to denying access - log.Infof("%v:%v denied access to %v:%v:%v:%v by lack of rule (%v rules found for namespace)", acc.Namespace, logID, res.Namespace, res.Type, res.Name, res.Endpoint, len(s.listRules(res.Namespace))) + log.Tracef("%v:%v denied access to %v:%v:%v:%v by lack of rule (%v rules found for namespace)", logNamespace, logID, res.Namespace, res.Type, res.Name, res.Endpoint, len(s.listRules(res.Namespace))) return auth.ErrForbidden } diff --git a/registry/cache/cache.go b/registry/cache/cache.go index 7714f980..2799f2d4 100644 --- a/registry/cache/cache.go +++ b/registry/cache/cache.go @@ -419,7 +419,7 @@ func (c *cache) watch(w registry.Watcher) error { } } -func (c *cache) GetService(service string) ([]*registry.Service, error) { +func (c *cache) GetService(service string, opts ...registry.GetOption) ([]*registry.Service, error) { // get the service services, err := c.get(service) if err != nil { diff --git a/registry/etcd/etcd.go b/registry/etcd/etcd.go index 8b7a65e1..8cbda8cf 100644 --- a/registry/etcd/etcd.go +++ b/registry/etcd/etcd.go @@ -276,7 +276,7 @@ func (e *etcdRegistry) registerNode(s *registry.Service, node *registry.Node, op return nil } -func (e *etcdRegistry) Deregister(s *registry.Service) error { +func (e *etcdRegistry) Deregister(s *registry.Service, opts ...registry.DeregisterOption) error { if len(s.Nodes) == 0 { return errors.New("Require at least one node") } @@ -322,7 +322,7 @@ func (e *etcdRegistry) Register(s *registry.Service, opts ...registry.RegisterOp return gerr } -func (e *etcdRegistry) GetService(name string) ([]*registry.Service, error) { +func (e *etcdRegistry) GetService(name string, opts ...registry.GetOption) ([]*registry.Service, error) { ctx, cancel := context.WithTimeout(context.Background(), e.options.Timeout) defer cancel() @@ -362,7 +362,7 @@ func (e *etcdRegistry) GetService(name string) ([]*registry.Service, error) { return services, nil } -func (e *etcdRegistry) ListServices() ([]*registry.Service, error) { +func (e *etcdRegistry) ListServices(opts ...registry.ListOption) ([]*registry.Service, error) { versions := make(map[string]*registry.Service) ctx, cancel := context.WithTimeout(context.Background(), e.options.Timeout) diff --git a/registry/mdns_registry.go b/registry/mdns_registry.go index 18758878..fde00f37 100644 --- a/registry/mdns_registry.go +++ b/registry/mdns_registry.go @@ -269,7 +269,7 @@ func (m *mdnsRegistry) Register(service *Service, opts ...RegisterOption) error return gerr } -func (m *mdnsRegistry) Deregister(service *Service) error { +func (m *mdnsRegistry) Deregister(service *Service, opts ...DeregisterOption) error { m.Lock() defer m.Unlock() @@ -304,7 +304,7 @@ func (m *mdnsRegistry) Deregister(service *Service) error { return nil } -func (m *mdnsRegistry) GetService(service string) ([]*Service, error) { +func (m *mdnsRegistry) GetService(service string, opts ...GetOption) ([]*Service, error) { serviceMap := make(map[string]*Service) entries := make(chan *mdns.ServiceEntry, 10) done := make(chan bool) @@ -396,7 +396,7 @@ func (m *mdnsRegistry) GetService(service string) ([]*Service, error) { return services, nil } -func (m *mdnsRegistry) ListServices() ([]*Service, error) { +func (m *mdnsRegistry) ListServices(opts ...ListOption) ([]*Service, error) { serviceMap := make(map[string]bool) entries := make(chan *mdns.ServiceEntry, 10) done := make(chan bool) diff --git a/registry/memory/memory.go b/registry/memory/memory.go index cde49b96..44f22439 100644 --- a/registry/memory/memory.go +++ b/registry/memory/memory.go @@ -207,7 +207,7 @@ func (m *Registry) Register(s *registry.Service, opts ...registry.RegisterOption return nil } -func (m *Registry) Deregister(s *registry.Service) error { +func (m *Registry) Deregister(s *registry.Service, opts ...registry.DeregisterOption) error { m.Lock() defer m.Unlock() @@ -240,7 +240,7 @@ func (m *Registry) Deregister(s *registry.Service) error { return nil } -func (m *Registry) GetService(name string) ([]*registry.Service, error) { +func (m *Registry) GetService(name string, opts ...registry.GetOption) ([]*registry.Service, error) { m.RLock() defer m.RUnlock() @@ -259,7 +259,7 @@ func (m *Registry) GetService(name string) ([]*registry.Service, error) { return services, nil } -func (m *Registry) ListServices() ([]*registry.Service, error) { +func (m *Registry) ListServices(opts ...registry.ListOption) ([]*registry.Service, error) { m.RLock() defer m.RUnlock() diff --git a/registry/options.go b/registry/options.go index d3a54856..03fd5eb4 100644 --- a/registry/options.go +++ b/registry/options.go @@ -32,6 +32,18 @@ type WatchOptions struct { Context context.Context } +type DeregisterOptions struct { + Context context.Context +} + +type GetOptions struct { + Context context.Context +} + +type ListOptions struct { + Context context.Context +} + // Addrs is the registry addresses to use func Addrs(addrs ...string) Option { return func(o *Options) { @@ -65,9 +77,39 @@ func RegisterTTL(t time.Duration) RegisterOption { } } +func RegisterContext(ctx context.Context) RegisterOption { + return func(o *RegisterOptions) { + o.Context = ctx + } +} + // Watch a service func WatchService(name string) WatchOption { return func(o *WatchOptions) { o.Service = name } } + +func WatchContext(ctx context.Context) WatchOption { + return func(o *WatchOptions) { + o.Context = ctx + } +} + +func DeregisterContext(ctx context.Context) DeregisterOption { + return func(o *DeregisterOptions) { + o.Context = ctx + } +} + +func GetContext(ctx context.Context) GetOption { + return func(o *GetOptions) { + o.Context = ctx + } +} + +func ListContext(ctx context.Context) ListOption { + return func(o *ListOptions) { + o.Context = ctx + } +} diff --git a/registry/registry.go b/registry/registry.go index 291a1988..dce9b431 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -21,9 +21,9 @@ type Registry interface { Init(...Option) error Options() Options Register(*Service, ...RegisterOption) error - Deregister(*Service) error - GetService(string) ([]*Service, error) - ListServices() ([]*Service, error) + Deregister(*Service, ...DeregisterOption) error + GetService(string, ...GetOption) ([]*Service, error) + ListServices(...ListOption) ([]*Service, error) Watch(...WatchOption) (Watcher, error) String() string } @@ -61,6 +61,12 @@ type RegisterOption func(*RegisterOptions) type WatchOption func(*WatchOptions) +type DeregisterOption func(*DeregisterOptions) + +type GetOption func(*GetOptions) + +type ListOption func(*ListOptions) + // Register a service node. Additionally supply options such as TTL. func Register(s *Service, opts ...RegisterOption) error { return DefaultRegistry.Register(s, opts...) diff --git a/registry/service/service.go b/registry/service/service.go index e266347f..acf42d3a 100644 --- a/registry/service/service.go +++ b/registry/service/service.go @@ -58,13 +58,16 @@ func (s *serviceRegistry) Register(srv *registry.Service, opts ...registry.Regis for _, o := range opts { o(&options) } + if options.Context == nil { + options.Context = context.TODO() + } // encode srv into protobuf and pack Register TTL into it pbSrv := ToProto(srv) pbSrv.Options.Ttl = int64(options.TTL.Seconds()) // register the service - _, err := s.client.Register(context.TODO(), pbSrv, s.callOpts()...) + _, err := s.client.Register(options.Context, pbSrv, s.callOpts()...) if err != nil { return err } @@ -72,17 +75,33 @@ func (s *serviceRegistry) Register(srv *registry.Service, opts ...registry.Regis return nil } -func (s *serviceRegistry) Deregister(srv *registry.Service) error { +func (s *serviceRegistry) Deregister(srv *registry.Service, opts ...registry.DeregisterOption) error { + var options registry.DeregisterOptions + for _, o := range opts { + o(&options) + } + if options.Context == nil { + options.Context = context.TODO() + } + // deregister the service - _, err := s.client.Deregister(context.TODO(), ToProto(srv), s.callOpts()...) + _, err := s.client.Deregister(options.Context, ToProto(srv), s.callOpts()...) if err != nil { return err } return nil } -func (s *serviceRegistry) GetService(name string) ([]*registry.Service, error) { - rsp, err := s.client.GetService(context.TODO(), &pb.GetRequest{ +func (s *serviceRegistry) GetService(name string, opts ...registry.GetOption) ([]*registry.Service, error) { + var options registry.GetOptions + for _, o := range opts { + o(&options) + } + if options.Context == nil { + options.Context = context.TODO() + } + + rsp, err := s.client.GetService(options.Context, &pb.GetRequest{ Service: name, }, s.callOpts()...) @@ -97,8 +116,16 @@ func (s *serviceRegistry) GetService(name string) ([]*registry.Service, error) { return services, nil } -func (s *serviceRegistry) ListServices() ([]*registry.Service, error) { - rsp, err := s.client.ListServices(context.TODO(), &pb.ListRequest{}, s.callOpts()...) +func (s *serviceRegistry) ListServices(opts ...registry.ListOption) ([]*registry.Service, error) { + var options registry.ListOptions + for _, o := range opts { + o(&options) + } + if options.Context == nil { + options.Context = context.TODO() + } + + rsp, err := s.client.ListServices(options.Context, &pb.ListRequest{}, s.callOpts()...) if err != nil { return nil, err } @@ -116,8 +143,11 @@ func (s *serviceRegistry) Watch(opts ...registry.WatchOption) (registry.Watcher, for _, o := range opts { o(&options) } + if options.Context == nil { + options.Context = context.TODO() + } - stream, err := s.client.Watch(context.TODO(), &pb.WatchRequest{ + stream, err := s.client.Watch(options.Context, &pb.WatchRequest{ Service: options.Service, }, s.callOpts()...) diff --git a/util/wrapper/wrapper.go b/util/wrapper/wrapper.go index 13ae5531..e5ce4bb1 100644 --- a/util/wrapper/wrapper.go +++ b/util/wrapper/wrapper.go @@ -159,7 +159,7 @@ func AuthHandler(fn func() auth.Auth) server.HandlerWrapper { // Inspect the token and get the account account, err := a.Inspect(token) if err != nil { - account = &auth.Account{} + account = &auth.Account{Namespace: a.Options().Namespace} } // construct the resource