package mdns import ( "context" "fmt" "log" "net" "strings" "sync" "time" "github.com/miekg/dns" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" ) // ServiceEntry is returned after we query for a service type ServiceEntry struct { Name string Host string AddrV4 net.IP AddrV6 net.IP Port int Info string InfoFields []string TTL int Addr net.IP // @Deprecated hasTXT bool sent bool } // complete is used to check if we have all the info we need func (s *ServiceEntry) complete() bool { return (s.AddrV4 != nil || s.AddrV6 != nil || s.Addr != nil) && s.Port != 0 && s.hasTXT } // QueryParam is used to customize how a Lookup is performed type QueryParam struct { Service string // Service to lookup Domain string // Lookup domain, default "local" Context context.Context // Context Timeout time.Duration // Lookup timeout, default 1 second. Ignored if Context is provided Interface *net.Interface // Multicast interface to use Entries chan<- *ServiceEntry // Entries Channel WantUnicastResponse bool // Unicast response desired, as per 5.4 in RFC } // DefaultParams is used to return a default set of QueryParam's func DefaultParams(service string) *QueryParam { return &QueryParam{ Service: service, Domain: "local", Timeout: time.Second, Entries: make(chan *ServiceEntry), WantUnicastResponse: false, // TODO(reddaly): Change this default. } } // Query looks up a given service, in a domain, waiting at most // for a timeout before finishing the query. The results are streamed // to a channel. Sends will not block, so clients should make sure to // either read or buffer. func Query(params *QueryParam) error { // Create a new client client, err := newClient() if err != nil { return err } defer client.Close() // Set the multicast interface if params.Interface != nil { if err := client.setInterface(params.Interface, false); err != nil { return err } } // Ensure defaults are set if params.Domain == "" { params.Domain = "local" } if params.Context == nil { if params.Timeout == 0 { params.Timeout = time.Second } params.Context, _ = context.WithTimeout(context.Background(), params.Timeout) if err != nil { return err } } // Run the query return client.query(params) } // Listen listens indefinitely for multicast updates func Listen(entries chan<- *ServiceEntry, exit chan struct{}) error { // Create a new client client, err := newClient() if err != nil { return err } defer client.Close() client.setInterface(nil, true) // Start listening for response packets msgCh := make(chan *dns.Msg, 32) go client.recv(client.ipv4MulticastConn, msgCh) go client.recv(client.ipv6MulticastConn, msgCh) go client.recv(client.ipv4MulticastConn, msgCh) go client.recv(client.ipv6MulticastConn, msgCh) ip := make(map[string]*ServiceEntry) for { select { case <-exit: return nil case <-client.closedCh: return nil case m := <-msgCh: e := messageToEntry(m, ip) if e == nil { continue } // Check if this entry is complete if e.complete() { if e.sent { continue } e.sent = true entries <- e ip = make(map[string]*ServiceEntry) } else { // Fire off a node specific query m := new(dns.Msg) m.SetQuestion(e.Name, dns.TypePTR) m.RecursionDesired = false if err := client.sendQuery(m); err != nil { log.Printf("[ERR] mdns: Failed to query instance %s: %v", e.Name, err) } } } } return nil } // Lookup is the same as Query, however it uses all the default parameters func Lookup(service string, entries chan<- *ServiceEntry) error { params := DefaultParams(service) params.Entries = entries return Query(params) } // Client provides a query interface that can be used to // search for service providers using mDNS type client struct { ipv4UnicastConn *net.UDPConn ipv6UnicastConn *net.UDPConn ipv4MulticastConn *net.UDPConn ipv6MulticastConn *net.UDPConn closed bool closedCh chan struct{} // TODO(reddaly): This doesn't appear to be used. closeLock sync.Mutex } // NewClient creates a new mdns Client that can be used to query // for records func newClient() (*client, error) { // TODO(reddaly): At least attempt to bind to the port required in the spec. // Create a IPv4 listener uconn4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) if err != nil { log.Printf("[ERR] mdns: Failed to bind to udp4 port: %v", err) } uconn6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0}) if err != nil { log.Printf("[ERR] mdns: Failed to bind to udp6 port: %v", err) } if uconn4 == nil && uconn6 == nil { return nil, fmt.Errorf("failed to bind to any unicast udp port") } if uconn4 == nil { uconn4 = &net.UDPConn{} } if uconn6 == nil { uconn6 = &net.UDPConn{} } mconn4, err := net.ListenUDP("udp4", mdnsWildcardAddrIPv4) if err != nil { log.Printf("[ERR] mdns: Failed to bind to udp4 port: %v", err) } mconn6, err := net.ListenUDP("udp6", mdnsWildcardAddrIPv6) if err != nil { log.Printf("[ERR] mdns: Failed to bind to udp6 port: %v", err) } if mconn4 == nil && mconn6 == nil { return nil, fmt.Errorf("failed to bind to any multicast udp port") } if mconn4 == nil { mconn4 = &net.UDPConn{} } if mconn6 == nil { mconn6 = &net.UDPConn{} } p1 := ipv4.NewPacketConn(mconn4) p2 := ipv6.NewPacketConn(mconn6) ifaces, err := net.Interfaces() if err != nil { return nil, err } var errCount1, errCount2 int for _, iface := range ifaces { if err := p1.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { errCount1++ } if err := p2.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { errCount2++ } } if len(ifaces) == errCount1 && len(ifaces) == errCount2 { return nil, fmt.Errorf("Failed to join multicast group on all interfaces!") } c := &client{ ipv4MulticastConn: mconn4, ipv6MulticastConn: mconn6, ipv4UnicastConn: uconn4, ipv6UnicastConn: uconn6, closedCh: make(chan struct{}), } return c, nil } // Close is used to cleanup the client func (c *client) Close() error { c.closeLock.Lock() defer c.closeLock.Unlock() if c.closed { return nil } c.closed = true close(c.closedCh) if c.ipv4UnicastConn != nil { c.ipv4UnicastConn.Close() } if c.ipv6UnicastConn != nil { c.ipv6UnicastConn.Close() } if c.ipv4MulticastConn != nil { c.ipv4MulticastConn.Close() } if c.ipv6MulticastConn != nil { c.ipv6MulticastConn.Close() } return nil } // setInterface is used to set the query interface, uses sytem // default if not provided func (c *client) setInterface(iface *net.Interface, loopback bool) error { p := ipv4.NewPacketConn(c.ipv4UnicastConn) if err := p.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { return err } p2 := ipv6.NewPacketConn(c.ipv6UnicastConn) if err := p2.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { return err } p = ipv4.NewPacketConn(c.ipv4MulticastConn) if err := p.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { return err } p2 = ipv6.NewPacketConn(c.ipv6MulticastConn) if err := p2.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { return err } if loopback { p.SetMulticastLoopback(true) p2.SetMulticastLoopback(true) } return nil } // query is used to perform a lookup and stream results func (c *client) query(params *QueryParam) error { // Create the service name serviceAddr := fmt.Sprintf("%s.%s.", trimDot(params.Service), trimDot(params.Domain)) // Start listening for response packets msgCh := make(chan *dns.Msg, 32) go c.recv(c.ipv4UnicastConn, msgCh) go c.recv(c.ipv6UnicastConn, msgCh) go c.recv(c.ipv4MulticastConn, msgCh) go c.recv(c.ipv6MulticastConn, msgCh) // Send the query m := new(dns.Msg) m.SetQuestion(serviceAddr, dns.TypePTR) // RFC 6762, section 18.12. Repurposing of Top Bit of qclass in Question // Section // // In the Question Section of a Multicast DNS query, the top bit of the qclass // field is used to indicate that unicast responses are preferred for this // particular question. (See Section 5.4.) if params.WantUnicastResponse { m.Question[0].Qclass |= 1 << 15 } m.RecursionDesired = false if err := c.sendQuery(m); err != nil { return err } // Map the in-progress responses inprogress := make(map[string]*ServiceEntry) for { select { case resp := <-msgCh: inp := messageToEntry(resp, inprogress) if inp == nil { continue } // Check if this entry is complete if inp.complete() { if inp.sent { continue } inp.sent = true select { case params.Entries <- inp: case <-params.Context.Done(): return nil } } else { // Fire off a node specific query m := new(dns.Msg) m.SetQuestion(inp.Name, dns.TypePTR) m.RecursionDesired = false if err := c.sendQuery(m); err != nil { log.Printf("[ERR] mdns: Failed to query instance %s: %v", inp.Name, err) } } case <-params.Context.Done(): return nil } } } // sendQuery is used to multicast a query out func (c *client) sendQuery(q *dns.Msg) error { buf, err := q.Pack() if err != nil { return err } if c.ipv4UnicastConn != nil { c.ipv4UnicastConn.WriteToUDP(buf, ipv4Addr) } if c.ipv6UnicastConn != nil { c.ipv6UnicastConn.WriteToUDP(buf, ipv6Addr) } return nil } // recv is used to receive until we get a shutdown func (c *client) recv(l *net.UDPConn, msgCh chan *dns.Msg) { if l == nil { return } buf := make([]byte, 65536) for { c.closeLock.Lock() if c.closed { c.closeLock.Unlock() return } c.closeLock.Unlock() n, err := l.Read(buf) if err != nil { continue } msg := new(dns.Msg) if err := msg.Unpack(buf[:n]); err != nil { continue } select { case msgCh <- msg: case <-c.closedCh: return } } } // ensureName is used to ensure the named node is in progress func ensureName(inprogress map[string]*ServiceEntry, name string) *ServiceEntry { if inp, ok := inprogress[name]; ok { return inp } inp := &ServiceEntry{ Name: name, } inprogress[name] = inp return inp } // alias is used to setup an alias between two entries func alias(inprogress map[string]*ServiceEntry, src, dst string) { srcEntry := ensureName(inprogress, src) inprogress[dst] = srcEntry } func messageToEntry(m *dns.Msg, inprogress map[string]*ServiceEntry) *ServiceEntry { var inp *ServiceEntry for _, answer := range append(m.Answer, m.Extra...) { // TODO(reddaly): Check that response corresponds to serviceAddr? switch rr := answer.(type) { case *dns.PTR: // Create new entry for this inp = ensureName(inprogress, rr.Ptr) if inp.complete() { continue } case *dns.SRV: // Check for a target mismatch if rr.Target != rr.Hdr.Name { alias(inprogress, rr.Hdr.Name, rr.Target) } // Get the port inp = ensureName(inprogress, rr.Hdr.Name) if inp.complete() { continue } inp.Host = rr.Target inp.Port = int(rr.Port) case *dns.TXT: // Pull out the txt inp = ensureName(inprogress, rr.Hdr.Name) if inp.complete() { continue } inp.Info = strings.Join(rr.Txt, "|") inp.InfoFields = rr.Txt inp.hasTXT = true case *dns.A: // Pull out the IP inp = ensureName(inprogress, rr.Hdr.Name) if inp.complete() { continue } inp.Addr = rr.A // @Deprecated inp.AddrV4 = rr.A case *dns.AAAA: // Pull out the IP inp = ensureName(inprogress, rr.Hdr.Name) if inp.complete() { continue } inp.Addr = rr.AAAA // @Deprecated inp.AddrV6 = rr.AAAA } if inp != nil { inp.TTL = int(answer.Header().Ttl) } } return inp }