add ability to filter routes based on headers
This commit is contained in:
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/micro/go-micro/codec/bytes"
|
||||
"github.com/micro/go-micro/config/options"
|
||||
"github.com/micro/go-micro/errors"
|
||||
"github.com/micro/go-micro/metadata"
|
||||
"github.com/micro/go-micro/proxy"
|
||||
"github.com/micro/go-micro/router"
|
||||
"github.com/micro/go-micro/server"
|
||||
@@ -105,6 +106,63 @@ func toSlice(r map[uint64]router.Route) []router.Route {
|
||||
return routes
|
||||
}
|
||||
|
||||
func (p *Proxy) filterRoutes(ctx context.Context, routes []router.Route) []router.Route {
|
||||
md, ok := metadata.FromContext(ctx)
|
||||
if !ok {
|
||||
return routes
|
||||
}
|
||||
|
||||
var filteredRoutes []router.Route
|
||||
|
||||
// filter the routes based on our headers
|
||||
for _, route := range routes {
|
||||
// process only routes for this id
|
||||
if id := md["Micro-Router"]; len(id) > 0 {
|
||||
if route.Router != id {
|
||||
// skip routes that don't mwatch
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// only process routes with this network
|
||||
if net := md["Micro-Network"]; len(net) > 0 {
|
||||
if route.Network != net {
|
||||
// skip routes that don't mwatch
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// process only this gateway
|
||||
if gw := md["Micro-Gateway"]; len(gw) > 0 {
|
||||
// if the gateway matches our address
|
||||
// special case, take the routes with no gateway
|
||||
// TODO: should we strip the gateway from the context?
|
||||
if gw == p.Router.Options().Address {
|
||||
if len(route.Gateway) > 0 && route.Gateway != gw {
|
||||
continue
|
||||
}
|
||||
// otherwise its a local route and we're keeping it
|
||||
} else {
|
||||
// gateway does not match our own
|
||||
if route.Gateway != gw {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: address based filtering
|
||||
// address := md["Micro-Address"]
|
||||
|
||||
// TODO: label based filtering
|
||||
// requires new field in routing table : route.Labels
|
||||
|
||||
// passed the filter checks
|
||||
filteredRoutes = append(filteredRoutes, route)
|
||||
}
|
||||
|
||||
return filteredRoutes
|
||||
}
|
||||
|
||||
func (p *Proxy) getLink(r router.Route) (client.Client, error) {
|
||||
if r.Link == "local" || len(p.Links) == 0 {
|
||||
return p.Client, nil
|
||||
@@ -116,13 +174,14 @@ func (p *Proxy) getLink(r router.Route) (client.Client, error) {
|
||||
return l, nil
|
||||
}
|
||||
|
||||
func (p *Proxy) getRoute(service string) ([]router.Route, error) {
|
||||
func (p *Proxy) getRoute(ctx context.Context, service string) ([]router.Route, error) {
|
||||
// lookup the route cache first
|
||||
p.Lock()
|
||||
cached, ok := p.Routes[service]
|
||||
if ok {
|
||||
p.Unlock()
|
||||
return toSlice(cached), nil
|
||||
routes := toSlice(cached)
|
||||
return p.filterRoutes(ctx, routes), nil
|
||||
}
|
||||
p.Unlock()
|
||||
|
||||
@@ -132,7 +191,7 @@ func (p *Proxy) getRoute(service string) ([]router.Route, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return routes, nil
|
||||
return p.filterRoutes(ctx, routes), nil
|
||||
}
|
||||
|
||||
func (p *Proxy) cacheRoutes(service string) ([]router.Route, error) {
|
||||
@@ -255,7 +314,7 @@ func (p *Proxy) ServeRequest(ctx context.Context, req server.Request, rsp server
|
||||
addresses = []string{p.Endpoint}
|
||||
} else {
|
||||
// get route for endpoint from router
|
||||
addr, err := p.getRoute(p.Endpoint)
|
||||
addr, err := p.getRoute(ctx, p.Endpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -267,7 +326,7 @@ func (p *Proxy) ServeRequest(ctx context.Context, req server.Request, rsp server
|
||||
} else {
|
||||
// no endpoint was specified just lookup the route
|
||||
// get route for endpoint from router
|
||||
addr, err := p.getRoute(service)
|
||||
addr, err := p.getRoute(ctx, service)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
Reference in New Issue
Block a user