Auth Wrapper (#1174)

* Auth Wrapper

* Tweak cmd flag

* auth_excludes => auth_exclude

* Make Auth.Excludes variadic

* Use metadata.Get (passes through http and http2 it will go through various case formats)

* fix auth wrapper auth.Auth interface initialisation

Co-authored-by: Asim Aslam <asim@aslam.me>
This commit is contained in:
ben-toogood 2020-02-10 08:26:28 +00:00 committed by GitHub
parent c706afcf04
commit 4401c12e6c
9 changed files with 111 additions and 18 deletions

View File

@ -11,6 +11,8 @@ type Auth interface {
String() string
// Init the auth package
Init(opts ...Option) error
// Options returns the options set
Options() Options
// Generate a new auth Account
Generate(id string, opts ...GenerateOption) (*Account, error)
// Revoke an authorization Account

View File

@ -18,6 +18,11 @@ func (a *noop) Init(...Option) error {
return nil
}
// Options set in init
func (a *noop) Options() Options {
return a.options
}
// Generate a new auth Account
func (a *noop) Generate(id string, ops ...GenerateOption) (*Account, error) {
return nil, nil

View File

@ -33,6 +33,10 @@ func (s *svc) String() string {
return "jwt"
}
func (s *svc) Options() auth.Options {
return s.options
}
func (s *svc) Init(opts ...auth.Option) error {
for _, o := range opts {
o(&s.options)

View File

@ -7,10 +7,18 @@ import (
type Options struct {
PublicKey []byte
PrivateKey []byte
Excludes []string
}
type Option func(o *Options)
// Excludes endpoints from auth
func Excludes(excludes ...string) Option {
return func(o *Options) {
o.Excludes = excludes
}
}
// PublicKey is the JWT public key
func PublicKey(key string) Option {
return func(o *Options) {

View File

@ -37,6 +37,10 @@ func (s *svc) Init(opts ...auth.Option) error {
return nil
}
func (s *svc) Options() auth.Options {
return s.options
}
// Generate a new auth account
func (s *svc) Generate(id string, opts ...auth.GenerateOption) (*auth.Account, error) {
// construct the request

View File

@ -254,6 +254,11 @@ var (
EnvVars: []string{"MICRO_AUTH_PRIVATE_KEY"},
Usage: "Private key for JWT auth (base64 encoded PEM)",
},
&cli.StringSliceFlag{
Name: "auth_exclude",
EnvVars: []string{"MICRO_AUTH_EXCLUDE"},
Usage: "Comma-separated list of endpoints excluded from authentication",
},
}
DefaultBrokers = map[string]func(...broker.Option) broker.Broker{
@ -319,6 +324,7 @@ func init() {
func newCmd(opts ...Option) Cmd {
options := Options{
Auth: &auth.DefaultAuth,
Broker: &broker.DefaultBroker,
Client: &client.DefaultClient,
Registry: &registry.DefaultRegistry,
@ -328,7 +334,6 @@ func newCmd(opts ...Option) Cmd {
Runtime: &runtime.DefaultRuntime,
Store: &store.DefaultStore,
Tracer: &trace.DefaultTracer,
Auth: &auth.DefaultAuth,
Brokers: DefaultBrokers,
Clients: DefaultClients,
@ -379,6 +384,7 @@ func (c *cmd) Options() Options {
func (c *cmd) Before(ctx *cli.Context) error {
// If flags are set then use them otherwise do nothing
var authOpts []auth.Option
var serverOpts []server.Option
var clientOpts []client.Option
@ -414,12 +420,12 @@ func (c *cmd) Before(ctx *cli.Context) error {
// Set the auth
if name := ctx.String("auth"); len(name) > 0 {
r, ok := c.opts.Auths[name]
a, ok := c.opts.Auths[name]
if !ok {
return fmt.Errorf("Unsupported auth: %s", name)
}
*c.opts.Auth = r()
*c.opts.Auth = a()
}
// Set the client
@ -571,24 +577,30 @@ func (c *cmd) Before(ctx *cli.Context) error {
serverOpts = append(serverOpts, server.RegisterInterval(val*time.Second))
}
if len(ctx.String("auth_public_key")) > 0 {
if err := (*c.opts.Auth).Init(auth.PublicKey(ctx.String("auth_public_key"))); err != nil {
log.Fatalf("Error configuring auth: %v", err)
}
}
if len(ctx.String("auth_private_key")) > 0 {
if err := (*c.opts.Auth).Init(auth.PrivateKey(ctx.String("auth_private_key"))); err != nil {
log.Fatalf("Error configuring auth: %v", err)
}
}
if len(ctx.String("runtime_source")) > 0 {
if err := (*c.opts.Runtime).Init(runtime.WithSource(ctx.String("runtime_source"))); err != nil {
log.Fatalf("Error configuring runtime: %v", err)
}
}
if len(ctx.String("auth_public_key")) > 0 {
authOpts = append(authOpts, auth.PublicKey(ctx.String("auth_public_key")))
}
if len(ctx.String("auth_private_key")) > 0 {
authOpts = append(authOpts, auth.PrivateKey(ctx.String("auth_private_key")))
}
if len(ctx.StringSlice("auth_exclude")) > 0 {
authOpts = append(authOpts, auth.Excludes(ctx.StringSlice("auth_exclude")...))
}
if len(authOpts) > 0 {
if err := (*c.opts.Auth).Init(authOpts...); err != nil {
log.Fatalf("Error configuring auth: %v", err)
}
}
// client opts
if r := ctx.Int("client_retries"); r >= 0 {
clientOpts = append(clientOpts, client.Retries(r))

View File

@ -18,6 +18,7 @@ import (
// Options for micro service
type Options struct {
Auth auth.Auth
Broker broker.Broker
Cmd cmd.Cmd
Client client.Client
@ -40,6 +41,7 @@ type Options struct {
func newOptions(opts ...Option) Options {
opt := Options{
Auth: auth.DefaultAuth,
Broker: broker.DefaultBroker,
Cmd: cmd.DefaultCmd,
Client: client.DefaultClient,
@ -127,6 +129,7 @@ func Tracer(t trace.Tracer) Option {
// Auth sets the auth for the service
func Auth(a auth.Auth) Option {
return func(o *Options) {
o.Auth = a
o.Server.Init(server.Auth(a))
}
}

View File

@ -8,6 +8,7 @@ import (
"sync"
"syscall"
"github.com/micro/go-micro/v2/auth"
"github.com/micro/go-micro/v2/client"
"github.com/micro/go-micro/v2/config/cmd"
"github.com/micro/go-micro/v2/debug/profile"
@ -29,11 +30,15 @@ type service struct {
}
func newService(opts ...Option) Service {
service := new(service)
options := newOptions(opts...)
// service name
serviceName := options.Server.Options().Name
// TODO: better accessors
authFn := func() auth.Auth { return service.opts.Auth }
// wrap client to inject From-Service header on any calls
options.Client = wrapper.FromService(serviceName, options.Client)
options.Client = wrapper.TraceCall(serviceName, trace.DefaultTracer, options.Client)
@ -42,11 +47,13 @@ func newService(opts ...Option) Service {
options.Server.Init(
server.WrapHandler(wrapper.HandlerStats(stats.DefaultStats)),
server.WrapHandler(wrapper.TraceHandler(trace.DefaultTracer)),
server.WrapHandler(wrapper.AuthHandler(authFn)),
)
return &service{
opts: options,
}
// set opts
service.opts = options
return service
}
func (s *service) Name() string {
@ -88,6 +95,7 @@ func (s *service) Init(opts ...Option) {
// Initialise the command flags, overriding new service
if err := s.opts.Cmd.Init(
cmd.Auth(&s.opts.Auth),
cmd.Broker(&s.opts.Broker),
cmd.Registry(&s.opts.Registry),
cmd.Transport(&s.opts.Transport),

View File

@ -4,9 +4,11 @@ import (
"context"
"strings"
"github.com/micro/go-micro/v2/auth"
"github.com/micro/go-micro/v2/client"
"github.com/micro/go-micro/v2/debug/stats"
"github.com/micro/go-micro/v2/debug/trace"
"github.com/micro/go-micro/v2/errors"
"github.com/micro/go-micro/v2/metadata"
"github.com/micro/go-micro/v2/server"
)
@ -25,6 +27,7 @@ type traceWrapper struct {
var (
HeaderPrefix = "Micro-"
BearerSchema = "Bearer "
)
func (c *clientWrapper) setHeaders(ctx context.Context) context.Context {
@ -132,3 +135,47 @@ func TraceHandler(t trace.Tracer) server.HandlerWrapper {
}
}
}
// AuthHandler wraps a server handler to perform auth
func AuthHandler(fn func() auth.Auth) server.HandlerWrapper {
return func(h server.HandlerFunc) server.HandlerFunc {
return func(ctx context.Context, req server.Request, rsp interface{}) error {
// get the auth.Auth interface
a := fn()
// Extract endpoint and remove service name prefix
// (e.g. Platform.ListServices => ListServices)
var endpoint string
if ec := strings.Split(req.Endpoint(), "."); len(ec) == 2 {
endpoint = ec[1]
}
// Check for endpoints excluded from auth. If the endpoint
// matches, execute the handler and return
for _, e := range a.Options().Excludes {
if e == endpoint {
return h(ctx, req, rsp)
}
}
// Extract the token if present. Note: if noop is being used
// then the token can be blank without erroring
var token string
if header, ok := metadata.Get(ctx, "Authorization"); ok {
// Ensure the correct scheme is being used
if !strings.HasPrefix(header, BearerSchema) {
return errors.Unauthorized("go.micro.auth", "invalid authorization header. expected Bearer schema")
}
token = header[len(BearerSchema):]
}
// Validate the token
if _, err := a.Validate(token); err != nil {
return err
}
return h(ctx, req, rsp)
}
}
}