Merge pull request #4 from mjgarton/contexts
Allow timeout and cancellation via context.Context
This commit is contained in:
commit
cdf30746f9
22
client.go
22
client.go
@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/net/context"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
)
|
||||
@ -39,7 +40,8 @@ func (s *ServiceEntry) complete() bool {
|
||||
type QueryParam struct {
|
||||
Service string // Service to lookup
|
||||
Domain string // Lookup domain, default "local"
|
||||
Timeout time.Duration // Lookup timeout, default 1 second
|
||||
Context context.Context // Context
|
||||
Timeout time.Duration // Lookup timeout, default 1 second. Ignored if Context is provided
|
||||
Interface *net.Interface // Multicast interface to use
|
||||
Entries chan<- *ServiceEntry // Entries Channel
|
||||
WantUnicastResponse bool // Unicast response desired, as per 5.4 in RFC
|
||||
@ -79,8 +81,15 @@ func Query(params *QueryParam) error {
|
||||
if params.Domain == "" {
|
||||
params.Domain = "local"
|
||||
}
|
||||
if params.Timeout == 0 {
|
||||
params.Timeout = time.Second
|
||||
|
||||
if params.Context == nil {
|
||||
if params.Timeout == 0 {
|
||||
params.Timeout = time.Second
|
||||
}
|
||||
params.Context, _ = context.WithTimeout(context.Background(), params.Timeout)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Run the query
|
||||
@ -316,9 +325,6 @@ func (c *client) query(params *QueryParam) error {
|
||||
// Map the in-progress responses
|
||||
inprogress := make(map[string]*ServiceEntry)
|
||||
|
||||
// Listen until we reach the timeout
|
||||
finish := time.After(params.Timeout)
|
||||
|
||||
for {
|
||||
select {
|
||||
case resp := <-msgCh:
|
||||
@ -335,7 +341,7 @@ func (c *client) query(params *QueryParam) error {
|
||||
inp.sent = true
|
||||
select {
|
||||
case params.Entries <- inp:
|
||||
case <-finish:
|
||||
case <-params.Context.Done():
|
||||
return nil
|
||||
}
|
||||
} else {
|
||||
@ -347,7 +353,7 @@ func (c *client) query(params *QueryParam) error {
|
||||
log.Printf("[ERR] mdns: Failed to query instance %s: %v", inp.Name, err)
|
||||
}
|
||||
}
|
||||
case <-finish:
|
||||
case <-params.Context.Done():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user