Add some fixes

This commit is contained in:
Asim Aslam 2019-09-04 18:46:20 +01:00
parent 46a9767648
commit b9a2f719a0
2 changed files with 49 additions and 21 deletions

View File

@ -220,14 +220,6 @@ func (t *tun) process() {
continue continue
} }
// if we're picking the link check the id
// this is where we explicitly set the link
// in a message received via the listen method
if len(msg.link) > 0 && link.id != msg.link {
err = errors.New("link not found")
continue
}
// if the link was a loopback accepted connection // if the link was a loopback accepted connection
// and the message is being sent outbound via // and the message is being sent outbound via
// a dialled connection don't use this link // a dialled connection don't use this link
@ -252,6 +244,14 @@ func (t *tun) process() {
if !ok { if !ok {
continue continue
} }
} else {
// if we're picking the link check the id
// this is where we explicitly set the link
// in a message received via the listen method
if len(msg.link) > 0 && link.id != msg.link {
err = errors.New("link not found")
continue
}
} }
// send the message via the current link // send the message via the current link
@ -364,6 +364,7 @@ func (t *tun) listen(link *link) {
case "connect": case "connect":
log.Debugf("Tunnel link %s received connect message", link.Remote()) log.Debugf("Tunnel link %s received connect message", link.Remote())
link.Lock()
// are we connecting to ourselves? // are we connecting to ourselves?
if id == t.id { if id == t.id {
link.loopback = true link.loopback = true
@ -374,6 +375,7 @@ func (t *tun) listen(link *link) {
link.id = id link.id = id
// set as connected // set as connected
link.connected = true link.connected = true
link.Unlock()
// save the link once connected // save the link once connected
t.Lock() t.Lock()
@ -417,10 +419,19 @@ func (t *tun) listen(link *link) {
continue continue
// a new connection dialled outbound // a new connection dialled outbound
case "open": case "open":
log.Debugf("Tunnel link %s received open %s %s", link.id, channel, sessionId)
// we just let it pass through to be processed // we just let it pass through to be processed
// an accept returned by the listener // an accept returned by the listener
case "accept": case "accept":
s, exists := t.getSession(channel, sessionId)
// we don't need this
if exists && s.multicast {
s.accepted = true
continue
}
if exists && s.accepted {
continue
}
// a continued session // a continued session
case "session": case "session":
// process message // process message
@ -725,6 +736,12 @@ func (t *tun) Connect() error {
} }
func (t *tun) close() error { func (t *tun) close() error {
// close all the sessions
for id, s := range t.sessions {
s.Close()
delete(t.sessions, id)
}
// close all the links // close all the links
for node, link := range t.links { for node, link := range t.links {
link.Send(&transport.Message{ link.Send(&transport.Message{
@ -768,12 +785,6 @@ func (t *tun) Close() error {
case <-t.closed: case <-t.closed:
return nil return nil
default: default:
// close all the sessions
for id, s := range t.sessions {
s.Close()
delete(t.sessions, id)
}
// close the connection
close(t.closed) close(t.closed)
t.connected = false t.connected = false
@ -814,11 +825,16 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) {
// set the dial timeout // set the dial timeout
c.timeout = options.Timeout c.timeout = options.Timeout
// don't bother with the song and dance below now := time.Now()
// we're just going to assume things come online
// as and when. after := func() time.Duration {
if c.multicast { d := time.Since(now)
return c, nil // dial timeout minus time since
wait := options.Timeout - d
if wait < time.Duration(0) {
return time.Duration(0)
}
return wait
} }
// non multicast so we need to find the link // non multicast so we need to find the link
@ -846,7 +862,17 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) {
// send the discovery message // send the discovery message
t.send <- msg t.send <- msg
// don't bother waiting around
// we're just going to assume things come online
if c.multicast {
c.discovered = true
c.accepted = true
return c, nil
}
select { select {
case <-time.After(after()):
return nil, ErrDialTimeout
case err := <-c.errChan: case err := <-c.errChan:
if err != nil { if err != nil {
return nil, err return nil, err
@ -859,6 +885,8 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) {
if msg.typ != "announce" { if msg.typ != "announce" {
return nil, errors.New("failed to discover channel") return nil, errors.New("failed to discover channel")
} }
case <-time.After(after()):
return nil, ErrDialTimeout
} }
} }

View File

@ -97,7 +97,7 @@ func (l *link) Close() error {
return nil return nil
default: default:
close(l.closed) close(l.closed)
return l.Socket.Close() return nil
} }
return nil return nil