diff --git a/client.go b/client.go index 1d2f5b4..e490088 100644 --- a/client.go +++ b/client.go @@ -1,6 +1,8 @@ package mdns import ( + "code.google.com/p/go.net/ipv4" + "code.google.com/p/go.net/ipv6" "fmt" "github.com/miekg/dns" "log" @@ -26,11 +28,30 @@ func (s *ServiceEntry) complete() bool { return s.Addr != nil && s.Port != 0 && s.hasTXT } -// LookupDomain looks up a given service, in a domain, waiting at most +// QueryParam is used to customize how a Lookup is performed +type QueryParam struct { + Service string // Service to lookup + Domain string // Lookup domain, default "local" + Timeout time.Duration // Lookup timeout, default 1 second + Interface *net.Interface // Multicast interface to use + Entries chan<- *ServiceEntry // Entries Channel +} + +// DefaultParams is used to return a default set of QueryParam's +func DefaultParams(service string) *QueryParam { + return &QueryParam{ + Service: service, + Domain: "local", + Timeout: time.Second, + Entries: make(chan *ServiceEntry), + } +} + +// Query looks up a given service, in a domain, waiting at most // for a timeout before finishing the query. The results are streamed // to a channel. Sends will not block, so clients should make sure to // either read or buffer. -func LookupDomain(service, domain string, timeout time.Duration, entries chan<- *ServiceEntry) error { +func Query(params *QueryParam) error { // Create a new client client, err := newClient() if err != nil { @@ -38,17 +59,27 @@ func LookupDomain(service, domain string, timeout time.Duration, entries chan<- } defer client.Close() - // Create the query name - serviceAddr := fmt.Sprintf("%s.%s.", trimDot(service), trimDot(domain)) + // Set the multicast interface + if params.Interface != nil { + if err := client.setInterface(params.Interface); err != nil { + return err + } + } + + // Ensure a timeout + if params.Timeout == 0 { + params.Timeout = time.Second + } // Run the query - return client.query(serviceAddr, timeout, entries) + return client.query(params) } -// Lookup is the same as LookupDomain, however it only searches in the "local" -// domain, and uses a one second lookup timeout. +// Lookup is the same as Query, however it uses all the default parameters func Lookup(service string, entries chan<- *ServiceEntry) error { - return LookupDomain(service, "local", time.Second, entries) + params := DefaultParams(service) + params.Entries = entries + return Query(params) } // Client provides a query interface that can be used to @@ -107,8 +138,25 @@ func (c *client) Close() error { return nil } +// setInterface is used to set the query interface, uses sytem +// default if not provided +func (c *client) setInterface(iface *net.Interface) error { + p := ipv4.NewPacketConn(c.ipv4List) + if err := p.SetMulticastInterface(iface); err != nil { + return err + } + p2 := ipv6.NewPacketConn(c.ipv6List) + if err := p2.SetMulticastInterface(iface); err != nil { + return err + } + return nil +} + // query is used to perform a lookup and stream results -func (c *client) query(service string, timeout time.Duration, entries chan<- *ServiceEntry) error { +func (c *client) query(params *QueryParam) error { + // Create the service name + serviceAddr := fmt.Sprintf("%s.%s.", trimDot(params.Service), trimDot(params.Domain)) + // Start listening for response packets msgCh := make(chan *dns.Msg, 32) go c.recv(c.ipv4List, msgCh) @@ -116,7 +164,7 @@ func (c *client) query(service string, timeout time.Duration, entries chan<- *Se // Send the query m := new(dns.Msg) - m.SetQuestion(service, dns.TypeANY) + m.SetQuestion(serviceAddr, dns.TypeANY) if err := c.sendQuery(m); err != nil { return nil } @@ -125,7 +173,7 @@ func (c *client) query(service string, timeout time.Duration, entries chan<- *Se inprogress := make(map[string]*ServiceEntry) // Listen until we reach the timeout - finish := time.After(timeout) + finish := time.After(params.Timeout) for { select { case resp := <-msgCh: @@ -163,7 +211,7 @@ func (c *client) query(service string, timeout time.Duration, entries chan<- *Se if inp.complete() && !inp.sent { inp.sent = true select { - case entries <- inp: + case params.Entries <- inp: default: } } else { diff --git a/server_test.go b/server_test.go index 6af0d83..57b21cf 100644 --- a/server_test.go +++ b/server_test.go @@ -8,7 +8,7 @@ import ( func TestServer_StartStop(t *testing.T) { s := makeService(t) - serv, err := NewServer(&Config{s}) + serv, err := NewServer(&Config{Zone: s}) if err != nil { t.Fatalf("err: %v", err) } @@ -19,7 +19,7 @@ func TestServer_Lookup(t *testing.T) { s := makeService(t) s.Service = "_foobar._tcp" s.Init() - serv, err := NewServer(&Config{s}) + serv, err := NewServer(&Config{Zone: s}) if err != nil { t.Fatalf("err: %v", err) } @@ -49,7 +49,13 @@ func TestServer_Lookup(t *testing.T) { } }() - err = LookupDomain("_foobar._tcp", "local", 50*time.Millisecond, entries) + params := &QueryParam{ + Service: "_foobar._tcp", + Domain: "local", + Timeout: 50 * time.Millisecond, + Entries: entries, + } + err = Query(params) if err != nil { t.Fatalf("err: %v", err) }