From 88e12347d0a08167b0160dfbb52b326e571a2511 Mon Sep 17 00:00:00 2001 From: Asim Aslam Date: Fri, 1 Feb 2019 13:41:11 +0000 Subject: [PATCH] update mdns to remove race condition --- registry/mdns_registry.go | 52 ++++++++++++++++++++++++--------------- 1 file changed, 32 insertions(+), 20 deletions(-) diff --git a/registry/mdns_registry.go b/registry/mdns_registry.go index 70f42736..4a3d09ec 100644 --- a/registry/mdns_registry.go +++ b/registry/mdns_registry.go @@ -2,6 +2,7 @@ package registry import ( + "context" "net" "strings" "sync" @@ -194,20 +195,20 @@ func (m *mdnsRegistry) Deregister(service *Service) error { } func (m *mdnsRegistry) GetService(service string) ([]*Service, error) { - p := mdns.DefaultParams(service) - p.Timeout = m.opts.Timeout - entryCh := make(chan *mdns.ServiceEntry, 10) - p.Entries = entryCh - - exit := make(chan bool) - defer close(exit) - serviceMap := make(map[string]*Service) + entries := make(chan *mdns.ServiceEntry, 10) + done := make(chan bool) + + p := mdns.DefaultParams(service) + // set context with timeout + p.Context, _ = context.WithTimeout(context.Background(), m.opts.Timeout) + // set entries channel + p.Entries = entries go func() { for { select { - case e := <-entryCh: + case e := <-entries: // list record so skip if p.Service == "_services" { continue @@ -243,16 +244,21 @@ func (m *mdnsRegistry) GetService(service string) ([]*Service, error) { }) serviceMap[txt.Version] = s - case <-exit: + case <-p.Context.Done(): + close(done) return } } }() + // execute the query if err := mdns.Query(p); err != nil { return nil, err } + // wait for completion + <-done + // create list and return var services []*Service @@ -264,21 +270,22 @@ func (m *mdnsRegistry) GetService(service string) ([]*Service, error) { } func (m *mdnsRegistry) ListServices() ([]*Service, error) { - p := mdns.DefaultParams("_services") - p.Timeout = m.opts.Timeout - entryCh := make(chan *mdns.ServiceEntry, 10) - p.Entries = entryCh - - exit := make(chan bool) - defer close(exit) - serviceMap := make(map[string]bool) + entries := make(chan *mdns.ServiceEntry, 10) + done := make(chan bool) + + p := mdns.DefaultParams("_services") + // set context with timeout + p.Context, _ = context.WithTimeout(context.Background(), m.opts.Timeout) + // set entries channel + p.Entries = entries + var services []*Service go func() { for { select { - case e := <-entryCh: + case e := <-entries: if e.TTL == 0 { continue } @@ -288,16 +295,21 @@ func (m *mdnsRegistry) ListServices() ([]*Service, error) { serviceMap[name] = true services = append(services, &Service{Name: name}) } - case <-exit: + case <-p.Context.Done(): + close(done) return } } }() + // execute query if err := mdns.Query(p); err != nil { return nil, err } + // wait till done + <-done + return services, nil }