update mdns to remove race condition
This commit is contained in:
		| @@ -2,6 +2,7 @@ | ||||
| package registry | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"net" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| @@ -194,20 +195,20 @@ func (m *mdnsRegistry) Deregister(service *Service) error { | ||||
| } | ||||
|  | ||||
| func (m *mdnsRegistry) GetService(service string) ([]*Service, error) { | ||||
| 	p := mdns.DefaultParams(service) | ||||
| 	p.Timeout = m.opts.Timeout | ||||
| 	entryCh := make(chan *mdns.ServiceEntry, 10) | ||||
| 	p.Entries = entryCh | ||||
|  | ||||
| 	exit := make(chan bool) | ||||
| 	defer close(exit) | ||||
|  | ||||
| 	serviceMap := make(map[string]*Service) | ||||
| 	entries := make(chan *mdns.ServiceEntry, 10) | ||||
| 	done := make(chan bool) | ||||
|  | ||||
| 	p := mdns.DefaultParams(service) | ||||
| 	// set context with timeout | ||||
| 	p.Context, _ = context.WithTimeout(context.Background(), m.opts.Timeout) | ||||
| 	// set entries channel | ||||
| 	p.Entries = entries | ||||
|  | ||||
| 	go func() { | ||||
| 		for { | ||||
| 			select { | ||||
| 			case e := <-entryCh: | ||||
| 			case e := <-entries: | ||||
| 				// list record so skip | ||||
| 				if p.Service == "_services" { | ||||
| 					continue | ||||
| @@ -243,16 +244,21 @@ func (m *mdnsRegistry) GetService(service string) ([]*Service, error) { | ||||
| 				}) | ||||
|  | ||||
| 				serviceMap[txt.Version] = s | ||||
| 			case <-exit: | ||||
| 			case <-p.Context.Done(): | ||||
| 				close(done) | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	// execute the query | ||||
| 	if err := mdns.Query(p); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	// wait for completion | ||||
| 	<-done | ||||
|  | ||||
| 	// create list and return | ||||
| 	var services []*Service | ||||
|  | ||||
| @@ -264,21 +270,22 @@ func (m *mdnsRegistry) GetService(service string) ([]*Service, error) { | ||||
| } | ||||
|  | ||||
| func (m *mdnsRegistry) ListServices() ([]*Service, error) { | ||||
| 	p := mdns.DefaultParams("_services") | ||||
| 	p.Timeout = m.opts.Timeout | ||||
| 	entryCh := make(chan *mdns.ServiceEntry, 10) | ||||
| 	p.Entries = entryCh | ||||
|  | ||||
| 	exit := make(chan bool) | ||||
| 	defer close(exit) | ||||
|  | ||||
| 	serviceMap := make(map[string]bool) | ||||
| 	entries := make(chan *mdns.ServiceEntry, 10) | ||||
| 	done := make(chan bool) | ||||
|  | ||||
| 	p := mdns.DefaultParams("_services") | ||||
| 	// set context with timeout | ||||
| 	p.Context, _ = context.WithTimeout(context.Background(), m.opts.Timeout) | ||||
| 	// set entries channel | ||||
| 	p.Entries = entries | ||||
|  | ||||
| 	var services []*Service | ||||
|  | ||||
| 	go func() { | ||||
| 		for { | ||||
| 			select { | ||||
| 			case e := <-entryCh: | ||||
| 			case e := <-entries: | ||||
| 				if e.TTL == 0 { | ||||
| 					continue | ||||
| 				} | ||||
| @@ -288,16 +295,21 @@ func (m *mdnsRegistry) ListServices() ([]*Service, error) { | ||||
| 					serviceMap[name] = true | ||||
| 					services = append(services, &Service{Name: name}) | ||||
| 				} | ||||
| 			case <-exit: | ||||
| 			case <-p.Context.Done(): | ||||
| 				close(done) | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	// execute query | ||||
| 	if err := mdns.Query(p); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	// wait till done | ||||
| 	<-done | ||||
|  | ||||
| 	return services, nil | ||||
| } | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user