diff --git a/client.go b/client.go index ab1c80d..697709f 100644 --- a/client.go +++ b/client.go @@ -150,6 +150,7 @@ func newClient() (*client, error) { // Close is used to cleanup the client func (c *client) Close() error { + log.Printf("[INFO] mdns: Closing client %v", *c) c.closeLock.Lock() defer c.closeLock.Unlock() @@ -158,7 +159,6 @@ func (c *client) Close() error { } c.closed = true - log.Printf("[INFO] mdns: Closing client %v", *c) close(c.closedCh) if c.ipv4UnicastConn != nil { @@ -243,6 +243,9 @@ func (c *client) query(params *QueryParam) error { case *dns.PTR: // Create new entry for this inp = ensureName(inprogress, rr.Ptr) + if inp.complete() { + continue + } case *dns.SRV: // Check for a target mismatch @@ -252,12 +255,18 @@ func (c *client) query(params *QueryParam) error { // Get the port inp = ensureName(inprogress, rr.Hdr.Name) + if inp.complete() { + continue + } inp.Host = rr.Target inp.Port = int(rr.Port) case *dns.TXT: // Pull out the txt inp = ensureName(inprogress, rr.Hdr.Name) + if inp.complete() { + continue + } inp.Info = strings.Join(rr.Txt, "|") inp.InfoFields = rr.Txt inp.hasTXT = true @@ -265,12 +274,18 @@ func (c *client) query(params *QueryParam) error { case *dns.A: // Pull out the IP inp = ensureName(inprogress, rr.Hdr.Name) + if inp.complete() { + continue + } inp.Addr = rr.A // @Deprecated inp.AddrV4 = rr.A case *dns.AAAA: // Pull out the IP inp = ensureName(inprogress, rr.Hdr.Name) + if inp.complete() { + continue + } inp.Addr = rr.AAAA // @Deprecated inp.AddrV6 = rr.AAAA } @@ -288,7 +303,6 @@ func (c *client) query(params *QueryParam) error { inp.sent = true select { case params.Entries <- inp: - default: } } else { // Fire off a node specific query @@ -326,7 +340,13 @@ func (c *client) recv(l *net.UDPConn, msgCh chan *dns.Msg) { return } buf := make([]byte, 65536) - for !c.closed { + for { + c.closeLock.Lock() + if c.closed { + c.closeLock.Unlock() + return + } + c.closeLock.Unlock() n, err := l.Read(buf) if err != nil { log.Printf("[ERR] mdns: Failed to read packet: %v", err) diff --git a/server.go b/server.go index 55988c4..18f4d4d 100644 --- a/server.go +++ b/server.go @@ -107,7 +107,13 @@ func (s *Server) recv(c *net.UDPConn) { return } buf := make([]byte, 65536) - for !s.shutdown { + for { + s.shutdownLock.Lock() + if s.shutdown { + s.shutdownLock.Unlock() + return + } + s.shutdownLock.Unlock() n, from, err := c.ReadFrom(buf) if err != nil { continue diff --git a/server_test.go b/server_test.go index 3f20bc6..6fb00fa 100644 --- a/server_test.go +++ b/server_test.go @@ -23,6 +23,7 @@ func TestServer_Lookup(t *testing.T) { entries := make(chan *ServiceEntry, 1) found := false + doneCh := make(chan struct{}) go func() { select { case e := <-entries: @@ -40,6 +41,7 @@ func TestServer_Lookup(t *testing.T) { case <-time.After(80 * time.Millisecond): t.Fatalf("timeout") } + close(doneCh) }() params := &QueryParam{ @@ -52,6 +54,7 @@ func TestServer_Lookup(t *testing.T) { if err != nil { t.Fatalf("err: %v", err) } + <-doneCh if !found { t.Fatalf("record not found") }