This commit is contained in:
Asim 2016-04-29 23:34:13 +01:00
parent 9d85cf22f9
commit b8712a0af4
3 changed files with 33 additions and 4 deletions

View File

@ -150,6 +150,7 @@ func newClient() (*client, error) {
// Close is used to cleanup the client
func (c *client) Close() error {
log.Printf("[INFO] mdns: Closing client %v", *c)
c.closeLock.Lock()
defer c.closeLock.Unlock()
@ -158,7 +159,6 @@ func (c *client) Close() error {
}
c.closed = true
log.Printf("[INFO] mdns: Closing client %v", *c)
close(c.closedCh)
if c.ipv4UnicastConn != nil {
@ -243,6 +243,9 @@ func (c *client) query(params *QueryParam) error {
case *dns.PTR:
// Create new entry for this
inp = ensureName(inprogress, rr.Ptr)
if inp.complete() {
continue
}
case *dns.SRV:
// Check for a target mismatch
@ -252,12 +255,18 @@ func (c *client) query(params *QueryParam) error {
// Get the port
inp = ensureName(inprogress, rr.Hdr.Name)
if inp.complete() {
continue
}
inp.Host = rr.Target
inp.Port = int(rr.Port)
case *dns.TXT:
// Pull out the txt
inp = ensureName(inprogress, rr.Hdr.Name)
if inp.complete() {
continue
}
inp.Info = strings.Join(rr.Txt, "|")
inp.InfoFields = rr.Txt
inp.hasTXT = true
@ -265,12 +274,18 @@ func (c *client) query(params *QueryParam) error {
case *dns.A:
// Pull out the IP
inp = ensureName(inprogress, rr.Hdr.Name)
if inp.complete() {
continue
}
inp.Addr = rr.A // @Deprecated
inp.AddrV4 = rr.A
case *dns.AAAA:
// Pull out the IP
inp = ensureName(inprogress, rr.Hdr.Name)
if inp.complete() {
continue
}
inp.Addr = rr.AAAA // @Deprecated
inp.AddrV6 = rr.AAAA
}
@ -288,7 +303,6 @@ func (c *client) query(params *QueryParam) error {
inp.sent = true
select {
case params.Entries <- inp:
default:
}
} else {
// Fire off a node specific query
@ -326,7 +340,13 @@ func (c *client) recv(l *net.UDPConn, msgCh chan *dns.Msg) {
return
}
buf := make([]byte, 65536)
for !c.closed {
for {
c.closeLock.Lock()
if c.closed {
c.closeLock.Unlock()
return
}
c.closeLock.Unlock()
n, err := l.Read(buf)
if err != nil {
log.Printf("[ERR] mdns: Failed to read packet: %v", err)

View File

@ -107,7 +107,13 @@ func (s *Server) recv(c *net.UDPConn) {
return
}
buf := make([]byte, 65536)
for !s.shutdown {
for {
s.shutdownLock.Lock()
if s.shutdown {
s.shutdownLock.Unlock()
return
}
s.shutdownLock.Unlock()
n, from, err := c.ReadFrom(buf)
if err != nil {
continue

View File

@ -23,6 +23,7 @@ func TestServer_Lookup(t *testing.T) {
entries := make(chan *ServiceEntry, 1)
found := false
doneCh := make(chan struct{})
go func() {
select {
case e := <-entries:
@ -40,6 +41,7 @@ func TestServer_Lookup(t *testing.T) {
case <-time.After(80 * time.Millisecond):
t.Fatalf("timeout")
}
close(doneCh)
}()
params := &QueryParam{
@ -52,6 +54,7 @@ func TestServer_Lookup(t *testing.T) {
if err != nil {
t.Fatalf("err: %v", err)
}
<-doneCh
if !found {
t.Fatalf("record not found")
}