All checks were successful
		
		
	
	test / test (push) Successful in 42s
				
			## Pull Request template Please, go through these steps before clicking submit on this PR. 1. Give a descriptive title to your PR. 2. Provide a description of your changes. 3. Make sure you have some relevant tests. 4. Put `closes #XXXX` in your comment to auto-close the issue that your PR fixes (if applicable). **PLEASE REMOVE THIS TEMPLATE BEFORE SUBMITTING** Reviewed-on: #369 Co-authored-by: Evstigneev Denis <danteevstigneev@yandex.ru> Co-committed-by: Evstigneev Denis <danteevstigneev@yandex.ru>
		
			
				
	
	
		
			184 lines
		
	
	
		
			3.4 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			184 lines
		
	
	
		
			3.4 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package dns
 | 
						|
 | 
						|
import (
 | 
						|
	"bytes"
 | 
						|
	"context"
 | 
						|
	"io"
 | 
						|
	"net"
 | 
						|
	"strings"
 | 
						|
	"sync"
 | 
						|
	"time"
 | 
						|
)
 | 
						|
 | 
						|
type dnsConn struct {
 | 
						|
	ctx       context.Context
 | 
						|
	cancel    context.CancelFunc
 | 
						|
	roundTrip roundTripper
 | 
						|
 | 
						|
	deadline time.Time
 | 
						|
 | 
						|
	ibuf bytes.Buffer
 | 
						|
	obuf bytes.Buffer
 | 
						|
 | 
						|
	sync.Mutex
 | 
						|
}
 | 
						|
 | 
						|
type roundTripper func(ctx context.Context, req string) (res string, err error)
 | 
						|
 | 
						|
func (c *dnsConn) Read(b []byte) (n int, err error) {
 | 
						|
	imsg, n, err := c.drainBuffers(b)
 | 
						|
	if n != 0 || err != nil {
 | 
						|
		return n, err
 | 
						|
	}
 | 
						|
 | 
						|
	ctx, cancel := c.childContext()
 | 
						|
	omsg, err := c.roundTrip(ctx, imsg)
 | 
						|
	cancel()
 | 
						|
	if err != nil {
 | 
						|
		return 0, err
 | 
						|
	}
 | 
						|
 | 
						|
	return c.fillBuffer(b, omsg)
 | 
						|
}
 | 
						|
 | 
						|
func (c *dnsConn) Write(b []byte) (n int, err error) {
 | 
						|
	c.Lock()
 | 
						|
	defer c.Unlock()
 | 
						|
	return c.ibuf.Write(b)
 | 
						|
}
 | 
						|
 | 
						|
func (c *dnsConn) Close() error {
 | 
						|
	c.Lock()
 | 
						|
	cancel := c.cancel
 | 
						|
	c.Unlock()
 | 
						|
 | 
						|
	if cancel != nil {
 | 
						|
		cancel()
 | 
						|
	}
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (c *dnsConn) LocalAddr() net.Addr {
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (c *dnsConn) RemoteAddr() net.Addr {
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (c *dnsConn) SetDeadline(t time.Time) error {
 | 
						|
	var err error
 | 
						|
	if err = c.SetReadDeadline(t); err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
	if err = c.SetWriteDeadline(t); err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (c *dnsConn) SetReadDeadline(t time.Time) error {
 | 
						|
	c.Lock()
 | 
						|
	c.deadline = t
 | 
						|
	c.Unlock()
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (c *dnsConn) SetWriteDeadline(_ time.Time) error {
 | 
						|
	// writes do not timeout
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (c *dnsConn) drainBuffers(b []byte) (string, int, error) {
 | 
						|
	c.Lock()
 | 
						|
	defer c.Unlock()
 | 
						|
 | 
						|
	// drain the output buffer
 | 
						|
	if c.obuf.Len() > 0 {
 | 
						|
		n, err := c.obuf.Read(b)
 | 
						|
		return "", n, err
 | 
						|
	}
 | 
						|
 | 
						|
	// otherwise, get the next message from the input buffer
 | 
						|
	sz := c.ibuf.Next(2)
 | 
						|
	if len(sz) < 2 {
 | 
						|
		return "", 0, io.ErrUnexpectedEOF
 | 
						|
	}
 | 
						|
 | 
						|
	size := int64(sz[0])<<8 | int64(sz[1])
 | 
						|
 | 
						|
	var str strings.Builder
 | 
						|
	_, err := io.CopyN(&str, &c.ibuf, size)
 | 
						|
	if err == io.EOF {
 | 
						|
		return "", 0, io.ErrUnexpectedEOF
 | 
						|
	}
 | 
						|
	if err != nil {
 | 
						|
		return "", 0, err
 | 
						|
	}
 | 
						|
	return str.String(), 0, nil
 | 
						|
}
 | 
						|
 | 
						|
func (c *dnsConn) fillBuffer(b []byte, str string) (int, error) {
 | 
						|
	c.Lock()
 | 
						|
	defer c.Unlock()
 | 
						|
	c.obuf.WriteByte(byte(len(str) >> 8))
 | 
						|
	c.obuf.WriteByte(byte(len(str)))
 | 
						|
	c.obuf.WriteString(str)
 | 
						|
	return c.obuf.Read(b)
 | 
						|
}
 | 
						|
 | 
						|
func (c *dnsConn) childContext() (context.Context, context.CancelFunc) {
 | 
						|
	c.Lock()
 | 
						|
	defer c.Unlock()
 | 
						|
	if c.ctx == nil {
 | 
						|
		c.ctx, c.cancel = context.WithCancel(context.Background())
 | 
						|
	}
 | 
						|
	return context.WithDeadline(c.ctx, c.deadline)
 | 
						|
}
 | 
						|
 | 
						|
func writeMessage(conn net.Conn, msg string) error {
 | 
						|
	var buf []byte
 | 
						|
	if _, ok := conn.(net.PacketConn); ok {
 | 
						|
		buf = []byte(msg)
 | 
						|
	} else {
 | 
						|
		buf = make([]byte, len(msg)+2)
 | 
						|
		buf[0] = byte(len(msg) >> 8)
 | 
						|
		buf[1] = byte(len(msg))
 | 
						|
		copy(buf[2:], msg)
 | 
						|
	}
 | 
						|
	// SHOULD do a single write on TCP (RFC 7766, section 8).
 | 
						|
	// MUST do a single write on UDP.
 | 
						|
	_, err := conn.Write(buf)
 | 
						|
	return err
 | 
						|
}
 | 
						|
 | 
						|
func readMessage(c net.Conn) (string, error) {
 | 
						|
	if _, ok := c.(net.PacketConn); ok {
 | 
						|
		// RFC 1035 specifies 512 as the maximum message size for DNS over UDP.
 | 
						|
		// RFC 6891 OTOH suggests 4096 as the maximum payload size for EDNS.
 | 
						|
		b := make([]byte, 4096)
 | 
						|
		n, err := c.Read(b)
 | 
						|
		if err != nil {
 | 
						|
			return "", err
 | 
						|
		}
 | 
						|
		return string(b[:n]), nil
 | 
						|
	}
 | 
						|
	var sz [2]byte
 | 
						|
	_, err := io.ReadFull(c, sz[:])
 | 
						|
	if err != nil {
 | 
						|
		return "", err
 | 
						|
	}
 | 
						|
 | 
						|
	size := int64(sz[0])<<8 | int64(sz[1])
 | 
						|
 | 
						|
	var str strings.Builder
 | 
						|
	_, err = io.CopyN(&str, c, size)
 | 
						|
	if err == io.EOF {
 | 
						|
		return "", io.ErrUnexpectedEOF
 | 
						|
	}
 | 
						|
	if err != nil {
 | 
						|
		return "", err
 | 
						|
	}
 | 
						|
	return str.String(), nil
 | 
						|
}
 |