From 83c234f21ab2c1cf6420e0a780a380b3c9314463 Mon Sep 17 00:00:00 2001 From: Asim Aslam Date: Mon, 27 Jul 2020 13:22:00 +0100 Subject: [PATCH] v3 refactor (#1868) * Move to v3 Co-authored-by: Ben Toogood --- registry.go | 723 ++++++++++++++++++++++++++++++++++++++++++++++- registry_test.go | 133 +++++++++ table.go | 321 +++++++++++++++++++++ table_test.go | 350 +++++++++++++++++++++++ watcher.go | 52 ++++ 5 files changed, 1575 insertions(+), 4 deletions(-) create mode 100644 registry_test.go create mode 100644 table.go create mode 100644 table_test.go create mode 100644 watcher.go diff --git a/registry.go b/registry.go index 01172b4..b585965 100644 --- a/registry.go +++ b/registry.go @@ -1,8 +1,723 @@ package registry -import "github.com/micro/go-micro/v2/router" +import ( + "errors" + "fmt" + "sort" + "strings" + "sync" + "time" -// NewRouter returns an initialised registry router -func NewRouter(opts ...router.Option) router.Router { - return router.NewRouter(opts...) + "github.com/google/uuid" + "github.com/micro/go-micro/v3/logger" + "github.com/micro/go-micro/v3/registry" + "github.com/micro/go-micro/v3/router" +) + +var ( + // AdvertiseEventsTick is time interval in which the router advertises route updates + AdvertiseEventsTick = 10 * time.Second + // DefaultAdvertTTL is default advertisement TTL + DefaultAdvertTTL = 2 * time.Minute +) + +// rtr implements router interface +type rtr struct { + sync.RWMutex + + running bool + table *table + options router.Options + exit chan bool + eventChan chan *router.Event + + // advert subscribers + sub sync.RWMutex + subscribers map[string]chan *router.Advert +} + +// NewRouter creates new router and returns it +func NewRouter(opts ...router.Option) router.Router { + // get default options + options := router.DefaultOptions() + + // apply requested options + for _, o := range opts { + o(&options) + } + + // construct the router + r := &rtr{ + options: options, + subscribers: make(map[string]chan *router.Advert), + } + + // create the new table, passing the fetchRoute method in as a fallback if + // the table doesn't contain the result for a query. + r.table = newTable(r.fetchRoutes) + return r +} + +// Init initializes router with given options +func (r *rtr) Init(opts ...router.Option) error { + // stop the router before we initialize + if err := r.Close(); err != nil { + return err + } + + r.Lock() + defer r.Unlock() + + for _, o := range opts { + o(&r.options) + } + + return nil +} + +// Options returns router options +func (r *rtr) Options() router.Options { + r.RLock() + defer r.RUnlock() + + options := r.options + + return options +} + +// Table returns routing table +func (r *rtr) Table() router.Table { + r.Lock() + defer r.Unlock() + return r.table +} + +// manageRoute applies action on a given route +func (r *rtr) manageRoute(route router.Route, action string) error { + switch action { + case "create": + if err := r.table.Create(route); err != nil && err != router.ErrDuplicateRoute { + return fmt.Errorf("failed adding route for service %s: %s", route.Service, err) + } + case "delete": + if err := r.table.Delete(route); err != nil && err != router.ErrRouteNotFound { + return fmt.Errorf("failed deleting route for service %s: %s", route.Service, err) + } + case "update": + if err := r.table.Update(route); err != nil { + return fmt.Errorf("failed updating route for service %s: %s", route.Service, err) + } + default: + return fmt.Errorf("failed to manage route for service %s: unknown action %s", route.Service, action) + } + + return nil +} + +// manageServiceRoutes applies action to all routes of the service. +// It returns error of the action fails with error. +func (r *rtr) manageRoutes(service *registry.Service, action, network string) error { + // action is the routing table action + action = strings.ToLower(action) + + // take route action on each service node + for _, node := range service.Nodes { + route := router.Route{ + Service: service.Name, + Address: node.Address, + Gateway: "", + Network: network, + Router: r.options.Id, + Link: router.DefaultLink, + Metric: router.DefaultLocalMetric, + Metadata: node.Metadata, + } + + if err := r.manageRoute(route, action); err != nil { + return err + } + } + + return nil +} + +// manageRegistryRoutes applies action to all routes of each service found in the registry. +// It returns error if either the services failed to be listed or the routing table action fails. +func (r *rtr) manageRegistryRoutes(reg registry.Registry, action string) error { + services, err := reg.ListServices(registry.ListDomain(registry.WildcardDomain)) + if err != nil { + return fmt.Errorf("failed listing services: %v", err) + } + + // add each service node as a separate route + for _, service := range services { + // get the services domain from metadata. Fallback to wildcard. + var domain string + if service.Metadata != nil && len(service.Metadata["domain"]) > 0 { + domain = service.Metadata["domain"] + } else { + domain = registry.WildcardDomain + } + + // get the service to retrieve all its info + srvs, err := reg.GetService(service.Name, registry.GetDomain(domain)) + if err != nil { + continue + } + // manage the routes for all returned services + for _, srv := range srvs { + if err := r.manageRoutes(srv, action, domain); err != nil { + return err + } + } + } + + return nil +} + +// fetchRoutes retrieves all the routes for a given service and creates them in the routing table +func (r *rtr) fetchRoutes(service string) error { + services, err := r.options.Registry.GetService(service, registry.GetDomain(registry.WildcardDomain)) + if err == registry.ErrNotFound { + return nil + } else if err != nil { + return fmt.Errorf("failed getting services: %v", err) + } + + for _, srv := range services { + var domain string + if srv.Metadata != nil && len(srv.Metadata["domain"]) > 0 { + domain = srv.Metadata["domain"] + } else { + domain = registry.WildcardDomain + } + + if err := r.manageRoutes(srv, "create", domain); err != nil { + return err + } + } + + return nil +} + +// watchRegistry watches registry and updates routing table based on the received events. +// It returns error if either the registry watcher fails with error or if the routing table update fails. +func (r *rtr) watchRegistry(w registry.Watcher) error { + exit := make(chan bool) + + defer func() { + close(exit) + }() + + go func() { + defer w.Stop() + + select { + case <-exit: + return + case <-r.exit: + return + } + }() + + for { + res, err := w.Next() + if err != nil { + if err != registry.ErrWatcherStopped { + return err + } + break + } + + if res.Service == nil { + continue + } + + // get the services domain from metadata. Fallback to wildcard. + var domain string + if res.Service.Metadata != nil && len(res.Service.Metadata["domain"]) > 0 { + domain = res.Service.Metadata["domain"] + } else { + domain = registry.WildcardDomain + } + + if err := r.manageRoutes(res.Service, res.Action, domain); err != nil { + return err + } + } + + return nil +} + +// watchTable watches routing table entries and either adds or deletes locally registered service to/from network registry +// It returns error if the locally registered services either fails to be added/deleted to/from network registry. +func (r *rtr) watchTable(w router.Watcher) error { + exit := make(chan bool) + + defer func() { + close(exit) + }() + + // wait in the background for the router to stop + // when the router stops, stop the watcher and exit + go func() { + defer w.Stop() + + select { + case <-r.exit: + return + case <-exit: + return + } + }() + + for { + event, err := w.Next() + if err != nil { + if err != router.ErrWatcherStopped { + return err + } + break + } + + select { + case <-r.exit: + return nil + case r.eventChan <- event: + // process event + } + } + + return nil +} + +// publishAdvert publishes router advert to advert channel +func (r *rtr) publishAdvert(advType router.AdvertType, events []*router.Event) { + a := &router.Advert{ + Id: r.options.Id, + Type: advType, + TTL: DefaultAdvertTTL, + Timestamp: time.Now(), + Events: events, + } + + r.sub.RLock() + for _, sub := range r.subscribers { + // now send the message + select { + case sub <- a: + case <-r.exit: + r.sub.RUnlock() + return + } + } + r.sub.RUnlock() +} + +// adverts maintains a map of router adverts +type adverts map[uint64]*router.Event + +// advertiseEvents advertises routing table events +// It suppresses unhealthy flapping events and advertises healthy events upstream. +func (r *rtr) advertiseEvents() error { + // ticker to periodically scan event for advertising + ticker := time.NewTicker(AdvertiseEventsTick) + defer ticker.Stop() + + // adverts is a map of advert events + adverts := make(adverts) + + // routing table watcher + w, err := r.Watch() + if err != nil { + return err + } + defer w.Stop() + + go func() { + var err error + + for { + select { + case <-r.exit: + return + default: + if w == nil { + // routing table watcher + w, err = r.Watch() + if err != nil { + if logger.V(logger.ErrorLevel, logger.DefaultLogger) { + logger.Errorf("Error creating watcher: %v", err) + } + time.Sleep(time.Second) + continue + } + } + + if err := r.watchTable(w); err != nil { + if logger.V(logger.ErrorLevel, logger.DefaultLogger) { + logger.Errorf("Error watching table: %v", err) + } + time.Sleep(time.Second) + } + + if w != nil { + // reset + w.Stop() + w = nil + } + } + } + }() + + for { + select { + case <-ticker.C: + // If we're not advertising any events then sip processing them entirely + if r.options.Advertise == router.AdvertiseNone { + continue + } + + var events []*router.Event + + // collect all events which are not flapping + for key, event := range adverts { + // if we only advertise local routes skip processing anything not link local + if r.options.Advertise == router.AdvertiseLocal && event.Route.Link != "local" { + continue + } + + // copy the event and append + e := new(router.Event) + // this is ok, because router.Event only contains builtin types + // and no references so this creates a deep copy of struct Event + *e = *event + events = append(events, e) + // delete the advert from adverts + delete(adverts, key) + } + + // advertise events to subscribers + if len(events) > 0 { + if logger.V(logger.DebugLevel, logger.DefaultLogger) { + logger.Debugf("Router publishing %d events", len(events)) + } + go r.publishAdvert(router.RouteUpdate, events) + } + case e := <-r.eventChan: + // if event is nil, continue + if e == nil { + continue + } + + // If we're not advertising any events then skip processing them entirely + if r.options.Advertise == router.AdvertiseNone { + continue + } + + // if we only advertise local routes skip processing anything not link local + if r.options.Advertise == router.AdvertiseLocal && e.Route.Link != "local" { + continue + } + + if logger.V(logger.DebugLevel, logger.DefaultLogger) { + logger.Debugf("Router processing table event %s for service %s %s", e.Type, e.Route.Service, e.Route.Address) + } + + // check if we have already registered the route + hash := e.Route.Hash() + ev, ok := adverts[hash] + if !ok { + ev = e + adverts[hash] = e + continue + } + + // override the route event only if the previous event was different + if ev.Type != e.Type { + ev = e + } + case <-r.exit: + if w != nil { + w.Stop() + } + return nil + } + } +} + +// drain all the events, only called on Stop +func (r *rtr) drain() { + for { + select { + case <-r.eventChan: + default: + return + } + } +} + +// start the router. Should be called under lock. +func (r *rtr) start() error { + if r.running { + return nil + } + + if r.options.Precache { + // add all local service routes into the routing table + if err := r.manageRegistryRoutes(r.options.Registry, "create"); err != nil { + return fmt.Errorf("failed adding registry routes: %s", err) + } + } + + // add default gateway into routing table + if r.options.Gateway != "" { + // note, the only non-default value is the gateway + route := router.Route{ + Service: "*", + Address: "*", + Gateway: r.options.Gateway, + Network: "*", + Router: r.options.Id, + Link: router.DefaultLink, + Metric: router.DefaultLocalMetric, + } + if err := r.table.Create(route); err != nil { + return fmt.Errorf("failed adding default gateway route: %s", err) + } + } + + // create error and exit channels + r.exit = make(chan bool) + + // registry watcher + w, err := r.options.Registry.Watch(registry.WatchDomain(registry.WildcardDomain)) + if err != nil { + return fmt.Errorf("failed creating registry watcher: %v", err) + } + + go func() { + var err error + + for { + select { + case <-r.exit: + if w != nil { + w.Stop() + } + return + default: + if w == nil { + w, err = r.options.Registry.Watch() + if err != nil { + if logger.V(logger.WarnLevel, logger.DefaultLogger) { + logger.Warnf("failed creating registry watcher: %v", err) + } + time.Sleep(time.Second) + continue + } + } + + if err := r.watchRegistry(w); err != nil { + if logger.V(logger.WarnLevel, logger.DefaultLogger) { + logger.Warnf("Error watching the registry: %v", err) + } + time.Sleep(time.Second) + } + + if w != nil { + w.Stop() + w = nil + } + } + } + }() + + r.running = true + + return nil +} + +// Advertise stars advertising the routes to the network and returns the advertisements channel to consume from. +// If the router is already advertising it returns the channel to consume from. +// It returns error if either the router is not running or if the routing table fails to list the routes to advertise. +func (r *rtr) Advertise() (<-chan *router.Advert, error) { + r.Lock() + defer r.Unlock() + + if r.running { + return nil, errors.New("cannot re-advertise, already running") + } + + // start the router + r.start() + + // we're mutating the subscribers so they need to be locked also + r.sub.Lock() + defer r.sub.Unlock() + + // already advertising + if r.eventChan != nil { + advertChan := make(chan *router.Advert, 128) + r.subscribers[uuid.New().String()] = advertChan + return advertChan, nil + } + + // list all the routes and pack them into even slice to advertise + events, err := r.flushRouteEvents(router.Create) + if err != nil { + return nil, fmt.Errorf("failed to flush routes: %s", err) + } + + // create event channels + r.eventChan = make(chan *router.Event) + + // create advert channel + advertChan := make(chan *router.Advert, 128) + r.subscribers[uuid.New().String()] = advertChan + + // advertise your presence + go r.publishAdvert(router.Announce, events) + + go func() { + select { + case <-r.exit: + return + default: + if err := r.advertiseEvents(); err != nil { + if logger.V(logger.ErrorLevel, logger.DefaultLogger) { + logger.Errorf("Error adveritising events: %v", err) + } + } + } + }() + + return advertChan, nil + +} + +// Process updates the routing table using the advertised values +func (r *rtr) Process(a *router.Advert) error { + // NOTE: event sorting might not be necessary + // copy update events intp new slices + events := make([]*router.Event, len(a.Events)) + copy(events, a.Events) + // sort events by timestamp + sort.Slice(events, func(i, j int) bool { + return events[i].Timestamp.Before(events[j].Timestamp) + }) + + if logger.V(logger.TraceLevel, logger.DefaultLogger) { + logger.Tracef("Router %s processing advert from: %s", r.options.Id, a.Id) + } + + for _, event := range events { + // skip if the router is the origin of this route + if event.Route.Router == r.options.Id { + if logger.V(logger.TraceLevel, logger.DefaultLogger) { + logger.Tracef("Router skipping processing its own route: %s", r.options.Id) + } + continue + } + // create a copy of the route + route := event.Route + action := event.Type + + if logger.V(logger.TraceLevel, logger.DefaultLogger) { + logger.Tracef("Router %s applying %s from router %s for service %s %s", r.options.Id, action, route.Router, route.Service, route.Address) + } + + if err := r.manageRoute(route, action.String()); err != nil { + return fmt.Errorf("failed applying action %s to routing table: %s", action, err) + } + } + + return nil +} + +// flushRouteEvents returns a slice of events, one per each route in the routing table +func (r *rtr) flushRouteEvents(evType router.EventType) ([]*router.Event, error) { + // get a list of routes for each service in our routing table + // for the configured advertising strategy + q := []router.QueryOption{ + router.QueryStrategy(r.options.Advertise), + } + + routes, err := r.table.Query(q...) + if err != nil && err != router.ErrRouteNotFound { + return nil, err + } + + if logger.V(logger.DebugLevel, logger.DefaultLogger) { + logger.Debugf("Router advertising %d routes with strategy %s", len(routes), r.options.Advertise) + } + + // build a list of events to advertise + events := make([]*router.Event, len(routes)) + var i int + + for _, route := range routes { + event := &router.Event{ + Type: evType, + Timestamp: time.Now(), + Route: route, + } + events[i] = event + i++ + } + + return events, nil +} + +// Lookup routes in the routing table +func (r *rtr) Lookup(q ...router.QueryOption) ([]router.Route, error) { + return r.Table().Query(q...) +} + +// Watch routes +func (r *rtr) Watch(opts ...router.WatchOption) (router.Watcher, error) { + return r.table.Watch(opts...) +} + +// Close the router +func (r *rtr) Close() error { + r.Lock() + defer r.Unlock() + + select { + case <-r.exit: + return nil + default: + if !r.running { + return nil + } + close(r.exit) + + // extract the events + r.drain() + + r.sub.Lock() + // close advert subscribers + for id, sub := range r.subscribers { + // close the channel + close(sub) + // delete the subscriber + delete(r.subscribers, id) + } + r.sub.Unlock() + } + + // close and remove event chan + if r.eventChan != nil { + close(r.eventChan) + r.eventChan = nil + } + + r.running = false + return nil +} + +// String prints debugging information about router +func (r *rtr) String() string { + return "registry" } diff --git a/registry_test.go b/registry_test.go new file mode 100644 index 0000000..0c1af95 --- /dev/null +++ b/registry_test.go @@ -0,0 +1,133 @@ +package registry + +import ( + "fmt" + "os" + "sync" + "testing" + "time" + + "github.com/micro/go-micro/v3/registry/memory" + "github.com/micro/go-micro/v3/router" +) + +func routerTestSetup() router.Router { + r := memory.NewRegistry() + return NewRouter(router.Registry(r)) +} + +func TestRouterClose(t *testing.T) { + r := routerTestSetup() + + _, err := r.Advertise() + if err != nil { + t.Errorf("failed to start advertising: %v", err) + } + + if err := r.Close(); err != nil { + t.Errorf("failed to stop router: %v", err) + } + if len(os.Getenv("IN_TRAVIS_CI")) == 0 { + t.Logf("TestRouterStartStop STOPPED") + } +} + +func TestRouterAdvertise(t *testing.T) { + r := routerTestSetup() + + // lower the advertise interval + AdvertiseEventsTick = 500 * time.Millisecond + + ch, err := r.Advertise() + if err != nil { + t.Errorf("failed to start advertising: %v", err) + } + + // receive announce event + ann := <-ch + if len(os.Getenv("IN_TRAVIS_CI")) == 0 { + t.Logf("received announce advert: %v", ann) + } + + // Generate random unique routes + nrRoutes := 5 + routes := make([]router.Route, nrRoutes) + route := router.Route{ + Service: "dest.svc", + Address: "dest.addr", + Gateway: "dest.gw", + Network: "dest.network", + Router: "src.router", + Link: "local", + Metric: 10, + } + + for i := 0; i < nrRoutes; i++ { + testRoute := route + testRoute.Service = fmt.Sprintf("%s-%d", route.Service, i) + routes[i] = testRoute + } + + var advertErr error + + createDone := make(chan bool) + errChan := make(chan error) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + wg.Done() + defer close(createDone) + for _, route := range routes { + if len(os.Getenv("IN_TRAVIS_CI")) == 0 { + t.Logf("Creating route %v", route) + } + if err := r.Table().Create(route); err != nil { + if len(os.Getenv("IN_TRAVIS_CI")) == 0 { + t.Logf("Failed to create route: %v", err) + } + errChan <- err + return + } + } + }() + + var adverts int + readDone := make(chan bool) + + wg.Add(1) + go func() { + defer func() { + wg.Done() + readDone <- true + }() + for advert := range ch { + select { + case advertErr = <-errChan: + t.Errorf("failed advertising events: %v", advertErr) + default: + // do nothing for now + if len(os.Getenv("IN_TRAVIS_CI")) == 0 { + t.Logf("Router advert received: %v", advert) + } + adverts += len(advert.Events) + } + return + } + }() + + // done adding routes to routing table + <-createDone + // done reading adverts from the routing table + <-readDone + + if adverts != nrRoutes { + t.Errorf("Expected %d adverts, received: %d", nrRoutes, adverts) + } + + wg.Wait() + + if err := r.Close(); err != nil { + t.Errorf("failed to stop router: %v", err) + } +} diff --git a/table.go b/table.go new file mode 100644 index 0000000..2f1d748 --- /dev/null +++ b/table.go @@ -0,0 +1,321 @@ +package registry + +import ( + "sync" + "time" + + "github.com/google/uuid" + "github.com/micro/go-micro/v3/logger" + "github.com/micro/go-micro/v3/router" +) + +// table is an in-memory routing table +type table struct { + sync.RWMutex + // fetchRoutes for a service + fetchRoutes func(string) error + // routes stores service routes + routes map[string]map[uint64]router.Route + // watchers stores table watchers + watchers map[string]*tableWatcher +} + +// newtable creates a new routing table and returns it +func newTable(fetchRoutes func(string) error, opts ...router.Option) *table { + return &table{ + fetchRoutes: fetchRoutes, + routes: make(map[string]map[uint64]router.Route), + watchers: make(map[string]*tableWatcher), + } +} + +// sendEvent sends events to all subscribed watchers +func (t *table) sendEvent(e *router.Event) { + t.RLock() + defer t.RUnlock() + + if len(e.Id) == 0 { + e.Id = uuid.New().String() + } + + for _, w := range t.watchers { + select { + case w.resChan <- e: + case <-w.done: + // don't block forever + case <-time.After(time.Second): + } + } +} + +// Create creates new route in the routing table +func (t *table) Create(r router.Route) error { + service := r.Service + sum := r.Hash() + + t.Lock() + defer t.Unlock() + + // check if there are any routes in the table for the route destination + if _, ok := t.routes[service]; !ok { + t.routes[service] = make(map[uint64]router.Route) + } + + // add new route to the table for the route destination + if _, ok := t.routes[service][sum]; !ok { + t.routes[service][sum] = r + if logger.V(logger.DebugLevel, logger.DefaultLogger) { + logger.Debugf("Router emitting %s for route: %s", router.Create, r.Address) + } + go t.sendEvent(&router.Event{Type: router.Create, Timestamp: time.Now(), Route: r}) + return nil + } + + return router.ErrDuplicateRoute +} + +// Delete deletes the route from the routing table +func (t *table) Delete(r router.Route) error { + service := r.Service + sum := r.Hash() + + t.Lock() + defer t.Unlock() + + if _, ok := t.routes[service]; !ok { + return router.ErrRouteNotFound + } + + if _, ok := t.routes[service][sum]; !ok { + return router.ErrRouteNotFound + } + + delete(t.routes[service], sum) + if len(t.routes[service]) == 0 { + delete(t.routes, service) + } + if logger.V(logger.DebugLevel, logger.DefaultLogger) { + logger.Debugf("Router emitting %s for route: %s", router.Delete, r.Address) + } + go t.sendEvent(&router.Event{Type: router.Delete, Timestamp: time.Now(), Route: r}) + + return nil +} + +// Update updates routing table with the new route +func (t *table) Update(r router.Route) error { + service := r.Service + sum := r.Hash() + + t.Lock() + defer t.Unlock() + + // check if the route destination has any routes in the table + if _, ok := t.routes[service]; !ok { + t.routes[service] = make(map[uint64]router.Route) + } + + if _, ok := t.routes[service][sum]; !ok { + t.routes[service][sum] = r + if logger.V(logger.DebugLevel, logger.DefaultLogger) { + logger.Debugf("Router emitting %s for route: %s", router.Update, r.Address) + } + go t.sendEvent(&router.Event{Type: router.Update, Timestamp: time.Now(), Route: r}) + return nil + } + + // just update the route, but dont emit Update event + t.routes[service][sum] = r + + return nil +} + +// List returns a list of all routes in the table +func (t *table) List() ([]router.Route, error) { + t.RLock() + defer t.RUnlock() + + var routes []router.Route + for _, rmap := range t.routes { + for _, route := range rmap { + routes = append(routes, route) + } + } + + return routes, nil +} + +// isMatch checks if the route matches given query options +func isMatch(route router.Route, address, gateway, network, rtr string, strategy router.Strategy) bool { + // matches the values provided + match := func(a, b string) bool { + if a == "*" || b == "*" || a == b { + return true + } + return false + } + + // a simple struct to hold our values + type compare struct { + a string + b string + } + + // by default assume we are querying all routes + link := "*" + // if AdvertiseLocal change the link query accordingly + if strategy == router.AdvertiseLocal { + link = "local" + } + + // compare the following values + values := []compare{ + {gateway, route.Gateway}, + {network, route.Network}, + {rtr, route.Router}, + {address, route.Address}, + {link, route.Link}, + } + + for _, v := range values { + // attempt to match each value + if !match(v.a, v.b) { + return false + } + } + + return true +} + +// findRoutes finds all the routes for given network and router and returns them +func findRoutes(routes map[uint64]router.Route, address, gateway, network, rtr string, strategy router.Strategy) []router.Route { + // routeMap stores the routes we're going to advertise + routeMap := make(map[string][]router.Route) + + for _, route := range routes { + if isMatch(route, address, gateway, network, rtr, strategy) { + // add matchihg route to the routeMap + routeKey := route.Service + "@" + route.Network + // append the first found route to routeMap + _, ok := routeMap[routeKey] + if !ok { + routeMap[routeKey] = append(routeMap[routeKey], route) + continue + } + + // if AdvertiseAll, keep appending + if strategy == router.AdvertiseAll || strategy == router.AdvertiseLocal { + routeMap[routeKey] = append(routeMap[routeKey], route) + continue + } + + // now we're going to find the best routes + if strategy == router.AdvertiseBest { + // if the current optimal route metric is higher than routing table route, replace it + if len(routeMap[routeKey]) > 0 { + // NOTE: we know that when AdvertiseBest is set, we only ever have one item in current + if routeMap[routeKey][0].Metric > route.Metric { + routeMap[routeKey][0] = route + continue + } + } + } + } + } + + var results []router.Route + for _, route := range routeMap { + results = append(results, route...) + } + + return results +} + +// Lookup queries routing table and returns all routes that match the lookup query +func (t *table) Query(q ...router.QueryOption) ([]router.Route, error) { + // create new query options + opts := router.NewQuery(q...) + + // create a cwslicelist of query results + results := make([]router.Route, 0, len(t.routes)) + + // if No routes are queried, return early + if opts.Strategy == router.AdvertiseNone { + return results, nil + } + + // readAndFilter routes for this service under read lock. + readAndFilter := func() ([]router.Route, bool) { + t.RLock() + defer t.RUnlock() + + routes, ok := t.routes[opts.Service] + if !ok || len(routes) == 0 { + return nil, false + } + + return findRoutes(routes, opts.Address, opts.Gateway, opts.Network, opts.Router, opts.Strategy), true + } + + if opts.Service != "*" { + // try and load services from the cache + if routes, ok := readAndFilter(); ok { + return routes, nil + } + + // load the cache and try again + if err := t.fetchRoutes(opts.Service); err != nil { + return nil, err + } + + // try again + if routes, ok := readAndFilter(); ok { + return routes, nil + } + + return nil, router.ErrRouteNotFound + } + + // search through all destinations + t.RLock() + for _, routes := range t.routes { + results = append(results, findRoutes(routes, opts.Address, opts.Gateway, opts.Network, opts.Router, opts.Strategy)...) + } + t.RUnlock() + + return results, nil +} + +// Watch returns routing table entry watcher +func (t *table) Watch(opts ...router.WatchOption) (router.Watcher, error) { + // by default watch everything + wopts := router.WatchOptions{ + Service: "*", + } + + for _, o := range opts { + o(&wopts) + } + + w := &tableWatcher{ + id: uuid.New().String(), + opts: wopts, + resChan: make(chan *router.Event, 10), + done: make(chan struct{}), + } + + // when the watcher is stopped delete it + go func() { + <-w.done + t.Lock() + delete(t.watchers, w.id) + t.Unlock() + }() + + // save the watcher + t.Lock() + t.watchers[w.id] = w + t.Unlock() + + return w, nil +} diff --git a/table_test.go b/table_test.go new file mode 100644 index 0000000..5e2590f --- /dev/null +++ b/table_test.go @@ -0,0 +1,350 @@ +package registry + +import ( + "fmt" + "testing" + + "github.com/micro/go-micro/v3/router" +) + +func testSetup() (*table, router.Route) { + routr := NewRouter().(*rtr) + table := newTable(routr.fetchRoutes) + + route := router.Route{ + Service: "dest.svc", + Address: "dest.addr", + Gateway: "dest.gw", + Network: "dest.network", + Router: "src.router", + Link: "det.link", + Metric: 10, + } + + return table, route +} + +func TestCreate(t *testing.T) { + table, route := testSetup() + + if err := table.Create(route); err != nil { + t.Errorf("error adding route: %s", err) + } + + // adds new route for the original destination + route.Gateway = "dest.gw2" + + if err := table.Create(route); err != nil { + t.Errorf("error adding route: %s", err) + } + + // adding the same route under Insert policy must error + if err := table.Create(route); err != router.ErrDuplicateRoute { + t.Errorf("error adding route. Expected error: %s, found: %s", router.ErrDuplicateRoute, err) + } +} + +func TestDelete(t *testing.T) { + table, route := testSetup() + + if err := table.Create(route); err != nil { + t.Errorf("error adding route: %s", err) + } + + // should fail to delete non-existant route + prevSvc := route.Service + route.Service = "randDest" + + if err := table.Delete(route); err != router.ErrRouteNotFound { + t.Errorf("error deleting route. Expected: %s, found: %s", router.ErrRouteNotFound, err) + } + + // we should be able to delete the existing route + route.Service = prevSvc + + if err := table.Delete(route); err != nil { + t.Errorf("error deleting route: %s", err) + } +} + +func TestUpdate(t *testing.T) { + table, route := testSetup() + + if err := table.Create(route); err != nil { + t.Errorf("error adding route: %s", err) + } + + // change the metric of the original route + route.Metric = 200 + + if err := table.Update(route); err != nil { + t.Errorf("error updating route: %s", err) + } + + // this should add a new route + route.Service = "rand.dest" + + if err := table.Update(route); err != nil { + t.Errorf("error updating route: %s", err) + } +} + +func TestList(t *testing.T) { + table, route := testSetup() + + svc := []string{"one.svc", "two.svc", "three.svc"} + + for i := 0; i < len(svc); i++ { + route.Service = svc[i] + if err := table.Create(route); err != nil { + t.Errorf("error adding route: %s", err) + } + } + + routes, err := table.List() + if err != nil { + t.Errorf("error listing routes: %s", err) + } + + if len(routes) != len(svc) { + t.Errorf("incorrect number of routes listed. Expected: %d, found: %d", len(svc), len(routes)) + } +} + +func TestQuery(t *testing.T) { + table, route := testSetup() + + svc := []string{"svc1", "svc2", "svc3", "svc1"} + net := []string{"net1", "net2", "net1", "net3"} + gw := []string{"gw1", "gw2", "gw3", "gw3"} + rtr := []string{"rtr1", "rt2", "rt3", "rtr3"} + + for i := 0; i < len(svc); i++ { + route.Service = svc[i] + route.Network = net[i] + route.Gateway = gw[i] + route.Router = rtr[i] + if err := table.Create(route); err != nil { + t.Errorf("error adding route: %s", err) + } + } + + // return all routes + routes, err := table.Query() + if err != nil { + t.Errorf("error looking up routes: %s", err) + } else if len(routes) == 0 { + t.Errorf("error looking up routes: not found") + } + + // query routes particular network + network := "net1" + + routes, err = table.Query(router.QueryNetwork(network)) + if err != nil { + t.Errorf("error looking up routes: %s", err) + } + + if len(routes) != 2 { + t.Errorf("incorrect number of routes returned. Expected: %d, found: %d", 2, len(routes)) + } + + for _, route := range routes { + if route.Network != network { + t.Errorf("incorrect route returned. Expected network: %s, found: %s", network, route.Network) + } + } + + // query routes for particular gateway + gateway := "gw1" + + routes, err = table.Query(router.QueryGateway(gateway)) + if err != nil { + t.Errorf("error looking up routes: %s", err) + } + + if len(routes) != 1 { + t.Errorf("incorrect number of routes returned. Expected: %d, found: %d", 1, len(routes)) + } + + if routes[0].Gateway != gateway { + t.Errorf("incorrect route returned. Expected gateway: %s, found: %s", gateway, routes[0].Gateway) + } + + // query routes for particular router + rt := "rtr1" + + routes, err = table.Query(router.QueryRouter(rt)) + if err != nil { + t.Errorf("error looking up routes: %s", err) + } + + if len(routes) != 1 { + t.Errorf("incorrect number of routes returned. Expected: %d, found: %d", 1, len(routes)) + } + + if routes[0].Router != rt { + t.Errorf("incorrect route returned. Expected router: %s, found: %s", rt, routes[0].Router) + } + + // query particular gateway and network + query := []router.QueryOption{ + router.QueryGateway(gateway), + router.QueryNetwork(network), + router.QueryRouter(rt), + } + + routes, err = table.Query(query...) + if err != nil { + t.Errorf("error looking up routes: %s", err) + } + + if len(routes) != 1 { + t.Errorf("incorrect number of routes returned. Expected: %d, found: %d", 1, len(routes)) + } + + if routes[0].Gateway != gateway { + t.Errorf("incorrect route returned. Expected gateway: %s, found: %s", gateway, routes[0].Gateway) + } + + if routes[0].Network != network { + t.Errorf("incorrect network returned. Expected network: %s, found: %s", network, routes[0].Network) + } + + if routes[0].Router != rt { + t.Errorf("incorrect route returned. Expected router: %s, found: %s", rt, routes[0].Router) + } + + // non-existen route query + routes, err = table.Query(router.QueryService("foobar")) + if err != router.ErrRouteNotFound { + t.Errorf("error looking up routes. Expected: %s, found: %s", router.ErrRouteNotFound, err) + } + + if len(routes) != 0 { + t.Errorf("incorrect number of routes returned. Expected: %d, found: %d", 0, len(routes)) + } + + // query NO routes + query = []router.QueryOption{ + router.QueryGateway(gateway), + router.QueryNetwork(network), + router.QueryStrategy(router.AdvertiseNone), + } + + routes, err = table.Query(query...) + if err != nil { + t.Errorf("error looking up routes: %s", err) + } + + if len(routes) > 0 { + t.Errorf("incorrect number of routes returned. Expected: %d, found: %d", 0, len(routes)) + } + + // insert local routes to query + for i := 0; i < 2; i++ { + route.Link = "local" + route.Address = fmt.Sprintf("local.route.address-%d", i) + if err := table.Create(route); err != nil { + t.Errorf("error adding route: %s", err) + } + } + + // query local routes + query = []router.QueryOption{ + router.QueryGateway("*"), + router.QueryNetwork("*"), + router.QueryStrategy(router.AdvertiseLocal), + } + + routes, err = table.Query(query...) + if err != nil { + t.Errorf("error looking up routes: %s", err) + } + + if len(routes) != 2 { + t.Errorf("incorrect number of routes returned. Expected: %d, found: %d", 2, len(routes)) + } + + // add two different routes for svcX with different metric + for i := 0; i < 2; i++ { + route.Service = "svcX" + route.Address = fmt.Sprintf("svcX.route.address-%d", i) + route.Metric = int64(100 + i) + if err := table.Create(route); err != nil { + t.Errorf("error adding route: %s", err) + } + } + + // query best routes for svcX + query = []router.QueryOption{ + router.QueryService("svcX"), + router.QueryStrategy(router.AdvertiseBest), + } + + routes, err = table.Query(query...) + if err != nil { + t.Errorf("error looking up routes: %s", err) + } + + if len(routes) != 1 { + t.Errorf("incorrect number of routes returned. Expected: %d, found: %d", 1, len(routes)) + } +} + +func TestFallback(t *testing.T) { + + r := &rtr{ + subscribers: make(map[string]chan *router.Advert), + options: router.DefaultOptions(), + } + route := router.Route{ + Service: "go.micro.service.foo", + Router: r.options.Id, + Link: router.DefaultLink, + Metric: router.DefaultLocalMetric, + } + r.table = newTable(func(s string) error { + r.table.Create(route) + return nil + }) + r.start() + + rts, err := r.Lookup(router.QueryService("go.micro.service.foo")) + if err != nil { + t.Errorf("error looking up service %s", err) + } + if len(rts) != 1 { + t.Errorf("incorrect number of routes returned %d", len(rts)) + } + + // deleting from the table but the next query should invoke the fallback that we passed during new table creation + if err := r.table.Delete(route); err != nil { + t.Errorf("error deleting route %s", err) + } + + rts, err = r.Lookup(router.QueryService("go.micro.service.foo")) + if err != nil { + t.Errorf("error looking up service %s", err) + } + if len(rts) != 1 { + t.Errorf("incorrect number of routes returned %d", len(rts)) + } + +} + +func TestFallbackError(t *testing.T) { + r := &rtr{ + subscribers: make(map[string]chan *router.Advert), + options: router.DefaultOptions(), + } + r.table = newTable(func(s string) error { + return fmt.Errorf("ERROR") + }) + r.start() + _, err := r.Lookup(router.QueryService("go.micro.service.foo")) + if err == nil { + t.Errorf("expected error looking up service but none returned") + } + +} diff --git a/watcher.go b/watcher.go new file mode 100644 index 0000000..edc94ea --- /dev/null +++ b/watcher.go @@ -0,0 +1,52 @@ +package registry + +import ( + "sync" + + "github.com/micro/go-micro/v3/router" +) + +// tableWatcher implements routing table Watcher +type tableWatcher struct { + sync.RWMutex + id string + opts router.WatchOptions + resChan chan *router.Event + done chan struct{} +} + +// Next returns the next noticed action taken on table +// TODO: right now we only allow to watch particular service +func (w *tableWatcher) Next() (*router.Event, error) { + for { + select { + case res := <-w.resChan: + switch w.opts.Service { + case res.Route.Service, "*": + return res, nil + default: + continue + } + case <-w.done: + return nil, router.ErrWatcherStopped + } + } +} + +// Chan returns watcher events channel +func (w *tableWatcher) Chan() (<-chan *router.Event, error) { + return w.resChan, nil +} + +// Stop stops routing table watcher +func (w *tableWatcher) Stop() { + w.Lock() + defer w.Unlock() + + select { + case <-w.done: + return + default: + close(w.done) + } +}