package mdns import ( "context" "fmt" "net" "sync" "go.unistack.org/micro/v3/logger" "golang.org/x/net/dns/dnsmessage" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" ) // ServiceEntry is returned after we query for a service type ServiceEntry struct { Name string Host string AddrV4 net.IP AddrV6 net.IP Port int Info string InfoFields []string TTL int Type uint16 hasTXT bool sent bool } // complete is used to check if we have all the info we need func (s *ServiceEntry) complete() bool { return (s.AddrV4 != nil || s.AddrV6 != nil) && s.Port != 0 && s.hasTXT } // QueryParam is used to customize how a Lookup is performed type QueryParam struct { Service string // Service to lookup Domain string // Lookup domain, default "local" Type dnsmessage.Type // Lookup type, defaults to dns.TypePTR Interface *net.Interface // Multicast interface to use Entries chan<- *ServiceEntry // Entries Channel WantUnicastResponse bool // Unicast response desired, as per 5.4 in RFC } // DefaultParams is used to return a default set of QueryParam's func DefaultParams(service string) *QueryParam { qp := &QueryParam{ Service: service, Domain: "local", Entries: make(chan *ServiceEntry), WantUnicastResponse: false, } return qp } // Query looks up a given service, in a domain, waiting at most // for a timeout before finishing the query. The results are streamed // to a channel. Sends will not block, so clients should make sure to // either read or buffer. func Query(ctx context.Context, params *QueryParam) error { // Create a new client client, err := newClient() if err != nil { return err } defer client.Close() // Set the multicast interface if params.Interface != nil { if err := client.setInterface(params.Interface, false); err != nil { return err } } // Ensure defaults are set if params.Domain == "" { params.Domain = "local" } // Run the query return client.query(ctx, params) } // Listen listens indefinitely for multicast updates func Listen(entries chan<- *ServiceEntry, exit chan struct{}) error { // Create a new client client, err := newClient() if err != nil { return err } defer client.Close() client.setInterface(nil, true) // Start listening for response packets msgCh := make(chan []byte, 32) go client.recv(client.ipv4UnicastConn, msgCh) go client.recv(client.ipv6UnicastConn, msgCh) go client.recv(client.ipv4MulticastConn, msgCh) go client.recv(client.ipv6MulticastConn, msgCh) sentry := make(map[string]*ServiceEntry) for { select { case <-exit: return nil case <-client.closedCh: return nil case msg := <-msgCh: fmt.Printf("%#+v\n", msg) entry := messageToEntry(msg, sentry) if entry == nil { continue } // Check if this entry is complete if entry.complete() { if entry.sent { continue } entry.sent = true entries <- entry sentry = make(map[string]*ServiceEntry) } else { // Fire off a node specific query /* h: -&dnsmessage.Header{RecursionDesired: false} m := dnsmessage.NewBuilder() m.SetQuestion(e.Name, dns.TypePTR) if err := client.sendQuery(m); err != nil { logger.Errorf("[ERR] mdns: Failed to query instance %s: %v", e.Name, err) } */ } } } return nil } /* // Lookup is the same as Query, however it uses all the default parameters func Lookup(service string, entries chan<- *ServiceEntry) error { params := DefaultParams(service) params.Entries = entries return Query(params) } */ // Client provides a query interface that can be used to // search for service providers using mDNS type client struct { ipv4UnicastConn *net.UDPConn ipv6UnicastConn *net.UDPConn ipv4MulticastConn *net.UDPConn ipv6MulticastConn *net.UDPConn closed bool closedCh chan struct{} // TODO(reddaly): This doesn't appear to be used. closeLock sync.RWMutex } // NewClient creates a new mdns Client that can be used to query // for records func newClient() (*client, error) { // TODO(reddaly): At least attempt to bind to the port required in the spec. // Create a IPv4 listener uconn4, err4 := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) uconn6, err6 := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0}) if err4 != nil && err6 != nil { logger.Errorf(context.TODO(), "[ERR] mdns: Failed to bind to udp port: %v %v", err4, err6) } if uconn4 == nil && uconn6 == nil { return nil, fmt.Errorf("failed to bind to any unicast udp port") } if uconn4 == nil { uconn4 = &net.UDPConn{} } if uconn6 == nil { uconn6 = &net.UDPConn{} } mconn4, err4 := net.ListenUDP("udp4", mdnsWildcardAddrIPv4) mconn6, err6 := net.ListenUDP("udp6", mdnsWildcardAddrIPv6) if err4 != nil && err6 != nil { logger.Errorf(context.TODO(), "[ERR] mdns: Failed to bind to udp port: %v %v", err4, err6) } if mconn4 == nil && mconn6 == nil { return nil, fmt.Errorf("failed to bind to any multicast udp port") } if mconn4 == nil { mconn4 = &net.UDPConn{} } if mconn6 == nil { mconn6 = &net.UDPConn{} } p1 := ipv4.NewPacketConn(mconn4) p2 := ipv6.NewPacketConn(mconn6) p1.SetMulticastLoopback(true) p2.SetMulticastLoopback(true) ifaces, err := net.Interfaces() if err != nil { return nil, err } var errCount1, errCount2 int for _, iface := range ifaces { if err := p1.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { errCount1++ } if err := p2.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { errCount2++ } } if len(ifaces) == errCount1 && len(ifaces) == errCount2 { return nil, fmt.Errorf("Failed to join multicast group on all interfaces!") } c := &client{ ipv4MulticastConn: mconn4, ipv6MulticastConn: mconn6, ipv4UnicastConn: uconn4, ipv6UnicastConn: uconn6, closedCh: make(chan struct{}), } return c, nil } // Close is used to cleanup the client func (c *client) Close() error { c.closeLock.Lock() defer c.closeLock.Unlock() if c.closed { return nil } c.closed = true close(c.closedCh) if c.ipv4UnicastConn != nil { c.ipv4UnicastConn.Close() } if c.ipv6UnicastConn != nil { c.ipv6UnicastConn.Close() } if c.ipv4MulticastConn != nil { c.ipv4MulticastConn.Close() } if c.ipv6MulticastConn != nil { c.ipv6MulticastConn.Close() } return nil } // setInterface is used to set the query interface, uses sytem // default if not provided func (c *client) setInterface(iface *net.Interface, loopback bool) error { p := ipv4.NewPacketConn(c.ipv4UnicastConn) if err := p.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { return err } p2 := ipv6.NewPacketConn(c.ipv6UnicastConn) if err := p2.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { return err } p = ipv4.NewPacketConn(c.ipv4MulticastConn) if err := p.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { return err } p2 = ipv6.NewPacketConn(c.ipv6MulticastConn) if err := p2.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { return err } if loopback { p.SetMulticastLoopback(true) p2.SetMulticastLoopback(true) } return nil } // query is used to perform a lookup and stream results func (c *client) query(ctx context.Context, params *QueryParam) error { // Create the service name serviceAddr := fmt.Sprintf("%s.%s.", trimDot(params.Service), trimDot(params.Domain)) // Start listening for response packets msgCh := make(chan []byte, 32) go c.recv(c.ipv4UnicastConn, msgCh) go c.recv(c.ipv6UnicastConn, msgCh) go c.recv(c.ipv4MulticastConn, msgCh) go c.recv(c.ipv6MulticastConn, msgCh) // buf := make([]byte, 2, 514) hdr := dnsmessage.Header{RecursionDesired: false} b := dnsmessage.NewBuilder(nil, hdr) // b.EnableCompression() name, err := dnsmessage.NewName(serviceAddr) if err != nil { return err } q := dnsmessage.Question{Name: name, Class: dnsmessage.ClassINET} if params.Type == 0 { q.Type = dnsmessage.TypePTR } else { q.Type = params.Type } // q.Class |= 1 << 15 if err = b.StartQuestions(); err != nil { return err } if err = b.Question(q); err != nil { return err } bbuf, err := b.Finish() if err != nil { return err } // Send the query // RFC 6762, section 18.12. Repurposing of Top Bit of qclass in Question // Section // // In the Question Section of a Multicast DNS query, the top bit of the qclass // field is used to indicate that unicast responses are preferred for this // particular question. (See Section 5.4.) if err := c.sendQuery(bbuf); err != nil { return err } // Map the in-progress responses inprogress := make(map[string]*ServiceEntry) for { select { case rsp := <-msgCh: inp := messageToEntry(rsp, inprogress) if inp == nil { continue } // Check if this entry is complete if inp.complete() { if inp.sent { continue } inp.sent = true select { case params.Entries <- inp: case <-ctx.Done(): return nil } } else { // Fire off a node specific query // m := new(dns.Msg) // m.SetQuestion(inp.Name, inp.Type) // m.RecursionDesired = false var buf []byte if err := c.sendQuery(buf); err != nil { logger.Errorf(context.TODO(), "[ERR] mdns: Failed to query instance %s: %v", inp.Name, err) } } case <-ctx.Done(): return nil } } } // sendQuery is used to multicast a query out func (c *client) sendQuery(buf []byte) error { if c.ipv4UnicastConn != nil { c.ipv4UnicastConn.WriteToUDP(buf, ipv4Addr) } if c.ipv6UnicastConn != nil { c.ipv6UnicastConn.WriteToUDP(buf, ipv6Addr) } return nil } // recv is used to receive until we get a shutdown func (c *client) recv(l *net.UDPConn, msgCh chan []byte) { if l == nil { return } buf := make([]byte, 65536) for { select { case <-c.closedCh: return default: c.closeLock.Lock() if c.closed { c.closeLock.Unlock() return } c.closeLock.Unlock() n, err := l.Read(buf) if err != nil { if logger.V(logger.DebugLevel) { logger.Debug(context.TODO(), err) } continue } msgCh <- buf[:n] } } } /* // ensureName is used to ensure the named node is in progress func ensureName(inprogress map[string]*ServiceEntry, name string, typ uint16) *ServiceEntry { if inp, ok := inprogress[name]; ok { return inp } inp := &ServiceEntry{ Name: name, Type: typ, } inprogress[name] = inp return inp } // alias is used to setup an alias between two entries func alias(inprogress map[string]*ServiceEntry, src, dst string, typ uint16) { srcEntry := ensureName(inprogress, src, typ) inprogress[dst] = srcEntry } */ func messageToEntry(m []byte, inprogress map[string]*ServiceEntry) *ServiceEntry { var inp *ServiceEntry /* for _, answer := range append(m.Answers, m.Additionals...) { // TODO(reddaly): Check that response corresponds to serviceAddr? switch answer.Header.Type { case dnsmessage.TypePTR: rr := answer.Body.(*dnsmessage.PTRResource) // Create new entry for this inp = ensureName(inprogress, rr.Ptr, rr.Hdr.Rrtype) if inp.complete() { continue } case dnsmessage.TypeSRV: // Check for a target mismatch if rr.Target != rr.Hdr.Name { alias(inprogress, rr.Hdr.Name, rr.Target, rr.Hdr.Rrtype) } // Get the port inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype) if inp.complete() { continue } inp.Host = rr.Target inp.Port = int(rr.Port) case dnsmessage.TypeTXT: // Pull out the txt inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype) if inp.complete() { continue } inp.Info = strings.Join(rr.Txt, "|") inp.InfoFields = rr.Txt inp.hasTXT = true case dnsmessage.TypeA: // Pull out the IP inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype) if inp.complete() { continue } inp.AddrV4 = rr.A case dnsmessage.TypeAAAA: // Pull out the IP inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype) if inp.complete() { continue } inp.AddrV6 = rr.AAAA } if inp != nil { inp.TTL = int(answer.Header().Ttl) } } */ return inp }