Merge pull request #668 from milosgajdos83/tun-token-loopback

[WIP] Tunnel loopback connections
This commit is contained in:
Asim Aslam 2019-08-14 14:32:18 +01:00 committed by GitHub
commit 15975d2903
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 113 additions and 51 deletions

View File

@ -17,6 +17,9 @@ type tun struct {
sync.RWMutex
// tunnel token
token string
// to indicate if we're connected or not
connected bool
@ -50,6 +53,7 @@ func newTunnel(opts ...Option) *tun {
return &tun{
options: options,
token: uuid.New().String(),
send: make(chan *message, 128),
closed: make(chan bool),
sockets: make(map[string]*socket),
@ -57,6 +61,14 @@ func newTunnel(opts ...Option) *tun {
}
}
// Init initializes tunnel options
func (t *tun) Init(opts ...Option) error {
for _, o := range opts {
o(&t.options)
}
return nil
}
// getSocket returns a socket from the internal socket map.
// It does this based on the Micro-Tunnel-Id and Micro-Tunnel-Session
func (t *tun) getSocket(id, session string) (*socket, bool) {
@ -92,6 +104,7 @@ func (t *tun) newSocket(id, session string) (*socket, bool) {
t.Unlock()
return nil, false
}
t.sockets[id+session] = s
t.Unlock()
@ -126,6 +139,9 @@ func (t *tun) process() {
// set the session id
newMsg.Header["Micro-Tunnel-Session"] = msg.session
// set the tunnel token
newMsg.Header["Micro-Tunnel-Token"] = t.token
// send the message via the interface
t.RLock()
if len(t.links) == 0 {
@ -144,7 +160,10 @@ func (t *tun) process() {
}
// process incoming messages
func (t *tun) listen(link transport.Socket, listener bool) {
func (t *tun) listen(link transport.Socket) {
// loopback flag
var loopback bool
for {
// process anything via the net interface
msg := new(transport.Message)
@ -155,10 +174,21 @@ func (t *tun) listen(link transport.Socket, listener bool) {
}
switch msg.Header["Micro-Tunnel"] {
case "connect", "close":
// TODO: handle the connect/close message
// maybe used to create the dial/listen sockets
// or report io.EOF or maybe to kill the link
case "connect":
// check the Micro-Tunnel-Token
token, ok := msg.Header["Micro-Tunnel-Token"]
if !ok {
continue
}
// are we connecting to ourselves?
if token == t.token {
loopback = true
}
continue
case "close":
// TODO: handle the close message
// maybe report io.EOF or kill the link
continue
}
@ -182,18 +212,23 @@ func (t *tun) listen(link transport.Socket, listener bool) {
var exists bool
log.Debugf("Received %+v from %s", msg, link.Remote())
// get the socket based on the tunnel id and session
// this could be something we dialed in which case
// we have a session for it otherwise its a listener
s, exists = t.getSocket(id, session)
if !exists {
// try get it based on just the tunnel id
// the assumption here is that a listener
// has no session but its set a listener session
s, exists = t.getSocket(id, "listener")
}
// no socket in existence
switch {
case loopback:
s, exists = t.getSocket(id, "listener")
default:
// get the socket based on the tunnel id and session
// this could be something we dialed in which case
// we have a session for it otherwise its a listener
s, exists = t.getSocket(id, session)
if !exists {
// try get it based on just the tunnel id
// the assumption here is that a listener
// has no session but its set a listener session
s, exists = t.getSocket(id, "listener")
}
}
// bail if no socket has been found
if !exists {
log.Debugf("Tunnel skipping no socket exists")
// drop it, we don't care about
@ -246,6 +281,7 @@ func (t *tun) listen(link transport.Socket, listener bool) {
}
}
// connect the tunnel to all the nodes and listen for incoming tunnel connections
func (t *tun) connect() error {
l, err := t.options.Transport.Listen(t.options.Address)
if err != nil {
@ -277,7 +313,7 @@ func (t *tun) connect() error {
}()
// listen for inbound messages
t.listen(sock, true)
t.listen(sock)
})
t.Lock()
@ -306,14 +342,15 @@ func (t *tun) connect() error {
if err := c.Send(&transport.Message{
Header: map[string]string{
"Micro-Tunnel": "connect",
"Micro-Tunnel": "connect",
"Micro-Tunnel-Token": t.token,
},
}); err != nil {
continue
}
// process incoming messages
go t.listen(c, false)
go t.listen(c)
// save the link
id := uuid.New().String()
@ -330,12 +367,36 @@ func (t *tun) connect() error {
return nil
}
// Connect the tunnel
func (t *tun) Connect() error {
t.Lock()
defer t.Unlock()
// already connected
if t.connected {
return nil
}
// send the connect message
if err := t.connect(); err != nil {
return err
}
// set as connected
t.connected = true
// create new close channel
t.closed = make(chan bool)
return nil
}
func (t *tun) close() error {
// close all the links
for id, link := range t.links {
link.Send(&transport.Message{
Header: map[string]string{
"Micro-Tunnel": "close",
"Micro-Tunnel": "close",
"Micro-Tunnel-Token": t.token,
},
})
link.Close()
@ -376,36 +437,6 @@ func (t *tun) Close() error {
return nil
}
// Connect the tunnel
func (t *tun) Connect() error {
t.Lock()
defer t.Unlock()
// already connected
if t.connected {
return nil
}
// send the connect message
if err := t.connect(); err != nil {
return err
}
// set as connected
t.connected = true
// create new close channel
t.closed = make(chan bool)
return nil
}
func (t *tun) Init(opts ...Option) error {
for _, o := range opts {
o(&t.options)
}
return nil
}
// Dial an address
func (t *tun) Dial(addr string) (Conn, error) {
log.Debugf("Tunnel dialing %s", addr)
@ -413,7 +444,6 @@ func (t *tun) Dial(addr string) (Conn, error) {
if !ok {
return nil, errors.New("error dialing " + addr)
}
// set remote
c.remote = addr
// set local

View File

@ -98,3 +98,35 @@ func TestTunnel(t *testing.T) {
// wait until done
wg.Wait()
}
func TestLoopbackTunnel(t *testing.T) {
// create a new tunnel client
tun := NewTunnel(
Address("127.0.0.1:9096"),
Nodes("127.0.0.1:9096"),
)
// start tunB
err := tun.Connect()
if err != nil {
t.Fatal(err)
}
defer tun.Close()
time.Sleep(time.Millisecond * 50)
var wg sync.WaitGroup
// start accepting connections
// on tunnel A
wg.Add(1)
go testAccept(t, tun, &wg)
time.Sleep(time.Millisecond * 50)
// dial and send via B
testSend(t, tun)
// wait until done
wg.Wait()
}