Stop watcher when router stops. Drain advert channel when stopping.

This commit is contained in:
Milos Gajdos 2019-07-27 13:08:54 +01:00
parent 002abca61f
commit d8b00e801d
No known key found for this signature in database
GPG Key ID: 8B31058CC55DFD4F
2 changed files with 38 additions and 33 deletions

View File

@ -445,7 +445,6 @@ func (r *router) watchErrors() {
} }
// Run runs the router. // Run runs the router.
// It returns error if the router is already running.
func (r *router) run() { func (r *router) run() {
r.Lock() r.Lock()
defer r.Unlock() defer r.Unlock()

View File

@ -2,7 +2,6 @@ package service
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"io" "io"
"sync" "sync"
@ -14,11 +13,6 @@ import (
pb "github.com/micro/go-micro/network/router/proto" pb "github.com/micro/go-micro/network/router/proto"
) )
var (
// ErrNotImplemented means the functionality has not been implemented
ErrNotImplemented = errors.New("not implemented")
)
type svc struct { type svc struct {
opts router.Options opts router.Options
router pb.RouterService router pb.RouterService
@ -71,31 +65,10 @@ func (s *svc) Options() router.Options {
return s.opts return s.opts
} }
// watchErrors watches router errors and takes appropriate actions
func (s *svc) watchErrors() {
var err error
select {
case <-s.exit:
case err = <-s.errChan:
}
s.Lock()
defer s.Unlock()
if s.status.Code != router.Stopped {
// notify all goroutines to finish
close(s.exit)
// TODO" might need to drain some channels here
}
if err != nil {
s.status = router.Status{Code: router.Error, Error: err}
}
}
// watchRouter watches router and send events to all registered watchers // watchRouter watches router and send events to all registered watchers
func (s *svc) watchRouter(stream pb.Router_WatchService) error { func (s *svc) watchRouter(stream pb.Router_WatchService) error {
defer stream.Close() defer stream.Close()
var watchErr error var watchErr error
for { for {
@ -122,6 +95,7 @@ func (s *svc) watchRouter(stream pb.Router_WatchService) error {
Route: route, Route: route,
} }
// TODO: might make this non-blocking
s.RLock() s.RLock()
for _, w := range s.watchers { for _, w := range s.watchers {
select { select {
@ -135,8 +109,31 @@ func (s *svc) watchRouter(stream pb.Router_WatchService) error {
return watchErr return watchErr
} }
// watchErrors watches router errors and takes appropriate actions
func (s *svc) watchErrors() {
var err error
select {
case <-s.exit:
case err = <-s.errChan:
}
s.Lock()
defer s.Unlock()
if s.status.Code != router.Stopped {
// notify all goroutines to finish
close(s.exit)
// drain the advertise channel
for range s.advertChan {
}
}
if err != nil {
s.status = router.Status{Code: router.Error, Error: err}
}
}
// Run runs the router. // Run runs the router.
// It returns error if the router is already running.
func (s *svc) run() { func (s *svc) run() {
s.Lock() s.Lock()
defer s.Unlock() defer s.Unlock()
@ -145,7 +142,7 @@ func (s *svc) run() {
case router.Stopped, router.Error: case router.Stopped, router.Error:
stream, err := s.router.Watch(context.Background(), &pb.WatchRequest{}) stream, err := s.router.Watch(context.Background(), &pb.WatchRequest{})
if err != nil { if err != nil {
s.status = router.Status{Code: router.Error, Error: fmt.Errorf("failed getting router stream: %s", err)} s.status = router.Status{Code: router.Error, Error: fmt.Errorf("failed getting event stream: %s", err)}
return return
} }
@ -425,6 +422,14 @@ func (s *svc) Watch(opts ...router.WatchOption) (router.Watcher, error) {
s.watchers[uuid.New().String()] = w s.watchers[uuid.New().String()] = w
s.Unlock() s.Unlock()
// when the router stops, stop the watcher and exit
s.wg.Add(1)
go func() {
defer s.wg.Done()
<-s.exit
w.Stop()
}()
return w, nil return w, nil
} }
@ -446,8 +451,9 @@ func (s *svc) Stop() error {
if s.status.Code == router.Running || s.status.Code == router.Advertising { if s.status.Code == router.Running || s.status.Code == router.Advertising {
// notify all goroutines to finish // notify all goroutines to finish
close(s.exit) close(s.exit)
// TODO: might need to drain some channels here // drain the advertise channel
for range s.advertChan {
}
// mark the router as Stopped and set its Error to nil // mark the router as Stopped and set its Error to nil
s.status = router.Status{Code: router.Stopped, Error: nil} s.status = router.Status{Code: router.Stopped, Error: nil}
} }