package dns

import (
	"bytes"
	"context"
	"io"
	"net"
	"strings"
	"sync"
	"time"
)

type dnsConn struct {
	sync.Mutex

	ibuf bytes.Buffer
	obuf bytes.Buffer

	ctx       context.Context
	cancel    context.CancelFunc
	deadline  time.Time
	roundTrip roundTripper
}

type roundTripper func(ctx context.Context, req string) (res string, err error)

func (c *dnsConn) Read(b []byte) (n int, err error) {
	imsg, n, err := c.drainBuffers(b)
	if n != 0 || err != nil {
		return n, err
	}

	ctx, cancel := c.childContext()
	omsg, err := c.roundTrip(ctx, imsg)
	cancel()
	if err != nil {
		return 0, err
	}

	return c.fillBuffer(b, omsg)
}

func (c *dnsConn) Write(b []byte) (n int, err error) {
	c.Lock()
	defer c.Unlock()
	return c.ibuf.Write(b)
}

func (c *dnsConn) Close() error {
	c.Lock()
	cancel := c.cancel
	c.Unlock()

	if cancel != nil {
		cancel()
	}
	return nil
}

func (c *dnsConn) LocalAddr() net.Addr {
	return nil
}

func (c *dnsConn) RemoteAddr() net.Addr {
	return nil
}

func (c *dnsConn) SetDeadline(t time.Time) error {
	var err error
	if err = c.SetReadDeadline(t); err != nil {
		return err
	}
	if err = c.SetWriteDeadline(t); err != nil {
		return err
	}
	return nil
}

func (c *dnsConn) SetReadDeadline(t time.Time) error {
	c.Lock()
	defer c.Unlock()
	c.deadline = t
	return nil
}

func (c *dnsConn) SetWriteDeadline(t time.Time) error {
	// writes do not timeout
	return nil
}

func (c *dnsConn) drainBuffers(b []byte) (string, int, error) {
	c.Lock()
	defer c.Unlock()

	// drain the output buffer
	if c.obuf.Len() > 0 {
		n, err := c.obuf.Read(b)
		return "", n, err
	}

	// otherwise, get the next message from the input buffer
	sz := c.ibuf.Next(2)
	if len(sz) < 2 {
		return "", 0, io.ErrUnexpectedEOF
	}

	size := int64(sz[0])<<8 | int64(sz[1])

	var str strings.Builder
	_, err := io.CopyN(&str, &c.ibuf, size)
	if err == io.EOF {
		return "", 0, io.ErrUnexpectedEOF
	}
	if err != nil {
		return "", 0, err
	}
	return str.String(), 0, nil
}

func (c *dnsConn) fillBuffer(b []byte, str string) (int, error) {
	c.Lock()
	defer c.Unlock()
	c.obuf.WriteByte(byte(len(str) >> 8))
	c.obuf.WriteByte(byte(len(str)))
	c.obuf.WriteString(str)
	return c.obuf.Read(b)
}

func (c *dnsConn) childContext() (context.Context, context.CancelFunc) {
	c.Lock()
	defer c.Unlock()
	if c.ctx == nil {
		c.ctx, c.cancel = context.WithCancel(context.Background())
	}
	return context.WithDeadline(c.ctx, c.deadline)
}

func writeMessage(conn net.Conn, msg string) error {
	var buf []byte
	if _, ok := conn.(net.PacketConn); ok {
		buf = []byte(msg)
	} else {
		buf = make([]byte, len(msg)+2)
		buf[0] = byte(len(msg) >> 8)
		buf[1] = byte(len(msg))
		copy(buf[2:], msg)
	}
	// SHOULD do a single write on TCP (RFC 7766, section 8).
	// MUST do a single write on UDP.
	_, err := conn.Write(buf)
	return err
}

func readMessage(c net.Conn) (string, error) {
	if _, ok := c.(net.PacketConn); ok {
		// RFC 1035 specifies 512 as the maximum message size for DNS over UDP.
		// RFC 6891 OTOH suggests 4096 as the maximum payload size for EDNS.
		b := make([]byte, 4096)
		n, err := c.Read(b)
		if err != nil {
			return "", err
		}
		return string(b[:n]), nil
	} else {
		var sz [2]byte
		_, err := io.ReadFull(c, sz[:])
		if err != nil {
			return "", err
		}

		size := int64(sz[0])<<8 | int64(sz[1])

		var str strings.Builder
		_, err = io.CopyN(&str, c, size)
		if err == io.EOF {
			return "", io.ErrUnexpectedEOF
		}
		if err != nil {
			return "", err
		}
		return str.String(), nil
	}
}