updates #207
| @@ -12,9 +12,9 @@ import ( | |||||||
|  |  | ||||||
| // Resolver is a DNS network resolve | // Resolver is a DNS network resolve | ||||||
| type Resolver struct { | type Resolver struct { | ||||||
| 	sync.RWMutex |  | ||||||
| 	goresolver *net.Resolver | 	goresolver *net.Resolver | ||||||
| 	Address    string | 	Address    string | ||||||
|  | 	mu         sync.RWMutex | ||||||
| } | } | ||||||
|  |  | ||||||
| // Resolve tries to resolve endpoint address | // Resolve tries to resolve endpoint address | ||||||
| @@ -39,12 +39,12 @@ func (r *Resolver) Resolve(name string) ([]*resolver.Record, error) { | |||||||
| 		return []*resolver.Record{rec}, nil | 		return []*resolver.Record{rec}, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	r.RLock() | 	r.mu.RLock() | ||||||
| 	goresolver := r.goresolver | 	goresolver := r.goresolver | ||||||
| 	r.RUnlock() | 	r.mu.RUnlock() | ||||||
|  |  | ||||||
| 	if goresolver == nil { | 	if goresolver == nil { | ||||||
| 		r.Lock() | 		r.mu.Lock() | ||||||
| 		r.goresolver = &net.Resolver{ | 		r.goresolver = &net.Resolver{ | ||||||
| 			Dial: func(ctx context.Context, _ string, _ string) (net.Conn, error) { | 			Dial: func(ctx context.Context, _ string, _ string) (net.Conn, error) { | ||||||
| 				d := net.Dialer{ | 				d := net.Dialer{ | ||||||
| @@ -53,7 +53,7 @@ func (r *Resolver) Resolve(name string) ([]*resolver.Record, error) { | |||||||
| 				return d.DialContext(ctx, "udp", r.Address) | 				return d.DialContext(ctx, "udp", r.Address) | ||||||
| 			}, | 			}, | ||||||
| 		} | 		} | ||||||
| 		r.Unlock() | 		r.mu.Unlock() | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	addrs, err := goresolver.LookupIP(context.TODO(), "ip", host) | 	addrs, err := goresolver.LookupIP(context.TODO(), "ip", host) | ||||||
|   | |||||||
| @@ -3,7 +3,9 @@ package micro | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"net" | ||||||
| 	"sync" | 	"sync" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
| 	"github.com/KimMachineGun/automemlimit/memlimit" | 	"github.com/KimMachineGun/automemlimit/memlimit" | ||||||
| 	"go.uber.org/automaxprocs/maxprocs" | 	"go.uber.org/automaxprocs/maxprocs" | ||||||
| @@ -17,6 +19,7 @@ import ( | |||||||
| 	"go.unistack.org/micro/v3/server" | 	"go.unistack.org/micro/v3/server" | ||||||
| 	"go.unistack.org/micro/v3/store" | 	"go.unistack.org/micro/v3/store" | ||||||
| 	"go.unistack.org/micro/v3/tracer" | 	"go.unistack.org/micro/v3/tracer" | ||||||
|  | 	utildns "go.unistack.org/micro/v3/util/dns" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func init() { | func init() { | ||||||
| @@ -30,6 +33,8 @@ func init() { | |||||||
| 			), | 			), | ||||||
| 		), | 		), | ||||||
| 	) | 	) | ||||||
|  |  | ||||||
|  | 	net.DefaultResolver = utildns.NewNetResolver(utildns.Timeout(1 * time.Second)) | ||||||
| } | } | ||||||
|  |  | ||||||
| // Service is an interface that wraps the lower level components. | // Service is an interface that wraps the lower level components. | ||||||
|   | |||||||
							
								
								
									
										377
									
								
								util/dns/cache.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										377
									
								
								util/dns/cache.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,377 @@ | |||||||
|  | package dns | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"math" | ||||||
|  | 	"net" | ||||||
|  | 	"sync" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // DialFunc is a [net.Resolver.Dial] function. | ||||||
|  | type DialFunc func(ctx context.Context, network, address string) (net.Conn, error) | ||||||
|  |  | ||||||
|  | // NewNetResolver creates a caching [net.Resolver] that uses parent to resolve names. | ||||||
|  | func NewNetResolver(opts ...Option) *net.Resolver { | ||||||
|  | 	options := Options{Resolver: &net.Resolver{}} | ||||||
|  |  | ||||||
|  | 	for _, o := range opts { | ||||||
|  | 		o(&options) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return &net.Resolver{ | ||||||
|  | 		PreferGo:     true, | ||||||
|  | 		StrictErrors: options.Resolver.StrictErrors, | ||||||
|  | 		Dial:         NewNetDialer(options.Resolver.Dial, append(opts, Resolver(options.Resolver))...), | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // NewNetDialer adds caching to a [net.Resolver.Dial] function. | ||||||
|  | func NewNetDialer(parent DialFunc, opts ...Option) DialFunc { | ||||||
|  | 	cache := cache{dial: parent, opts: Options{}} | ||||||
|  | 	for _, o := range opts { | ||||||
|  | 		o(&cache.opts) | ||||||
|  | 	} | ||||||
|  | 	if cache.opts.MaxCacheEntries == 0 { | ||||||
|  | 		cache.opts.MaxCacheEntries = DefaultMaxCacheEntries | ||||||
|  | 	} | ||||||
|  | 	return func(ctx context.Context, network, address string) (net.Conn, error) { | ||||||
|  | 		conn := &dnsConn{} | ||||||
|  | 		conn.roundTrip = cachingRoundTrip(&cache, network, address) | ||||||
|  | 		return conn, nil | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | const DefaultMaxCacheEntries = 300 | ||||||
|  |  | ||||||
|  | // A Option customizes the resolver cache. | ||||||
|  | type Option func(*Options) | ||||||
|  |  | ||||||
|  | type Options struct { | ||||||
|  | 	Resolver        *net.Resolver | ||||||
|  | 	MaxCacheEntries int | ||||||
|  | 	MaxCacheTTL     time.Duration | ||||||
|  | 	MinCacheTTL     time.Duration | ||||||
|  | 	NegativeCache   bool | ||||||
|  | 	PreferIPV4      bool | ||||||
|  | 	PreferIPV6      bool | ||||||
|  | 	Timeout         time.Duration | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // MaxCacheEntries sets the maximum number of entries to cache. | ||||||
|  | // If zero, [DefaultMaxCacheEntries] is used; negative means no limit. | ||||||
|  | func MaxCacheEntries(n int) Option { | ||||||
|  | 	return func(o *Options) { | ||||||
|  | 		o.MaxCacheEntries = n | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // MaxCacheTTL sets the maximum time-to-live for entries in the cache. | ||||||
|  | func MaxCacheTTL(td time.Duration) Option { | ||||||
|  | 	return func(o *Options) { | ||||||
|  | 		o.MaxCacheTTL = td | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // MinCacheTTL sets the minimum time-to-live for entries in the cache. | ||||||
|  | func MinCacheTTL(td time.Duration) Option { | ||||||
|  | 	return func(o *Options) { | ||||||
|  | 		o.MinCacheTTL = td | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // NegativeCache sets whether to cache negative responses. | ||||||
|  | func NegativeCache(b bool) Option { | ||||||
|  | 	return func(o *Options) { | ||||||
|  | 		o.NegativeCache = b | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Timeout sets upstream *net.Resolver timeout | ||||||
|  | func Timeout(td time.Duration) Option { | ||||||
|  | 	return func(o *Options) { | ||||||
|  | 		o.Timeout = td | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Resolver sets upstream *net.Resolver. | ||||||
|  | func Resolver(r *net.Resolver) Option { | ||||||
|  | 	return func(o *Options) { | ||||||
|  | 		o.Resolver = r | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // PreferIPV4 resolve ipv4 records. | ||||||
|  | func PreferIPV4(b bool) Option { | ||||||
|  | 	return func(o *Options) { | ||||||
|  | 		o.PreferIPV4 = b | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // PreferIPV6 resolve ipv4 records. | ||||||
|  | func PreferIPV6(b bool) Option { | ||||||
|  | 	return func(o *Options) { | ||||||
|  | 		o.PreferIPV6 = b | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type cache struct { | ||||||
|  | 	sync.RWMutex | ||||||
|  |  | ||||||
|  | 	dial    DialFunc | ||||||
|  | 	entries map[string]cacheEntry | ||||||
|  |  | ||||||
|  | 	opts Options | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type cacheEntry struct { | ||||||
|  | 	deadline time.Time | ||||||
|  | 	value    string | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *cache) put(req string, res string) { | ||||||
|  | 	// ignore uncacheable/unparseable answers | ||||||
|  | 	if invalid(req, res) { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// ignore errors (if requested) | ||||||
|  | 	if nameError(res) && !c.opts.NegativeCache { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// ignore uncacheable/unparseable answers | ||||||
|  | 	ttl := getTTL(res) | ||||||
|  | 	if ttl <= 0 { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// adjust TTL | ||||||
|  | 	if ttl < c.opts.MinCacheTTL { | ||||||
|  | 		ttl = c.opts.MinCacheTTL | ||||||
|  | 	} | ||||||
|  | 	// maxTTL overrides minTTL | ||||||
|  | 	if ttl > c.opts.MaxCacheTTL && c.opts.MaxCacheTTL != 0 { | ||||||
|  | 		ttl = c.opts.MaxCacheTTL | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	c.Lock() | ||||||
|  | 	defer c.Unlock() | ||||||
|  | 	if c.entries == nil { | ||||||
|  | 		c.entries = make(map[string]cacheEntry) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// do some cache evition | ||||||
|  | 	var tested, evicted int | ||||||
|  | 	for k, e := range c.entries { | ||||||
|  | 		if time.Until(e.deadline) <= 0 { | ||||||
|  | 			// delete expired entry | ||||||
|  | 			delete(c.entries, k) | ||||||
|  | 			evicted++ | ||||||
|  | 		} | ||||||
|  | 		tested++ | ||||||
|  |  | ||||||
|  | 		if tested < 8 { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 		if evicted == 0 && c.opts.MaxCacheEntries > 0 && len(c.entries) >= c.opts.MaxCacheEntries { | ||||||
|  | 			// delete at least one entry | ||||||
|  | 			delete(c.entries, k) | ||||||
|  | 		} | ||||||
|  | 		break | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// remove message IDs | ||||||
|  | 	c.entries[req[2:]] = cacheEntry{ | ||||||
|  | 		deadline: time.Now().Add(ttl), | ||||||
|  | 		value:    res[2:], | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *cache) get(req string) (res string) { | ||||||
|  | 	// ignore invalid messages | ||||||
|  | 	if len(req) < 12 { | ||||||
|  | 		return "" | ||||||
|  | 	} | ||||||
|  | 	if req[2] >= 0x7f { | ||||||
|  | 		return "" | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	c.RLock() | ||||||
|  | 	defer c.RUnlock() | ||||||
|  |  | ||||||
|  | 	if c.entries == nil { | ||||||
|  | 		return "" | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// remove message ID | ||||||
|  | 	entry, ok := c.entries[req[2:]] | ||||||
|  | 	if ok && time.Until(entry.deadline) > 0 { | ||||||
|  | 		// prepend correct ID | ||||||
|  | 		return req[:2] + entry.value | ||||||
|  | 	} | ||||||
|  | 	return "" | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func invalid(req string, res string) bool { | ||||||
|  | 	if len(req) < 12 || len(res) < 12 { // header size | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	if req[0] != res[0] || req[1] != res[1] { // IDs match | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	if req[2] >= 0x7f || res[2] < 0x7f { // query, response | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	if req[2]&0x7a != 0 || res[2]&0x7a != 0 { // standard query, not truncated | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	if res[3]&0xf != 0 && res[3]&0xf != 3 { // no error, or name error | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	return false | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func nameError(res string) bool { | ||||||
|  | 	return res[3]&0xf == 3 | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func getTTL(msg string) time.Duration { | ||||||
|  | 	ttl := math.MaxInt32 | ||||||
|  |  | ||||||
|  | 	qdcount := getUint16(msg[4:]) | ||||||
|  | 	ancount := getUint16(msg[6:]) | ||||||
|  | 	nscount := getUint16(msg[8:]) | ||||||
|  | 	arcount := getUint16(msg[10:]) | ||||||
|  | 	rdcount := ancount + nscount + arcount | ||||||
|  |  | ||||||
|  | 	msg = msg[12:] // skip header | ||||||
|  |  | ||||||
|  | 	// skip questions | ||||||
|  | 	for i := 0; i < qdcount; i++ { | ||||||
|  | 		name := getNameLen(msg) | ||||||
|  | 		if name < 0 || name+4 > len(msg) { | ||||||
|  | 			return -1 | ||||||
|  | 		} | ||||||
|  | 		msg = msg[name+4:] | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// parse records | ||||||
|  | 	for i := 0; i < rdcount; i++ { | ||||||
|  | 		name := getNameLen(msg) | ||||||
|  | 		if name < 0 || name+10 > len(msg) { | ||||||
|  | 			return -1 | ||||||
|  | 		} | ||||||
|  | 		rtyp := getUint16(msg[name+0:]) | ||||||
|  | 		rttl := getUint32(msg[name+4:]) | ||||||
|  | 		rlen := getUint16(msg[name+8:]) | ||||||
|  | 		if name+10+rlen > len(msg) { | ||||||
|  | 			return -1 | ||||||
|  | 		} | ||||||
|  | 		// skip EDNS OPT since it doesn't have a TTL | ||||||
|  | 		if rtyp != 41 && rttl < ttl { | ||||||
|  | 			ttl = rttl | ||||||
|  | 		} | ||||||
|  | 		msg = msg[name+10+rlen:] | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return time.Duration(ttl) * time.Second | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func getNameLen(msg string) int { | ||||||
|  | 	i := 0 | ||||||
|  | 	for i < len(msg) { | ||||||
|  | 		if msg[i] == 0 { | ||||||
|  | 			// end of name | ||||||
|  | 			i += 1 | ||||||
|  | 			break | ||||||
|  | 		} | ||||||
|  | 		if msg[i] >= 0xc0 { | ||||||
|  | 			// compressed name | ||||||
|  | 			i += 2 | ||||||
|  | 			break | ||||||
|  | 		} | ||||||
|  | 		if msg[i] >= 0x40 { | ||||||
|  | 			// reserved | ||||||
|  | 			return -1 | ||||||
|  | 		} | ||||||
|  | 		i += int(msg[i] + 1) | ||||||
|  | 	} | ||||||
|  | 	return i | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func getUint16(s string) int { | ||||||
|  | 	return int(s[1]) | int(s[0])<<8 | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func getUint32(s string) int { | ||||||
|  | 	return int(s[3]) | int(s[2])<<8 | int(s[1])<<16 | int(s[0])<<24 | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func cachingRoundTrip(cache *cache, network, address string) roundTripper { | ||||||
|  | 	return func(ctx context.Context, req string) (res string, err error) { | ||||||
|  | 		// check cache | ||||||
|  | 		if res := cache.get(req); res != "" { | ||||||
|  | 			return res, nil | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		switch { | ||||||
|  | 		case cache.opts.PreferIPV4 && cache.opts.PreferIPV6: | ||||||
|  | 			network = "udp" | ||||||
|  | 		case cache.opts.PreferIPV4: | ||||||
|  | 			network = "udp4" | ||||||
|  | 		case cache.opts.PreferIPV6: | ||||||
|  | 			network = "udp6" | ||||||
|  | 		default: | ||||||
|  | 			network = "udp" | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if cache.opts.Timeout > 0 { | ||||||
|  | 			var cancel func() | ||||||
|  | 			ctx, cancel = context.WithTimeout(ctx, cache.opts.Timeout) | ||||||
|  | 			defer cancel() | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		// dial connection | ||||||
|  | 		var conn net.Conn | ||||||
|  | 		if cache.dial != nil { | ||||||
|  | 			conn, err = cache.dial(ctx, network, address) | ||||||
|  | 		} else { | ||||||
|  | 			var d net.Dialer | ||||||
|  | 			conn, err = d.DialContext(ctx, network, address) | ||||||
|  | 		} | ||||||
|  | 		if err != nil { | ||||||
|  | 			return "", err | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		ctx, cancel := context.WithCancel(ctx) | ||||||
|  | 		go func() { | ||||||
|  | 			<-ctx.Done() | ||||||
|  | 			conn.Close() | ||||||
|  | 		}() | ||||||
|  | 		defer cancel() | ||||||
|  |  | ||||||
|  | 		if t, ok := ctx.Deadline(); ok { | ||||||
|  | 			err = conn.SetDeadline(t) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return "", err | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		// send request | ||||||
|  | 		err = writeMessage(conn, req) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return "", err | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		// read response | ||||||
|  | 		res, err = readMessage(conn) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return "", err | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		// cache response | ||||||
|  | 		cache.put(req, res) | ||||||
|  | 		return res, nil | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										16
									
								
								util/dns/cache_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								util/dns/cache_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,16 @@ | |||||||
|  | package dns | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"net" | ||||||
|  | 	"testing" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestCache(t *testing.T) { | ||||||
|  | 	net.DefaultResolver = NewNetResolver(PreferIPV4(true)) | ||||||
|  |  | ||||||
|  | 	addrs, err := net.LookupHost("unistack.org") | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	t.Logf("addrs %v", addrs) | ||||||
|  | } | ||||||
							
								
								
									
										178
									
								
								util/dns/conn.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										178
									
								
								util/dns/conn.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,178 @@ | |||||||
|  | package dns | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"context" | ||||||
|  | 	"io" | ||||||
|  | 	"net" | ||||||
|  | 	"strings" | ||||||
|  | 	"sync" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type dnsConn struct { | ||||||
|  | 	sync.Mutex | ||||||
|  |  | ||||||
|  | 	ibuf bytes.Buffer | ||||||
|  | 	obuf bytes.Buffer | ||||||
|  |  | ||||||
|  | 	ctx       context.Context | ||||||
|  | 	cancel    context.CancelFunc | ||||||
|  | 	deadline  time.Time | ||||||
|  | 	roundTrip roundTripper | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type roundTripper func(ctx context.Context, req string) (res string, err error) | ||||||
|  |  | ||||||
|  | func (c *dnsConn) Read(b []byte) (n int, err error) { | ||||||
|  | 	imsg, n, err := c.drainBuffers(b) | ||||||
|  | 	if n != 0 || err != nil { | ||||||
|  | 		return n, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	ctx, cancel := c.childContext() | ||||||
|  | 	omsg, err := c.roundTrip(ctx, imsg) | ||||||
|  | 	cancel() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return 0, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return c.fillBuffer(b, omsg) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *dnsConn) Write(b []byte) (n int, err error) { | ||||||
|  | 	c.Lock() | ||||||
|  | 	defer c.Unlock() | ||||||
|  | 	return c.ibuf.Write(b) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *dnsConn) Close() error { | ||||||
|  | 	c.Lock() | ||||||
|  | 	cancel := c.cancel | ||||||
|  | 	c.Unlock() | ||||||
|  |  | ||||||
|  | 	if cancel != nil { | ||||||
|  | 		cancel() | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *dnsConn) LocalAddr() net.Addr { | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *dnsConn) RemoteAddr() net.Addr { | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *dnsConn) SetDeadline(t time.Time) error { | ||||||
|  | 	c.SetReadDeadline(t) | ||||||
|  | 	c.SetWriteDeadline(t) | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *dnsConn) SetReadDeadline(t time.Time) error { | ||||||
|  | 	c.Lock() | ||||||
|  | 	defer c.Unlock() | ||||||
|  | 	c.deadline = t | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *dnsConn) SetWriteDeadline(t time.Time) error { | ||||||
|  | 	// writes do not timeout | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *dnsConn) drainBuffers(b []byte) (string, int, error) { | ||||||
|  | 	c.Lock() | ||||||
|  | 	defer c.Unlock() | ||||||
|  |  | ||||||
|  | 	// drain the output buffer | ||||||
|  | 	if c.obuf.Len() > 0 { | ||||||
|  | 		n, err := c.obuf.Read(b) | ||||||
|  | 		return "", n, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// otherwise, get the next message from the input buffer | ||||||
|  | 	sz := c.ibuf.Next(2) | ||||||
|  | 	if len(sz) < 2 { | ||||||
|  | 		return "", 0, io.ErrUnexpectedEOF | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	size := int64(sz[0])<<8 | int64(sz[1]) | ||||||
|  |  | ||||||
|  | 	var str strings.Builder | ||||||
|  | 	_, err := io.CopyN(&str, &c.ibuf, size) | ||||||
|  | 	if err == io.EOF { | ||||||
|  | 		return "", 0, io.ErrUnexpectedEOF | ||||||
|  | 	} | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", 0, err | ||||||
|  | 	} | ||||||
|  | 	return str.String(), 0, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *dnsConn) fillBuffer(b []byte, str string) (int, error) { | ||||||
|  | 	c.Lock() | ||||||
|  | 	defer c.Unlock() | ||||||
|  | 	c.obuf.WriteByte(byte(len(str) >> 8)) | ||||||
|  | 	c.obuf.WriteByte(byte(len(str))) | ||||||
|  | 	c.obuf.WriteString(str) | ||||||
|  | 	return c.obuf.Read(b) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *dnsConn) childContext() (context.Context, context.CancelFunc) { | ||||||
|  | 	c.Lock() | ||||||
|  | 	defer c.Unlock() | ||||||
|  | 	if c.ctx == nil { | ||||||
|  | 		c.ctx, c.cancel = context.WithCancel(context.Background()) | ||||||
|  | 	} | ||||||
|  | 	return context.WithDeadline(c.ctx, c.deadline) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func writeMessage(conn net.Conn, msg string) error { | ||||||
|  | 	var buf []byte | ||||||
|  | 	if _, ok := conn.(net.PacketConn); ok { | ||||||
|  | 		buf = []byte(msg) | ||||||
|  | 	} else { | ||||||
|  | 		buf = make([]byte, len(msg)+2) | ||||||
|  | 		buf[0] = byte(len(msg) >> 8) | ||||||
|  | 		buf[1] = byte(len(msg)) | ||||||
|  | 		copy(buf[2:], msg) | ||||||
|  | 	} | ||||||
|  | 	// SHOULD do a single write on TCP (RFC 7766, section 8). | ||||||
|  | 	// MUST do a single write on UDP. | ||||||
|  | 	_, err := conn.Write(buf) | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func readMessage(c net.Conn) (string, error) { | ||||||
|  | 	if _, ok := c.(net.PacketConn); ok { | ||||||
|  | 		// RFC 1035 specifies 512 as the maximum message size for DNS over UDP. | ||||||
|  | 		// RFC 6891 OTOH suggests 4096 as the maximum payload size for EDNS. | ||||||
|  | 		b := make([]byte, 4096) | ||||||
|  | 		n, err := c.Read(b) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return "", err | ||||||
|  | 		} | ||||||
|  | 		return string(b[:n]), nil | ||||||
|  | 	} else { | ||||||
|  | 		var sz [2]byte | ||||||
|  | 		_, err := io.ReadFull(c, sz[:]) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return "", err | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		size := int64(sz[0])<<8 | int64(sz[1]) | ||||||
|  |  | ||||||
|  | 		var str strings.Builder | ||||||
|  | 		_, err = io.CopyN(&str, c, size) | ||||||
|  | 		if err == io.EOF { | ||||||
|  | 			return "", io.ErrUnexpectedEOF | ||||||
|  | 		} | ||||||
|  | 		if err != nil { | ||||||
|  | 			return "", err | ||||||
|  | 		} | ||||||
|  | 		return str.String(), nil | ||||||
|  | 	} | ||||||
|  | } | ||||||
		Reference in New Issue
	
	Block a user