replace default go resolver with caching resolver
Signed-off-by: Vasiliy Tolstov <v.tolstov@unistack.org>
This commit is contained in:
parent
36b7b9f5fb
commit
bf4143cde5
@ -12,9 +12,9 @@ import (
|
||||
|
||||
// Resolver is a DNS network resolve
|
||||
type Resolver struct {
|
||||
sync.RWMutex
|
||||
goresolver *net.Resolver
|
||||
Address string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// Resolve tries to resolve endpoint address
|
||||
@ -39,12 +39,12 @@ func (r *Resolver) Resolve(name string) ([]*resolver.Record, error) {
|
||||
return []*resolver.Record{rec}, nil
|
||||
}
|
||||
|
||||
r.RLock()
|
||||
r.mu.RLock()
|
||||
goresolver := r.goresolver
|
||||
r.RUnlock()
|
||||
r.mu.RUnlock()
|
||||
|
||||
if goresolver == nil {
|
||||
r.Lock()
|
||||
r.mu.Lock()
|
||||
r.goresolver = &net.Resolver{
|
||||
Dial: func(ctx context.Context, _ string, _ string) (net.Conn, error) {
|
||||
d := net.Dialer{
|
||||
@ -53,7 +53,7 @@ func (r *Resolver) Resolve(name string) ([]*resolver.Record, error) {
|
||||
return d.DialContext(ctx, "udp", r.Address)
|
||||
},
|
||||
}
|
||||
r.Unlock()
|
||||
r.mu.Unlock()
|
||||
}
|
||||
|
||||
addrs, err := goresolver.LookupIP(context.TODO(), "ip", host)
|
||||
|
@ -3,7 +3,9 @@ package micro
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/KimMachineGun/automemlimit/memlimit"
|
||||
"go.uber.org/automaxprocs/maxprocs"
|
||||
@ -17,6 +19,7 @@ import (
|
||||
"go.unistack.org/micro/v3/server"
|
||||
"go.unistack.org/micro/v3/store"
|
||||
"go.unistack.org/micro/v3/tracer"
|
||||
utildns "go.unistack.org/micro/v3/util/dns"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@ -30,6 +33,8 @@ func init() {
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
net.DefaultResolver = utildns.NewNetResolver(utildns.Timeout(1 * time.Second))
|
||||
}
|
||||
|
||||
// Service is an interface that wraps the lower level components.
|
||||
|
377
util/dns/cache.go
Normal file
377
util/dns/cache.go
Normal file
@ -0,0 +1,377 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DialFunc is a [net.Resolver.Dial] function.
|
||||
type DialFunc func(ctx context.Context, network, address string) (net.Conn, error)
|
||||
|
||||
// NewNetResolver creates a caching [net.Resolver] that uses parent to resolve names.
|
||||
func NewNetResolver(opts ...Option) *net.Resolver {
|
||||
options := Options{Resolver: &net.Resolver{}}
|
||||
|
||||
for _, o := range opts {
|
||||
o(&options)
|
||||
}
|
||||
|
||||
return &net.Resolver{
|
||||
PreferGo: true,
|
||||
StrictErrors: options.Resolver.StrictErrors,
|
||||
Dial: NewNetDialer(options.Resolver.Dial, append(opts, Resolver(options.Resolver))...),
|
||||
}
|
||||
}
|
||||
|
||||
// NewNetDialer adds caching to a [net.Resolver.Dial] function.
|
||||
func NewNetDialer(parent DialFunc, opts ...Option) DialFunc {
|
||||
cache := cache{dial: parent, opts: Options{}}
|
||||
for _, o := range opts {
|
||||
o(&cache.opts)
|
||||
}
|
||||
if cache.opts.MaxCacheEntries == 0 {
|
||||
cache.opts.MaxCacheEntries = DefaultMaxCacheEntries
|
||||
}
|
||||
return func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
conn := &dnsConn{}
|
||||
conn.roundTrip = cachingRoundTrip(&cache, network, address)
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
|
||||
const DefaultMaxCacheEntries = 300
|
||||
|
||||
// A Option customizes the resolver cache.
|
||||
type Option func(*Options)
|
||||
|
||||
type Options struct {
|
||||
Resolver *net.Resolver
|
||||
MaxCacheEntries int
|
||||
MaxCacheTTL time.Duration
|
||||
MinCacheTTL time.Duration
|
||||
NegativeCache bool
|
||||
PreferIPV4 bool
|
||||
PreferIPV6 bool
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
// MaxCacheEntries sets the maximum number of entries to cache.
|
||||
// If zero, [DefaultMaxCacheEntries] is used; negative means no limit.
|
||||
func MaxCacheEntries(n int) Option {
|
||||
return func(o *Options) {
|
||||
o.MaxCacheEntries = n
|
||||
}
|
||||
}
|
||||
|
||||
// MaxCacheTTL sets the maximum time-to-live for entries in the cache.
|
||||
func MaxCacheTTL(td time.Duration) Option {
|
||||
return func(o *Options) {
|
||||
o.MaxCacheTTL = td
|
||||
}
|
||||
}
|
||||
|
||||
// MinCacheTTL sets the minimum time-to-live for entries in the cache.
|
||||
func MinCacheTTL(td time.Duration) Option {
|
||||
return func(o *Options) {
|
||||
o.MinCacheTTL = td
|
||||
}
|
||||
}
|
||||
|
||||
// NegativeCache sets whether to cache negative responses.
|
||||
func NegativeCache(b bool) Option {
|
||||
return func(o *Options) {
|
||||
o.NegativeCache = b
|
||||
}
|
||||
}
|
||||
|
||||
// Timeout sets upstream *net.Resolver timeout
|
||||
func Timeout(td time.Duration) Option {
|
||||
return func(o *Options) {
|
||||
o.Timeout = td
|
||||
}
|
||||
}
|
||||
|
||||
// Resolver sets upstream *net.Resolver.
|
||||
func Resolver(r *net.Resolver) Option {
|
||||
return func(o *Options) {
|
||||
o.Resolver = r
|
||||
}
|
||||
}
|
||||
|
||||
// PreferIPV4 resolve ipv4 records.
|
||||
func PreferIPV4(b bool) Option {
|
||||
return func(o *Options) {
|
||||
o.PreferIPV4 = b
|
||||
}
|
||||
}
|
||||
|
||||
// PreferIPV6 resolve ipv4 records.
|
||||
func PreferIPV6(b bool) Option {
|
||||
return func(o *Options) {
|
||||
o.PreferIPV6 = b
|
||||
}
|
||||
}
|
||||
|
||||
type cache struct {
|
||||
sync.RWMutex
|
||||
|
||||
dial DialFunc
|
||||
entries map[string]cacheEntry
|
||||
|
||||
opts Options
|
||||
}
|
||||
|
||||
type cacheEntry struct {
|
||||
deadline time.Time
|
||||
value string
|
||||
}
|
||||
|
||||
func (c *cache) put(req string, res string) {
|
||||
// ignore uncacheable/unparseable answers
|
||||
if invalid(req, res) {
|
||||
return
|
||||
}
|
||||
|
||||
// ignore errors (if requested)
|
||||
if nameError(res) && !c.opts.NegativeCache {
|
||||
return
|
||||
}
|
||||
|
||||
// ignore uncacheable/unparseable answers
|
||||
ttl := getTTL(res)
|
||||
if ttl <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// adjust TTL
|
||||
if ttl < c.opts.MinCacheTTL {
|
||||
ttl = c.opts.MinCacheTTL
|
||||
}
|
||||
// maxTTL overrides minTTL
|
||||
if ttl > c.opts.MaxCacheTTL && c.opts.MaxCacheTTL != 0 {
|
||||
ttl = c.opts.MaxCacheTTL
|
||||
}
|
||||
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
if c.entries == nil {
|
||||
c.entries = make(map[string]cacheEntry)
|
||||
}
|
||||
|
||||
// do some cache evition
|
||||
var tested, evicted int
|
||||
for k, e := range c.entries {
|
||||
if time.Until(e.deadline) <= 0 {
|
||||
// delete expired entry
|
||||
delete(c.entries, k)
|
||||
evicted++
|
||||
}
|
||||
tested++
|
||||
|
||||
if tested < 8 {
|
||||
continue
|
||||
}
|
||||
if evicted == 0 && c.opts.MaxCacheEntries > 0 && len(c.entries) >= c.opts.MaxCacheEntries {
|
||||
// delete at least one entry
|
||||
delete(c.entries, k)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// remove message IDs
|
||||
c.entries[req[2:]] = cacheEntry{
|
||||
deadline: time.Now().Add(ttl),
|
||||
value: res[2:],
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cache) get(req string) (res string) {
|
||||
// ignore invalid messages
|
||||
if len(req) < 12 {
|
||||
return ""
|
||||
}
|
||||
if req[2] >= 0x7f {
|
||||
return ""
|
||||
}
|
||||
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
|
||||
if c.entries == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// remove message ID
|
||||
entry, ok := c.entries[req[2:]]
|
||||
if ok && time.Until(entry.deadline) > 0 {
|
||||
// prepend correct ID
|
||||
return req[:2] + entry.value
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func invalid(req string, res string) bool {
|
||||
if len(req) < 12 || len(res) < 12 { // header size
|
||||
return true
|
||||
}
|
||||
if req[0] != res[0] || req[1] != res[1] { // IDs match
|
||||
return true
|
||||
}
|
||||
if req[2] >= 0x7f || res[2] < 0x7f { // query, response
|
||||
return true
|
||||
}
|
||||
if req[2]&0x7a != 0 || res[2]&0x7a != 0 { // standard query, not truncated
|
||||
return true
|
||||
}
|
||||
if res[3]&0xf != 0 && res[3]&0xf != 3 { // no error, or name error
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func nameError(res string) bool {
|
||||
return res[3]&0xf == 3
|
||||
}
|
||||
|
||||
func getTTL(msg string) time.Duration {
|
||||
ttl := math.MaxInt32
|
||||
|
||||
qdcount := getUint16(msg[4:])
|
||||
ancount := getUint16(msg[6:])
|
||||
nscount := getUint16(msg[8:])
|
||||
arcount := getUint16(msg[10:])
|
||||
rdcount := ancount + nscount + arcount
|
||||
|
||||
msg = msg[12:] // skip header
|
||||
|
||||
// skip questions
|
||||
for i := 0; i < qdcount; i++ {
|
||||
name := getNameLen(msg)
|
||||
if name < 0 || name+4 > len(msg) {
|
||||
return -1
|
||||
}
|
||||
msg = msg[name+4:]
|
||||
}
|
||||
|
||||
// parse records
|
||||
for i := 0; i < rdcount; i++ {
|
||||
name := getNameLen(msg)
|
||||
if name < 0 || name+10 > len(msg) {
|
||||
return -1
|
||||
}
|
||||
rtyp := getUint16(msg[name+0:])
|
||||
rttl := getUint32(msg[name+4:])
|
||||
rlen := getUint16(msg[name+8:])
|
||||
if name+10+rlen > len(msg) {
|
||||
return -1
|
||||
}
|
||||
// skip EDNS OPT since it doesn't have a TTL
|
||||
if rtyp != 41 && rttl < ttl {
|
||||
ttl = rttl
|
||||
}
|
||||
msg = msg[name+10+rlen:]
|
||||
}
|
||||
|
||||
return time.Duration(ttl) * time.Second
|
||||
}
|
||||
|
||||
func getNameLen(msg string) int {
|
||||
i := 0
|
||||
for i < len(msg) {
|
||||
if msg[i] == 0 {
|
||||
// end of name
|
||||
i += 1
|
||||
break
|
||||
}
|
||||
if msg[i] >= 0xc0 {
|
||||
// compressed name
|
||||
i += 2
|
||||
break
|
||||
}
|
||||
if msg[i] >= 0x40 {
|
||||
// reserved
|
||||
return -1
|
||||
}
|
||||
i += int(msg[i] + 1)
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
func getUint16(s string) int {
|
||||
return int(s[1]) | int(s[0])<<8
|
||||
}
|
||||
|
||||
func getUint32(s string) int {
|
||||
return int(s[3]) | int(s[2])<<8 | int(s[1])<<16 | int(s[0])<<24
|
||||
}
|
||||
|
||||
func cachingRoundTrip(cache *cache, network, address string) roundTripper {
|
||||
return func(ctx context.Context, req string) (res string, err error) {
|
||||
// check cache
|
||||
if res := cache.get(req); res != "" {
|
||||
return res, nil
|
||||
}
|
||||
|
||||
switch {
|
||||
case cache.opts.PreferIPV4 && cache.opts.PreferIPV6:
|
||||
network = "udp"
|
||||
case cache.opts.PreferIPV4:
|
||||
network = "udp4"
|
||||
case cache.opts.PreferIPV6:
|
||||
network = "udp6"
|
||||
default:
|
||||
network = "udp"
|
||||
}
|
||||
|
||||
if cache.opts.Timeout > 0 {
|
||||
var cancel func()
|
||||
ctx, cancel = context.WithTimeout(ctx, cache.opts.Timeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
// dial connection
|
||||
var conn net.Conn
|
||||
if cache.dial != nil {
|
||||
conn, err = cache.dial(ctx, network, address)
|
||||
} else {
|
||||
var d net.Dialer
|
||||
conn, err = d.DialContext(ctx, network, address)
|
||||
}
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
conn.Close()
|
||||
}()
|
||||
defer cancel()
|
||||
|
||||
if t, ok := ctx.Deadline(); ok {
|
||||
err = conn.SetDeadline(t)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
// send request
|
||||
err = writeMessage(conn, req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// read response
|
||||
res, err = readMessage(conn)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// cache response
|
||||
cache.put(req, res)
|
||||
return res, nil
|
||||
}
|
||||
}
|
16
util/dns/cache_test.go
Normal file
16
util/dns/cache_test.go
Normal file
@ -0,0 +1,16 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCache(t *testing.T) {
|
||||
net.DefaultResolver = NewNetResolver(PreferIPV4(true))
|
||||
|
||||
addrs, err := net.LookupHost("unistack.org")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("addrs %v", addrs)
|
||||
}
|
178
util/dns/conn.go
Normal file
178
util/dns/conn.go
Normal file
@ -0,0 +1,178 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type dnsConn struct {
|
||||
sync.Mutex
|
||||
|
||||
ibuf bytes.Buffer
|
||||
obuf bytes.Buffer
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
deadline time.Time
|
||||
roundTrip roundTripper
|
||||
}
|
||||
|
||||
type roundTripper func(ctx context.Context, req string) (res string, err error)
|
||||
|
||||
func (c *dnsConn) Read(b []byte) (n int, err error) {
|
||||
imsg, n, err := c.drainBuffers(b)
|
||||
if n != 0 || err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
ctx, cancel := c.childContext()
|
||||
omsg, err := c.roundTrip(ctx, imsg)
|
||||
cancel()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return c.fillBuffer(b, omsg)
|
||||
}
|
||||
|
||||
func (c *dnsConn) Write(b []byte) (n int, err error) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
return c.ibuf.Write(b)
|
||||
}
|
||||
|
||||
func (c *dnsConn) Close() error {
|
||||
c.Lock()
|
||||
cancel := c.cancel
|
||||
c.Unlock()
|
||||
|
||||
if cancel != nil {
|
||||
cancel()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *dnsConn) LocalAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *dnsConn) RemoteAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *dnsConn) SetDeadline(t time.Time) error {
|
||||
c.SetReadDeadline(t)
|
||||
c.SetWriteDeadline(t)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *dnsConn) SetReadDeadline(t time.Time) error {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
c.deadline = t
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *dnsConn) SetWriteDeadline(t time.Time) error {
|
||||
// writes do not timeout
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *dnsConn) drainBuffers(b []byte) (string, int, error) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
// drain the output buffer
|
||||
if c.obuf.Len() > 0 {
|
||||
n, err := c.obuf.Read(b)
|
||||
return "", n, err
|
||||
}
|
||||
|
||||
// otherwise, get the next message from the input buffer
|
||||
sz := c.ibuf.Next(2)
|
||||
if len(sz) < 2 {
|
||||
return "", 0, io.ErrUnexpectedEOF
|
||||
}
|
||||
|
||||
size := int64(sz[0])<<8 | int64(sz[1])
|
||||
|
||||
var str strings.Builder
|
||||
_, err := io.CopyN(&str, &c.ibuf, size)
|
||||
if err == io.EOF {
|
||||
return "", 0, io.ErrUnexpectedEOF
|
||||
}
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
return str.String(), 0, nil
|
||||
}
|
||||
|
||||
func (c *dnsConn) fillBuffer(b []byte, str string) (int, error) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
c.obuf.WriteByte(byte(len(str) >> 8))
|
||||
c.obuf.WriteByte(byte(len(str)))
|
||||
c.obuf.WriteString(str)
|
||||
return c.obuf.Read(b)
|
||||
}
|
||||
|
||||
func (c *dnsConn) childContext() (context.Context, context.CancelFunc) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
if c.ctx == nil {
|
||||
c.ctx, c.cancel = context.WithCancel(context.Background())
|
||||
}
|
||||
return context.WithDeadline(c.ctx, c.deadline)
|
||||
}
|
||||
|
||||
func writeMessage(conn net.Conn, msg string) error {
|
||||
var buf []byte
|
||||
if _, ok := conn.(net.PacketConn); ok {
|
||||
buf = []byte(msg)
|
||||
} else {
|
||||
buf = make([]byte, len(msg)+2)
|
||||
buf[0] = byte(len(msg) >> 8)
|
||||
buf[1] = byte(len(msg))
|
||||
copy(buf[2:], msg)
|
||||
}
|
||||
// SHOULD do a single write on TCP (RFC 7766, section 8).
|
||||
// MUST do a single write on UDP.
|
||||
_, err := conn.Write(buf)
|
||||
return err
|
||||
}
|
||||
|
||||
func readMessage(c net.Conn) (string, error) {
|
||||
if _, ok := c.(net.PacketConn); ok {
|
||||
// RFC 1035 specifies 512 as the maximum message size for DNS over UDP.
|
||||
// RFC 6891 OTOH suggests 4096 as the maximum payload size for EDNS.
|
||||
b := make([]byte, 4096)
|
||||
n, err := c.Read(b)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(b[:n]), nil
|
||||
} else {
|
||||
var sz [2]byte
|
||||
_, err := io.ReadFull(c, sz[:])
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
size := int64(sz[0])<<8 | int64(sz[1])
|
||||
|
||||
var str strings.Builder
|
||||
_, err = io.CopyN(&str, c, size)
|
||||
if err == io.EOF {
|
||||
return "", io.ErrUnexpectedEOF
|
||||
}
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return str.String(), nil
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user