From 8cfcc9a3d976f965d4fbfb2cdb628279a321481c Mon Sep 17 00:00:00 2001 From: Martin Garton Date: Thu, 29 Sep 2016 17:46:10 +0100 Subject: [PATCH] Allow timeout and cancellation via context.Context Introduce context.Context to enable cancellation in addition to the existing timeout functionality. To retain compatability, Timeout is still available but will only be used if Context is not set. We use x/net/context.Context for the moment instead of context.Context in order to avoid creating a requirement on Go 1.7 --- client.go | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/client.go b/client.go index f4f0af0..6f10663 100644 --- a/client.go +++ b/client.go @@ -9,6 +9,7 @@ import ( "time" "github.com/miekg/dns" + "golang.org/x/net/context" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" ) @@ -39,7 +40,8 @@ func (s *ServiceEntry) complete() bool { type QueryParam struct { Service string // Service to lookup Domain string // Lookup domain, default "local" - Timeout time.Duration // Lookup timeout, default 1 second + Context context.Context // Context + Timeout time.Duration // Lookup timeout, default 1 second. Ignored if Context is provided Interface *net.Interface // Multicast interface to use Entries chan<- *ServiceEntry // Entries Channel WantUnicastResponse bool // Unicast response desired, as per 5.4 in RFC @@ -79,8 +81,15 @@ func Query(params *QueryParam) error { if params.Domain == "" { params.Domain = "local" } - if params.Timeout == 0 { - params.Timeout = time.Second + + if params.Context == nil { + if params.Timeout == 0 { + params.Timeout = time.Second + } + params.Context, _ = context.WithTimeout(context.Background(), params.Timeout) + if err != nil { + return err + } } // Run the query @@ -316,9 +325,6 @@ func (c *client) query(params *QueryParam) error { // Map the in-progress responses inprogress := make(map[string]*ServiceEntry) - // Listen until we reach the timeout - finish := time.After(params.Timeout) - for { select { case resp := <-msgCh: @@ -335,7 +341,7 @@ func (c *client) query(params *QueryParam) error { inp.sent = true select { case params.Entries <- inp: - case <-finish: + case <-params.Context.Done(): return nil } } else { @@ -347,7 +353,7 @@ func (c *client) query(params *QueryParam) error { log.Printf("[ERR] mdns: Failed to query instance %s: %v", inp.Name, err) } } - case <-finish: + case <-params.Context.Done(): return nil } }