Merge pull request #15 from gonzojive/refactor-client

Refactor client.go.
This commit is contained in:
Armon Dadgar 2014-10-16 10:39:12 -07:00
commit 6c44326b32

107
client.go
View File

@ -1,15 +1,16 @@
package mdns
import (
"code.google.com/p/go.net/ipv4"
"code.google.com/p/go.net/ipv6"
"fmt"
"github.com/miekg/dns"
"log"
"net"
"strings"
"sync"
"time"
"code.google.com/p/go.net/ipv4"
"code.google.com/p/go.net/ipv6"
"github.com/miekg/dns"
)
// ServiceEntry is returned after we query for a service
@ -39,6 +40,7 @@ type QueryParam struct {
Timeout time.Duration // Lookup timeout, default 1 second
Interface *net.Interface // Multicast interface to use
Entries chan<- *ServiceEntry // Entries Channel
WantUnicastResponse bool // Unicast response desired, as per 5.4 in RFC
}
// DefaultParams is used to return a default set of QueryParam's
@ -48,6 +50,7 @@ func DefaultParams(service string) *QueryParam {
Domain: "local",
Timeout: time.Second,
Entries: make(chan *ServiceEntry),
WantUnicastResponse: false, // TODO(reddaly): Change this default.
}
}
@ -92,34 +95,53 @@ func Lookup(service string, entries chan<- *ServiceEntry) error {
// Client provides a query interface that can be used to
// search for service providers using mDNS
type client struct {
ipv4List *net.UDPConn
ipv6List *net.UDPConn
ipv4UnicastConn *net.UDPConn
ipv6UnicastConn *net.UDPConn
ipv4MulticastConn *net.UDPConn
ipv6MulticastConn *net.UDPConn
closed bool
closedCh chan struct{}
closedCh chan struct{} // TODO(reddaly): This doesn't appear to be used.
closeLock sync.Mutex
}
// NewClient creates a new mdns Client that can be used to query
// for records
func newClient() (*client, error) {
// TODO(reddaly): At least attempt to bind to the port required in the spec.
// Create a IPv4 listener
ipv4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
uconn4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
log.Printf("[ERR] mdns: Failed to bind to udp4 port: %v", err)
}
ipv6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
uconn6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
if err != nil {
log.Printf("[ERR] mdns: Failed to bind to udp6 port: %v", err)
}
if ipv4 == nil && ipv6 == nil {
return nil, fmt.Errorf("Failed to bind to any udp port!")
if uconn4 == nil && uconn6 == nil {
return nil, fmt.Errorf("failed to bind to any unicast udp port")
}
mconn4, err := net.ListenMulticastUDP("udp4", nil, ipv4Addr)
if err != nil {
log.Printf("[ERR] mdns: Failed to bind to udp4 port: %v", err)
}
mconn6, err := net.ListenMulticastUDP("udp6", nil, ipv6Addr)
if err != nil {
log.Printf("[ERR] mdns: Failed to bind to udp6 port: %v", err)
}
if mconn4 == nil && mconn6 == nil {
return nil, fmt.Errorf("failed to bind to any multicast udp port")
}
c := &client{
ipv4List: ipv4,
ipv6List: ipv6,
ipv4MulticastConn: mconn4,
ipv6MulticastConn: mconn6,
ipv4UnicastConn: uconn4,
ipv6UnicastConn: uconn6,
closedCh: make(chan struct{}),
}
return c, nil
@ -134,25 +156,42 @@ func (c *client) Close() error {
return nil
}
c.closed = true
log.Printf("[INFO] mdns: Closing client %v", *c)
close(c.closedCh)
if c.ipv4List != nil {
c.ipv4List.Close()
if c.ipv4UnicastConn != nil {
c.ipv4UnicastConn.Close()
}
if c.ipv6List != nil {
c.ipv6List.Close()
if c.ipv6UnicastConn != nil {
c.ipv6UnicastConn.Close()
}
if c.ipv4MulticastConn != nil {
c.ipv4MulticastConn.Close()
}
if c.ipv6MulticastConn != nil {
c.ipv6MulticastConn.Close()
}
return nil
}
// setInterface is used to set the query interface, uses sytem
// default if not provided
func (c *client) setInterface(iface *net.Interface) error {
p := ipv4.NewPacketConn(c.ipv4List)
p := ipv4.NewPacketConn(c.ipv4UnicastConn)
if err := p.SetMulticastInterface(iface); err != nil {
return err
}
p2 := ipv6.NewPacketConn(c.ipv6List)
p2 := ipv6.NewPacketConn(c.ipv6UnicastConn)
if err := p2.SetMulticastInterface(iface); err != nil {
return err
}
p = ipv4.NewPacketConn(c.ipv4MulticastConn)
if err := p.SetMulticastInterface(iface); err != nil {
return err
}
p2 = ipv6.NewPacketConn(c.ipv6MulticastConn)
if err := p2.SetMulticastInterface(iface); err != nil {
return err
}
@ -166,15 +205,26 @@ func (c *client) query(params *QueryParam) error {
// Start listening for response packets
msgCh := make(chan *dns.Msg, 32)
go c.recv(c.ipv4List, msgCh)
go c.recv(c.ipv6List, msgCh)
go c.recv(c.ipv4UnicastConn, msgCh)
go c.recv(c.ipv6UnicastConn, msgCh)
go c.recv(c.ipv4MulticastConn, msgCh)
go c.recv(c.ipv6MulticastConn, msgCh)
// Send the query
m := new(dns.Msg)
m.SetQuestion(serviceAddr, dns.TypePTR)
// RFC 6762, section 18.12. Repurposing of Top Bit of qclass in Question
// Section
//
// In the Question Section of a Multicast DNS query, the top bit of the qclass
// field is used to indicate that unicast responses are preferred for this
// particular question. (See Section 5.4.)
if params.WantUnicastResponse {
m.Question[0].Qclass |= 1 << 15
}
m.RecursionDesired = false
if err := c.sendQuery(m); err != nil {
return nil
return err
}
// Map the in-progress responses
@ -187,6 +237,7 @@ func (c *client) query(params *QueryParam) error {
case resp := <-msgCh:
var inp *ServiceEntry
for _, answer := range resp.Answer {
// TODO(reddaly): Check that response corresponds to serviceAddr?
switch rr := answer.(type) {
case *dns.PTR:
// Create new entry for this
@ -223,6 +274,10 @@ func (c *client) query(params *QueryParam) error {
}
}
if inp == nil {
continue
}
// Check if this entry is complete
if inp.complete() {
if inp.sent {
@ -246,7 +301,6 @@ func (c *client) query(params *QueryParam) error {
return nil
}
}
return nil
}
// sendQuery is used to multicast a query out
@ -255,11 +309,11 @@ func (c *client) sendQuery(q *dns.Msg) error {
if err != nil {
return err
}
if c.ipv4List != nil {
c.ipv4List.WriteTo(buf, ipv4Addr)
if c.ipv4UnicastConn != nil {
c.ipv4UnicastConn.WriteToUDP(buf, ipv4Addr)
}
if c.ipv6List != nil {
c.ipv6List.WriteTo(buf, ipv6Addr)
if c.ipv6UnicastConn != nil {
c.ipv6UnicastConn.WriteToUDP(buf, ipv6Addr)
}
return nil
}
@ -273,6 +327,7 @@ func (c *client) recv(l *net.UDPConn, msgCh chan *dns.Msg) {
for !c.closed {
n, err := l.Read(buf)
if err != nil {
log.Printf("[ERR] mdns: Failed to read packet: %v", err)
continue
}
msg := new(dns.Msg)