Merge pull request #1595 from micro/auth-client-wrapper
Auth Client Wrapper
This commit is contained in:
		| @@ -79,6 +79,13 @@ func Credentials(id, secret string) Option { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // ClientToken sets the auth token to use when making requests | ||||||
|  | func ClientToken(token *Token) Option { | ||||||
|  | 	return func(o *Options) { | ||||||
|  | 		o.Token = token | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| // Provider set the auth provider | // Provider set the auth provider | ||||||
| func Provider(p provider.Provider) Option { | func Provider(p provider.Provider) Option { | ||||||
| 	return func(o *Options) { | 	return func(o *Options) { | ||||||
|   | |||||||
| @@ -70,35 +70,6 @@ func (s *svc) Init(opts ...auth.Option) { | |||||||
| 			s.loadRules() | 			s.loadRules() | ||||||
| 		} | 		} | ||||||
| 	}() | 	}() | ||||||
|  |  | ||||||
| 	// we have client credentials and must load a new token |  | ||||||
| 	// periodically |  | ||||||
| 	if len(s.options.ID) > 0 || len(s.options.Secret) > 0 { |  | ||||||
| 		// get a token immediately |  | ||||||
| 		s.refreshToken() |  | ||||||
|  |  | ||||||
| 		go func() { |  | ||||||
| 			tokenTimer := time.NewTicker(time.Minute) |  | ||||||
|  |  | ||||||
| 			for { |  | ||||||
| 				<-tokenTimer.C |  | ||||||
|  |  | ||||||
| 				// Do not get a new token if the current one has more than three |  | ||||||
| 				// minutes remaining. We do 3 minutes to allow multiple retires in |  | ||||||
| 				// the case one request fails |  | ||||||
| 				t := s.Options().Token |  | ||||||
| 				if t != nil && t.Expiry.Unix() > time.Now().Add(time.Minute*3).Unix() { |  | ||||||
| 					continue |  | ||||||
| 				} |  | ||||||
|  |  | ||||||
| 				// jitter for up to 5 seconds, this stops |  | ||||||
| 				// all the services calling the auth service |  | ||||||
| 				// at the exact same time |  | ||||||
| 				time.Sleep(jitter.Do(time.Second * 5)) |  | ||||||
| 				s.refreshToken() |  | ||||||
| 			} |  | ||||||
| 		}() |  | ||||||
| 	} |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *svc) Options() auth.Options { | func (s *svc) Options() auth.Options { | ||||||
| @@ -313,33 +284,6 @@ func (s *svc) loadRules() { | |||||||
| 	s.rules = rsp.Rules | 	s.rules = rsp.Rules | ||||||
| } | } | ||||||
|  |  | ||||||
| // refreshToken generates a new token for the service to use when making calls |  | ||||||
| func (s *svc) refreshToken() { |  | ||||||
| 	req := &pb.TokenRequest{ |  | ||||||
| 		TokenExpiry: int64((time.Minute * 15).Seconds()), |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if s.Options().Token == nil { |  | ||||||
| 		// we do not have a token, use the credentials to get one |  | ||||||
| 		req.Id = s.Options().ID |  | ||||||
| 		req.Secret = s.Options().Secret |  | ||||||
| 	} else { |  | ||||||
| 		// we have a token, refresh it |  | ||||||
| 		req.RefreshToken = s.Options().Token.RefreshToken |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	rsp, err := s.auth.Token(context.TODO(), req) |  | ||||||
| 	s.Lock() |  | ||||||
| 	defer s.Unlock() |  | ||||||
|  |  | ||||||
| 	if err != nil { |  | ||||||
| 		log.Errorf("Error generating token: %v", err) |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	s.options.Token = serializeToken(rsp.Token) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func serializeToken(t *pb.Token) *auth.Token { | func serializeToken(t *pb.Token) *auth.Token { | ||||||
| 	return &auth.Token{ | 	return &auth.Token{ | ||||||
| 		AccessToken:  t.AccessToken, | 		AccessToken:  t.AccessToken, | ||||||
|   | |||||||
| @@ -10,7 +10,6 @@ import ( | |||||||
| 	"sync/atomic" | 	"sync/atomic" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| 	"github.com/micro/go-micro/v2/auth" |  | ||||||
| 	"github.com/micro/go-micro/v2/broker" | 	"github.com/micro/go-micro/v2/broker" | ||||||
| 	"github.com/micro/go-micro/v2/client" | 	"github.com/micro/go-micro/v2/client" | ||||||
| 	"github.com/micro/go-micro/v2/client/selector" | 	"github.com/micro/go-micro/v2/client/selector" | ||||||
| @@ -18,7 +17,6 @@ import ( | |||||||
| 	"github.com/micro/go-micro/v2/errors" | 	"github.com/micro/go-micro/v2/errors" | ||||||
| 	"github.com/micro/go-micro/v2/metadata" | 	"github.com/micro/go-micro/v2/metadata" | ||||||
| 	"github.com/micro/go-micro/v2/registry" | 	"github.com/micro/go-micro/v2/registry" | ||||||
| 	"github.com/micro/go-micro/v2/util/config" |  | ||||||
| 	pnet "github.com/micro/go-micro/v2/util/net" | 	pnet "github.com/micro/go-micro/v2/util/net" | ||||||
|  |  | ||||||
| 	"google.golang.org/grpc" | 	"google.golang.org/grpc" | ||||||
| @@ -117,13 +115,6 @@ func (g *grpcClient) call(ctx context.Context, node *registry.Node, req client.R | |||||||
| 	// set the content type for the request | 	// set the content type for the request | ||||||
| 	header["x-content-type"] = req.ContentType() | 	header["x-content-type"] = req.ContentType() | ||||||
|  |  | ||||||
| 	// set the authorization header |  | ||||||
| 	if opts.ServiceToken || len(header["authorization"]) == 0 { |  | ||||||
| 		if h := g.authorizationHeader(); len(h) > 0 { |  | ||||||
| 			header["authorization"] = h |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	md := gmetadata.New(header) | 	md := gmetadata.New(header) | ||||||
| 	ctx = gmetadata.NewOutgoingContext(ctx, md) | 	ctx = gmetadata.NewOutgoingContext(ctx, md) | ||||||
|  |  | ||||||
| @@ -202,13 +193,6 @@ func (g *grpcClient) stream(ctx context.Context, node *registry.Node, req client | |||||||
| 	// set the content type for the request | 	// set the content type for the request | ||||||
| 	header["x-content-type"] = req.ContentType() | 	header["x-content-type"] = req.ContentType() | ||||||
|  |  | ||||||
| 	// set the authorization header |  | ||||||
| 	if opts.ServiceToken || len(header["authorization"]) == 0 { |  | ||||||
| 		if h := g.authorizationHeader(); len(h) > 0 { |  | ||||||
| 			header["authorization"] = h |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	md := gmetadata.New(header) | 	md := gmetadata.New(header) | ||||||
| 	ctx = gmetadata.NewOutgoingContext(ctx, md) | 	ctx = gmetadata.NewOutgoingContext(ctx, md) | ||||||
|  |  | ||||||
| @@ -295,26 +279,6 @@ func (g *grpcClient) stream(ctx context.Context, node *registry.Node, req client | |||||||
| 	}, nil | 	}, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (g *grpcClient) authorizationHeader() string { |  | ||||||
| 	// if the caller specifies using service token or no token |  | ||||||
| 	// was passed with the request, set the service token |  | ||||||
| 	var srvToken string |  | ||||||
| 	if g.opts.Auth != nil && g.opts.Auth.Options().Token != nil { |  | ||||||
| 		srvToken = g.opts.Auth.Options().Token.AccessToken |  | ||||||
| 	} |  | ||||||
| 	if len(srvToken) > 0 { |  | ||||||
| 		return auth.BearerScheme + srvToken |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// fall back to using the authorization token set in config, |  | ||||||
| 	// this enables the CLI to provide a token |  | ||||||
| 	if token, err := config.Get("micro", "auth", "token"); err == nil && len(token) > 0 { |  | ||||||
| 		return auth.BearerScheme + token |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	return "" |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (g *grpcClient) poolMaxStreams() int { | func (g *grpcClient) poolMaxStreams() int { | ||||||
| 	if g.opts.Context == nil { | 	if g.opts.Context == nil { | ||||||
| 		return DefaultPoolMaxStreams | 		return DefaultPoolMaxStreams | ||||||
|   | |||||||
| @@ -39,8 +39,9 @@ func newService(opts ...Option) Service { | |||||||
| 	authFn := func() auth.Auth { return options.Server.Options().Auth } | 	authFn := func() auth.Auth { return options.Server.Options().Auth } | ||||||
|  |  | ||||||
| 	// wrap client to inject From-Service header on any calls | 	// wrap client to inject From-Service header on any calls | ||||||
| 	options.Client = wrapper.FromService(serviceName, options.Client, authFn) | 	options.Client = wrapper.FromService(serviceName, options.Client) | ||||||
| 	options.Client = wrapper.TraceCall(serviceName, trace.DefaultTracer, options.Client) | 	options.Client = wrapper.TraceCall(serviceName, trace.DefaultTracer, options.Client) | ||||||
|  | 	options.Client = wrapper.AuthClient(serviceName, options.Server.Options().Id, authFn, options.Client) | ||||||
|  |  | ||||||
| 	// wrap the server to provide handler stats | 	// wrap the server to provide handler stats | ||||||
| 	options.Server.Init( | 	options.Server.Init( | ||||||
|   | |||||||
| @@ -2,7 +2,9 @@ package wrapper | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
|  | 	"fmt" | ||||||
| 	"strings" | 	"strings" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
| 	"github.com/micro/go-micro/v2/auth" | 	"github.com/micro/go-micro/v2/auth" | ||||||
| 	"github.com/micro/go-micro/v2/client" | 	"github.com/micro/go-micro/v2/client" | ||||||
| @@ -11,68 +13,44 @@ import ( | |||||||
| 	"github.com/micro/go-micro/v2/errors" | 	"github.com/micro/go-micro/v2/errors" | ||||||
| 	"github.com/micro/go-micro/v2/metadata" | 	"github.com/micro/go-micro/v2/metadata" | ||||||
| 	"github.com/micro/go-micro/v2/server" | 	"github.com/micro/go-micro/v2/server" | ||||||
|  | 	"github.com/micro/go-micro/v2/util/config" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type clientWrapper struct { | type fromServiceWrapper struct { | ||||||
| 	client.Client | 	client.Client | ||||||
|  |  | ||||||
| 	// Auth interface |  | ||||||
| 	auth func() auth.Auth |  | ||||||
| 	// headers to inject | 	// headers to inject | ||||||
| 	headers metadata.Metadata | 	headers metadata.Metadata | ||||||
| } | } | ||||||
|  |  | ||||||
| type traceWrapper struct { |  | ||||||
| 	client.Client |  | ||||||
|  |  | ||||||
| 	name  string |  | ||||||
| 	trace trace.Tracer |  | ||||||
| } |  | ||||||
|  |  | ||||||
| var ( | var ( | ||||||
| 	HeaderPrefix = "Micro-" | 	HeaderPrefix = "Micro-" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func (c *clientWrapper) setHeaders(ctx context.Context) context.Context { | func (f *fromServiceWrapper) setHeaders(ctx context.Context) context.Context { | ||||||
| 	// don't overwrite keys | 	// don't overwrite keys | ||||||
| 	return metadata.MergeContext(ctx, c.headers, false) | 	return metadata.MergeContext(ctx, f.headers, false) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (c *clientWrapper) Call(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error { | func (f *fromServiceWrapper) Call(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error { | ||||||
| 	ctx = c.setHeaders(ctx) | 	ctx = f.setHeaders(ctx) | ||||||
| 	return c.Client.Call(ctx, req, rsp, opts...) | 	return f.Client.Call(ctx, req, rsp, opts...) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (c *clientWrapper) Stream(ctx context.Context, req client.Request, opts ...client.CallOption) (client.Stream, error) { | func (f *fromServiceWrapper) Stream(ctx context.Context, req client.Request, opts ...client.CallOption) (client.Stream, error) { | ||||||
| 	ctx = c.setHeaders(ctx) | 	ctx = f.setHeaders(ctx) | ||||||
| 	return c.Client.Stream(ctx, req, opts...) | 	return f.Client.Stream(ctx, req, opts...) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (c *clientWrapper) Publish(ctx context.Context, p client.Message, opts ...client.PublishOption) error { | func (f *fromServiceWrapper) Publish(ctx context.Context, p client.Message, opts ...client.PublishOption) error { | ||||||
| 	ctx = c.setHeaders(ctx) | 	ctx = f.setHeaders(ctx) | ||||||
| 	return c.Client.Publish(ctx, p, opts...) | 	return f.Client.Publish(ctx, p, opts...) | ||||||
| } |  | ||||||
|  |  | ||||||
| func (c *traceWrapper) Call(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error { |  | ||||||
| 	newCtx, s := c.trace.Start(ctx, req.Service()+"."+req.Endpoint()) |  | ||||||
|  |  | ||||||
| 	s.Type = trace.SpanTypeRequestOutbound |  | ||||||
| 	err := c.Client.Call(newCtx, req, rsp, opts...) |  | ||||||
| 	if err != nil { |  | ||||||
| 		s.Metadata["error"] = err.Error() |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// finish the trace |  | ||||||
| 	c.trace.Finish(s) |  | ||||||
|  |  | ||||||
| 	return err |  | ||||||
| } | } | ||||||
|  |  | ||||||
| // FromService wraps a client to inject service and auth metadata | // FromService wraps a client to inject service and auth metadata | ||||||
| func FromService(name string, c client.Client, fn func() auth.Auth) client.Client { | func FromService(name string, c client.Client) client.Client { | ||||||
| 	return &clientWrapper{ | 	return &fromServiceWrapper{ | ||||||
| 		c, | 		c, | ||||||
| 		fn, |  | ||||||
| 		metadata.Metadata{ | 		metadata.Metadata{ | ||||||
| 			HeaderPrefix + "From-Service": name, | 			HeaderPrefix + "From-Service": name, | ||||||
| 		}, | 		}, | ||||||
| @@ -95,6 +73,28 @@ func HandlerStats(stats stats.Stats) server.HandlerWrapper { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | type traceWrapper struct { | ||||||
|  | 	client.Client | ||||||
|  |  | ||||||
|  | 	name  string | ||||||
|  | 	trace trace.Tracer | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *traceWrapper) Call(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error { | ||||||
|  | 	newCtx, s := c.trace.Start(ctx, req.Service()+"."+req.Endpoint()) | ||||||
|  |  | ||||||
|  | 	s.Type = trace.SpanTypeRequestOutbound | ||||||
|  | 	err := c.Client.Call(newCtx, req, rsp, opts...) | ||||||
|  | 	if err != nil { | ||||||
|  | 		s.Metadata["error"] = err.Error() | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// finish the trace | ||||||
|  | 	c.trace.Finish(s) | ||||||
|  |  | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  |  | ||||||
| // TraceCall is a call tracing wrapper | // TraceCall is a call tracing wrapper | ||||||
| func TraceCall(name string, t trace.Tracer, c client.Client) client.Client { | func TraceCall(name string, t trace.Tracer, c client.Client) client.Client { | ||||||
| 	return &traceWrapper{ | 	return &traceWrapper{ | ||||||
| @@ -132,6 +132,104 @@ func TraceHandler(t trace.Tracer) server.HandlerWrapper { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | type authWrapper struct { | ||||||
|  | 	client.Client | ||||||
|  | 	name string | ||||||
|  | 	id   string | ||||||
|  | 	auth func() auth.Auth | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (a *authWrapper) Call(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error { | ||||||
|  | 	// parse the options | ||||||
|  | 	var options client.CallOptions | ||||||
|  | 	for _, o := range opts { | ||||||
|  | 		o(&options) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// check to see if the authorization header has already been set. | ||||||
|  | 	// We dont't override the header unless the ServiceToken option has | ||||||
|  | 	// been specified or the header wasn't provided | ||||||
|  | 	if _, ok := metadata.Get(ctx, "Authorization"); ok && !options.ServiceToken { | ||||||
|  | 		return a.Client.Call(ctx, req, rsp, opts...) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// if auth is nil we won't be able to get an access token, so we execute | ||||||
|  | 	// the request without one. | ||||||
|  | 	aa := a.auth() | ||||||
|  | 	if a == nil { | ||||||
|  | 		return a.Client.Call(ctx, req, rsp, opts...) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// performs the call with the authorization token provided | ||||||
|  | 	callWithToken := func(token string) error { | ||||||
|  | 		ctx := metadata.Set(ctx, "Authorization", auth.BearerScheme+token) | ||||||
|  | 		return a.Client.Call(ctx, req, rsp, opts...) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// check to see if we have a valid access token | ||||||
|  | 	aaOpts := aa.Options() | ||||||
|  | 	if aaOpts.Token != nil && aaOpts.Token.Expiry.Unix() > time.Now().Unix() { | ||||||
|  | 		return callWithToken(aaOpts.Token.AccessToken) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// if we have a refresh token we can use this to generate another access token | ||||||
|  | 	if aaOpts.Token != nil { | ||||||
|  | 		tok, err := aa.Token(auth.WithToken(aaOpts.Token.RefreshToken)) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		aa.Init(auth.ClientToken(tok)) | ||||||
|  | 		return callWithToken(tok.AccessToken) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// if we have credentials we can generate a new token for the account | ||||||
|  | 	if len(aaOpts.ID) > 0 && len(aaOpts.Secret) > 0 { | ||||||
|  | 		tok, err := aa.Token(auth.WithCredentials(aaOpts.ID, aaOpts.Secret)) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		aa.Init(auth.ClientToken(tok)) | ||||||
|  | 		return callWithToken(tok.AccessToken) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// check to see if a token was provided in config, this is normally used for | ||||||
|  | 	// setting the token when calling via the cli | ||||||
|  | 	if token, err := config.Get("micro", "auth", "token"); err == nil && len(token) > 0 { | ||||||
|  | 		return callWithToken(token) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// determine the type of service from the name. we do this so we can allocate | ||||||
|  | 	// different roles depending on the type of services. e.g. we don't want web | ||||||
|  | 	// services talking directly to the runtime. TODO: find a better way to determine | ||||||
|  | 	// the type of service | ||||||
|  | 	serviceType := "service" | ||||||
|  | 	if strings.Contains(a.name, "api") { | ||||||
|  | 		serviceType = "api" | ||||||
|  | 	} else if strings.Contains(a.name, "web") { | ||||||
|  | 		serviceType = "web" | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// generate a new auth account for the service | ||||||
|  | 	name := fmt.Sprintf("%v-%v", a.name, a.id) | ||||||
|  | 	acc, err := aa.Generate(name, auth.WithNamespace(aaOpts.Namespace), auth.WithRoles(serviceType)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	token, err := aa.Token(auth.WithCredentials(acc.ID, acc.Secret)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	aa.Init(auth.ClientToken(token)) | ||||||
|  |  | ||||||
|  | 	// use the token to execute the request | ||||||
|  | 	return callWithToken(token.AccessToken) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // AuthClient wraps requests with the auth header | ||||||
|  | func AuthClient(name string, id string, auth func() auth.Auth, c client.Client) client.Client { | ||||||
|  | 	return &authWrapper{c, name, id, auth} | ||||||
|  | } | ||||||
|  |  | ||||||
| // AuthHandler wraps a server handler to perform auth | // AuthHandler wraps a server handler to perform auth | ||||||
| func AuthHandler(fn func() auth.Auth) server.HandlerWrapper { | func AuthHandler(fn func() auth.Auth) server.HandlerWrapper { | ||||||
| 	return func(h server.HandlerFunc) server.HandlerFunc { | 	return func(h server.HandlerFunc) server.HandlerFunc { | ||||||
|   | |||||||
| @@ -4,7 +4,6 @@ import ( | |||||||
| 	"context" | 	"context" | ||||||
| 	"testing" | 	"testing" | ||||||
|  |  | ||||||
| 	"github.com/micro/go-micro/v2/auth" |  | ||||||
| 	"github.com/micro/go-micro/v2/metadata" | 	"github.com/micro/go-micro/v2/metadata" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -33,8 +32,7 @@ func TestWrapper(t *testing.T) { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	for _, d := range testData { | 	for _, d := range testData { | ||||||
| 		c := &clientWrapper{ | 		c := &fromServiceWrapper{ | ||||||
| 			auth:    func() auth.Auth { return nil }, |  | ||||||
| 			headers: d.headers, | 			headers: d.headers, | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user