diff --git a/client/options.go b/client/options.go index 28a4e9c4..0ab07aa5 100644 --- a/client/options.go +++ b/client/options.go @@ -187,9 +187,9 @@ func DialTimeout(d time.Duration) Option { // Call Options -func WithSelectOption(so selector.SelectOption) CallOption { +func WithSelectOption(so ...selector.SelectOption) CallOption { return func(o *CallOptions) { - o.SelectOptions = append(o.SelectOptions, so) + o.SelectOptions = append(o.SelectOptions, so...) } } diff --git a/examples/client/dc_filter/dc_filter.go b/examples/client/dc_filter/dc_filter.go index 49d52a0e..9b018281 100644 --- a/examples/client/dc_filter/dc_filter.go +++ b/examples/client/dc_filter/dc_filter.go @@ -41,7 +41,7 @@ func (dc *dcWrapper) Call(ctx context.Context, req client.Request, rsp interface } callOptions := append(opts, client.WithSelectOption( - selector.Filter(filter), + selector.WithFilter(filter), )) fmt.Printf("[DC Wrapper] filtering for datacenter %s\n", md["datacenter"]) diff --git a/selector/filter.go b/selector/filter.go new file mode 100644 index 00000000..53c5af39 --- /dev/null +++ b/selector/filter.go @@ -0,0 +1,73 @@ +package selector + +import ( + "github.com/micro/go-micro/registry" +) + +// FilterEndpoint is an endpoint based Select Filter which will +// only return services with the endpoint specified. +func FilterEndpoint(name string) Filter { + return func(old []*registry.Service) []*registry.Service { + var services []*registry.Service + + for _, service := range old { + for _, ep := range service.Endpoints { + if ep.Name == name { + services = append(services, service) + break + } + } + } + + return services + } +} + +// FilterLabel is a label based Select Filter which will +// only return services with the label specified. +func FilterLabel(key, val string) Filter { + return func(old []*registry.Service) []*registry.Service { + var services []*registry.Service + + for _, service := range old { + serv := new(registry.Service) + var nodes []*registry.Node + + for _, node := range service.Nodes { + if node.Metadata == nil { + continue + } + + if node.Metadata[key] == val { + nodes = append(nodes, node) + } + } + + // only add service if there's some nodes + if len(nodes) > 0 { + // copy + *serv = *service + serv.Nodes = nodes + services = append(services, serv) + } + } + + return services + } +} + +// FilterVersion is a version based Select Filter which will +// only return services with the version specified. +func FilterVersion(version string) Filter { + return func(old []*registry.Service) []*registry.Service { + var services []*registry.Service + + for _, service := range old { + if service.Version == version { + services = append(services, service) + } + } + + return services + } +} diff --git a/selector/filter_test.go b/selector/filter_test.go new file mode 100644 index 00000000..0b92aa15 --- /dev/null +++ b/selector/filter_test.go @@ -0,0 +1,239 @@ +package selector + +import ( + "testing" + + "github.com/micro/go-micro/registry" +) + +func TestFilterEndpoint(t *testing.T) { + testData := []struct { + services []*registry.Service + endpoint string + count int + }{ + { + services: []*registry.Service{ + ®istry.Service{ + Name: "test", + Version: "1.0.0", + Endpoints: []*registry.Endpoint{ + ®istry.Endpoint{ + Name: "Foo.Bar", + }, + }, + }, + ®istry.Service{ + Name: "test", + Version: "1.1.0", + Endpoints: []*registry.Endpoint{ + ®istry.Endpoint{ + Name: "Baz.Bar", + }, + }, + }, + }, + endpoint: "Foo.Bar", + count: 1, + }, + { + services: []*registry.Service{ + ®istry.Service{ + Name: "test", + Version: "1.0.0", + Endpoints: []*registry.Endpoint{ + ®istry.Endpoint{ + Name: "Foo.Bar", + }, + }, + }, + ®istry.Service{ + Name: "test", + Version: "1.1.0", + Endpoints: []*registry.Endpoint{ + ®istry.Endpoint{ + Name: "Foo.Bar", + }, + }, + }, + }, + endpoint: "Bar.Baz", + count: 0, + }, + } + + for _, data := range testData { + filter := FilterEndpoint(data.endpoint) + services := filter(data.services) + + if len(services) != data.count { + t.Fatalf("Expected %d services, got %d", data.count, len(services)) + } + + for _, service := range services { + var seen bool + + for _, ep := range service.Endpoints { + if ep.Name == data.endpoint { + seen = true + break + } + } + + if seen == false && data.count > 0 { + t.Fatalf("Expected %d services but seen is %t; result %+v", data.count, seen, services) + } + } + } +} + +func TestFilterLabel(t *testing.T) { + testData := []struct { + services []*registry.Service + label [2]string + count int + }{ + { + services: []*registry.Service{ + ®istry.Service{ + Name: "test", + Version: "1.0.0", + Nodes: []*registry.Node{ + ®istry.Node{ + Id: "test-1", + Address: "localhost", + Metadata: map[string]string{ + "foo": "bar", + }, + }, + }, + }, + ®istry.Service{ + Name: "test", + Version: "1.1.0", + Nodes: []*registry.Node{ + ®istry.Node{ + Id: "test-2", + Address: "localhost", + Metadata: map[string]string{ + "foo": "baz", + }, + }, + }, + }, + }, + label: [2]string{"foo", "bar"}, + count: 1, + }, + { + services: []*registry.Service{ + ®istry.Service{ + Name: "test", + Version: "1.0.0", + Nodes: []*registry.Node{ + ®istry.Node{ + Id: "test-1", + Address: "localhost", + }, + }, + }, + ®istry.Service{ + Name: "test", + Version: "1.1.0", + Nodes: []*registry.Node{ + ®istry.Node{ + Id: "test-2", + Address: "localhost", + }, + }, + }, + }, + label: [2]string{"foo", "bar"}, + count: 0, + }, + } + + for _, data := range testData { + filter := FilterLabel(data.label[0], data.label[1]) + services := filter(data.services) + + if len(services) != data.count { + t.Fatalf("Expected %d services, got %d", data.count, len(services)) + } + + for _, service := range services { + var seen bool + + for _, node := range service.Nodes { + if node.Metadata[data.label[0]] != data.label[1] { + t.Fatal("Expected %s=%s but got %s=%s for service %+v node %+v", + data.label[0], data.label[1], data.label[0], node.Metadata[data.label[0]], service, node) + } + seen = true + } + + if !seen { + t.Fatalf("Expected node for %s=%s but saw none; results %+v", data.label[0], data.label[1], service) + } + } + } +} + +func TestFilterVersion(t *testing.T) { + testData := []struct { + services []*registry.Service + version string + count int + }{ + { + services: []*registry.Service{ + ®istry.Service{ + Name: "test", + Version: "1.0.0", + }, + ®istry.Service{ + Name: "test", + Version: "1.1.0", + }, + }, + version: "1.0.0", + count: 1, + }, + { + services: []*registry.Service{ + ®istry.Service{ + Name: "test", + Version: "1.0.0", + }, + ®istry.Service{ + Name: "test", + Version: "1.1.0", + }, + }, + version: "2.0.0", + count: 0, + }, + } + + for _, data := range testData { + filter := FilterVersion(data.version) + services := filter(data.services) + + if len(services) != data.count { + t.Fatalf("Expected %d services, got %d", data.count, len(services)) + } + + var seen bool + + for _, service := range services { + if service.Version != data.version { + t.Fatalf("Expected version %s, got %s", data.version, service.Version) + } + seen = true + } + + if seen == false && data.count > 0 { + t.Fatalf("Expected %d services but seen is %t; result %+v", data.count, seen, services) + } + } +} diff --git a/selector/options.go b/selector/options.go index 13973489..47bc81c0 100644 --- a/selector/options.go +++ b/selector/options.go @@ -15,7 +15,7 @@ type Options struct { } type SelectOptions struct { - Filters []SelectFilter + Filters []Filter // Other options for implementations of the interface // can be stored in a context @@ -35,10 +35,10 @@ func Registry(r registry.Registry) Option { } } -// Filter adds a filter function to the list of filters +// WithFilter adds a filter function to the list of filters // used during the Select call. -func Filter(fn SelectFilter) SelectOption { +func WithFilter(fn ...Filter) SelectOption { return func(o *SelectOptions) { - o.Filters = append(o.Filters, fn) + o.Filters = append(o.Filters, fn...) } } diff --git a/selector/selector.go b/selector/selector.go index 551b0a41..42ccd300 100644 --- a/selector/selector.go +++ b/selector/selector.go @@ -82,8 +82,8 @@ type Selector interface { // based on the selector's algorithm type Next func() (*registry.Node, error) -// SelectFilter is used to filter a service during the selection process -type SelectFilter func([]*registry.Service) []*registry.Service +// Filter is used to filter a service during the selection process +type Filter func([]*registry.Service) []*registry.Service var ( DefaultSelector = newRandomSelector()