diff --git a/resolver/dns/dns.go b/resolver/dns/dns.go index e80a9c8b..35c80c67 100644 --- a/resolver/dns/dns.go +++ b/resolver/dns/dns.go @@ -12,9 +12,9 @@ import ( // Resolver is a DNS network resolve type Resolver struct { - sync.RWMutex goresolver *net.Resolver Address string + mu sync.RWMutex } // Resolve tries to resolve endpoint address @@ -39,12 +39,12 @@ func (r *Resolver) Resolve(name string) ([]*resolver.Record, error) { return []*resolver.Record{rec}, nil } - r.RLock() + r.mu.RLock() goresolver := r.goresolver - r.RUnlock() + r.mu.RUnlock() if goresolver == nil { - r.Lock() + r.mu.Lock() r.goresolver = &net.Resolver{ Dial: func(ctx context.Context, _ string, _ string) (net.Conn, error) { d := net.Dialer{ @@ -53,7 +53,7 @@ func (r *Resolver) Resolve(name string) ([]*resolver.Record, error) { return d.DialContext(ctx, "udp", r.Address) }, } - r.Unlock() + r.mu.Unlock() } addrs, err := goresolver.LookupIP(context.TODO(), "ip", host) diff --git a/service.go b/service.go index 295e8c82..3c782035 100644 --- a/service.go +++ b/service.go @@ -3,7 +3,9 @@ package micro import ( "fmt" + "net" "sync" + "time" "github.com/KimMachineGun/automemlimit/memlimit" "go.uber.org/automaxprocs/maxprocs" @@ -17,6 +19,7 @@ import ( "go.unistack.org/micro/v3/server" "go.unistack.org/micro/v3/store" "go.unistack.org/micro/v3/tracer" + utildns "go.unistack.org/micro/v3/util/dns" ) func init() { @@ -30,6 +33,8 @@ func init() { ), ), ) + + net.DefaultResolver = utildns.NewNetResolver(utildns.Timeout(1 * time.Second)) } // Service is an interface that wraps the lower level components. diff --git a/util/dns/cache.go b/util/dns/cache.go new file mode 100644 index 00000000..051cc041 --- /dev/null +++ b/util/dns/cache.go @@ -0,0 +1,377 @@ +package dns + +import ( + "context" + "math" + "net" + "sync" + "time" +) + +// DialFunc is a [net.Resolver.Dial] function. +type DialFunc func(ctx context.Context, network, address string) (net.Conn, error) + +// NewNetResolver creates a caching [net.Resolver] that uses parent to resolve names. +func NewNetResolver(opts ...Option) *net.Resolver { + options := Options{Resolver: &net.Resolver{}} + + for _, o := range opts { + o(&options) + } + + return &net.Resolver{ + PreferGo: true, + StrictErrors: options.Resolver.StrictErrors, + Dial: NewNetDialer(options.Resolver.Dial, append(opts, Resolver(options.Resolver))...), + } +} + +// NewNetDialer adds caching to a [net.Resolver.Dial] function. +func NewNetDialer(parent DialFunc, opts ...Option) DialFunc { + cache := cache{dial: parent, opts: Options{}} + for _, o := range opts { + o(&cache.opts) + } + if cache.opts.MaxCacheEntries == 0 { + cache.opts.MaxCacheEntries = DefaultMaxCacheEntries + } + return func(ctx context.Context, network, address string) (net.Conn, error) { + conn := &dnsConn{} + conn.roundTrip = cachingRoundTrip(&cache, network, address) + return conn, nil + } +} + +const DefaultMaxCacheEntries = 300 + +// A Option customizes the resolver cache. +type Option func(*Options) + +type Options struct { + Resolver *net.Resolver + MaxCacheEntries int + MaxCacheTTL time.Duration + MinCacheTTL time.Duration + NegativeCache bool + PreferIPV4 bool + PreferIPV6 bool + Timeout time.Duration +} + +// MaxCacheEntries sets the maximum number of entries to cache. +// If zero, [DefaultMaxCacheEntries] is used; negative means no limit. +func MaxCacheEntries(n int) Option { + return func(o *Options) { + o.MaxCacheEntries = n + } +} + +// MaxCacheTTL sets the maximum time-to-live for entries in the cache. +func MaxCacheTTL(td time.Duration) Option { + return func(o *Options) { + o.MaxCacheTTL = td + } +} + +// MinCacheTTL sets the minimum time-to-live for entries in the cache. +func MinCacheTTL(td time.Duration) Option { + return func(o *Options) { + o.MinCacheTTL = td + } +} + +// NegativeCache sets whether to cache negative responses. +func NegativeCache(b bool) Option { + return func(o *Options) { + o.NegativeCache = b + } +} + +// Timeout sets upstream *net.Resolver timeout +func Timeout(td time.Duration) Option { + return func(o *Options) { + o.Timeout = td + } +} + +// Resolver sets upstream *net.Resolver. +func Resolver(r *net.Resolver) Option { + return func(o *Options) { + o.Resolver = r + } +} + +// PreferIPV4 resolve ipv4 records. +func PreferIPV4(b bool) Option { + return func(o *Options) { + o.PreferIPV4 = b + } +} + +// PreferIPV6 resolve ipv4 records. +func PreferIPV6(b bool) Option { + return func(o *Options) { + o.PreferIPV6 = b + } +} + +type cache struct { + sync.RWMutex + + dial DialFunc + entries map[string]cacheEntry + + opts Options +} + +type cacheEntry struct { + deadline time.Time + value string +} + +func (c *cache) put(req string, res string) { + // ignore uncacheable/unparseable answers + if invalid(req, res) { + return + } + + // ignore errors (if requested) + if nameError(res) && !c.opts.NegativeCache { + return + } + + // ignore uncacheable/unparseable answers + ttl := getTTL(res) + if ttl <= 0 { + return + } + + // adjust TTL + if ttl < c.opts.MinCacheTTL { + ttl = c.opts.MinCacheTTL + } + // maxTTL overrides minTTL + if ttl > c.opts.MaxCacheTTL && c.opts.MaxCacheTTL != 0 { + ttl = c.opts.MaxCacheTTL + } + + c.Lock() + defer c.Unlock() + if c.entries == nil { + c.entries = make(map[string]cacheEntry) + } + + // do some cache evition + var tested, evicted int + for k, e := range c.entries { + if time.Until(e.deadline) <= 0 { + // delete expired entry + delete(c.entries, k) + evicted++ + } + tested++ + + if tested < 8 { + continue + } + if evicted == 0 && c.opts.MaxCacheEntries > 0 && len(c.entries) >= c.opts.MaxCacheEntries { + // delete at least one entry + delete(c.entries, k) + } + break + } + + // remove message IDs + c.entries[req[2:]] = cacheEntry{ + deadline: time.Now().Add(ttl), + value: res[2:], + } +} + +func (c *cache) get(req string) (res string) { + // ignore invalid messages + if len(req) < 12 { + return "" + } + if req[2] >= 0x7f { + return "" + } + + c.RLock() + defer c.RUnlock() + + if c.entries == nil { + return "" + } + + // remove message ID + entry, ok := c.entries[req[2:]] + if ok && time.Until(entry.deadline) > 0 { + // prepend correct ID + return req[:2] + entry.value + } + return "" +} + +func invalid(req string, res string) bool { + if len(req) < 12 || len(res) < 12 { // header size + return true + } + if req[0] != res[0] || req[1] != res[1] { // IDs match + return true + } + if req[2] >= 0x7f || res[2] < 0x7f { // query, response + return true + } + if req[2]&0x7a != 0 || res[2]&0x7a != 0 { // standard query, not truncated + return true + } + if res[3]&0xf != 0 && res[3]&0xf != 3 { // no error, or name error + return true + } + return false +} + +func nameError(res string) bool { + return res[3]&0xf == 3 +} + +func getTTL(msg string) time.Duration { + ttl := math.MaxInt32 + + qdcount := getUint16(msg[4:]) + ancount := getUint16(msg[6:]) + nscount := getUint16(msg[8:]) + arcount := getUint16(msg[10:]) + rdcount := ancount + nscount + arcount + + msg = msg[12:] // skip header + + // skip questions + for i := 0; i < qdcount; i++ { + name := getNameLen(msg) + if name < 0 || name+4 > len(msg) { + return -1 + } + msg = msg[name+4:] + } + + // parse records + for i := 0; i < rdcount; i++ { + name := getNameLen(msg) + if name < 0 || name+10 > len(msg) { + return -1 + } + rtyp := getUint16(msg[name+0:]) + rttl := getUint32(msg[name+4:]) + rlen := getUint16(msg[name+8:]) + if name+10+rlen > len(msg) { + return -1 + } + // skip EDNS OPT since it doesn't have a TTL + if rtyp != 41 && rttl < ttl { + ttl = rttl + } + msg = msg[name+10+rlen:] + } + + return time.Duration(ttl) * time.Second +} + +func getNameLen(msg string) int { + i := 0 + for i < len(msg) { + if msg[i] == 0 { + // end of name + i += 1 + break + } + if msg[i] >= 0xc0 { + // compressed name + i += 2 + break + } + if msg[i] >= 0x40 { + // reserved + return -1 + } + i += int(msg[i] + 1) + } + return i +} + +func getUint16(s string) int { + return int(s[1]) | int(s[0])<<8 +} + +func getUint32(s string) int { + return int(s[3]) | int(s[2])<<8 | int(s[1])<<16 | int(s[0])<<24 +} + +func cachingRoundTrip(cache *cache, network, address string) roundTripper { + return func(ctx context.Context, req string) (res string, err error) { + // check cache + if res := cache.get(req); res != "" { + return res, nil + } + + switch { + case cache.opts.PreferIPV4 && cache.opts.PreferIPV6: + network = "udp" + case cache.opts.PreferIPV4: + network = "udp4" + case cache.opts.PreferIPV6: + network = "udp6" + default: + network = "udp" + } + + if cache.opts.Timeout > 0 { + var cancel func() + ctx, cancel = context.WithTimeout(ctx, cache.opts.Timeout) + defer cancel() + } + + // dial connection + var conn net.Conn + if cache.dial != nil { + conn, err = cache.dial(ctx, network, address) + } else { + var d net.Dialer + conn, err = d.DialContext(ctx, network, address) + } + if err != nil { + return "", err + } + + ctx, cancel := context.WithCancel(ctx) + go func() { + <-ctx.Done() + conn.Close() + }() + defer cancel() + + if t, ok := ctx.Deadline(); ok { + err = conn.SetDeadline(t) + if err != nil { + return "", err + } + } + + // send request + err = writeMessage(conn, req) + if err != nil { + return "", err + } + + // read response + res, err = readMessage(conn) + if err != nil { + return "", err + } + + // cache response + cache.put(req, res) + return res, nil + } +} diff --git a/util/dns/cache_test.go b/util/dns/cache_test.go new file mode 100644 index 00000000..6e2fb1d1 --- /dev/null +++ b/util/dns/cache_test.go @@ -0,0 +1,16 @@ +package dns + +import ( + "net" + "testing" +) + +func TestCache(t *testing.T) { + net.DefaultResolver = NewNetResolver(PreferIPV4(true)) + + addrs, err := net.LookupHost("unistack.org") + if err != nil { + t.Fatal(err) + } + t.Logf("addrs %v", addrs) +} diff --git a/util/dns/conn.go b/util/dns/conn.go new file mode 100644 index 00000000..f6057a96 --- /dev/null +++ b/util/dns/conn.go @@ -0,0 +1,178 @@ +package dns + +import ( + "bytes" + "context" + "io" + "net" + "strings" + "sync" + "time" +) + +type dnsConn struct { + sync.Mutex + + ibuf bytes.Buffer + obuf bytes.Buffer + + ctx context.Context + cancel context.CancelFunc + deadline time.Time + roundTrip roundTripper +} + +type roundTripper func(ctx context.Context, req string) (res string, err error) + +func (c *dnsConn) Read(b []byte) (n int, err error) { + imsg, n, err := c.drainBuffers(b) + if n != 0 || err != nil { + return n, err + } + + ctx, cancel := c.childContext() + omsg, err := c.roundTrip(ctx, imsg) + cancel() + if err != nil { + return 0, err + } + + return c.fillBuffer(b, omsg) +} + +func (c *dnsConn) Write(b []byte) (n int, err error) { + c.Lock() + defer c.Unlock() + return c.ibuf.Write(b) +} + +func (c *dnsConn) Close() error { + c.Lock() + cancel := c.cancel + c.Unlock() + + if cancel != nil { + cancel() + } + return nil +} + +func (c *dnsConn) LocalAddr() net.Addr { + return nil +} + +func (c *dnsConn) RemoteAddr() net.Addr { + return nil +} + +func (c *dnsConn) SetDeadline(t time.Time) error { + c.SetReadDeadline(t) + c.SetWriteDeadline(t) + return nil +} + +func (c *dnsConn) SetReadDeadline(t time.Time) error { + c.Lock() + defer c.Unlock() + c.deadline = t + return nil +} + +func (c *dnsConn) SetWriteDeadline(t time.Time) error { + // writes do not timeout + return nil +} + +func (c *dnsConn) drainBuffers(b []byte) (string, int, error) { + c.Lock() + defer c.Unlock() + + // drain the output buffer + if c.obuf.Len() > 0 { + n, err := c.obuf.Read(b) + return "", n, err + } + + // otherwise, get the next message from the input buffer + sz := c.ibuf.Next(2) + if len(sz) < 2 { + return "", 0, io.ErrUnexpectedEOF + } + + size := int64(sz[0])<<8 | int64(sz[1]) + + var str strings.Builder + _, err := io.CopyN(&str, &c.ibuf, size) + if err == io.EOF { + return "", 0, io.ErrUnexpectedEOF + } + if err != nil { + return "", 0, err + } + return str.String(), 0, nil +} + +func (c *dnsConn) fillBuffer(b []byte, str string) (int, error) { + c.Lock() + defer c.Unlock() + c.obuf.WriteByte(byte(len(str) >> 8)) + c.obuf.WriteByte(byte(len(str))) + c.obuf.WriteString(str) + return c.obuf.Read(b) +} + +func (c *dnsConn) childContext() (context.Context, context.CancelFunc) { + c.Lock() + defer c.Unlock() + if c.ctx == nil { + c.ctx, c.cancel = context.WithCancel(context.Background()) + } + return context.WithDeadline(c.ctx, c.deadline) +} + +func writeMessage(conn net.Conn, msg string) error { + var buf []byte + if _, ok := conn.(net.PacketConn); ok { + buf = []byte(msg) + } else { + buf = make([]byte, len(msg)+2) + buf[0] = byte(len(msg) >> 8) + buf[1] = byte(len(msg)) + copy(buf[2:], msg) + } + // SHOULD do a single write on TCP (RFC 7766, section 8). + // MUST do a single write on UDP. + _, err := conn.Write(buf) + return err +} + +func readMessage(c net.Conn) (string, error) { + if _, ok := c.(net.PacketConn); ok { + // RFC 1035 specifies 512 as the maximum message size for DNS over UDP. + // RFC 6891 OTOH suggests 4096 as the maximum payload size for EDNS. + b := make([]byte, 4096) + n, err := c.Read(b) + if err != nil { + return "", err + } + return string(b[:n]), nil + } else { + var sz [2]byte + _, err := io.ReadFull(c, sz[:]) + if err != nil { + return "", err + } + + size := int64(sz[0])<<8 | int64(sz[1]) + + var str strings.Builder + _, err = io.CopyN(&str, c, size) + if err == io.EOF { + return "", io.ErrUnexpectedEOF + } + if err != nil { + return "", err + } + return str.String(), nil + } +}