diff --git a/client.go b/client.go index 83c6968..b6d2ff1 100644 --- a/client.go +++ b/client.go @@ -22,6 +22,7 @@ type ServiceEntry struct { Port int Info string InfoFields []string + TTL int Addr net.IP // @Deprecated @@ -69,7 +70,7 @@ func Query(params *QueryParam) error { // Set the multicast interface if params.Interface != nil { - if err := client.setInterface(params.Interface); err != nil { + if err := client.setInterface(params.Interface, false); err != nil { return err } } @@ -86,6 +87,62 @@ func Query(params *QueryParam) error { return client.query(params) } +// Listen listens indefinitely for multicast updates +func Listen(entries chan<- *ServiceEntry, exit chan struct{}) error { + // Create a new client + client, err := newClient() + if err != nil { + return err + } + defer client.Close() + + client.setInterface(nil, true) + + // Start listening for response packets + msgCh := make(chan *dns.Msg, 32) + + go client.recv(client.ipv4MulticastConn, msgCh) + go client.recv(client.ipv6MulticastConn, msgCh) + go client.recv(client.ipv4MulticastConn, msgCh) + go client.recv(client.ipv6MulticastConn, msgCh) + + ip := make(map[string]*ServiceEntry) + + for { + select { + case <-exit: + return nil + case <-client.closedCh: + return nil + case m := <-msgCh: + e := messageToEntry(m, ip) + if e == nil { + continue + } + + // Check if this entry is complete + if e.complete() { + if e.sent { + continue + } + e.sent = true + entries <- e + ip = make(map[string]*ServiceEntry) + } else { + // Fire off a node specific query + m := new(dns.Msg) + m.SetQuestion(e.Name, dns.TypePTR) + m.RecursionDesired = false + if err := client.sendQuery(m); err != nil { + log.Printf("[ERR] mdns: Failed to query instance %s: %v", e.Name, err) + } + } + } + } + + return nil +} + // Lookup is the same as Query, however it uses all the default parameters func Lookup(service string, entries chan<- *ServiceEntry) error { params := DefaultParams(service) @@ -178,23 +235,29 @@ func (c *client) Close() error { // setInterface is used to set the query interface, uses sytem // default if not provided -func (c *client) setInterface(iface *net.Interface) error { +func (c *client) setInterface(iface *net.Interface, loopback bool) error { p := ipv4.NewPacketConn(c.ipv4UnicastConn) - if err := p.SetMulticastInterface(iface); err != nil { + if err := p.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { return err } p2 := ipv6.NewPacketConn(c.ipv6UnicastConn) - if err := p2.SetMulticastInterface(iface); err != nil { + if err := p2.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { return err } p = ipv4.NewPacketConn(c.ipv4MulticastConn) - if err := p.SetMulticastInterface(iface); err != nil { + if err := p.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { return err } p2 = ipv6.NewPacketConn(c.ipv6MulticastConn) - if err := p2.SetMulticastInterface(iface); err != nil { + if err := p2.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { return err } + + if loopback { + p.SetMulticastLoopback(true) + p2.SetMulticastLoopback(true) + } + return nil } @@ -232,64 +295,11 @@ func (c *client) query(params *QueryParam) error { // Listen until we reach the timeout finish := time.After(params.Timeout) + for { select { case resp := <-msgCh: - var inp *ServiceEntry - for _, answer := range append(resp.Answer, resp.Extra...) { - // TODO(reddaly): Check that response corresponds to serviceAddr? - switch rr := answer.(type) { - case *dns.PTR: - // Create new entry for this - inp = ensureName(inprogress, rr.Ptr) - if inp.complete() { - continue - } - - case *dns.SRV: - // Check for a target mismatch - if rr.Target != rr.Hdr.Name { - alias(inprogress, rr.Hdr.Name, rr.Target) - } - - // Get the port - inp = ensureName(inprogress, rr.Hdr.Name) - if inp.complete() { - continue - } - inp.Host = rr.Target - inp.Port = int(rr.Port) - - case *dns.TXT: - // Pull out the txt - inp = ensureName(inprogress, rr.Hdr.Name) - if inp.complete() { - continue - } - inp.Info = strings.Join(rr.Txt, "|") - inp.InfoFields = rr.Txt - inp.hasTXT = true - - case *dns.A: - // Pull out the IP - inp = ensureName(inprogress, rr.Hdr.Name) - if inp.complete() { - continue - } - inp.Addr = rr.A // @Deprecated - inp.AddrV4 = rr.A - - case *dns.AAAA: - // Pull out the IP - inp = ensureName(inprogress, rr.Hdr.Name) - if inp.complete() { - continue - } - inp.Addr = rr.AAAA // @Deprecated - inp.AddrV6 = rr.AAAA - } - } - + inp := messageToEntry(resp, inprogress) if inp == nil { continue } @@ -379,3 +389,63 @@ func alias(inprogress map[string]*ServiceEntry, src, dst string) { srcEntry := ensureName(inprogress, src) inprogress[dst] = srcEntry } + +func messageToEntry(m *dns.Msg, inprogress map[string]*ServiceEntry) *ServiceEntry { + var inp *ServiceEntry + + for _, answer := range append(m.Answer, m.Extra...) { + // TODO(reddaly): Check that response corresponds to serviceAddr? + switch rr := answer.(type) { + case *dns.PTR: + // Create new entry for this + inp = ensureName(inprogress, rr.Ptr) + if inp.complete() { + continue + } + case *dns.SRV: + // Check for a target mismatch + if rr.Target != rr.Hdr.Name { + alias(inprogress, rr.Hdr.Name, rr.Target) + } + + // Get the port + inp = ensureName(inprogress, rr.Hdr.Name) + if inp.complete() { + continue + } + inp.Host = rr.Target + inp.Port = int(rr.Port) + case *dns.TXT: + // Pull out the txt + inp = ensureName(inprogress, rr.Hdr.Name) + if inp.complete() { + continue + } + inp.Info = strings.Join(rr.Txt, "|") + inp.InfoFields = rr.Txt + inp.hasTXT = true + case *dns.A: + // Pull out the IP + inp = ensureName(inprogress, rr.Hdr.Name) + if inp.complete() { + continue + } + inp.Addr = rr.A // @Deprecated + inp.AddrV4 = rr.A + case *dns.AAAA: + // Pull out the IP + inp = ensureName(inprogress, rr.Hdr.Name) + if inp.complete() { + continue + } + inp.Addr = rr.AAAA // @Deprecated + inp.AddrV6 = rr.AAAA + } + + if inp != nil { + inp.TTL = int(answer.Header().Ttl) + } + } + + return inp +} diff --git a/server.go b/server.go index 541187d..4cde1a5 100644 --- a/server.go +++ b/server.go @@ -3,27 +3,38 @@ package mdns import ( "fmt" "log" + "math/rand" "net" "sync" + "time" "github.com/miekg/dns" -) - -const ( - ipv4mdns = "224.0.0.251" - ipv6mdns = "ff02::fb" - mdnsPort = 5353 - forceUnicastResponses = false + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" ) var ( + mdnsGroupIPv4 = net.IPv4(224, 0, 0, 251) + mdnsGroupIPv6 = net.ParseIP("ff02::fb") + + // mDNS wildcard addresses + mdnsWildcardAddrIPv4 = &net.UDPAddr{ + IP: net.ParseIP("224.0.0.0"), + Port: 5353, + } + mdnsWildcardAddrIPv6 = &net.UDPAddr{ + IP: net.ParseIP("[ff02::]"), + Port: 5353, + } + + // mDNS endpoint addresses ipv4Addr = &net.UDPAddr{ - IP: net.ParseIP(ipv4mdns), - Port: mdnsPort, + IP: mdnsGroupIPv4, + Port: 5353, } ipv6Addr = &net.UDPAddr{ - IP: net.ParseIP(ipv6mdns), - Port: mdnsPort, + IP: mdnsGroupIPv6, + Port: 5353, } ) @@ -54,12 +65,43 @@ type Server struct { // NewServer is used to create a new mDNS server from a config func NewServer(config *Config) (*Server, error) { // Create the listeners - ipv4List, _ := net.ListenMulticastUDP("udp4", config.Iface, ipv4Addr) - ipv6List, _ := net.ListenMulticastUDP("udp6", config.Iface, ipv6Addr) - - // Check if we have any listener + // Create wildcard connections (because :5353 can be already taken by other apps) + ipv4List, _ := net.ListenUDP("udp4", mdnsWildcardAddrIPv4) + ipv6List, _ := net.ListenUDP("udp6", mdnsWildcardAddrIPv6) if ipv4List == nil && ipv6List == nil { - return nil, fmt.Errorf("No multicast listeners could be started") + return nil, fmt.Errorf("[ERR] mdns: Failed to bind to any udp port!") + } + + // Join multicast groups to receive announcements + p1 := ipv4.NewPacketConn(ipv4List) + p2 := ipv6.NewPacketConn(ipv6List) + p1.SetMulticastLoopback(true) + p2.SetMulticastLoopback(true) + + if config.Iface != nil { + if err := p1.JoinGroup(config.Iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { + return nil, err + } + if err := p2.JoinGroup(config.Iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { + return nil, err + } + } else { + ifaces, err := net.Interfaces() + if err != nil { + return nil, err + } + errCount1, errCount2 := 0, 0 + for _, iface := range ifaces { + if err := p1.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { + errCount1++ + } + if err := p2.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { + errCount2++ + } + } + if len(ifaces) == errCount1 && len(ifaces) == errCount2 { + return nil, fmt.Errorf("Failed to join multicast group on all interfaces!") + } } s := &Server{ @@ -77,6 +119,8 @@ func NewServer(config *Config) (*Server, error) { go s.recv(s.ipv6List) } + go s.probe() + return s, nil } @@ -88,8 +132,10 @@ func (s *Server) Shutdown() error { if s.shutdown { return nil } + s.shutdown = true close(s.shutdownCh) + s.unregister() if s.ipv4List != nil { s.ipv4List.Close() @@ -97,6 +143,7 @@ func (s *Server) Shutdown() error { if s.ipv6List != nil { s.ipv6List.Close() } + return nil } @@ -222,12 +269,12 @@ func (s *Server) handleQuery(query *dns.Msg, from net.Addr) error { } if mresp := resp(false); mresp != nil { - if err := s.sendResponse(mresp, from, false); err != nil { + if err := s.sendResponse(mresp, from); err != nil { return fmt.Errorf("mdns: error sending multicast response: %v", err) } } if uresp := resp(true); uresp != nil { - if err := s.sendResponse(uresp, from, true); err != nil { + if err := s.sendResponse(uresp, from); err != nil { return fmt.Errorf("mdns: error sending unicast response: %v", err) } } @@ -256,14 +303,100 @@ func (s *Server) handleQuestion(q dns.Question) (multicastRecs, unicastRecs []dn // In the Question Section of a Multicast DNS query, the top bit of the // qclass field is used to indicate that unicast responses are preferred // for this particular question. (See Section 5.4.) - if q.Qclass&(1<<15) != 0 || forceUnicastResponses { + if q.Qclass&(1<<15) != 0 { return nil, records } return records, nil } +func (s *Server) probe() { + sd, ok := s.config.Zone.(*MDNSService) + if !ok { + return + } + + name := fmt.Sprintf("%s.%s.%s.", sd.Instance, trimDot(sd.Service), trimDot(sd.Domain)) + + q := new(dns.Msg) + q.SetQuestion(name, dns.TypePTR) + q.RecursionDesired = false + + srv := &dns.SRV{ + Hdr: dns.RR_Header{ + Name: name, + Rrtype: dns.TypeSRV, + Class: dns.ClassINET, + Ttl: defaultTTL, + }, + Priority: 0, + Weight: 0, + Port: uint16(sd.Port), + Target: sd.HostName, + } + txt := &dns.TXT{ + Hdr: dns.RR_Header{ + Name: name, + Rrtype: dns.TypeTXT, + Class: dns.ClassINET, + Ttl: defaultTTL, + }, + Txt: sd.TXT, + } + q.Ns = []dns.RR{srv, txt} + + randomizer := rand.New(rand.NewSource(time.Now().UnixNano())) + + for i := 0; i < 3; i++ { + if err := s.multicastResponse(q); err != nil { + log.Println("[ERR] mdns: failed to send probe:", err.Error()) + } + time.Sleep(time.Duration(randomizer.Intn(250)) * time.Millisecond) + } + + resp := new(dns.Msg) + resp.MsgHdr.Response = true + + // set for query + q.SetQuestion(name, dns.TypeANY) + + resp.Answer = append(resp.Answer, s.config.Zone.Records(q.Question[0])...) + + // reset + q.SetQuestion(name, dns.TypePTR) + + // From RFC6762 + // The Multicast DNS responder MUST send at least two unsolicited + // responses, one second apart. To provide increased robustness against + // packet loss, a responder MAY send up to eight unsolicited responses, + // provided that the interval between unsolicited responses increases by + // at least a factor of two with every response sent. + timeout := 1 * time.Second + for i := 0; i < 3; i++ { + if err := s.multicastResponse(resp); err != nil { + log.Println("[ERR] mdns: failed to send announcement:", err.Error()) + } + time.Sleep(timeout) + timeout *= 2 + } +} + +// multicastResponse us used to send a multicast response packet +func (s *Server) multicastResponse(msg *dns.Msg) error { + buf, err := msg.Pack() + if err != nil { + return err + } + if s.ipv4List != nil { + s.ipv4List.WriteToUDP(buf, ipv4Addr) + } + if s.ipv6List != nil { + s.ipv6List.WriteToUDP(buf, ipv6Addr) + } + return nil +} + // sendResponse is used to send a response packet -func (s *Server) sendResponse(resp *dns.Msg, from net.Addr, unicast bool) error { +func (s *Server) sendResponse(resp *dns.Msg, from net.Addr) error { // TODO(reddaly): Respect the unicast argument, and allow sending responses // over multicast. buf, err := resp.Pack() @@ -281,3 +414,22 @@ func (s *Server) sendResponse(resp *dns.Msg, from net.Addr, unicast bool) error return err } } + +func (s *Server) unregister() error { + sd, ok := s.config.Zone.(*MDNSService) + if !ok { + return nil + } + + sd.TTL = 0 + name := fmt.Sprintf("%s.%s.%s.", sd.Instance, trimDot(sd.Service), trimDot(sd.Domain)) + + q := new(dns.Msg) + q.SetQuestion(name, dns.TypeANY) + + resp := new(dns.Msg) + resp.MsgHdr.Response = true + resp.Answer = append(resp.Answer, s.config.Zone.Records(q.Question[0])...) + + return s.multicastResponse(resp) +} diff --git a/zone.go b/zone.go index 6f442c7..1ef04d1 100644 --- a/zone.go +++ b/zone.go @@ -23,14 +23,14 @@ type Zone interface { // MDNSService is used to export a named service by implementing a Zone type MDNSService struct { - Instance string // Instance name (e.g. "hostService name") - Service string // Service name (e.g. "_http._tcp.") - Domain string // If blank, assumes "local" - HostName string // Host machine DNS name (e.g. "mymachine.net.") - Port int // Service Port - IPs []net.IP // IP addresses for the service's host - TXT []string // Service TXT records - + Instance string // Instance name (e.g. "hostService name") + Service string // Service name (e.g. "_http._tcp.") + Domain string // If blank, assumes "local" + HostName string // Host machine DNS name (e.g. "mymachine.net.") + Port int // Service Port + IPs []net.IP // IP addresses for the service's host + TXT []string // Service TXT records + TTL uint32 serviceAddr string // Fully qualified service address instanceAddr string // Fully qualified instance address enumAddr string // _services._dns-sd._udp. @@ -122,6 +122,7 @@ func NewMDNSService(instance, service, domain, hostName string, port int, ips [] Port: port, IPs: ips, TXT: txt, + TTL: defaultTTL, serviceAddr: fmt.Sprintf("%s.%s.", trimDot(service), trimDot(domain)), instanceAddr: fmt.Sprintf("%s.%s.%s.", instance, trimDot(service), trimDot(domain)), enumAddr: fmt.Sprintf("_services._dns-sd._udp.%s.", trimDot(domain)), @@ -162,7 +163,7 @@ func (m *MDNSService) serviceEnum(q dns.Question) []dns.RR { Name: q.Name, Rrtype: dns.TypePTR, Class: dns.ClassINET, - Ttl: defaultTTL, + Ttl: m.TTL, }, Ptr: m.serviceAddr, } @@ -184,7 +185,7 @@ func (m *MDNSService) serviceRecords(q dns.Question) []dns.RR { Name: q.Name, Rrtype: dns.TypePTR, Class: dns.ClassINET, - Ttl: defaultTTL, + Ttl: m.TTL, }, Ptr: m.instanceAddr, } @@ -229,7 +230,7 @@ func (m *MDNSService) instanceRecords(q dns.Question) []dns.RR { Name: m.HostName, Rrtype: dns.TypeA, Class: dns.ClassINET, - Ttl: defaultTTL, + Ttl: m.TTL, }, A: ip4, }) @@ -254,7 +255,7 @@ func (m *MDNSService) instanceRecords(q dns.Question) []dns.RR { Name: m.HostName, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, - Ttl: defaultTTL, + Ttl: m.TTL, }, AAAA: ip16, }) @@ -269,7 +270,7 @@ func (m *MDNSService) instanceRecords(q dns.Question) []dns.RR { Name: q.Name, Rrtype: dns.TypeSRV, Class: dns.ClassINET, - Ttl: defaultTTL, + Ttl: m.TTL, }, Priority: 10, Weight: 1, @@ -297,7 +298,7 @@ func (m *MDNSService) instanceRecords(q dns.Question) []dns.RR { Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, - Ttl: defaultTTL, + Ttl: m.TTL, }, Txt: m.TXT, }