Add the probing, announcements and shutdown. Add client Listen which can be used for a watcher

This commit is contained in:
Asim 2016-05-01 18:55:42 +01:00
parent 2963a9d96a
commit f4bad1caf6
3 changed files with 318 additions and 95 deletions

192
client.go
View File

@ -22,6 +22,7 @@ type ServiceEntry struct {
Port int Port int
Info string Info string
InfoFields []string InfoFields []string
TTL int
Addr net.IP // @Deprecated Addr net.IP // @Deprecated
@ -69,7 +70,7 @@ func Query(params *QueryParam) error {
// Set the multicast interface // Set the multicast interface
if params.Interface != nil { if params.Interface != nil {
if err := client.setInterface(params.Interface); err != nil { if err := client.setInterface(params.Interface, false); err != nil {
return err return err
} }
} }
@ -86,6 +87,62 @@ func Query(params *QueryParam) error {
return client.query(params) return client.query(params)
} }
// Listen listens indefinitely for multicast updates
func Listen(entries chan<- *ServiceEntry, exit chan struct{}) error {
// Create a new client
client, err := newClient()
if err != nil {
return err
}
defer client.Close()
client.setInterface(nil, true)
// Start listening for response packets
msgCh := make(chan *dns.Msg, 32)
go client.recv(client.ipv4MulticastConn, msgCh)
go client.recv(client.ipv6MulticastConn, msgCh)
go client.recv(client.ipv4MulticastConn, msgCh)
go client.recv(client.ipv6MulticastConn, msgCh)
ip := make(map[string]*ServiceEntry)
for {
select {
case <-exit:
return nil
case <-client.closedCh:
return nil
case m := <-msgCh:
e := messageToEntry(m, ip)
if e == nil {
continue
}
// Check if this entry is complete
if e.complete() {
if e.sent {
continue
}
e.sent = true
entries <- e
ip = make(map[string]*ServiceEntry)
} else {
// Fire off a node specific query
m := new(dns.Msg)
m.SetQuestion(e.Name, dns.TypePTR)
m.RecursionDesired = false
if err := client.sendQuery(m); err != nil {
log.Printf("[ERR] mdns: Failed to query instance %s: %v", e.Name, err)
}
}
}
}
return nil
}
// Lookup is the same as Query, however it uses all the default parameters // Lookup is the same as Query, however it uses all the default parameters
func Lookup(service string, entries chan<- *ServiceEntry) error { func Lookup(service string, entries chan<- *ServiceEntry) error {
params := DefaultParams(service) params := DefaultParams(service)
@ -178,23 +235,29 @@ func (c *client) Close() error {
// setInterface is used to set the query interface, uses sytem // setInterface is used to set the query interface, uses sytem
// default if not provided // default if not provided
func (c *client) setInterface(iface *net.Interface) error { func (c *client) setInterface(iface *net.Interface, loopback bool) error {
p := ipv4.NewPacketConn(c.ipv4UnicastConn) p := ipv4.NewPacketConn(c.ipv4UnicastConn)
if err := p.SetMulticastInterface(iface); err != nil { if err := p.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil {
return err return err
} }
p2 := ipv6.NewPacketConn(c.ipv6UnicastConn) p2 := ipv6.NewPacketConn(c.ipv6UnicastConn)
if err := p2.SetMulticastInterface(iface); err != nil { if err := p2.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil {
return err return err
} }
p = ipv4.NewPacketConn(c.ipv4MulticastConn) p = ipv4.NewPacketConn(c.ipv4MulticastConn)
if err := p.SetMulticastInterface(iface); err != nil { if err := p.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil {
return err return err
} }
p2 = ipv6.NewPacketConn(c.ipv6MulticastConn) p2 = ipv6.NewPacketConn(c.ipv6MulticastConn)
if err := p2.SetMulticastInterface(iface); err != nil { if err := p2.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil {
return err return err
} }
if loopback {
p.SetMulticastLoopback(true)
p2.SetMulticastLoopback(true)
}
return nil return nil
} }
@ -232,64 +295,11 @@ func (c *client) query(params *QueryParam) error {
// Listen until we reach the timeout // Listen until we reach the timeout
finish := time.After(params.Timeout) finish := time.After(params.Timeout)
for { for {
select { select {
case resp := <-msgCh: case resp := <-msgCh:
var inp *ServiceEntry inp := messageToEntry(resp, inprogress)
for _, answer := range append(resp.Answer, resp.Extra...) {
// TODO(reddaly): Check that response corresponds to serviceAddr?
switch rr := answer.(type) {
case *dns.PTR:
// Create new entry for this
inp = ensureName(inprogress, rr.Ptr)
if inp.complete() {
continue
}
case *dns.SRV:
// Check for a target mismatch
if rr.Target != rr.Hdr.Name {
alias(inprogress, rr.Hdr.Name, rr.Target)
}
// Get the port
inp = ensureName(inprogress, rr.Hdr.Name)
if inp.complete() {
continue
}
inp.Host = rr.Target
inp.Port = int(rr.Port)
case *dns.TXT:
// Pull out the txt
inp = ensureName(inprogress, rr.Hdr.Name)
if inp.complete() {
continue
}
inp.Info = strings.Join(rr.Txt, "|")
inp.InfoFields = rr.Txt
inp.hasTXT = true
case *dns.A:
// Pull out the IP
inp = ensureName(inprogress, rr.Hdr.Name)
if inp.complete() {
continue
}
inp.Addr = rr.A // @Deprecated
inp.AddrV4 = rr.A
case *dns.AAAA:
// Pull out the IP
inp = ensureName(inprogress, rr.Hdr.Name)
if inp.complete() {
continue
}
inp.Addr = rr.AAAA // @Deprecated
inp.AddrV6 = rr.AAAA
}
}
if inp == nil { if inp == nil {
continue continue
} }
@ -379,3 +389,63 @@ func alias(inprogress map[string]*ServiceEntry, src, dst string) {
srcEntry := ensureName(inprogress, src) srcEntry := ensureName(inprogress, src)
inprogress[dst] = srcEntry inprogress[dst] = srcEntry
} }
func messageToEntry(m *dns.Msg, inprogress map[string]*ServiceEntry) *ServiceEntry {
var inp *ServiceEntry
for _, answer := range append(m.Answer, m.Extra...) {
// TODO(reddaly): Check that response corresponds to serviceAddr?
switch rr := answer.(type) {
case *dns.PTR:
// Create new entry for this
inp = ensureName(inprogress, rr.Ptr)
if inp.complete() {
continue
}
case *dns.SRV:
// Check for a target mismatch
if rr.Target != rr.Hdr.Name {
alias(inprogress, rr.Hdr.Name, rr.Target)
}
// Get the port
inp = ensureName(inprogress, rr.Hdr.Name)
if inp.complete() {
continue
}
inp.Host = rr.Target
inp.Port = int(rr.Port)
case *dns.TXT:
// Pull out the txt
inp = ensureName(inprogress, rr.Hdr.Name)
if inp.complete() {
continue
}
inp.Info = strings.Join(rr.Txt, "|")
inp.InfoFields = rr.Txt
inp.hasTXT = true
case *dns.A:
// Pull out the IP
inp = ensureName(inprogress, rr.Hdr.Name)
if inp.complete() {
continue
}
inp.Addr = rr.A // @Deprecated
inp.AddrV4 = rr.A
case *dns.AAAA:
// Pull out the IP
inp = ensureName(inprogress, rr.Hdr.Name)
if inp.complete() {
continue
}
inp.Addr = rr.AAAA // @Deprecated
inp.AddrV6 = rr.AAAA
}
if inp != nil {
inp.TTL = int(answer.Header().Ttl)
}
}
return inp
}

192
server.go
View File

@ -3,27 +3,38 @@ package mdns
import ( import (
"fmt" "fmt"
"log" "log"
"math/rand"
"net" "net"
"sync" "sync"
"time"
"github.com/miekg/dns" "github.com/miekg/dns"
) "golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
const (
ipv4mdns = "224.0.0.251"
ipv6mdns = "ff02::fb"
mdnsPort = 5353
forceUnicastResponses = false
) )
var ( var (
mdnsGroupIPv4 = net.IPv4(224, 0, 0, 251)
mdnsGroupIPv6 = net.ParseIP("ff02::fb")
// mDNS wildcard addresses
mdnsWildcardAddrIPv4 = &net.UDPAddr{
IP: net.ParseIP("224.0.0.0"),
Port: 5353,
}
mdnsWildcardAddrIPv6 = &net.UDPAddr{
IP: net.ParseIP("[ff02::]"),
Port: 5353,
}
// mDNS endpoint addresses
ipv4Addr = &net.UDPAddr{ ipv4Addr = &net.UDPAddr{
IP: net.ParseIP(ipv4mdns), IP: mdnsGroupIPv4,
Port: mdnsPort, Port: 5353,
} }
ipv6Addr = &net.UDPAddr{ ipv6Addr = &net.UDPAddr{
IP: net.ParseIP(ipv6mdns), IP: mdnsGroupIPv6,
Port: mdnsPort, Port: 5353,
} }
) )
@ -54,12 +65,43 @@ type Server struct {
// NewServer is used to create a new mDNS server from a config // NewServer is used to create a new mDNS server from a config
func NewServer(config *Config) (*Server, error) { func NewServer(config *Config) (*Server, error) {
// Create the listeners // Create the listeners
ipv4List, _ := net.ListenMulticastUDP("udp4", config.Iface, ipv4Addr) // Create wildcard connections (because :5353 can be already taken by other apps)
ipv6List, _ := net.ListenMulticastUDP("udp6", config.Iface, ipv6Addr) ipv4List, _ := net.ListenUDP("udp4", mdnsWildcardAddrIPv4)
ipv6List, _ := net.ListenUDP("udp6", mdnsWildcardAddrIPv6)
// Check if we have any listener
if ipv4List == nil && ipv6List == nil { if ipv4List == nil && ipv6List == nil {
return nil, fmt.Errorf("No multicast listeners could be started") return nil, fmt.Errorf("[ERR] mdns: Failed to bind to any udp port!")
}
// Join multicast groups to receive announcements
p1 := ipv4.NewPacketConn(ipv4List)
p2 := ipv6.NewPacketConn(ipv6List)
p1.SetMulticastLoopback(true)
p2.SetMulticastLoopback(true)
if config.Iface != nil {
if err := p1.JoinGroup(config.Iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil {
return nil, err
}
if err := p2.JoinGroup(config.Iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil {
return nil, err
}
} else {
ifaces, err := net.Interfaces()
if err != nil {
return nil, err
}
errCount1, errCount2 := 0, 0
for _, iface := range ifaces {
if err := p1.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil {
errCount1++
}
if err := p2.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil {
errCount2++
}
}
if len(ifaces) == errCount1 && len(ifaces) == errCount2 {
return nil, fmt.Errorf("Failed to join multicast group on all interfaces!")
}
} }
s := &Server{ s := &Server{
@ -77,6 +119,8 @@ func NewServer(config *Config) (*Server, error) {
go s.recv(s.ipv6List) go s.recv(s.ipv6List)
} }
go s.probe()
return s, nil return s, nil
} }
@ -88,8 +132,10 @@ func (s *Server) Shutdown() error {
if s.shutdown { if s.shutdown {
return nil return nil
} }
s.shutdown = true s.shutdown = true
close(s.shutdownCh) close(s.shutdownCh)
s.unregister()
if s.ipv4List != nil { if s.ipv4List != nil {
s.ipv4List.Close() s.ipv4List.Close()
@ -97,6 +143,7 @@ func (s *Server) Shutdown() error {
if s.ipv6List != nil { if s.ipv6List != nil {
s.ipv6List.Close() s.ipv6List.Close()
} }
return nil return nil
} }
@ -222,12 +269,12 @@ func (s *Server) handleQuery(query *dns.Msg, from net.Addr) error {
} }
if mresp := resp(false); mresp != nil { if mresp := resp(false); mresp != nil {
if err := s.sendResponse(mresp, from, false); err != nil { if err := s.sendResponse(mresp, from); err != nil {
return fmt.Errorf("mdns: error sending multicast response: %v", err) return fmt.Errorf("mdns: error sending multicast response: %v", err)
} }
} }
if uresp := resp(true); uresp != nil { if uresp := resp(true); uresp != nil {
if err := s.sendResponse(uresp, from, true); err != nil { if err := s.sendResponse(uresp, from); err != nil {
return fmt.Errorf("mdns: error sending unicast response: %v", err) return fmt.Errorf("mdns: error sending unicast response: %v", err)
} }
} }
@ -256,14 +303,100 @@ func (s *Server) handleQuestion(q dns.Question) (multicastRecs, unicastRecs []dn
// In the Question Section of a Multicast DNS query, the top bit of the // In the Question Section of a Multicast DNS query, the top bit of the
// qclass field is used to indicate that unicast responses are preferred // qclass field is used to indicate that unicast responses are preferred
// for this particular question. (See Section 5.4.) // for this particular question. (See Section 5.4.)
if q.Qclass&(1<<15) != 0 || forceUnicastResponses { if q.Qclass&(1<<15) != 0 {
return nil, records return nil, records
} }
return records, nil return records, nil
} }
func (s *Server) probe() {
sd, ok := s.config.Zone.(*MDNSService)
if !ok {
return
}
name := fmt.Sprintf("%s.%s.%s.", sd.Instance, trimDot(sd.Service), trimDot(sd.Domain))
q := new(dns.Msg)
q.SetQuestion(name, dns.TypePTR)
q.RecursionDesired = false
srv := &dns.SRV{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
Ttl: defaultTTL,
},
Priority: 0,
Weight: 0,
Port: uint16(sd.Port),
Target: sd.HostName,
}
txt := &dns.TXT{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeTXT,
Class: dns.ClassINET,
Ttl: defaultTTL,
},
Txt: sd.TXT,
}
q.Ns = []dns.RR{srv, txt}
randomizer := rand.New(rand.NewSource(time.Now().UnixNano()))
for i := 0; i < 3; i++ {
if err := s.multicastResponse(q); err != nil {
log.Println("[ERR] mdns: failed to send probe:", err.Error())
}
time.Sleep(time.Duration(randomizer.Intn(250)) * time.Millisecond)
}
resp := new(dns.Msg)
resp.MsgHdr.Response = true
// set for query
q.SetQuestion(name, dns.TypeANY)
resp.Answer = append(resp.Answer, s.config.Zone.Records(q.Question[0])...)
// reset
q.SetQuestion(name, dns.TypePTR)
// From RFC6762
// The Multicast DNS responder MUST send at least two unsolicited
// responses, one second apart. To provide increased robustness against
// packet loss, a responder MAY send up to eight unsolicited responses,
// provided that the interval between unsolicited responses increases by
// at least a factor of two with every response sent.
timeout := 1 * time.Second
for i := 0; i < 3; i++ {
if err := s.multicastResponse(resp); err != nil {
log.Println("[ERR] mdns: failed to send announcement:", err.Error())
}
time.Sleep(timeout)
timeout *= 2
}
}
// multicastResponse us used to send a multicast response packet
func (s *Server) multicastResponse(msg *dns.Msg) error {
buf, err := msg.Pack()
if err != nil {
return err
}
if s.ipv4List != nil {
s.ipv4List.WriteToUDP(buf, ipv4Addr)
}
if s.ipv6List != nil {
s.ipv6List.WriteToUDP(buf, ipv6Addr)
}
return nil
}
// sendResponse is used to send a response packet // sendResponse is used to send a response packet
func (s *Server) sendResponse(resp *dns.Msg, from net.Addr, unicast bool) error { func (s *Server) sendResponse(resp *dns.Msg, from net.Addr) error {
// TODO(reddaly): Respect the unicast argument, and allow sending responses // TODO(reddaly): Respect the unicast argument, and allow sending responses
// over multicast. // over multicast.
buf, err := resp.Pack() buf, err := resp.Pack()
@ -281,3 +414,22 @@ func (s *Server) sendResponse(resp *dns.Msg, from net.Addr, unicast bool) error
return err return err
} }
} }
func (s *Server) unregister() error {
sd, ok := s.config.Zone.(*MDNSService)
if !ok {
return nil
}
sd.TTL = 0
name := fmt.Sprintf("%s.%s.%s.", sd.Instance, trimDot(sd.Service), trimDot(sd.Domain))
q := new(dns.Msg)
q.SetQuestion(name, dns.TypeANY)
resp := new(dns.Msg)
resp.MsgHdr.Response = true
resp.Answer = append(resp.Answer, s.config.Zone.Records(q.Question[0])...)
return s.multicastResponse(resp)
}

29
zone.go
View File

@ -23,14 +23,14 @@ type Zone interface {
// MDNSService is used to export a named service by implementing a Zone // MDNSService is used to export a named service by implementing a Zone
type MDNSService struct { type MDNSService struct {
Instance string // Instance name (e.g. "hostService name") Instance string // Instance name (e.g. "hostService name")
Service string // Service name (e.g. "_http._tcp.") Service string // Service name (e.g. "_http._tcp.")
Domain string // If blank, assumes "local" Domain string // If blank, assumes "local"
HostName string // Host machine DNS name (e.g. "mymachine.net.") HostName string // Host machine DNS name (e.g. "mymachine.net.")
Port int // Service Port Port int // Service Port
IPs []net.IP // IP addresses for the service's host IPs []net.IP // IP addresses for the service's host
TXT []string // Service TXT records TXT []string // Service TXT records
TTL uint32
serviceAddr string // Fully qualified service address serviceAddr string // Fully qualified service address
instanceAddr string // Fully qualified instance address instanceAddr string // Fully qualified instance address
enumAddr string // _services._dns-sd._udp.<domain> enumAddr string // _services._dns-sd._udp.<domain>
@ -122,6 +122,7 @@ func NewMDNSService(instance, service, domain, hostName string, port int, ips []
Port: port, Port: port,
IPs: ips, IPs: ips,
TXT: txt, TXT: txt,
TTL: defaultTTL,
serviceAddr: fmt.Sprintf("%s.%s.", trimDot(service), trimDot(domain)), serviceAddr: fmt.Sprintf("%s.%s.", trimDot(service), trimDot(domain)),
instanceAddr: fmt.Sprintf("%s.%s.%s.", instance, trimDot(service), trimDot(domain)), instanceAddr: fmt.Sprintf("%s.%s.%s.", instance, trimDot(service), trimDot(domain)),
enumAddr: fmt.Sprintf("_services._dns-sd._udp.%s.", trimDot(domain)), enumAddr: fmt.Sprintf("_services._dns-sd._udp.%s.", trimDot(domain)),
@ -162,7 +163,7 @@ func (m *MDNSService) serviceEnum(q dns.Question) []dns.RR {
Name: q.Name, Name: q.Name,
Rrtype: dns.TypePTR, Rrtype: dns.TypePTR,
Class: dns.ClassINET, Class: dns.ClassINET,
Ttl: defaultTTL, Ttl: m.TTL,
}, },
Ptr: m.serviceAddr, Ptr: m.serviceAddr,
} }
@ -184,7 +185,7 @@ func (m *MDNSService) serviceRecords(q dns.Question) []dns.RR {
Name: q.Name, Name: q.Name,
Rrtype: dns.TypePTR, Rrtype: dns.TypePTR,
Class: dns.ClassINET, Class: dns.ClassINET,
Ttl: defaultTTL, Ttl: m.TTL,
}, },
Ptr: m.instanceAddr, Ptr: m.instanceAddr,
} }
@ -229,7 +230,7 @@ func (m *MDNSService) instanceRecords(q dns.Question) []dns.RR {
Name: m.HostName, Name: m.HostName,
Rrtype: dns.TypeA, Rrtype: dns.TypeA,
Class: dns.ClassINET, Class: dns.ClassINET,
Ttl: defaultTTL, Ttl: m.TTL,
}, },
A: ip4, A: ip4,
}) })
@ -254,7 +255,7 @@ func (m *MDNSService) instanceRecords(q dns.Question) []dns.RR {
Name: m.HostName, Name: m.HostName,
Rrtype: dns.TypeAAAA, Rrtype: dns.TypeAAAA,
Class: dns.ClassINET, Class: dns.ClassINET,
Ttl: defaultTTL, Ttl: m.TTL,
}, },
AAAA: ip16, AAAA: ip16,
}) })
@ -269,7 +270,7 @@ func (m *MDNSService) instanceRecords(q dns.Question) []dns.RR {
Name: q.Name, Name: q.Name,
Rrtype: dns.TypeSRV, Rrtype: dns.TypeSRV,
Class: dns.ClassINET, Class: dns.ClassINET,
Ttl: defaultTTL, Ttl: m.TTL,
}, },
Priority: 10, Priority: 10,
Weight: 1, Weight: 1,
@ -297,7 +298,7 @@ func (m *MDNSService) instanceRecords(q dns.Question) []dns.RR {
Name: q.Name, Name: q.Name,
Rrtype: dns.TypeTXT, Rrtype: dns.TypeTXT,
Class: dns.ClassINET, Class: dns.ClassINET,
Ttl: defaultTTL, Ttl: m.TTL,
}, },
Txt: m.TXT, Txt: m.TXT,
} }