504 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			504 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package mucp
 | 
						|
 | 
						|
import (
 | 
						|
	"crypto/cipher"
 | 
						|
	"encoding/base32"
 | 
						|
	"io"
 | 
						|
	"sync"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"github.com/micro/go-micro/v3/logger"
 | 
						|
	"github.com/micro/go-micro/v3/network/transport"
 | 
						|
	"github.com/micro/go-micro/v3/network/tunnel"
 | 
						|
)
 | 
						|
 | 
						|
// session is our pseudo session for transport.Socket
 | 
						|
type session struct {
 | 
						|
	// the tunnel id
 | 
						|
	tunnel string
 | 
						|
	// the channel name
 | 
						|
	channel string
 | 
						|
	// the session id based on Micro.Tunnel-Session
 | 
						|
	session string
 | 
						|
	// token is the session token
 | 
						|
	token string
 | 
						|
	// closed
 | 
						|
	closed chan bool
 | 
						|
	// remote addr
 | 
						|
	remote string
 | 
						|
	// local addr
 | 
						|
	local string
 | 
						|
	// send chan
 | 
						|
	send chan *message
 | 
						|
	// recv chan
 | 
						|
	recv chan *message
 | 
						|
	// if the discovery worked
 | 
						|
	discovered bool
 | 
						|
	// if the session was accepted
 | 
						|
	accepted bool
 | 
						|
	// outbound marks the session as outbound dialled connection
 | 
						|
	outbound bool
 | 
						|
	// lookback marks the session as a loopback on the inbound
 | 
						|
	loopback bool
 | 
						|
	// mode of the connection
 | 
						|
	mode tunnel.Mode
 | 
						|
	// the dial timeout
 | 
						|
	dialTimeout time.Duration
 | 
						|
	// the read timeout
 | 
						|
	readTimeout time.Duration
 | 
						|
	// the link on which this message was received
 | 
						|
	link string
 | 
						|
	// the error response
 | 
						|
	errChan chan error
 | 
						|
	// key for session encryption
 | 
						|
	key []byte
 | 
						|
	// cipher for session
 | 
						|
	gcm cipher.AEAD
 | 
						|
	sync.RWMutex
 | 
						|
}
 | 
						|
 | 
						|
// message is sent over the send channel
 | 
						|
type message struct {
 | 
						|
	// type of message
 | 
						|
	typ string
 | 
						|
	// tunnel id
 | 
						|
	tunnel string
 | 
						|
	// channel name
 | 
						|
	channel string
 | 
						|
	// the session id
 | 
						|
	session string
 | 
						|
	// outbound marks the message as outbound
 | 
						|
	outbound bool
 | 
						|
	// loopback marks the message intended for loopback
 | 
						|
	loopback bool
 | 
						|
	// mode of the connection
 | 
						|
	mode tunnel.Mode
 | 
						|
	// the link to send the message on
 | 
						|
	link string
 | 
						|
	// transport data
 | 
						|
	data *transport.Message
 | 
						|
	// the error channel
 | 
						|
	errChan chan error
 | 
						|
}
 | 
						|
 | 
						|
func (s *session) Remote() string {
 | 
						|
	return s.remote
 | 
						|
}
 | 
						|
 | 
						|
func (s *session) Local() string {
 | 
						|
	return s.local
 | 
						|
}
 | 
						|
 | 
						|
func (s *session) Link() string {
 | 
						|
	return s.link
 | 
						|
}
 | 
						|
 | 
						|
func (s *session) Id() string {
 | 
						|
	return s.session
 | 
						|
}
 | 
						|
 | 
						|
func (s *session) Channel() string {
 | 
						|
	return s.channel
 | 
						|
}
 | 
						|
 | 
						|
// newMessage creates a new message based on the session
 | 
						|
func (s *session) newMessage(typ string) *message {
 | 
						|
	return &message{
 | 
						|
		typ:      typ,
 | 
						|
		tunnel:   s.tunnel,
 | 
						|
		channel:  s.channel,
 | 
						|
		session:  s.session,
 | 
						|
		outbound: s.outbound,
 | 
						|
		loopback: s.loopback,
 | 
						|
		mode:     s.mode,
 | 
						|
		link:     s.link,
 | 
						|
		errChan:  s.errChan,
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (s *session) sendMsg(msg *message) error {
 | 
						|
	select {
 | 
						|
	case <-s.closed:
 | 
						|
		return io.EOF
 | 
						|
	case s.send <- msg:
 | 
						|
		return nil
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (s *session) wait(msg *message) error {
 | 
						|
	// wait for an error response
 | 
						|
	select {
 | 
						|
	case err := <-msg.errChan:
 | 
						|
		if err != nil {
 | 
						|
			return err
 | 
						|
		}
 | 
						|
	case <-s.closed:
 | 
						|
		return io.EOF
 | 
						|
	}
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
// waitFor waits for the message type required until the timeout specified
 | 
						|
func (s *session) waitFor(msgType string, timeout time.Duration) (*message, error) {
 | 
						|
	now := time.Now()
 | 
						|
 | 
						|
	after := func(timeout time.Duration) <-chan time.Time {
 | 
						|
		if timeout < time.Duration(0) {
 | 
						|
			return nil
 | 
						|
		}
 | 
						|
 | 
						|
		// get the delta
 | 
						|
		d := time.Since(now)
 | 
						|
 | 
						|
		// dial timeout minus time since
 | 
						|
		wait := timeout - d
 | 
						|
 | 
						|
		if wait < time.Duration(0) {
 | 
						|
			wait = time.Duration(0)
 | 
						|
		}
 | 
						|
 | 
						|
		return time.After(wait)
 | 
						|
	}
 | 
						|
 | 
						|
	// wait for the message type
 | 
						|
	for {
 | 
						|
		select {
 | 
						|
		case msg := <-s.recv:
 | 
						|
			// there may be no message type
 | 
						|
			if len(msgType) == 0 {
 | 
						|
				return msg, nil
 | 
						|
			}
 | 
						|
 | 
						|
			// ignore what we don't want
 | 
						|
			if msg.typ != msgType {
 | 
						|
				if logger.V(logger.DebugLevel, log) {
 | 
						|
					log.Debugf("Tunnel received non %s message in waiting for %s", msg.typ, msgType)
 | 
						|
				}
 | 
						|
				continue
 | 
						|
			}
 | 
						|
 | 
						|
			// got the message
 | 
						|
			return msg, nil
 | 
						|
		case <-after(timeout):
 | 
						|
			return nil, tunnel.ErrReadTimeout
 | 
						|
		case <-s.closed:
 | 
						|
			// check pending message queue
 | 
						|
			select {
 | 
						|
			case msg := <-s.recv:
 | 
						|
				// there may be no message type
 | 
						|
				if len(msgType) == 0 {
 | 
						|
					return msg, nil
 | 
						|
				}
 | 
						|
 | 
						|
				// ignore what we don't want
 | 
						|
				if msg.typ != msgType {
 | 
						|
					if logger.V(logger.DebugLevel, log) {
 | 
						|
						log.Debugf("Tunnel received non %s message in waiting for %s", msg.typ, msgType)
 | 
						|
					}
 | 
						|
					continue
 | 
						|
				}
 | 
						|
 | 
						|
				// got the message
 | 
						|
				return msg, nil
 | 
						|
			default:
 | 
						|
				// non blocking
 | 
						|
			}
 | 
						|
			return nil, io.EOF
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// Discover attempts to discover the link for a specific channel.
 | 
						|
// This is only used by the tunnel.Dial when first connecting.
 | 
						|
func (s *session) Discover() error {
 | 
						|
	// create a new discovery message for this channel
 | 
						|
	msg := s.newMessage("discover")
 | 
						|
	// broadcast the message to all links
 | 
						|
	msg.mode = tunnel.Broadcast
 | 
						|
	// its an outbound connection since we're dialling
 | 
						|
	msg.outbound = true
 | 
						|
	// don't set the link since we don't know where it is
 | 
						|
	msg.link = ""
 | 
						|
 | 
						|
	// if multicast then set that as session
 | 
						|
	if s.mode == tunnel.Multicast {
 | 
						|
		msg.session = "multicast"
 | 
						|
	}
 | 
						|
 | 
						|
	// send discover message
 | 
						|
	if err := s.sendMsg(msg); err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	// set time now
 | 
						|
	now := time.Now()
 | 
						|
 | 
						|
	// after strips down the dial timeout
 | 
						|
	after := func() time.Duration {
 | 
						|
		d := time.Since(now)
 | 
						|
		// dial timeout minus time since
 | 
						|
		wait := s.dialTimeout - d
 | 
						|
		// make sure its always > 0
 | 
						|
		if wait < time.Duration(0) {
 | 
						|
			return time.Duration(0)
 | 
						|
		}
 | 
						|
		return wait
 | 
						|
	}
 | 
						|
 | 
						|
	// the discover message is sent out, now
 | 
						|
	// wait to hear back about the sent message
 | 
						|
	select {
 | 
						|
	case <-time.After(after()):
 | 
						|
		return tunnel.ErrDialTimeout
 | 
						|
	case err := <-s.errChan:
 | 
						|
		if err != nil {
 | 
						|
			return err
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	// bail early if its not unicast
 | 
						|
	// we don't need to wait for the announce
 | 
						|
	if s.mode != tunnel.Unicast {
 | 
						|
		s.discovered = true
 | 
						|
		s.accepted = true
 | 
						|
		return nil
 | 
						|
	}
 | 
						|
 | 
						|
	// wait for announce
 | 
						|
	_, err := s.waitFor("announce", after())
 | 
						|
	if err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	// set discovered
 | 
						|
	s.discovered = true
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
// Open will fire the open message for the session. This is called by the dialler.
 | 
						|
// This is to indicate that we want to create a new session.
 | 
						|
func (s *session) Open() error {
 | 
						|
	// create a new message
 | 
						|
	msg := s.newMessage("open")
 | 
						|
 | 
						|
	// send open message
 | 
						|
	if err := s.sendMsg(msg); err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	// wait for an error response for send
 | 
						|
	if err := s.wait(msg); err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	// now wait for the accept message to be returned
 | 
						|
	msg, err := s.waitFor("accept", s.dialTimeout)
 | 
						|
	if err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	// set to accepted
 | 
						|
	s.accepted = true
 | 
						|
	// set link
 | 
						|
	s.link = msg.link
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
// Accept sends the accept response to an open message from a dialled connection
 | 
						|
func (s *session) Accept() error {
 | 
						|
	msg := s.newMessage("accept")
 | 
						|
 | 
						|
	// send the accept message
 | 
						|
	if err := s.sendMsg(msg); err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	// wait for send response
 | 
						|
	return s.wait(msg)
 | 
						|
}
 | 
						|
 | 
						|
// Announce sends an announcement to notify that this session exists.
 | 
						|
// This is primarily used by the listener.
 | 
						|
func (s *session) Announce() error {
 | 
						|
	msg := s.newMessage("announce")
 | 
						|
	// we don't need an error back
 | 
						|
	msg.errChan = nil
 | 
						|
	// announce to all
 | 
						|
	msg.mode = tunnel.Broadcast
 | 
						|
	// we don't need the link
 | 
						|
	msg.link = ""
 | 
						|
 | 
						|
	// send announce message
 | 
						|
	return s.sendMsg(msg)
 | 
						|
}
 | 
						|
 | 
						|
// Send is used to send a message
 | 
						|
func (s *session) Send(m *transport.Message) error {
 | 
						|
	var err error
 | 
						|
 | 
						|
	s.RLock()
 | 
						|
	gcm := s.gcm
 | 
						|
	s.RUnlock()
 | 
						|
 | 
						|
	if gcm == nil {
 | 
						|
		gcm, err = newCipher(s.key)
 | 
						|
		if err != nil {
 | 
						|
			return err
 | 
						|
		}
 | 
						|
		s.Lock()
 | 
						|
		s.gcm = gcm
 | 
						|
		s.Unlock()
 | 
						|
	}
 | 
						|
	// encrypt the transport message payload
 | 
						|
	body, err := Encrypt(gcm, m.Body)
 | 
						|
	if err != nil {
 | 
						|
		log.Debugf("failed to encrypt message body: %v", err)
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	// make copy, without rehash and realloc
 | 
						|
	data := &transport.Message{
 | 
						|
		Header: make(map[string]string, len(m.Header)),
 | 
						|
		Body:   body,
 | 
						|
	}
 | 
						|
 | 
						|
	// encrypt all the headers
 | 
						|
	for k, v := range m.Header {
 | 
						|
		// encrypt the transport message payload
 | 
						|
		val, err := Encrypt(s.gcm, []byte(v))
 | 
						|
		if err != nil {
 | 
						|
			log.Debugf("failed to encrypt message header %s: %v", k, err)
 | 
						|
			return err
 | 
						|
		}
 | 
						|
		// add the encrypted header value
 | 
						|
		data.Header[k] = base32.StdEncoding.EncodeToString(val)
 | 
						|
	}
 | 
						|
 | 
						|
	// create a new message
 | 
						|
	msg := s.newMessage("session")
 | 
						|
	// set the data
 | 
						|
	msg.data = data
 | 
						|
 | 
						|
	// if multicast don't set the link
 | 
						|
	if s.mode != tunnel.Unicast {
 | 
						|
		msg.link = ""
 | 
						|
	}
 | 
						|
 | 
						|
	if logger.V(logger.TraceLevel, log) {
 | 
						|
		log.Tracef("Appending to send backlog: %v", msg)
 | 
						|
	}
 | 
						|
	// send the actual message
 | 
						|
	if err := s.sendMsg(msg); err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	// wait for an error response
 | 
						|
	return s.wait(msg)
 | 
						|
}
 | 
						|
 | 
						|
// Recv is used to receive a message
 | 
						|
func (s *session) Recv(m *transport.Message) error {
 | 
						|
	var msg *message
 | 
						|
 | 
						|
	msg, err := s.waitFor("", s.readTimeout)
 | 
						|
	if err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	// check the error if one exists
 | 
						|
	select {
 | 
						|
	case err := <-msg.errChan:
 | 
						|
		return err
 | 
						|
	default:
 | 
						|
	}
 | 
						|
 | 
						|
	if logger.V(logger.TraceLevel, log) {
 | 
						|
		log.Tracef("Received from recv backlog: %v", msg)
 | 
						|
	}
 | 
						|
 | 
						|
	gcm, err := newCipher([]byte(s.token + s.channel + msg.session))
 | 
						|
	if err != nil {
 | 
						|
		if logger.V(logger.ErrorLevel, log) {
 | 
						|
			log.Errorf("unable to create cipher: %v", err)
 | 
						|
		}
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	// decrypt the received payload using the token
 | 
						|
	// we have to used msg.session because multicast has a shared
 | 
						|
	// session id of "multicast" in this session struct on
 | 
						|
	// the listener side
 | 
						|
	msg.data.Body, err = Decrypt(gcm, msg.data.Body)
 | 
						|
	if err != nil {
 | 
						|
		if logger.V(logger.DebugLevel, log) {
 | 
						|
			log.Debugf("failed to decrypt message body: %v", err)
 | 
						|
		}
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	// dencrypt all the headers
 | 
						|
	for k, v := range msg.data.Header {
 | 
						|
		// decode the header values
 | 
						|
		h, err := base32.StdEncoding.DecodeString(v)
 | 
						|
		if err != nil {
 | 
						|
			if logger.V(logger.DebugLevel, log) {
 | 
						|
				log.Debugf("failed to decode message header %s: %v", k, err)
 | 
						|
			}
 | 
						|
			return err
 | 
						|
		}
 | 
						|
 | 
						|
		// dencrypt the transport message payload
 | 
						|
		val, err := Decrypt(gcm, h)
 | 
						|
		if err != nil {
 | 
						|
			if logger.V(logger.DebugLevel, log) {
 | 
						|
				log.Debugf("failed to decrypt message header %s: %v", k, err)
 | 
						|
			}
 | 
						|
			return err
 | 
						|
		}
 | 
						|
		// add decrypted header value
 | 
						|
		msg.data.Header[k] = string(val)
 | 
						|
	}
 | 
						|
 | 
						|
	// set the link
 | 
						|
	// TODO: decruft, this is only for multicast
 | 
						|
	// since the session is now a single session
 | 
						|
	// likely provide as part of message.Link()
 | 
						|
	msg.data.Header["Micro-Link"] = msg.link
 | 
						|
 | 
						|
	// set message
 | 
						|
	*m = *msg.data
 | 
						|
	// return nil
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
// Close closes the session by sending a close message
 | 
						|
func (s *session) Close() error {
 | 
						|
	select {
 | 
						|
	case <-s.closed:
 | 
						|
		// no op
 | 
						|
	default:
 | 
						|
		close(s.closed)
 | 
						|
 | 
						|
		// don't send close on multicast or broadcast
 | 
						|
		if s.mode != tunnel.Unicast {
 | 
						|
			return nil
 | 
						|
		}
 | 
						|
 | 
						|
		// append to backlog
 | 
						|
		msg := s.newMessage("close")
 | 
						|
		// no error response on close
 | 
						|
		msg.errChan = nil
 | 
						|
 | 
						|
		// send the close message
 | 
						|
		select {
 | 
						|
		case s.send <- msg:
 | 
						|
		case <-time.After(time.Millisecond * 10):
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 |