Rename Tunnel ID to Channel
This commit is contained in:
@@ -1,9 +1,8 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -25,7 +24,10 @@ type tun struct {
|
||||
|
||||
sync.RWMutex
|
||||
|
||||
// tunnel token
|
||||
// the unique id for this tunnel
|
||||
id string
|
||||
|
||||
// tunnel token for authentication
|
||||
token string
|
||||
|
||||
// to indicate if we're connected or not
|
||||
@@ -37,8 +39,8 @@ type tun struct {
|
||||
// close channel
|
||||
closed chan bool
|
||||
|
||||
// a map of sockets based on Micro-Tunnel-Id
|
||||
sockets map[string]*socket
|
||||
// a map of sessions based on Micro-Tunnel-Channel
|
||||
sessions map[string]*session
|
||||
|
||||
// outbound links
|
||||
links map[string]*link
|
||||
@@ -47,25 +49,6 @@ type tun struct {
|
||||
listener transport.Listener
|
||||
}
|
||||
|
||||
type link struct {
|
||||
transport.Socket
|
||||
// unique id of this link e.g uuid
|
||||
// which we define for ourselves
|
||||
id string
|
||||
// whether its a loopback connection
|
||||
// this flag is used by the transport listener
|
||||
// which accepts inbound quic connections
|
||||
loopback bool
|
||||
// whether its actually connected
|
||||
// dialled side sets it to connected
|
||||
// after sending the message. the
|
||||
// listener waits for the connect
|
||||
connected bool
|
||||
// the last time we received a keepalive
|
||||
// on this link from the remote side
|
||||
lastKeepAlive time.Time
|
||||
}
|
||||
|
||||
// create new tunnel on top of a link
|
||||
func newTunnel(opts ...Option) *tun {
|
||||
options := DefaultOptions()
|
||||
@@ -74,12 +57,13 @@ func newTunnel(opts ...Option) *tun {
|
||||
}
|
||||
|
||||
return &tun{
|
||||
options: options,
|
||||
token: uuid.New().String(),
|
||||
send: make(chan *message, 128),
|
||||
closed: make(chan bool),
|
||||
sockets: make(map[string]*socket),
|
||||
links: make(map[string]*link),
|
||||
options: options,
|
||||
id: options.Id,
|
||||
token: options.Token,
|
||||
send: make(chan *message, 128),
|
||||
closed: make(chan bool),
|
||||
sessions: make(map[string]*session),
|
||||
links: make(map[string]*link),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -93,51 +77,48 @@ func (t *tun) Init(opts ...Option) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// getSocket returns a socket from the internal socket map.
|
||||
// It does this based on the Micro-Tunnel-Id and Micro-Tunnel-Session
|
||||
func (t *tun) getSocket(id, session string) (*socket, bool) {
|
||||
// get the socket
|
||||
// getSession returns a session from the internal session map.
|
||||
// It does this based on the Micro-Tunnel-Channel and Micro-Tunnel-Session
|
||||
func (t *tun) getSession(channel, session string) (*session, bool) {
|
||||
// get the session
|
||||
t.RLock()
|
||||
s, ok := t.sockets[id+session]
|
||||
s, ok := t.sessions[channel+session]
|
||||
t.RUnlock()
|
||||
return s, ok
|
||||
}
|
||||
|
||||
// newSocket creates a new socket and saves it
|
||||
func (t *tun) newSocket(id, session string) (*socket, bool) {
|
||||
// hash the id
|
||||
h := sha256.New()
|
||||
h.Write([]byte(id))
|
||||
id = fmt.Sprintf("%x", h.Sum(nil))
|
||||
|
||||
// new socket
|
||||
s := &socket{
|
||||
id: id,
|
||||
session: session,
|
||||
// newSession creates a new session and saves it
|
||||
func (t *tun) newSession(channel, sessionId string) (*session, bool) {
|
||||
// new session
|
||||
s := &session{
|
||||
id: t.id,
|
||||
channel: channel,
|
||||
session: sessionId,
|
||||
closed: make(chan bool),
|
||||
recv: make(chan *message, 128),
|
||||
send: t.send,
|
||||
wait: make(chan bool),
|
||||
errChan: make(chan error, 1),
|
||||
}
|
||||
|
||||
// save socket
|
||||
// save session
|
||||
t.Lock()
|
||||
_, ok := t.sockets[id+session]
|
||||
_, ok := t.sessions[channel+sessionId]
|
||||
if ok {
|
||||
// socket already exists
|
||||
// session already exists
|
||||
t.Unlock()
|
||||
return nil, false
|
||||
}
|
||||
|
||||
t.sockets[id+session] = s
|
||||
t.sessions[channel+sessionId] = s
|
||||
t.Unlock()
|
||||
|
||||
// return socket
|
||||
// return session
|
||||
return s, true
|
||||
}
|
||||
|
||||
// TODO: use tunnel id as part of the session
|
||||
func (t *tun) newSession() string {
|
||||
func (t *tun) newSessionId() string {
|
||||
return uuid.New().String()
|
||||
}
|
||||
|
||||
@@ -168,10 +149,10 @@ func (t *tun) monitor() {
|
||||
}
|
||||
}
|
||||
|
||||
// process outgoing messages sent by all local sockets
|
||||
// process outgoing messages sent by all local sessions
|
||||
func (t *tun) process() {
|
||||
// manage the send buffer
|
||||
// all pseudo sockets throw everything down this
|
||||
// all pseudo sessions throw everything down this
|
||||
for {
|
||||
select {
|
||||
case msg := <-t.send:
|
||||
@@ -190,6 +171,9 @@ func (t *tun) process() {
|
||||
// set the tunnel id on the outgoing message
|
||||
newMsg.Header["Micro-Tunnel-Id"] = msg.id
|
||||
|
||||
// set the tunnel channel on the outgoing message
|
||||
newMsg.Header["Micro-Tunnel-Channel"] = msg.channel
|
||||
|
||||
// set the session id
|
||||
newMsg.Header["Micro-Tunnel-Session"] = msg.session
|
||||
|
||||
@@ -203,10 +187,14 @@ func (t *tun) process() {
|
||||
log.Debugf("No links to send to")
|
||||
}
|
||||
|
||||
var sent bool
|
||||
var err error
|
||||
|
||||
for node, link := range t.links {
|
||||
// if the link is not connected skip it
|
||||
if !link.connected {
|
||||
log.Debugf("Link for node %s not connected", node)
|
||||
err = errors.New("link not connected")
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -214,6 +202,7 @@ func (t *tun) process() {
|
||||
// this is where we explicitly set the link
|
||||
// in a message received via the listen method
|
||||
if len(msg.link) > 0 && link.id != msg.link {
|
||||
err = errors.New("link not found")
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -221,25 +210,41 @@ func (t *tun) process() {
|
||||
// and the message is being sent outbound via
|
||||
// a dialled connection don't use this link
|
||||
if link.loopback && msg.outbound {
|
||||
err = errors.New("link is loopback")
|
||||
continue
|
||||
}
|
||||
|
||||
// if the message was being returned by the loopback listener
|
||||
// send it back up the loopback link only
|
||||
if msg.loopback && !link.loopback {
|
||||
err = errors.New("link is not loopback")
|
||||
continue
|
||||
}
|
||||
|
||||
// send the message via the current link
|
||||
log.Debugf("Sending %+v to %s", newMsg, node)
|
||||
if err := link.Send(newMsg); err != nil {
|
||||
log.Debugf("Tunnel error sending %+v to %s: %v", newMsg, node, err)
|
||||
if errr := link.Send(newMsg); errr != nil {
|
||||
log.Debugf("Tunnel error sending %+v to %s: %v", newMsg, node, errr)
|
||||
err = errors.New(errr.Error())
|
||||
delete(t.links, node)
|
||||
continue
|
||||
}
|
||||
// is sent
|
||||
sent = true
|
||||
}
|
||||
|
||||
t.Unlock()
|
||||
|
||||
var gerr error
|
||||
if !sent {
|
||||
gerr = err
|
||||
}
|
||||
|
||||
// return error non blocking
|
||||
select {
|
||||
case msg.errChan <- gerr:
|
||||
default:
|
||||
}
|
||||
case <-t.closed:
|
||||
return
|
||||
}
|
||||
@@ -267,17 +272,23 @@ func (t *tun) listen(link *link) {
|
||||
return
|
||||
}
|
||||
|
||||
// always ensure we have the correct auth token
|
||||
// TODO: segment the tunnel based on token
|
||||
// e.g use it as the basis
|
||||
token := msg.Header["Micro-Tunnel-Token"]
|
||||
if token != t.token {
|
||||
log.Debugf("Tunnel link %s received invalid token %s", token)
|
||||
return
|
||||
}
|
||||
|
||||
switch msg.Header["Micro-Tunnel"] {
|
||||
case "connect":
|
||||
log.Debugf("Tunnel link %s received connect message", link.Remote())
|
||||
// check the Micro-Tunnel-Token
|
||||
token, ok := msg.Header["Micro-Tunnel-Token"]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
id := msg.Header["Micro-Tunnel-Id"]
|
||||
|
||||
// are we connecting to ourselves?
|
||||
if token == t.token {
|
||||
if id == t.id {
|
||||
link.loopback = true
|
||||
loopback = true
|
||||
}
|
||||
@@ -318,76 +329,77 @@ func (t *tun) listen(link *link) {
|
||||
return
|
||||
}
|
||||
|
||||
// strip message header
|
||||
delete(msg.Header, "Micro-Tunnel")
|
||||
|
||||
// the tunnel id
|
||||
id := msg.Header["Micro-Tunnel-Id"]
|
||||
delete(msg.Header, "Micro-Tunnel-Id")
|
||||
|
||||
// the tunnel channel
|
||||
channel := msg.Header["Micro-Tunnel-Channel"]
|
||||
// the session id
|
||||
session := msg.Header["Micro-Tunnel-Session"]
|
||||
delete(msg.Header, "Micro-Tunnel-Session")
|
||||
sessionId := msg.Header["Micro-Tunnel-Session"]
|
||||
|
||||
// strip token header
|
||||
delete(msg.Header, "Micro-Tunnel-Token")
|
||||
// strip tunnel message header
|
||||
for k, _ := range msg.Header {
|
||||
if strings.HasPrefix(k, "Micro-Tunnel") {
|
||||
delete(msg.Header, k)
|
||||
}
|
||||
}
|
||||
|
||||
// if the session id is blank there's nothing we can do
|
||||
// TODO: check this is the case, is there any reason
|
||||
// why we'd have a blank session? Is the tunnel
|
||||
// used for some other purpose?
|
||||
if len(id) == 0 || len(session) == 0 {
|
||||
if len(channel) == 0 || len(sessionId) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var s *socket
|
||||
var s *session
|
||||
var exists bool
|
||||
|
||||
// If its a loopback connection then we've enabled link direction
|
||||
// listening side is used for listening, the dialling side for dialling
|
||||
switch {
|
||||
case loopback:
|
||||
s, exists = t.getSocket(id, "listener")
|
||||
s, exists = t.getSession(channel, "listener")
|
||||
default:
|
||||
// get the socket based on the tunnel id and session
|
||||
// get the session based on the tunnel id and session
|
||||
// this could be something we dialed in which case
|
||||
// we have a session for it otherwise its a listener
|
||||
s, exists = t.getSocket(id, session)
|
||||
s, exists = t.getSession(channel, sessionId)
|
||||
if !exists {
|
||||
// try get it based on just the tunnel id
|
||||
// the assumption here is that a listener
|
||||
// has no session but its set a listener session
|
||||
s, exists = t.getSocket(id, "listener")
|
||||
s, exists = t.getSession(channel, "listener")
|
||||
}
|
||||
}
|
||||
|
||||
// bail if no socket has been found
|
||||
// bail if no session has been found
|
||||
if !exists {
|
||||
log.Debugf("Tunnel skipping no socket exists")
|
||||
log.Debugf("Tunnel skipping no session exists")
|
||||
// drop it, we don't care about
|
||||
// messages we don't know about
|
||||
continue
|
||||
}
|
||||
log.Debugf("Tunnel using socket %s %s", s.id, s.session)
|
||||
|
||||
// is the socket closed?
|
||||
log.Debugf("Tunnel using session %s %s", s.channel, s.session)
|
||||
|
||||
// is the session closed?
|
||||
select {
|
||||
case <-s.closed:
|
||||
// closed
|
||||
delete(t.sockets, id)
|
||||
delete(t.sessions, channel)
|
||||
continue
|
||||
default:
|
||||
// process
|
||||
}
|
||||
|
||||
// is the socket new?
|
||||
// is the session new?
|
||||
select {
|
||||
// if its new the socket is actually blocked waiting
|
||||
// if its new the session is actually blocked waiting
|
||||
// for a connection. so we check if its waiting.
|
||||
case <-s.wait:
|
||||
// if its waiting e.g its new then we close it
|
||||
default:
|
||||
// set remote address of the socket
|
||||
// set remote address of the session
|
||||
s.remote = msg.Header["Remote"]
|
||||
close(s.wait)
|
||||
}
|
||||
@@ -401,10 +413,12 @@ func (t *tun) listen(link *link) {
|
||||
// construct the internal message
|
||||
imsg := &message{
|
||||
id: id,
|
||||
session: session,
|
||||
channel: channel,
|
||||
session: sessionId,
|
||||
data: tmsg,
|
||||
link: link.id,
|
||||
loopback: loopback,
|
||||
errChan: make(chan error, 1),
|
||||
}
|
||||
|
||||
// append to recv backlog
|
||||
@@ -431,6 +445,7 @@ func (t *tun) keepalive(link *link) {
|
||||
if err := link.Send(&transport.Message{
|
||||
Header: map[string]string{
|
||||
"Micro-Tunnel": "keepalive",
|
||||
"Micro-Tunnel-Id": t.id,
|
||||
"Micro-Tunnel-Token": t.token,
|
||||
},
|
||||
}); err != nil {
|
||||
@@ -458,6 +473,7 @@ func (t *tun) setupLink(node string) (*link, error) {
|
||||
if err := c.Send(&transport.Message{
|
||||
Header: map[string]string{
|
||||
"Micro-Tunnel": "connect",
|
||||
"Micro-Tunnel-Id": t.id,
|
||||
"Micro-Tunnel-Token": t.token,
|
||||
},
|
||||
}); err != nil {
|
||||
@@ -568,6 +584,7 @@ func (t *tun) close() error {
|
||||
link.Send(&transport.Message{
|
||||
Header: map[string]string{
|
||||
"Micro-Tunnel": "close",
|
||||
"Micro-Tunnel-Id": t.id,
|
||||
"Micro-Tunnel-Token": t.token,
|
||||
},
|
||||
})
|
||||
@@ -603,10 +620,10 @@ func (t *tun) Close() error {
|
||||
case <-t.closed:
|
||||
return nil
|
||||
default:
|
||||
// close all the sockets
|
||||
for id, s := range t.sockets {
|
||||
// close all the sessions
|
||||
for id, s := range t.sessions {
|
||||
s.Close()
|
||||
delete(t.sockets, id)
|
||||
delete(t.sessions, id)
|
||||
}
|
||||
// close the connection
|
||||
close(t.closed)
|
||||
@@ -622,52 +639,50 @@ func (t *tun) Close() error {
|
||||
}
|
||||
|
||||
// Dial an address
|
||||
func (t *tun) Dial(addr string) (Conn, error) {
|
||||
log.Debugf("Tunnel dialing %s", addr)
|
||||
c, ok := t.newSocket(addr, t.newSession())
|
||||
func (t *tun) Dial(channel string) (Session, error) {
|
||||
log.Debugf("Tunnel dialing %s", channel)
|
||||
c, ok := t.newSession(channel, t.newSessionId())
|
||||
if !ok {
|
||||
return nil, errors.New("error dialing " + addr)
|
||||
return nil, errors.New("error dialing " + channel)
|
||||
}
|
||||
// set remote
|
||||
c.remote = addr
|
||||
c.remote = channel
|
||||
// set local
|
||||
c.local = "local"
|
||||
// outbound socket
|
||||
// outbound session
|
||||
c.outbound = true
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Accept a connection on the address
|
||||
func (t *tun) Listen(addr string) (Listener, error) {
|
||||
log.Debugf("Tunnel listening on %s", addr)
|
||||
// create a new socket by hashing the address
|
||||
c, ok := t.newSocket(addr, "listener")
|
||||
func (t *tun) Listen(channel string) (Listener, error) {
|
||||
log.Debugf("Tunnel listening on %s", channel)
|
||||
// create a new session by hashing the address
|
||||
c, ok := t.newSession(channel, "listener")
|
||||
if !ok {
|
||||
return nil, errors.New("already listening on " + addr)
|
||||
return nil, errors.New("already listening on " + channel)
|
||||
}
|
||||
|
||||
// set remote. it will be replaced by the first message received
|
||||
c.remote = "remote"
|
||||
// set local
|
||||
c.local = addr
|
||||
c.local = channel
|
||||
|
||||
tl := &tunListener{
|
||||
addr: addr,
|
||||
channel: channel,
|
||||
// the accept channel
|
||||
accept: make(chan *socket, 128),
|
||||
accept: make(chan *session, 128),
|
||||
// the channel to close
|
||||
closed: make(chan bool),
|
||||
// tunnel closed channel
|
||||
tunClosed: t.closed,
|
||||
// the connection
|
||||
conn: c,
|
||||
// the listener socket
|
||||
socket: c,
|
||||
// the listener session
|
||||
session: c,
|
||||
}
|
||||
|
||||
// this kicks off the internal message processor
|
||||
// for the listener so it can create pseudo sockets
|
||||
// for the listener so it can create pseudo sessions
|
||||
// per session if they do not exist or pass messages
|
||||
// to the existign sessions
|
||||
go tl.process()
|
||||
|
||||
Reference in New Issue
Block a user