Allow timeout and cancellation via context.Context

Introduce context.Context to enable cancellation in addition to the
existing timeout functionality.  To retain compatability, Timeout is
still available but will only be used if Context is not set.

We use x/net/context.Context for the moment instead of context.Context
in order to avoid creating a requirement on Go 1.7
This commit is contained in:
Martin Garton 2016-09-29 17:46:10 +01:00
parent 4725bf05ee
commit 8cfcc9a3d9

View File

@ -9,6 +9,7 @@ import (
"time" "time"
"github.com/miekg/dns" "github.com/miekg/dns"
"golang.org/x/net/context"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
) )
@ -39,7 +40,8 @@ func (s *ServiceEntry) complete() bool {
type QueryParam struct { type QueryParam struct {
Service string // Service to lookup Service string // Service to lookup
Domain string // Lookup domain, default "local" Domain string // Lookup domain, default "local"
Timeout time.Duration // Lookup timeout, default 1 second Context context.Context // Context
Timeout time.Duration // Lookup timeout, default 1 second. Ignored if Context is provided
Interface *net.Interface // Multicast interface to use Interface *net.Interface // Multicast interface to use
Entries chan<- *ServiceEntry // Entries Channel Entries chan<- *ServiceEntry // Entries Channel
WantUnicastResponse bool // Unicast response desired, as per 5.4 in RFC WantUnicastResponse bool // Unicast response desired, as per 5.4 in RFC
@ -79,8 +81,15 @@ func Query(params *QueryParam) error {
if params.Domain == "" { if params.Domain == "" {
params.Domain = "local" params.Domain = "local"
} }
if params.Timeout == 0 {
params.Timeout = time.Second if params.Context == nil {
if params.Timeout == 0 {
params.Timeout = time.Second
}
params.Context, _ = context.WithTimeout(context.Background(), params.Timeout)
if err != nil {
return err
}
} }
// Run the query // Run the query
@ -316,9 +325,6 @@ func (c *client) query(params *QueryParam) error {
// Map the in-progress responses // Map the in-progress responses
inprogress := make(map[string]*ServiceEntry) inprogress := make(map[string]*ServiceEntry)
// Listen until we reach the timeout
finish := time.After(params.Timeout)
for { for {
select { select {
case resp := <-msgCh: case resp := <-msgCh:
@ -335,7 +341,7 @@ func (c *client) query(params *QueryParam) error {
inp.sent = true inp.sent = true
select { select {
case params.Entries <- inp: case params.Entries <- inp:
case <-finish: case <-params.Context.Done():
return nil return nil
} }
} else { } else {
@ -347,7 +353,7 @@ func (c *client) query(params *QueryParam) error {
log.Printf("[ERR] mdns: Failed to query instance %s: %v", inp.Name, err) log.Printf("[ERR] mdns: Failed to query instance %s: %v", inp.Name, err)
} }
} }
case <-finish: case <-params.Context.Done():
return nil return nil
} }
} }