2024-12-03 01:11:08 +03:00
|
|
|
package dns
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"math"
|
|
|
|
"net"
|
|
|
|
"sync"
|
|
|
|
"time"
|
|
|
|
)
|
|
|
|
|
|
|
|
// DialFunc is a [net.Resolver.Dial] function.
|
|
|
|
type DialFunc func(ctx context.Context, network, address string) (net.Conn, error)
|
|
|
|
|
|
|
|
// NewNetResolver creates a caching [net.Resolver] that uses parent to resolve names.
|
|
|
|
func NewNetResolver(opts ...Option) *net.Resolver {
|
|
|
|
options := Options{Resolver: &net.Resolver{}}
|
|
|
|
|
|
|
|
for _, o := range opts {
|
|
|
|
o(&options)
|
|
|
|
}
|
|
|
|
|
|
|
|
return &net.Resolver{
|
|
|
|
PreferGo: true,
|
|
|
|
StrictErrors: options.Resolver.StrictErrors,
|
|
|
|
Dial: NewNetDialer(options.Resolver.Dial, append(opts, Resolver(options.Resolver))...),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// NewNetDialer adds caching to a [net.Resolver.Dial] function.
|
|
|
|
func NewNetDialer(parent DialFunc, opts ...Option) DialFunc {
|
|
|
|
cache := cache{dial: parent, opts: Options{}}
|
|
|
|
for _, o := range opts {
|
|
|
|
o(&cache.opts)
|
|
|
|
}
|
|
|
|
if cache.opts.MaxCacheEntries == 0 {
|
|
|
|
cache.opts.MaxCacheEntries = DefaultMaxCacheEntries
|
|
|
|
}
|
2024-12-09 13:06:43 +03:00
|
|
|
return func(_ context.Context, network, address string) (net.Conn, error) {
|
2024-12-03 01:11:08 +03:00
|
|
|
conn := &dnsConn{}
|
|
|
|
conn.roundTrip = cachingRoundTrip(&cache, network, address)
|
|
|
|
return conn, nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
const DefaultMaxCacheEntries = 300
|
|
|
|
|
|
|
|
// A Option customizes the resolver cache.
|
|
|
|
type Option func(*Options)
|
|
|
|
|
|
|
|
type Options struct {
|
|
|
|
Resolver *net.Resolver
|
|
|
|
MaxCacheEntries int
|
|
|
|
MaxCacheTTL time.Duration
|
|
|
|
MinCacheTTL time.Duration
|
|
|
|
NegativeCache bool
|
|
|
|
PreferIPV4 bool
|
|
|
|
PreferIPV6 bool
|
|
|
|
Timeout time.Duration
|
|
|
|
}
|
|
|
|
|
|
|
|
// MaxCacheEntries sets the maximum number of entries to cache.
|
|
|
|
// If zero, [DefaultMaxCacheEntries] is used; negative means no limit.
|
|
|
|
func MaxCacheEntries(n int) Option {
|
|
|
|
return func(o *Options) {
|
|
|
|
o.MaxCacheEntries = n
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// MaxCacheTTL sets the maximum time-to-live for entries in the cache.
|
|
|
|
func MaxCacheTTL(td time.Duration) Option {
|
|
|
|
return func(o *Options) {
|
|
|
|
o.MaxCacheTTL = td
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// MinCacheTTL sets the minimum time-to-live for entries in the cache.
|
|
|
|
func MinCacheTTL(td time.Duration) Option {
|
|
|
|
return func(o *Options) {
|
|
|
|
o.MinCacheTTL = td
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// NegativeCache sets whether to cache negative responses.
|
|
|
|
func NegativeCache(b bool) Option {
|
|
|
|
return func(o *Options) {
|
|
|
|
o.NegativeCache = b
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Timeout sets upstream *net.Resolver timeout
|
|
|
|
func Timeout(td time.Duration) Option {
|
|
|
|
return func(o *Options) {
|
|
|
|
o.Timeout = td
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Resolver sets upstream *net.Resolver.
|
|
|
|
func Resolver(r *net.Resolver) Option {
|
|
|
|
return func(o *Options) {
|
|
|
|
o.Resolver = r
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// PreferIPV4 resolve ipv4 records.
|
|
|
|
func PreferIPV4(b bool) Option {
|
|
|
|
return func(o *Options) {
|
|
|
|
o.PreferIPV4 = b
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// PreferIPV6 resolve ipv4 records.
|
|
|
|
func PreferIPV6(b bool) Option {
|
|
|
|
return func(o *Options) {
|
|
|
|
o.PreferIPV6 = b
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
type cache struct {
|
|
|
|
entries map[string]cacheEntry
|
2024-12-09 13:06:43 +03:00
|
|
|
dial DialFunc
|
2024-12-03 01:11:08 +03:00
|
|
|
|
|
|
|
opts Options
|
2024-12-09 13:06:43 +03:00
|
|
|
|
|
|
|
sync.RWMutex
|
2024-12-03 01:11:08 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
type cacheEntry struct {
|
|
|
|
deadline time.Time
|
|
|
|
value string
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *cache) put(req string, res string) {
|
|
|
|
// ignore uncacheable/unparseable answers
|
|
|
|
if invalid(req, res) {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
// ignore errors (if requested)
|
|
|
|
if nameError(res) && !c.opts.NegativeCache {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
// ignore uncacheable/unparseable answers
|
|
|
|
ttl := getTTL(res)
|
|
|
|
if ttl <= 0 {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
// adjust TTL
|
|
|
|
if ttl < c.opts.MinCacheTTL {
|
|
|
|
ttl = c.opts.MinCacheTTL
|
|
|
|
}
|
|
|
|
// maxTTL overrides minTTL
|
|
|
|
if ttl > c.opts.MaxCacheTTL && c.opts.MaxCacheTTL != 0 {
|
|
|
|
ttl = c.opts.MaxCacheTTL
|
|
|
|
}
|
|
|
|
|
|
|
|
c.Lock()
|
|
|
|
defer c.Unlock()
|
|
|
|
if c.entries == nil {
|
|
|
|
c.entries = make(map[string]cacheEntry)
|
|
|
|
}
|
|
|
|
|
|
|
|
// do some cache evition
|
|
|
|
var tested, evicted int
|
|
|
|
for k, e := range c.entries {
|
|
|
|
if time.Until(e.deadline) <= 0 {
|
|
|
|
// delete expired entry
|
|
|
|
delete(c.entries, k)
|
|
|
|
evicted++
|
|
|
|
}
|
|
|
|
tested++
|
|
|
|
|
|
|
|
if tested < 8 {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
if evicted == 0 && c.opts.MaxCacheEntries > 0 && len(c.entries) >= c.opts.MaxCacheEntries {
|
|
|
|
// delete at least one entry
|
|
|
|
delete(c.entries, k)
|
|
|
|
}
|
|
|
|
break
|
|
|
|
}
|
|
|
|
|
|
|
|
// remove message IDs
|
|
|
|
c.entries[req[2:]] = cacheEntry{
|
|
|
|
deadline: time.Now().Add(ttl),
|
|
|
|
value: res[2:],
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *cache) get(req string) (res string) {
|
|
|
|
// ignore invalid messages
|
|
|
|
if len(req) < 12 {
|
|
|
|
return ""
|
|
|
|
}
|
|
|
|
if req[2] >= 0x7f {
|
|
|
|
return ""
|
|
|
|
}
|
|
|
|
|
|
|
|
c.RLock()
|
|
|
|
defer c.RUnlock()
|
|
|
|
|
|
|
|
if c.entries == nil {
|
|
|
|
return ""
|
|
|
|
}
|
|
|
|
|
|
|
|
// remove message ID
|
|
|
|
entry, ok := c.entries[req[2:]]
|
|
|
|
if ok && time.Until(entry.deadline) > 0 {
|
|
|
|
// prepend correct ID
|
|
|
|
return req[:2] + entry.value
|
|
|
|
}
|
|
|
|
return ""
|
|
|
|
}
|
|
|
|
|
|
|
|
func invalid(req string, res string) bool {
|
|
|
|
if len(req) < 12 || len(res) < 12 { // header size
|
|
|
|
return true
|
|
|
|
}
|
|
|
|
if req[0] != res[0] || req[1] != res[1] { // IDs match
|
|
|
|
return true
|
|
|
|
}
|
|
|
|
if req[2] >= 0x7f || res[2] < 0x7f { // query, response
|
|
|
|
return true
|
|
|
|
}
|
|
|
|
if req[2]&0x7a != 0 || res[2]&0x7a != 0 { // standard query, not truncated
|
|
|
|
return true
|
|
|
|
}
|
|
|
|
if res[3]&0xf != 0 && res[3]&0xf != 3 { // no error, or name error
|
|
|
|
return true
|
|
|
|
}
|
|
|
|
return false
|
|
|
|
}
|
|
|
|
|
|
|
|
func nameError(res string) bool {
|
|
|
|
return res[3]&0xf == 3
|
|
|
|
}
|
|
|
|
|
|
|
|
func getTTL(msg string) time.Duration {
|
|
|
|
ttl := math.MaxInt32
|
|
|
|
|
|
|
|
qdcount := getUint16(msg[4:])
|
|
|
|
ancount := getUint16(msg[6:])
|
|
|
|
nscount := getUint16(msg[8:])
|
|
|
|
arcount := getUint16(msg[10:])
|
|
|
|
rdcount := ancount + nscount + arcount
|
|
|
|
|
|
|
|
msg = msg[12:] // skip header
|
|
|
|
|
|
|
|
// skip questions
|
|
|
|
for i := 0; i < qdcount; i++ {
|
|
|
|
name := getNameLen(msg)
|
|
|
|
if name < 0 || name+4 > len(msg) {
|
|
|
|
return -1
|
|
|
|
}
|
|
|
|
msg = msg[name+4:]
|
|
|
|
}
|
|
|
|
|
|
|
|
// parse records
|
|
|
|
for i := 0; i < rdcount; i++ {
|
|
|
|
name := getNameLen(msg)
|
|
|
|
if name < 0 || name+10 > len(msg) {
|
|
|
|
return -1
|
|
|
|
}
|
|
|
|
rtyp := getUint16(msg[name+0:])
|
|
|
|
rttl := getUint32(msg[name+4:])
|
|
|
|
rlen := getUint16(msg[name+8:])
|
|
|
|
if name+10+rlen > len(msg) {
|
|
|
|
return -1
|
|
|
|
}
|
|
|
|
// skip EDNS OPT since it doesn't have a TTL
|
|
|
|
if rtyp != 41 && rttl < ttl {
|
|
|
|
ttl = rttl
|
|
|
|
}
|
|
|
|
msg = msg[name+10+rlen:]
|
|
|
|
}
|
|
|
|
|
|
|
|
return time.Duration(ttl) * time.Second
|
|
|
|
}
|
|
|
|
|
|
|
|
func getNameLen(msg string) int {
|
|
|
|
i := 0
|
|
|
|
for i < len(msg) {
|
|
|
|
if msg[i] == 0 {
|
|
|
|
// end of name
|
2024-12-09 13:06:43 +03:00
|
|
|
i++
|
2024-12-03 01:11:08 +03:00
|
|
|
break
|
|
|
|
}
|
|
|
|
if msg[i] >= 0xc0 {
|
|
|
|
// compressed name
|
|
|
|
i += 2
|
|
|
|
break
|
|
|
|
}
|
|
|
|
if msg[i] >= 0x40 {
|
|
|
|
// reserved
|
|
|
|
return -1
|
|
|
|
}
|
|
|
|
i += int(msg[i] + 1)
|
|
|
|
}
|
|
|
|
return i
|
|
|
|
}
|
|
|
|
|
|
|
|
func getUint16(s string) int {
|
|
|
|
return int(s[1]) | int(s[0])<<8
|
|
|
|
}
|
|
|
|
|
|
|
|
func getUint32(s string) int {
|
|
|
|
return int(s[3]) | int(s[2])<<8 | int(s[1])<<16 | int(s[0])<<24
|
|
|
|
}
|
|
|
|
|
|
|
|
func cachingRoundTrip(cache *cache, network, address string) roundTripper {
|
|
|
|
return func(ctx context.Context, req string) (res string, err error) {
|
|
|
|
// check cache
|
2024-12-09 13:06:43 +03:00
|
|
|
if res = cache.get(req); res != "" {
|
2024-12-03 01:11:08 +03:00
|
|
|
return res, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
switch {
|
|
|
|
case cache.opts.PreferIPV4 && cache.opts.PreferIPV6:
|
|
|
|
network = "udp"
|
|
|
|
case cache.opts.PreferIPV4:
|
|
|
|
network = "udp4"
|
|
|
|
case cache.opts.PreferIPV6:
|
|
|
|
network = "udp6"
|
|
|
|
default:
|
|
|
|
network = "udp"
|
|
|
|
}
|
|
|
|
|
|
|
|
if cache.opts.Timeout > 0 {
|
|
|
|
var cancel func()
|
|
|
|
ctx, cancel = context.WithTimeout(ctx, cache.opts.Timeout)
|
|
|
|
defer cancel()
|
|
|
|
}
|
|
|
|
|
|
|
|
// dial connection
|
|
|
|
var conn net.Conn
|
|
|
|
if cache.dial != nil {
|
|
|
|
conn, err = cache.dial(ctx, network, address)
|
|
|
|
} else {
|
|
|
|
var d net.Dialer
|
|
|
|
conn, err = d.DialContext(ctx, network, address)
|
|
|
|
}
|
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
|
|
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
|
|
go func() {
|
|
|
|
<-ctx.Done()
|
|
|
|
conn.Close()
|
|
|
|
}()
|
|
|
|
defer cancel()
|
|
|
|
|
|
|
|
if t, ok := ctx.Deadline(); ok {
|
|
|
|
err = conn.SetDeadline(t)
|
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// send request
|
|
|
|
err = writeMessage(conn, req)
|
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
|
|
|
|
// read response
|
|
|
|
res, err = readMessage(conn)
|
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
|
|
|
|
// cache response
|
|
|
|
cache.put(req, res)
|
|
|
|
return res, nil
|
|
|
|
}
|
|
|
|
}
|