diff --git a/registry/consul_registry_test.go b/registry/consul_registry_test.go index 262b15bb..7441a22c 100644 --- a/registry/consul_registry_test.go +++ b/registry/consul_registry_test.go @@ -4,18 +4,74 @@ import ( "bytes" "encoding/json" "errors" - "fmt" - "io/ioutil" + "net" "net/http" "testing" consul "github.com/hashicorp/consul/api" ) -func TestConsul_GetService_WithError(t *testing.T) { - cr := newConsulTestRegistry(&mockTransport{ - err: errors.New("client-error"), +type mockRegistry struct { + body []byte + status int + err error + url string +} + +func encodeData(obj interface{}) ([]byte, error) { + buf := bytes.NewBuffer(nil) + enc := json.NewEncoder(buf) + if err := enc.Encode(obj); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +func newMockServer(rg *mockRegistry, l net.Listener) error { + mux := http.NewServeMux() + mux.HandleFunc(rg.url, func(w http.ResponseWriter, r *http.Request) { + if rg.err != nil { + http.Error(w, rg.err.Error(), 500) + return + } + w.WriteHeader(rg.status) + w.Write(rg.body) }) + return http.Serve(l, mux) +} + +func newConsulTestRegistry(r *mockRegistry) (*consulRegistry, func()) { + l, err := net.Listen("tcp", ":0") + if err != nil { + // blurgh?!! + panic(err.Error()) + } + cfg := consul.DefaultConfig() + cfg.Address = l.Addr().String() + cl, _ := consul.NewClient(cfg) + + go newMockServer(r, l) + + return &consulRegistry{ + Address: cfg.Address, + Client: cl, + register: make(map[string]uint64), + }, func() { + l.Close() + } +} + +func newServiceList(svc []*consul.ServiceEntry) []byte { + bts, _ := encodeData(svc) + return bts +} + +func TestConsul_GetService_WithError(t *testing.T) { + cr, cl := newConsulTestRegistry(&mockRegistry{ + err: errors.New("client-error"), + url: "/v1/health/service/service-name", + }) + defer cl() if _, err := cr.GetService("test-service"); err == nil { t.Fatalf("Expected error not to be `nil`") @@ -41,11 +97,12 @@ func TestConsul_GetService_WithHealthyServiceNodes(t *testing.T) { ), } - cr := newConsulTestRegistry(&mockTransport{ + cr, cl := newConsulTestRegistry(&mockRegistry{ status: 200, body: newServiceList(svcs), url: "/v1/health/service/service-name", }) + defer cl() svc, _ := cr.GetService("service-name") if exp, act := 1, len(svc); exp != act { @@ -76,11 +133,12 @@ func TestConsul_GetService_WithUnhealthyServiceNode(t *testing.T) { ), } - cr := newConsulTestRegistry(&mockTransport{ + cr, cl := newConsulTestRegistry(&mockRegistry{ status: 200, body: newServiceList(svcs), url: "/v1/health/service/service-name", }) + defer cl() svc, _ := cr.GetService("service-name") if exp, act := 1, len(svc); exp != act { @@ -111,11 +169,12 @@ func TestConsul_GetService_WithUnhealthyServiceNodes(t *testing.T) { ), } - cr := newConsulTestRegistry(&mockTransport{ + cr, cl := newConsulTestRegistry(&mockRegistry{ status: 200, body: newServiceList(svcs), url: "/v1/health/service/service-name", }) + defer cl() svc, _ := cr.GetService("service-name") if exp, act := 1, len(svc); exp != act { @@ -126,58 +185,3 @@ func TestConsul_GetService_WithUnhealthyServiceNodes(t *testing.T) { t.Fatalf("Expected len of nodes to be `%d`, got `%d`.", exp, act) } } - -func newServiceList(svc []*consul.ServiceEntry) []byte { - bts, _ := encodeData(svc) - return bts -} - -func newConsulTestRegistry(t *mockTransport) *consulRegistry { - cfg := &consul.Config{ - HttpClient: mockHttpClient(t), - } - cl, _ := consul.NewClient(cfg) - - return &consulRegistry{ - Address: cfg.Address, - Client: cl, - register: make(map[string]uint64), - } -} - -func mockHttpClient(t *mockTransport) *http.Client { - return &http.Client{ - Transport: t, - } -} - -type mockTransport struct { - body []byte - status int - err error - url string -} - -func encodeData(obj interface{}) ([]byte, error) { - buf := bytes.NewBuffer(nil) - enc := json.NewEncoder(buf) - if err := enc.Encode(obj); err != nil { - return nil, err - } - return buf.Bytes(), nil -} - -func (t *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) { - if t.err != nil { - return nil, t.err - } - - if t.url != "" && fmt.Sprintf("http://127.0.0.1:8500%s", t.url) != req.URL.String() { - return nil, errors.New("URLs do not match") - } - - return &http.Response{ - StatusCode: t.status, - Body: ioutil.NopCloser(bytes.NewReader(t.body)), - }, nil -}