From e17825474fcd8c5c22134c30ba1a0372074b8c47 Mon Sep 17 00:00:00 2001 From: Ben Toogood Date: Tue, 14 Apr 2020 12:32:59 +0100 Subject: [PATCH] Add context options to the runtime --- auth/auth.go | 10 -------- auth/service/service.go | 20 ++++++++++------ registry/cache/cache.go | 2 +- registry/etcd/etcd.go | 6 ++--- registry/mdns_registry.go | 6 ++--- registry/memory/memory.go | 6 ++--- registry/options.go | 42 +++++++++++++++++++++++++++++++++ registry/registry.go | 12 +++++++--- registry/service/service.go | 46 ++++++++++++++++++++++++++++++------- util/wrapper/wrapper.go | 5 ---- 10 files changed, 112 insertions(+), 43 deletions(-) diff --git a/auth/auth.go b/auth/auth.go index 64b29d81..d18b0e38 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -90,8 +90,6 @@ type Token struct { const ( // DefaultNamespace used for auth DefaultNamespace = "go.micro" - // NamespaceKey is the key used when storing the namespace in metadata - NamespaceKey = "Micro-Namespace" // MetadataKey is the key used when storing the account in metadata MetadataKey = "auth-account" // TokenCookieName is the name of the cookie which stores the auth token @@ -133,11 +131,3 @@ func ContextWithAccount(ctx context.Context, account *Account) (context.Context, // generate a new context with the MetadataKey set return metadata.Set(ctx, MetadataKey, string(bytes)), nil } - -// NamespaceFromContext gets the namespace from the context -func NamespaceFromContext(ctx context.Context) string { - if ns, ok := metadata.Get(ctx, NamespaceKey); ok { - return ns - } - return DefaultNamespace -} diff --git a/auth/service/service.go b/auth/service/service.go index 5a59f988..73e3f80a 100644 --- a/auth/service/service.go +++ b/auth/service/service.go @@ -53,8 +53,9 @@ func (s *svc) Init(opts ...auth.Option) { } // load rules periodically from the auth service - ruleTimer := time.NewTicker(time.Second * 30) go func() { + ruleTimer := time.NewTicker(time.Second * 30) + // load rules immediately on startup s.loadRules() @@ -72,10 +73,11 @@ func (s *svc) Init(opts ...auth.Option) { // we have client credentials and must load a new token // periodically if len(s.options.ID) > 0 || len(s.options.Secret) > 0 { - tokenTimer := time.NewTicker(time.Minute) + // get a token immediately + s.refreshToken() go func() { - s.refreshToken() + tokenTimer := time.NewTicker(time.Minute) for { <-tokenTimer.C @@ -178,11 +180,15 @@ func (s *svc) Verify(acc *auth.Account, res *auth.Resource) error { } } - // set a default account id to log + // set a default account id / namespace to log logID := acc.ID if len(logID) == 0 { logID = "[no account]" } + logNamespace := acc.Namespace + if len(logNamespace) == 0 { + logNamespace = "[no namespace]" + } for _, q := range queries { for _, rule := range s.listRules(q...) { @@ -190,17 +196,17 @@ func (s *svc) Verify(acc *auth.Account, res *auth.Resource) error { case pb.Access_UNKNOWN: continue // rule did not specify access, check the next rule case pb.Access_GRANTED: - log.Infof("%v:%v granted access to %v:%v:%v:%v by rule %v", acc.Namespace, logID, res.Namespace, res.Type, res.Name, res.Endpoint, rule.Id) + log.Tracef("%v:%v granted access to %v:%v:%v:%v by rule %v", logNamespace, logID, res.Namespace, res.Type, res.Name, res.Endpoint, rule.Id) return nil // rule grants the account access to the resource case pb.Access_DENIED: - log.Infof("%v:%v denied access to %v:%v:%v:%v by rule %v", acc.Namespace, logID, res.Namespace, res.Type, res.Name, res.Endpoint, rule.Id) + log.Tracef("%v:%v denied access to %v:%v:%v:%v by rule %v", logNamespace, logID, res.Namespace, res.Type, res.Name, res.Endpoint, rule.Id) return auth.ErrForbidden // rule denies access to the resource } } } // no rules were found for the resource, default to denying access - log.Infof("%v:%v denied access to %v:%v:%v:%v by lack of rule (%v rules found for namespace)", acc.Namespace, logID, res.Namespace, res.Type, res.Name, res.Endpoint, len(s.listRules(res.Namespace))) + log.Tracef("%v:%v denied access to %v:%v:%v:%v by lack of rule (%v rules found for namespace)", logNamespace, logID, res.Namespace, res.Type, res.Name, res.Endpoint, len(s.listRules(res.Namespace))) return auth.ErrForbidden } diff --git a/registry/cache/cache.go b/registry/cache/cache.go index 7714f980..2799f2d4 100644 --- a/registry/cache/cache.go +++ b/registry/cache/cache.go @@ -419,7 +419,7 @@ func (c *cache) watch(w registry.Watcher) error { } } -func (c *cache) GetService(service string) ([]*registry.Service, error) { +func (c *cache) GetService(service string, opts ...registry.GetOption) ([]*registry.Service, error) { // get the service services, err := c.get(service) if err != nil { diff --git a/registry/etcd/etcd.go b/registry/etcd/etcd.go index 8b7a65e1..8cbda8cf 100644 --- a/registry/etcd/etcd.go +++ b/registry/etcd/etcd.go @@ -276,7 +276,7 @@ func (e *etcdRegistry) registerNode(s *registry.Service, node *registry.Node, op return nil } -func (e *etcdRegistry) Deregister(s *registry.Service) error { +func (e *etcdRegistry) Deregister(s *registry.Service, opts ...registry.DeregisterOption) error { if len(s.Nodes) == 0 { return errors.New("Require at least one node") } @@ -322,7 +322,7 @@ func (e *etcdRegistry) Register(s *registry.Service, opts ...registry.RegisterOp return gerr } -func (e *etcdRegistry) GetService(name string) ([]*registry.Service, error) { +func (e *etcdRegistry) GetService(name string, opts ...registry.GetOption) ([]*registry.Service, error) { ctx, cancel := context.WithTimeout(context.Background(), e.options.Timeout) defer cancel() @@ -362,7 +362,7 @@ func (e *etcdRegistry) GetService(name string) ([]*registry.Service, error) { return services, nil } -func (e *etcdRegistry) ListServices() ([]*registry.Service, error) { +func (e *etcdRegistry) ListServices(opts ...registry.ListOption) ([]*registry.Service, error) { versions := make(map[string]*registry.Service) ctx, cancel := context.WithTimeout(context.Background(), e.options.Timeout) diff --git a/registry/mdns_registry.go b/registry/mdns_registry.go index 18758878..fde00f37 100644 --- a/registry/mdns_registry.go +++ b/registry/mdns_registry.go @@ -269,7 +269,7 @@ func (m *mdnsRegistry) Register(service *Service, opts ...RegisterOption) error return gerr } -func (m *mdnsRegistry) Deregister(service *Service) error { +func (m *mdnsRegistry) Deregister(service *Service, opts ...DeregisterOption) error { m.Lock() defer m.Unlock() @@ -304,7 +304,7 @@ func (m *mdnsRegistry) Deregister(service *Service) error { return nil } -func (m *mdnsRegistry) GetService(service string) ([]*Service, error) { +func (m *mdnsRegistry) GetService(service string, opts ...GetOption) ([]*Service, error) { serviceMap := make(map[string]*Service) entries := make(chan *mdns.ServiceEntry, 10) done := make(chan bool) @@ -396,7 +396,7 @@ func (m *mdnsRegistry) GetService(service string) ([]*Service, error) { return services, nil } -func (m *mdnsRegistry) ListServices() ([]*Service, error) { +func (m *mdnsRegistry) ListServices(opts ...ListOption) ([]*Service, error) { serviceMap := make(map[string]bool) entries := make(chan *mdns.ServiceEntry, 10) done := make(chan bool) diff --git a/registry/memory/memory.go b/registry/memory/memory.go index cde49b96..44f22439 100644 --- a/registry/memory/memory.go +++ b/registry/memory/memory.go @@ -207,7 +207,7 @@ func (m *Registry) Register(s *registry.Service, opts ...registry.RegisterOption return nil } -func (m *Registry) Deregister(s *registry.Service) error { +func (m *Registry) Deregister(s *registry.Service, opts ...registry.DeregisterOption) error { m.Lock() defer m.Unlock() @@ -240,7 +240,7 @@ func (m *Registry) Deregister(s *registry.Service) error { return nil } -func (m *Registry) GetService(name string) ([]*registry.Service, error) { +func (m *Registry) GetService(name string, opts ...registry.GetOption) ([]*registry.Service, error) { m.RLock() defer m.RUnlock() @@ -259,7 +259,7 @@ func (m *Registry) GetService(name string) ([]*registry.Service, error) { return services, nil } -func (m *Registry) ListServices() ([]*registry.Service, error) { +func (m *Registry) ListServices(opts ...registry.ListOption) ([]*registry.Service, error) { m.RLock() defer m.RUnlock() diff --git a/registry/options.go b/registry/options.go index d3a54856..03fd5eb4 100644 --- a/registry/options.go +++ b/registry/options.go @@ -32,6 +32,18 @@ type WatchOptions struct { Context context.Context } +type DeregisterOptions struct { + Context context.Context +} + +type GetOptions struct { + Context context.Context +} + +type ListOptions struct { + Context context.Context +} + // Addrs is the registry addresses to use func Addrs(addrs ...string) Option { return func(o *Options) { @@ -65,9 +77,39 @@ func RegisterTTL(t time.Duration) RegisterOption { } } +func RegisterContext(ctx context.Context) RegisterOption { + return func(o *RegisterOptions) { + o.Context = ctx + } +} + // Watch a service func WatchService(name string) WatchOption { return func(o *WatchOptions) { o.Service = name } } + +func WatchContext(ctx context.Context) WatchOption { + return func(o *WatchOptions) { + o.Context = ctx + } +} + +func DeregisterContext(ctx context.Context) DeregisterOption { + return func(o *DeregisterOptions) { + o.Context = ctx + } +} + +func GetContext(ctx context.Context) GetOption { + return func(o *GetOptions) { + o.Context = ctx + } +} + +func ListContext(ctx context.Context) ListOption { + return func(o *ListOptions) { + o.Context = ctx + } +} diff --git a/registry/registry.go b/registry/registry.go index 291a1988..dce9b431 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -21,9 +21,9 @@ type Registry interface { Init(...Option) error Options() Options Register(*Service, ...RegisterOption) error - Deregister(*Service) error - GetService(string) ([]*Service, error) - ListServices() ([]*Service, error) + Deregister(*Service, ...DeregisterOption) error + GetService(string, ...GetOption) ([]*Service, error) + ListServices(...ListOption) ([]*Service, error) Watch(...WatchOption) (Watcher, error) String() string } @@ -61,6 +61,12 @@ type RegisterOption func(*RegisterOptions) type WatchOption func(*WatchOptions) +type DeregisterOption func(*DeregisterOptions) + +type GetOption func(*GetOptions) + +type ListOption func(*ListOptions) + // Register a service node. Additionally supply options such as TTL. func Register(s *Service, opts ...RegisterOption) error { return DefaultRegistry.Register(s, opts...) diff --git a/registry/service/service.go b/registry/service/service.go index e266347f..acf42d3a 100644 --- a/registry/service/service.go +++ b/registry/service/service.go @@ -58,13 +58,16 @@ func (s *serviceRegistry) Register(srv *registry.Service, opts ...registry.Regis for _, o := range opts { o(&options) } + if options.Context == nil { + options.Context = context.TODO() + } // encode srv into protobuf and pack Register TTL into it pbSrv := ToProto(srv) pbSrv.Options.Ttl = int64(options.TTL.Seconds()) // register the service - _, err := s.client.Register(context.TODO(), pbSrv, s.callOpts()...) + _, err := s.client.Register(options.Context, pbSrv, s.callOpts()...) if err != nil { return err } @@ -72,17 +75,33 @@ func (s *serviceRegistry) Register(srv *registry.Service, opts ...registry.Regis return nil } -func (s *serviceRegistry) Deregister(srv *registry.Service) error { +func (s *serviceRegistry) Deregister(srv *registry.Service, opts ...registry.DeregisterOption) error { + var options registry.DeregisterOptions + for _, o := range opts { + o(&options) + } + if options.Context == nil { + options.Context = context.TODO() + } + // deregister the service - _, err := s.client.Deregister(context.TODO(), ToProto(srv), s.callOpts()...) + _, err := s.client.Deregister(options.Context, ToProto(srv), s.callOpts()...) if err != nil { return err } return nil } -func (s *serviceRegistry) GetService(name string) ([]*registry.Service, error) { - rsp, err := s.client.GetService(context.TODO(), &pb.GetRequest{ +func (s *serviceRegistry) GetService(name string, opts ...registry.GetOption) ([]*registry.Service, error) { + var options registry.GetOptions + for _, o := range opts { + o(&options) + } + if options.Context == nil { + options.Context = context.TODO() + } + + rsp, err := s.client.GetService(options.Context, &pb.GetRequest{ Service: name, }, s.callOpts()...) @@ -97,8 +116,16 @@ func (s *serviceRegistry) GetService(name string) ([]*registry.Service, error) { return services, nil } -func (s *serviceRegistry) ListServices() ([]*registry.Service, error) { - rsp, err := s.client.ListServices(context.TODO(), &pb.ListRequest{}, s.callOpts()...) +func (s *serviceRegistry) ListServices(opts ...registry.ListOption) ([]*registry.Service, error) { + var options registry.ListOptions + for _, o := range opts { + o(&options) + } + if options.Context == nil { + options.Context = context.TODO() + } + + rsp, err := s.client.ListServices(options.Context, &pb.ListRequest{}, s.callOpts()...) if err != nil { return nil, err } @@ -116,8 +143,11 @@ func (s *serviceRegistry) Watch(opts ...registry.WatchOption) (registry.Watcher, for _, o := range opts { o(&options) } + if options.Context == nil { + options.Context = context.TODO() + } - stream, err := s.client.Watch(context.TODO(), &pb.WatchRequest{ + stream, err := s.client.Watch(options.Context, &pb.WatchRequest{ Service: options.Service, }, s.callOpts()...) diff --git a/util/wrapper/wrapper.go b/util/wrapper/wrapper.go index 414a4c6b..e5ce4bb1 100644 --- a/util/wrapper/wrapper.go +++ b/util/wrapper/wrapper.go @@ -183,11 +183,6 @@ func AuthHandler(fn func() auth.Auth) server.HandlerWrapper { return err } - // Set the namespace in the context - if _, ok := metadata.Get(ctx, auth.NamespaceKey); !ok { - ctx = metadata.Set(ctx, auth.NamespaceKey, a.Options().Namespace) - } - // The user is authorised, allow the call return h(ctx, req, rsp) }