All checks were successful
		
		
	
	test / test (push) Successful in 42s
				
			## Pull Request template Please, go through these steps before clicking submit on this PR. 1. Give a descriptive title to your PR. 2. Provide a description of your changes. 3. Make sure you have some relevant tests. 4. Put `closes #XXXX` in your comment to auto-close the issue that your PR fixes (if applicable). **PLEASE REMOVE THIS TEMPLATE BEFORE SUBMITTING** Reviewed-on: #369 Co-authored-by: Evstigneev Denis <danteevstigneev@yandex.ru> Co-committed-by: Evstigneev Denis <danteevstigneev@yandex.ru>
		
			
				
	
	
		
			184 lines
		
	
	
		
			3.4 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			184 lines
		
	
	
		
			3.4 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package dns
 | |
| 
 | |
| import (
 | |
| 	"bytes"
 | |
| 	"context"
 | |
| 	"io"
 | |
| 	"net"
 | |
| 	"strings"
 | |
| 	"sync"
 | |
| 	"time"
 | |
| )
 | |
| 
 | |
| type dnsConn struct {
 | |
| 	ctx       context.Context
 | |
| 	cancel    context.CancelFunc
 | |
| 	roundTrip roundTripper
 | |
| 
 | |
| 	deadline time.Time
 | |
| 
 | |
| 	ibuf bytes.Buffer
 | |
| 	obuf bytes.Buffer
 | |
| 
 | |
| 	sync.Mutex
 | |
| }
 | |
| 
 | |
| type roundTripper func(ctx context.Context, req string) (res string, err error)
 | |
| 
 | |
| func (c *dnsConn) Read(b []byte) (n int, err error) {
 | |
| 	imsg, n, err := c.drainBuffers(b)
 | |
| 	if n != 0 || err != nil {
 | |
| 		return n, err
 | |
| 	}
 | |
| 
 | |
| 	ctx, cancel := c.childContext()
 | |
| 	omsg, err := c.roundTrip(ctx, imsg)
 | |
| 	cancel()
 | |
| 	if err != nil {
 | |
| 		return 0, err
 | |
| 	}
 | |
| 
 | |
| 	return c.fillBuffer(b, omsg)
 | |
| }
 | |
| 
 | |
| func (c *dnsConn) Write(b []byte) (n int, err error) {
 | |
| 	c.Lock()
 | |
| 	defer c.Unlock()
 | |
| 	return c.ibuf.Write(b)
 | |
| }
 | |
| 
 | |
| func (c *dnsConn) Close() error {
 | |
| 	c.Lock()
 | |
| 	cancel := c.cancel
 | |
| 	c.Unlock()
 | |
| 
 | |
| 	if cancel != nil {
 | |
| 		cancel()
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (c *dnsConn) LocalAddr() net.Addr {
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (c *dnsConn) RemoteAddr() net.Addr {
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (c *dnsConn) SetDeadline(t time.Time) error {
 | |
| 	var err error
 | |
| 	if err = c.SetReadDeadline(t); err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	if err = c.SetWriteDeadline(t); err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (c *dnsConn) SetReadDeadline(t time.Time) error {
 | |
| 	c.Lock()
 | |
| 	c.deadline = t
 | |
| 	c.Unlock()
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (c *dnsConn) SetWriteDeadline(_ time.Time) error {
 | |
| 	// writes do not timeout
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (c *dnsConn) drainBuffers(b []byte) (string, int, error) {
 | |
| 	c.Lock()
 | |
| 	defer c.Unlock()
 | |
| 
 | |
| 	// drain the output buffer
 | |
| 	if c.obuf.Len() > 0 {
 | |
| 		n, err := c.obuf.Read(b)
 | |
| 		return "", n, err
 | |
| 	}
 | |
| 
 | |
| 	// otherwise, get the next message from the input buffer
 | |
| 	sz := c.ibuf.Next(2)
 | |
| 	if len(sz) < 2 {
 | |
| 		return "", 0, io.ErrUnexpectedEOF
 | |
| 	}
 | |
| 
 | |
| 	size := int64(sz[0])<<8 | int64(sz[1])
 | |
| 
 | |
| 	var str strings.Builder
 | |
| 	_, err := io.CopyN(&str, &c.ibuf, size)
 | |
| 	if err == io.EOF {
 | |
| 		return "", 0, io.ErrUnexpectedEOF
 | |
| 	}
 | |
| 	if err != nil {
 | |
| 		return "", 0, err
 | |
| 	}
 | |
| 	return str.String(), 0, nil
 | |
| }
 | |
| 
 | |
| func (c *dnsConn) fillBuffer(b []byte, str string) (int, error) {
 | |
| 	c.Lock()
 | |
| 	defer c.Unlock()
 | |
| 	c.obuf.WriteByte(byte(len(str) >> 8))
 | |
| 	c.obuf.WriteByte(byte(len(str)))
 | |
| 	c.obuf.WriteString(str)
 | |
| 	return c.obuf.Read(b)
 | |
| }
 | |
| 
 | |
| func (c *dnsConn) childContext() (context.Context, context.CancelFunc) {
 | |
| 	c.Lock()
 | |
| 	defer c.Unlock()
 | |
| 	if c.ctx == nil {
 | |
| 		c.ctx, c.cancel = context.WithCancel(context.Background())
 | |
| 	}
 | |
| 	return context.WithDeadline(c.ctx, c.deadline)
 | |
| }
 | |
| 
 | |
| func writeMessage(conn net.Conn, msg string) error {
 | |
| 	var buf []byte
 | |
| 	if _, ok := conn.(net.PacketConn); ok {
 | |
| 		buf = []byte(msg)
 | |
| 	} else {
 | |
| 		buf = make([]byte, len(msg)+2)
 | |
| 		buf[0] = byte(len(msg) >> 8)
 | |
| 		buf[1] = byte(len(msg))
 | |
| 		copy(buf[2:], msg)
 | |
| 	}
 | |
| 	// SHOULD do a single write on TCP (RFC 7766, section 8).
 | |
| 	// MUST do a single write on UDP.
 | |
| 	_, err := conn.Write(buf)
 | |
| 	return err
 | |
| }
 | |
| 
 | |
| func readMessage(c net.Conn) (string, error) {
 | |
| 	if _, ok := c.(net.PacketConn); ok {
 | |
| 		// RFC 1035 specifies 512 as the maximum message size for DNS over UDP.
 | |
| 		// RFC 6891 OTOH suggests 4096 as the maximum payload size for EDNS.
 | |
| 		b := make([]byte, 4096)
 | |
| 		n, err := c.Read(b)
 | |
| 		if err != nil {
 | |
| 			return "", err
 | |
| 		}
 | |
| 		return string(b[:n]), nil
 | |
| 	}
 | |
| 	var sz [2]byte
 | |
| 	_, err := io.ReadFull(c, sz[:])
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 
 | |
| 	size := int64(sz[0])<<8 | int64(sz[1])
 | |
| 
 | |
| 	var str strings.Builder
 | |
| 	_, err = io.CopyN(&str, c, size)
 | |
| 	if err == io.EOF {
 | |
| 		return "", io.ErrUnexpectedEOF
 | |
| 	}
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 	return str.String(), nil
 | |
| }
 |