Merge pull request #4 from mjgarton/contexts

Allow timeout and cancellation via context.Context
This commit is contained in:
Asim Aslam 2016-09-29 17:56:50 +01:00 committed by GitHub
commit cdf30746f9

View File

@ -9,6 +9,7 @@ import (
"time" "time"
"github.com/miekg/dns" "github.com/miekg/dns"
"golang.org/x/net/context"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
) )
@ -39,7 +40,8 @@ func (s *ServiceEntry) complete() bool {
type QueryParam struct { type QueryParam struct {
Service string // Service to lookup Service string // Service to lookup
Domain string // Lookup domain, default "local" Domain string // Lookup domain, default "local"
Timeout time.Duration // Lookup timeout, default 1 second Context context.Context // Context
Timeout time.Duration // Lookup timeout, default 1 second. Ignored if Context is provided
Interface *net.Interface // Multicast interface to use Interface *net.Interface // Multicast interface to use
Entries chan<- *ServiceEntry // Entries Channel Entries chan<- *ServiceEntry // Entries Channel
WantUnicastResponse bool // Unicast response desired, as per 5.4 in RFC WantUnicastResponse bool // Unicast response desired, as per 5.4 in RFC
@ -79,8 +81,15 @@ func Query(params *QueryParam) error {
if params.Domain == "" { if params.Domain == "" {
params.Domain = "local" params.Domain = "local"
} }
if params.Timeout == 0 {
params.Timeout = time.Second if params.Context == nil {
if params.Timeout == 0 {
params.Timeout = time.Second
}
params.Context, _ = context.WithTimeout(context.Background(), params.Timeout)
if err != nil {
return err
}
} }
// Run the query // Run the query
@ -316,9 +325,6 @@ func (c *client) query(params *QueryParam) error {
// Map the in-progress responses // Map the in-progress responses
inprogress := make(map[string]*ServiceEntry) inprogress := make(map[string]*ServiceEntry)
// Listen until we reach the timeout
finish := time.After(params.Timeout)
for { for {
select { select {
case resp := <-msgCh: case resp := <-msgCh:
@ -335,7 +341,7 @@ func (c *client) query(params *QueryParam) error {
inp.sent = true inp.sent = true
select { select {
case params.Entries <- inp: case params.Entries <- inp:
case <-finish: case <-params.Context.Done():
return nil return nil
} }
} else { } else {
@ -347,7 +353,7 @@ func (c *client) query(params *QueryParam) error {
log.Printf("[ERR] mdns: Failed to query instance %s: %v", inp.Name, err) log.Printf("[ERR] mdns: Failed to query instance %s: %v", inp.Name, err)
} }
} }
case <-finish: case <-params.Context.Done():
return nil return nil
} }
} }