package mock import ( "context" "fmt" "reflect" "strings" "sync" "time" "go.unistack.org/micro/v3/client" "go.unistack.org/micro/v3/codec" "go.unistack.org/micro/v3/errors" rutil "go.unistack.org/micro/v3/util/reflect" ) var _ client.Client = (*MockClient)(nil) type MockClient struct { opts client.Options mu sync.Mutex expected []expectation } func (c *MockClient) newCodec(ct string) (codec.Codec, error) { if idx := strings.IndexRune(ct, ';'); idx >= 0 { ct = ct[:idx] } if cc, ok := c.opts.Codecs[ct]; ok { return cc, nil } return nil, codec.ErrUnknownContentType } // an expectation interface type expectation interface { fulfilled() bool Lock() Unlock() String() string } // common expectation struct // satisfies the expectation interface type commonExpectation struct { sync.Mutex triggered bool err error } func (e *commonExpectation) fulfilled() bool { return e.triggered } // ExpectedRequest is used to manage client.Call expectations. // Returned by *MockClient.ExpectRequest. type ExpectedRequest struct { commonExpectation delay time.Duration rsp interface{} req client.Request rspct string } // WillDelayFor allows to specify duration for which it will delay result. May // be used together with Context. func (e *ExpectedRequest) WillDelayFor(duration time.Duration) *ExpectedRequest { e.delay = duration return e } // WillReturnError allows to set an error for expected client.Call func (e *ExpectedRequest) WillReturnError(err error) *ExpectedRequest { e.err = err return e } // WillReturnResponse allows to set a response for expected client.Call func (e *ExpectedRequest) WillReturnResponse(ct string, rsp interface{}) *ExpectedRequest { e.rsp = rsp e.rspct = ct return e } // String returns string representation func (e *ExpectedRequest) String() string { msg := "ExpectedRequest => expecting client.Call request" if e.err != nil { msg += fmt.Sprintf(", which should return error: %s", e.err) } if e.rsp != nil { msg += fmt.Sprintf(", which should return rsp: %v", e.rsp) } return msg } func (c *MockClient) ExpectationsWereMet() error { for _, e := range c.expected { e.Lock() fulfilled := e.fulfilled() e.Unlock() if !fulfilled { return fmt.Errorf("there is a remaining expectation which was not matched: %s", e) } } return nil } func (c *MockClient) ExpectRequest(req client.Request) *ExpectedRequest { e := &ExpectedRequest{req: req} c.expected = append(c.expected, e) return e } func (c *MockClient) BatchPublish(ctx context.Context, msgs []client.Message, opts ...client.PublishOption) error { return nil } func (c *MockClient) Publish(ctx context.Context, msg client.Message, opts ...client.PublishOption) error { return nil // c.opts.Broker.Publish() } func (c *MockClient) Call(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error { c.mu.Lock() defer c.mu.Unlock() options := client.NewCallOptions(opts...) ct := req.ContentType() if len(options.ContentType) > 0 { ct = options.ContentType } for _, e := range c.expected { er, ok := e.(*ExpectedRequest) if !ok { continue } if er.delay > 0 { time.Sleep(er.delay) } if er.req.Service() != req.Service() || er.req.Method() != req.Method() { continue } er.triggered = true if er.err != nil { return er.err } if er.req == nil { return errors.BadRequest("go.micro.client", "empty request passed") } src := er.req.Body() switch reqbody := er.req.Body().(type) { case []byte: cf, err := c.newCodec(ct) if err != nil { return errors.BadRequest("go.micro.client", err.Error()) } src, err = rutil.Zero(req.Body()) if err == nil { err = cf.Unmarshal(reqbody, src) } if err != nil { return errors.BadRequest("go.micro.client", err.Error()) } } if !reflect.DeepEqual(req.Body(), src) { return errors.BadRequest("go.micro.client", "unexpected request %v != %v", req.Body(), src) } if er.rsp == nil { return nil } switch rspbody := er.rsp.(type) { case []byte: if er.rspct != "" { ct = er.rspct } cf, err := c.newCodec(ct) if err != nil { return errors.BadRequest("go.micro.client", err.Error()) } if err = cf.Unmarshal(rspbody, rsp); err != nil { return errors.BadRequest("go.micro.client", err.Error()) } return nil } v := reflect.ValueOf(rsp) if t := reflect.TypeOf(rsp); t.Kind() == reflect.Ptr { v = reflect.Indirect(v) } response := er.rsp if t := reflect.TypeOf(er.rsp); t.Kind() == reflect.Func { response = reflect.ValueOf(er.rsp).Call([]reflect.Value{})[0].Interface() } v.Set(reflect.ValueOf(response)) return nil } return fmt.Errorf("can't find service %s", req.Method()) } func (c *MockClient) Stream(ctx context.Context, req client.Request, opts ...client.CallOption) (client.Stream, error) { return nil, nil } func (c *MockClient) Init(opts ...client.Option) error { for _, o := range opts { o(&c.opts) } return nil } func (c *MockClient) String() string { return "mock" } func (c *MockClient) Name() string { return c.opts.Name } func (c *MockClient) Options() client.Options { return c.opts } func (c *MockClient) NewMessage(topic string, msg interface{}, opts ...client.MessageOption) client.Message { return nil } func (c *MockClient) NewRequest(service, method string, req interface{}, opts ...client.RequestOption) client.Request { return newRequest(service, method, req, c.opts.ContentType, opts...) } func NewClient(opts ...client.Option) *MockClient { options := client.NewOptions(opts...) return &MockClient{opts: options} }