504 lines
12 KiB
Go
504 lines
12 KiB
Go
package mdns
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"sync"
|
|
|
|
"go.unistack.org/micro/v3/logger"
|
|
"golang.org/x/net/dns/dnsmessage"
|
|
"golang.org/x/net/ipv4"
|
|
"golang.org/x/net/ipv6"
|
|
)
|
|
|
|
// ServiceEntry is returned after we query for a service
|
|
type ServiceEntry struct {
|
|
Name string
|
|
Host string
|
|
AddrV4 net.IP
|
|
AddrV6 net.IP
|
|
Port int
|
|
Info string
|
|
InfoFields []string
|
|
TTL int
|
|
Type uint16
|
|
|
|
hasTXT bool
|
|
sent bool
|
|
}
|
|
|
|
// complete is used to check if we have all the info we need
|
|
func (s *ServiceEntry) complete() bool {
|
|
return (s.AddrV4 != nil || s.AddrV6 != nil) && s.Port != 0 && s.hasTXT
|
|
}
|
|
|
|
// QueryParam is used to customize how a Lookup is performed
|
|
type QueryParam struct {
|
|
Service string // Service to lookup
|
|
Domain string // Lookup domain, default "local"
|
|
Type dnsmessage.Type // Lookup type, defaults to dns.TypePTR
|
|
Interface *net.Interface // Multicast interface to use
|
|
Entries chan<- *ServiceEntry // Entries Channel
|
|
WantUnicastResponse bool // Unicast response desired, as per 5.4 in RFC
|
|
}
|
|
|
|
// DefaultParams is used to return a default set of QueryParam's
|
|
func DefaultParams(service string) *QueryParam {
|
|
qp := &QueryParam{
|
|
Service: service,
|
|
Domain: "local",
|
|
Entries: make(chan *ServiceEntry),
|
|
WantUnicastResponse: false,
|
|
}
|
|
return qp
|
|
}
|
|
|
|
// Query looks up a given service, in a domain, waiting at most
|
|
// for a timeout before finishing the query. The results are streamed
|
|
// to a channel. Sends will not block, so clients should make sure to
|
|
// either read or buffer.
|
|
func Query(ctx context.Context, params *QueryParam) error {
|
|
// Create a new client
|
|
client, err := newClient()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer client.Close()
|
|
|
|
// Set the multicast interface
|
|
if params.Interface != nil {
|
|
if err := client.setInterface(params.Interface, false); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// Ensure defaults are set
|
|
if params.Domain == "" {
|
|
params.Domain = "local"
|
|
}
|
|
|
|
// Run the query
|
|
return client.query(ctx, params)
|
|
}
|
|
|
|
// Listen listens indefinitely for multicast updates
|
|
func Listen(entries chan<- *ServiceEntry, exit chan struct{}) error {
|
|
// Create a new client
|
|
client, err := newClient()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer client.Close()
|
|
|
|
client.setInterface(nil, true)
|
|
|
|
// Start listening for response packets
|
|
msgCh := make(chan []byte, 32)
|
|
|
|
go client.recv(client.ipv4UnicastConn, msgCh)
|
|
go client.recv(client.ipv6UnicastConn, msgCh)
|
|
go client.recv(client.ipv4MulticastConn, msgCh)
|
|
go client.recv(client.ipv6MulticastConn, msgCh)
|
|
|
|
sentry := make(map[string]*ServiceEntry)
|
|
|
|
for {
|
|
select {
|
|
case <-exit:
|
|
return nil
|
|
case <-client.closedCh:
|
|
return nil
|
|
case msg := <-msgCh:
|
|
fmt.Printf("%#+v\n", msg)
|
|
entry := messageToEntry(msg, sentry)
|
|
if entry == nil {
|
|
continue
|
|
}
|
|
|
|
// Check if this entry is complete
|
|
if entry.complete() {
|
|
if entry.sent {
|
|
continue
|
|
}
|
|
entry.sent = true
|
|
entries <- entry
|
|
sentry = make(map[string]*ServiceEntry)
|
|
} else {
|
|
// Fire off a node specific query
|
|
/*
|
|
h:
|
|
-&dnsmessage.Header{RecursionDesired: false}
|
|
m := dnsmessage.NewBuilder()
|
|
m.SetQuestion(e.Name, dns.TypePTR)
|
|
if err := client.sendQuery(m); err != nil {
|
|
logger.Errorf("[ERR] mdns: Failed to query instance %s: %v", e.Name, err)
|
|
}
|
|
*/
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
/*
|
|
// Lookup is the same as Query, however it uses all the default parameters
|
|
func Lookup(service string, entries chan<- *ServiceEntry) error {
|
|
params := DefaultParams(service)
|
|
params.Entries = entries
|
|
return Query(params)
|
|
}
|
|
*/
|
|
// Client provides a query interface that can be used to
|
|
// search for service providers using mDNS
|
|
type client struct {
|
|
ipv4UnicastConn *net.UDPConn
|
|
ipv6UnicastConn *net.UDPConn
|
|
|
|
ipv4MulticastConn *net.UDPConn
|
|
ipv6MulticastConn *net.UDPConn
|
|
|
|
closed bool
|
|
closedCh chan struct{} // TODO(reddaly): This doesn't appear to be used.
|
|
closeLock sync.RWMutex
|
|
}
|
|
|
|
// NewClient creates a new mdns Client that can be used to query
|
|
// for records
|
|
func newClient() (*client, error) {
|
|
// TODO(reddaly): At least attempt to bind to the port required in the spec.
|
|
// Create a IPv4 listener
|
|
uconn4, err4 := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
|
uconn6, err6 := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
|
|
if err4 != nil && err6 != nil {
|
|
logger.Errorf(context.TODO(), "[ERR] mdns: Failed to bind to udp port: %v %v", err4, err6)
|
|
}
|
|
|
|
if uconn4 == nil && uconn6 == nil {
|
|
return nil, fmt.Errorf("failed to bind to any unicast udp port")
|
|
}
|
|
|
|
if uconn4 == nil {
|
|
uconn4 = &net.UDPConn{}
|
|
}
|
|
|
|
if uconn6 == nil {
|
|
uconn6 = &net.UDPConn{}
|
|
}
|
|
|
|
mconn4, err4 := net.ListenUDP("udp4", mdnsWildcardAddrIPv4)
|
|
mconn6, err6 := net.ListenUDP("udp6", mdnsWildcardAddrIPv6)
|
|
if err4 != nil && err6 != nil {
|
|
logger.Errorf(context.TODO(), "[ERR] mdns: Failed to bind to udp port: %v %v", err4, err6)
|
|
}
|
|
|
|
if mconn4 == nil && mconn6 == nil {
|
|
return nil, fmt.Errorf("failed to bind to any multicast udp port")
|
|
}
|
|
|
|
if mconn4 == nil {
|
|
mconn4 = &net.UDPConn{}
|
|
}
|
|
|
|
if mconn6 == nil {
|
|
mconn6 = &net.UDPConn{}
|
|
}
|
|
|
|
p1 := ipv4.NewPacketConn(mconn4)
|
|
p2 := ipv6.NewPacketConn(mconn6)
|
|
p1.SetMulticastLoopback(true)
|
|
p2.SetMulticastLoopback(true)
|
|
|
|
ifaces, err := net.Interfaces()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var errCount1, errCount2 int
|
|
|
|
for _, iface := range ifaces {
|
|
if err := p1.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil {
|
|
errCount1++
|
|
}
|
|
if err := p2.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil {
|
|
errCount2++
|
|
}
|
|
}
|
|
|
|
if len(ifaces) == errCount1 && len(ifaces) == errCount2 {
|
|
return nil, fmt.Errorf("Failed to join multicast group on all interfaces!")
|
|
}
|
|
|
|
c := &client{
|
|
ipv4MulticastConn: mconn4,
|
|
ipv6MulticastConn: mconn6,
|
|
ipv4UnicastConn: uconn4,
|
|
ipv6UnicastConn: uconn6,
|
|
closedCh: make(chan struct{}),
|
|
}
|
|
return c, nil
|
|
}
|
|
|
|
// Close is used to cleanup the client
|
|
func (c *client) Close() error {
|
|
c.closeLock.Lock()
|
|
defer c.closeLock.Unlock()
|
|
|
|
if c.closed {
|
|
return nil
|
|
}
|
|
c.closed = true
|
|
|
|
close(c.closedCh)
|
|
|
|
if c.ipv4UnicastConn != nil {
|
|
c.ipv4UnicastConn.Close()
|
|
}
|
|
if c.ipv6UnicastConn != nil {
|
|
c.ipv6UnicastConn.Close()
|
|
}
|
|
if c.ipv4MulticastConn != nil {
|
|
c.ipv4MulticastConn.Close()
|
|
}
|
|
if c.ipv6MulticastConn != nil {
|
|
c.ipv6MulticastConn.Close()
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// setInterface is used to set the query interface, uses sytem
|
|
// default if not provided
|
|
func (c *client) setInterface(iface *net.Interface, loopback bool) error {
|
|
p := ipv4.NewPacketConn(c.ipv4UnicastConn)
|
|
if err := p.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil {
|
|
return err
|
|
}
|
|
p2 := ipv6.NewPacketConn(c.ipv6UnicastConn)
|
|
if err := p2.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil {
|
|
return err
|
|
}
|
|
p = ipv4.NewPacketConn(c.ipv4MulticastConn)
|
|
if err := p.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil {
|
|
return err
|
|
}
|
|
p2 = ipv6.NewPacketConn(c.ipv6MulticastConn)
|
|
if err := p2.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil {
|
|
return err
|
|
}
|
|
|
|
if loopback {
|
|
p.SetMulticastLoopback(true)
|
|
p2.SetMulticastLoopback(true)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// query is used to perform a lookup and stream results
|
|
func (c *client) query(ctx context.Context, params *QueryParam) error {
|
|
// Create the service name
|
|
serviceAddr := fmt.Sprintf("%s.%s.", trimDot(params.Service), trimDot(params.Domain))
|
|
|
|
// Start listening for response packets
|
|
msgCh := make(chan []byte, 32)
|
|
go c.recv(c.ipv4UnicastConn, msgCh)
|
|
go c.recv(c.ipv6UnicastConn, msgCh)
|
|
go c.recv(c.ipv4MulticastConn, msgCh)
|
|
go c.recv(c.ipv6MulticastConn, msgCh)
|
|
|
|
// buf := make([]byte, 2, 514)
|
|
hdr := dnsmessage.Header{RecursionDesired: false}
|
|
b := dnsmessage.NewBuilder(nil, hdr)
|
|
// b.EnableCompression()
|
|
name, err := dnsmessage.NewName(serviceAddr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
q := dnsmessage.Question{Name: name, Class: dnsmessage.ClassINET}
|
|
if params.Type == 0 {
|
|
q.Type = dnsmessage.TypePTR
|
|
} else {
|
|
q.Type = params.Type
|
|
}
|
|
// q.Class |= 1 << 15
|
|
if err = b.StartQuestions(); err != nil {
|
|
return err
|
|
}
|
|
if err = b.Question(q); err != nil {
|
|
return err
|
|
}
|
|
bbuf, err := b.Finish()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Send the query
|
|
// RFC 6762, section 18.12. Repurposing of Top Bit of qclass in Question
|
|
// Section
|
|
//
|
|
// In the Question Section of a Multicast DNS query, the top bit of the qclass
|
|
// field is used to indicate that unicast responses are preferred for this
|
|
// particular question. (See Section 5.4.)
|
|
if err := c.sendQuery(bbuf); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Map the in-progress responses
|
|
inprogress := make(map[string]*ServiceEntry)
|
|
|
|
for {
|
|
select {
|
|
case rsp := <-msgCh:
|
|
inp := messageToEntry(rsp, inprogress)
|
|
if inp == nil {
|
|
continue
|
|
}
|
|
|
|
// Check if this entry is complete
|
|
if inp.complete() {
|
|
if inp.sent {
|
|
continue
|
|
}
|
|
inp.sent = true
|
|
select {
|
|
case params.Entries <- inp:
|
|
case <-ctx.Done():
|
|
return nil
|
|
}
|
|
} else {
|
|
// Fire off a node specific query
|
|
// m := new(dns.Msg)
|
|
// m.SetQuestion(inp.Name, inp.Type)
|
|
// m.RecursionDesired = false
|
|
var buf []byte
|
|
if err := c.sendQuery(buf); err != nil {
|
|
logger.Errorf(context.TODO(), "[ERR] mdns: Failed to query instance %s: %v", inp.Name, err)
|
|
}
|
|
}
|
|
case <-ctx.Done():
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
|
|
// sendQuery is used to multicast a query out
|
|
func (c *client) sendQuery(buf []byte) error {
|
|
if c.ipv4UnicastConn != nil {
|
|
c.ipv4UnicastConn.WriteToUDP(buf, ipv4Addr)
|
|
}
|
|
if c.ipv6UnicastConn != nil {
|
|
c.ipv6UnicastConn.WriteToUDP(buf, ipv6Addr)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// recv is used to receive until we get a shutdown
|
|
func (c *client) recv(l *net.UDPConn, msgCh chan []byte) {
|
|
if l == nil {
|
|
return
|
|
}
|
|
buf := make([]byte, 65536)
|
|
for {
|
|
select {
|
|
case <-c.closedCh:
|
|
return
|
|
default:
|
|
c.closeLock.Lock()
|
|
if c.closed {
|
|
c.closeLock.Unlock()
|
|
return
|
|
}
|
|
c.closeLock.Unlock()
|
|
n, err := l.Read(buf)
|
|
if err != nil {
|
|
if logger.V(logger.DebugLevel) {
|
|
logger.Debug(context.TODO(), err)
|
|
}
|
|
continue
|
|
}
|
|
msgCh <- buf[:n]
|
|
}
|
|
}
|
|
}
|
|
|
|
/*
|
|
// ensureName is used to ensure the named node is in progress
|
|
func ensureName(inprogress map[string]*ServiceEntry, name string, typ uint16) *ServiceEntry {
|
|
if inp, ok := inprogress[name]; ok {
|
|
return inp
|
|
}
|
|
inp := &ServiceEntry{
|
|
Name: name,
|
|
Type: typ,
|
|
}
|
|
inprogress[name] = inp
|
|
return inp
|
|
}
|
|
|
|
// alias is used to setup an alias between two entries
|
|
func alias(inprogress map[string]*ServiceEntry, src, dst string, typ uint16) {
|
|
srcEntry := ensureName(inprogress, src, typ)
|
|
inprogress[dst] = srcEntry
|
|
}
|
|
*/
|
|
func messageToEntry(m []byte, inprogress map[string]*ServiceEntry) *ServiceEntry {
|
|
var inp *ServiceEntry
|
|
/*
|
|
for _, answer := range append(m.Answers, m.Additionals...) {
|
|
// TODO(reddaly): Check that response corresponds to serviceAddr?
|
|
switch answer.Header.Type {
|
|
case dnsmessage.TypePTR:
|
|
rr := answer.Body.(*dnsmessage.PTRResource)
|
|
// Create new entry for this
|
|
inp = ensureName(inprogress, rr.Ptr, rr.Hdr.Rrtype)
|
|
if inp.complete() {
|
|
continue
|
|
}
|
|
case dnsmessage.TypeSRV:
|
|
// Check for a target mismatch
|
|
if rr.Target != rr.Hdr.Name {
|
|
alias(inprogress, rr.Hdr.Name, rr.Target, rr.Hdr.Rrtype)
|
|
}
|
|
|
|
// Get the port
|
|
inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype)
|
|
if inp.complete() {
|
|
continue
|
|
}
|
|
inp.Host = rr.Target
|
|
inp.Port = int(rr.Port)
|
|
case dnsmessage.TypeTXT:
|
|
// Pull out the txt
|
|
inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype)
|
|
if inp.complete() {
|
|
continue
|
|
}
|
|
inp.Info = strings.Join(rr.Txt, "|")
|
|
inp.InfoFields = rr.Txt
|
|
inp.hasTXT = true
|
|
case dnsmessage.TypeA:
|
|
// Pull out the IP
|
|
inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype)
|
|
if inp.complete() {
|
|
continue
|
|
}
|
|
inp.AddrV4 = rr.A
|
|
case dnsmessage.TypeAAAA:
|
|
// Pull out the IP
|
|
inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype)
|
|
if inp.complete() {
|
|
continue
|
|
}
|
|
inp.AddrV6 = rr.AAAA
|
|
}
|
|
|
|
if inp != nil {
|
|
inp.TTL = int(answer.Header().Ttl)
|
|
}
|
|
}
|
|
*/
|
|
return inp
|
|
}
|