many lint fixes and optimizations #17
							
								
								
									
										23
									
								
								util/mdns/.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										23
									
								
								util/mdns/.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -1,23 +0,0 @@ | |||||||
| # Compiled Object files, Static and Dynamic libs (Shared Objects) |  | ||||||
| *.o |  | ||||||
| *.a |  | ||||||
| *.so |  | ||||||
|  |  | ||||||
| # Folders |  | ||||||
| _obj |  | ||||||
| _test |  | ||||||
|  |  | ||||||
| # Architecture specific extensions/prefixes |  | ||||||
| *.[568vq] |  | ||||||
| [568vq].out |  | ||||||
|  |  | ||||||
| *.cgo1.go |  | ||||||
| *.cgo2.c |  | ||||||
| _cgo_defun.c |  | ||||||
| _cgo_gotypes.go |  | ||||||
| _cgo_export.* |  | ||||||
|  |  | ||||||
| _testmain.go |  | ||||||
|  |  | ||||||
| *.exe |  | ||||||
| *.test |  | ||||||
| @@ -1,511 +0,0 @@ | |||||||
| package mdns |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"context" |  | ||||||
| 	"fmt" |  | ||||||
| 	"log" |  | ||||||
| 	"net" |  | ||||||
| 	"strings" |  | ||||||
| 	"sync" |  | ||||||
| 	"time" |  | ||||||
|  |  | ||||||
| 	"github.com/miekg/dns" |  | ||||||
| 	"golang.org/x/net/ipv4" |  | ||||||
| 	"golang.org/x/net/ipv6" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| // ServiceEntry is returned after we query for a service |  | ||||||
| type ServiceEntry struct { |  | ||||||
| 	Name       string |  | ||||||
| 	Host       string |  | ||||||
| 	AddrV4     net.IP |  | ||||||
| 	AddrV6     net.IP |  | ||||||
| 	Port       int |  | ||||||
| 	Info       string |  | ||||||
| 	InfoFields []string |  | ||||||
| 	TTL        int |  | ||||||
| 	Type       uint16 |  | ||||||
|  |  | ||||||
| 	Addr net.IP // @Deprecated |  | ||||||
|  |  | ||||||
| 	hasTXT bool |  | ||||||
| 	sent   bool |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // complete is used to check if we have all the info we need |  | ||||||
| func (s *ServiceEntry) complete() bool { |  | ||||||
|  |  | ||||||
| 	return (len(s.AddrV4) > 0 || len(s.AddrV6) > 0 || len(s.Addr) > 0) && s.Port != 0 && s.hasTXT |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // QueryParam is used to customize how a Lookup is performed |  | ||||||
| type QueryParam struct { |  | ||||||
| 	Service             string               // Service to lookup |  | ||||||
| 	Domain              string               // Lookup domain, default "local" |  | ||||||
| 	Type                uint16               // Lookup type, defaults to dns.TypePTR |  | ||||||
| 	Context             context.Context      // Context |  | ||||||
| 	Timeout             time.Duration        // Lookup timeout, default 1 second. Ignored if Context is provided |  | ||||||
| 	Interface           *net.Interface       // Multicast interface to use |  | ||||||
| 	Entries             chan<- *ServiceEntry // Entries Channel |  | ||||||
| 	WantUnicastResponse bool                 // Unicast response desired, as per 5.4 in RFC |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // DefaultParams is used to return a default set of QueryParam's |  | ||||||
| func DefaultParams(service string) *QueryParam { |  | ||||||
| 	return &QueryParam{ |  | ||||||
| 		Service:             service, |  | ||||||
| 		Domain:              "local", |  | ||||||
| 		Timeout:             time.Second, |  | ||||||
| 		Entries:             make(chan *ServiceEntry), |  | ||||||
| 		WantUnicastResponse: false, // TODO(reddaly): Change this default. |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Query looks up a given service, in a domain, waiting at most |  | ||||||
| // for a timeout before finishing the query. The results are streamed |  | ||||||
| // to a channel. Sends will not block, so clients should make sure to |  | ||||||
| // either read or buffer. |  | ||||||
| func Query(params *QueryParam) error { |  | ||||||
| 	// Create a new client |  | ||||||
| 	client, err := newClient() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
| 	defer client.Close() |  | ||||||
|  |  | ||||||
| 	// Set the multicast interface |  | ||||||
| 	if params.Interface != nil { |  | ||||||
| 		if err := client.setInterface(params.Interface, false); err != nil { |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Ensure defaults are set |  | ||||||
| 	if params.Domain == "" { |  | ||||||
| 		params.Domain = "local" |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if params.Context == nil { |  | ||||||
| 		var cancel context.CancelFunc |  | ||||||
| 		if params.Timeout == 0 { |  | ||||||
| 			params.Timeout = time.Second |  | ||||||
| 		} |  | ||||||
| 		params.Context, cancel = context.WithTimeout(context.Background(), params.Timeout) |  | ||||||
| 		defer cancel() |  | ||||||
| 		if err != nil { |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Run the query |  | ||||||
| 	return client.query(params) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Listen listens indefinitely for multicast updates |  | ||||||
| func Listen(entries chan<- *ServiceEntry, exit chan struct{}) error { |  | ||||||
| 	// Create a new client |  | ||||||
| 	client, err := newClient() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
| 	defer client.Close() |  | ||||||
|  |  | ||||||
| 	client.setInterface(nil, true) |  | ||||||
|  |  | ||||||
| 	// Start listening for response packets |  | ||||||
| 	msgCh := make(chan *dns.Msg, 32) |  | ||||||
|  |  | ||||||
| 	go client.recv(client.ipv4UnicastConn, msgCh) |  | ||||||
| 	go client.recv(client.ipv6UnicastConn, msgCh) |  | ||||||
| 	go client.recv(client.ipv4MulticastConn, msgCh) |  | ||||||
| 	go client.recv(client.ipv6MulticastConn, msgCh) |  | ||||||
|  |  | ||||||
| 	ip := make(map[string]*ServiceEntry) |  | ||||||
|  |  | ||||||
| loop: |  | ||||||
| 	for { |  | ||||||
| 		select { |  | ||||||
| 		case <-exit: |  | ||||||
| 			break loop |  | ||||||
| 		case <-client.closedCh: |  | ||||||
| 			break loop |  | ||||||
| 		case m := <-msgCh: |  | ||||||
| 			e := messageToEntry(m, ip) |  | ||||||
| 			if e == nil { |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			// Check if this entry is complete |  | ||||||
| 			if e.complete() { |  | ||||||
| 				if e.sent { |  | ||||||
| 					continue |  | ||||||
| 				} |  | ||||||
| 				e.sent = true |  | ||||||
| 				entries <- e |  | ||||||
| 				ip = make(map[string]*ServiceEntry) |  | ||||||
| 			} else { |  | ||||||
| 				// Fire off a node specific query |  | ||||||
| 				m := new(dns.Msg) |  | ||||||
| 				m.SetQuestion(e.Name, dns.TypePTR) |  | ||||||
| 				m.RecursionDesired = false |  | ||||||
| 				if err := client.sendQuery(m); err != nil { |  | ||||||
| 					log.Printf("[ERR] mdns: Failed to query instance %s: %v", e.Name, err) |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Lookup is the same as Query, however it uses all the default parameters |  | ||||||
| func Lookup(service string, entries chan<- *ServiceEntry) error { |  | ||||||
| 	params := DefaultParams(service) |  | ||||||
| 	params.Entries = entries |  | ||||||
| 	return Query(params) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Client provides a query interface that can be used to |  | ||||||
| // search for service providers using mDNS |  | ||||||
| type client struct { |  | ||||||
| 	ipv4UnicastConn *net.UDPConn |  | ||||||
| 	ipv6UnicastConn *net.UDPConn |  | ||||||
|  |  | ||||||
| 	ipv4MulticastConn *net.UDPConn |  | ||||||
| 	ipv6MulticastConn *net.UDPConn |  | ||||||
|  |  | ||||||
| 	closed    bool |  | ||||||
| 	closedCh  chan struct{} // TODO(reddaly): This doesn't appear to be used. |  | ||||||
| 	closeLock sync.Mutex |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // NewClient creates a new mdns Client that can be used to query |  | ||||||
| // for records |  | ||||||
| func newClient() (*client, error) { |  | ||||||
| 	// TODO(reddaly): At least attempt to bind to the port required in the spec. |  | ||||||
| 	// Create a IPv4 listener |  | ||||||
| 	uconn4, err4 := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) |  | ||||||
| 	uconn6, err6 := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0}) |  | ||||||
| 	if err4 != nil && err6 != nil { |  | ||||||
| 		log.Printf("[ERR] mdns: Failed to bind to udp port: %v %v", err4, err6) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if uconn4 == nil && uconn6 == nil { |  | ||||||
| 		return nil, fmt.Errorf("failed to bind to any unicast udp port") |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if uconn4 == nil { |  | ||||||
| 		uconn4 = &net.UDPConn{} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if uconn6 == nil { |  | ||||||
| 		uconn6 = &net.UDPConn{} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	mconn4, err4 := net.ListenUDP("udp4", mdnsWildcardAddrIPv4) |  | ||||||
| 	mconn6, err6 := net.ListenUDP("udp6", mdnsWildcardAddrIPv6) |  | ||||||
| 	if err4 != nil && err6 != nil { |  | ||||||
| 		log.Printf("[ERR] mdns: Failed to bind to udp port: %v %v", err4, err6) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if mconn4 == nil && mconn6 == nil { |  | ||||||
| 		return nil, fmt.Errorf("failed to bind to any multicast udp port") |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if mconn4 == nil { |  | ||||||
| 		mconn4 = &net.UDPConn{} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if mconn6 == nil { |  | ||||||
| 		mconn6 = &net.UDPConn{} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	p1 := ipv4.NewPacketConn(mconn4) |  | ||||||
| 	p2 := ipv6.NewPacketConn(mconn6) |  | ||||||
| 	p1.SetMulticastLoopback(true) |  | ||||||
| 	p2.SetMulticastLoopback(true) |  | ||||||
|  |  | ||||||
| 	ifaces, err := net.Interfaces() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	var errCount1, errCount2 int |  | ||||||
|  |  | ||||||
| 	for _, iface := range ifaces { |  | ||||||
| 		if err := p1.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { |  | ||||||
| 			errCount1++ |  | ||||||
| 		} |  | ||||||
| 		if err := p2.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { |  | ||||||
| 			errCount2++ |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if len(ifaces) == errCount1 && len(ifaces) == errCount2 { |  | ||||||
| 		return nil, fmt.Errorf("Failed to join multicast group on all interfaces!") |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	c := &client{ |  | ||||||
| 		ipv4MulticastConn: mconn4, |  | ||||||
| 		ipv6MulticastConn: mconn6, |  | ||||||
| 		ipv4UnicastConn:   uconn4, |  | ||||||
| 		ipv6UnicastConn:   uconn6, |  | ||||||
| 		closedCh:          make(chan struct{}), |  | ||||||
| 	} |  | ||||||
| 	return c, nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Close is used to cleanup the client |  | ||||||
| func (c *client) Close() error { |  | ||||||
| 	c.closeLock.Lock() |  | ||||||
| 	defer c.closeLock.Unlock() |  | ||||||
|  |  | ||||||
| 	if c.closed { |  | ||||||
| 		return nil |  | ||||||
| 	} |  | ||||||
| 	c.closed = true |  | ||||||
|  |  | ||||||
| 	close(c.closedCh) |  | ||||||
|  |  | ||||||
| 	if c.ipv4UnicastConn != nil { |  | ||||||
| 		c.ipv4UnicastConn.Close() |  | ||||||
| 	} |  | ||||||
| 	if c.ipv6UnicastConn != nil { |  | ||||||
| 		c.ipv6UnicastConn.Close() |  | ||||||
| 	} |  | ||||||
| 	if c.ipv4MulticastConn != nil { |  | ||||||
| 		c.ipv4MulticastConn.Close() |  | ||||||
| 	} |  | ||||||
| 	if c.ipv6MulticastConn != nil { |  | ||||||
| 		c.ipv6MulticastConn.Close() |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // setInterface is used to set the query interface, uses system |  | ||||||
| // default if not provided |  | ||||||
| func (c *client) setInterface(iface *net.Interface, loopback bool) error { |  | ||||||
| 	p := ipv4.NewPacketConn(c.ipv4UnicastConn) |  | ||||||
| 	if err := p.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
| 	p2 := ipv6.NewPacketConn(c.ipv6UnicastConn) |  | ||||||
| 	if err := p2.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
| 	p = ipv4.NewPacketConn(c.ipv4MulticastConn) |  | ||||||
| 	if err := p.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
| 	p2 = ipv6.NewPacketConn(c.ipv6MulticastConn) |  | ||||||
| 	if err := p2.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if loopback { |  | ||||||
| 		p.SetMulticastLoopback(true) |  | ||||||
| 		p2.SetMulticastLoopback(true) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // query is used to perform a lookup and stream results |  | ||||||
| func (c *client) query(params *QueryParam) error { |  | ||||||
| 	// Create the service name |  | ||||||
| 	serviceAddr := fmt.Sprintf("%s.%s.", trimDot(params.Service), trimDot(params.Domain)) |  | ||||||
|  |  | ||||||
| 	// Start listening for response packets |  | ||||||
| 	msgCh := make(chan *dns.Msg, 32) |  | ||||||
| 	go c.recv(c.ipv4UnicastConn, msgCh) |  | ||||||
| 	go c.recv(c.ipv6UnicastConn, msgCh) |  | ||||||
| 	go c.recv(c.ipv4MulticastConn, msgCh) |  | ||||||
| 	go c.recv(c.ipv6MulticastConn, msgCh) |  | ||||||
|  |  | ||||||
| 	// Send the query |  | ||||||
| 	m := new(dns.Msg) |  | ||||||
| 	if params.Type == dns.TypeNone { |  | ||||||
| 		m.SetQuestion(serviceAddr, dns.TypePTR) |  | ||||||
| 	} else { |  | ||||||
| 		m.SetQuestion(serviceAddr, params.Type) |  | ||||||
| 	} |  | ||||||
| 	// RFC 6762, section 18.12.  Repurposing of Top Bit of qclass in Question |  | ||||||
| 	// Section |  | ||||||
| 	// |  | ||||||
| 	// In the Question Section of a Multicast DNS query, the top bit of the qclass |  | ||||||
| 	// field is used to indicate that unicast responses are preferred for this |  | ||||||
| 	// particular question.  (See Section 5.4.) |  | ||||||
| 	if params.WantUnicastResponse { |  | ||||||
| 		m.Question[0].Qclass |= 1 << 15 |  | ||||||
| 	} |  | ||||||
| 	m.RecursionDesired = false |  | ||||||
| 	if err := c.sendQuery(m); err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Map the in-progress responses |  | ||||||
| 	inprogress := make(map[string]*ServiceEntry) |  | ||||||
|  |  | ||||||
| 	for { |  | ||||||
| 		select { |  | ||||||
| 		case resp := <-msgCh: |  | ||||||
| 			inp := messageToEntry(resp, inprogress) |  | ||||||
|  |  | ||||||
| 			if inp == nil { |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
| 			if len(resp.Question) == 0 || resp.Question[0].Name != m.Question[0].Name { |  | ||||||
| 				// discard anything which we've not asked for |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			// Check if this entry is complete |  | ||||||
| 			if inp.complete() { |  | ||||||
| 				if inp.sent { |  | ||||||
| 					continue |  | ||||||
| 				} |  | ||||||
|  |  | ||||||
| 				inp.sent = true |  | ||||||
| 				select { |  | ||||||
| 				case params.Entries <- inp: |  | ||||||
| 				case <-params.Context.Done(): |  | ||||||
| 					return nil |  | ||||||
| 				} |  | ||||||
| 			} else { |  | ||||||
| 				// Fire off a node specific query |  | ||||||
| 				m := new(dns.Msg) |  | ||||||
| 				m.SetQuestion(inp.Name, inp.Type) |  | ||||||
| 				m.RecursionDesired = false |  | ||||||
| 				if err := c.sendQuery(m); err != nil { |  | ||||||
| 					log.Printf("[ERR] mdns: Failed to query instance %s: %v", inp.Name, err) |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 		case <-params.Context.Done(): |  | ||||||
| 			return nil |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // sendQuery is used to multicast a query out |  | ||||||
| func (c *client) sendQuery(q *dns.Msg) error { |  | ||||||
| 	buf, err := q.Pack() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
| 	if c.ipv4UnicastConn != nil { |  | ||||||
| 		c.ipv4UnicastConn.WriteToUDP(buf, ipv4Addr) |  | ||||||
| 	} |  | ||||||
| 	if c.ipv6UnicastConn != nil { |  | ||||||
| 		c.ipv6UnicastConn.WriteToUDP(buf, ipv6Addr) |  | ||||||
| 	} |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // recv is used to receive until we get a shutdown |  | ||||||
| func (c *client) recv(l *net.UDPConn, msgCh chan *dns.Msg) { |  | ||||||
| 	if l == nil { |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 	buf := make([]byte, 65536) |  | ||||||
| 	for { |  | ||||||
| 		c.closeLock.Lock() |  | ||||||
| 		if c.closed { |  | ||||||
| 			c.closeLock.Unlock() |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 		c.closeLock.Unlock() |  | ||||||
| 		n, err := l.Read(buf) |  | ||||||
| 		if err != nil { |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 		msg := new(dns.Msg) |  | ||||||
| 		if err := msg.Unpack(buf[:n]); err != nil { |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 		select { |  | ||||||
| 		case msgCh <- msg: |  | ||||||
| 		case <-c.closedCh: |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // ensureName is used to ensure the named node is in progress |  | ||||||
| func ensureName(inprogress map[string]*ServiceEntry, name string, typ uint16) *ServiceEntry { |  | ||||||
| 	if inp, ok := inprogress[name]; ok { |  | ||||||
| 		return inp |  | ||||||
| 	} |  | ||||||
| 	inp := &ServiceEntry{ |  | ||||||
| 		Name: name, |  | ||||||
| 		Type: typ, |  | ||||||
| 	} |  | ||||||
| 	inprogress[name] = inp |  | ||||||
| 	return inp |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // alias is used to setup an alias between two entries |  | ||||||
| func alias(inprogress map[string]*ServiceEntry, src, dst string, typ uint16) { |  | ||||||
| 	srcEntry := ensureName(inprogress, src, typ) |  | ||||||
| 	inprogress[dst] = srcEntry |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func messageToEntry(m *dns.Msg, inprogress map[string]*ServiceEntry) *ServiceEntry { |  | ||||||
| 	var inp *ServiceEntry |  | ||||||
|  |  | ||||||
| 	for _, answer := range append(m.Answer, m.Extra...) { |  | ||||||
| 		// TODO(reddaly): Check that response corresponds to serviceAddr? |  | ||||||
| 		switch rr := answer.(type) { |  | ||||||
| 		case *dns.PTR: |  | ||||||
| 			// Create new entry for this |  | ||||||
| 			inp = ensureName(inprogress, rr.Ptr, rr.Hdr.Rrtype) |  | ||||||
| 			if inp.complete() { |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
| 		case *dns.SRV: |  | ||||||
| 			// Check for a target mismatch |  | ||||||
| 			if rr.Target != rr.Hdr.Name { |  | ||||||
| 				alias(inprogress, rr.Hdr.Name, rr.Target, rr.Hdr.Rrtype) |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			// Get the port |  | ||||||
| 			inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype) |  | ||||||
| 			if inp.complete() { |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
| 			inp.Host = rr.Target |  | ||||||
| 			inp.Port = int(rr.Port) |  | ||||||
| 		case *dns.TXT: |  | ||||||
| 			// Pull out the txt |  | ||||||
| 			inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype) |  | ||||||
| 			if inp.complete() { |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
| 			inp.Info = strings.Join(rr.Txt, "|") |  | ||||||
| 			inp.InfoFields = rr.Txt |  | ||||||
| 			inp.hasTXT = true |  | ||||||
| 		case *dns.A: |  | ||||||
| 			// Pull out the IP |  | ||||||
| 			inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype) |  | ||||||
| 			if inp.complete() { |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
| 			inp.Addr = rr.A // @Deprecated |  | ||||||
| 			inp.AddrV4 = rr.A |  | ||||||
| 		case *dns.AAAA: |  | ||||||
| 			// Pull out the IP |  | ||||||
| 			inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype) |  | ||||||
| 			if inp.complete() { |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
| 			inp.Addr = rr.AAAA // @Deprecated |  | ||||||
| 			inp.AddrV6 = rr.AAAA |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		if inp != nil { |  | ||||||
| 			inp.TTL = int(answer.Header().Ttl) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	return inp |  | ||||||
| } |  | ||||||
| @@ -1,84 +0,0 @@ | |||||||
| package mdns |  | ||||||
|  |  | ||||||
| import "github.com/miekg/dns" |  | ||||||
|  |  | ||||||
| // DNSSDService is a service that complies with the DNS-SD (RFC 6762) and MDNS |  | ||||||
| // (RFC 6762) specs for local, multicast-DNS-based discovery. |  | ||||||
| // |  | ||||||
| // DNSSDService implements the Zone interface and wraps an MDNSService instance. |  | ||||||
| // To deploy an mDNS service that is compliant with DNS-SD, it's recommended to |  | ||||||
| // register only the wrapped instance with the server. |  | ||||||
| // |  | ||||||
| // Example usage: |  | ||||||
| //     service := &mdns.DNSSDService{ |  | ||||||
| //       MDNSService: &mdns.MDNSService{ |  | ||||||
| // 	       Instance: "My Foobar Service", |  | ||||||
| // 	       Service: "_foobar._tcp", |  | ||||||
| // 	       Port:    8000, |  | ||||||
| //        } |  | ||||||
| //      } |  | ||||||
| //      server, err := mdns.NewServer(&mdns.Config{Zone: service}) |  | ||||||
| //      if err != nil { |  | ||||||
| //        log.Fatalf("Error creating server: %v", err) |  | ||||||
| //      } |  | ||||||
| //      defer server.Shutdown() |  | ||||||
| type DNSSDService struct { |  | ||||||
| 	MDNSService *MDNSService |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Records returns DNS records in response to a DNS question. |  | ||||||
| // |  | ||||||
| // This function returns the DNS response of the underlying MDNSService |  | ||||||
| // instance.  It also returns a PTR record for a request for " |  | ||||||
| // _services._dns-sd._udp.<Domain>", as described in section 9 of RFC 6763 |  | ||||||
| // ("Service Type Enumeration"), to allow browsing of the underlying MDNSService |  | ||||||
| // instance. |  | ||||||
| func (s *DNSSDService) Records(q dns.Question) []dns.RR { |  | ||||||
| 	var recs []dns.RR |  | ||||||
| 	if q.Name == "_services._dns-sd._udp."+s.MDNSService.Domain+"." { |  | ||||||
| 		recs = s.dnssdMetaQueryRecords(q) |  | ||||||
| 	} |  | ||||||
| 	return append(recs, s.MDNSService.Records(q)...) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // dnssdMetaQueryRecords returns the DNS records in response to a "meta-query" |  | ||||||
| // issued to browse for DNS-SD services, as per section 9. of RFC6763. |  | ||||||
| // |  | ||||||
| // A meta-query has a name of the form "_services._dns-sd._udp.<Domain>" where |  | ||||||
| // Domain is a fully-qualified domain, such as "local." |  | ||||||
| func (s *DNSSDService) dnssdMetaQueryRecords(q dns.Question) []dns.RR { |  | ||||||
| 	// Intended behavior, as described in the RFC: |  | ||||||
| 	//     ...it may be useful for network administrators to find the list of |  | ||||||
| 	//     advertised service types on the network, even if those Service Names |  | ||||||
| 	//     are just opaque identifiers and not particularly informative in |  | ||||||
| 	//     isolation. |  | ||||||
| 	// |  | ||||||
| 	//     For this purpose, a special meta-query is defined.  A DNS query for PTR |  | ||||||
| 	//     records with the name "_services._dns-sd._udp.<Domain>" yields a set of |  | ||||||
| 	//     PTR records, where the rdata of each PTR record is the two-abel |  | ||||||
| 	//     <Service> name, plus the same domain, e.g., "_http._tcp.<Domain>". |  | ||||||
| 	//     Including the domain in the PTR rdata allows for slightly better name |  | ||||||
| 	//     compression in Unicast DNS responses, but only the first two labels are |  | ||||||
| 	//     relevant for the purposes of service type enumeration.  These two-label |  | ||||||
| 	//     service types can then be used to construct subsequent Service Instance |  | ||||||
| 	//     Enumeration PTR queries, in this <Domain> or others, to discover |  | ||||||
| 	//     instances of that service type. |  | ||||||
| 	return []dns.RR{ |  | ||||||
| 		&dns.PTR{ |  | ||||||
| 			Hdr: dns.RR_Header{ |  | ||||||
| 				Name:   q.Name, |  | ||||||
| 				Rrtype: dns.TypePTR, |  | ||||||
| 				Class:  dns.ClassINET, |  | ||||||
| 				Ttl:    defaultTTL, |  | ||||||
| 			}, |  | ||||||
| 			Ptr: s.MDNSService.serviceAddr, |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Announcement returns DNS records that should be broadcast during the initial |  | ||||||
| // availability of the service, as described in section 8.3 of RFC 6762. |  | ||||||
| // TODO(reddaly): Add this when Announcement is added to the mdns.Zone interface. |  | ||||||
| //func (s *DNSSDService) Announcement() []dns.RR { |  | ||||||
| //	return s.MDNSService.Announcement() |  | ||||||
| //} |  | ||||||
| @@ -1,69 +0,0 @@ | |||||||
| package mdns |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"reflect" |  | ||||||
| 	"testing" |  | ||||||
|  |  | ||||||
| 	"github.com/miekg/dns" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| type mockMDNSService struct{} |  | ||||||
|  |  | ||||||
| func (s *mockMDNSService) Records(q dns.Question) []dns.RR { |  | ||||||
| 	return []dns.RR{ |  | ||||||
| 		&dns.PTR{ |  | ||||||
| 			Hdr: dns.RR_Header{ |  | ||||||
| 				Name:   "fakerecord", |  | ||||||
| 				Rrtype: dns.TypePTR, |  | ||||||
| 				Class:  dns.ClassINET, |  | ||||||
| 				Ttl:    42, |  | ||||||
| 			}, |  | ||||||
| 			Ptr: "fake.local.", |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (s *mockMDNSService) Announcement() []dns.RR { |  | ||||||
| 	return []dns.RR{ |  | ||||||
| 		&dns.PTR{ |  | ||||||
| 			Hdr: dns.RR_Header{ |  | ||||||
| 				Name:   "fakeannounce", |  | ||||||
| 				Rrtype: dns.TypePTR, |  | ||||||
| 				Class:  dns.ClassINET, |  | ||||||
| 				Ttl:    42, |  | ||||||
| 			}, |  | ||||||
| 			Ptr: "fake.local.", |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestDNSSDServiceRecords(t *testing.T) { |  | ||||||
| 	s := &DNSSDService{ |  | ||||||
| 		MDNSService: &MDNSService{ |  | ||||||
| 			serviceAddr: "_foobar._tcp.local.", |  | ||||||
| 			Domain:      "local", |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
| 	q := dns.Question{ |  | ||||||
| 		Name:   "_services._dns-sd._udp.local.", |  | ||||||
| 		Qtype:  dns.TypePTR, |  | ||||||
| 		Qclass: dns.ClassINET, |  | ||||||
| 	} |  | ||||||
| 	recs := s.Records(q) |  | ||||||
| 	if got, want := len(recs), 1; got != want { |  | ||||||
| 		t.Fatalf("s.Records(%v) returned %v records, want %v", q, got, want) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	want := dns.RR(&dns.PTR{ |  | ||||||
| 		Hdr: dns.RR_Header{ |  | ||||||
| 			Name:   "_services._dns-sd._udp.local.", |  | ||||||
| 			Rrtype: dns.TypePTR, |  | ||||||
| 			Class:  dns.ClassINET, |  | ||||||
| 			Ttl:    defaultTTL, |  | ||||||
| 		}, |  | ||||||
| 		Ptr: "_foobar._tcp.local.", |  | ||||||
| 	}) |  | ||||||
| 	if got := recs[0]; !reflect.DeepEqual(got, want) { |  | ||||||
| 		t.Errorf("s.Records()[0] = %v, want %v", got, want) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| @@ -1,527 +0,0 @@ | |||||||
| package mdns |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"context" |  | ||||||
| 	"fmt" |  | ||||||
| 	"math/rand" |  | ||||||
| 	"net" |  | ||||||
| 	"sync" |  | ||||||
| 	"sync/atomic" |  | ||||||
| 	"time" |  | ||||||
|  |  | ||||||
| 	"github.com/miekg/dns" |  | ||||||
| 	"github.com/unistack-org/micro/v3/logger" |  | ||||||
| 	"golang.org/x/net/ipv4" |  | ||||||
| 	"golang.org/x/net/ipv6" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| var ( |  | ||||||
| 	mdnsGroupIPv4 = net.ParseIP("224.0.0.251") |  | ||||||
| 	mdnsGroupIPv6 = net.ParseIP("ff02::fb") |  | ||||||
|  |  | ||||||
| 	// mDNS wildcard addresses |  | ||||||
| 	mdnsWildcardAddrIPv4 = &net.UDPAddr{ |  | ||||||
| 		IP:   net.ParseIP("224.0.0.0"), |  | ||||||
| 		Port: 5353, |  | ||||||
| 	} |  | ||||||
| 	mdnsWildcardAddrIPv6 = &net.UDPAddr{ |  | ||||||
| 		IP:   net.ParseIP("ff02::"), |  | ||||||
| 		Port: 5353, |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// mDNS endpoint addresses |  | ||||||
| 	ipv4Addr = &net.UDPAddr{ |  | ||||||
| 		IP:   mdnsGroupIPv4, |  | ||||||
| 		Port: 5353, |  | ||||||
| 	} |  | ||||||
| 	ipv6Addr = &net.UDPAddr{ |  | ||||||
| 		IP:   mdnsGroupIPv6, |  | ||||||
| 		Port: 5353, |  | ||||||
| 	} |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| // GetMachineIP is a func which returns the outbound IP of this machine. |  | ||||||
| // Used by the server to determine whether to attempt send the response on a local address |  | ||||||
| type GetMachineIP func() net.IP |  | ||||||
|  |  | ||||||
| // Config is used to configure the mDNS server |  | ||||||
| type Config struct { |  | ||||||
| 	// Zone must be provided to support responding to queries |  | ||||||
| 	Zone Zone |  | ||||||
|  |  | ||||||
| 	// Iface if provided binds the multicast listener to the given |  | ||||||
| 	// interface. If not provided, the system default multicase interface |  | ||||||
| 	// is used. |  | ||||||
| 	Iface *net.Interface |  | ||||||
|  |  | ||||||
| 	// Port If it is not 0, replace the port 5353 with this port number. |  | ||||||
| 	Port int |  | ||||||
|  |  | ||||||
| 	// GetMachineIP is a function to return the IP of the local machine |  | ||||||
| 	GetMachineIP GetMachineIP |  | ||||||
| 	// LocalhostChecking if enabled asks the server to also send responses to 0.0.0.0 if the target IP |  | ||||||
| 	// is this host (as defined by GetMachineIP). Useful in case machine is on a VPN which blocks comms on non standard ports |  | ||||||
| 	LocalhostChecking bool |  | ||||||
|  |  | ||||||
| 	Context context.Context |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Server is an mDNS server used to listen for mDNS queries and respond if we |  | ||||||
| // have a matching local record |  | ||||||
| type Server struct { |  | ||||||
| 	config *Config |  | ||||||
|  |  | ||||||
| 	ipv4List *net.UDPConn |  | ||||||
| 	ipv6List *net.UDPConn |  | ||||||
|  |  | ||||||
| 	shutdown     bool |  | ||||||
| 	shutdownCh   chan struct{} |  | ||||||
| 	shutdownLock sync.Mutex |  | ||||||
| 	wg           sync.WaitGroup |  | ||||||
|  |  | ||||||
| 	outboundIP net.IP |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // NewServer is used to create a new mDNS server from a config |  | ||||||
| func NewServer(config *Config) (*Server, error) { |  | ||||||
| 	setCustomPort(config.Port) |  | ||||||
|  |  | ||||||
| 	// Create the listeners |  | ||||||
| 	// Create wildcard connections (because :5353 can be already taken by other apps) |  | ||||||
| 	ipv4List, _ := net.ListenUDP("udp4", mdnsWildcardAddrIPv4) |  | ||||||
| 	ipv6List, _ := net.ListenUDP("udp6", mdnsWildcardAddrIPv6) |  | ||||||
| 	if ipv4List == nil && ipv6List == nil { |  | ||||||
| 		return nil, fmt.Errorf("[ERR] mdns: Failed to bind to any udp port!") |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if ipv4List == nil { |  | ||||||
| 		ipv4List = &net.UDPConn{} |  | ||||||
| 	} |  | ||||||
| 	if ipv6List == nil { |  | ||||||
| 		ipv6List = &net.UDPConn{} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Join multicast groups to receive announcements |  | ||||||
| 	p1 := ipv4.NewPacketConn(ipv4List) |  | ||||||
| 	p2 := ipv6.NewPacketConn(ipv6List) |  | ||||||
| 	p1.SetMulticastLoopback(true) |  | ||||||
| 	p2.SetMulticastLoopback(true) |  | ||||||
|  |  | ||||||
| 	if config.Iface != nil { |  | ||||||
| 		if err := p1.JoinGroup(config.Iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { |  | ||||||
| 			return nil, err |  | ||||||
| 		} |  | ||||||
| 		if err := p2.JoinGroup(config.Iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { |  | ||||||
| 			return nil, err |  | ||||||
| 		} |  | ||||||
| 	} else { |  | ||||||
| 		ifaces, err := net.Interfaces() |  | ||||||
| 		if err != nil { |  | ||||||
| 			return nil, err |  | ||||||
| 		} |  | ||||||
| 		errCount1, errCount2 := 0, 0 |  | ||||||
| 		for _, iface := range ifaces { |  | ||||||
| 			if err := p1.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { |  | ||||||
| 				errCount1++ |  | ||||||
| 			} |  | ||||||
| 			if err := p2.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { |  | ||||||
| 				errCount2++ |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 		if len(ifaces) == errCount1 && len(ifaces) == errCount2 { |  | ||||||
| 			return nil, fmt.Errorf("Failed to join multicast group on all interfaces!") |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	ipFunc := getOutboundIP |  | ||||||
| 	if config.GetMachineIP != nil { |  | ||||||
| 		ipFunc = config.GetMachineIP |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	s := &Server{ |  | ||||||
| 		config:     config, |  | ||||||
| 		ipv4List:   ipv4List, |  | ||||||
| 		ipv6List:   ipv6List, |  | ||||||
| 		shutdownCh: make(chan struct{}), |  | ||||||
| 		outboundIP: ipFunc(), |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if s.config.Context == nil { |  | ||||||
| 		s.config.Context = context.Background() |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	go s.recv(s.ipv4List) |  | ||||||
| 	go s.recv(s.ipv6List) |  | ||||||
|  |  | ||||||
| 	s.wg.Add(1) |  | ||||||
| 	go s.probe() |  | ||||||
|  |  | ||||||
| 	return s, nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Shutdown is used to shutdown the listener |  | ||||||
| func (s *Server) Shutdown() error { |  | ||||||
| 	s.shutdownLock.Lock() |  | ||||||
| 	defer s.shutdownLock.Unlock() |  | ||||||
|  |  | ||||||
| 	if s.shutdown { |  | ||||||
| 		return nil |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	s.shutdown = true |  | ||||||
| 	close(s.shutdownCh) |  | ||||||
| 	if err := s.unregister(); err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if s.ipv4List != nil { |  | ||||||
| 		s.ipv4List.Close() |  | ||||||
| 	} |  | ||||||
| 	if s.ipv6List != nil { |  | ||||||
| 		s.ipv6List.Close() |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	s.wg.Wait() |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // recv is a long running routine to receive packets from an interface |  | ||||||
| func (s *Server) recv(c *net.UDPConn) { |  | ||||||
| 	if c == nil { |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 	buf := make([]byte, 65536) |  | ||||||
| 	for { |  | ||||||
| 		s.shutdownLock.Lock() |  | ||||||
| 		if s.shutdown { |  | ||||||
| 			s.shutdownLock.Unlock() |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 		s.shutdownLock.Unlock() |  | ||||||
| 		n, from, err := c.ReadFrom(buf) |  | ||||||
| 		if err != nil { |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 		if err := s.parsePacket(buf[:n], from); err != nil { |  | ||||||
| 			logger.Errorf(s.config.Context, "[ERR] mdns: Failed to handle query: %v", err) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // parsePacket is used to parse an incoming packet |  | ||||||
| func (s *Server) parsePacket(packet []byte, from net.Addr) error { |  | ||||||
| 	var msg dns.Msg |  | ||||||
| 	if err := msg.Unpack(packet); err != nil { |  | ||||||
| 		logger.Errorf(s.config.Context, "[ERR] mdns: Failed to unpack packet: %v", err) |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
| 	// TODO: This is a bit of a hack |  | ||||||
| 	// We decided to ignore some mDNS answers for the time being |  | ||||||
| 	// See: https://tools.ietf.org/html/rfc6762#section-7.2 |  | ||||||
| 	msg.Truncated = false |  | ||||||
| 	return s.handleQuery(&msg, from) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // handleQuery is used to handle an incoming query |  | ||||||
| func (s *Server) handleQuery(query *dns.Msg, from net.Addr) error { |  | ||||||
| 	if query.Opcode != dns.OpcodeQuery { |  | ||||||
| 		// "In both multicast query and multicast response messages, the OPCODE MUST |  | ||||||
| 		// be zero on transmission (only standard queries are currently supported |  | ||||||
| 		// over multicast).  Multicast DNS messages received with an OPCODE other |  | ||||||
| 		// than zero MUST be silently ignored."  Note: OpcodeQuery == 0 |  | ||||||
| 		return fmt.Errorf("mdns: received query with non-zero Opcode %v: %v", query.Opcode, *query) |  | ||||||
| 	} |  | ||||||
| 	if query.Rcode != 0 { |  | ||||||
| 		// "In both multicast query and multicast response messages, the Response |  | ||||||
| 		// Code MUST be zero on transmission.  Multicast DNS messages received with |  | ||||||
| 		// non-zero Response Codes MUST be silently ignored." |  | ||||||
| 		return fmt.Errorf("mdns: received query with non-zero Rcode %v: %v", query.Rcode, *query) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// TODO(reddaly): Handle "TC (Truncated) Bit": |  | ||||||
| 	//    In query messages, if the TC bit is set, it means that additional |  | ||||||
| 	//    Known-Answer records may be following shortly.  A responder SHOULD |  | ||||||
| 	//    record this fact, and wait for those additional Known-Answer records, |  | ||||||
| 	//    before deciding whether to respond.  If the TC bit is clear, it means |  | ||||||
| 	//    that the querying host has no additional Known Answers. |  | ||||||
| 	if query.Truncated { |  | ||||||
| 		return fmt.Errorf("[ERR] mdns: support for DNS requests with high truncated bit not implemented: %v", *query) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	unicastAnswer := make([]dns.RR, 0, len(query.Question)) |  | ||||||
| 	multicastAnswer := make([]dns.RR, 0, len(query.Question)) |  | ||||||
|  |  | ||||||
| 	// Handle each question |  | ||||||
| 	for _, q := range query.Question { |  | ||||||
| 		mrecs, urecs := s.handleQuestion(q) |  | ||||||
| 		multicastAnswer = append(multicastAnswer, mrecs...) |  | ||||||
| 		unicastAnswer = append(unicastAnswer, urecs...) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// See section 18 of RFC 6762 for rules about DNS headers. |  | ||||||
| 	resp := func(unicast bool) *dns.Msg { |  | ||||||
| 		// 18.1: ID (Query Identifier) |  | ||||||
| 		// 0 for multicast response, query.Id for unicast response |  | ||||||
| 		id := uint16(0) |  | ||||||
| 		if unicast { |  | ||||||
| 			id = query.Id |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		var answer []dns.RR |  | ||||||
| 		if unicast { |  | ||||||
| 			answer = unicastAnswer |  | ||||||
| 		} else { |  | ||||||
| 			answer = multicastAnswer |  | ||||||
| 		} |  | ||||||
| 		if len(answer) == 0 { |  | ||||||
| 			return nil |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		return &dns.Msg{ |  | ||||||
| 			MsgHdr: dns.MsgHdr{ |  | ||||||
| 				Id: id, |  | ||||||
|  |  | ||||||
| 				// 18.2: QR (Query/Response) Bit - must be set to 1 in response. |  | ||||||
| 				Response: true, |  | ||||||
|  |  | ||||||
| 				// 18.3: OPCODE - must be zero in response (OpcodeQuery == 0) |  | ||||||
| 				Opcode: dns.OpcodeQuery, |  | ||||||
|  |  | ||||||
| 				// 18.4: AA (Authoritative Answer) Bit - must be set to 1 |  | ||||||
| 				Authoritative: true, |  | ||||||
|  |  | ||||||
| 				// The following fields must all be set to 0: |  | ||||||
| 				// 18.5: TC (TRUNCATED) Bit |  | ||||||
| 				// 18.6: RD (Recursion Desired) Bit |  | ||||||
| 				// 18.7: RA (Recursion Available) Bit |  | ||||||
| 				// 18.8: Z (Zero) Bit |  | ||||||
| 				// 18.9: AD (Authentic Data) Bit |  | ||||||
| 				// 18.10: CD (Checking Disabled) Bit |  | ||||||
| 				// 18.11: RCODE (Response Code) |  | ||||||
| 			}, |  | ||||||
| 			// 18.12 pertains to questions (handled by handleQuestion) |  | ||||||
| 			// 18.13 pertains to resource records (handled by handleQuestion) |  | ||||||
|  |  | ||||||
| 			// 18.14: Name Compression - responses should be compressed (though see |  | ||||||
| 			// caveats in the RFC), so set the Compress bit (part of the dns library |  | ||||||
| 			// API, not part of the DNS packet) to true. |  | ||||||
| 			Compress: true, |  | ||||||
| 			Question: query.Question, |  | ||||||
| 			Answer:   answer, |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if mresp := resp(false); mresp != nil { |  | ||||||
| 		if err := s.sendResponse(mresp, from); err != nil { |  | ||||||
| 			return fmt.Errorf("mdns: error sending multicast response: %v", err) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	if uresp := resp(true); uresp != nil { |  | ||||||
| 		if err := s.sendResponse(uresp, from); err != nil { |  | ||||||
| 			return fmt.Errorf("mdns: error sending unicast response: %v", err) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // handleQuestion is used to handle an incoming question |  | ||||||
| // |  | ||||||
| // The response to a question may be transmitted over multicast, unicast, or |  | ||||||
| // both.  The return values are DNS records for each transmission type. |  | ||||||
| func (s *Server) handleQuestion(q dns.Question) (multicastRecs, unicastRecs []dns.RR) { |  | ||||||
| 	records := s.config.Zone.Records(q) |  | ||||||
| 	if len(records) == 0 { |  | ||||||
| 		return nil, nil |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Handle unicast and multicast responses. |  | ||||||
| 	// TODO(reddaly): The decision about sending over unicast vs. multicast is not |  | ||||||
| 	// yet fully compliant with RFC 6762.  For example, the unicast bit should be |  | ||||||
| 	// ignored if the records in question are close to TTL expiration.  For now, |  | ||||||
| 	// we just use the unicast bit to make the decision, as per the spec: |  | ||||||
| 	//     RFC 6762, section 18.12.  Repurposing of Top Bit of qclass in Question |  | ||||||
| 	//     Section |  | ||||||
| 	// |  | ||||||
| 	//     In the Question Section of a Multicast DNS query, the top bit of the |  | ||||||
| 	//     qclass field is used to indicate that unicast responses are preferred |  | ||||||
| 	//     for this particular question.  (See Section 5.4.) |  | ||||||
| 	if q.Qclass&(1<<15) != 0 { |  | ||||||
| 		return nil, records |  | ||||||
| 	} |  | ||||||
| 	return records, nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (s *Server) probe() { |  | ||||||
| 	defer s.wg.Done() |  | ||||||
|  |  | ||||||
| 	sd, ok := s.config.Zone.(*MDNSService) |  | ||||||
| 	if !ok { |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	name := fmt.Sprintf("%s.%s.%s.", sd.Instance, trimDot(sd.Service), trimDot(sd.Domain)) |  | ||||||
|  |  | ||||||
| 	q := new(dns.Msg) |  | ||||||
| 	q.SetQuestion(name, dns.TypePTR) |  | ||||||
| 	q.RecursionDesired = false |  | ||||||
|  |  | ||||||
| 	srv := &dns.SRV{ |  | ||||||
| 		Hdr: dns.RR_Header{ |  | ||||||
| 			Name:   name, |  | ||||||
| 			Rrtype: dns.TypeSRV, |  | ||||||
| 			Class:  dns.ClassINET, |  | ||||||
| 			Ttl:    defaultTTL, |  | ||||||
| 		}, |  | ||||||
| 		Priority: 0, |  | ||||||
| 		Weight:   0, |  | ||||||
| 		Port:     uint16(sd.Port), |  | ||||||
| 		Target:   sd.HostName, |  | ||||||
| 	} |  | ||||||
| 	txt := &dns.TXT{ |  | ||||||
| 		Hdr: dns.RR_Header{ |  | ||||||
| 			Name:   name, |  | ||||||
| 			Rrtype: dns.TypeTXT, |  | ||||||
| 			Class:  dns.ClassINET, |  | ||||||
| 			Ttl:    defaultTTL, |  | ||||||
| 		}, |  | ||||||
| 		Txt: sd.TXT, |  | ||||||
| 	} |  | ||||||
| 	q.Ns = []dns.RR{srv, txt} |  | ||||||
|  |  | ||||||
| 	randomizer := rand.New(rand.NewSource(time.Now().UnixNano())) |  | ||||||
|  |  | ||||||
| 	for i := 0; i < 3; i++ { |  | ||||||
| 		if err := s.SendMulticast(q); err != nil { |  | ||||||
| 			logger.Errorf(s.config.Context, "[ERR] mdns: failed to send probe: %v", err) |  | ||||||
| 		} |  | ||||||
| 		time.Sleep(time.Duration(randomizer.Intn(250)) * time.Millisecond) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	resp := new(dns.Msg) |  | ||||||
| 	resp.MsgHdr.Response = true |  | ||||||
|  |  | ||||||
| 	// set for query |  | ||||||
| 	q.SetQuestion(name, dns.TypeANY) |  | ||||||
|  |  | ||||||
| 	resp.Answer = append(resp.Answer, s.config.Zone.Records(q.Question[0])...) |  | ||||||
|  |  | ||||||
| 	// reset |  | ||||||
| 	q.SetQuestion(name, dns.TypePTR) |  | ||||||
|  |  | ||||||
| 	// From RFC6762 |  | ||||||
| 	//    The Multicast DNS responder MUST send at least two unsolicited |  | ||||||
| 	//    responses, one second apart. To provide increased robustness against |  | ||||||
| 	//    packet loss, a responder MAY send up to eight unsolicited responses, |  | ||||||
| 	//    provided that the interval between unsolicited responses increases by |  | ||||||
| 	//    at least a factor of two with every response sent. |  | ||||||
| 	timeout := 1 * time.Second |  | ||||||
| 	timer := time.NewTimer(timeout) |  | ||||||
| 	for i := 0; i < 3; i++ { |  | ||||||
| 		if err := s.SendMulticast(resp); err != nil { |  | ||||||
| 			logger.Errorf(s.config.Context, "[ERR] mdns: failed to send announcement: %v", err) |  | ||||||
| 		} |  | ||||||
| 		select { |  | ||||||
| 		case <-timer.C: |  | ||||||
| 			timeout *= 2 |  | ||||||
| 			timer.Reset(timeout) |  | ||||||
| 		case <-s.shutdownCh: |  | ||||||
| 			timer.Stop() |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // SendMulticast us used to send a multicast response packet |  | ||||||
| func (s *Server) SendMulticast(msg *dns.Msg) error { |  | ||||||
| 	buf, err := msg.Pack() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
| 	if s.ipv4List != nil { |  | ||||||
| 		s.ipv4List.WriteToUDP(buf, ipv4Addr) |  | ||||||
| 	} |  | ||||||
| 	if s.ipv6List != nil { |  | ||||||
| 		s.ipv6List.WriteToUDP(buf, ipv6Addr) |  | ||||||
| 	} |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // sendResponse is used to send a response packet |  | ||||||
| func (s *Server) sendResponse(resp *dns.Msg, from net.Addr) error { |  | ||||||
| 	// TODO(reddaly): Respect the unicast argument, and allow sending responses |  | ||||||
| 	// over multicast. |  | ||||||
| 	buf, err := resp.Pack() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Determine the socket to send from |  | ||||||
| 	addr := from.(*net.UDPAddr) |  | ||||||
| 	conn := s.ipv4List |  | ||||||
| 	backupTarget := net.IPv4zero |  | ||||||
|  |  | ||||||
| 	if addr.IP.To4() == nil { |  | ||||||
| 		conn = s.ipv6List |  | ||||||
| 		backupTarget = net.IPv6zero |  | ||||||
| 	} |  | ||||||
| 	_, err = conn.WriteToUDP(buf, addr) |  | ||||||
| 	// If the address we're responding to is this machine then we can also attempt sending on 0.0.0.0 |  | ||||||
| 	// This covers the case where this machine is using a VPN and certain ports are blocked so the response never gets there |  | ||||||
| 	// Sending two responses is OK |  | ||||||
| 	if s.config.LocalhostChecking && addr.IP.Equal(s.outboundIP) { |  | ||||||
| 		// ignore any errors, this is best efforts |  | ||||||
| 		conn.WriteToUDP(buf, &net.UDPAddr{IP: backupTarget, Port: addr.Port}) |  | ||||||
| 	} |  | ||||||
| 	return err |  | ||||||
|  |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (s *Server) unregister() error { |  | ||||||
| 	sd, ok := s.config.Zone.(*MDNSService) |  | ||||||
| 	if !ok { |  | ||||||
| 		return nil |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	atomic.StoreUint32(&sd.TTL, 0) |  | ||||||
| 	name := fmt.Sprintf("%s.%s.%s.", sd.Instance, trimDot(sd.Service), trimDot(sd.Domain)) |  | ||||||
|  |  | ||||||
| 	q := new(dns.Msg) |  | ||||||
| 	q.SetQuestion(name, dns.TypeANY) |  | ||||||
|  |  | ||||||
| 	resp := new(dns.Msg) |  | ||||||
| 	resp.MsgHdr.Response = true |  | ||||||
| 	resp.Answer = append(resp.Answer, s.config.Zone.Records(q.Question[0])...) |  | ||||||
|  |  | ||||||
| 	return s.SendMulticast(resp) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func setCustomPort(port int) { |  | ||||||
| 	if port != 0 { |  | ||||||
| 		if mdnsWildcardAddrIPv4.Port != port { |  | ||||||
| 			mdnsWildcardAddrIPv4.Port = port |  | ||||||
| 		} |  | ||||||
| 		if mdnsWildcardAddrIPv6.Port != port { |  | ||||||
| 			mdnsWildcardAddrIPv6.Port = port |  | ||||||
| 		} |  | ||||||
| 		if ipv4Addr.Port != port { |  | ||||||
| 			ipv4Addr.Port = port |  | ||||||
| 		} |  | ||||||
| 		if ipv6Addr.Port != port { |  | ||||||
| 			ipv6Addr.Port = port |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // getOutboundIP returns the IP address of this machine as seen when dialling out |  | ||||||
| func getOutboundIP() net.IP { |  | ||||||
| 	conn, err := net.Dial("udp", "8.8.8.8:80") |  | ||||||
| 	if err != nil { |  | ||||||
| 		// no net connectivity maybe so fallback |  | ||||||
| 		return nil |  | ||||||
| 	} |  | ||||||
| 	defer conn.Close() |  | ||||||
|  |  | ||||||
| 	localAddr := conn.LocalAddr().(*net.UDPAddr) |  | ||||||
|  |  | ||||||
| 	return localAddr.IP |  | ||||||
| } |  | ||||||
| @@ -1,61 +0,0 @@ | |||||||
| package mdns |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"testing" |  | ||||||
| 	"time" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| func TestServer_StartStop(t *testing.T) { |  | ||||||
| 	s := makeService(t) |  | ||||||
| 	serv, err := NewServer(&Config{Zone: s, LocalhostChecking: true}) |  | ||||||
| 	if err != nil { |  | ||||||
| 		t.Fatalf("err: %v", err) |  | ||||||
| 	} |  | ||||||
| 	defer serv.Shutdown() |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestServer_Lookup(t *testing.T) { |  | ||||||
| 	serv, err := NewServer(&Config{Zone: makeServiceWithServiceName(t, "_foobar._tcp"), LocalhostChecking: true}) |  | ||||||
| 	if err != nil { |  | ||||||
| 		t.Fatalf("err: %v", err) |  | ||||||
| 	} |  | ||||||
| 	defer serv.Shutdown() |  | ||||||
|  |  | ||||||
| 	entries := make(chan *ServiceEntry, 1) |  | ||||||
| 	found := false |  | ||||||
| 	doneCh := make(chan struct{}) |  | ||||||
| 	go func() { |  | ||||||
| 		select { |  | ||||||
| 		case e := <-entries: |  | ||||||
| 			if e.Name != "hostname._foobar._tcp.local." { |  | ||||||
| 				t.Fatalf("bad: %v", e) |  | ||||||
| 			} |  | ||||||
| 			if e.Port != 80 { |  | ||||||
| 				t.Fatalf("bad: %v", e) |  | ||||||
| 			} |  | ||||||
| 			if e.Info != "Local web server" { |  | ||||||
| 				t.Fatalf("bad: %v", e) |  | ||||||
| 			} |  | ||||||
| 			found = true |  | ||||||
|  |  | ||||||
| 		case <-time.After(80 * time.Millisecond): |  | ||||||
| 			t.Fatalf("timeout") |  | ||||||
| 		} |  | ||||||
| 		close(doneCh) |  | ||||||
| 	}() |  | ||||||
|  |  | ||||||
| 	params := &QueryParam{ |  | ||||||
| 		Service: "_foobar._tcp", |  | ||||||
| 		Domain:  "local", |  | ||||||
| 		Timeout: 50 * time.Millisecond, |  | ||||||
| 		Entries: entries, |  | ||||||
| 	} |  | ||||||
| 	err = Query(params) |  | ||||||
| 	if err != nil { |  | ||||||
| 		t.Fatalf("err: %v", err) |  | ||||||
| 	} |  | ||||||
| 	<-doneCh |  | ||||||
| 	if !found { |  | ||||||
| 		t.Fatalf("record not found") |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| @@ -1,309 +0,0 @@ | |||||||
| package mdns |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"fmt" |  | ||||||
| 	"net" |  | ||||||
| 	"os" |  | ||||||
| 	"strings" |  | ||||||
| 	"sync/atomic" |  | ||||||
|  |  | ||||||
| 	"github.com/miekg/dns" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| const ( |  | ||||||
| 	// defaultTTL is the default TTL value in returned DNS records in seconds. |  | ||||||
| 	defaultTTL = 120 |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| // Zone is the interface used to integrate with the server and |  | ||||||
| // to serve records dynamically |  | ||||||
| type Zone interface { |  | ||||||
| 	// Records returns DNS records in response to a DNS question. |  | ||||||
| 	Records(q dns.Question) []dns.RR |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // MDNSService is used to export a named service by implementing a Zone |  | ||||||
| type MDNSService struct { |  | ||||||
| 	Instance     string   // Instance name (e.g. "hostService name") |  | ||||||
| 	Service      string   // Service name (e.g. "_http._tcp.") |  | ||||||
| 	Domain       string   // If blank, assumes "local" |  | ||||||
| 	HostName     string   // Host machine DNS name (e.g. "mymachine.net.") |  | ||||||
| 	Port         int      // Service Port |  | ||||||
| 	IPs          []net.IP // IP addresses for the service's host |  | ||||||
| 	TXT          []string // Service TXT records |  | ||||||
| 	TTL          uint32 |  | ||||||
| 	serviceAddr  string // Fully qualified service address |  | ||||||
| 	instanceAddr string // Fully qualified instance address |  | ||||||
| 	enumAddr     string // _services._dns-sd._udp.<domain> |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // validateFQDN returns an error if the passed string is not a fully qualified |  | ||||||
| // hdomain name (more specifically, a hostname). |  | ||||||
| func validateFQDN(s string) error { |  | ||||||
| 	if len(s) == 0 { |  | ||||||
| 		return fmt.Errorf("FQDN must not be blank") |  | ||||||
| 	} |  | ||||||
| 	if s[len(s)-1] != '.' { |  | ||||||
| 		return fmt.Errorf("FQDN must end in period: %s", s) |  | ||||||
| 	} |  | ||||||
| 	// TODO(reddaly): Perform full validation. |  | ||||||
|  |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // NewMDNSService returns a new instance of MDNSService. |  | ||||||
| // |  | ||||||
| // If domain, hostName, or ips is set to the zero value, then a default value |  | ||||||
| // will be inferred from the operating system. |  | ||||||
| // |  | ||||||
| // TODO(reddaly): This interface may need to change to account for "unique |  | ||||||
| // record" conflict rules of the mDNS protocol.  Upon startup, the server should |  | ||||||
| // check to ensure that the instance name does not conflict with other instance |  | ||||||
| // names, and, if required, select a new name.  There may also be conflicting |  | ||||||
| // hostName A/AAAA records. |  | ||||||
| func NewMDNSService(instance, service, domain, hostName string, port int, ips []net.IP, txt []string) (*MDNSService, error) { |  | ||||||
| 	// Sanity check inputs |  | ||||||
| 	if instance == "" { |  | ||||||
| 		return nil, fmt.Errorf("missing service instance name") |  | ||||||
| 	} |  | ||||||
| 	if service == "" { |  | ||||||
| 		return nil, fmt.Errorf("missing service name") |  | ||||||
| 	} |  | ||||||
| 	if port == 0 { |  | ||||||
| 		return nil, fmt.Errorf("missing service port") |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Set default domain |  | ||||||
| 	if domain == "" { |  | ||||||
| 		domain = "local." |  | ||||||
| 	} |  | ||||||
| 	if err := validateFQDN(domain); err != nil { |  | ||||||
| 		return nil, fmt.Errorf("domain %q is not a fully-qualified domain name: %v", domain, err) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Get host information if no host is specified. |  | ||||||
| 	if hostName == "" { |  | ||||||
| 		var err error |  | ||||||
| 		hostName, err = os.Hostname() |  | ||||||
| 		if err != nil { |  | ||||||
| 			return nil, fmt.Errorf("could not determine host: %v", err) |  | ||||||
| 		} |  | ||||||
| 		hostName = fmt.Sprintf("%s.", hostName) |  | ||||||
| 	} |  | ||||||
| 	if err := validateFQDN(hostName); err != nil { |  | ||||||
| 		return nil, fmt.Errorf("hostName %q is not a fully-qualified domain name: %v", hostName, err) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if len(ips) == 0 { |  | ||||||
| 		var err error |  | ||||||
| 		ips, err = net.LookupIP(trimDot(hostName)) |  | ||||||
| 		if err != nil { |  | ||||||
| 			// Try appending the host domain suffix and lookup again |  | ||||||
| 			// (required for Linux-based hosts) |  | ||||||
| 			tmpHostName := fmt.Sprintf("%s%s", hostName, domain) |  | ||||||
|  |  | ||||||
| 			ips, err = net.LookupIP(trimDot(tmpHostName)) |  | ||||||
|  |  | ||||||
| 			if err != nil { |  | ||||||
| 				return nil, fmt.Errorf("could not determine host IP addresses for %s", hostName) |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	for _, ip := range ips { |  | ||||||
| 		if ip.To4() == nil && ip.To16() == nil { |  | ||||||
| 			return nil, fmt.Errorf("invalid IP address in IPs list: %v", ip) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	return &MDNSService{ |  | ||||||
| 		Instance:     instance, |  | ||||||
| 		Service:      service, |  | ||||||
| 		Domain:       domain, |  | ||||||
| 		HostName:     hostName, |  | ||||||
| 		Port:         port, |  | ||||||
| 		IPs:          ips, |  | ||||||
| 		TXT:          txt, |  | ||||||
| 		TTL:          defaultTTL, |  | ||||||
| 		serviceAddr:  fmt.Sprintf("%s.%s.", trimDot(service), trimDot(domain)), |  | ||||||
| 		instanceAddr: fmt.Sprintf("%s.%s.%s.", instance, trimDot(service), trimDot(domain)), |  | ||||||
| 		enumAddr:     fmt.Sprintf("_services._dns-sd._udp.%s.", trimDot(domain)), |  | ||||||
| 	}, nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // trimDot is used to trim the dots from the start or end of a string |  | ||||||
| func trimDot(s string) string { |  | ||||||
| 	return strings.Trim(s, ".") |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Records returns DNS records in response to a DNS question. |  | ||||||
| func (m *MDNSService) Records(q dns.Question) []dns.RR { |  | ||||||
| 	switch q.Name { |  | ||||||
| 	case m.enumAddr: |  | ||||||
| 		return m.serviceEnum(q) |  | ||||||
| 	case m.serviceAddr: |  | ||||||
| 		return m.serviceRecords(q) |  | ||||||
| 	case m.instanceAddr: |  | ||||||
| 		return m.instanceRecords(q) |  | ||||||
| 	case m.HostName: |  | ||||||
| 		if q.Qtype == dns.TypeA || q.Qtype == dns.TypeAAAA { |  | ||||||
| 			return m.instanceRecords(q) |  | ||||||
| 		} |  | ||||||
| 		fallthrough |  | ||||||
| 	default: |  | ||||||
| 		return nil |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (m *MDNSService) serviceEnum(q dns.Question) []dns.RR { |  | ||||||
| 	switch q.Qtype { |  | ||||||
| 	case dns.TypeANY: |  | ||||||
| 		fallthrough |  | ||||||
| 	case dns.TypePTR: |  | ||||||
| 		rr := &dns.PTR{ |  | ||||||
| 			Hdr: dns.RR_Header{ |  | ||||||
| 				Name:   q.Name, |  | ||||||
| 				Rrtype: dns.TypePTR, |  | ||||||
| 				Class:  dns.ClassINET, |  | ||||||
| 				Ttl:    atomic.LoadUint32(&m.TTL), |  | ||||||
| 			}, |  | ||||||
| 			Ptr: m.serviceAddr, |  | ||||||
| 		} |  | ||||||
| 		return []dns.RR{rr} |  | ||||||
| 	default: |  | ||||||
| 		return nil |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // serviceRecords is called when the query matches the service name |  | ||||||
| func (m *MDNSService) serviceRecords(q dns.Question) []dns.RR { |  | ||||||
| 	switch q.Qtype { |  | ||||||
| 	case dns.TypeANY: |  | ||||||
| 		fallthrough |  | ||||||
| 	case dns.TypePTR: |  | ||||||
| 		// Build a PTR response for the service |  | ||||||
| 		rr := &dns.PTR{ |  | ||||||
| 			Hdr: dns.RR_Header{ |  | ||||||
| 				Name:   q.Name, |  | ||||||
| 				Rrtype: dns.TypePTR, |  | ||||||
| 				Class:  dns.ClassINET, |  | ||||||
| 				Ttl:    atomic.LoadUint32(&m.TTL), |  | ||||||
| 			}, |  | ||||||
| 			Ptr: m.instanceAddr, |  | ||||||
| 		} |  | ||||||
| 		servRec := []dns.RR{rr} |  | ||||||
|  |  | ||||||
| 		// Get the instance records |  | ||||||
| 		instRecs := m.instanceRecords(dns.Question{ |  | ||||||
| 			Name:  m.instanceAddr, |  | ||||||
| 			Qtype: dns.TypeANY, |  | ||||||
| 		}) |  | ||||||
|  |  | ||||||
| 		// Return the service record with the instance records |  | ||||||
| 		return append(servRec, instRecs...) |  | ||||||
| 	default: |  | ||||||
| 		return nil |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // serviceRecords is called when the query matches the instance name |  | ||||||
| func (m *MDNSService) instanceRecords(q dns.Question) []dns.RR { |  | ||||||
| 	switch q.Qtype { |  | ||||||
| 	case dns.TypeANY: |  | ||||||
| 		// Get the SRV, which includes A and AAAA |  | ||||||
| 		recs := m.instanceRecords(dns.Question{ |  | ||||||
| 			Name:  m.instanceAddr, |  | ||||||
| 			Qtype: dns.TypeSRV, |  | ||||||
| 		}) |  | ||||||
|  |  | ||||||
| 		// Add the TXT record |  | ||||||
| 		recs = append(recs, m.instanceRecords(dns.Question{ |  | ||||||
| 			Name:  m.instanceAddr, |  | ||||||
| 			Qtype: dns.TypeTXT, |  | ||||||
| 		})...) |  | ||||||
| 		return recs |  | ||||||
|  |  | ||||||
| 	case dns.TypeA: |  | ||||||
| 		var rr []dns.RR |  | ||||||
| 		for _, ip := range m.IPs { |  | ||||||
| 			if ip4 := ip.To4(); ip4 != nil { |  | ||||||
| 				rr = append(rr, &dns.A{ |  | ||||||
| 					Hdr: dns.RR_Header{ |  | ||||||
| 						Name:   m.HostName, |  | ||||||
| 						Rrtype: dns.TypeA, |  | ||||||
| 						Class:  dns.ClassINET, |  | ||||||
| 						Ttl:    atomic.LoadUint32(&m.TTL), |  | ||||||
| 					}, |  | ||||||
| 					A: ip4, |  | ||||||
| 				}) |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 		return rr |  | ||||||
|  |  | ||||||
| 	case dns.TypeAAAA: |  | ||||||
| 		var rr []dns.RR |  | ||||||
| 		for _, ip := range m.IPs { |  | ||||||
| 			if ip.To4() != nil { |  | ||||||
| 				// TODO(reddaly): IPv4 addresses could be encoded in IPv6 format and |  | ||||||
| 				// putinto AAAA records, but the current logic puts ipv4-encodable |  | ||||||
| 				// addresses into the A records exclusively.  Perhaps this should be |  | ||||||
| 				// configurable? |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			if ip16 := ip.To16(); ip16 != nil { |  | ||||||
| 				rr = append(rr, &dns.AAAA{ |  | ||||||
| 					Hdr: dns.RR_Header{ |  | ||||||
| 						Name:   m.HostName, |  | ||||||
| 						Rrtype: dns.TypeAAAA, |  | ||||||
| 						Class:  dns.ClassINET, |  | ||||||
| 						Ttl:    atomic.LoadUint32(&m.TTL), |  | ||||||
| 					}, |  | ||||||
| 					AAAA: ip16, |  | ||||||
| 				}) |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 		return rr |  | ||||||
|  |  | ||||||
| 	case dns.TypeSRV: |  | ||||||
| 		// Create the SRV Record |  | ||||||
| 		srv := &dns.SRV{ |  | ||||||
| 			Hdr: dns.RR_Header{ |  | ||||||
| 				Name:   q.Name, |  | ||||||
| 				Rrtype: dns.TypeSRV, |  | ||||||
| 				Class:  dns.ClassINET, |  | ||||||
| 				Ttl:    atomic.LoadUint32(&m.TTL), |  | ||||||
| 			}, |  | ||||||
| 			Priority: 10, |  | ||||||
| 			Weight:   1, |  | ||||||
| 			Port:     uint16(m.Port), |  | ||||||
| 			Target:   m.HostName, |  | ||||||
| 		} |  | ||||||
| 		recs := []dns.RR{srv} |  | ||||||
|  |  | ||||||
| 		// Add the A record |  | ||||||
| 		recs = append(recs, m.instanceRecords(dns.Question{ |  | ||||||
| 			Name:  m.instanceAddr, |  | ||||||
| 			Qtype: dns.TypeA, |  | ||||||
| 		})...) |  | ||||||
|  |  | ||||||
| 		// Add the AAAA record |  | ||||||
| 		recs = append(recs, m.instanceRecords(dns.Question{ |  | ||||||
| 			Name:  m.instanceAddr, |  | ||||||
| 			Qtype: dns.TypeAAAA, |  | ||||||
| 		})...) |  | ||||||
| 		return recs |  | ||||||
|  |  | ||||||
| 	case dns.TypeTXT: |  | ||||||
| 		txt := &dns.TXT{ |  | ||||||
| 			Hdr: dns.RR_Header{ |  | ||||||
| 				Name:   q.Name, |  | ||||||
| 				Rrtype: dns.TypeTXT, |  | ||||||
| 				Class:  dns.ClassINET, |  | ||||||
| 				Ttl:    atomic.LoadUint32(&m.TTL), |  | ||||||
| 			}, |  | ||||||
| 			Txt: m.TXT, |  | ||||||
| 		} |  | ||||||
| 		return []dns.RR{txt} |  | ||||||
| 	} |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
| @@ -1,275 +0,0 @@ | |||||||
| package mdns |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"bytes" |  | ||||||
| 	"net" |  | ||||||
| 	"reflect" |  | ||||||
| 	"testing" |  | ||||||
|  |  | ||||||
| 	"github.com/miekg/dns" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| func makeService(t *testing.T) *MDNSService { |  | ||||||
| 	return makeServiceWithServiceName(t, "_http._tcp") |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func makeServiceWithServiceName(t *testing.T, service string) *MDNSService { |  | ||||||
| 	m, err := NewMDNSService( |  | ||||||
| 		"hostname", |  | ||||||
| 		service, |  | ||||||
| 		"local.", |  | ||||||
| 		"testhost.", |  | ||||||
| 		80, // port |  | ||||||
| 		[]net.IP{net.IP([]byte{192, 168, 0, 42}), net.ParseIP("2620:0:1000:1900:b0c2:d0b2:c411:18bc")}, |  | ||||||
| 		[]string{"Local web server"}) // TXT |  | ||||||
|  |  | ||||||
| 	if err != nil { |  | ||||||
| 		t.Fatalf("err: %v", err) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	return m |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestNewMDNSService_BadParams(t *testing.T) { |  | ||||||
| 	for _, test := range []struct { |  | ||||||
| 		testName string |  | ||||||
| 		hostName string |  | ||||||
| 		domain   string |  | ||||||
| 	}{ |  | ||||||
| 		{ |  | ||||||
| 			"NewMDNSService should fail when passed hostName that is not a legal fully-qualified domain name", |  | ||||||
| 			"hostname", // not legal FQDN - should be "hostname." or "hostname.local.", etc. |  | ||||||
| 			"local.",   // legal |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			"NewMDNSService should fail when passed domain that is not a legal fully-qualified domain name", |  | ||||||
| 			"hostname.", // legal |  | ||||||
| 			"local",     // should be "local." |  | ||||||
| 		}, |  | ||||||
| 	} { |  | ||||||
| 		_, err := NewMDNSService( |  | ||||||
| 			"instance name", |  | ||||||
| 			"_http._tcp", |  | ||||||
| 			test.domain, |  | ||||||
| 			test.hostName, |  | ||||||
| 			80, // port |  | ||||||
| 			[]net.IP{net.IP([]byte{192, 168, 0, 42})}, |  | ||||||
| 			[]string{"Local web server"}) // TXT |  | ||||||
| 		if err == nil { |  | ||||||
| 			t.Fatalf("%s: error expected, but got none", test.testName) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestMDNSService_BadAddr(t *testing.T) { |  | ||||||
| 	s := makeService(t) |  | ||||||
| 	q := dns.Question{ |  | ||||||
| 		Name:  "random", |  | ||||||
| 		Qtype: dns.TypeANY, |  | ||||||
| 	} |  | ||||||
| 	recs := s.Records(q) |  | ||||||
| 	if len(recs) != 0 { |  | ||||||
| 		t.Fatalf("bad: %v", recs) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestMDNSService_ServiceAddr(t *testing.T) { |  | ||||||
| 	s := makeService(t) |  | ||||||
| 	q := dns.Question{ |  | ||||||
| 		Name:  "_http._tcp.local.", |  | ||||||
| 		Qtype: dns.TypeANY, |  | ||||||
| 	} |  | ||||||
| 	recs := s.Records(q) |  | ||||||
| 	if got, want := len(recs), 5; got != want { |  | ||||||
| 		t.Fatalf("got %d records, want %d: %v", got, want, recs) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if ptr, ok := recs[0].(*dns.PTR); !ok { |  | ||||||
| 		t.Errorf("recs[0] should be PTR record, got: %v, all records: %v", recs[0], recs) |  | ||||||
| 	} else if got, want := ptr.Ptr, "hostname._http._tcp.local."; got != want { |  | ||||||
| 		t.Fatalf("bad PTR record %v: got %v, want %v", ptr, got, want) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if _, ok := recs[1].(*dns.SRV); !ok { |  | ||||||
| 		t.Errorf("recs[1] should be SRV record, got: %v, all reccords: %v", recs[1], recs) |  | ||||||
| 	} |  | ||||||
| 	if _, ok := recs[2].(*dns.A); !ok { |  | ||||||
| 		t.Errorf("recs[2] should be A record, got: %v, all records: %v", recs[2], recs) |  | ||||||
| 	} |  | ||||||
| 	if _, ok := recs[3].(*dns.AAAA); !ok { |  | ||||||
| 		t.Errorf("recs[3] should be AAAA record, got: %v, all records: %v", recs[3], recs) |  | ||||||
| 	} |  | ||||||
| 	if _, ok := recs[4].(*dns.TXT); !ok { |  | ||||||
| 		t.Errorf("recs[4] should be TXT record, got: %v, all records: %v", recs[4], recs) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	q.Qtype = dns.TypePTR |  | ||||||
| 	if recs2 := s.Records(q); !reflect.DeepEqual(recs, recs2) { |  | ||||||
| 		t.Fatalf("PTR question should return same result as ANY question: ANY => %v, PTR => %v", recs, recs2) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestMDNSService_InstanceAddr_ANY(t *testing.T) { |  | ||||||
| 	s := makeService(t) |  | ||||||
| 	q := dns.Question{ |  | ||||||
| 		Name:  "hostname._http._tcp.local.", |  | ||||||
| 		Qtype: dns.TypeANY, |  | ||||||
| 	} |  | ||||||
| 	recs := s.Records(q) |  | ||||||
| 	if len(recs) != 4 { |  | ||||||
| 		t.Fatalf("bad: %v", recs) |  | ||||||
| 	} |  | ||||||
| 	if _, ok := recs[0].(*dns.SRV); !ok { |  | ||||||
| 		t.Fatalf("bad: %v", recs[0]) |  | ||||||
| 	} |  | ||||||
| 	if _, ok := recs[1].(*dns.A); !ok { |  | ||||||
| 		t.Fatalf("bad: %v", recs[1]) |  | ||||||
| 	} |  | ||||||
| 	if _, ok := recs[2].(*dns.AAAA); !ok { |  | ||||||
| 		t.Fatalf("bad: %v", recs[2]) |  | ||||||
| 	} |  | ||||||
| 	if _, ok := recs[3].(*dns.TXT); !ok { |  | ||||||
| 		t.Fatalf("bad: %v", recs[3]) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestMDNSService_InstanceAddr_SRV(t *testing.T) { |  | ||||||
| 	s := makeService(t) |  | ||||||
| 	q := dns.Question{ |  | ||||||
| 		Name:  "hostname._http._tcp.local.", |  | ||||||
| 		Qtype: dns.TypeSRV, |  | ||||||
| 	} |  | ||||||
| 	recs := s.Records(q) |  | ||||||
| 	if len(recs) != 3 { |  | ||||||
| 		t.Fatalf("bad: %v", recs) |  | ||||||
| 	} |  | ||||||
| 	srv, ok := recs[0].(*dns.SRV) |  | ||||||
| 	if !ok { |  | ||||||
| 		t.Fatalf("bad: %v", recs[0]) |  | ||||||
| 	} |  | ||||||
| 	if _, ok := recs[1].(*dns.A); !ok { |  | ||||||
| 		t.Fatalf("bad: %v", recs[1]) |  | ||||||
| 	} |  | ||||||
| 	if _, ok := recs[2].(*dns.AAAA); !ok { |  | ||||||
| 		t.Fatalf("bad: %v", recs[2]) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if srv.Port != uint16(s.Port) { |  | ||||||
| 		t.Fatalf("bad: %v", recs[0]) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestMDNSService_InstanceAddr_A(t *testing.T) { |  | ||||||
| 	s := makeService(t) |  | ||||||
| 	q := dns.Question{ |  | ||||||
| 		Name:  "hostname._http._tcp.local.", |  | ||||||
| 		Qtype: dns.TypeA, |  | ||||||
| 	} |  | ||||||
| 	recs := s.Records(q) |  | ||||||
| 	if len(recs) != 1 { |  | ||||||
| 		t.Fatalf("bad: %v", recs) |  | ||||||
| 	} |  | ||||||
| 	a, ok := recs[0].(*dns.A) |  | ||||||
| 	if !ok { |  | ||||||
| 		t.Fatalf("bad: %v", recs[0]) |  | ||||||
| 	} |  | ||||||
| 	if !bytes.Equal(a.A, []byte{192, 168, 0, 42}) { |  | ||||||
| 		t.Fatalf("bad: %v", recs[0]) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestMDNSService_InstanceAddr_AAAA(t *testing.T) { |  | ||||||
| 	s := makeService(t) |  | ||||||
| 	q := dns.Question{ |  | ||||||
| 		Name:  "hostname._http._tcp.local.", |  | ||||||
| 		Qtype: dns.TypeAAAA, |  | ||||||
| 	} |  | ||||||
| 	recs := s.Records(q) |  | ||||||
| 	if len(recs) != 1 { |  | ||||||
| 		t.Fatalf("bad: %v", recs) |  | ||||||
| 	} |  | ||||||
| 	a4, ok := recs[0].(*dns.AAAA) |  | ||||||
| 	if !ok { |  | ||||||
| 		t.Fatalf("bad: %v", recs[0]) |  | ||||||
| 	} |  | ||||||
| 	ip6 := net.ParseIP("2620:0:1000:1900:b0c2:d0b2:c411:18bc") |  | ||||||
| 	if got := len(ip6); got != net.IPv6len { |  | ||||||
| 		t.Fatalf("test IP failed to parse (len = %d, want %d)", got, net.IPv6len) |  | ||||||
| 	} |  | ||||||
| 	if !bytes.Equal(a4.AAAA, ip6) { |  | ||||||
| 		t.Fatalf("bad: %v", recs[0]) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestMDNSService_InstanceAddr_TXT(t *testing.T) { |  | ||||||
| 	s := makeService(t) |  | ||||||
| 	q := dns.Question{ |  | ||||||
| 		Name:  "hostname._http._tcp.local.", |  | ||||||
| 		Qtype: dns.TypeTXT, |  | ||||||
| 	} |  | ||||||
| 	recs := s.Records(q) |  | ||||||
| 	if len(recs) != 1 { |  | ||||||
| 		t.Fatalf("bad: %v", recs) |  | ||||||
| 	} |  | ||||||
| 	txt, ok := recs[0].(*dns.TXT) |  | ||||||
| 	if !ok { |  | ||||||
| 		t.Fatalf("bad: %v", recs[0]) |  | ||||||
| 	} |  | ||||||
| 	if got, want := txt.Txt, s.TXT; !reflect.DeepEqual(got, want) { |  | ||||||
| 		t.Fatalf("TXT record mismatch for %v: got %v, want %v", recs[0], got, want) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestMDNSService_HostNameQuery(t *testing.T) { |  | ||||||
| 	s := makeService(t) |  | ||||||
| 	for _, test := range []struct { |  | ||||||
| 		q    dns.Question |  | ||||||
| 		want []dns.RR |  | ||||||
| 	}{ |  | ||||||
| 		{ |  | ||||||
| 			dns.Question{Name: "testhost.", Qtype: dns.TypeA}, |  | ||||||
| 			[]dns.RR{&dns.A{ |  | ||||||
| 				Hdr: dns.RR_Header{ |  | ||||||
| 					Name:   "testhost.", |  | ||||||
| 					Rrtype: dns.TypeA, |  | ||||||
| 					Class:  dns.ClassINET, |  | ||||||
| 					Ttl:    120, |  | ||||||
| 				}, |  | ||||||
| 				A: net.IP([]byte{192, 168, 0, 42}), |  | ||||||
| 			}}, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			dns.Question{Name: "testhost.", Qtype: dns.TypeAAAA}, |  | ||||||
| 			[]dns.RR{&dns.AAAA{ |  | ||||||
| 				Hdr: dns.RR_Header{ |  | ||||||
| 					Name:   "testhost.", |  | ||||||
| 					Rrtype: dns.TypeAAAA, |  | ||||||
| 					Class:  dns.ClassINET, |  | ||||||
| 					Ttl:    120, |  | ||||||
| 				}, |  | ||||||
| 				AAAA: net.ParseIP("2620:0:1000:1900:b0c2:d0b2:c411:18bc"), |  | ||||||
| 			}}, |  | ||||||
| 		}, |  | ||||||
| 	} { |  | ||||||
| 		if got := s.Records(test.q); !reflect.DeepEqual(got, test.want) { |  | ||||||
| 			t.Errorf("hostname query failed: s.Records(%v) = %v, want %v", test.q, got, test.want) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestMDNSService_serviceEnum_PTR(t *testing.T) { |  | ||||||
| 	s := makeService(t) |  | ||||||
| 	q := dns.Question{ |  | ||||||
| 		Name:  "_services._dns-sd._udp.local.", |  | ||||||
| 		Qtype: dns.TypePTR, |  | ||||||
| 	} |  | ||||||
| 	recs := s.Records(q) |  | ||||||
| 	if len(recs) != 1 { |  | ||||||
| 		t.Fatalf("bad: %v", recs) |  | ||||||
| 	} |  | ||||||
| 	if ptr, ok := recs[0].(*dns.PTR); !ok { |  | ||||||
| 		t.Errorf("recs[0] should be PTR record, got: %v, all records: %v", recs[0], recs) |  | ||||||
| 	} else if got, want := ptr.Ptr, "_http._tcp.local."; got != want { |  | ||||||
| 		t.Fatalf("bad PTR record %v: got %v, want %v", ptr, got, want) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
		Reference in New Issue
	
	Block a user