diff --git a/client.go b/client.go index f4f0af0..6f10663 100644 --- a/client.go +++ b/client.go @@ -9,6 +9,7 @@ import ( "time" "github.com/miekg/dns" + "golang.org/x/net/context" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" ) @@ -39,7 +40,8 @@ func (s *ServiceEntry) complete() bool { type QueryParam struct { Service string // Service to lookup Domain string // Lookup domain, default "local" - Timeout time.Duration // Lookup timeout, default 1 second + Context context.Context // Context + Timeout time.Duration // Lookup timeout, default 1 second. Ignored if Context is provided Interface *net.Interface // Multicast interface to use Entries chan<- *ServiceEntry // Entries Channel WantUnicastResponse bool // Unicast response desired, as per 5.4 in RFC @@ -79,8 +81,15 @@ func Query(params *QueryParam) error { if params.Domain == "" { params.Domain = "local" } - if params.Timeout == 0 { - params.Timeout = time.Second + + if params.Context == nil { + if params.Timeout == 0 { + params.Timeout = time.Second + } + params.Context, _ = context.WithTimeout(context.Background(), params.Timeout) + if err != nil { + return err + } } // Run the query @@ -316,9 +325,6 @@ func (c *client) query(params *QueryParam) error { // Map the in-progress responses inprogress := make(map[string]*ServiceEntry) - // Listen until we reach the timeout - finish := time.After(params.Timeout) - for { select { case resp := <-msgCh: @@ -335,7 +341,7 @@ func (c *client) query(params *QueryParam) error { inp.sent = true select { case params.Entries <- inp: - case <-finish: + case <-params.Context.Done(): return nil } } else { @@ -347,7 +353,7 @@ func (c *client) query(params *QueryParam) error { log.Printf("[ERR] mdns: Failed to query instance %s: %v", inp.Name, err) } } - case <-finish: + case <-params.Context.Done(): return nil } }