Add the probing, announcements and shutdown. Add client Listen which can be used for a watcher

This commit is contained in:
Asim 2016-05-01 18:55:42 +01:00
parent 2963a9d96a
commit f4bad1caf6
3 changed files with 318 additions and 95 deletions

192
client.go
View File

@ -22,6 +22,7 @@ type ServiceEntry struct {
Port int
Info string
InfoFields []string
TTL int
Addr net.IP // @Deprecated
@ -69,7 +70,7 @@ func Query(params *QueryParam) error {
// Set the multicast interface
if params.Interface != nil {
if err := client.setInterface(params.Interface); err != nil {
if err := client.setInterface(params.Interface, false); err != nil {
return err
}
}
@ -86,6 +87,62 @@ func Query(params *QueryParam) error {
return client.query(params)
}
// Listen listens indefinitely for multicast updates
func Listen(entries chan<- *ServiceEntry, exit chan struct{}) error {
// Create a new client
client, err := newClient()
if err != nil {
return err
}
defer client.Close()
client.setInterface(nil, true)
// Start listening for response packets
msgCh := make(chan *dns.Msg, 32)
go client.recv(client.ipv4MulticastConn, msgCh)
go client.recv(client.ipv6MulticastConn, msgCh)
go client.recv(client.ipv4MulticastConn, msgCh)
go client.recv(client.ipv6MulticastConn, msgCh)
ip := make(map[string]*ServiceEntry)
for {
select {
case <-exit:
return nil
case <-client.closedCh:
return nil
case m := <-msgCh:
e := messageToEntry(m, ip)
if e == nil {
continue
}
// Check if this entry is complete
if e.complete() {
if e.sent {
continue
}
e.sent = true
entries <- e
ip = make(map[string]*ServiceEntry)
} else {
// Fire off a node specific query
m := new(dns.Msg)
m.SetQuestion(e.Name, dns.TypePTR)
m.RecursionDesired = false
if err := client.sendQuery(m); err != nil {
log.Printf("[ERR] mdns: Failed to query instance %s: %v", e.Name, err)
}
}
}
}
return nil
}
// Lookup is the same as Query, however it uses all the default parameters
func Lookup(service string, entries chan<- *ServiceEntry) error {
params := DefaultParams(service)
@ -178,23 +235,29 @@ func (c *client) Close() error {
// setInterface is used to set the query interface, uses sytem
// default if not provided
func (c *client) setInterface(iface *net.Interface) error {
func (c *client) setInterface(iface *net.Interface, loopback bool) error {
p := ipv4.NewPacketConn(c.ipv4UnicastConn)
if err := p.SetMulticastInterface(iface); err != nil {
if err := p.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil {
return err
}
p2 := ipv6.NewPacketConn(c.ipv6UnicastConn)
if err := p2.SetMulticastInterface(iface); err != nil {
if err := p2.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil {
return err
}
p = ipv4.NewPacketConn(c.ipv4MulticastConn)
if err := p.SetMulticastInterface(iface); err != nil {
if err := p.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil {
return err
}
p2 = ipv6.NewPacketConn(c.ipv6MulticastConn)
if err := p2.SetMulticastInterface(iface); err != nil {
if err := p2.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil {
return err
}
if loopback {
p.SetMulticastLoopback(true)
p2.SetMulticastLoopback(true)
}
return nil
}
@ -232,64 +295,11 @@ func (c *client) query(params *QueryParam) error {
// Listen until we reach the timeout
finish := time.After(params.Timeout)
for {
select {
case resp := <-msgCh:
var inp *ServiceEntry
for _, answer := range append(resp.Answer, resp.Extra...) {
// TODO(reddaly): Check that response corresponds to serviceAddr?
switch rr := answer.(type) {
case *dns.PTR:
// Create new entry for this
inp = ensureName(inprogress, rr.Ptr)
if inp.complete() {
continue
}
case *dns.SRV:
// Check for a target mismatch
if rr.Target != rr.Hdr.Name {
alias(inprogress, rr.Hdr.Name, rr.Target)
}
// Get the port
inp = ensureName(inprogress, rr.Hdr.Name)
if inp.complete() {
continue
}
inp.Host = rr.Target
inp.Port = int(rr.Port)
case *dns.TXT:
// Pull out the txt
inp = ensureName(inprogress, rr.Hdr.Name)
if inp.complete() {
continue
}
inp.Info = strings.Join(rr.Txt, "|")
inp.InfoFields = rr.Txt
inp.hasTXT = true
case *dns.A:
// Pull out the IP
inp = ensureName(inprogress, rr.Hdr.Name)
if inp.complete() {
continue
}
inp.Addr = rr.A // @Deprecated
inp.AddrV4 = rr.A
case *dns.AAAA:
// Pull out the IP
inp = ensureName(inprogress, rr.Hdr.Name)
if inp.complete() {
continue
}
inp.Addr = rr.AAAA // @Deprecated
inp.AddrV6 = rr.AAAA
}
}
inp := messageToEntry(resp, inprogress)
if inp == nil {
continue
}
@ -379,3 +389,63 @@ func alias(inprogress map[string]*ServiceEntry, src, dst string) {
srcEntry := ensureName(inprogress, src)
inprogress[dst] = srcEntry
}
func messageToEntry(m *dns.Msg, inprogress map[string]*ServiceEntry) *ServiceEntry {
var inp *ServiceEntry
for _, answer := range append(m.Answer, m.Extra...) {
// TODO(reddaly): Check that response corresponds to serviceAddr?
switch rr := answer.(type) {
case *dns.PTR:
// Create new entry for this
inp = ensureName(inprogress, rr.Ptr)
if inp.complete() {
continue
}
case *dns.SRV:
// Check for a target mismatch
if rr.Target != rr.Hdr.Name {
alias(inprogress, rr.Hdr.Name, rr.Target)
}
// Get the port
inp = ensureName(inprogress, rr.Hdr.Name)
if inp.complete() {
continue
}
inp.Host = rr.Target
inp.Port = int(rr.Port)
case *dns.TXT:
// Pull out the txt
inp = ensureName(inprogress, rr.Hdr.Name)
if inp.complete() {
continue
}
inp.Info = strings.Join(rr.Txt, "|")
inp.InfoFields = rr.Txt
inp.hasTXT = true
case *dns.A:
// Pull out the IP
inp = ensureName(inprogress, rr.Hdr.Name)
if inp.complete() {
continue
}
inp.Addr = rr.A // @Deprecated
inp.AddrV4 = rr.A
case *dns.AAAA:
// Pull out the IP
inp = ensureName(inprogress, rr.Hdr.Name)
if inp.complete() {
continue
}
inp.Addr = rr.AAAA // @Deprecated
inp.AddrV6 = rr.AAAA
}
if inp != nil {
inp.TTL = int(answer.Header().Ttl)
}
}
return inp
}

192
server.go
View File

@ -3,27 +3,38 @@ package mdns
import (
"fmt"
"log"
"math/rand"
"net"
"sync"
"time"
"github.com/miekg/dns"
)
const (
ipv4mdns = "224.0.0.251"
ipv6mdns = "ff02::fb"
mdnsPort = 5353
forceUnicastResponses = false
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
var (
mdnsGroupIPv4 = net.IPv4(224, 0, 0, 251)
mdnsGroupIPv6 = net.ParseIP("ff02::fb")
// mDNS wildcard addresses
mdnsWildcardAddrIPv4 = &net.UDPAddr{
IP: net.ParseIP("224.0.0.0"),
Port: 5353,
}
mdnsWildcardAddrIPv6 = &net.UDPAddr{
IP: net.ParseIP("[ff02::]"),
Port: 5353,
}
// mDNS endpoint addresses
ipv4Addr = &net.UDPAddr{
IP: net.ParseIP(ipv4mdns),
Port: mdnsPort,
IP: mdnsGroupIPv4,
Port: 5353,
}
ipv6Addr = &net.UDPAddr{
IP: net.ParseIP(ipv6mdns),
Port: mdnsPort,
IP: mdnsGroupIPv6,
Port: 5353,
}
)
@ -54,12 +65,43 @@ type Server struct {
// NewServer is used to create a new mDNS server from a config
func NewServer(config *Config) (*Server, error) {
// Create the listeners
ipv4List, _ := net.ListenMulticastUDP("udp4", config.Iface, ipv4Addr)
ipv6List, _ := net.ListenMulticastUDP("udp6", config.Iface, ipv6Addr)
// Check if we have any listener
// Create wildcard connections (because :5353 can be already taken by other apps)
ipv4List, _ := net.ListenUDP("udp4", mdnsWildcardAddrIPv4)
ipv6List, _ := net.ListenUDP("udp6", mdnsWildcardAddrIPv6)
if ipv4List == nil && ipv6List == nil {
return nil, fmt.Errorf("No multicast listeners could be started")
return nil, fmt.Errorf("[ERR] mdns: Failed to bind to any udp port!")
}
// Join multicast groups to receive announcements
p1 := ipv4.NewPacketConn(ipv4List)
p2 := ipv6.NewPacketConn(ipv6List)
p1.SetMulticastLoopback(true)
p2.SetMulticastLoopback(true)
if config.Iface != nil {
if err := p1.JoinGroup(config.Iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil {
return nil, err
}
if err := p2.JoinGroup(config.Iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil {
return nil, err
}
} else {
ifaces, err := net.Interfaces()
if err != nil {
return nil, err
}
errCount1, errCount2 := 0, 0
for _, iface := range ifaces {
if err := p1.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil {
errCount1++
}
if err := p2.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil {
errCount2++
}
}
if len(ifaces) == errCount1 && len(ifaces) == errCount2 {
return nil, fmt.Errorf("Failed to join multicast group on all interfaces!")
}
}
s := &Server{
@ -77,6 +119,8 @@ func NewServer(config *Config) (*Server, error) {
go s.recv(s.ipv6List)
}
go s.probe()
return s, nil
}
@ -88,8 +132,10 @@ func (s *Server) Shutdown() error {
if s.shutdown {
return nil
}
s.shutdown = true
close(s.shutdownCh)
s.unregister()
if s.ipv4List != nil {
s.ipv4List.Close()
@ -97,6 +143,7 @@ func (s *Server) Shutdown() error {
if s.ipv6List != nil {
s.ipv6List.Close()
}
return nil
}
@ -222,12 +269,12 @@ func (s *Server) handleQuery(query *dns.Msg, from net.Addr) error {
}
if mresp := resp(false); mresp != nil {
if err := s.sendResponse(mresp, from, false); err != nil {
if err := s.sendResponse(mresp, from); err != nil {
return fmt.Errorf("mdns: error sending multicast response: %v", err)
}
}
if uresp := resp(true); uresp != nil {
if err := s.sendResponse(uresp, from, true); err != nil {
if err := s.sendResponse(uresp, from); err != nil {
return fmt.Errorf("mdns: error sending unicast response: %v", err)
}
}
@ -256,14 +303,100 @@ func (s *Server) handleQuestion(q dns.Question) (multicastRecs, unicastRecs []dn
// In the Question Section of a Multicast DNS query, the top bit of the
// qclass field is used to indicate that unicast responses are preferred
// for this particular question. (See Section 5.4.)
if q.Qclass&(1<<15) != 0 || forceUnicastResponses {
if q.Qclass&(1<<15) != 0 {
return nil, records
}
return records, nil
}
func (s *Server) probe() {
sd, ok := s.config.Zone.(*MDNSService)
if !ok {
return
}
name := fmt.Sprintf("%s.%s.%s.", sd.Instance, trimDot(sd.Service), trimDot(sd.Domain))
q := new(dns.Msg)
q.SetQuestion(name, dns.TypePTR)
q.RecursionDesired = false
srv := &dns.SRV{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
Ttl: defaultTTL,
},
Priority: 0,
Weight: 0,
Port: uint16(sd.Port),
Target: sd.HostName,
}
txt := &dns.TXT{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeTXT,
Class: dns.ClassINET,
Ttl: defaultTTL,
},
Txt: sd.TXT,
}
q.Ns = []dns.RR{srv, txt}
randomizer := rand.New(rand.NewSource(time.Now().UnixNano()))
for i := 0; i < 3; i++ {
if err := s.multicastResponse(q); err != nil {
log.Println("[ERR] mdns: failed to send probe:", err.Error())
}
time.Sleep(time.Duration(randomizer.Intn(250)) * time.Millisecond)
}
resp := new(dns.Msg)
resp.MsgHdr.Response = true
// set for query
q.SetQuestion(name, dns.TypeANY)
resp.Answer = append(resp.Answer, s.config.Zone.Records(q.Question[0])...)
// reset
q.SetQuestion(name, dns.TypePTR)
// From RFC6762
// The Multicast DNS responder MUST send at least two unsolicited
// responses, one second apart. To provide increased robustness against
// packet loss, a responder MAY send up to eight unsolicited responses,
// provided that the interval between unsolicited responses increases by
// at least a factor of two with every response sent.
timeout := 1 * time.Second
for i := 0; i < 3; i++ {
if err := s.multicastResponse(resp); err != nil {
log.Println("[ERR] mdns: failed to send announcement:", err.Error())
}
time.Sleep(timeout)
timeout *= 2
}
}
// multicastResponse us used to send a multicast response packet
func (s *Server) multicastResponse(msg *dns.Msg) error {
buf, err := msg.Pack()
if err != nil {
return err
}
if s.ipv4List != nil {
s.ipv4List.WriteToUDP(buf, ipv4Addr)
}
if s.ipv6List != nil {
s.ipv6List.WriteToUDP(buf, ipv6Addr)
}
return nil
}
// sendResponse is used to send a response packet
func (s *Server) sendResponse(resp *dns.Msg, from net.Addr, unicast bool) error {
func (s *Server) sendResponse(resp *dns.Msg, from net.Addr) error {
// TODO(reddaly): Respect the unicast argument, and allow sending responses
// over multicast.
buf, err := resp.Pack()
@ -281,3 +414,22 @@ func (s *Server) sendResponse(resp *dns.Msg, from net.Addr, unicast bool) error
return err
}
}
func (s *Server) unregister() error {
sd, ok := s.config.Zone.(*MDNSService)
if !ok {
return nil
}
sd.TTL = 0
name := fmt.Sprintf("%s.%s.%s.", sd.Instance, trimDot(sd.Service), trimDot(sd.Domain))
q := new(dns.Msg)
q.SetQuestion(name, dns.TypeANY)
resp := new(dns.Msg)
resp.MsgHdr.Response = true
resp.Answer = append(resp.Answer, s.config.Zone.Records(q.Question[0])...)
return s.multicastResponse(resp)
}

15
zone.go
View File

@ -30,7 +30,7 @@ type MDNSService struct {
Port int // Service Port
IPs []net.IP // IP addresses for the service's host
TXT []string // Service TXT records
TTL uint32
serviceAddr string // Fully qualified service address
instanceAddr string // Fully qualified instance address
enumAddr string // _services._dns-sd._udp.<domain>
@ -122,6 +122,7 @@ func NewMDNSService(instance, service, domain, hostName string, port int, ips []
Port: port,
IPs: ips,
TXT: txt,
TTL: defaultTTL,
serviceAddr: fmt.Sprintf("%s.%s.", trimDot(service), trimDot(domain)),
instanceAddr: fmt.Sprintf("%s.%s.%s.", instance, trimDot(service), trimDot(domain)),
enumAddr: fmt.Sprintf("_services._dns-sd._udp.%s.", trimDot(domain)),
@ -162,7 +163,7 @@ func (m *MDNSService) serviceEnum(q dns.Question) []dns.RR {
Name: q.Name,
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: defaultTTL,
Ttl: m.TTL,
},
Ptr: m.serviceAddr,
}
@ -184,7 +185,7 @@ func (m *MDNSService) serviceRecords(q dns.Question) []dns.RR {
Name: q.Name,
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: defaultTTL,
Ttl: m.TTL,
},
Ptr: m.instanceAddr,
}
@ -229,7 +230,7 @@ func (m *MDNSService) instanceRecords(q dns.Question) []dns.RR {
Name: m.HostName,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: defaultTTL,
Ttl: m.TTL,
},
A: ip4,
})
@ -254,7 +255,7 @@ func (m *MDNSService) instanceRecords(q dns.Question) []dns.RR {
Name: m.HostName,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: defaultTTL,
Ttl: m.TTL,
},
AAAA: ip16,
})
@ -269,7 +270,7 @@ func (m *MDNSService) instanceRecords(q dns.Question) []dns.RR {
Name: q.Name,
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
Ttl: defaultTTL,
Ttl: m.TTL,
},
Priority: 10,
Weight: 1,
@ -297,7 +298,7 @@ func (m *MDNSService) instanceRecords(q dns.Question) []dns.RR {
Name: q.Name,
Rrtype: dns.TypeTXT,
Class: dns.ClassINET,
Ttl: defaultTTL,
Ttl: m.TTL,
},
Txt: m.TXT,
}