diff --git a/tunnel/default.go b/tunnel/default.go index 6ba7f878..90a641f0 100644 --- a/tunnel/default.go +++ b/tunnel/default.go @@ -17,6 +17,9 @@ type tun struct { sync.RWMutex + // tunnel token + token string + // to indicate if we're connected or not connected bool @@ -50,6 +53,7 @@ func newTunnel(opts ...Option) *tun { return &tun{ options: options, + token: uuid.New().String(), send: make(chan *message, 128), closed: make(chan bool), sockets: make(map[string]*socket), @@ -57,6 +61,14 @@ func newTunnel(opts ...Option) *tun { } } +// Init initializes tunnel options +func (t *tun) Init(opts ...Option) error { + for _, o := range opts { + o(&t.options) + } + return nil +} + // getSocket returns a socket from the internal socket map. // It does this based on the Micro-Tunnel-Id and Micro-Tunnel-Session func (t *tun) getSocket(id, session string) (*socket, bool) { @@ -92,6 +104,7 @@ func (t *tun) newSocket(id, session string) (*socket, bool) { t.Unlock() return nil, false } + t.sockets[id+session] = s t.Unlock() @@ -126,6 +139,9 @@ func (t *tun) process() { // set the session id newMsg.Header["Micro-Tunnel-Session"] = msg.session + // set the tunnel token + newMsg.Header["Micro-Tunnel-Token"] = t.token + // send the message via the interface t.RLock() if len(t.links) == 0 { @@ -144,7 +160,10 @@ func (t *tun) process() { } // process incoming messages -func (t *tun) listen(link transport.Socket, listener bool) { +func (t *tun) listen(link transport.Socket) { + // loopback flag + var loopback bool + for { // process anything via the net interface msg := new(transport.Message) @@ -155,10 +174,21 @@ func (t *tun) listen(link transport.Socket, listener bool) { } switch msg.Header["Micro-Tunnel"] { - case "connect", "close": - // TODO: handle the connect/close message - // maybe used to create the dial/listen sockets - // or report io.EOF or maybe to kill the link + case "connect": + // check the Micro-Tunnel-Token + token, ok := msg.Header["Micro-Tunnel-Token"] + if !ok { + continue + } + + // are we connecting to ourselves? + if token == t.token { + loopback = true + } + continue + case "close": + // TODO: handle the close message + // maybe report io.EOF or kill the link continue } @@ -182,18 +212,23 @@ func (t *tun) listen(link transport.Socket, listener bool) { var exists bool log.Debugf("Received %+v from %s", msg, link.Remote()) - // get the socket based on the tunnel id and session - // this could be something we dialed in which case - // we have a session for it otherwise its a listener - s, exists = t.getSocket(id, session) - if !exists { - // try get it based on just the tunnel id - // the assumption here is that a listener - // has no session but its set a listener session - s, exists = t.getSocket(id, "listener") - } - // no socket in existence + switch { + case loopback: + s, exists = t.getSocket(id, "listener") + default: + // get the socket based on the tunnel id and session + // this could be something we dialed in which case + // we have a session for it otherwise its a listener + s, exists = t.getSocket(id, session) + if !exists { + // try get it based on just the tunnel id + // the assumption here is that a listener + // has no session but its set a listener session + s, exists = t.getSocket(id, "listener") + } + } + // bail if no socket has been found if !exists { log.Debugf("Tunnel skipping no socket exists") // drop it, we don't care about @@ -246,6 +281,7 @@ func (t *tun) listen(link transport.Socket, listener bool) { } } +// connect the tunnel to all the nodes and listen for incoming tunnel connections func (t *tun) connect() error { l, err := t.options.Transport.Listen(t.options.Address) if err != nil { @@ -277,7 +313,7 @@ func (t *tun) connect() error { }() // listen for inbound messages - t.listen(sock, true) + t.listen(sock) }) t.Lock() @@ -306,14 +342,15 @@ func (t *tun) connect() error { if err := c.Send(&transport.Message{ Header: map[string]string{ - "Micro-Tunnel": "connect", + "Micro-Tunnel": "connect", + "Micro-Tunnel-Token": t.token, }, }); err != nil { continue } // process incoming messages - go t.listen(c, false) + go t.listen(c) // save the link id := uuid.New().String() @@ -330,12 +367,36 @@ func (t *tun) connect() error { return nil } +// Connect the tunnel +func (t *tun) Connect() error { + t.Lock() + defer t.Unlock() + + // already connected + if t.connected { + return nil + } + + // send the connect message + if err := t.connect(); err != nil { + return err + } + + // set as connected + t.connected = true + // create new close channel + t.closed = make(chan bool) + + return nil +} + func (t *tun) close() error { // close all the links for id, link := range t.links { link.Send(&transport.Message{ Header: map[string]string{ - "Micro-Tunnel": "close", + "Micro-Tunnel": "close", + "Micro-Tunnel-Token": t.token, }, }) link.Close() @@ -376,36 +437,6 @@ func (t *tun) Close() error { return nil } -// Connect the tunnel -func (t *tun) Connect() error { - t.Lock() - defer t.Unlock() - - // already connected - if t.connected { - return nil - } - - // send the connect message - if err := t.connect(); err != nil { - return err - } - - // set as connected - t.connected = true - // create new close channel - t.closed = make(chan bool) - - return nil -} - -func (t *tun) Init(opts ...Option) error { - for _, o := range opts { - o(&t.options) - } - return nil -} - // Dial an address func (t *tun) Dial(addr string) (Conn, error) { log.Debugf("Tunnel dialing %s", addr) @@ -413,7 +444,6 @@ func (t *tun) Dial(addr string) (Conn, error) { if !ok { return nil, errors.New("error dialing " + addr) } - // set remote c.remote = addr // set local diff --git a/tunnel/tunnel_test.go b/tunnel/tunnel_test.go index 721479bb..79db046a 100644 --- a/tunnel/tunnel_test.go +++ b/tunnel/tunnel_test.go @@ -24,10 +24,22 @@ func testAccept(t *testing.T, tun Tunnel, wg *sync.WaitGroup) { // get a message for { + // accept the message m := new(transport.Message) if err := c.Recv(m); err != nil { t.Fatal(err) } + + if v := m.Header["test"]; v != "send" { + t.Fatalf("Accept side expected test:send header. Received: %s", v) + } + + // now respond + m.Header["test"] = "accept" + if err := c.Send(m); err != nil { + t.Fatal(err) + } + wg.Done() return } @@ -44,13 +56,24 @@ func testSend(t *testing.T, tun Tunnel) { m := transport.Message{ Header: map[string]string{ - "test": "header", + "test": "send", }, } + // send the message if err := c.Send(&m); err != nil { t.Fatal(err) } + + // now wait for the response + mr := new(transport.Message) + if err := c.Recv(mr); err != nil { + t.Fatal(err) + } + + if v := mr.Header["test"]; v != "accept" { + t.Fatalf("Message not received from accepted side. Received: %s", v) + } } func TestTunnel(t *testing.T) { @@ -98,3 +121,35 @@ func TestTunnel(t *testing.T) { // wait until done wg.Wait() } + +func TestLoopbackTunnel(t *testing.T) { + // create a new tunnel client + tun := NewTunnel( + Address("127.0.0.1:9096"), + Nodes("127.0.0.1:9096"), + ) + + // start tunB + err := tun.Connect() + if err != nil { + t.Fatal(err) + } + defer tun.Close() + + time.Sleep(time.Millisecond * 50) + + var wg sync.WaitGroup + + // start accepting connections + // on tunnel A + wg.Add(1) + go testAccept(t, tun, &wg) + + time.Sleep(time.Millisecond * 50) + + // dial and send via B + testSend(t, tun) + + // wait until done + wg.Wait() +}