update mdns to remove race condition

This commit is contained in:
Asim Aslam 2019-02-01 13:41:11 +00:00
parent 652b1067f5
commit 88e12347d0

View File

@ -2,6 +2,7 @@
package registry package registry
import ( import (
"context"
"net" "net"
"strings" "strings"
"sync" "sync"
@ -194,20 +195,20 @@ func (m *mdnsRegistry) Deregister(service *Service) error {
} }
func (m *mdnsRegistry) GetService(service string) ([]*Service, error) { func (m *mdnsRegistry) GetService(service string) ([]*Service, error) {
p := mdns.DefaultParams(service)
p.Timeout = m.opts.Timeout
entryCh := make(chan *mdns.ServiceEntry, 10)
p.Entries = entryCh
exit := make(chan bool)
defer close(exit)
serviceMap := make(map[string]*Service) serviceMap := make(map[string]*Service)
entries := make(chan *mdns.ServiceEntry, 10)
done := make(chan bool)
p := mdns.DefaultParams(service)
// set context with timeout
p.Context, _ = context.WithTimeout(context.Background(), m.opts.Timeout)
// set entries channel
p.Entries = entries
go func() { go func() {
for { for {
select { select {
case e := <-entryCh: case e := <-entries:
// list record so skip // list record so skip
if p.Service == "_services" { if p.Service == "_services" {
continue continue
@ -243,16 +244,21 @@ func (m *mdnsRegistry) GetService(service string) ([]*Service, error) {
}) })
serviceMap[txt.Version] = s serviceMap[txt.Version] = s
case <-exit: case <-p.Context.Done():
close(done)
return return
} }
} }
}() }()
// execute the query
if err := mdns.Query(p); err != nil { if err := mdns.Query(p); err != nil {
return nil, err return nil, err
} }
// wait for completion
<-done
// create list and return // create list and return
var services []*Service var services []*Service
@ -264,21 +270,22 @@ func (m *mdnsRegistry) GetService(service string) ([]*Service, error) {
} }
func (m *mdnsRegistry) ListServices() ([]*Service, error) { func (m *mdnsRegistry) ListServices() ([]*Service, error) {
p := mdns.DefaultParams("_services")
p.Timeout = m.opts.Timeout
entryCh := make(chan *mdns.ServiceEntry, 10)
p.Entries = entryCh
exit := make(chan bool)
defer close(exit)
serviceMap := make(map[string]bool) serviceMap := make(map[string]bool)
entries := make(chan *mdns.ServiceEntry, 10)
done := make(chan bool)
p := mdns.DefaultParams("_services")
// set context with timeout
p.Context, _ = context.WithTimeout(context.Background(), m.opts.Timeout)
// set entries channel
p.Entries = entries
var services []*Service var services []*Service
go func() { go func() {
for { for {
select { select {
case e := <-entryCh: case e := <-entries:
if e.TTL == 0 { if e.TTL == 0 {
continue continue
} }
@ -288,16 +295,21 @@ func (m *mdnsRegistry) ListServices() ([]*Service, error) {
serviceMap[name] = true serviceMap[name] = true
services = append(services, &Service{Name: name}) services = append(services, &Service{Name: name})
} }
case <-exit: case <-p.Context.Done():
close(done)
return return
} }
} }
}() }()
// execute query
if err := mdns.Query(p); err != nil { if err := mdns.Query(p); err != nil {
return nil, err return nil, err
} }
// wait till done
<-done
return services, nil return services, nil
} }