package mucp

import (
	"crypto/cipher"
	"encoding/base32"
	"io"
	"sync"
	"time"

	"github.com/unistack-org/micro/v3/logger"
	"github.com/unistack-org/micro/v3/transport"
	"github.com/unistack-org/micro/v3/tunnel"
)

// session is our pseudo session for transport.Socket
type session struct {
	// the tunnel id
	tunnel string
	// the channel name
	channel string
	// the session id based on Micro.Tunnel-Session
	session string
	// token is the session token
	token string
	// closed
	closed chan bool
	// remote addr
	remote string
	// local addr
	local string
	// send chan
	send chan *message
	// recv chan
	recv chan *message
	// if the discovery worked
	discovered bool
	// if the session was accepted
	accepted bool
	// outbound marks the session as outbound dialled connection
	outbound bool
	// lookback marks the session as a loopback on the inbound
	loopback bool
	// mode of the connection
	mode tunnel.Mode
	// the dial timeout
	dialTimeout time.Duration
	// the read timeout
	readTimeout time.Duration
	// the link on which this message was received
	link string
	// the error response
	errChan chan error
	// key for session encryption
	key []byte
	// cipher for session
	gcm cipher.AEAD
	sync.RWMutex
}

// message is sent over the send channel
type message struct {
	// type of message
	typ string
	// tunnel id
	tunnel string
	// channel name
	channel string
	// the session id
	session string
	// outbound marks the message as outbound
	outbound bool
	// loopback marks the message intended for loopback
	loopback bool
	// mode of the connection
	mode tunnel.Mode
	// the link to send the message on
	link string
	// transport data
	data *transport.Message
	// the error channel
	errChan chan error
}

func (s *session) Remote() string {
	return s.remote
}

func (s *session) Local() string {
	return s.local
}

func (s *session) Link() string {
	return s.link
}

func (s *session) Id() string {
	return s.session
}

func (s *session) Channel() string {
	return s.channel
}

// newMessage creates a new message based on the session
func (s *session) newMessage(typ string) *message {
	return &message{
		typ:      typ,
		tunnel:   s.tunnel,
		channel:  s.channel,
		session:  s.session,
		outbound: s.outbound,
		loopback: s.loopback,
		mode:     s.mode,
		link:     s.link,
		errChan:  s.errChan,
	}
}

func (s *session) sendMsg(msg *message) error {
	select {
	case <-s.closed:
		return io.EOF
	case s.send <- msg:
		return nil
	}
}

func (s *session) wait(msg *message) error {
	// wait for an error response
	select {
	case err := <-msg.errChan:
		if err != nil {
			return err
		}
	case <-s.closed:
		return io.EOF
	}

	return nil
}

// waitFor waits for the message type required until the timeout specified
func (s *session) waitFor(msgType string, timeout time.Duration) (*message, error) {
	now := time.Now()

	after := func(timeout time.Duration) <-chan time.Time {
		if timeout < time.Duration(0) {
			return nil
		}

		// get the delta
		d := time.Since(now)

		// dial timeout minus time since
		wait := timeout - d

		if wait < time.Duration(0) {
			wait = time.Duration(0)
		}

		return time.After(wait)
	}

	// wait for the message type
	for {
		select {
		case msg := <-s.recv:
			// there may be no message type
			if len(msgType) == 0 {
				return msg, nil
			}

			// ignore what we don't want
			if msg.typ != msgType {
				if logger.V(logger.DebugLevel, log) {
					log.Debugf("Tunnel received non %s message in waiting for %s", msg.typ, msgType)
				}
				continue
			}

			// got the message
			return msg, nil
		case <-after(timeout):
			return nil, tunnel.ErrReadTimeout
		case <-s.closed:
			// check pending message queue
			select {
			case msg := <-s.recv:
				// there may be no message type
				if len(msgType) == 0 {
					return msg, nil
				}

				// ignore what we don't want
				if msg.typ != msgType {
					if logger.V(logger.DebugLevel, log) {
						log.Debugf("Tunnel received non %s message in waiting for %s", msg.typ, msgType)
					}
					continue
				}

				// got the message
				return msg, nil
			default:
				// non blocking
			}
			return nil, io.EOF
		}
	}
}

// Discover attempts to discover the link for a specific channel.
// This is only used by the tunnel.Dial when first connecting.
func (s *session) Discover() error {
	// create a new discovery message for this channel
	msg := s.newMessage("discover")
	// broadcast the message to all links
	msg.mode = tunnel.Broadcast
	// its an outbound connection since we're dialling
	msg.outbound = true
	// don't set the link since we don't know where it is
	msg.link = ""

	// if multicast then set that as session
	if s.mode == tunnel.Multicast {
		msg.session = "multicast"
	}

	// send discover message
	if err := s.sendMsg(msg); err != nil {
		return err
	}

	// set time now
	now := time.Now()

	// after strips down the dial timeout
	after := func() time.Duration {
		d := time.Since(now)
		// dial timeout minus time since
		wait := s.dialTimeout - d
		// make sure its always > 0
		if wait < time.Duration(0) {
			return time.Duration(0)
		}
		return wait
	}

	// the discover message is sent out, now
	// wait to hear back about the sent message
	select {
	case <-time.After(after()):
		return tunnel.ErrDialTimeout
	case err := <-s.errChan:
		if err != nil {
			return err
		}
	}

	// bail early if its not unicast
	// we don't need to wait for the announce
	if s.mode != tunnel.Unicast {
		s.discovered = true
		s.accepted = true
		return nil
	}

	// wait for announce
	_, err := s.waitFor("announce", after())
	if err != nil {
		return err
	}

	// set discovered
	s.discovered = true

	return nil
}

// Open will fire the open message for the session. This is called by the dialler.
// This is to indicate that we want to create a new session.
func (s *session) Open() error {
	// create a new message
	msg := s.newMessage("open")

	// send open message
	if err := s.sendMsg(msg); err != nil {
		return err
	}

	// wait for an error response for send
	if err := s.wait(msg); err != nil {
		return err
	}

	// now wait for the accept message to be returned
	msg, err := s.waitFor("accept", s.dialTimeout)
	if err != nil {
		return err
	}

	// set to accepted
	s.accepted = true
	// set link
	s.link = msg.link

	return nil
}

// Accept sends the accept response to an open message from a dialled connection
func (s *session) Accept() error {
	msg := s.newMessage("accept")

	// send the accept message
	if err := s.sendMsg(msg); err != nil {
		return err
	}

	// wait for send response
	return s.wait(msg)
}

// Announce sends an announcement to notify that this session exists.
// This is primarily used by the listener.
func (s *session) Announce() error {
	msg := s.newMessage("announce")
	// we don't need an error back
	msg.errChan = nil
	// announce to all
	msg.mode = tunnel.Broadcast
	// we don't need the link
	msg.link = ""

	// send announce message
	return s.sendMsg(msg)
}

// Send is used to send a message
func (s *session) Send(m *transport.Message) error {
	var err error

	s.RLock()
	gcm := s.gcm
	s.RUnlock()

	if gcm == nil {
		gcm, err = newCipher(s.key)
		if err != nil {
			return err
		}
		s.Lock()
		s.gcm = gcm
		s.Unlock()
	}
	// encrypt the transport message payload
	body, err := Encrypt(gcm, m.Body)
	if err != nil {
		log.Debugf("failed to encrypt message body: %v", err)
		return err
	}

	// make copy, without rehash and realloc
	data := &transport.Message{
		Header: make(map[string]string, len(m.Header)),
		Body:   body,
	}

	// encrypt all the headers
	for k, v := range m.Header {
		// encrypt the transport message payload
		val, err := Encrypt(s.gcm, []byte(v))
		if err != nil {
			log.Debugf("failed to encrypt message header %s: %v", k, err)
			return err
		}
		// add the encrypted header value
		data.Header[k] = base32.StdEncoding.EncodeToString(val)
	}

	// create a new message
	msg := s.newMessage("session")
	// set the data
	msg.data = data

	// if multicast don't set the link
	if s.mode != tunnel.Unicast {
		msg.link = ""
	}

	if logger.V(logger.TraceLevel, log) {
		log.Tracef("Appending to send backlog: %v", msg)
	}
	// send the actual message
	if err := s.sendMsg(msg); err != nil {
		return err
	}

	// wait for an error response
	return s.wait(msg)
}

// Recv is used to receive a message
func (s *session) Recv(m *transport.Message) error {
	var msg *message

	msg, err := s.waitFor("", s.readTimeout)
	if err != nil {
		return err
	}

	// check the error if one exists
	select {
	case err := <-msg.errChan:
		return err
	default:
	}

	if logger.V(logger.TraceLevel, log) {
		log.Tracef("Received from recv backlog: %v", msg)
	}

	gcm, err := newCipher([]byte(s.token + s.channel + msg.session))
	if err != nil {
		if logger.V(logger.ErrorLevel, log) {
			log.Errorf("unable to create cipher: %v", err)
		}
		return err
	}

	// decrypt the received payload using the token
	// we have to used msg.session because multicast has a shared
	// session id of "multicast" in this session struct on
	// the listener side
	msg.data.Body, err = Decrypt(gcm, msg.data.Body)
	if err != nil {
		if logger.V(logger.DebugLevel, log) {
			log.Debugf("failed to decrypt message body: %v", err)
		}
		return err
	}

	// dencrypt all the headers
	for k, v := range msg.data.Header {
		// decode the header values
		h, err := base32.StdEncoding.DecodeString(v)
		if err != nil {
			if logger.V(logger.DebugLevel, log) {
				log.Debugf("failed to decode message header %s: %v", k, err)
			}
			return err
		}

		// dencrypt the transport message payload
		val, err := Decrypt(gcm, h)
		if err != nil {
			if logger.V(logger.DebugLevel, log) {
				log.Debugf("failed to decrypt message header %s: %v", k, err)
			}
			return err
		}
		// add decrypted header value
		msg.data.Header[k] = string(val)
	}

	// set the link
	// TODO: decruft, this is only for multicast
	// since the session is now a single session
	// likely provide as part of message.Link()
	msg.data.Header["Micro-Link"] = msg.link

	// set message
	*m = *msg.data
	// return nil
	return nil
}

// Close closes the session by sending a close message
func (s *session) Close() error {
	select {
	case <-s.closed:
		// no op
	default:
		close(s.closed)

		// don't send close on multicast or broadcast
		if s.mode != tunnel.Unicast {
			return nil
		}

		// append to backlog
		msg := s.newMessage("close")
		// no error response on close
		msg.errChan = nil

		// send the close message
		select {
		case s.send <- msg:
		case <-time.After(time.Millisecond * 10):
		}
	}

	return nil
}