move tunnel/resolver into network
This commit is contained in:
503
network/tunnel/mucp/session.go
Normal file
503
network/tunnel/mucp/session.go
Normal file
@@ -0,0 +1,503 @@
|
||||
package mucp
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"encoding/base32"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/micro/go-micro/v3/logger"
|
||||
"github.com/micro/go-micro/v3/transport"
|
||||
"github.com/micro/go-micro/v3/network/tunnel"
|
||||
)
|
||||
|
||||
// session is our pseudo session for transport.Socket
|
||||
type session struct {
|
||||
// the tunnel id
|
||||
tunnel string
|
||||
// the channel name
|
||||
channel string
|
||||
// the session id based on Micro.Tunnel-Session
|
||||
session string
|
||||
// token is the session token
|
||||
token string
|
||||
// closed
|
||||
closed chan bool
|
||||
// remote addr
|
||||
remote string
|
||||
// local addr
|
||||
local string
|
||||
// send chan
|
||||
send chan *message
|
||||
// recv chan
|
||||
recv chan *message
|
||||
// if the discovery worked
|
||||
discovered bool
|
||||
// if the session was accepted
|
||||
accepted bool
|
||||
// outbound marks the session as outbound dialled connection
|
||||
outbound bool
|
||||
// lookback marks the session as a loopback on the inbound
|
||||
loopback bool
|
||||
// mode of the connection
|
||||
mode tunnel.Mode
|
||||
// the dial timeout
|
||||
dialTimeout time.Duration
|
||||
// the read timeout
|
||||
readTimeout time.Duration
|
||||
// the link on which this message was received
|
||||
link string
|
||||
// the error response
|
||||
errChan chan error
|
||||
// key for session encryption
|
||||
key []byte
|
||||
// cipher for session
|
||||
gcm cipher.AEAD
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
// message is sent over the send channel
|
||||
type message struct {
|
||||
// type of message
|
||||
typ string
|
||||
// tunnel id
|
||||
tunnel string
|
||||
// channel name
|
||||
channel string
|
||||
// the session id
|
||||
session string
|
||||
// outbound marks the message as outbound
|
||||
outbound bool
|
||||
// loopback marks the message intended for loopback
|
||||
loopback bool
|
||||
// mode of the connection
|
||||
mode tunnel.Mode
|
||||
// the link to send the message on
|
||||
link string
|
||||
// transport data
|
||||
data *transport.Message
|
||||
// the error channel
|
||||
errChan chan error
|
||||
}
|
||||
|
||||
func (s *session) Remote() string {
|
||||
return s.remote
|
||||
}
|
||||
|
||||
func (s *session) Local() string {
|
||||
return s.local
|
||||
}
|
||||
|
||||
func (s *session) Link() string {
|
||||
return s.link
|
||||
}
|
||||
|
||||
func (s *session) Id() string {
|
||||
return s.session
|
||||
}
|
||||
|
||||
func (s *session) Channel() string {
|
||||
return s.channel
|
||||
}
|
||||
|
||||
// newMessage creates a new message based on the session
|
||||
func (s *session) newMessage(typ string) *message {
|
||||
return &message{
|
||||
typ: typ,
|
||||
tunnel: s.tunnel,
|
||||
channel: s.channel,
|
||||
session: s.session,
|
||||
outbound: s.outbound,
|
||||
loopback: s.loopback,
|
||||
mode: s.mode,
|
||||
link: s.link,
|
||||
errChan: s.errChan,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) sendMsg(msg *message) error {
|
||||
select {
|
||||
case <-s.closed:
|
||||
return io.EOF
|
||||
case s.send <- msg:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) wait(msg *message) error {
|
||||
// wait for an error response
|
||||
select {
|
||||
case err := <-msg.errChan:
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case <-s.closed:
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// waitFor waits for the message type required until the timeout specified
|
||||
func (s *session) waitFor(msgType string, timeout time.Duration) (*message, error) {
|
||||
now := time.Now()
|
||||
|
||||
after := func(timeout time.Duration) <-chan time.Time {
|
||||
if timeout < time.Duration(0) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// get the delta
|
||||
d := time.Since(now)
|
||||
|
||||
// dial timeout minus time since
|
||||
wait := timeout - d
|
||||
|
||||
if wait < time.Duration(0) {
|
||||
wait = time.Duration(0)
|
||||
}
|
||||
|
||||
return time.After(wait)
|
||||
}
|
||||
|
||||
// wait for the message type
|
||||
for {
|
||||
select {
|
||||
case msg := <-s.recv:
|
||||
// there may be no message type
|
||||
if len(msgType) == 0 {
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// ignore what we don't want
|
||||
if msg.typ != msgType {
|
||||
if logger.V(logger.DebugLevel, log) {
|
||||
log.Debugf("Tunnel received non %s message in waiting for %s", msg.typ, msgType)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// got the message
|
||||
return msg, nil
|
||||
case <-after(timeout):
|
||||
return nil, tunnel.ErrReadTimeout
|
||||
case <-s.closed:
|
||||
// check pending message queue
|
||||
select {
|
||||
case msg := <-s.recv:
|
||||
// there may be no message type
|
||||
if len(msgType) == 0 {
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// ignore what we don't want
|
||||
if msg.typ != msgType {
|
||||
if logger.V(logger.DebugLevel, log) {
|
||||
log.Debugf("Tunnel received non %s message in waiting for %s", msg.typ, msgType)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// got the message
|
||||
return msg, nil
|
||||
default:
|
||||
// non blocking
|
||||
}
|
||||
return nil, io.EOF
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Discover attempts to discover the link for a specific channel.
|
||||
// This is only used by the tunnel.Dial when first connecting.
|
||||
func (s *session) Discover() error {
|
||||
// create a new discovery message for this channel
|
||||
msg := s.newMessage("discover")
|
||||
// broadcast the message to all links
|
||||
msg.mode = tunnel.Broadcast
|
||||
// its an outbound connection since we're dialling
|
||||
msg.outbound = true
|
||||
// don't set the link since we don't know where it is
|
||||
msg.link = ""
|
||||
|
||||
// if multicast then set that as session
|
||||
if s.mode == tunnel.Multicast {
|
||||
msg.session = "multicast"
|
||||
}
|
||||
|
||||
// send discover message
|
||||
if err := s.sendMsg(msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// set time now
|
||||
now := time.Now()
|
||||
|
||||
// after strips down the dial timeout
|
||||
after := func() time.Duration {
|
||||
d := time.Since(now)
|
||||
// dial timeout minus time since
|
||||
wait := s.dialTimeout - d
|
||||
// make sure its always > 0
|
||||
if wait < time.Duration(0) {
|
||||
return time.Duration(0)
|
||||
}
|
||||
return wait
|
||||
}
|
||||
|
||||
// the discover message is sent out, now
|
||||
// wait to hear back about the sent message
|
||||
select {
|
||||
case <-time.After(after()):
|
||||
return tunnel.ErrDialTimeout
|
||||
case err := <-s.errChan:
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// bail early if its not unicast
|
||||
// we don't need to wait for the announce
|
||||
if s.mode != tunnel.Unicast {
|
||||
s.discovered = true
|
||||
s.accepted = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// wait for announce
|
||||
_, err := s.waitFor("announce", after())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// set discovered
|
||||
s.discovered = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Open will fire the open message for the session. This is called by the dialler.
|
||||
// This is to indicate that we want to create a new session.
|
||||
func (s *session) Open() error {
|
||||
// create a new message
|
||||
msg := s.newMessage("open")
|
||||
|
||||
// send open message
|
||||
if err := s.sendMsg(msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// wait for an error response for send
|
||||
if err := s.wait(msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// now wait for the accept message to be returned
|
||||
msg, err := s.waitFor("accept", s.dialTimeout)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// set to accepted
|
||||
s.accepted = true
|
||||
// set link
|
||||
s.link = msg.link
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Accept sends the accept response to an open message from a dialled connection
|
||||
func (s *session) Accept() error {
|
||||
msg := s.newMessage("accept")
|
||||
|
||||
// send the accept message
|
||||
if err := s.sendMsg(msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// wait for send response
|
||||
return s.wait(msg)
|
||||
}
|
||||
|
||||
// Announce sends an announcement to notify that this session exists.
|
||||
// This is primarily used by the listener.
|
||||
func (s *session) Announce() error {
|
||||
msg := s.newMessage("announce")
|
||||
// we don't need an error back
|
||||
msg.errChan = nil
|
||||
// announce to all
|
||||
msg.mode = tunnel.Broadcast
|
||||
// we don't need the link
|
||||
msg.link = ""
|
||||
|
||||
// send announce message
|
||||
return s.sendMsg(msg)
|
||||
}
|
||||
|
||||
// Send is used to send a message
|
||||
func (s *session) Send(m *transport.Message) error {
|
||||
var err error
|
||||
|
||||
s.RLock()
|
||||
gcm := s.gcm
|
||||
s.RUnlock()
|
||||
|
||||
if gcm == nil {
|
||||
gcm, err = newCipher(s.key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.Lock()
|
||||
s.gcm = gcm
|
||||
s.Unlock()
|
||||
}
|
||||
// encrypt the transport message payload
|
||||
body, err := Encrypt(gcm, m.Body)
|
||||
if err != nil {
|
||||
log.Debugf("failed to encrypt message body: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// make copy, without rehash and realloc
|
||||
data := &transport.Message{
|
||||
Header: make(map[string]string, len(m.Header)),
|
||||
Body: body,
|
||||
}
|
||||
|
||||
// encrypt all the headers
|
||||
for k, v := range m.Header {
|
||||
// encrypt the transport message payload
|
||||
val, err := Encrypt(s.gcm, []byte(v))
|
||||
if err != nil {
|
||||
log.Debugf("failed to encrypt message header %s: %v", k, err)
|
||||
return err
|
||||
}
|
||||
// add the encrypted header value
|
||||
data.Header[k] = base32.StdEncoding.EncodeToString(val)
|
||||
}
|
||||
|
||||
// create a new message
|
||||
msg := s.newMessage("session")
|
||||
// set the data
|
||||
msg.data = data
|
||||
|
||||
// if multicast don't set the link
|
||||
if s.mode != tunnel.Unicast {
|
||||
msg.link = ""
|
||||
}
|
||||
|
||||
if logger.V(logger.TraceLevel, log) {
|
||||
log.Tracef("Appending to send backlog: %v", msg)
|
||||
}
|
||||
// send the actual message
|
||||
if err := s.sendMsg(msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// wait for an error response
|
||||
return s.wait(msg)
|
||||
}
|
||||
|
||||
// Recv is used to receive a message
|
||||
func (s *session) Recv(m *transport.Message) error {
|
||||
var msg *message
|
||||
|
||||
msg, err := s.waitFor("", s.readTimeout)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// check the error if one exists
|
||||
select {
|
||||
case err := <-msg.errChan:
|
||||
return err
|
||||
default:
|
||||
}
|
||||
|
||||
if logger.V(logger.TraceLevel, log) {
|
||||
log.Tracef("Received from recv backlog: %v", msg)
|
||||
}
|
||||
|
||||
gcm, err := newCipher([]byte(s.token + s.channel + msg.session))
|
||||
if err != nil {
|
||||
if logger.V(logger.ErrorLevel, log) {
|
||||
log.Errorf("unable to create cipher: %v", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// decrypt the received payload using the token
|
||||
// we have to used msg.session because multicast has a shared
|
||||
// session id of "multicast" in this session struct on
|
||||
// the listener side
|
||||
msg.data.Body, err = Decrypt(gcm, msg.data.Body)
|
||||
if err != nil {
|
||||
if logger.V(logger.DebugLevel, log) {
|
||||
log.Debugf("failed to decrypt message body: %v", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// dencrypt all the headers
|
||||
for k, v := range msg.data.Header {
|
||||
// decode the header values
|
||||
h, err := base32.StdEncoding.DecodeString(v)
|
||||
if err != nil {
|
||||
if logger.V(logger.DebugLevel, log) {
|
||||
log.Debugf("failed to decode message header %s: %v", k, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// dencrypt the transport message payload
|
||||
val, err := Decrypt(gcm, h)
|
||||
if err != nil {
|
||||
if logger.V(logger.DebugLevel, log) {
|
||||
log.Debugf("failed to decrypt message header %s: %v", k, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
// add decrypted header value
|
||||
msg.data.Header[k] = string(val)
|
||||
}
|
||||
|
||||
// set the link
|
||||
// TODO: decruft, this is only for multicast
|
||||
// since the session is now a single session
|
||||
// likely provide as part of message.Link()
|
||||
msg.data.Header["Micro-Link"] = msg.link
|
||||
|
||||
// set message
|
||||
*m = *msg.data
|
||||
// return nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the session by sending a close message
|
||||
func (s *session) Close() error {
|
||||
select {
|
||||
case <-s.closed:
|
||||
// no op
|
||||
default:
|
||||
close(s.closed)
|
||||
|
||||
// don't send close on multicast or broadcast
|
||||
if s.mode != tunnel.Unicast {
|
||||
return nil
|
||||
}
|
||||
|
||||
// append to backlog
|
||||
msg := s.newMessage("close")
|
||||
// no error response on close
|
||||
msg.errChan = nil
|
||||
|
||||
// send the close message
|
||||
select {
|
||||
case s.send <- msg:
|
||||
case <-time.After(time.Millisecond * 10):
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user