diff --git a/client/mock/mock.go b/client/mock/mock.go index 6ee8f1f8..ba182478 100644 --- a/client/mock/mock.go +++ b/client/mock/mock.go @@ -84,10 +84,24 @@ func (m *MockClient) Call(ctx context.Context, req client.Request, rsp interface response := r.Response if t := reflect.TypeOf(r.Response); t.Kind() == reflect.Func { var request []reflect.Value - if t.NumIn() == 1 { + switch t.NumIn() { + case 1: + // one input params: (req) request = append(request, reflect.ValueOf(req.Body())) + case 2: + // two input params: (ctx, req) + request = append(request, reflect.ValueOf(ctx), reflect.ValueOf(req.Body())) + } + + responseValue := reflect.ValueOf(r.Response).Call(request) + response = responseValue[0].Interface() + if len(responseValue) == 2 { + // make it possible to return error in response function + respErr, ok := responseValue[1].Interface().(error) + if ok && respErr != nil { + return respErr + } } - response = reflect.ValueOf(r.Response).Call(request)[0].Interface() } v.Set(reflect.ValueOf(response)) diff --git a/client/mock/mock_test.go b/client/mock/mock_test.go index 4f00f8a9..0730610f 100644 --- a/client/mock/mock_test.go +++ b/client/mock/mock_test.go @@ -2,6 +2,7 @@ package mock import ( "context" + "fmt" "testing" "github.com/micro/go-micro/errors" @@ -24,6 +25,12 @@ func TestClient(t *testing.T) { } return "wrong" }}, + {Endpoint: "Foo.FuncWithRequestContextAndResponse", Response: func(ctx context.Context, req interface{}) string { + return "something" + }}, + {Endpoint: "Foo.FuncWithRequestContextAndResponseError", Response: func(ctx context.Context, req interface{}) (string, error) { + return "something", fmt.Errorf("mock error") + }}, } c := NewClient(Response("go.mock", response)) @@ -35,7 +42,9 @@ func TestClient(t *testing.T) { err := c.Call(context.TODO(), req, &rsp) if err != r.Error { - t.Fatalf("Expecter error %v got %v", r.Error, err) + if r.Endpoint != "Foo.FuncWithRequestContextAndResponseError" { + t.Fatalf("Expecter error %v got %v", r.Error, err) + } } t.Log(rsp)