Merge pull request #15 from gonzojive/refactor-client

Refactor client.go.
This commit is contained in:
Armon Dadgar 2014-10-16 10:39:12 -07:00
commit 6c44326b32

107
client.go
View File

@ -1,15 +1,16 @@
package mdns package mdns
import ( import (
"code.google.com/p/go.net/ipv4"
"code.google.com/p/go.net/ipv6"
"fmt" "fmt"
"github.com/miekg/dns"
"log" "log"
"net" "net"
"strings" "strings"
"sync" "sync"
"time" "time"
"code.google.com/p/go.net/ipv4"
"code.google.com/p/go.net/ipv6"
"github.com/miekg/dns"
) )
// ServiceEntry is returned after we query for a service // ServiceEntry is returned after we query for a service
@ -39,6 +40,7 @@ type QueryParam struct {
Timeout time.Duration // Lookup timeout, default 1 second Timeout time.Duration // Lookup timeout, default 1 second
Interface *net.Interface // Multicast interface to use Interface *net.Interface // Multicast interface to use
Entries chan<- *ServiceEntry // Entries Channel Entries chan<- *ServiceEntry // Entries Channel
WantUnicastResponse bool // Unicast response desired, as per 5.4 in RFC
} }
// DefaultParams is used to return a default set of QueryParam's // DefaultParams is used to return a default set of QueryParam's
@ -48,6 +50,7 @@ func DefaultParams(service string) *QueryParam {
Domain: "local", Domain: "local",
Timeout: time.Second, Timeout: time.Second,
Entries: make(chan *ServiceEntry), Entries: make(chan *ServiceEntry),
WantUnicastResponse: false, // TODO(reddaly): Change this default.
} }
} }
@ -92,34 +95,53 @@ func Lookup(service string, entries chan<- *ServiceEntry) error {
// Client provides a query interface that can be used to // Client provides a query interface that can be used to
// search for service providers using mDNS // search for service providers using mDNS
type client struct { type client struct {
ipv4List *net.UDPConn ipv4UnicastConn *net.UDPConn
ipv6List *net.UDPConn ipv6UnicastConn *net.UDPConn
ipv4MulticastConn *net.UDPConn
ipv6MulticastConn *net.UDPConn
closed bool closed bool
closedCh chan struct{} closedCh chan struct{} // TODO(reddaly): This doesn't appear to be used.
closeLock sync.Mutex closeLock sync.Mutex
} }
// NewClient creates a new mdns Client that can be used to query // NewClient creates a new mdns Client that can be used to query
// for records // for records
func newClient() (*client, error) { func newClient() (*client, error) {
// TODO(reddaly): At least attempt to bind to the port required in the spec.
// Create a IPv4 listener // Create a IPv4 listener
ipv4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) uconn4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil { if err != nil {
log.Printf("[ERR] mdns: Failed to bind to udp4 port: %v", err) log.Printf("[ERR] mdns: Failed to bind to udp4 port: %v", err)
} }
ipv6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0}) uconn6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
if err != nil { if err != nil {
log.Printf("[ERR] mdns: Failed to bind to udp6 port: %v", err) log.Printf("[ERR] mdns: Failed to bind to udp6 port: %v", err)
} }
if ipv4 == nil && ipv6 == nil { if uconn4 == nil && uconn6 == nil {
return nil, fmt.Errorf("Failed to bind to any udp port!") return nil, fmt.Errorf("failed to bind to any unicast udp port")
}
mconn4, err := net.ListenMulticastUDP("udp4", nil, ipv4Addr)
if err != nil {
log.Printf("[ERR] mdns: Failed to bind to udp4 port: %v", err)
}
mconn6, err := net.ListenMulticastUDP("udp6", nil, ipv6Addr)
if err != nil {
log.Printf("[ERR] mdns: Failed to bind to udp6 port: %v", err)
}
if mconn4 == nil && mconn6 == nil {
return nil, fmt.Errorf("failed to bind to any multicast udp port")
} }
c := &client{ c := &client{
ipv4List: ipv4, ipv4MulticastConn: mconn4,
ipv6List: ipv6, ipv6MulticastConn: mconn6,
ipv4UnicastConn: uconn4,
ipv6UnicastConn: uconn6,
closedCh: make(chan struct{}), closedCh: make(chan struct{}),
} }
return c, nil return c, nil
@ -134,25 +156,42 @@ func (c *client) Close() error {
return nil return nil
} }
c.closed = true c.closed = true
log.Printf("[INFO] mdns: Closing client %v", *c)
close(c.closedCh) close(c.closedCh)
if c.ipv4List != nil { if c.ipv4UnicastConn != nil {
c.ipv4List.Close() c.ipv4UnicastConn.Close()
} }
if c.ipv6List != nil { if c.ipv6UnicastConn != nil {
c.ipv6List.Close() c.ipv6UnicastConn.Close()
} }
if c.ipv4MulticastConn != nil {
c.ipv4MulticastConn.Close()
}
if c.ipv6MulticastConn != nil {
c.ipv6MulticastConn.Close()
}
return nil return nil
} }
// setInterface is used to set the query interface, uses sytem // setInterface is used to set the query interface, uses sytem
// default if not provided // default if not provided
func (c *client) setInterface(iface *net.Interface) error { func (c *client) setInterface(iface *net.Interface) error {
p := ipv4.NewPacketConn(c.ipv4List) p := ipv4.NewPacketConn(c.ipv4UnicastConn)
if err := p.SetMulticastInterface(iface); err != nil { if err := p.SetMulticastInterface(iface); err != nil {
return err return err
} }
p2 := ipv6.NewPacketConn(c.ipv6List) p2 := ipv6.NewPacketConn(c.ipv6UnicastConn)
if err := p2.SetMulticastInterface(iface); err != nil {
return err
}
p = ipv4.NewPacketConn(c.ipv4MulticastConn)
if err := p.SetMulticastInterface(iface); err != nil {
return err
}
p2 = ipv6.NewPacketConn(c.ipv6MulticastConn)
if err := p2.SetMulticastInterface(iface); err != nil { if err := p2.SetMulticastInterface(iface); err != nil {
return err return err
} }
@ -166,15 +205,26 @@ func (c *client) query(params *QueryParam) error {
// Start listening for response packets // Start listening for response packets
msgCh := make(chan *dns.Msg, 32) msgCh := make(chan *dns.Msg, 32)
go c.recv(c.ipv4List, msgCh) go c.recv(c.ipv4UnicastConn, msgCh)
go c.recv(c.ipv6List, msgCh) go c.recv(c.ipv6UnicastConn, msgCh)
go c.recv(c.ipv4MulticastConn, msgCh)
go c.recv(c.ipv6MulticastConn, msgCh)
// Send the query // Send the query
m := new(dns.Msg) m := new(dns.Msg)
m.SetQuestion(serviceAddr, dns.TypePTR) m.SetQuestion(serviceAddr, dns.TypePTR)
// RFC 6762, section 18.12. Repurposing of Top Bit of qclass in Question
// Section
//
// In the Question Section of a Multicast DNS query, the top bit of the qclass
// field is used to indicate that unicast responses are preferred for this
// particular question. (See Section 5.4.)
if params.WantUnicastResponse {
m.Question[0].Qclass |= 1 << 15
}
m.RecursionDesired = false m.RecursionDesired = false
if err := c.sendQuery(m); err != nil { if err := c.sendQuery(m); err != nil {
return nil return err
} }
// Map the in-progress responses // Map the in-progress responses
@ -187,6 +237,7 @@ func (c *client) query(params *QueryParam) error {
case resp := <-msgCh: case resp := <-msgCh:
var inp *ServiceEntry var inp *ServiceEntry
for _, answer := range resp.Answer { for _, answer := range resp.Answer {
// TODO(reddaly): Check that response corresponds to serviceAddr?
switch rr := answer.(type) { switch rr := answer.(type) {
case *dns.PTR: case *dns.PTR:
// Create new entry for this // Create new entry for this
@ -223,6 +274,10 @@ func (c *client) query(params *QueryParam) error {
} }
} }
if inp == nil {
continue
}
// Check if this entry is complete // Check if this entry is complete
if inp.complete() { if inp.complete() {
if inp.sent { if inp.sent {
@ -246,7 +301,6 @@ func (c *client) query(params *QueryParam) error {
return nil return nil
} }
} }
return nil
} }
// sendQuery is used to multicast a query out // sendQuery is used to multicast a query out
@ -255,11 +309,11 @@ func (c *client) sendQuery(q *dns.Msg) error {
if err != nil { if err != nil {
return err return err
} }
if c.ipv4List != nil { if c.ipv4UnicastConn != nil {
c.ipv4List.WriteTo(buf, ipv4Addr) c.ipv4UnicastConn.WriteToUDP(buf, ipv4Addr)
} }
if c.ipv6List != nil { if c.ipv6UnicastConn != nil {
c.ipv6List.WriteTo(buf, ipv6Addr) c.ipv6UnicastConn.WriteToUDP(buf, ipv6Addr)
} }
return nil return nil
} }
@ -273,6 +327,7 @@ func (c *client) recv(l *net.UDPConn, msgCh chan *dns.Msg) {
for !c.closed { for !c.closed {
n, err := l.Read(buf) n, err := l.Read(buf)
if err != nil { if err != nil {
log.Printf("[ERR] mdns: Failed to read packet: %v", err)
continue continue
} }
msg := new(dns.Msg) msg := new(dns.Msg)