QueryStrategy to allow querying routes based on Advertising Strategy
This commit is contained in:
		| @@ -15,6 +15,8 @@ type QueryOptions struct { | |||||||
| 	Network string | 	Network string | ||||||
| 	// Router is router id | 	// Router is router id | ||||||
| 	Router string | 	Router string | ||||||
|  | 	// Strategy is routing strategy | ||||||
|  | 	Strategy Strategy | ||||||
| } | } | ||||||
|  |  | ||||||
| // QueryService sets service to query | // QueryService sets service to query | ||||||
| @@ -52,6 +54,13 @@ func QueryRouter(r string) QueryOption { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // QueryStrategy sets strategy to query | ||||||
|  | func QueryStrategy(s Strategy) QueryOption { | ||||||
|  | 	return func(o *QueryOptions) { | ||||||
|  | 		o.Strategy = s | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| // NewQuery creates new query and returns it | // NewQuery creates new query and returns it | ||||||
| func NewQuery(opts ...QueryOption) QueryOptions { | func NewQuery(opts ...QueryOption) QueryOptions { | ||||||
| 	// default options | 	// default options | ||||||
| @@ -61,6 +70,7 @@ func NewQuery(opts ...QueryOption) QueryOptions { | |||||||
| 		Gateway:  "*", | 		Gateway:  "*", | ||||||
| 		Network:  "*", | 		Network:  "*", | ||||||
| 		Router:   "*", | 		Router:   "*", | ||||||
|  | 		Strategy: AdvertiseAll, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	for _, o := range opts { | 	for _, o := range opts { | ||||||
|   | |||||||
| @@ -146,6 +146,7 @@ type Advert struct { | |||||||
| // Strategy is route advertisement strategy | // Strategy is route advertisement strategy | ||||||
| type Strategy int | type Strategy int | ||||||
|  |  | ||||||
|  | // TODO: remove the "Advertise" prefix from these | ||||||
| const ( | const ( | ||||||
| 	// AdvertiseAll advertises all routes to the network | 	// AdvertiseAll advertises all routes to the network | ||||||
| 	AdvertiseAll Strategy = iota | 	AdvertiseAll Strategy = iota | ||||||
|   | |||||||
| @@ -135,7 +135,7 @@ func (t *table) List() ([]Route, error) { | |||||||
| } | } | ||||||
|  |  | ||||||
| // isMatch checks if the route matches given query options | // isMatch checks if the route matches given query options | ||||||
| func isMatch(route Route, address, gateway, network, router string) bool { | func isMatch(route Route, address, gateway, network, router string, strategy Strategy) bool { | ||||||
| 	// matches the values provided | 	// matches the values provided | ||||||
| 	match := func(a, b string) bool { | 	match := func(a, b string) bool { | ||||||
| 		if a == "*" || a == b { | 		if a == "*" || a == b { | ||||||
| @@ -150,12 +150,20 @@ func isMatch(route Route, address, gateway, network, router string) bool { | |||||||
| 		b string | 		b string | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// by default assume we are querying all routes | ||||||
|  | 	link := "*" | ||||||
|  | 	// if AdvertiseLocal change the link query accordingly | ||||||
|  | 	if strategy == AdvertiseLocal { | ||||||
|  | 		link = "local" | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	// compare the following values | 	// compare the following values | ||||||
| 	values := []compare{ | 	values := []compare{ | ||||||
| 		{gateway, route.Gateway}, | 		{gateway, route.Gateway}, | ||||||
| 		{network, route.Network}, | 		{network, route.Network}, | ||||||
| 		{router, route.Router}, | 		{router, route.Router}, | ||||||
| 		{address, route.Address}, | 		{address, route.Address}, | ||||||
|  | 		{link, route.Link}, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	for _, v := range values { | 	for _, v := range values { | ||||||
| @@ -169,13 +177,46 @@ func isMatch(route Route, address, gateway, network, router string) bool { | |||||||
| } | } | ||||||
|  |  | ||||||
| // findRoutes finds all the routes for given network and router and returns them | // findRoutes finds all the routes for given network and router and returns them | ||||||
| func findRoutes(routes map[uint64]Route, address, gateway, network, router string) []Route { | func findRoutes(routes map[uint64]Route, address, gateway, network, router string, strategy Strategy) []Route { | ||||||
| 	var results []Route | 	// routeMap stores the routes we're going to advertise | ||||||
|  | 	routeMap := make(map[string][]Route) | ||||||
|  |  | ||||||
| 	for _, route := range routes { | 	for _, route := range routes { | ||||||
| 		if isMatch(route, address, gateway, network, router) { | 		if isMatch(route, address, gateway, network, router, strategy) { | ||||||
| 			results = append(results, route) | 			// add matchihg route to the routeMap | ||||||
|  | 			routeKey := route.Service + "@" + route.Network | ||||||
|  | 			// append the first found route to routeMap | ||||||
|  | 			_, ok := routeMap[routeKey] | ||||||
|  | 			if !ok { | ||||||
|  | 				routeMap[routeKey] = append(routeMap[routeKey], route) | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			// if AdvertiseAll, keep appending | ||||||
|  | 			if strategy == AdvertiseAll || strategy == AdvertiseLocal { | ||||||
|  | 				routeMap[routeKey] = append(routeMap[routeKey], route) | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			// now we're going to find the best routes | ||||||
|  | 			if strategy == AdvertiseBest { | ||||||
|  | 				// if the current optimal route metric is higher than routing table route, replace it | ||||||
|  | 				if len(routeMap[routeKey]) > 0 { | ||||||
|  | 					// NOTE: we know that when AdvertiseBest is set, we only ever have one item in current | ||||||
|  | 					if routeMap[routeKey][0].Metric > route.Metric { | ||||||
|  | 						routeMap[routeKey][0] = route | ||||||
|  | 						continue | ||||||
| 					} | 					} | ||||||
| 				} | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	var results []Route | ||||||
|  | 	for _, route := range routeMap { | ||||||
|  | 		results = append(results, route...) | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	return results | 	return results | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -187,17 +228,24 @@ func (t *table) Query(q ...QueryOption) ([]Route, error) { | |||||||
| 	// create new query options | 	// create new query options | ||||||
| 	opts := NewQuery(q...) | 	opts := NewQuery(q...) | ||||||
|  |  | ||||||
|  | 	// create a cwslicelist of query results | ||||||
|  | 	results := make([]Route, 0, len(t.routes)) | ||||||
|  |  | ||||||
|  | 	// if No routes are queried, return early | ||||||
|  | 	if opts.Strategy == AdvertiseNone { | ||||||
|  | 		return results, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	if opts.Service != "*" { | 	if opts.Service != "*" { | ||||||
| 		if _, ok := t.routes[opts.Service]; !ok { | 		if _, ok := t.routes[opts.Service]; !ok { | ||||||
| 			return nil, ErrRouteNotFound | 			return nil, ErrRouteNotFound | ||||||
| 		} | 		} | ||||||
| 		return findRoutes(t.routes[opts.Service], opts.Address, opts.Gateway, opts.Network, opts.Router), nil | 		return findRoutes(t.routes[opts.Service], opts.Address, opts.Gateway, opts.Network, opts.Router, opts.Strategy), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	results := make([]Route, 0, len(t.routes)) |  | ||||||
| 	// search through all destinations | 	// search through all destinations | ||||||
| 	for _, routes := range t.routes { | 	for _, routes := range t.routes { | ||||||
| 		results = append(results, findRoutes(routes, opts.Address, opts.Gateway, opts.Network, opts.Router)...) | 		results = append(results, findRoutes(routes, opts.Address, opts.Gateway, opts.Network, opts.Router, opts.Strategy)...) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return results, nil | 	return results, nil | ||||||
|   | |||||||
| @@ -1,6 +1,9 @@ | |||||||
| package router | package router | ||||||
|  |  | ||||||
| import "testing" | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"testing" | ||||||
|  | ) | ||||||
|  |  | ||||||
| func testSetup() (*table, Route) { | func testSetup() (*table, Route) { | ||||||
| 	table := newTable() | 	table := newTable() | ||||||
| @@ -108,10 +111,10 @@ func TestList(t *testing.T) { | |||||||
| func TestQuery(t *testing.T) { | func TestQuery(t *testing.T) { | ||||||
| 	table, route := testSetup() | 	table, route := testSetup() | ||||||
|  |  | ||||||
| 	svc := []string{"svc1", "svc2", "svc3"} | 	svc := []string{"svc1", "svc2", "svc3", "svc1"} | ||||||
| 	net := []string{"net1", "net2", "net1"} | 	net := []string{"net1", "net2", "net1", "net3"} | ||||||
| 	gw := []string{"gw1", "gw2", "gw3"} | 	gw := []string{"gw1", "gw2", "gw3", "gw3"} | ||||||
| 	rtr := []string{"rtr1", "rt2", "rt3"} | 	rtr := []string{"rtr1", "rt2", "rt3", "rtr3"} | ||||||
|  |  | ||||||
| 	for i := 0; i < len(svc); i++ { | 	for i := 0; i < len(svc); i++ { | ||||||
| 		route.Service = svc[i] | 		route.Service = svc[i] | ||||||
| @@ -218,4 +221,70 @@ func TestQuery(t *testing.T) { | |||||||
| 	if len(routes) != 0 { | 	if len(routes) != 0 { | ||||||
| 		t.Errorf("incorrect number of routes returned. Expected: %d, found: %d", 0, len(routes)) | 		t.Errorf("incorrect number of routes returned. Expected: %d, found: %d", 0, len(routes)) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// query NO routes | ||||||
|  | 	query = []QueryOption{ | ||||||
|  | 		QueryGateway(gateway), | ||||||
|  | 		QueryNetwork(network), | ||||||
|  | 		QueryStrategy(AdvertiseNone), | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	routes, err = table.Query(query...) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Errorf("error looking up routes: %s", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if len(routes) > 0 { | ||||||
|  | 		t.Errorf("incorrect number of routes returned. Expected: %d, found: %d", 0, len(routes)) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// insert local routes to query | ||||||
|  | 	for i := 0; i < 2; i++ { | ||||||
|  | 		route.Link = "local" | ||||||
|  | 		route.Address = fmt.Sprintf("local.route.address-%d", i) | ||||||
|  | 		if err := table.Create(route); err != nil { | ||||||
|  | 			t.Errorf("error adding route: %s", err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// query local routes | ||||||
|  | 	query = []QueryOption{ | ||||||
|  | 		QueryGateway("*"), | ||||||
|  | 		QueryNetwork("*"), | ||||||
|  | 		QueryStrategy(AdvertiseLocal), | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	routes, err = table.Query(query...) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Errorf("error looking up routes: %s", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if len(routes) != 2 { | ||||||
|  | 		t.Errorf("incorrect number of routes returned. Expected: %d, found: %d", 2, len(routes)) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// add two different routes for svcX with different metric | ||||||
|  | 	for i := 0; i < 2; i++ { | ||||||
|  | 		route.Service = "svcX" | ||||||
|  | 		route.Address = fmt.Sprintf("svcX.route.address-%d", i) | ||||||
|  | 		route.Metric = int64(100 + i) | ||||||
|  | 		if err := table.Create(route); err != nil { | ||||||
|  | 			t.Errorf("error adding route: %s", err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// query best routes for svcX | ||||||
|  | 	query = []QueryOption{ | ||||||
|  | 		QueryService("svcX"), | ||||||
|  | 		QueryStrategy(AdvertiseBest), | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	routes, err = table.Query(query...) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Errorf("error looking up routes: %s", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if len(routes) != 1 { | ||||||
|  | 		t.Errorf("incorrect number of routes returned. Expected: %d, found: %d", 1, len(routes)) | ||||||
|  | 	} | ||||||
| } | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user