From 5d14970a55ab954c05f4dceabd0b7f6a146e6995 Mon Sep 17 00:00:00 2001 From: Ben Toogood Date: Wed, 20 May 2020 16:11:34 +0100 Subject: [PATCH] Fix nil account bug --- auth/auth.go | 5 +++-- auth/auth_test.go | 6 +++--- util/wrapper/wrapper.go | 4 ++-- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/auth/auth.go b/auth/auth.go index 0a03eda5..3651fa92 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -4,6 +4,7 @@ package auth import ( "context" "errors" + "strings" "time" ) @@ -60,13 +61,13 @@ type Account struct { } // HasScope returns a boolean indicating if the account has the given scope -func (a *Account) HasScope(scope string) bool { +func (a *Account) HasScope(scopes ...string) bool { if a.Scopes == nil { return false } for _, s := range a.Scopes { - if s == scope { + if s == strings.Join(scopes, ".") { return true } } diff --git a/auth/auth_test.go b/auth/auth_test.go index 50f3a990..7985ff76 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -3,15 +3,15 @@ package auth import "testing" func TestHasScope(t *testing.T) { - if new(Account).HasScope("namespace.foo") { + if new(Account).HasScope("namespace", "foo") { t.Errorf("Expected the blank account to not have a role") } acc := Account{Scopes: []string{"namespace.foo"}} - if !acc.HasScope("namespace.foo") { + if !acc.HasScope("namespace", "foo") { t.Errorf("Expected the account to have the namespace.foo role") } - if acc.HasScope("namespace.bar") { + if acc.HasScope("namespace", "bar") { t.Errorf("Expected the account to not have the namespace.bar role") } } diff --git a/util/wrapper/wrapper.go b/util/wrapper/wrapper.go index de62288d..bf0f4a3a 100644 --- a/util/wrapper/wrapper.go +++ b/util/wrapper/wrapper.go @@ -208,14 +208,14 @@ func AuthHandler(fn func() auth.Auth) server.HandlerWrapper { // Verify the caller has access to the resource err := a.Verify(account, res) - if err != nil && len(account.ID) > 0 { + if err != nil && account != nil { return errors.Forbidden(req.Service(), "Forbidden call made to %v:%v by %v", req.Service(), req.Endpoint(), account.ID) } else if err != nil { return errors.Unauthorized(req.Service(), "Unauthorised call made to %v:%v", req.Service(), req.Endpoint()) } // There is an account, set it in the context - if len(account.ID) > 0 { + if account != nil { ctx = auth.ContextWithAccount(ctx, account) }