From cb1679fd8d25741ee75e206555290093fec765a5 Mon Sep 17 00:00:00 2001 From: Milos Gajdos Date: Mon, 12 Aug 2019 18:18:17 +0100 Subject: [PATCH] Add Start method to router Added Start to router packages. Fixed potential deadlocks. --- router/default.go | 145 ++++++++++++++++++++------------------ router/router.go | 11 ++- router/service/service.go | 46 ++++++++---- 3 files changed, 117 insertions(+), 85 deletions(-) diff --git a/router/default.go b/router/default.go index 498539a1..1dec7a51 100644 --- a/router/default.go +++ b/router/default.go @@ -43,10 +43,9 @@ var ( // router implements default router type router struct { sync.RWMutex - // embed the table - table *table opts Options status Status + table *table exit chan struct{} errChan chan error eventChan chan *Event @@ -67,33 +66,41 @@ func newRouter(opts ...Option) Router { o(&options) } - r := &router{ - table: newTable(), + // set initial status to Stopped + status := Status{Code: Stopped, Error: nil} + + return &router{ opts: options, - status: Status{Code: Stopped, Error: nil}, + status: status, + table: newTable(), advertWg: &sync.WaitGroup{}, wg: &sync.WaitGroup{}, subscribers: make(map[string]chan *Advert), } - - go r.run() - - return r } // Init initializes router with given options func (r *router) Init(opts ...Option) error { + r.Lock() + defer r.Unlock() + for _, o := range opts { o(&r.opts) } + return nil } // Options returns router options func (r *router) Options() Options { - return r.opts + r.Lock() + opts := r.opts + r.Unlock() + + return opts } +// Table returns routing table func (r *router) Table() Table { return r.table } @@ -475,11 +482,12 @@ func (r *router) watchErrors() { r.Lock() defer r.Unlock() + // if the router is not stopped, stop it if r.status.Code != Stopped { // notify all goroutines to finish close(r.exit) - // drain the advertise channel only if advertising + // drain the advertise channel only if the router is advertising if r.status.Code == Advertising { // drain the event channel for range r.eventChan { @@ -495,69 +503,67 @@ func (r *router) watchErrors() { } } -// Run runs the router. -func (r *router) run() { +// Start starts the router +func (r *router) Start() error { r.Lock() defer r.Unlock() - switch r.status.Code { - case Stopped, Error: - // add all local service routes into the routing table - if err := r.manageRegistryRoutes(r.opts.Registry, "create"); err != nil { - r.status = Status{Code: Error, Error: fmt.Errorf("failed adding registry routes: %s", err)} - return - } - - // add default gateway into routing table - if r.opts.Gateway != "" { - // note, the only non-default value is the gateway - route := Route{ - Service: "*", - Address: "*", - Gateway: r.opts.Gateway, - Network: "*", - Metric: DefaultLocalMetric, - } - if err := r.table.Create(route); err != nil { - r.status = Status{Code: Error, Error: fmt.Errorf("failed adding default gateway route: %s", err)} - return - } - } - - // create error and exit channels - r.errChan = make(chan error, 1) - r.exit = make(chan struct{}) - - // registry watcher - regWatcher, err := r.opts.Registry.Watch() - if err != nil { - r.status = Status{Code: Error, Error: fmt.Errorf("failed creating registry watcher: %v", err)} - return - } - - r.wg.Add(1) - go func() { - defer r.wg.Done() - select { - case r.errChan <- r.watchRegistry(regWatcher): - case <-r.exit: - } - }() - - // watch for errors and cleanup - r.wg.Add(1) - go func() { - defer r.wg.Done() - r.watchErrors() - }() - - // mark router as Running and set its Error to nil - r.status = Status{Code: Running, Error: nil} - - return + // add all local service routes into the routing table + if err := r.manageRegistryRoutes(r.opts.Registry, "create"); err != nil { + e := fmt.Errorf("failed adding registry routes: %s", err) + r.status = Status{Code: Error, Error: e} + return e } - return + // add default gateway into routing table + if r.opts.Gateway != "" { + // note, the only non-default value is the gateway + route := Route{ + Service: "*", + Address: "*", + Gateway: r.opts.Gateway, + Network: "*", + Metric: DefaultLocalMetric, + } + if err := r.table.Create(route); err != nil { + e := fmt.Errorf("failed adding default gateway route: %s", err) + r.status = Status{Code: Error, Error: e} + return e + } + } + + // create error and exit channels + r.errChan = make(chan error, 1) + r.exit = make(chan struct{}) + + // registry watcher + regWatcher, err := r.opts.Registry.Watch() + if err != nil { + e := fmt.Errorf("failed creating registry watcher: %v", err) + r.status = Status{Code: Error, Error: e} + return e + } + + r.wg.Add(1) + go func() { + defer r.wg.Done() + select { + case r.errChan <- r.watchRegistry(regWatcher): + case <-r.exit: + } + }() + + // watch for errors and cleanup + r.wg.Add(1) + go func() { + defer r.wg.Done() + r.watchErrors() + }() + + // mark router as Running + r.status = Status{Code: Running, Error: nil} + + return nil } // Advertise stars advertising the routes to the network and returns the advertisements channel to consume from. @@ -578,6 +584,7 @@ func (r *router) Advertise() (<-chan *Advert, error) { if err != nil { return nil, fmt.Errorf("failed listing routes: %s", err) } + // collect all the added routes before we attempt to add default gateway events := make([]*Event, len(routes)) for i, route := range routes { diff --git a/router/router.go b/router/router.go index af6f8bc3..7d7ab0de 100644 --- a/router/router.go +++ b/router/router.go @@ -21,6 +21,8 @@ type Router interface { Lookup(Query) ([]Route, error) // Watch returns a watcher which tracks updates to the routing table Watch(opts ...WatchOption) (Watcher, error) + // Start starts the router + Start() error // Status returns router status Status() Status // Stop stops the router @@ -76,10 +78,15 @@ func (s StatusCode) String() string { // Status is router status type Status struct { - // Error is router error - Error error // Code defines router status Code StatusCode + // Error contains error description + Error error +} + +// String returns human readable status +func (s Status) String() string { + return s.Code.String() } // AdvertType is route advertisement type diff --git a/router/service/service.go b/router/service/service.go index b06de9c8..77619762 100644 --- a/router/service/service.go +++ b/router/service/service.go @@ -43,9 +43,16 @@ func NewRouter(opts ...router.Option) router.Router { cli = options.Client } + // set the status to Stopped + status := &router.Status{ + Code: router.Stopped, + Error: nil, + } + // NOTE: should we have Client/Service option in router.Options? s := &svc{ opts: options, + status: status, router: pb.NewRouterService(router.DefaultName, cli), } @@ -63,21 +70,43 @@ func NewRouter(opts ...router.Option) router.Router { // Init initializes router with given options func (s *svc) Init(opts ...router.Option) error { + s.Lock() + defer s.Unlock() + for _, o := range opts { o(&s.opts) } + return nil } // Options returns router options func (s *svc) Options() router.Options { - return s.opts + s.Lock() + opts := s.opts + s.Unlock() + + return opts } +// Table returns routing table func (s *svc) Table() router.Table { return s.table } +// Start starts the service +func (s *svc) Start() error { + s.Lock() + defer s.Unlock() + + s.status = &router.Status{ + Code: router.Running, + Error: nil, + } + + return nil +} + func (s *svc) advertiseEvents(advertChan chan *router.Advert, stream pb.Router_AdvertiseService) error { go func() { <-s.exit @@ -140,10 +169,7 @@ func (s *svc) Advertise() (<-chan *router.Advert, error) { s.Lock() defer s.Unlock() - // get the status - status := s.Status() - - switch status.Code { + switch s.status.Code { case router.Running, router.Advertising: stream, err := s.router.Advertise(context.Background(), &pb.Request{}, s.callOpts...) if err != nil { @@ -154,15 +180,7 @@ func (s *svc) Advertise() (<-chan *router.Advert, error) { go s.advertiseEvents(advertChan, stream) return advertChan, nil case router.Stopped: - // check if our router is stopped - select { - case <-s.exit: - s.exit = make(chan bool) - // call advertise again - return s.Advertise() - default: - return nil, fmt.Errorf("not running") - } + return nil, fmt.Errorf("not running") } return nil, fmt.Errorf("error: %s", s.status.Error)