diff --git a/selector/filter.go b/selector/filter.go index 0547b59e..53c5af39 100644 --- a/selector/filter.go +++ b/selector/filter.go @@ -4,6 +4,25 @@ import ( "github.com/micro/go-micro/registry" ) +// FilterEndpoint is an endpoint based Select Filter which will +// only return services with the endpoint specified. +func FilterEndpoint(name string) Filter { + return func(old []*registry.Service) []*registry.Service { + var services []*registry.Service + + for _, service := range old { + for _, ep := range service.Endpoints { + if ep.Name == name { + services = append(services, service) + break + } + } + } + + return services + } +} + // FilterLabel is a label based Select Filter which will // only return services with the label specified. func FilterLabel(key, val string) Filter { diff --git a/selector/filter_test.go b/selector/filter_test.go index 793c00f9..0b92aa15 100644 --- a/selector/filter_test.go +++ b/selector/filter_test.go @@ -6,6 +6,87 @@ import ( "github.com/micro/go-micro/registry" ) +func TestFilterEndpoint(t *testing.T) { + testData := []struct { + services []*registry.Service + endpoint string + count int + }{ + { + services: []*registry.Service{ + ®istry.Service{ + Name: "test", + Version: "1.0.0", + Endpoints: []*registry.Endpoint{ + ®istry.Endpoint{ + Name: "Foo.Bar", + }, + }, + }, + ®istry.Service{ + Name: "test", + Version: "1.1.0", + Endpoints: []*registry.Endpoint{ + ®istry.Endpoint{ + Name: "Baz.Bar", + }, + }, + }, + }, + endpoint: "Foo.Bar", + count: 1, + }, + { + services: []*registry.Service{ + ®istry.Service{ + Name: "test", + Version: "1.0.0", + Endpoints: []*registry.Endpoint{ + ®istry.Endpoint{ + Name: "Foo.Bar", + }, + }, + }, + ®istry.Service{ + Name: "test", + Version: "1.1.0", + Endpoints: []*registry.Endpoint{ + ®istry.Endpoint{ + Name: "Foo.Bar", + }, + }, + }, + }, + endpoint: "Bar.Baz", + count: 0, + }, + } + + for _, data := range testData { + filter := FilterEndpoint(data.endpoint) + services := filter(data.services) + + if len(services) != data.count { + t.Fatalf("Expected %d services, got %d", data.count, len(services)) + } + + for _, service := range services { + var seen bool + + for _, ep := range service.Endpoints { + if ep.Name == data.endpoint { + seen = true + break + } + } + + if seen == false && data.count > 0 { + t.Fatalf("Expected %d services but seen is %t; result %+v", data.count, seen, services) + } + } + } +} + func TestFilterLabel(t *testing.T) { testData := []struct { services []*registry.Service