diff --git a/client/options.go b/client/options.go index 4eab7b0c..0cdcff95 100644 --- a/client/options.go +++ b/client/options.go @@ -68,8 +68,8 @@ type CallOptions struct { SelectOptions []selector.SelectOption // Stream timeout for the stream StreamTimeout time.Duration - // Use the services own auth token - ServiceToken bool + // Use the auth token as the authorization header + AuthToken bool // Network to lookup the route within Network string @@ -336,11 +336,11 @@ func WithDialTimeout(d time.Duration) CallOption { } } -// WithServiceToken is a CallOption which overrides the +// WithAuthToken is a CallOption which overrides the // authorization header with the services own auth token -func WithServiceToken() CallOption { +func WithAuthToken() CallOption { return func(o *CallOptions) { - o.ServiceToken = true + o.AuthToken = true } } diff --git a/util/wrapper/wrapper.go b/util/wrapper/wrapper.go index 061fdcec..83bd8910 100644 --- a/util/wrapper/wrapper.go +++ b/util/wrapper/wrapper.go @@ -2,338 +2,10 @@ package wrapper import ( "context" - "reflect" - "strings" - "github.com/micro/go-micro/v3/auth" "github.com/micro/go-micro/v3/client" - "github.com/micro/go-micro/v3/debug/stats" - "github.com/micro/go-micro/v3/debug/trace" - "github.com/micro/go-micro/v3/errors" - "github.com/micro/go-micro/v3/metadata" - "github.com/micro/go-micro/v3/server" ) -type fromServiceWrapper struct { - client.Client - - // headers to inject - headers metadata.Metadata -} - -var ( - HeaderPrefix = "Micro-" -) - -func (f *fromServiceWrapper) setHeaders(ctx context.Context) context.Context { - // don't overwrite keys - return metadata.MergeContext(ctx, f.headers, false) -} - -func (f *fromServiceWrapper) Call(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error { - ctx = f.setHeaders(ctx) - return f.Client.Call(ctx, req, rsp, opts...) -} - -func (f *fromServiceWrapper) Stream(ctx context.Context, req client.Request, opts ...client.CallOption) (client.Stream, error) { - ctx = f.setHeaders(ctx) - return f.Client.Stream(ctx, req, opts...) -} - -func (f *fromServiceWrapper) Publish(ctx context.Context, p client.Message, opts ...client.PublishOption) error { - ctx = f.setHeaders(ctx) - return f.Client.Publish(ctx, p, opts...) -} - -// FromService wraps a client to inject service and auth metadata -func FromService(name string, c client.Client) client.Client { - return &fromServiceWrapper{ - c, - metadata.Metadata{ - HeaderPrefix + "From-Service": name, - }, - } -} - -// HandlerStats wraps a server handler to generate request/error stats -func HandlerStats(stats stats.Stats) server.HandlerWrapper { - // return a handler wrapper - return func(h server.HandlerFunc) server.HandlerFunc { - // return a function that returns a function - return func(ctx context.Context, req server.Request, rsp interface{}) error { - // execute the handler - err := h(ctx, req, rsp) - // record the stats - stats.Record(err) - // return the error - return err - } - } -} - -type traceWrapper struct { - client.Client - - name string - trace trace.Tracer -} - -func (c *traceWrapper) Call(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error { - newCtx, s := c.trace.Start(ctx, req.Service()+"."+req.Endpoint()) - - s.Type = trace.SpanTypeRequestOutbound - err := c.Client.Call(newCtx, req, rsp, opts...) - if err != nil { - s.Metadata["error"] = err.Error() - } - - // finish the trace - c.trace.Finish(s) - - return err -} - -// TraceCall is a call tracing wrapper -func TraceCall(name string, t trace.Tracer, c client.Client) client.Client { - return &traceWrapper{ - name: name, - trace: t, - Client: c, - } -} - -// TraceHandler wraps a server handler to perform tracing -func TraceHandler(t trace.Tracer) server.HandlerWrapper { - // return a handler wrapper - return func(h server.HandlerFunc) server.HandlerFunc { - // return a function that returns a function - return func(ctx context.Context, req server.Request, rsp interface{}) error { - // don't store traces for debug - if strings.HasPrefix(req.Endpoint(), "Debug.") { - return h(ctx, req, rsp) - } - - // get the span - newCtx, s := t.Start(ctx, req.Service()+"."+req.Endpoint()) - s.Type = trace.SpanTypeRequestInbound - - err := h(newCtx, req, rsp) - if err != nil { - s.Metadata["error"] = err.Error() - } - - // finish - t.Finish(s) - - return err - } - } -} - -type authWrapper struct { - client.Client - auth func() auth.Auth -} - -func (a *authWrapper) Call(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error { - ctx = a.wrapContext(ctx, opts...) - return a.Client.Call(ctx, req, rsp, opts...) -} - -func (a *authWrapper) Stream(ctx context.Context, req client.Request, opts ...client.CallOption) (client.Stream, error) { - ctx = a.wrapContext(ctx, opts...) - return a.Client.Stream(ctx, req, opts...) -} - -func (a *authWrapper) wrapContext(ctx context.Context, opts ...client.CallOption) context.Context { - // parse the options - var options client.CallOptions - for _, o := range opts { - o(&options) - } - - // check to see if the authorization header has already been set. - // We dont't override the header unless the ServiceToken option has - // been specified or the header wasn't provided - if _, ok := metadata.Get(ctx, "Authorization"); ok && !options.ServiceToken { - return ctx - } - - // if auth is nil we won't be able to get an access token, so we execute - // the request without one. - aa := a.auth() - if aa == nil { - return ctx - } - - // set the namespace header if it has not been set (e.g. on a service to service request) - if _, ok := metadata.Get(ctx, "Micro-Namespace"); !ok { - ctx = metadata.Set(ctx, "Micro-Namespace", aa.Options().Issuer) - } - - // check to see if we have a valid access token - aaOpts := aa.Options() - if aaOpts.Token != nil && !aaOpts.Token.Expired() { - ctx = metadata.Set(ctx, "Authorization", auth.BearerScheme+aaOpts.Token.AccessToken) - return ctx - } - - // call without an auth token - return ctx -} - -// AuthClient wraps requests with the auth header -func AuthClient(auth func() auth.Auth, c client.Client) client.Client { - return &authWrapper{c, auth} -} - -func AuthHandlerNamespace(ns string) AuthHandlerOption { - return func(o *AuthHandlerOptions) { - o.Namespace = ns - } -} - -type AuthHandlerOption func(o *AuthHandlerOptions) - -type AuthHandlerOptions struct { - Namespace string -} - -// AuthHandler wraps a server handler to perform auth -func AuthHandler(fn func() auth.Auth, opts ...AuthHandlerOption) server.HandlerWrapper { - return func(h server.HandlerFunc) server.HandlerFunc { - return func(ctx context.Context, req server.Request, rsp interface{}) error { - // parse the options - options := AuthHandlerOptions{} - for _, o := range opts { - o(&options) - } - - // get the auth.Auth interface - a := fn() - - // Check for debug endpoints which should be excluded from auth - if strings.HasPrefix(req.Endpoint(), "Debug.") { - return h(ctx, req, rsp) - } - - // Extract the token if the header is present. We will inspect the token regardless of if it's - // present or not since noop auth will return a blank account upon Inspecting a blank token. - var token string - if header, ok := metadata.Get(ctx, "Authorization"); ok { - // Ensure the correct scheme is being used - if !strings.HasPrefix(header, auth.BearerScheme) { - return errors.Unauthorized(req.Service(), "invalid authorization header. expected Bearer schema") - } - - // Strip the bearer scheme prefix - token = strings.TrimPrefix(header, auth.BearerScheme) - } - - // Inspect the token and decode an account - account, _ := a.Inspect(token) - - // Extract the namespace header - ns, ok := metadata.Get(ctx, "Micro-Namespace") - if !ok { - ns = a.Options().Issuer - ctx = metadata.Set(ctx, "Micro-Namespace", ns) - } - - // Check the issuer matches the services namespace. TODO: Stop allowing micro to access - // any namespace and instead check for the server issuer. - if account != nil && account.Issuer != ns && account.Issuer != "micro" { - return errors.Forbidden(req.Service(), "Account was issued by %v, not %v", account.Issuer, ns) - } - - // construct the resource - res := &auth.Resource{ - Type: "service", - Name: req.Service(), - Endpoint: req.Endpoint(), - } - - // Normal services set the namespace to prevent it being overriden - // by setting the Namespace header, however this isn't the case for - // the proxy which uses the namespace header when routing requests - // to a specific network. - if len(options.Namespace) == 0 { - options.Namespace = ns - } - - // Verify the caller has access to the resource. - err := a.Verify(account, res, auth.VerifyNamespace(options.Namespace)) - if err == auth.ErrForbidden && account != nil { - return errors.Forbidden(req.Service(), "Forbidden call made to %v:%v by %v", req.Service(), req.Endpoint(), account.ID) - } else if err == auth.ErrForbidden { - return errors.Unauthorized(req.Service(), "Unauthorized call made to %v:%v", req.Service(), req.Endpoint()) - } else if err != nil { - return errors.InternalServerError(req.Service(), "Error authorizing request: %v", err) - } - - // There is an account, set it in the context - if account != nil { - ctx = auth.ContextWithAccount(ctx, account) - } - - // The user is authorised, allow the call - return h(ctx, req, rsp) - } - } -} - -type cacheWrapper struct { - cacheFn func() *client.Cache - client.Client -} - -// Call executes the request. If the CacheExpiry option was set, the response will be cached using -// a hash of the metadata and request as the key. -func (c *cacheWrapper) Call(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error { - // parse the options - var options client.CallOptions - for _, o := range opts { - o(&options) - } - - // if the client doesn't have a cacbe setup don't continue - cache := c.cacheFn() - if cache == nil { - return c.Client.Call(ctx, req, rsp, opts...) - } - - // if the cache expiry is not set, execute the call without the cache - if options.CacheExpiry == 0 { - return c.Client.Call(ctx, req, rsp, opts...) - } - - // if the response is nil don't call the cache since we can't assign the response - if rsp == nil { - return c.Client.Call(ctx, req, rsp, opts...) - } - - // check to see if there is a response cached, if there is assign it - if r, ok := cache.Get(ctx, req); ok { - val := reflect.ValueOf(rsp).Elem() - val.Set(reflect.ValueOf(r).Elem()) - return nil - } - - // don't cache the result if there was an error - if err := c.Client.Call(ctx, req, rsp, opts...); err != nil { - return err - } - - // set the result in the cache - cache.Set(ctx, req, rsp, options.CacheExpiry) - return nil -} - -// CacheClient wraps requests with the cache wrapper -func CacheClient(cacheFn func() *client.Cache, c client.Client) client.Client { - return &cacheWrapper{cacheFn, c} -} - type staticClient struct { address string client.Client diff --git a/util/wrapper/wrapper_static_client_test.go b/util/wrapper/wrapper_static_client_test.go deleted file mode 100644 index dafa17fe..00000000 --- a/util/wrapper/wrapper_static_client_test.go +++ /dev/null @@ -1,82 +0,0 @@ -package wrapper_test - -import ( - "context" - "testing" - - "github.com/micro/go-micro/v3/broker" - bmemory "github.com/micro/go-micro/v3/broker/memory" - "github.com/micro/go-micro/v3/client" - "github.com/micro/go-micro/v3/client/grpc" - rmemory "github.com/micro/go-micro/v3/registry/memory" - "github.com/micro/go-micro/v3/router" - rtreg "github.com/micro/go-micro/v3/router/registry" - "github.com/micro/go-micro/v3/server" - grpcsrv "github.com/micro/go-micro/v3/server/grpc" - tmemory "github.com/micro/go-micro/v3/transport/memory" - wrapper "github.com/micro/go-micro/v3/util/wrapper" -) - -type TestFoo struct { -} - -type TestReq struct{} - -type TestRsp struct { - Data string -} - -func (h *TestFoo) Bar(ctx context.Context, req *TestReq, rsp *TestRsp) error { - rsp.Data = "pass" - return nil -} - -func TestStaticClientWrapper(t *testing.T) { - var err error - - req := grpc.NewClient().NewRequest( - "go.micro.service.foo", - "TestFoo.Bar", - &TestReq{}, - client.WithContentType("application/json"), - ) - rsp := &TestRsp{} - - reg := rmemory.NewRegistry() - brk := bmemory.NewBroker(broker.Registry(reg)) - tr := tmemory.NewTransport() - rtr := rtreg.NewRouter(router.Registry(reg)) - - srv := grpcsrv.NewServer( - server.Broker(brk), - server.Registry(reg), - server.Name("go.micro.service.foo"), - server.Address("127.0.0.1:0"), - server.Transport(tr), - ) - if err = srv.Handle(srv.NewHandler(&TestFoo{})); err != nil { - t.Fatal(err) - } - - if err = srv.Start(); err != nil { - t.Fatal(err) - } - - cli := grpc.NewClient( - client.Router(rtr), - client.Broker(brk), - client.Transport(tr), - ) - - w1 := wrapper.StaticClient("xxx_localhost:12345", cli) - if err = w1.Call(context.TODO(), req, nil); err == nil { - t.Fatal("address xxx_#localhost:12345 must not exists and call must be failed") - } - - w2 := wrapper.StaticClient(srv.Options().Address, cli) - if err = w2.Call(context.TODO(), req, rsp); err != nil { - t.Fatal(err) - } else if rsp.Data != "pass" { - t.Fatalf("something wrong with response: %#+v", rsp) - } -} diff --git a/util/wrapper/wrapper_test.go b/util/wrapper/wrapper_test.go index 2e9535f0..dafa17fe 100644 --- a/util/wrapper/wrapper_test.go +++ b/util/wrapper/wrapper_test.go @@ -1,447 +1,82 @@ -package wrapper +package wrapper_test import ( "context" - "net/http" - "reflect" "testing" - "time" - "github.com/micro/go-micro/v3/auth" + "github.com/micro/go-micro/v3/broker" + bmemory "github.com/micro/go-micro/v3/broker/memory" "github.com/micro/go-micro/v3/client" "github.com/micro/go-micro/v3/client/grpc" - "github.com/micro/go-micro/v3/errors" - "github.com/micro/go-micro/v3/metadata" + rmemory "github.com/micro/go-micro/v3/registry/memory" + "github.com/micro/go-micro/v3/router" + rtreg "github.com/micro/go-micro/v3/router/registry" "github.com/micro/go-micro/v3/server" + grpcsrv "github.com/micro/go-micro/v3/server/grpc" + tmemory "github.com/micro/go-micro/v3/transport/memory" + wrapper "github.com/micro/go-micro/v3/util/wrapper" ) -func TestWrapper(t *testing.T) { - testData := []struct { - existing metadata.Metadata - headers metadata.Metadata - overwrite bool - }{ - { - existing: metadata.Metadata{}, - headers: metadata.Metadata{ - "Foo": "bar", - }, - overwrite: true, - }, - { - existing: metadata.Metadata{ - "Foo": "bar", - }, - headers: metadata.Metadata{ - "Foo": "baz", - }, - overwrite: false, - }, - } - - for _, d := range testData { - c := &fromServiceWrapper{ - headers: d.headers, - } - - ctx := metadata.NewContext(context.Background(), d.existing) - ctx = c.setHeaders(ctx) - md, _ := metadata.FromContext(ctx) - - for k, v := range d.headers { - if d.overwrite && md[k] != v { - t.Fatalf("Expected %s=%s got %s=%s", k, v, k, md[k]) - } - if !d.overwrite && md[k] != d.existing[k] { - t.Fatalf("Expected %s=%s got %s=%s", k, d.existing[k], k, md[k]) - } - } - } +type TestFoo struct { } -type testAuth struct { - verifyCount int - inspectCount int - issuer string - inspectAccount *auth.Account - verifyError error +type TestReq struct{} - auth.Auth +type TestRsp struct { + Data string } -func (a *testAuth) Verify(acc *auth.Account, res *auth.Resource, opts ...auth.VerifyOption) error { - a.verifyCount = a.verifyCount + 1 - return a.verifyError -} - -func (a *testAuth) Inspect(token string) (*auth.Account, error) { - a.inspectCount = a.inspectCount + 1 - return a.inspectAccount, nil -} - -func (a *testAuth) Options() auth.Options { - return auth.Options{Issuer: a.issuer} -} - -type testRequest struct { - service string - endpoint string - - server.Request -} - -func (r testRequest) Service() string { - return r.service -} - -func (r testRequest) Endpoint() string { - return r.endpoint -} - -func TestAuthHandler(t *testing.T) { - h := func(ctx context.Context, req server.Request, rsp interface{}) error { - return nil - } - - debugReq := testRequest{service: "go.micro.service.foo", endpoint: "Debug.Foo"} - serviceReq := testRequest{service: "go.micro.service.foo", endpoint: "Foo.Bar"} - - // Debug endpoints should be excluded from auth so auth.Verify should never get called - t.Run("DebugEndpoint", func(t *testing.T) { - a := testAuth{} - handler := AuthHandler(func() auth.Auth { - return &a - }) - - err := handler(h)(context.TODO(), debugReq, nil) - if err != nil { - t.Errorf("Expected nil error but got %v", err) - } - if a.verifyCount != 0 { - t.Errorf("Did not expect verify to be called") - } - }) - - // If the Authorization header is invalid, an error should be returned and verify not called - t.Run("InvalidAuthorizationHeader", func(t *testing.T) { - a := testAuth{} - handler := AuthHandler(func() auth.Auth { - return &a - }) - - ctx := metadata.Set(context.TODO(), "Authorization", "Invalid") - err := handler(h)(ctx, serviceReq, nil) - if verr, ok := err.(*errors.Error); !ok || verr.Code != http.StatusUnauthorized { - t.Errorf("Expected unauthorized error but got %v", err) - } - if a.inspectCount != 0 { - t.Errorf("Did not expect inspect to be called") - } - }) - - // If the Authorization header is valid, no error should be returned and verify should called - t.Run("ValidAuthorizationHeader", func(t *testing.T) { - a := testAuth{} - handler := AuthHandler(func() auth.Auth { - return &a - }) - - ctx := metadata.Set(context.TODO(), "Authorization", auth.BearerScheme+"Token") - err := handler(h)(ctx, serviceReq, nil) - if err != nil { - t.Errorf("Expected nil error but got %v", err) - } - if a.inspectCount != 1 { - t.Errorf("Expected inspect to be called") - } - }) - - // If the issuer header was not set on the request, the wrapper should set it to the auths - // own issuer - t.Run("BlankIssuerHeader", func(t *testing.T) { - a := testAuth{issuer: "myissuer"} - handler := AuthHandler(func() auth.Auth { - return &a - }) - - inCtx := context.TODO() - h := func(ctx context.Context, req server.Request, rsp interface{}) error { - inCtx = ctx - return nil - } - - err := handler(h)(inCtx, serviceReq, nil) - if err != nil { - t.Errorf("Expected nil error but got %v", err) - } - if ns, _ := metadata.Get(inCtx, "Micro-Namespace"); ns != a.issuer { - t.Errorf("Expected issuer to be set to %v but was %v", a.issuer, ns) - } - }) - t.Run("ValidIssuerHeader", func(t *testing.T) { - a := testAuth{issuer: "myissuer"} - handler := AuthHandler(func() auth.Auth { - return &a - }) - - inNs := "reqissuer" - inCtx := metadata.Set(context.TODO(), "Micro-Namespace", inNs) - h := func(ctx context.Context, req server.Request, rsp interface{}) error { - inCtx = ctx - return nil - } - - err := handler(h)(inCtx, serviceReq, nil) - if err != nil { - t.Errorf("Expected nil error but got %v", err) - } - if ns, _ := metadata.Get(inCtx, "Micro-Namespace"); ns != inNs { - t.Errorf("Expected issuer to remain as %v but was set to %v", inNs, ns) - } - }) - - // If the callers account was set but the issuer didn't match that of the request, the request - // should be forbidden - t.Run("InvalidAccountIssuer", func(t *testing.T) { - a := testAuth{ - issuer: "validissuer", - inspectAccount: &auth.Account{Issuer: "invalidissuer"}, - } - - handler := AuthHandler(func() auth.Auth { - return &a - }) - - ctx := metadata.Set(context.TODO(), "Authorization", auth.BearerScheme+"Token") - err := handler(h)(ctx, serviceReq, nil) - if verr, ok := err.(*errors.Error); !ok || verr.Code != http.StatusForbidden { - t.Errorf("Expected forbidden error but got %v", err) - } - }) - t.Run("ValidAccountIssuer", func(t *testing.T) { - a := testAuth{ - issuer: "validissuer", - inspectAccount: &auth.Account{Issuer: "validissuer"}, - } - - handler := AuthHandler(func() auth.Auth { - return &a - }) - - ctx := metadata.Set(context.TODO(), "Authorization", auth.BearerScheme+"Token") - err := handler(h)(ctx, serviceReq, nil) - if err != nil { - t.Errorf("Expected nil error but got %v", err) - } - }) - - // If the caller had a nil account and verify returns an error, the request should be unauthorised - t.Run("NilAccountUnauthorized", func(t *testing.T) { - a := testAuth{verifyError: auth.ErrForbidden} - - handler := AuthHandler(func() auth.Auth { - return &a - }) - - err := handler(h)(context.TODO(), serviceReq, nil) - if verr, ok := err.(*errors.Error); !ok || verr.Code != http.StatusUnauthorized { - t.Errorf("Expected unauthorizard error but got %v", err) - } - }) - t.Run("AccountForbidden", func(t *testing.T) { - a := testAuth{verifyError: auth.ErrForbidden, inspectAccount: &auth.Account{}} - - handler := AuthHandler(func() auth.Auth { - return &a - }) - - ctx := metadata.Set(context.TODO(), "Authorization", auth.BearerScheme+"Token") - err := handler(h)(ctx, serviceReq, nil) - if verr, ok := err.(*errors.Error); !ok || verr.Code != http.StatusForbidden { - t.Errorf("Expected forbidden error but got %v", err) - } - }) - t.Run("AccountValid", func(t *testing.T) { - a := testAuth{inspectAccount: &auth.Account{}} - - handler := AuthHandler(func() auth.Auth { - return &a - }) - - ctx := metadata.Set(context.TODO(), "Authorization", auth.BearerScheme+"Token") - err := handler(h)(ctx, serviceReq, nil) - if err != nil { - t.Errorf("Expected nil error but got %v", err) - } - }) - - // If an account is returned from inspecting the token, it should be set in the context - t.Run("ContextWithAccount", func(t *testing.T) { - accID := "myaccountid" - a := testAuth{inspectAccount: &auth.Account{ID: accID}} - - handler := AuthHandler(func() auth.Auth { - return &a - }) - - inCtx := metadata.Set(context.TODO(), "Authorization", auth.BearerScheme+"Token") - h := func(ctx context.Context, req server.Request, rsp interface{}) error { - inCtx = ctx - return nil - } - - err := handler(h)(inCtx, serviceReq, nil) - if err != nil { - t.Errorf("Expected nil error but got %v", err) - } - if acc, ok := auth.AccountFromContext(inCtx); !ok { - t.Errorf("Expected an account to be set in the context") - } else if acc.ID != accID { - t.Errorf("Expected the account in the context to have the ID %v but it actually had %v", accID, acc.ID) - } - }) - - // If verify returns an error the handler should not be called - t.Run("HandlerNotCalled", func(t *testing.T) { - a := testAuth{verifyError: auth.ErrForbidden} - - handler := AuthHandler(func() auth.Auth { - return &a - }) - - var handlerCalled bool - h := func(ctx context.Context, req server.Request, rsp interface{}) error { - handlerCalled = true - return nil - } - - ctx := metadata.Set(context.TODO(), "Authorization", auth.BearerScheme+"Token") - err := handler(h)(ctx, serviceReq, nil) - if verr, ok := err.(*errors.Error); !ok || verr.Code != http.StatusUnauthorized { - t.Errorf("Expected unauthorizard error but got %v", err) - } - if handlerCalled { - t.Errorf("Expected the handler to not be called") - } - }) - - // If verify does not return an error the handler should be called - t.Run("HandlerNotCalled", func(t *testing.T) { - a := testAuth{} - - handler := AuthHandler(func() auth.Auth { - return &a - }) - - var handlerCalled bool - h := func(ctx context.Context, req server.Request, rsp interface{}) error { - handlerCalled = true - return nil - } - - ctx := metadata.Set(context.TODO(), "Authorization", auth.BearerScheme+"Token") - err := handler(h)(ctx, serviceReq, nil) - if err != nil { - t.Errorf("Expected nil error but got %v", err) - } - if !handlerCalled { - t.Errorf("Expected the handler be called") - } - }) -} - -type testClient struct { - callCount int - callRsp interface{} - client.Client -} - -func (c *testClient) Call(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error { - c.callCount++ - - if c.callRsp != nil { - val := reflect.ValueOf(rsp).Elem() - val.Set(reflect.ValueOf(c.callRsp).Elem()) - } - +func (h *TestFoo) Bar(ctx context.Context, req *TestReq, rsp *TestRsp) error { + rsp.Data = "pass" return nil } -type testRsp struct { - value string -} - -func TestCacheWrapper(t *testing.T) { - req := grpc.NewClient().NewRequest("go.micro.service.foo", "Foo.Bar", nil) - - t.Run("NilCache", func(t *testing.T) { - cli := new(testClient) - - w := CacheClient(func() *client.Cache { - return nil - }, cli) - - // performing two requests should increment the call count by two indicating the cache wasn't - // used even though the WithCache option was passed. - w.Call(context.TODO(), req, nil, client.WithCache(time.Minute)) - w.Call(context.TODO(), req, nil, client.WithCache(time.Minute)) - - if cli.callCount != 2 { - t.Errorf("Expected the client to have been called twice") - } - }) - - t.Run("OptionNotSet", func(t *testing.T) { - cli := new(testClient) - cache := client.NewCache() - - w := CacheClient(func() *client.Cache { - return cache - }, cli) - - // performing two requests should increment the call count by two since we didn't pass the WithCache - // option to Call. - w.Call(context.TODO(), req, nil) - w.Call(context.TODO(), req, nil) - - if cli.callCount != 2 { - t.Errorf("Expected the client to have been called twice") - } - }) - - t.Run("OptionSet", func(t *testing.T) { - val := "foo" - cli := &testClient{callRsp: &testRsp{value: val}} - cache := client.NewCache() - - w := CacheClient(func() *client.Cache { - return cache - }, cli) - - // performing two requests should increment the call count by once since the second request should - // have used the cache. The correct value should be set on both responses and no errors should - // be returned. - rsp1 := &testRsp{} - rsp2 := &testRsp{} - err1 := w.Call(context.TODO(), req, rsp1, client.WithCache(time.Minute)) - err2 := w.Call(context.TODO(), req, rsp2, client.WithCache(time.Minute)) - - if err1 != nil { - t.Errorf("Expected nil error, got %v", err1) - } - if err2 != nil { - t.Errorf("Expected nil error, got %v", err2) - } - - if rsp1.value != val { - t.Errorf("Expected %v to be assigned to the value, got %v", val, rsp1.value) - } - if rsp2.value != val { - t.Errorf("Expected %v to be assigned to the value, got %v", val, rsp2.value) - } - - if cli.callCount != 1 { - t.Errorf("Expected the client to be called 1 time, was actually called %v time(s)", cli.callCount) - } - }) +func TestStaticClientWrapper(t *testing.T) { + var err error + + req := grpc.NewClient().NewRequest( + "go.micro.service.foo", + "TestFoo.Bar", + &TestReq{}, + client.WithContentType("application/json"), + ) + rsp := &TestRsp{} + + reg := rmemory.NewRegistry() + brk := bmemory.NewBroker(broker.Registry(reg)) + tr := tmemory.NewTransport() + rtr := rtreg.NewRouter(router.Registry(reg)) + + srv := grpcsrv.NewServer( + server.Broker(brk), + server.Registry(reg), + server.Name("go.micro.service.foo"), + server.Address("127.0.0.1:0"), + server.Transport(tr), + ) + if err = srv.Handle(srv.NewHandler(&TestFoo{})); err != nil { + t.Fatal(err) + } + + if err = srv.Start(); err != nil { + t.Fatal(err) + } + + cli := grpc.NewClient( + client.Router(rtr), + client.Broker(brk), + client.Transport(tr), + ) + + w1 := wrapper.StaticClient("xxx_localhost:12345", cli) + if err = w1.Call(context.TODO(), req, nil); err == nil { + t.Fatal("address xxx_#localhost:12345 must not exists and call must be failed") + } + + w2 := wrapper.StaticClient(srv.Options().Address, cli) + if err = w2.Call(context.TODO(), req, rsp); err != nil { + t.Fatal(err) + } else if rsp.Data != "pass" { + t.Fatalf("something wrong with response: %#+v", rsp) + } }