Fix nil account bug

This commit is contained in:
Ben Toogood 2020-05-20 16:11:34 +01:00
parent f6d9416a9e
commit 5d14970a55
3 changed files with 8 additions and 7 deletions

View File

@ -4,6 +4,7 @@ package auth
import ( import (
"context" "context"
"errors" "errors"
"strings"
"time" "time"
) )
@ -60,13 +61,13 @@ type Account struct {
} }
// HasScope returns a boolean indicating if the account has the given scope // HasScope returns a boolean indicating if the account has the given scope
func (a *Account) HasScope(scope string) bool { func (a *Account) HasScope(scopes ...string) bool {
if a.Scopes == nil { if a.Scopes == nil {
return false return false
} }
for _, s := range a.Scopes { for _, s := range a.Scopes {
if s == scope { if s == strings.Join(scopes, ".") {
return true return true
} }
} }

View File

@ -3,15 +3,15 @@ package auth
import "testing" import "testing"
func TestHasScope(t *testing.T) { func TestHasScope(t *testing.T) {
if new(Account).HasScope("namespace.foo") { if new(Account).HasScope("namespace", "foo") {
t.Errorf("Expected the blank account to not have a role") t.Errorf("Expected the blank account to not have a role")
} }
acc := Account{Scopes: []string{"namespace.foo"}} acc := Account{Scopes: []string{"namespace.foo"}}
if !acc.HasScope("namespace.foo") { if !acc.HasScope("namespace", "foo") {
t.Errorf("Expected the account to have the namespace.foo role") t.Errorf("Expected the account to have the namespace.foo role")
} }
if acc.HasScope("namespace.bar") { if acc.HasScope("namespace", "bar") {
t.Errorf("Expected the account to not have the namespace.bar role") t.Errorf("Expected the account to not have the namespace.bar role")
} }
} }

View File

@ -208,14 +208,14 @@ func AuthHandler(fn func() auth.Auth) server.HandlerWrapper {
// Verify the caller has access to the resource // Verify the caller has access to the resource
err := a.Verify(account, res) err := a.Verify(account, res)
if err != nil && len(account.ID) > 0 { if err != nil && account != nil {
return errors.Forbidden(req.Service(), "Forbidden call made to %v:%v by %v", req.Service(), req.Endpoint(), account.ID) return errors.Forbidden(req.Service(), "Forbidden call made to %v:%v by %v", req.Service(), req.Endpoint(), account.ID)
} else if err != nil { } else if err != nil {
return errors.Unauthorized(req.Service(), "Unauthorised call made to %v:%v", req.Service(), req.Endpoint()) return errors.Unauthorized(req.Service(), "Unauthorised call made to %v:%v", req.Service(), req.Endpoint())
} }
// There is an account, set it in the context // There is an account, set it in the context
if len(account.ID) > 0 { if account != nil {
ctx = auth.ContextWithAccount(ctx, account) ctx = auth.ContextWithAccount(ctx, account)
} }