many lint fixes and optimizations #17
23
util/mdns/.gitignore
vendored
23
util/mdns/.gitignore
vendored
@ -1,23 +0,0 @@
|
|||||||
# Compiled Object files, Static and Dynamic libs (Shared Objects)
|
|
||||||
*.o
|
|
||||||
*.a
|
|
||||||
*.so
|
|
||||||
|
|
||||||
# Folders
|
|
||||||
_obj
|
|
||||||
_test
|
|
||||||
|
|
||||||
# Architecture specific extensions/prefixes
|
|
||||||
*.[568vq]
|
|
||||||
[568vq].out
|
|
||||||
|
|
||||||
*.cgo1.go
|
|
||||||
*.cgo2.c
|
|
||||||
_cgo_defun.c
|
|
||||||
_cgo_gotypes.go
|
|
||||||
_cgo_export.*
|
|
||||||
|
|
||||||
_testmain.go
|
|
||||||
|
|
||||||
*.exe
|
|
||||||
*.test
|
|
@ -1,511 +0,0 @@
|
|||||||
package mdns
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
"net"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
"golang.org/x/net/ipv4"
|
|
||||||
"golang.org/x/net/ipv6"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ServiceEntry is returned after we query for a service
|
|
||||||
type ServiceEntry struct {
|
|
||||||
Name string
|
|
||||||
Host string
|
|
||||||
AddrV4 net.IP
|
|
||||||
AddrV6 net.IP
|
|
||||||
Port int
|
|
||||||
Info string
|
|
||||||
InfoFields []string
|
|
||||||
TTL int
|
|
||||||
Type uint16
|
|
||||||
|
|
||||||
Addr net.IP // @Deprecated
|
|
||||||
|
|
||||||
hasTXT bool
|
|
||||||
sent bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// complete is used to check if we have all the info we need
|
|
||||||
func (s *ServiceEntry) complete() bool {
|
|
||||||
|
|
||||||
return (len(s.AddrV4) > 0 || len(s.AddrV6) > 0 || len(s.Addr) > 0) && s.Port != 0 && s.hasTXT
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryParam is used to customize how a Lookup is performed
|
|
||||||
type QueryParam struct {
|
|
||||||
Service string // Service to lookup
|
|
||||||
Domain string // Lookup domain, default "local"
|
|
||||||
Type uint16 // Lookup type, defaults to dns.TypePTR
|
|
||||||
Context context.Context // Context
|
|
||||||
Timeout time.Duration // Lookup timeout, default 1 second. Ignored if Context is provided
|
|
||||||
Interface *net.Interface // Multicast interface to use
|
|
||||||
Entries chan<- *ServiceEntry // Entries Channel
|
|
||||||
WantUnicastResponse bool // Unicast response desired, as per 5.4 in RFC
|
|
||||||
}
|
|
||||||
|
|
||||||
// DefaultParams is used to return a default set of QueryParam's
|
|
||||||
func DefaultParams(service string) *QueryParam {
|
|
||||||
return &QueryParam{
|
|
||||||
Service: service,
|
|
||||||
Domain: "local",
|
|
||||||
Timeout: time.Second,
|
|
||||||
Entries: make(chan *ServiceEntry),
|
|
||||||
WantUnicastResponse: false, // TODO(reddaly): Change this default.
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Query looks up a given service, in a domain, waiting at most
|
|
||||||
// for a timeout before finishing the query. The results are streamed
|
|
||||||
// to a channel. Sends will not block, so clients should make sure to
|
|
||||||
// either read or buffer.
|
|
||||||
func Query(params *QueryParam) error {
|
|
||||||
// Create a new client
|
|
||||||
client, err := newClient()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer client.Close()
|
|
||||||
|
|
||||||
// Set the multicast interface
|
|
||||||
if params.Interface != nil {
|
|
||||||
if err := client.setInterface(params.Interface, false); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure defaults are set
|
|
||||||
if params.Domain == "" {
|
|
||||||
params.Domain = "local"
|
|
||||||
}
|
|
||||||
|
|
||||||
if params.Context == nil {
|
|
||||||
var cancel context.CancelFunc
|
|
||||||
if params.Timeout == 0 {
|
|
||||||
params.Timeout = time.Second
|
|
||||||
}
|
|
||||||
params.Context, cancel = context.WithTimeout(context.Background(), params.Timeout)
|
|
||||||
defer cancel()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run the query
|
|
||||||
return client.query(params)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Listen listens indefinitely for multicast updates
|
|
||||||
func Listen(entries chan<- *ServiceEntry, exit chan struct{}) error {
|
|
||||||
// Create a new client
|
|
||||||
client, err := newClient()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer client.Close()
|
|
||||||
|
|
||||||
client.setInterface(nil, true)
|
|
||||||
|
|
||||||
// Start listening for response packets
|
|
||||||
msgCh := make(chan *dns.Msg, 32)
|
|
||||||
|
|
||||||
go client.recv(client.ipv4UnicastConn, msgCh)
|
|
||||||
go client.recv(client.ipv6UnicastConn, msgCh)
|
|
||||||
go client.recv(client.ipv4MulticastConn, msgCh)
|
|
||||||
go client.recv(client.ipv6MulticastConn, msgCh)
|
|
||||||
|
|
||||||
ip := make(map[string]*ServiceEntry)
|
|
||||||
|
|
||||||
loop:
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-exit:
|
|
||||||
break loop
|
|
||||||
case <-client.closedCh:
|
|
||||||
break loop
|
|
||||||
case m := <-msgCh:
|
|
||||||
e := messageToEntry(m, ip)
|
|
||||||
if e == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if this entry is complete
|
|
||||||
if e.complete() {
|
|
||||||
if e.sent {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
e.sent = true
|
|
||||||
entries <- e
|
|
||||||
ip = make(map[string]*ServiceEntry)
|
|
||||||
} else {
|
|
||||||
// Fire off a node specific query
|
|
||||||
m := new(dns.Msg)
|
|
||||||
m.SetQuestion(e.Name, dns.TypePTR)
|
|
||||||
m.RecursionDesired = false
|
|
||||||
if err := client.sendQuery(m); err != nil {
|
|
||||||
log.Printf("[ERR] mdns: Failed to query instance %s: %v", e.Name, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Lookup is the same as Query, however it uses all the default parameters
|
|
||||||
func Lookup(service string, entries chan<- *ServiceEntry) error {
|
|
||||||
params := DefaultParams(service)
|
|
||||||
params.Entries = entries
|
|
||||||
return Query(params)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Client provides a query interface that can be used to
|
|
||||||
// search for service providers using mDNS
|
|
||||||
type client struct {
|
|
||||||
ipv4UnicastConn *net.UDPConn
|
|
||||||
ipv6UnicastConn *net.UDPConn
|
|
||||||
|
|
||||||
ipv4MulticastConn *net.UDPConn
|
|
||||||
ipv6MulticastConn *net.UDPConn
|
|
||||||
|
|
||||||
closed bool
|
|
||||||
closedCh chan struct{} // TODO(reddaly): This doesn't appear to be used.
|
|
||||||
closeLock sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewClient creates a new mdns Client that can be used to query
|
|
||||||
// for records
|
|
||||||
func newClient() (*client, error) {
|
|
||||||
// TODO(reddaly): At least attempt to bind to the port required in the spec.
|
|
||||||
// Create a IPv4 listener
|
|
||||||
uconn4, err4 := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
|
||||||
uconn6, err6 := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
|
|
||||||
if err4 != nil && err6 != nil {
|
|
||||||
log.Printf("[ERR] mdns: Failed to bind to udp port: %v %v", err4, err6)
|
|
||||||
}
|
|
||||||
|
|
||||||
if uconn4 == nil && uconn6 == nil {
|
|
||||||
return nil, fmt.Errorf("failed to bind to any unicast udp port")
|
|
||||||
}
|
|
||||||
|
|
||||||
if uconn4 == nil {
|
|
||||||
uconn4 = &net.UDPConn{}
|
|
||||||
}
|
|
||||||
|
|
||||||
if uconn6 == nil {
|
|
||||||
uconn6 = &net.UDPConn{}
|
|
||||||
}
|
|
||||||
|
|
||||||
mconn4, err4 := net.ListenUDP("udp4", mdnsWildcardAddrIPv4)
|
|
||||||
mconn6, err6 := net.ListenUDP("udp6", mdnsWildcardAddrIPv6)
|
|
||||||
if err4 != nil && err6 != nil {
|
|
||||||
log.Printf("[ERR] mdns: Failed to bind to udp port: %v %v", err4, err6)
|
|
||||||
}
|
|
||||||
|
|
||||||
if mconn4 == nil && mconn6 == nil {
|
|
||||||
return nil, fmt.Errorf("failed to bind to any multicast udp port")
|
|
||||||
}
|
|
||||||
|
|
||||||
if mconn4 == nil {
|
|
||||||
mconn4 = &net.UDPConn{}
|
|
||||||
}
|
|
||||||
|
|
||||||
if mconn6 == nil {
|
|
||||||
mconn6 = &net.UDPConn{}
|
|
||||||
}
|
|
||||||
|
|
||||||
p1 := ipv4.NewPacketConn(mconn4)
|
|
||||||
p2 := ipv6.NewPacketConn(mconn6)
|
|
||||||
p1.SetMulticastLoopback(true)
|
|
||||||
p2.SetMulticastLoopback(true)
|
|
||||||
|
|
||||||
ifaces, err := net.Interfaces()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var errCount1, errCount2 int
|
|
||||||
|
|
||||||
for _, iface := range ifaces {
|
|
||||||
if err := p1.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil {
|
|
||||||
errCount1++
|
|
||||||
}
|
|
||||||
if err := p2.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil {
|
|
||||||
errCount2++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(ifaces) == errCount1 && len(ifaces) == errCount2 {
|
|
||||||
return nil, fmt.Errorf("Failed to join multicast group on all interfaces!")
|
|
||||||
}
|
|
||||||
|
|
||||||
c := &client{
|
|
||||||
ipv4MulticastConn: mconn4,
|
|
||||||
ipv6MulticastConn: mconn6,
|
|
||||||
ipv4UnicastConn: uconn4,
|
|
||||||
ipv6UnicastConn: uconn6,
|
|
||||||
closedCh: make(chan struct{}),
|
|
||||||
}
|
|
||||||
return c, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close is used to cleanup the client
|
|
||||||
func (c *client) Close() error {
|
|
||||||
c.closeLock.Lock()
|
|
||||||
defer c.closeLock.Unlock()
|
|
||||||
|
|
||||||
if c.closed {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
c.closed = true
|
|
||||||
|
|
||||||
close(c.closedCh)
|
|
||||||
|
|
||||||
if c.ipv4UnicastConn != nil {
|
|
||||||
c.ipv4UnicastConn.Close()
|
|
||||||
}
|
|
||||||
if c.ipv6UnicastConn != nil {
|
|
||||||
c.ipv6UnicastConn.Close()
|
|
||||||
}
|
|
||||||
if c.ipv4MulticastConn != nil {
|
|
||||||
c.ipv4MulticastConn.Close()
|
|
||||||
}
|
|
||||||
if c.ipv6MulticastConn != nil {
|
|
||||||
c.ipv6MulticastConn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// setInterface is used to set the query interface, uses system
|
|
||||||
// default if not provided
|
|
||||||
func (c *client) setInterface(iface *net.Interface, loopback bool) error {
|
|
||||||
p := ipv4.NewPacketConn(c.ipv4UnicastConn)
|
|
||||||
if err := p.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
p2 := ipv6.NewPacketConn(c.ipv6UnicastConn)
|
|
||||||
if err := p2.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
p = ipv4.NewPacketConn(c.ipv4MulticastConn)
|
|
||||||
if err := p.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
p2 = ipv6.NewPacketConn(c.ipv6MulticastConn)
|
|
||||||
if err := p2.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if loopback {
|
|
||||||
p.SetMulticastLoopback(true)
|
|
||||||
p2.SetMulticastLoopback(true)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// query is used to perform a lookup and stream results
|
|
||||||
func (c *client) query(params *QueryParam) error {
|
|
||||||
// Create the service name
|
|
||||||
serviceAddr := fmt.Sprintf("%s.%s.", trimDot(params.Service), trimDot(params.Domain))
|
|
||||||
|
|
||||||
// Start listening for response packets
|
|
||||||
msgCh := make(chan *dns.Msg, 32)
|
|
||||||
go c.recv(c.ipv4UnicastConn, msgCh)
|
|
||||||
go c.recv(c.ipv6UnicastConn, msgCh)
|
|
||||||
go c.recv(c.ipv4MulticastConn, msgCh)
|
|
||||||
go c.recv(c.ipv6MulticastConn, msgCh)
|
|
||||||
|
|
||||||
// Send the query
|
|
||||||
m := new(dns.Msg)
|
|
||||||
if params.Type == dns.TypeNone {
|
|
||||||
m.SetQuestion(serviceAddr, dns.TypePTR)
|
|
||||||
} else {
|
|
||||||
m.SetQuestion(serviceAddr, params.Type)
|
|
||||||
}
|
|
||||||
// RFC 6762, section 18.12. Repurposing of Top Bit of qclass in Question
|
|
||||||
// Section
|
|
||||||
//
|
|
||||||
// In the Question Section of a Multicast DNS query, the top bit of the qclass
|
|
||||||
// field is used to indicate that unicast responses are preferred for this
|
|
||||||
// particular question. (See Section 5.4.)
|
|
||||||
if params.WantUnicastResponse {
|
|
||||||
m.Question[0].Qclass |= 1 << 15
|
|
||||||
}
|
|
||||||
m.RecursionDesired = false
|
|
||||||
if err := c.sendQuery(m); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Map the in-progress responses
|
|
||||||
inprogress := make(map[string]*ServiceEntry)
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case resp := <-msgCh:
|
|
||||||
inp := messageToEntry(resp, inprogress)
|
|
||||||
|
|
||||||
if inp == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if len(resp.Question) == 0 || resp.Question[0].Name != m.Question[0].Name {
|
|
||||||
// discard anything which we've not asked for
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if this entry is complete
|
|
||||||
if inp.complete() {
|
|
||||||
if inp.sent {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
inp.sent = true
|
|
||||||
select {
|
|
||||||
case params.Entries <- inp:
|
|
||||||
case <-params.Context.Done():
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Fire off a node specific query
|
|
||||||
m := new(dns.Msg)
|
|
||||||
m.SetQuestion(inp.Name, inp.Type)
|
|
||||||
m.RecursionDesired = false
|
|
||||||
if err := c.sendQuery(m); err != nil {
|
|
||||||
log.Printf("[ERR] mdns: Failed to query instance %s: %v", inp.Name, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case <-params.Context.Done():
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// sendQuery is used to multicast a query out
|
|
||||||
func (c *client) sendQuery(q *dns.Msg) error {
|
|
||||||
buf, err := q.Pack()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if c.ipv4UnicastConn != nil {
|
|
||||||
c.ipv4UnicastConn.WriteToUDP(buf, ipv4Addr)
|
|
||||||
}
|
|
||||||
if c.ipv6UnicastConn != nil {
|
|
||||||
c.ipv6UnicastConn.WriteToUDP(buf, ipv6Addr)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// recv is used to receive until we get a shutdown
|
|
||||||
func (c *client) recv(l *net.UDPConn, msgCh chan *dns.Msg) {
|
|
||||||
if l == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
buf := make([]byte, 65536)
|
|
||||||
for {
|
|
||||||
c.closeLock.Lock()
|
|
||||||
if c.closed {
|
|
||||||
c.closeLock.Unlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.closeLock.Unlock()
|
|
||||||
n, err := l.Read(buf)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
msg := new(dns.Msg)
|
|
||||||
if err := msg.Unpack(buf[:n]); err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case msgCh <- msg:
|
|
||||||
case <-c.closedCh:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ensureName is used to ensure the named node is in progress
|
|
||||||
func ensureName(inprogress map[string]*ServiceEntry, name string, typ uint16) *ServiceEntry {
|
|
||||||
if inp, ok := inprogress[name]; ok {
|
|
||||||
return inp
|
|
||||||
}
|
|
||||||
inp := &ServiceEntry{
|
|
||||||
Name: name,
|
|
||||||
Type: typ,
|
|
||||||
}
|
|
||||||
inprogress[name] = inp
|
|
||||||
return inp
|
|
||||||
}
|
|
||||||
|
|
||||||
// alias is used to setup an alias between two entries
|
|
||||||
func alias(inprogress map[string]*ServiceEntry, src, dst string, typ uint16) {
|
|
||||||
srcEntry := ensureName(inprogress, src, typ)
|
|
||||||
inprogress[dst] = srcEntry
|
|
||||||
}
|
|
||||||
|
|
||||||
func messageToEntry(m *dns.Msg, inprogress map[string]*ServiceEntry) *ServiceEntry {
|
|
||||||
var inp *ServiceEntry
|
|
||||||
|
|
||||||
for _, answer := range append(m.Answer, m.Extra...) {
|
|
||||||
// TODO(reddaly): Check that response corresponds to serviceAddr?
|
|
||||||
switch rr := answer.(type) {
|
|
||||||
case *dns.PTR:
|
|
||||||
// Create new entry for this
|
|
||||||
inp = ensureName(inprogress, rr.Ptr, rr.Hdr.Rrtype)
|
|
||||||
if inp.complete() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
case *dns.SRV:
|
|
||||||
// Check for a target mismatch
|
|
||||||
if rr.Target != rr.Hdr.Name {
|
|
||||||
alias(inprogress, rr.Hdr.Name, rr.Target, rr.Hdr.Rrtype)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the port
|
|
||||||
inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype)
|
|
||||||
if inp.complete() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
inp.Host = rr.Target
|
|
||||||
inp.Port = int(rr.Port)
|
|
||||||
case *dns.TXT:
|
|
||||||
// Pull out the txt
|
|
||||||
inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype)
|
|
||||||
if inp.complete() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
inp.Info = strings.Join(rr.Txt, "|")
|
|
||||||
inp.InfoFields = rr.Txt
|
|
||||||
inp.hasTXT = true
|
|
||||||
case *dns.A:
|
|
||||||
// Pull out the IP
|
|
||||||
inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype)
|
|
||||||
if inp.complete() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
inp.Addr = rr.A // @Deprecated
|
|
||||||
inp.AddrV4 = rr.A
|
|
||||||
case *dns.AAAA:
|
|
||||||
// Pull out the IP
|
|
||||||
inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype)
|
|
||||||
if inp.complete() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
inp.Addr = rr.AAAA // @Deprecated
|
|
||||||
inp.AddrV6 = rr.AAAA
|
|
||||||
}
|
|
||||||
|
|
||||||
if inp != nil {
|
|
||||||
inp.TTL = int(answer.Header().Ttl)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return inp
|
|
||||||
}
|
|
@ -1,84 +0,0 @@
|
|||||||
package mdns
|
|
||||||
|
|
||||||
import "github.com/miekg/dns"
|
|
||||||
|
|
||||||
// DNSSDService is a service that complies with the DNS-SD (RFC 6762) and MDNS
|
|
||||||
// (RFC 6762) specs for local, multicast-DNS-based discovery.
|
|
||||||
//
|
|
||||||
// DNSSDService implements the Zone interface and wraps an MDNSService instance.
|
|
||||||
// To deploy an mDNS service that is compliant with DNS-SD, it's recommended to
|
|
||||||
// register only the wrapped instance with the server.
|
|
||||||
//
|
|
||||||
// Example usage:
|
|
||||||
// service := &mdns.DNSSDService{
|
|
||||||
// MDNSService: &mdns.MDNSService{
|
|
||||||
// Instance: "My Foobar Service",
|
|
||||||
// Service: "_foobar._tcp",
|
|
||||||
// Port: 8000,
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// server, err := mdns.NewServer(&mdns.Config{Zone: service})
|
|
||||||
// if err != nil {
|
|
||||||
// log.Fatalf("Error creating server: %v", err)
|
|
||||||
// }
|
|
||||||
// defer server.Shutdown()
|
|
||||||
type DNSSDService struct {
|
|
||||||
MDNSService *MDNSService
|
|
||||||
}
|
|
||||||
|
|
||||||
// Records returns DNS records in response to a DNS question.
|
|
||||||
//
|
|
||||||
// This function returns the DNS response of the underlying MDNSService
|
|
||||||
// instance. It also returns a PTR record for a request for "
|
|
||||||
// _services._dns-sd._udp.<Domain>", as described in section 9 of RFC 6763
|
|
||||||
// ("Service Type Enumeration"), to allow browsing of the underlying MDNSService
|
|
||||||
// instance.
|
|
||||||
func (s *DNSSDService) Records(q dns.Question) []dns.RR {
|
|
||||||
var recs []dns.RR
|
|
||||||
if q.Name == "_services._dns-sd._udp."+s.MDNSService.Domain+"." {
|
|
||||||
recs = s.dnssdMetaQueryRecords(q)
|
|
||||||
}
|
|
||||||
return append(recs, s.MDNSService.Records(q)...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// dnssdMetaQueryRecords returns the DNS records in response to a "meta-query"
|
|
||||||
// issued to browse for DNS-SD services, as per section 9. of RFC6763.
|
|
||||||
//
|
|
||||||
// A meta-query has a name of the form "_services._dns-sd._udp.<Domain>" where
|
|
||||||
// Domain is a fully-qualified domain, such as "local."
|
|
||||||
func (s *DNSSDService) dnssdMetaQueryRecords(q dns.Question) []dns.RR {
|
|
||||||
// Intended behavior, as described in the RFC:
|
|
||||||
// ...it may be useful for network administrators to find the list of
|
|
||||||
// advertised service types on the network, even if those Service Names
|
|
||||||
// are just opaque identifiers and not particularly informative in
|
|
||||||
// isolation.
|
|
||||||
//
|
|
||||||
// For this purpose, a special meta-query is defined. A DNS query for PTR
|
|
||||||
// records with the name "_services._dns-sd._udp.<Domain>" yields a set of
|
|
||||||
// PTR records, where the rdata of each PTR record is the two-abel
|
|
||||||
// <Service> name, plus the same domain, e.g., "_http._tcp.<Domain>".
|
|
||||||
// Including the domain in the PTR rdata allows for slightly better name
|
|
||||||
// compression in Unicast DNS responses, but only the first two labels are
|
|
||||||
// relevant for the purposes of service type enumeration. These two-label
|
|
||||||
// service types can then be used to construct subsequent Service Instance
|
|
||||||
// Enumeration PTR queries, in this <Domain> or others, to discover
|
|
||||||
// instances of that service type.
|
|
||||||
return []dns.RR{
|
|
||||||
&dns.PTR{
|
|
||||||
Hdr: dns.RR_Header{
|
|
||||||
Name: q.Name,
|
|
||||||
Rrtype: dns.TypePTR,
|
|
||||||
Class: dns.ClassINET,
|
|
||||||
Ttl: defaultTTL,
|
|
||||||
},
|
|
||||||
Ptr: s.MDNSService.serviceAddr,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Announcement returns DNS records that should be broadcast during the initial
|
|
||||||
// availability of the service, as described in section 8.3 of RFC 6762.
|
|
||||||
// TODO(reddaly): Add this when Announcement is added to the mdns.Zone interface.
|
|
||||||
//func (s *DNSSDService) Announcement() []dns.RR {
|
|
||||||
// return s.MDNSService.Announcement()
|
|
||||||
//}
|
|
@ -1,69 +0,0 @@
|
|||||||
package mdns
|
|
||||||
|
|
||||||
import (
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
)
|
|
||||||
|
|
||||||
type mockMDNSService struct{}
|
|
||||||
|
|
||||||
func (s *mockMDNSService) Records(q dns.Question) []dns.RR {
|
|
||||||
return []dns.RR{
|
|
||||||
&dns.PTR{
|
|
||||||
Hdr: dns.RR_Header{
|
|
||||||
Name: "fakerecord",
|
|
||||||
Rrtype: dns.TypePTR,
|
|
||||||
Class: dns.ClassINET,
|
|
||||||
Ttl: 42,
|
|
||||||
},
|
|
||||||
Ptr: "fake.local.",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *mockMDNSService) Announcement() []dns.RR {
|
|
||||||
return []dns.RR{
|
|
||||||
&dns.PTR{
|
|
||||||
Hdr: dns.RR_Header{
|
|
||||||
Name: "fakeannounce",
|
|
||||||
Rrtype: dns.TypePTR,
|
|
||||||
Class: dns.ClassINET,
|
|
||||||
Ttl: 42,
|
|
||||||
},
|
|
||||||
Ptr: "fake.local.",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDNSSDServiceRecords(t *testing.T) {
|
|
||||||
s := &DNSSDService{
|
|
||||||
MDNSService: &MDNSService{
|
|
||||||
serviceAddr: "_foobar._tcp.local.",
|
|
||||||
Domain: "local",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
q := dns.Question{
|
|
||||||
Name: "_services._dns-sd._udp.local.",
|
|
||||||
Qtype: dns.TypePTR,
|
|
||||||
Qclass: dns.ClassINET,
|
|
||||||
}
|
|
||||||
recs := s.Records(q)
|
|
||||||
if got, want := len(recs), 1; got != want {
|
|
||||||
t.Fatalf("s.Records(%v) returned %v records, want %v", q, got, want)
|
|
||||||
}
|
|
||||||
|
|
||||||
want := dns.RR(&dns.PTR{
|
|
||||||
Hdr: dns.RR_Header{
|
|
||||||
Name: "_services._dns-sd._udp.local.",
|
|
||||||
Rrtype: dns.TypePTR,
|
|
||||||
Class: dns.ClassINET,
|
|
||||||
Ttl: defaultTTL,
|
|
||||||
},
|
|
||||||
Ptr: "_foobar._tcp.local.",
|
|
||||||
})
|
|
||||||
if got := recs[0]; !reflect.DeepEqual(got, want) {
|
|
||||||
t.Errorf("s.Records()[0] = %v, want %v", got, want)
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,527 +0,0 @@
|
|||||||
package mdns
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"math/rand"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
"github.com/unistack-org/micro/v3/logger"
|
|
||||||
"golang.org/x/net/ipv4"
|
|
||||||
"golang.org/x/net/ipv6"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
mdnsGroupIPv4 = net.ParseIP("224.0.0.251")
|
|
||||||
mdnsGroupIPv6 = net.ParseIP("ff02::fb")
|
|
||||||
|
|
||||||
// mDNS wildcard addresses
|
|
||||||
mdnsWildcardAddrIPv4 = &net.UDPAddr{
|
|
||||||
IP: net.ParseIP("224.0.0.0"),
|
|
||||||
Port: 5353,
|
|
||||||
}
|
|
||||||
mdnsWildcardAddrIPv6 = &net.UDPAddr{
|
|
||||||
IP: net.ParseIP("ff02::"),
|
|
||||||
Port: 5353,
|
|
||||||
}
|
|
||||||
|
|
||||||
// mDNS endpoint addresses
|
|
||||||
ipv4Addr = &net.UDPAddr{
|
|
||||||
IP: mdnsGroupIPv4,
|
|
||||||
Port: 5353,
|
|
||||||
}
|
|
||||||
ipv6Addr = &net.UDPAddr{
|
|
||||||
IP: mdnsGroupIPv6,
|
|
||||||
Port: 5353,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
// GetMachineIP is a func which returns the outbound IP of this machine.
|
|
||||||
// Used by the server to determine whether to attempt send the response on a local address
|
|
||||||
type GetMachineIP func() net.IP
|
|
||||||
|
|
||||||
// Config is used to configure the mDNS server
|
|
||||||
type Config struct {
|
|
||||||
// Zone must be provided to support responding to queries
|
|
||||||
Zone Zone
|
|
||||||
|
|
||||||
// Iface if provided binds the multicast listener to the given
|
|
||||||
// interface. If not provided, the system default multicase interface
|
|
||||||
// is used.
|
|
||||||
Iface *net.Interface
|
|
||||||
|
|
||||||
// Port If it is not 0, replace the port 5353 with this port number.
|
|
||||||
Port int
|
|
||||||
|
|
||||||
// GetMachineIP is a function to return the IP of the local machine
|
|
||||||
GetMachineIP GetMachineIP
|
|
||||||
// LocalhostChecking if enabled asks the server to also send responses to 0.0.0.0 if the target IP
|
|
||||||
// is this host (as defined by GetMachineIP). Useful in case machine is on a VPN which blocks comms on non standard ports
|
|
||||||
LocalhostChecking bool
|
|
||||||
|
|
||||||
Context context.Context
|
|
||||||
}
|
|
||||||
|
|
||||||
// Server is an mDNS server used to listen for mDNS queries and respond if we
|
|
||||||
// have a matching local record
|
|
||||||
type Server struct {
|
|
||||||
config *Config
|
|
||||||
|
|
||||||
ipv4List *net.UDPConn
|
|
||||||
ipv6List *net.UDPConn
|
|
||||||
|
|
||||||
shutdown bool
|
|
||||||
shutdownCh chan struct{}
|
|
||||||
shutdownLock sync.Mutex
|
|
||||||
wg sync.WaitGroup
|
|
||||||
|
|
||||||
outboundIP net.IP
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewServer is used to create a new mDNS server from a config
|
|
||||||
func NewServer(config *Config) (*Server, error) {
|
|
||||||
setCustomPort(config.Port)
|
|
||||||
|
|
||||||
// Create the listeners
|
|
||||||
// Create wildcard connections (because :5353 can be already taken by other apps)
|
|
||||||
ipv4List, _ := net.ListenUDP("udp4", mdnsWildcardAddrIPv4)
|
|
||||||
ipv6List, _ := net.ListenUDP("udp6", mdnsWildcardAddrIPv6)
|
|
||||||
if ipv4List == nil && ipv6List == nil {
|
|
||||||
return nil, fmt.Errorf("[ERR] mdns: Failed to bind to any udp port!")
|
|
||||||
}
|
|
||||||
|
|
||||||
if ipv4List == nil {
|
|
||||||
ipv4List = &net.UDPConn{}
|
|
||||||
}
|
|
||||||
if ipv6List == nil {
|
|
||||||
ipv6List = &net.UDPConn{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Join multicast groups to receive announcements
|
|
||||||
p1 := ipv4.NewPacketConn(ipv4List)
|
|
||||||
p2 := ipv6.NewPacketConn(ipv6List)
|
|
||||||
p1.SetMulticastLoopback(true)
|
|
||||||
p2.SetMulticastLoopback(true)
|
|
||||||
|
|
||||||
if config.Iface != nil {
|
|
||||||
if err := p1.JoinGroup(config.Iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err := p2.JoinGroup(config.Iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
ifaces, err := net.Interfaces()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
errCount1, errCount2 := 0, 0
|
|
||||||
for _, iface := range ifaces {
|
|
||||||
if err := p1.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil {
|
|
||||||
errCount1++
|
|
||||||
}
|
|
||||||
if err := p2.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil {
|
|
||||||
errCount2++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(ifaces) == errCount1 && len(ifaces) == errCount2 {
|
|
||||||
return nil, fmt.Errorf("Failed to join multicast group on all interfaces!")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ipFunc := getOutboundIP
|
|
||||||
if config.GetMachineIP != nil {
|
|
||||||
ipFunc = config.GetMachineIP
|
|
||||||
}
|
|
||||||
|
|
||||||
s := &Server{
|
|
||||||
config: config,
|
|
||||||
ipv4List: ipv4List,
|
|
||||||
ipv6List: ipv6List,
|
|
||||||
shutdownCh: make(chan struct{}),
|
|
||||||
outboundIP: ipFunc(),
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.config.Context == nil {
|
|
||||||
s.config.Context = context.Background()
|
|
||||||
}
|
|
||||||
|
|
||||||
go s.recv(s.ipv4List)
|
|
||||||
go s.recv(s.ipv6List)
|
|
||||||
|
|
||||||
s.wg.Add(1)
|
|
||||||
go s.probe()
|
|
||||||
|
|
||||||
return s, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Shutdown is used to shutdown the listener
|
|
||||||
func (s *Server) Shutdown() error {
|
|
||||||
s.shutdownLock.Lock()
|
|
||||||
defer s.shutdownLock.Unlock()
|
|
||||||
|
|
||||||
if s.shutdown {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
s.shutdown = true
|
|
||||||
close(s.shutdownCh)
|
|
||||||
if err := s.unregister(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.ipv4List != nil {
|
|
||||||
s.ipv4List.Close()
|
|
||||||
}
|
|
||||||
if s.ipv6List != nil {
|
|
||||||
s.ipv6List.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
s.wg.Wait()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// recv is a long running routine to receive packets from an interface
|
|
||||||
func (s *Server) recv(c *net.UDPConn) {
|
|
||||||
if c == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
buf := make([]byte, 65536)
|
|
||||||
for {
|
|
||||||
s.shutdownLock.Lock()
|
|
||||||
if s.shutdown {
|
|
||||||
s.shutdownLock.Unlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.shutdownLock.Unlock()
|
|
||||||
n, from, err := c.ReadFrom(buf)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if err := s.parsePacket(buf[:n], from); err != nil {
|
|
||||||
logger.Errorf(s.config.Context, "[ERR] mdns: Failed to handle query: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// parsePacket is used to parse an incoming packet
|
|
||||||
func (s *Server) parsePacket(packet []byte, from net.Addr) error {
|
|
||||||
var msg dns.Msg
|
|
||||||
if err := msg.Unpack(packet); err != nil {
|
|
||||||
logger.Errorf(s.config.Context, "[ERR] mdns: Failed to unpack packet: %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// TODO: This is a bit of a hack
|
|
||||||
// We decided to ignore some mDNS answers for the time being
|
|
||||||
// See: https://tools.ietf.org/html/rfc6762#section-7.2
|
|
||||||
msg.Truncated = false
|
|
||||||
return s.handleQuery(&msg, from)
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleQuery is used to handle an incoming query
|
|
||||||
func (s *Server) handleQuery(query *dns.Msg, from net.Addr) error {
|
|
||||||
if query.Opcode != dns.OpcodeQuery {
|
|
||||||
// "In both multicast query and multicast response messages, the OPCODE MUST
|
|
||||||
// be zero on transmission (only standard queries are currently supported
|
|
||||||
// over multicast). Multicast DNS messages received with an OPCODE other
|
|
||||||
// than zero MUST be silently ignored." Note: OpcodeQuery == 0
|
|
||||||
return fmt.Errorf("mdns: received query with non-zero Opcode %v: %v", query.Opcode, *query)
|
|
||||||
}
|
|
||||||
if query.Rcode != 0 {
|
|
||||||
// "In both multicast query and multicast response messages, the Response
|
|
||||||
// Code MUST be zero on transmission. Multicast DNS messages received with
|
|
||||||
// non-zero Response Codes MUST be silently ignored."
|
|
||||||
return fmt.Errorf("mdns: received query with non-zero Rcode %v: %v", query.Rcode, *query)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(reddaly): Handle "TC (Truncated) Bit":
|
|
||||||
// In query messages, if the TC bit is set, it means that additional
|
|
||||||
// Known-Answer records may be following shortly. A responder SHOULD
|
|
||||||
// record this fact, and wait for those additional Known-Answer records,
|
|
||||||
// before deciding whether to respond. If the TC bit is clear, it means
|
|
||||||
// that the querying host has no additional Known Answers.
|
|
||||||
if query.Truncated {
|
|
||||||
return fmt.Errorf("[ERR] mdns: support for DNS requests with high truncated bit not implemented: %v", *query)
|
|
||||||
}
|
|
||||||
|
|
||||||
unicastAnswer := make([]dns.RR, 0, len(query.Question))
|
|
||||||
multicastAnswer := make([]dns.RR, 0, len(query.Question))
|
|
||||||
|
|
||||||
// Handle each question
|
|
||||||
for _, q := range query.Question {
|
|
||||||
mrecs, urecs := s.handleQuestion(q)
|
|
||||||
multicastAnswer = append(multicastAnswer, mrecs...)
|
|
||||||
unicastAnswer = append(unicastAnswer, urecs...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// See section 18 of RFC 6762 for rules about DNS headers.
|
|
||||||
resp := func(unicast bool) *dns.Msg {
|
|
||||||
// 18.1: ID (Query Identifier)
|
|
||||||
// 0 for multicast response, query.Id for unicast response
|
|
||||||
id := uint16(0)
|
|
||||||
if unicast {
|
|
||||||
id = query.Id
|
|
||||||
}
|
|
||||||
|
|
||||||
var answer []dns.RR
|
|
||||||
if unicast {
|
|
||||||
answer = unicastAnswer
|
|
||||||
} else {
|
|
||||||
answer = multicastAnswer
|
|
||||||
}
|
|
||||||
if len(answer) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return &dns.Msg{
|
|
||||||
MsgHdr: dns.MsgHdr{
|
|
||||||
Id: id,
|
|
||||||
|
|
||||||
// 18.2: QR (Query/Response) Bit - must be set to 1 in response.
|
|
||||||
Response: true,
|
|
||||||
|
|
||||||
// 18.3: OPCODE - must be zero in response (OpcodeQuery == 0)
|
|
||||||
Opcode: dns.OpcodeQuery,
|
|
||||||
|
|
||||||
// 18.4: AA (Authoritative Answer) Bit - must be set to 1
|
|
||||||
Authoritative: true,
|
|
||||||
|
|
||||||
// The following fields must all be set to 0:
|
|
||||||
// 18.5: TC (TRUNCATED) Bit
|
|
||||||
// 18.6: RD (Recursion Desired) Bit
|
|
||||||
// 18.7: RA (Recursion Available) Bit
|
|
||||||
// 18.8: Z (Zero) Bit
|
|
||||||
// 18.9: AD (Authentic Data) Bit
|
|
||||||
// 18.10: CD (Checking Disabled) Bit
|
|
||||||
// 18.11: RCODE (Response Code)
|
|
||||||
},
|
|
||||||
// 18.12 pertains to questions (handled by handleQuestion)
|
|
||||||
// 18.13 pertains to resource records (handled by handleQuestion)
|
|
||||||
|
|
||||||
// 18.14: Name Compression - responses should be compressed (though see
|
|
||||||
// caveats in the RFC), so set the Compress bit (part of the dns library
|
|
||||||
// API, not part of the DNS packet) to true.
|
|
||||||
Compress: true,
|
|
||||||
Question: query.Question,
|
|
||||||
Answer: answer,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if mresp := resp(false); mresp != nil {
|
|
||||||
if err := s.sendResponse(mresp, from); err != nil {
|
|
||||||
return fmt.Errorf("mdns: error sending multicast response: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if uresp := resp(true); uresp != nil {
|
|
||||||
if err := s.sendResponse(uresp, from); err != nil {
|
|
||||||
return fmt.Errorf("mdns: error sending unicast response: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleQuestion is used to handle an incoming question
|
|
||||||
//
|
|
||||||
// The response to a question may be transmitted over multicast, unicast, or
|
|
||||||
// both. The return values are DNS records for each transmission type.
|
|
||||||
func (s *Server) handleQuestion(q dns.Question) (multicastRecs, unicastRecs []dns.RR) {
|
|
||||||
records := s.config.Zone.Records(q)
|
|
||||||
if len(records) == 0 {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle unicast and multicast responses.
|
|
||||||
// TODO(reddaly): The decision about sending over unicast vs. multicast is not
|
|
||||||
// yet fully compliant with RFC 6762. For example, the unicast bit should be
|
|
||||||
// ignored if the records in question are close to TTL expiration. For now,
|
|
||||||
// we just use the unicast bit to make the decision, as per the spec:
|
|
||||||
// RFC 6762, section 18.12. Repurposing of Top Bit of qclass in Question
|
|
||||||
// Section
|
|
||||||
//
|
|
||||||
// In the Question Section of a Multicast DNS query, the top bit of the
|
|
||||||
// qclass field is used to indicate that unicast responses are preferred
|
|
||||||
// for this particular question. (See Section 5.4.)
|
|
||||||
if q.Qclass&(1<<15) != 0 {
|
|
||||||
return nil, records
|
|
||||||
}
|
|
||||||
return records, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) probe() {
|
|
||||||
defer s.wg.Done()
|
|
||||||
|
|
||||||
sd, ok := s.config.Zone.(*MDNSService)
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
name := fmt.Sprintf("%s.%s.%s.", sd.Instance, trimDot(sd.Service), trimDot(sd.Domain))
|
|
||||||
|
|
||||||
q := new(dns.Msg)
|
|
||||||
q.SetQuestion(name, dns.TypePTR)
|
|
||||||
q.RecursionDesired = false
|
|
||||||
|
|
||||||
srv := &dns.SRV{
|
|
||||||
Hdr: dns.RR_Header{
|
|
||||||
Name: name,
|
|
||||||
Rrtype: dns.TypeSRV,
|
|
||||||
Class: dns.ClassINET,
|
|
||||||
Ttl: defaultTTL,
|
|
||||||
},
|
|
||||||
Priority: 0,
|
|
||||||
Weight: 0,
|
|
||||||
Port: uint16(sd.Port),
|
|
||||||
Target: sd.HostName,
|
|
||||||
}
|
|
||||||
txt := &dns.TXT{
|
|
||||||
Hdr: dns.RR_Header{
|
|
||||||
Name: name,
|
|
||||||
Rrtype: dns.TypeTXT,
|
|
||||||
Class: dns.ClassINET,
|
|
||||||
Ttl: defaultTTL,
|
|
||||||
},
|
|
||||||
Txt: sd.TXT,
|
|
||||||
}
|
|
||||||
q.Ns = []dns.RR{srv, txt}
|
|
||||||
|
|
||||||
randomizer := rand.New(rand.NewSource(time.Now().UnixNano()))
|
|
||||||
|
|
||||||
for i := 0; i < 3; i++ {
|
|
||||||
if err := s.SendMulticast(q); err != nil {
|
|
||||||
logger.Errorf(s.config.Context, "[ERR] mdns: failed to send probe: %v", err)
|
|
||||||
}
|
|
||||||
time.Sleep(time.Duration(randomizer.Intn(250)) * time.Millisecond)
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := new(dns.Msg)
|
|
||||||
resp.MsgHdr.Response = true
|
|
||||||
|
|
||||||
// set for query
|
|
||||||
q.SetQuestion(name, dns.TypeANY)
|
|
||||||
|
|
||||||
resp.Answer = append(resp.Answer, s.config.Zone.Records(q.Question[0])...)
|
|
||||||
|
|
||||||
// reset
|
|
||||||
q.SetQuestion(name, dns.TypePTR)
|
|
||||||
|
|
||||||
// From RFC6762
|
|
||||||
// The Multicast DNS responder MUST send at least two unsolicited
|
|
||||||
// responses, one second apart. To provide increased robustness against
|
|
||||||
// packet loss, a responder MAY send up to eight unsolicited responses,
|
|
||||||
// provided that the interval between unsolicited responses increases by
|
|
||||||
// at least a factor of two with every response sent.
|
|
||||||
timeout := 1 * time.Second
|
|
||||||
timer := time.NewTimer(timeout)
|
|
||||||
for i := 0; i < 3; i++ {
|
|
||||||
if err := s.SendMulticast(resp); err != nil {
|
|
||||||
logger.Errorf(s.config.Context, "[ERR] mdns: failed to send announcement: %v", err)
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case <-timer.C:
|
|
||||||
timeout *= 2
|
|
||||||
timer.Reset(timeout)
|
|
||||||
case <-s.shutdownCh:
|
|
||||||
timer.Stop()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SendMulticast us used to send a multicast response packet
|
|
||||||
func (s *Server) SendMulticast(msg *dns.Msg) error {
|
|
||||||
buf, err := msg.Pack()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if s.ipv4List != nil {
|
|
||||||
s.ipv4List.WriteToUDP(buf, ipv4Addr)
|
|
||||||
}
|
|
||||||
if s.ipv6List != nil {
|
|
||||||
s.ipv6List.WriteToUDP(buf, ipv6Addr)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// sendResponse is used to send a response packet
|
|
||||||
func (s *Server) sendResponse(resp *dns.Msg, from net.Addr) error {
|
|
||||||
// TODO(reddaly): Respect the unicast argument, and allow sending responses
|
|
||||||
// over multicast.
|
|
||||||
buf, err := resp.Pack()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Determine the socket to send from
|
|
||||||
addr := from.(*net.UDPAddr)
|
|
||||||
conn := s.ipv4List
|
|
||||||
backupTarget := net.IPv4zero
|
|
||||||
|
|
||||||
if addr.IP.To4() == nil {
|
|
||||||
conn = s.ipv6List
|
|
||||||
backupTarget = net.IPv6zero
|
|
||||||
}
|
|
||||||
_, err = conn.WriteToUDP(buf, addr)
|
|
||||||
// If the address we're responding to is this machine then we can also attempt sending on 0.0.0.0
|
|
||||||
// This covers the case where this machine is using a VPN and certain ports are blocked so the response never gets there
|
|
||||||
// Sending two responses is OK
|
|
||||||
if s.config.LocalhostChecking && addr.IP.Equal(s.outboundIP) {
|
|
||||||
// ignore any errors, this is best efforts
|
|
||||||
conn.WriteToUDP(buf, &net.UDPAddr{IP: backupTarget, Port: addr.Port})
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) unregister() error {
|
|
||||||
sd, ok := s.config.Zone.(*MDNSService)
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
atomic.StoreUint32(&sd.TTL, 0)
|
|
||||||
name := fmt.Sprintf("%s.%s.%s.", sd.Instance, trimDot(sd.Service), trimDot(sd.Domain))
|
|
||||||
|
|
||||||
q := new(dns.Msg)
|
|
||||||
q.SetQuestion(name, dns.TypeANY)
|
|
||||||
|
|
||||||
resp := new(dns.Msg)
|
|
||||||
resp.MsgHdr.Response = true
|
|
||||||
resp.Answer = append(resp.Answer, s.config.Zone.Records(q.Question[0])...)
|
|
||||||
|
|
||||||
return s.SendMulticast(resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
func setCustomPort(port int) {
|
|
||||||
if port != 0 {
|
|
||||||
if mdnsWildcardAddrIPv4.Port != port {
|
|
||||||
mdnsWildcardAddrIPv4.Port = port
|
|
||||||
}
|
|
||||||
if mdnsWildcardAddrIPv6.Port != port {
|
|
||||||
mdnsWildcardAddrIPv6.Port = port
|
|
||||||
}
|
|
||||||
if ipv4Addr.Port != port {
|
|
||||||
ipv4Addr.Port = port
|
|
||||||
}
|
|
||||||
if ipv6Addr.Port != port {
|
|
||||||
ipv6Addr.Port = port
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// getOutboundIP returns the IP address of this machine as seen when dialling out
|
|
||||||
func getOutboundIP() net.IP {
|
|
||||||
conn, err := net.Dial("udp", "8.8.8.8:80")
|
|
||||||
if err != nil {
|
|
||||||
// no net connectivity maybe so fallback
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
localAddr := conn.LocalAddr().(*net.UDPAddr)
|
|
||||||
|
|
||||||
return localAddr.IP
|
|
||||||
}
|
|
@ -1,61 +0,0 @@
|
|||||||
package mdns
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestServer_StartStop(t *testing.T) {
|
|
||||||
s := makeService(t)
|
|
||||||
serv, err := NewServer(&Config{Zone: s, LocalhostChecking: true})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("err: %v", err)
|
|
||||||
}
|
|
||||||
defer serv.Shutdown()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestServer_Lookup(t *testing.T) {
|
|
||||||
serv, err := NewServer(&Config{Zone: makeServiceWithServiceName(t, "_foobar._tcp"), LocalhostChecking: true})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("err: %v", err)
|
|
||||||
}
|
|
||||||
defer serv.Shutdown()
|
|
||||||
|
|
||||||
entries := make(chan *ServiceEntry, 1)
|
|
||||||
found := false
|
|
||||||
doneCh := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
select {
|
|
||||||
case e := <-entries:
|
|
||||||
if e.Name != "hostname._foobar._tcp.local." {
|
|
||||||
t.Fatalf("bad: %v", e)
|
|
||||||
}
|
|
||||||
if e.Port != 80 {
|
|
||||||
t.Fatalf("bad: %v", e)
|
|
||||||
}
|
|
||||||
if e.Info != "Local web server" {
|
|
||||||
t.Fatalf("bad: %v", e)
|
|
||||||
}
|
|
||||||
found = true
|
|
||||||
|
|
||||||
case <-time.After(80 * time.Millisecond):
|
|
||||||
t.Fatalf("timeout")
|
|
||||||
}
|
|
||||||
close(doneCh)
|
|
||||||
}()
|
|
||||||
|
|
||||||
params := &QueryParam{
|
|
||||||
Service: "_foobar._tcp",
|
|
||||||
Domain: "local",
|
|
||||||
Timeout: 50 * time.Millisecond,
|
|
||||||
Entries: entries,
|
|
||||||
}
|
|
||||||
err = Query(params)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("err: %v", err)
|
|
||||||
}
|
|
||||||
<-doneCh
|
|
||||||
if !found {
|
|
||||||
t.Fatalf("record not found")
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,309 +0,0 @@
|
|||||||
package mdns
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
"sync/atomic"
|
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// defaultTTL is the default TTL value in returned DNS records in seconds.
|
|
||||||
defaultTTL = 120
|
|
||||||
)
|
|
||||||
|
|
||||||
// Zone is the interface used to integrate with the server and
|
|
||||||
// to serve records dynamically
|
|
||||||
type Zone interface {
|
|
||||||
// Records returns DNS records in response to a DNS question.
|
|
||||||
Records(q dns.Question) []dns.RR
|
|
||||||
}
|
|
||||||
|
|
||||||
// MDNSService is used to export a named service by implementing a Zone
|
|
||||||
type MDNSService struct {
|
|
||||||
Instance string // Instance name (e.g. "hostService name")
|
|
||||||
Service string // Service name (e.g. "_http._tcp.")
|
|
||||||
Domain string // If blank, assumes "local"
|
|
||||||
HostName string // Host machine DNS name (e.g. "mymachine.net.")
|
|
||||||
Port int // Service Port
|
|
||||||
IPs []net.IP // IP addresses for the service's host
|
|
||||||
TXT []string // Service TXT records
|
|
||||||
TTL uint32
|
|
||||||
serviceAddr string // Fully qualified service address
|
|
||||||
instanceAddr string // Fully qualified instance address
|
|
||||||
enumAddr string // _services._dns-sd._udp.<domain>
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateFQDN returns an error if the passed string is not a fully qualified
|
|
||||||
// hdomain name (more specifically, a hostname).
|
|
||||||
func validateFQDN(s string) error {
|
|
||||||
if len(s) == 0 {
|
|
||||||
return fmt.Errorf("FQDN must not be blank")
|
|
||||||
}
|
|
||||||
if s[len(s)-1] != '.' {
|
|
||||||
return fmt.Errorf("FQDN must end in period: %s", s)
|
|
||||||
}
|
|
||||||
// TODO(reddaly): Perform full validation.
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewMDNSService returns a new instance of MDNSService.
|
|
||||||
//
|
|
||||||
// If domain, hostName, or ips is set to the zero value, then a default value
|
|
||||||
// will be inferred from the operating system.
|
|
||||||
//
|
|
||||||
// TODO(reddaly): This interface may need to change to account for "unique
|
|
||||||
// record" conflict rules of the mDNS protocol. Upon startup, the server should
|
|
||||||
// check to ensure that the instance name does not conflict with other instance
|
|
||||||
// names, and, if required, select a new name. There may also be conflicting
|
|
||||||
// hostName A/AAAA records.
|
|
||||||
func NewMDNSService(instance, service, domain, hostName string, port int, ips []net.IP, txt []string) (*MDNSService, error) {
|
|
||||||
// Sanity check inputs
|
|
||||||
if instance == "" {
|
|
||||||
return nil, fmt.Errorf("missing service instance name")
|
|
||||||
}
|
|
||||||
if service == "" {
|
|
||||||
return nil, fmt.Errorf("missing service name")
|
|
||||||
}
|
|
||||||
if port == 0 {
|
|
||||||
return nil, fmt.Errorf("missing service port")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set default domain
|
|
||||||
if domain == "" {
|
|
||||||
domain = "local."
|
|
||||||
}
|
|
||||||
if err := validateFQDN(domain); err != nil {
|
|
||||||
return nil, fmt.Errorf("domain %q is not a fully-qualified domain name: %v", domain, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get host information if no host is specified.
|
|
||||||
if hostName == "" {
|
|
||||||
var err error
|
|
||||||
hostName, err = os.Hostname()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("could not determine host: %v", err)
|
|
||||||
}
|
|
||||||
hostName = fmt.Sprintf("%s.", hostName)
|
|
||||||
}
|
|
||||||
if err := validateFQDN(hostName); err != nil {
|
|
||||||
return nil, fmt.Errorf("hostName %q is not a fully-qualified domain name: %v", hostName, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(ips) == 0 {
|
|
||||||
var err error
|
|
||||||
ips, err = net.LookupIP(trimDot(hostName))
|
|
||||||
if err != nil {
|
|
||||||
// Try appending the host domain suffix and lookup again
|
|
||||||
// (required for Linux-based hosts)
|
|
||||||
tmpHostName := fmt.Sprintf("%s%s", hostName, domain)
|
|
||||||
|
|
||||||
ips, err = net.LookupIP(trimDot(tmpHostName))
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("could not determine host IP addresses for %s", hostName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, ip := range ips {
|
|
||||||
if ip.To4() == nil && ip.To16() == nil {
|
|
||||||
return nil, fmt.Errorf("invalid IP address in IPs list: %v", ip)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &MDNSService{
|
|
||||||
Instance: instance,
|
|
||||||
Service: service,
|
|
||||||
Domain: domain,
|
|
||||||
HostName: hostName,
|
|
||||||
Port: port,
|
|
||||||
IPs: ips,
|
|
||||||
TXT: txt,
|
|
||||||
TTL: defaultTTL,
|
|
||||||
serviceAddr: fmt.Sprintf("%s.%s.", trimDot(service), trimDot(domain)),
|
|
||||||
instanceAddr: fmt.Sprintf("%s.%s.%s.", instance, trimDot(service), trimDot(domain)),
|
|
||||||
enumAddr: fmt.Sprintf("_services._dns-sd._udp.%s.", trimDot(domain)),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// trimDot is used to trim the dots from the start or end of a string
|
|
||||||
func trimDot(s string) string {
|
|
||||||
return strings.Trim(s, ".")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Records returns DNS records in response to a DNS question.
|
|
||||||
func (m *MDNSService) Records(q dns.Question) []dns.RR {
|
|
||||||
switch q.Name {
|
|
||||||
case m.enumAddr:
|
|
||||||
return m.serviceEnum(q)
|
|
||||||
case m.serviceAddr:
|
|
||||||
return m.serviceRecords(q)
|
|
||||||
case m.instanceAddr:
|
|
||||||
return m.instanceRecords(q)
|
|
||||||
case m.HostName:
|
|
||||||
if q.Qtype == dns.TypeA || q.Qtype == dns.TypeAAAA {
|
|
||||||
return m.instanceRecords(q)
|
|
||||||
}
|
|
||||||
fallthrough
|
|
||||||
default:
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MDNSService) serviceEnum(q dns.Question) []dns.RR {
|
|
||||||
switch q.Qtype {
|
|
||||||
case dns.TypeANY:
|
|
||||||
fallthrough
|
|
||||||
case dns.TypePTR:
|
|
||||||
rr := &dns.PTR{
|
|
||||||
Hdr: dns.RR_Header{
|
|
||||||
Name: q.Name,
|
|
||||||
Rrtype: dns.TypePTR,
|
|
||||||
Class: dns.ClassINET,
|
|
||||||
Ttl: atomic.LoadUint32(&m.TTL),
|
|
||||||
},
|
|
||||||
Ptr: m.serviceAddr,
|
|
||||||
}
|
|
||||||
return []dns.RR{rr}
|
|
||||||
default:
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// serviceRecords is called when the query matches the service name
|
|
||||||
func (m *MDNSService) serviceRecords(q dns.Question) []dns.RR {
|
|
||||||
switch q.Qtype {
|
|
||||||
case dns.TypeANY:
|
|
||||||
fallthrough
|
|
||||||
case dns.TypePTR:
|
|
||||||
// Build a PTR response for the service
|
|
||||||
rr := &dns.PTR{
|
|
||||||
Hdr: dns.RR_Header{
|
|
||||||
Name: q.Name,
|
|
||||||
Rrtype: dns.TypePTR,
|
|
||||||
Class: dns.ClassINET,
|
|
||||||
Ttl: atomic.LoadUint32(&m.TTL),
|
|
||||||
},
|
|
||||||
Ptr: m.instanceAddr,
|
|
||||||
}
|
|
||||||
servRec := []dns.RR{rr}
|
|
||||||
|
|
||||||
// Get the instance records
|
|
||||||
instRecs := m.instanceRecords(dns.Question{
|
|
||||||
Name: m.instanceAddr,
|
|
||||||
Qtype: dns.TypeANY,
|
|
||||||
})
|
|
||||||
|
|
||||||
// Return the service record with the instance records
|
|
||||||
return append(servRec, instRecs...)
|
|
||||||
default:
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// serviceRecords is called when the query matches the instance name
|
|
||||||
func (m *MDNSService) instanceRecords(q dns.Question) []dns.RR {
|
|
||||||
switch q.Qtype {
|
|
||||||
case dns.TypeANY:
|
|
||||||
// Get the SRV, which includes A and AAAA
|
|
||||||
recs := m.instanceRecords(dns.Question{
|
|
||||||
Name: m.instanceAddr,
|
|
||||||
Qtype: dns.TypeSRV,
|
|
||||||
})
|
|
||||||
|
|
||||||
// Add the TXT record
|
|
||||||
recs = append(recs, m.instanceRecords(dns.Question{
|
|
||||||
Name: m.instanceAddr,
|
|
||||||
Qtype: dns.TypeTXT,
|
|
||||||
})...)
|
|
||||||
return recs
|
|
||||||
|
|
||||||
case dns.TypeA:
|
|
||||||
var rr []dns.RR
|
|
||||||
for _, ip := range m.IPs {
|
|
||||||
if ip4 := ip.To4(); ip4 != nil {
|
|
||||||
rr = append(rr, &dns.A{
|
|
||||||
Hdr: dns.RR_Header{
|
|
||||||
Name: m.HostName,
|
|
||||||
Rrtype: dns.TypeA,
|
|
||||||
Class: dns.ClassINET,
|
|
||||||
Ttl: atomic.LoadUint32(&m.TTL),
|
|
||||||
},
|
|
||||||
A: ip4,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return rr
|
|
||||||
|
|
||||||
case dns.TypeAAAA:
|
|
||||||
var rr []dns.RR
|
|
||||||
for _, ip := range m.IPs {
|
|
||||||
if ip.To4() != nil {
|
|
||||||
// TODO(reddaly): IPv4 addresses could be encoded in IPv6 format and
|
|
||||||
// putinto AAAA records, but the current logic puts ipv4-encodable
|
|
||||||
// addresses into the A records exclusively. Perhaps this should be
|
|
||||||
// configurable?
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if ip16 := ip.To16(); ip16 != nil {
|
|
||||||
rr = append(rr, &dns.AAAA{
|
|
||||||
Hdr: dns.RR_Header{
|
|
||||||
Name: m.HostName,
|
|
||||||
Rrtype: dns.TypeAAAA,
|
|
||||||
Class: dns.ClassINET,
|
|
||||||
Ttl: atomic.LoadUint32(&m.TTL),
|
|
||||||
},
|
|
||||||
AAAA: ip16,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return rr
|
|
||||||
|
|
||||||
case dns.TypeSRV:
|
|
||||||
// Create the SRV Record
|
|
||||||
srv := &dns.SRV{
|
|
||||||
Hdr: dns.RR_Header{
|
|
||||||
Name: q.Name,
|
|
||||||
Rrtype: dns.TypeSRV,
|
|
||||||
Class: dns.ClassINET,
|
|
||||||
Ttl: atomic.LoadUint32(&m.TTL),
|
|
||||||
},
|
|
||||||
Priority: 10,
|
|
||||||
Weight: 1,
|
|
||||||
Port: uint16(m.Port),
|
|
||||||
Target: m.HostName,
|
|
||||||
}
|
|
||||||
recs := []dns.RR{srv}
|
|
||||||
|
|
||||||
// Add the A record
|
|
||||||
recs = append(recs, m.instanceRecords(dns.Question{
|
|
||||||
Name: m.instanceAddr,
|
|
||||||
Qtype: dns.TypeA,
|
|
||||||
})...)
|
|
||||||
|
|
||||||
// Add the AAAA record
|
|
||||||
recs = append(recs, m.instanceRecords(dns.Question{
|
|
||||||
Name: m.instanceAddr,
|
|
||||||
Qtype: dns.TypeAAAA,
|
|
||||||
})...)
|
|
||||||
return recs
|
|
||||||
|
|
||||||
case dns.TypeTXT:
|
|
||||||
txt := &dns.TXT{
|
|
||||||
Hdr: dns.RR_Header{
|
|
||||||
Name: q.Name,
|
|
||||||
Rrtype: dns.TypeTXT,
|
|
||||||
Class: dns.ClassINET,
|
|
||||||
Ttl: atomic.LoadUint32(&m.TTL),
|
|
||||||
},
|
|
||||||
Txt: m.TXT,
|
|
||||||
}
|
|
||||||
return []dns.RR{txt}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
@ -1,275 +0,0 @@
|
|||||||
package mdns
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"net"
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
)
|
|
||||||
|
|
||||||
func makeService(t *testing.T) *MDNSService {
|
|
||||||
return makeServiceWithServiceName(t, "_http._tcp")
|
|
||||||
}
|
|
||||||
|
|
||||||
func makeServiceWithServiceName(t *testing.T, service string) *MDNSService {
|
|
||||||
m, err := NewMDNSService(
|
|
||||||
"hostname",
|
|
||||||
service,
|
|
||||||
"local.",
|
|
||||||
"testhost.",
|
|
||||||
80, // port
|
|
||||||
[]net.IP{net.IP([]byte{192, 168, 0, 42}), net.ParseIP("2620:0:1000:1900:b0c2:d0b2:c411:18bc")},
|
|
||||||
[]string{"Local web server"}) // TXT
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("err: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewMDNSService_BadParams(t *testing.T) {
|
|
||||||
for _, test := range []struct {
|
|
||||||
testName string
|
|
||||||
hostName string
|
|
||||||
domain string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
"NewMDNSService should fail when passed hostName that is not a legal fully-qualified domain name",
|
|
||||||
"hostname", // not legal FQDN - should be "hostname." or "hostname.local.", etc.
|
|
||||||
"local.", // legal
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"NewMDNSService should fail when passed domain that is not a legal fully-qualified domain name",
|
|
||||||
"hostname.", // legal
|
|
||||||
"local", // should be "local."
|
|
||||||
},
|
|
||||||
} {
|
|
||||||
_, err := NewMDNSService(
|
|
||||||
"instance name",
|
|
||||||
"_http._tcp",
|
|
||||||
test.domain,
|
|
||||||
test.hostName,
|
|
||||||
80, // port
|
|
||||||
[]net.IP{net.IP([]byte{192, 168, 0, 42})},
|
|
||||||
[]string{"Local web server"}) // TXT
|
|
||||||
if err == nil {
|
|
||||||
t.Fatalf("%s: error expected, but got none", test.testName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMDNSService_BadAddr(t *testing.T) {
|
|
||||||
s := makeService(t)
|
|
||||||
q := dns.Question{
|
|
||||||
Name: "random",
|
|
||||||
Qtype: dns.TypeANY,
|
|
||||||
}
|
|
||||||
recs := s.Records(q)
|
|
||||||
if len(recs) != 0 {
|
|
||||||
t.Fatalf("bad: %v", recs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMDNSService_ServiceAddr(t *testing.T) {
|
|
||||||
s := makeService(t)
|
|
||||||
q := dns.Question{
|
|
||||||
Name: "_http._tcp.local.",
|
|
||||||
Qtype: dns.TypeANY,
|
|
||||||
}
|
|
||||||
recs := s.Records(q)
|
|
||||||
if got, want := len(recs), 5; got != want {
|
|
||||||
t.Fatalf("got %d records, want %d: %v", got, want, recs)
|
|
||||||
}
|
|
||||||
|
|
||||||
if ptr, ok := recs[0].(*dns.PTR); !ok {
|
|
||||||
t.Errorf("recs[0] should be PTR record, got: %v, all records: %v", recs[0], recs)
|
|
||||||
} else if got, want := ptr.Ptr, "hostname._http._tcp.local."; got != want {
|
|
||||||
t.Fatalf("bad PTR record %v: got %v, want %v", ptr, got, want)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, ok := recs[1].(*dns.SRV); !ok {
|
|
||||||
t.Errorf("recs[1] should be SRV record, got: %v, all reccords: %v", recs[1], recs)
|
|
||||||
}
|
|
||||||
if _, ok := recs[2].(*dns.A); !ok {
|
|
||||||
t.Errorf("recs[2] should be A record, got: %v, all records: %v", recs[2], recs)
|
|
||||||
}
|
|
||||||
if _, ok := recs[3].(*dns.AAAA); !ok {
|
|
||||||
t.Errorf("recs[3] should be AAAA record, got: %v, all records: %v", recs[3], recs)
|
|
||||||
}
|
|
||||||
if _, ok := recs[4].(*dns.TXT); !ok {
|
|
||||||
t.Errorf("recs[4] should be TXT record, got: %v, all records: %v", recs[4], recs)
|
|
||||||
}
|
|
||||||
|
|
||||||
q.Qtype = dns.TypePTR
|
|
||||||
if recs2 := s.Records(q); !reflect.DeepEqual(recs, recs2) {
|
|
||||||
t.Fatalf("PTR question should return same result as ANY question: ANY => %v, PTR => %v", recs, recs2)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMDNSService_InstanceAddr_ANY(t *testing.T) {
|
|
||||||
s := makeService(t)
|
|
||||||
q := dns.Question{
|
|
||||||
Name: "hostname._http._tcp.local.",
|
|
||||||
Qtype: dns.TypeANY,
|
|
||||||
}
|
|
||||||
recs := s.Records(q)
|
|
||||||
if len(recs) != 4 {
|
|
||||||
t.Fatalf("bad: %v", recs)
|
|
||||||
}
|
|
||||||
if _, ok := recs[0].(*dns.SRV); !ok {
|
|
||||||
t.Fatalf("bad: %v", recs[0])
|
|
||||||
}
|
|
||||||
if _, ok := recs[1].(*dns.A); !ok {
|
|
||||||
t.Fatalf("bad: %v", recs[1])
|
|
||||||
}
|
|
||||||
if _, ok := recs[2].(*dns.AAAA); !ok {
|
|
||||||
t.Fatalf("bad: %v", recs[2])
|
|
||||||
}
|
|
||||||
if _, ok := recs[3].(*dns.TXT); !ok {
|
|
||||||
t.Fatalf("bad: %v", recs[3])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMDNSService_InstanceAddr_SRV(t *testing.T) {
|
|
||||||
s := makeService(t)
|
|
||||||
q := dns.Question{
|
|
||||||
Name: "hostname._http._tcp.local.",
|
|
||||||
Qtype: dns.TypeSRV,
|
|
||||||
}
|
|
||||||
recs := s.Records(q)
|
|
||||||
if len(recs) != 3 {
|
|
||||||
t.Fatalf("bad: %v", recs)
|
|
||||||
}
|
|
||||||
srv, ok := recs[0].(*dns.SRV)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("bad: %v", recs[0])
|
|
||||||
}
|
|
||||||
if _, ok := recs[1].(*dns.A); !ok {
|
|
||||||
t.Fatalf("bad: %v", recs[1])
|
|
||||||
}
|
|
||||||
if _, ok := recs[2].(*dns.AAAA); !ok {
|
|
||||||
t.Fatalf("bad: %v", recs[2])
|
|
||||||
}
|
|
||||||
|
|
||||||
if srv.Port != uint16(s.Port) {
|
|
||||||
t.Fatalf("bad: %v", recs[0])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMDNSService_InstanceAddr_A(t *testing.T) {
|
|
||||||
s := makeService(t)
|
|
||||||
q := dns.Question{
|
|
||||||
Name: "hostname._http._tcp.local.",
|
|
||||||
Qtype: dns.TypeA,
|
|
||||||
}
|
|
||||||
recs := s.Records(q)
|
|
||||||
if len(recs) != 1 {
|
|
||||||
t.Fatalf("bad: %v", recs)
|
|
||||||
}
|
|
||||||
a, ok := recs[0].(*dns.A)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("bad: %v", recs[0])
|
|
||||||
}
|
|
||||||
if !bytes.Equal(a.A, []byte{192, 168, 0, 42}) {
|
|
||||||
t.Fatalf("bad: %v", recs[0])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMDNSService_InstanceAddr_AAAA(t *testing.T) {
|
|
||||||
s := makeService(t)
|
|
||||||
q := dns.Question{
|
|
||||||
Name: "hostname._http._tcp.local.",
|
|
||||||
Qtype: dns.TypeAAAA,
|
|
||||||
}
|
|
||||||
recs := s.Records(q)
|
|
||||||
if len(recs) != 1 {
|
|
||||||
t.Fatalf("bad: %v", recs)
|
|
||||||
}
|
|
||||||
a4, ok := recs[0].(*dns.AAAA)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("bad: %v", recs[0])
|
|
||||||
}
|
|
||||||
ip6 := net.ParseIP("2620:0:1000:1900:b0c2:d0b2:c411:18bc")
|
|
||||||
if got := len(ip6); got != net.IPv6len {
|
|
||||||
t.Fatalf("test IP failed to parse (len = %d, want %d)", got, net.IPv6len)
|
|
||||||
}
|
|
||||||
if !bytes.Equal(a4.AAAA, ip6) {
|
|
||||||
t.Fatalf("bad: %v", recs[0])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMDNSService_InstanceAddr_TXT(t *testing.T) {
|
|
||||||
s := makeService(t)
|
|
||||||
q := dns.Question{
|
|
||||||
Name: "hostname._http._tcp.local.",
|
|
||||||
Qtype: dns.TypeTXT,
|
|
||||||
}
|
|
||||||
recs := s.Records(q)
|
|
||||||
if len(recs) != 1 {
|
|
||||||
t.Fatalf("bad: %v", recs)
|
|
||||||
}
|
|
||||||
txt, ok := recs[0].(*dns.TXT)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("bad: %v", recs[0])
|
|
||||||
}
|
|
||||||
if got, want := txt.Txt, s.TXT; !reflect.DeepEqual(got, want) {
|
|
||||||
t.Fatalf("TXT record mismatch for %v: got %v, want %v", recs[0], got, want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMDNSService_HostNameQuery(t *testing.T) {
|
|
||||||
s := makeService(t)
|
|
||||||
for _, test := range []struct {
|
|
||||||
q dns.Question
|
|
||||||
want []dns.RR
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
dns.Question{Name: "testhost.", Qtype: dns.TypeA},
|
|
||||||
[]dns.RR{&dns.A{
|
|
||||||
Hdr: dns.RR_Header{
|
|
||||||
Name: "testhost.",
|
|
||||||
Rrtype: dns.TypeA,
|
|
||||||
Class: dns.ClassINET,
|
|
||||||
Ttl: 120,
|
|
||||||
},
|
|
||||||
A: net.IP([]byte{192, 168, 0, 42}),
|
|
||||||
}},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
dns.Question{Name: "testhost.", Qtype: dns.TypeAAAA},
|
|
||||||
[]dns.RR{&dns.AAAA{
|
|
||||||
Hdr: dns.RR_Header{
|
|
||||||
Name: "testhost.",
|
|
||||||
Rrtype: dns.TypeAAAA,
|
|
||||||
Class: dns.ClassINET,
|
|
||||||
Ttl: 120,
|
|
||||||
},
|
|
||||||
AAAA: net.ParseIP("2620:0:1000:1900:b0c2:d0b2:c411:18bc"),
|
|
||||||
}},
|
|
||||||
},
|
|
||||||
} {
|
|
||||||
if got := s.Records(test.q); !reflect.DeepEqual(got, test.want) {
|
|
||||||
t.Errorf("hostname query failed: s.Records(%v) = %v, want %v", test.q, got, test.want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMDNSService_serviceEnum_PTR(t *testing.T) {
|
|
||||||
s := makeService(t)
|
|
||||||
q := dns.Question{
|
|
||||||
Name: "_services._dns-sd._udp.local.",
|
|
||||||
Qtype: dns.TypePTR,
|
|
||||||
}
|
|
||||||
recs := s.Records(q)
|
|
||||||
if len(recs) != 1 {
|
|
||||||
t.Fatalf("bad: %v", recs)
|
|
||||||
}
|
|
||||||
if ptr, ok := recs[0].(*dns.PTR); !ok {
|
|
||||||
t.Errorf("recs[0] should be PTR record, got: %v, all records: %v", recs[0], recs)
|
|
||||||
} else if got, want := ptr.Ptr, "_http._tcp.local."; got != want {
|
|
||||||
t.Fatalf("bad PTR record %v: got %v, want %v", ptr, got, want)
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Reference in New Issue
Block a user