Merge pull request #668 from milosgajdos83/tun-token-loopback

[WIP] Tunnel loopback connections
This commit is contained in:
Asim Aslam 2019-08-14 14:32:18 +01:00 committed by GitHub
commit 15975d2903
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 113 additions and 51 deletions

View File

@ -17,6 +17,9 @@ type tun struct {
sync.RWMutex sync.RWMutex
// tunnel token
token string
// to indicate if we're connected or not // to indicate if we're connected or not
connected bool connected bool
@ -50,6 +53,7 @@ func newTunnel(opts ...Option) *tun {
return &tun{ return &tun{
options: options, options: options,
token: uuid.New().String(),
send: make(chan *message, 128), send: make(chan *message, 128),
closed: make(chan bool), closed: make(chan bool),
sockets: make(map[string]*socket), sockets: make(map[string]*socket),
@ -57,6 +61,14 @@ func newTunnel(opts ...Option) *tun {
} }
} }
// Init initializes tunnel options
func (t *tun) Init(opts ...Option) error {
for _, o := range opts {
o(&t.options)
}
return nil
}
// getSocket returns a socket from the internal socket map. // getSocket returns a socket from the internal socket map.
// It does this based on the Micro-Tunnel-Id and Micro-Tunnel-Session // It does this based on the Micro-Tunnel-Id and Micro-Tunnel-Session
func (t *tun) getSocket(id, session string) (*socket, bool) { func (t *tun) getSocket(id, session string) (*socket, bool) {
@ -92,6 +104,7 @@ func (t *tun) newSocket(id, session string) (*socket, bool) {
t.Unlock() t.Unlock()
return nil, false return nil, false
} }
t.sockets[id+session] = s t.sockets[id+session] = s
t.Unlock() t.Unlock()
@ -126,6 +139,9 @@ func (t *tun) process() {
// set the session id // set the session id
newMsg.Header["Micro-Tunnel-Session"] = msg.session newMsg.Header["Micro-Tunnel-Session"] = msg.session
// set the tunnel token
newMsg.Header["Micro-Tunnel-Token"] = t.token
// send the message via the interface // send the message via the interface
t.RLock() t.RLock()
if len(t.links) == 0 { if len(t.links) == 0 {
@ -144,7 +160,10 @@ func (t *tun) process() {
} }
// process incoming messages // process incoming messages
func (t *tun) listen(link transport.Socket, listener bool) { func (t *tun) listen(link transport.Socket) {
// loopback flag
var loopback bool
for { for {
// process anything via the net interface // process anything via the net interface
msg := new(transport.Message) msg := new(transport.Message)
@ -155,10 +174,21 @@ func (t *tun) listen(link transport.Socket, listener bool) {
} }
switch msg.Header["Micro-Tunnel"] { switch msg.Header["Micro-Tunnel"] {
case "connect", "close": case "connect":
// TODO: handle the connect/close message // check the Micro-Tunnel-Token
// maybe used to create the dial/listen sockets token, ok := msg.Header["Micro-Tunnel-Token"]
// or report io.EOF or maybe to kill the link if !ok {
continue
}
// are we connecting to ourselves?
if token == t.token {
loopback = true
}
continue
case "close":
// TODO: handle the close message
// maybe report io.EOF or kill the link
continue continue
} }
@ -182,6 +212,11 @@ func (t *tun) listen(link transport.Socket, listener bool) {
var exists bool var exists bool
log.Debugf("Received %+v from %s", msg, link.Remote()) log.Debugf("Received %+v from %s", msg, link.Remote())
switch {
case loopback:
s, exists = t.getSocket(id, "listener")
default:
// get the socket based on the tunnel id and session // get the socket based on the tunnel id and session
// this could be something we dialed in which case // this could be something we dialed in which case
// we have a session for it otherwise its a listener // we have a session for it otherwise its a listener
@ -192,8 +227,8 @@ func (t *tun) listen(link transport.Socket, listener bool) {
// has no session but its set a listener session // has no session but its set a listener session
s, exists = t.getSocket(id, "listener") s, exists = t.getSocket(id, "listener")
} }
}
// no socket in existence // bail if no socket has been found
if !exists { if !exists {
log.Debugf("Tunnel skipping no socket exists") log.Debugf("Tunnel skipping no socket exists")
// drop it, we don't care about // drop it, we don't care about
@ -246,6 +281,7 @@ func (t *tun) listen(link transport.Socket, listener bool) {
} }
} }
// connect the tunnel to all the nodes and listen for incoming tunnel connections
func (t *tun) connect() error { func (t *tun) connect() error {
l, err := t.options.Transport.Listen(t.options.Address) l, err := t.options.Transport.Listen(t.options.Address)
if err != nil { if err != nil {
@ -277,7 +313,7 @@ func (t *tun) connect() error {
}() }()
// listen for inbound messages // listen for inbound messages
t.listen(sock, true) t.listen(sock)
}) })
t.Lock() t.Lock()
@ -307,13 +343,14 @@ func (t *tun) connect() error {
if err := c.Send(&transport.Message{ if err := c.Send(&transport.Message{
Header: map[string]string{ Header: map[string]string{
"Micro-Tunnel": "connect", "Micro-Tunnel": "connect",
"Micro-Tunnel-Token": t.token,
}, },
}); err != nil { }); err != nil {
continue continue
} }
// process incoming messages // process incoming messages
go t.listen(c, false) go t.listen(c)
// save the link // save the link
id := uuid.New().String() id := uuid.New().String()
@ -330,12 +367,36 @@ func (t *tun) connect() error {
return nil return nil
} }
// Connect the tunnel
func (t *tun) Connect() error {
t.Lock()
defer t.Unlock()
// already connected
if t.connected {
return nil
}
// send the connect message
if err := t.connect(); err != nil {
return err
}
// set as connected
t.connected = true
// create new close channel
t.closed = make(chan bool)
return nil
}
func (t *tun) close() error { func (t *tun) close() error {
// close all the links // close all the links
for id, link := range t.links { for id, link := range t.links {
link.Send(&transport.Message{ link.Send(&transport.Message{
Header: map[string]string{ Header: map[string]string{
"Micro-Tunnel": "close", "Micro-Tunnel": "close",
"Micro-Tunnel-Token": t.token,
}, },
}) })
link.Close() link.Close()
@ -376,36 +437,6 @@ func (t *tun) Close() error {
return nil return nil
} }
// Connect the tunnel
func (t *tun) Connect() error {
t.Lock()
defer t.Unlock()
// already connected
if t.connected {
return nil
}
// send the connect message
if err := t.connect(); err != nil {
return err
}
// set as connected
t.connected = true
// create new close channel
t.closed = make(chan bool)
return nil
}
func (t *tun) Init(opts ...Option) error {
for _, o := range opts {
o(&t.options)
}
return nil
}
// Dial an address // Dial an address
func (t *tun) Dial(addr string) (Conn, error) { func (t *tun) Dial(addr string) (Conn, error) {
log.Debugf("Tunnel dialing %s", addr) log.Debugf("Tunnel dialing %s", addr)
@ -413,7 +444,6 @@ func (t *tun) Dial(addr string) (Conn, error) {
if !ok { if !ok {
return nil, errors.New("error dialing " + addr) return nil, errors.New("error dialing " + addr)
} }
// set remote // set remote
c.remote = addr c.remote = addr
// set local // set local

View File

@ -98,3 +98,35 @@ func TestTunnel(t *testing.T) {
// wait until done // wait until done
wg.Wait() wg.Wait()
} }
func TestLoopbackTunnel(t *testing.T) {
// create a new tunnel client
tun := NewTunnel(
Address("127.0.0.1:9096"),
Nodes("127.0.0.1:9096"),
)
// start tunB
err := tun.Connect()
if err != nil {
t.Fatal(err)
}
defer tun.Close()
time.Sleep(time.Millisecond * 50)
var wg sync.WaitGroup
// start accepting connections
// on tunnel A
wg.Add(1)
go testAccept(t, tun, &wg)
time.Sleep(time.Millisecond * 50)
// dial and send via B
testSend(t, tun)
// wait until done
wg.Wait()
}