Add context options to the runtime

This commit is contained in:
Ben Toogood 2020-04-14 12:32:59 +01:00
parent 0c75a0306b
commit e17825474f
10 changed files with 112 additions and 43 deletions

View File

@ -90,8 +90,6 @@ type Token struct {
const ( const (
// DefaultNamespace used for auth // DefaultNamespace used for auth
DefaultNamespace = "go.micro" DefaultNamespace = "go.micro"
// NamespaceKey is the key used when storing the namespace in metadata
NamespaceKey = "Micro-Namespace"
// MetadataKey is the key used when storing the account in metadata // MetadataKey is the key used when storing the account in metadata
MetadataKey = "auth-account" MetadataKey = "auth-account"
// TokenCookieName is the name of the cookie which stores the auth token // TokenCookieName is the name of the cookie which stores the auth token
@ -133,11 +131,3 @@ func ContextWithAccount(ctx context.Context, account *Account) (context.Context,
// generate a new context with the MetadataKey set // generate a new context with the MetadataKey set
return metadata.Set(ctx, MetadataKey, string(bytes)), nil return metadata.Set(ctx, MetadataKey, string(bytes)), nil
} }
// NamespaceFromContext gets the namespace from the context
func NamespaceFromContext(ctx context.Context) string {
if ns, ok := metadata.Get(ctx, NamespaceKey); ok {
return ns
}
return DefaultNamespace
}

View File

@ -53,8 +53,9 @@ func (s *svc) Init(opts ...auth.Option) {
} }
// load rules periodically from the auth service // load rules periodically from the auth service
ruleTimer := time.NewTicker(time.Second * 30)
go func() { go func() {
ruleTimer := time.NewTicker(time.Second * 30)
// load rules immediately on startup // load rules immediately on startup
s.loadRules() s.loadRules()
@ -72,10 +73,11 @@ func (s *svc) Init(opts ...auth.Option) {
// we have client credentials and must load a new token // we have client credentials and must load a new token
// periodically // periodically
if len(s.options.ID) > 0 || len(s.options.Secret) > 0 { if len(s.options.ID) > 0 || len(s.options.Secret) > 0 {
tokenTimer := time.NewTicker(time.Minute) // get a token immediately
s.refreshToken()
go func() { go func() {
s.refreshToken() tokenTimer := time.NewTicker(time.Minute)
for { for {
<-tokenTimer.C <-tokenTimer.C
@ -178,11 +180,15 @@ func (s *svc) Verify(acc *auth.Account, res *auth.Resource) error {
} }
} }
// set a default account id to log // set a default account id / namespace to log
logID := acc.ID logID := acc.ID
if len(logID) == 0 { if len(logID) == 0 {
logID = "[no account]" logID = "[no account]"
} }
logNamespace := acc.Namespace
if len(logNamespace) == 0 {
logNamespace = "[no namespace]"
}
for _, q := range queries { for _, q := range queries {
for _, rule := range s.listRules(q...) { for _, rule := range s.listRules(q...) {
@ -190,17 +196,17 @@ func (s *svc) Verify(acc *auth.Account, res *auth.Resource) error {
case pb.Access_UNKNOWN: case pb.Access_UNKNOWN:
continue // rule did not specify access, check the next rule continue // rule did not specify access, check the next rule
case pb.Access_GRANTED: case pb.Access_GRANTED:
log.Infof("%v:%v granted access to %v:%v:%v:%v by rule %v", acc.Namespace, logID, res.Namespace, res.Type, res.Name, res.Endpoint, rule.Id) log.Tracef("%v:%v granted access to %v:%v:%v:%v by rule %v", logNamespace, logID, res.Namespace, res.Type, res.Name, res.Endpoint, rule.Id)
return nil // rule grants the account access to the resource return nil // rule grants the account access to the resource
case pb.Access_DENIED: case pb.Access_DENIED:
log.Infof("%v:%v denied access to %v:%v:%v:%v by rule %v", acc.Namespace, logID, res.Namespace, res.Type, res.Name, res.Endpoint, rule.Id) log.Tracef("%v:%v denied access to %v:%v:%v:%v by rule %v", logNamespace, logID, res.Namespace, res.Type, res.Name, res.Endpoint, rule.Id)
return auth.ErrForbidden // rule denies access to the resource return auth.ErrForbidden // rule denies access to the resource
} }
} }
} }
// no rules were found for the resource, default to denying access // no rules were found for the resource, default to denying access
log.Infof("%v:%v denied access to %v:%v:%v:%v by lack of rule (%v rules found for namespace)", acc.Namespace, logID, res.Namespace, res.Type, res.Name, res.Endpoint, len(s.listRules(res.Namespace))) log.Tracef("%v:%v denied access to %v:%v:%v:%v by lack of rule (%v rules found for namespace)", logNamespace, logID, res.Namespace, res.Type, res.Name, res.Endpoint, len(s.listRules(res.Namespace)))
return auth.ErrForbidden return auth.ErrForbidden
} }

View File

@ -419,7 +419,7 @@ func (c *cache) watch(w registry.Watcher) error {
} }
} }
func (c *cache) GetService(service string) ([]*registry.Service, error) { func (c *cache) GetService(service string, opts ...registry.GetOption) ([]*registry.Service, error) {
// get the service // get the service
services, err := c.get(service) services, err := c.get(service)
if err != nil { if err != nil {

View File

@ -276,7 +276,7 @@ func (e *etcdRegistry) registerNode(s *registry.Service, node *registry.Node, op
return nil return nil
} }
func (e *etcdRegistry) Deregister(s *registry.Service) error { func (e *etcdRegistry) Deregister(s *registry.Service, opts ...registry.DeregisterOption) error {
if len(s.Nodes) == 0 { if len(s.Nodes) == 0 {
return errors.New("Require at least one node") return errors.New("Require at least one node")
} }
@ -322,7 +322,7 @@ func (e *etcdRegistry) Register(s *registry.Service, opts ...registry.RegisterOp
return gerr return gerr
} }
func (e *etcdRegistry) GetService(name string) ([]*registry.Service, error) { func (e *etcdRegistry) GetService(name string, opts ...registry.GetOption) ([]*registry.Service, error) {
ctx, cancel := context.WithTimeout(context.Background(), e.options.Timeout) ctx, cancel := context.WithTimeout(context.Background(), e.options.Timeout)
defer cancel() defer cancel()
@ -362,7 +362,7 @@ func (e *etcdRegistry) GetService(name string) ([]*registry.Service, error) {
return services, nil return services, nil
} }
func (e *etcdRegistry) ListServices() ([]*registry.Service, error) { func (e *etcdRegistry) ListServices(opts ...registry.ListOption) ([]*registry.Service, error) {
versions := make(map[string]*registry.Service) versions := make(map[string]*registry.Service)
ctx, cancel := context.WithTimeout(context.Background(), e.options.Timeout) ctx, cancel := context.WithTimeout(context.Background(), e.options.Timeout)

View File

@ -269,7 +269,7 @@ func (m *mdnsRegistry) Register(service *Service, opts ...RegisterOption) error
return gerr return gerr
} }
func (m *mdnsRegistry) Deregister(service *Service) error { func (m *mdnsRegistry) Deregister(service *Service, opts ...DeregisterOption) error {
m.Lock() m.Lock()
defer m.Unlock() defer m.Unlock()
@ -304,7 +304,7 @@ func (m *mdnsRegistry) Deregister(service *Service) error {
return nil return nil
} }
func (m *mdnsRegistry) GetService(service string) ([]*Service, error) { func (m *mdnsRegistry) GetService(service string, opts ...GetOption) ([]*Service, error) {
serviceMap := make(map[string]*Service) serviceMap := make(map[string]*Service)
entries := make(chan *mdns.ServiceEntry, 10) entries := make(chan *mdns.ServiceEntry, 10)
done := make(chan bool) done := make(chan bool)
@ -396,7 +396,7 @@ func (m *mdnsRegistry) GetService(service string) ([]*Service, error) {
return services, nil return services, nil
} }
func (m *mdnsRegistry) ListServices() ([]*Service, error) { func (m *mdnsRegistry) ListServices(opts ...ListOption) ([]*Service, error) {
serviceMap := make(map[string]bool) serviceMap := make(map[string]bool)
entries := make(chan *mdns.ServiceEntry, 10) entries := make(chan *mdns.ServiceEntry, 10)
done := make(chan bool) done := make(chan bool)

View File

@ -207,7 +207,7 @@ func (m *Registry) Register(s *registry.Service, opts ...registry.RegisterOption
return nil return nil
} }
func (m *Registry) Deregister(s *registry.Service) error { func (m *Registry) Deregister(s *registry.Service, opts ...registry.DeregisterOption) error {
m.Lock() m.Lock()
defer m.Unlock() defer m.Unlock()
@ -240,7 +240,7 @@ func (m *Registry) Deregister(s *registry.Service) error {
return nil return nil
} }
func (m *Registry) GetService(name string) ([]*registry.Service, error) { func (m *Registry) GetService(name string, opts ...registry.GetOption) ([]*registry.Service, error) {
m.RLock() m.RLock()
defer m.RUnlock() defer m.RUnlock()
@ -259,7 +259,7 @@ func (m *Registry) GetService(name string) ([]*registry.Service, error) {
return services, nil return services, nil
} }
func (m *Registry) ListServices() ([]*registry.Service, error) { func (m *Registry) ListServices(opts ...registry.ListOption) ([]*registry.Service, error) {
m.RLock() m.RLock()
defer m.RUnlock() defer m.RUnlock()

View File

@ -32,6 +32,18 @@ type WatchOptions struct {
Context context.Context Context context.Context
} }
type DeregisterOptions struct {
Context context.Context
}
type GetOptions struct {
Context context.Context
}
type ListOptions struct {
Context context.Context
}
// Addrs is the registry addresses to use // Addrs is the registry addresses to use
func Addrs(addrs ...string) Option { func Addrs(addrs ...string) Option {
return func(o *Options) { return func(o *Options) {
@ -65,9 +77,39 @@ func RegisterTTL(t time.Duration) RegisterOption {
} }
} }
func RegisterContext(ctx context.Context) RegisterOption {
return func(o *RegisterOptions) {
o.Context = ctx
}
}
// Watch a service // Watch a service
func WatchService(name string) WatchOption { func WatchService(name string) WatchOption {
return func(o *WatchOptions) { return func(o *WatchOptions) {
o.Service = name o.Service = name
} }
} }
func WatchContext(ctx context.Context) WatchOption {
return func(o *WatchOptions) {
o.Context = ctx
}
}
func DeregisterContext(ctx context.Context) DeregisterOption {
return func(o *DeregisterOptions) {
o.Context = ctx
}
}
func GetContext(ctx context.Context) GetOption {
return func(o *GetOptions) {
o.Context = ctx
}
}
func ListContext(ctx context.Context) ListOption {
return func(o *ListOptions) {
o.Context = ctx
}
}

View File

@ -21,9 +21,9 @@ type Registry interface {
Init(...Option) error Init(...Option) error
Options() Options Options() Options
Register(*Service, ...RegisterOption) error Register(*Service, ...RegisterOption) error
Deregister(*Service) error Deregister(*Service, ...DeregisterOption) error
GetService(string) ([]*Service, error) GetService(string, ...GetOption) ([]*Service, error)
ListServices() ([]*Service, error) ListServices(...ListOption) ([]*Service, error)
Watch(...WatchOption) (Watcher, error) Watch(...WatchOption) (Watcher, error)
String() string String() string
} }
@ -61,6 +61,12 @@ type RegisterOption func(*RegisterOptions)
type WatchOption func(*WatchOptions) type WatchOption func(*WatchOptions)
type DeregisterOption func(*DeregisterOptions)
type GetOption func(*GetOptions)
type ListOption func(*ListOptions)
// Register a service node. Additionally supply options such as TTL. // Register a service node. Additionally supply options such as TTL.
func Register(s *Service, opts ...RegisterOption) error { func Register(s *Service, opts ...RegisterOption) error {
return DefaultRegistry.Register(s, opts...) return DefaultRegistry.Register(s, opts...)

View File

@ -58,13 +58,16 @@ func (s *serviceRegistry) Register(srv *registry.Service, opts ...registry.Regis
for _, o := range opts { for _, o := range opts {
o(&options) o(&options)
} }
if options.Context == nil {
options.Context = context.TODO()
}
// encode srv into protobuf and pack Register TTL into it // encode srv into protobuf and pack Register TTL into it
pbSrv := ToProto(srv) pbSrv := ToProto(srv)
pbSrv.Options.Ttl = int64(options.TTL.Seconds()) pbSrv.Options.Ttl = int64(options.TTL.Seconds())
// register the service // register the service
_, err := s.client.Register(context.TODO(), pbSrv, s.callOpts()...) _, err := s.client.Register(options.Context, pbSrv, s.callOpts()...)
if err != nil { if err != nil {
return err return err
} }
@ -72,17 +75,33 @@ func (s *serviceRegistry) Register(srv *registry.Service, opts ...registry.Regis
return nil return nil
} }
func (s *serviceRegistry) Deregister(srv *registry.Service) error { func (s *serviceRegistry) Deregister(srv *registry.Service, opts ...registry.DeregisterOption) error {
var options registry.DeregisterOptions
for _, o := range opts {
o(&options)
}
if options.Context == nil {
options.Context = context.TODO()
}
// deregister the service // deregister the service
_, err := s.client.Deregister(context.TODO(), ToProto(srv), s.callOpts()...) _, err := s.client.Deregister(options.Context, ToProto(srv), s.callOpts()...)
if err != nil { if err != nil {
return err return err
} }
return nil return nil
} }
func (s *serviceRegistry) GetService(name string) ([]*registry.Service, error) { func (s *serviceRegistry) GetService(name string, opts ...registry.GetOption) ([]*registry.Service, error) {
rsp, err := s.client.GetService(context.TODO(), &pb.GetRequest{ var options registry.GetOptions
for _, o := range opts {
o(&options)
}
if options.Context == nil {
options.Context = context.TODO()
}
rsp, err := s.client.GetService(options.Context, &pb.GetRequest{
Service: name, Service: name,
}, s.callOpts()...) }, s.callOpts()...)
@ -97,8 +116,16 @@ func (s *serviceRegistry) GetService(name string) ([]*registry.Service, error) {
return services, nil return services, nil
} }
func (s *serviceRegistry) ListServices() ([]*registry.Service, error) { func (s *serviceRegistry) ListServices(opts ...registry.ListOption) ([]*registry.Service, error) {
rsp, err := s.client.ListServices(context.TODO(), &pb.ListRequest{}, s.callOpts()...) var options registry.ListOptions
for _, o := range opts {
o(&options)
}
if options.Context == nil {
options.Context = context.TODO()
}
rsp, err := s.client.ListServices(options.Context, &pb.ListRequest{}, s.callOpts()...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -116,8 +143,11 @@ func (s *serviceRegistry) Watch(opts ...registry.WatchOption) (registry.Watcher,
for _, o := range opts { for _, o := range opts {
o(&options) o(&options)
} }
if options.Context == nil {
options.Context = context.TODO()
}
stream, err := s.client.Watch(context.TODO(), &pb.WatchRequest{ stream, err := s.client.Watch(options.Context, &pb.WatchRequest{
Service: options.Service, Service: options.Service,
}, s.callOpts()...) }, s.callOpts()...)

View File

@ -183,11 +183,6 @@ func AuthHandler(fn func() auth.Auth) server.HandlerWrapper {
return err return err
} }
// Set the namespace in the context
if _, ok := metadata.Get(ctx, auth.NamespaceKey); !ok {
ctx = metadata.Set(ctx, auth.NamespaceKey, a.Options().Namespace)
}
// The user is authorised, allow the call // The user is authorised, allow the call
return h(ctx, req, rsp) return h(ctx, req, rsp)
} }