diff --git a/tunnel/default.go b/tunnel/default.go index 6ba7f878..26a4a4d0 100644 --- a/tunnel/default.go +++ b/tunnel/default.go @@ -17,6 +17,9 @@ type tun struct { sync.RWMutex + // tunnel token + token string + // to indicate if we're connected or not connected bool @@ -50,6 +53,7 @@ func newTunnel(opts ...Option) *tun { return &tun{ options: options, + token: uuid.New().String(), send: make(chan *message, 128), closed: make(chan bool), sockets: make(map[string]*socket), @@ -144,7 +148,7 @@ func (t *tun) process() { } // process incoming messages -func (t *tun) listen(link transport.Socket, listener bool) { +func (t *tun) listen(link transport.Socket) { for { // process anything via the net interface msg := new(transport.Message) @@ -154,11 +158,24 @@ func (t *tun) listen(link transport.Socket, listener bool) { return } + var loopback bool + switch msg.Header["Micro-Tunnel"] { - case "connect", "close": - // TODO: handle the connect/close message - // maybe used to create the dial/listen sockets - // or report io.EOF or maybe to kill the link + case "connect": + // TODO: handle the connect message + // check the Micro-Tunnel-Token + token, ok := msg.Header["Micro-Tunnel-Token"] + if !ok { + // no token found; bailing + continue + } + // are we connecting to ourselves? + if token == t.token { + loopback = true + } + case "close": + // TODO: handle the close message + // maybe report io.EOF or kill the link continue } @@ -182,18 +199,17 @@ func (t *tun) listen(link transport.Socket, listener bool) { var exists bool log.Debugf("Received %+v from %s", msg, link.Remote()) - // get the socket based on the tunnel id and session - // this could be something we dialed in which case - // we have a session for it otherwise its a listener - s, exists = t.getSocket(id, session) - if !exists { - // try get it based on just the tunnel id - // the assumption here is that a listener - // has no session but its set a listener session - s, exists = t.getSocket(id, "listener") - } - // no socket in existence + switch { + case loopback: + s, exists = t.getSocket(id, "listener") + default: + // get the socket based on the tunnel id and session + // this could be something we dialed in which case + // we have a session for it otherwise its a listener + s, exists = t.getSocket(id, session) + } + // bail if no socket has been found if !exists { log.Debugf("Tunnel skipping no socket exists") // drop it, we don't care about @@ -277,7 +293,7 @@ func (t *tun) connect() error { }() // listen for inbound messages - t.listen(sock, true) + t.listen(sock) }) t.Lock() @@ -306,14 +322,15 @@ func (t *tun) connect() error { if err := c.Send(&transport.Message{ Header: map[string]string{ - "Micro-Tunnel": "connect", + "Micro-Tunnel": "connect", + "Micro-Tunnel-Token": t.token, }, }); err != nil { continue } // process incoming messages - go t.listen(c, false) + go t.listen(c) // save the link id := uuid.New().String()