Merge pull request #1595 from micro/auth-client-wrapper
Auth Client Wrapper
This commit is contained in:
		| @@ -79,6 +79,13 @@ func Credentials(id, secret string) Option { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ClientToken sets the auth token to use when making requests | ||||
| func ClientToken(token *Token) Option { | ||||
| 	return func(o *Options) { | ||||
| 		o.Token = token | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Provider set the auth provider | ||||
| func Provider(p provider.Provider) Option { | ||||
| 	return func(o *Options) { | ||||
|   | ||||
| @@ -70,35 +70,6 @@ func (s *svc) Init(opts ...auth.Option) { | ||||
| 			s.loadRules() | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	// we have client credentials and must load a new token | ||||
| 	// periodically | ||||
| 	if len(s.options.ID) > 0 || len(s.options.Secret) > 0 { | ||||
| 		// get a token immediately | ||||
| 		s.refreshToken() | ||||
|  | ||||
| 		go func() { | ||||
| 			tokenTimer := time.NewTicker(time.Minute) | ||||
|  | ||||
| 			for { | ||||
| 				<-tokenTimer.C | ||||
|  | ||||
| 				// Do not get a new token if the current one has more than three | ||||
| 				// minutes remaining. We do 3 minutes to allow multiple retires in | ||||
| 				// the case one request fails | ||||
| 				t := s.Options().Token | ||||
| 				if t != nil && t.Expiry.Unix() > time.Now().Add(time.Minute*3).Unix() { | ||||
| 					continue | ||||
| 				} | ||||
|  | ||||
| 				// jitter for up to 5 seconds, this stops | ||||
| 				// all the services calling the auth service | ||||
| 				// at the exact same time | ||||
| 				time.Sleep(jitter.Do(time.Second * 5)) | ||||
| 				s.refreshToken() | ||||
| 			} | ||||
| 		}() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (s *svc) Options() auth.Options { | ||||
| @@ -313,33 +284,6 @@ func (s *svc) loadRules() { | ||||
| 	s.rules = rsp.Rules | ||||
| } | ||||
|  | ||||
| // refreshToken generates a new token for the service to use when making calls | ||||
| func (s *svc) refreshToken() { | ||||
| 	req := &pb.TokenRequest{ | ||||
| 		TokenExpiry: int64((time.Minute * 15).Seconds()), | ||||
| 	} | ||||
|  | ||||
| 	if s.Options().Token == nil { | ||||
| 		// we do not have a token, use the credentials to get one | ||||
| 		req.Id = s.Options().ID | ||||
| 		req.Secret = s.Options().Secret | ||||
| 	} else { | ||||
| 		// we have a token, refresh it | ||||
| 		req.RefreshToken = s.Options().Token.RefreshToken | ||||
| 	} | ||||
|  | ||||
| 	rsp, err := s.auth.Token(context.TODO(), req) | ||||
| 	s.Lock() | ||||
| 	defer s.Unlock() | ||||
|  | ||||
| 	if err != nil { | ||||
| 		log.Errorf("Error generating token: %v", err) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	s.options.Token = serializeToken(rsp.Token) | ||||
| } | ||||
|  | ||||
| func serializeToken(t *pb.Token) *auth.Token { | ||||
| 	return &auth.Token{ | ||||
| 		AccessToken:  t.AccessToken, | ||||
|   | ||||
| @@ -10,7 +10,6 @@ import ( | ||||
| 	"sync/atomic" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/micro/go-micro/v2/auth" | ||||
| 	"github.com/micro/go-micro/v2/broker" | ||||
| 	"github.com/micro/go-micro/v2/client" | ||||
| 	"github.com/micro/go-micro/v2/client/selector" | ||||
| @@ -18,7 +17,6 @@ import ( | ||||
| 	"github.com/micro/go-micro/v2/errors" | ||||
| 	"github.com/micro/go-micro/v2/metadata" | ||||
| 	"github.com/micro/go-micro/v2/registry" | ||||
| 	"github.com/micro/go-micro/v2/util/config" | ||||
| 	pnet "github.com/micro/go-micro/v2/util/net" | ||||
|  | ||||
| 	"google.golang.org/grpc" | ||||
| @@ -117,13 +115,6 @@ func (g *grpcClient) call(ctx context.Context, node *registry.Node, req client.R | ||||
| 	// set the content type for the request | ||||
| 	header["x-content-type"] = req.ContentType() | ||||
|  | ||||
| 	// set the authorization header | ||||
| 	if opts.ServiceToken || len(header["authorization"]) == 0 { | ||||
| 		if h := g.authorizationHeader(); len(h) > 0 { | ||||
| 			header["authorization"] = h | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	md := gmetadata.New(header) | ||||
| 	ctx = gmetadata.NewOutgoingContext(ctx, md) | ||||
|  | ||||
| @@ -202,13 +193,6 @@ func (g *grpcClient) stream(ctx context.Context, node *registry.Node, req client | ||||
| 	// set the content type for the request | ||||
| 	header["x-content-type"] = req.ContentType() | ||||
|  | ||||
| 	// set the authorization header | ||||
| 	if opts.ServiceToken || len(header["authorization"]) == 0 { | ||||
| 		if h := g.authorizationHeader(); len(h) > 0 { | ||||
| 			header["authorization"] = h | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	md := gmetadata.New(header) | ||||
| 	ctx = gmetadata.NewOutgoingContext(ctx, md) | ||||
|  | ||||
| @@ -295,26 +279,6 @@ func (g *grpcClient) stream(ctx context.Context, node *registry.Node, req client | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| func (g *grpcClient) authorizationHeader() string { | ||||
| 	// if the caller specifies using service token or no token | ||||
| 	// was passed with the request, set the service token | ||||
| 	var srvToken string | ||||
| 	if g.opts.Auth != nil && g.opts.Auth.Options().Token != nil { | ||||
| 		srvToken = g.opts.Auth.Options().Token.AccessToken | ||||
| 	} | ||||
| 	if len(srvToken) > 0 { | ||||
| 		return auth.BearerScheme + srvToken | ||||
| 	} | ||||
|  | ||||
| 	// fall back to using the authorization token set in config, | ||||
| 	// this enables the CLI to provide a token | ||||
| 	if token, err := config.Get("micro", "auth", "token"); err == nil && len(token) > 0 { | ||||
| 		return auth.BearerScheme + token | ||||
| 	} | ||||
|  | ||||
| 	return "" | ||||
| } | ||||
|  | ||||
| func (g *grpcClient) poolMaxStreams() int { | ||||
| 	if g.opts.Context == nil { | ||||
| 		return DefaultPoolMaxStreams | ||||
|   | ||||
| @@ -39,8 +39,9 @@ func newService(opts ...Option) Service { | ||||
| 	authFn := func() auth.Auth { return options.Server.Options().Auth } | ||||
|  | ||||
| 	// wrap client to inject From-Service header on any calls | ||||
| 	options.Client = wrapper.FromService(serviceName, options.Client, authFn) | ||||
| 	options.Client = wrapper.FromService(serviceName, options.Client) | ||||
| 	options.Client = wrapper.TraceCall(serviceName, trace.DefaultTracer, options.Client) | ||||
| 	options.Client = wrapper.AuthClient(serviceName, options.Server.Options().Id, authFn, options.Client) | ||||
|  | ||||
| 	// wrap the server to provide handler stats | ||||
| 	options.Server.Init( | ||||
|   | ||||
| @@ -2,7 +2,9 @@ package wrapper | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/micro/go-micro/v2/auth" | ||||
| 	"github.com/micro/go-micro/v2/client" | ||||
| @@ -11,68 +13,44 @@ import ( | ||||
| 	"github.com/micro/go-micro/v2/errors" | ||||
| 	"github.com/micro/go-micro/v2/metadata" | ||||
| 	"github.com/micro/go-micro/v2/server" | ||||
| 	"github.com/micro/go-micro/v2/util/config" | ||||
| ) | ||||
|  | ||||
| type clientWrapper struct { | ||||
| type fromServiceWrapper struct { | ||||
| 	client.Client | ||||
|  | ||||
| 	// Auth interface | ||||
| 	auth func() auth.Auth | ||||
| 	// headers to inject | ||||
| 	headers metadata.Metadata | ||||
| } | ||||
|  | ||||
| type traceWrapper struct { | ||||
| 	client.Client | ||||
|  | ||||
| 	name  string | ||||
| 	trace trace.Tracer | ||||
| } | ||||
|  | ||||
| var ( | ||||
| 	HeaderPrefix = "Micro-" | ||||
| ) | ||||
|  | ||||
| func (c *clientWrapper) setHeaders(ctx context.Context) context.Context { | ||||
| func (f *fromServiceWrapper) setHeaders(ctx context.Context) context.Context { | ||||
| 	// don't overwrite keys | ||||
| 	return metadata.MergeContext(ctx, c.headers, false) | ||||
| 	return metadata.MergeContext(ctx, f.headers, false) | ||||
| } | ||||
|  | ||||
| func (c *clientWrapper) Call(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error { | ||||
| 	ctx = c.setHeaders(ctx) | ||||
| 	return c.Client.Call(ctx, req, rsp, opts...) | ||||
| func (f *fromServiceWrapper) Call(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error { | ||||
| 	ctx = f.setHeaders(ctx) | ||||
| 	return f.Client.Call(ctx, req, rsp, opts...) | ||||
| } | ||||
|  | ||||
| func (c *clientWrapper) Stream(ctx context.Context, req client.Request, opts ...client.CallOption) (client.Stream, error) { | ||||
| 	ctx = c.setHeaders(ctx) | ||||
| 	return c.Client.Stream(ctx, req, opts...) | ||||
| func (f *fromServiceWrapper) Stream(ctx context.Context, req client.Request, opts ...client.CallOption) (client.Stream, error) { | ||||
| 	ctx = f.setHeaders(ctx) | ||||
| 	return f.Client.Stream(ctx, req, opts...) | ||||
| } | ||||
|  | ||||
| func (c *clientWrapper) Publish(ctx context.Context, p client.Message, opts ...client.PublishOption) error { | ||||
| 	ctx = c.setHeaders(ctx) | ||||
| 	return c.Client.Publish(ctx, p, opts...) | ||||
| } | ||||
|  | ||||
| func (c *traceWrapper) Call(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error { | ||||
| 	newCtx, s := c.trace.Start(ctx, req.Service()+"."+req.Endpoint()) | ||||
|  | ||||
| 	s.Type = trace.SpanTypeRequestOutbound | ||||
| 	err := c.Client.Call(newCtx, req, rsp, opts...) | ||||
| 	if err != nil { | ||||
| 		s.Metadata["error"] = err.Error() | ||||
| 	} | ||||
|  | ||||
| 	// finish the trace | ||||
| 	c.trace.Finish(s) | ||||
|  | ||||
| 	return err | ||||
| func (f *fromServiceWrapper) Publish(ctx context.Context, p client.Message, opts ...client.PublishOption) error { | ||||
| 	ctx = f.setHeaders(ctx) | ||||
| 	return f.Client.Publish(ctx, p, opts...) | ||||
| } | ||||
|  | ||||
| // FromService wraps a client to inject service and auth metadata | ||||
| func FromService(name string, c client.Client, fn func() auth.Auth) client.Client { | ||||
| 	return &clientWrapper{ | ||||
| func FromService(name string, c client.Client) client.Client { | ||||
| 	return &fromServiceWrapper{ | ||||
| 		c, | ||||
| 		fn, | ||||
| 		metadata.Metadata{ | ||||
| 			HeaderPrefix + "From-Service": name, | ||||
| 		}, | ||||
| @@ -95,6 +73,28 @@ func HandlerStats(stats stats.Stats) server.HandlerWrapper { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type traceWrapper struct { | ||||
| 	client.Client | ||||
|  | ||||
| 	name  string | ||||
| 	trace trace.Tracer | ||||
| } | ||||
|  | ||||
| func (c *traceWrapper) Call(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error { | ||||
| 	newCtx, s := c.trace.Start(ctx, req.Service()+"."+req.Endpoint()) | ||||
|  | ||||
| 	s.Type = trace.SpanTypeRequestOutbound | ||||
| 	err := c.Client.Call(newCtx, req, rsp, opts...) | ||||
| 	if err != nil { | ||||
| 		s.Metadata["error"] = err.Error() | ||||
| 	} | ||||
|  | ||||
| 	// finish the trace | ||||
| 	c.trace.Finish(s) | ||||
|  | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| // TraceCall is a call tracing wrapper | ||||
| func TraceCall(name string, t trace.Tracer, c client.Client) client.Client { | ||||
| 	return &traceWrapper{ | ||||
| @@ -132,6 +132,104 @@ func TraceHandler(t trace.Tracer) server.HandlerWrapper { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type authWrapper struct { | ||||
| 	client.Client | ||||
| 	name string | ||||
| 	id   string | ||||
| 	auth func() auth.Auth | ||||
| } | ||||
|  | ||||
| func (a *authWrapper) Call(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error { | ||||
| 	// parse the options | ||||
| 	var options client.CallOptions | ||||
| 	for _, o := range opts { | ||||
| 		o(&options) | ||||
| 	} | ||||
|  | ||||
| 	// check to see if the authorization header has already been set. | ||||
| 	// We dont't override the header unless the ServiceToken option has | ||||
| 	// been specified or the header wasn't provided | ||||
| 	if _, ok := metadata.Get(ctx, "Authorization"); ok && !options.ServiceToken { | ||||
| 		return a.Client.Call(ctx, req, rsp, opts...) | ||||
| 	} | ||||
|  | ||||
| 	// if auth is nil we won't be able to get an access token, so we execute | ||||
| 	// the request without one. | ||||
| 	aa := a.auth() | ||||
| 	if a == nil { | ||||
| 		return a.Client.Call(ctx, req, rsp, opts...) | ||||
| 	} | ||||
|  | ||||
| 	// performs the call with the authorization token provided | ||||
| 	callWithToken := func(token string) error { | ||||
| 		ctx := metadata.Set(ctx, "Authorization", auth.BearerScheme+token) | ||||
| 		return a.Client.Call(ctx, req, rsp, opts...) | ||||
| 	} | ||||
|  | ||||
| 	// check to see if we have a valid access token | ||||
| 	aaOpts := aa.Options() | ||||
| 	if aaOpts.Token != nil && aaOpts.Token.Expiry.Unix() > time.Now().Unix() { | ||||
| 		return callWithToken(aaOpts.Token.AccessToken) | ||||
| 	} | ||||
|  | ||||
| 	// if we have a refresh token we can use this to generate another access token | ||||
| 	if aaOpts.Token != nil { | ||||
| 		tok, err := aa.Token(auth.WithToken(aaOpts.Token.RefreshToken)) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		aa.Init(auth.ClientToken(tok)) | ||||
| 		return callWithToken(tok.AccessToken) | ||||
| 	} | ||||
|  | ||||
| 	// if we have credentials we can generate a new token for the account | ||||
| 	if len(aaOpts.ID) > 0 && len(aaOpts.Secret) > 0 { | ||||
| 		tok, err := aa.Token(auth.WithCredentials(aaOpts.ID, aaOpts.Secret)) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		aa.Init(auth.ClientToken(tok)) | ||||
| 		return callWithToken(tok.AccessToken) | ||||
| 	} | ||||
|  | ||||
| 	// check to see if a token was provided in config, this is normally used for | ||||
| 	// setting the token when calling via the cli | ||||
| 	if token, err := config.Get("micro", "auth", "token"); err == nil && len(token) > 0 { | ||||
| 		return callWithToken(token) | ||||
| 	} | ||||
|  | ||||
| 	// determine the type of service from the name. we do this so we can allocate | ||||
| 	// different roles depending on the type of services. e.g. we don't want web | ||||
| 	// services talking directly to the runtime. TODO: find a better way to determine | ||||
| 	// the type of service | ||||
| 	serviceType := "service" | ||||
| 	if strings.Contains(a.name, "api") { | ||||
| 		serviceType = "api" | ||||
| 	} else if strings.Contains(a.name, "web") { | ||||
| 		serviceType = "web" | ||||
| 	} | ||||
|  | ||||
| 	// generate a new auth account for the service | ||||
| 	name := fmt.Sprintf("%v-%v", a.name, a.id) | ||||
| 	acc, err := aa.Generate(name, auth.WithNamespace(aaOpts.Namespace), auth.WithRoles(serviceType)) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	token, err := aa.Token(auth.WithCredentials(acc.ID, acc.Secret)) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	aa.Init(auth.ClientToken(token)) | ||||
|  | ||||
| 	// use the token to execute the request | ||||
| 	return callWithToken(token.AccessToken) | ||||
| } | ||||
|  | ||||
| // AuthClient wraps requests with the auth header | ||||
| func AuthClient(name string, id string, auth func() auth.Auth, c client.Client) client.Client { | ||||
| 	return &authWrapper{c, name, id, auth} | ||||
| } | ||||
|  | ||||
| // AuthHandler wraps a server handler to perform auth | ||||
| func AuthHandler(fn func() auth.Auth) server.HandlerWrapper { | ||||
| 	return func(h server.HandlerFunc) server.HandlerFunc { | ||||
|   | ||||
| @@ -4,7 +4,6 @@ import ( | ||||
| 	"context" | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/micro/go-micro/v2/auth" | ||||
| 	"github.com/micro/go-micro/v2/metadata" | ||||
| ) | ||||
|  | ||||
| @@ -33,8 +32,7 @@ func TestWrapper(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	for _, d := range testData { | ||||
| 		c := &clientWrapper{ | ||||
| 			auth:    func() auth.Auth { return nil }, | ||||
| 		c := &fromServiceWrapper{ | ||||
| 			headers: d.headers, | ||||
| 		} | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user