diff --git a/tunnel/default.go b/tunnel/default.go index 5d27809a..add48eb8 100644 --- a/tunnel/default.go +++ b/tunnel/default.go @@ -220,14 +220,6 @@ func (t *tun) process() { continue } - // if we're picking the link check the id - // this is where we explicitly set the link - // in a message received via the listen method - if len(msg.link) > 0 && link.id != msg.link { - err = errors.New("link not found") - continue - } - // if the link was a loopback accepted connection // and the message is being sent outbound via // a dialled connection don't use this link @@ -252,6 +244,14 @@ func (t *tun) process() { if !ok { continue } + } else { + // if we're picking the link check the id + // this is where we explicitly set the link + // in a message received via the listen method + if len(msg.link) > 0 && link.id != msg.link { + err = errors.New("link not found") + continue + } } // send the message via the current link @@ -364,6 +364,7 @@ func (t *tun) listen(link *link) { case "connect": log.Debugf("Tunnel link %s received connect message", link.Remote()) + link.Lock() // are we connecting to ourselves? if id == t.id { link.loopback = true @@ -374,6 +375,7 @@ func (t *tun) listen(link *link) { link.id = id // set as connected link.connected = true + link.Unlock() // save the link once connected t.Lock() @@ -417,10 +419,19 @@ func (t *tun) listen(link *link) { continue // a new connection dialled outbound case "open": + log.Debugf("Tunnel link %s received open %s %s", link.id, channel, sessionId) // we just let it pass through to be processed // an accept returned by the listener case "accept": - + s, exists := t.getSession(channel, sessionId) + // we don't need this + if exists && s.multicast { + s.accepted = true + continue + } + if exists && s.accepted { + continue + } // a continued session case "session": // process message @@ -725,6 +736,12 @@ func (t *tun) Connect() error { } func (t *tun) close() error { + // close all the sessions + for id, s := range t.sessions { + s.Close() + delete(t.sessions, id) + } + // close all the links for node, link := range t.links { link.Send(&transport.Message{ @@ -768,12 +785,6 @@ func (t *tun) Close() error { case <-t.closed: return nil default: - // close all the sessions - for id, s := range t.sessions { - s.Close() - delete(t.sessions, id) - } - // close the connection close(t.closed) t.connected = false @@ -814,11 +825,16 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { // set the dial timeout c.timeout = options.Timeout - // don't bother with the song and dance below - // we're just going to assume things come online - // as and when. - if c.multicast { - return c, nil + now := time.Now() + + after := func() time.Duration { + d := time.Since(now) + // dial timeout minus time since + wait := options.Timeout - d + if wait < time.Duration(0) { + return time.Duration(0) + } + return wait } // non multicast so we need to find the link @@ -846,7 +862,17 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { // send the discovery message t.send <- msg + // don't bother waiting around + // we're just going to assume things come online + if c.multicast { + c.discovered = true + c.accepted = true + return c, nil + } + select { + case <-time.After(after()): + return nil, ErrDialTimeout case err := <-c.errChan: if err != nil { return nil, err @@ -859,6 +885,8 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { if msg.typ != "announce" { return nil, errors.New("failed to discover channel") } + case <-time.After(after()): + return nil, ErrDialTimeout } } diff --git a/tunnel/link.go b/tunnel/link.go index 9470ef6a..38c80825 100644 --- a/tunnel/link.go +++ b/tunnel/link.go @@ -97,7 +97,7 @@ func (l *link) Close() error { return nil default: close(l.closed) - return l.Socket.Close() + return nil } return nil