diff --git a/mock.go b/mock.go index 70c4f98..6a5d334 100644 --- a/mock.go +++ b/mock.go @@ -61,6 +61,7 @@ type ExpectedRequest struct { delay time.Duration rsp interface{} req client.Request + rspct string } // WillDelayFor allows to specify duration for which it will delay result. May @@ -77,8 +78,9 @@ func (e *ExpectedRequest) WillReturnError(err error) *ExpectedRequest { } // WillReturnResponse allows to set a response for expected client.Call -func (e *ExpectedRequest) WillReturnResponse(rsp interface{}) *ExpectedRequest { +func (e *ExpectedRequest) WillReturnResponse(ct string, rsp interface{}) *ExpectedRequest { e.rsp = rsp + e.rspct = ct return e } @@ -133,11 +135,6 @@ func (c *MockClient) Call(ctx context.Context, req client.Request, rsp interface ct = options.ContentType } - cf, err := c.newCodec(ct) - if err != nil { - return errors.BadRequest("go.micro.client", err.Error()) - } - for _, e := range c.expected { er, ok := e.(*ExpectedRequest) if !ok { @@ -166,6 +163,10 @@ func (c *MockClient) Call(ctx context.Context, req client.Request, rsp interface src := er.req.Body() switch reqbody := er.req.Body().(type) { case []byte: + cf, err := c.newCodec(ct) + if err != nil { + return errors.BadRequest("go.micro.client", err.Error()) + } src, err = rutil.Zero(req.Body()) if err == nil { err = cf.Unmarshal(reqbody, src) @@ -173,10 +174,6 @@ func (c *MockClient) Call(ctx context.Context, req client.Request, rsp interface if err != nil { return errors.BadRequest("go.micro.client", err.Error()) } - case client.Request: - break - default: - return errors.BadRequest("go.micro.client", "unknown request passed: %v", reqbody) } if !reflect.DeepEqual(req.Body(), src) { @@ -189,6 +186,13 @@ func (c *MockClient) Call(ctx context.Context, req client.Request, rsp interface switch rspbody := er.rsp.(type) { case []byte: + if er.rspct != "" { + ct = er.rspct + } + cf, err := c.newCodec(ct) + if err != nil { + return errors.BadRequest("go.micro.client", err.Error()) + } if err = cf.Unmarshal(rspbody, rsp); err != nil { return errors.BadRequest("go.micro.client", err.Error()) }