Adding ability to set multicast query interface

This commit is contained in:
Armon Dadgar 2014-02-25 14:16:06 -08:00
parent dab97f2090
commit 9215784091
2 changed files with 69 additions and 15 deletions

View File

@ -1,6 +1,8 @@
package mdns package mdns
import ( import (
"code.google.com/p/go.net/ipv4"
"code.google.com/p/go.net/ipv6"
"fmt" "fmt"
"github.com/miekg/dns" "github.com/miekg/dns"
"log" "log"
@ -26,11 +28,30 @@ func (s *ServiceEntry) complete() bool {
return s.Addr != nil && s.Port != 0 && s.hasTXT return s.Addr != nil && s.Port != 0 && s.hasTXT
} }
// LookupDomain looks up a given service, in a domain, waiting at most // QueryParam is used to customize how a Lookup is performed
type QueryParam struct {
Service string // Service to lookup
Domain string // Lookup domain, default "local"
Timeout time.Duration // Lookup timeout, default 1 second
Interface *net.Interface // Multicast interface to use
Entries chan<- *ServiceEntry // Entries Channel
}
// DefaultParams is used to return a default set of QueryParam's
func DefaultParams(service string) *QueryParam {
return &QueryParam{
Service: service,
Domain: "local",
Timeout: time.Second,
Entries: make(chan *ServiceEntry),
}
}
// Query looks up a given service, in a domain, waiting at most
// for a timeout before finishing the query. The results are streamed // for a timeout before finishing the query. The results are streamed
// to a channel. Sends will not block, so clients should make sure to // to a channel. Sends will not block, so clients should make sure to
// either read or buffer. // either read or buffer.
func LookupDomain(service, domain string, timeout time.Duration, entries chan<- *ServiceEntry) error { func Query(params *QueryParam) error {
// Create a new client // Create a new client
client, err := newClient() client, err := newClient()
if err != nil { if err != nil {
@ -38,17 +59,27 @@ func LookupDomain(service, domain string, timeout time.Duration, entries chan<-
} }
defer client.Close() defer client.Close()
// Create the query name // Set the multicast interface
serviceAddr := fmt.Sprintf("%s.%s.", trimDot(service), trimDot(domain)) if params.Interface != nil {
if err := client.setInterface(params.Interface); err != nil {
// Run the query return err
return client.query(serviceAddr, timeout, entries) }
} }
// Lookup is the same as LookupDomain, however it only searches in the "local" // Ensure a timeout
// domain, and uses a one second lookup timeout. if params.Timeout == 0 {
params.Timeout = time.Second
}
// Run the query
return client.query(params)
}
// Lookup is the same as Query, however it uses all the default parameters
func Lookup(service string, entries chan<- *ServiceEntry) error { func Lookup(service string, entries chan<- *ServiceEntry) error {
return LookupDomain(service, "local", time.Second, entries) params := DefaultParams(service)
params.Entries = entries
return Query(params)
} }
// Client provides a query interface that can be used to // Client provides a query interface that can be used to
@ -107,8 +138,25 @@ func (c *client) Close() error {
return nil return nil
} }
// setInterface is used to set the query interface, uses sytem
// default if not provided
func (c *client) setInterface(iface *net.Interface) error {
p := ipv4.NewPacketConn(c.ipv4List)
if err := p.SetMulticastInterface(iface); err != nil {
return err
}
p2 := ipv6.NewPacketConn(c.ipv6List)
if err := p2.SetMulticastInterface(iface); err != nil {
return err
}
return nil
}
// query is used to perform a lookup and stream results // query is used to perform a lookup and stream results
func (c *client) query(service string, timeout time.Duration, entries chan<- *ServiceEntry) error { func (c *client) query(params *QueryParam) error {
// Create the service name
serviceAddr := fmt.Sprintf("%s.%s.", trimDot(params.Service), trimDot(params.Domain))
// Start listening for response packets // Start listening for response packets
msgCh := make(chan *dns.Msg, 32) msgCh := make(chan *dns.Msg, 32)
go c.recv(c.ipv4List, msgCh) go c.recv(c.ipv4List, msgCh)
@ -116,7 +164,7 @@ func (c *client) query(service string, timeout time.Duration, entries chan<- *Se
// Send the query // Send the query
m := new(dns.Msg) m := new(dns.Msg)
m.SetQuestion(service, dns.TypeANY) m.SetQuestion(serviceAddr, dns.TypeANY)
if err := c.sendQuery(m); err != nil { if err := c.sendQuery(m); err != nil {
return nil return nil
} }
@ -125,7 +173,7 @@ func (c *client) query(service string, timeout time.Duration, entries chan<- *Se
inprogress := make(map[string]*ServiceEntry) inprogress := make(map[string]*ServiceEntry)
// Listen until we reach the timeout // Listen until we reach the timeout
finish := time.After(timeout) finish := time.After(params.Timeout)
for { for {
select { select {
case resp := <-msgCh: case resp := <-msgCh:
@ -163,7 +211,7 @@ func (c *client) query(service string, timeout time.Duration, entries chan<- *Se
if inp.complete() && !inp.sent { if inp.complete() && !inp.sent {
inp.sent = true inp.sent = true
select { select {
case entries <- inp: case params.Entries <- inp:
default: default:
} }
} else { } else {

View File

@ -8,7 +8,7 @@ import (
func TestServer_StartStop(t *testing.T) { func TestServer_StartStop(t *testing.T) {
s := makeService(t) s := makeService(t)
serv, err := NewServer(&Config{s}) serv, err := NewServer(&Config{Zone: s})
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -19,7 +19,7 @@ func TestServer_Lookup(t *testing.T) {
s := makeService(t) s := makeService(t)
s.Service = "_foobar._tcp" s.Service = "_foobar._tcp"
s.Init() s.Init()
serv, err := NewServer(&Config{s}) serv, err := NewServer(&Config{Zone: s})
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -49,7 +49,13 @@ func TestServer_Lookup(t *testing.T) {
} }
}() }()
err = LookupDomain("_foobar._tcp", "local", 50*time.Millisecond, entries) params := &QueryParam{
Service: "_foobar._tcp",
Domain: "local",
Timeout: 50 * time.Millisecond,
Entries: entries,
}
err = Query(params)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }