Add some fixes

This commit is contained in:
Asim Aslam 2019-09-04 18:46:20 +01:00
parent 46a9767648
commit b9a2f719a0
2 changed files with 49 additions and 21 deletions

View File

@ -220,14 +220,6 @@ func (t *tun) process() {
continue
}
// if we're picking the link check the id
// this is where we explicitly set the link
// in a message received via the listen method
if len(msg.link) > 0 && link.id != msg.link {
err = errors.New("link not found")
continue
}
// if the link was a loopback accepted connection
// and the message is being sent outbound via
// a dialled connection don't use this link
@ -252,6 +244,14 @@ func (t *tun) process() {
if !ok {
continue
}
} else {
// if we're picking the link check the id
// this is where we explicitly set the link
// in a message received via the listen method
if len(msg.link) > 0 && link.id != msg.link {
err = errors.New("link not found")
continue
}
}
// send the message via the current link
@ -364,6 +364,7 @@ func (t *tun) listen(link *link) {
case "connect":
log.Debugf("Tunnel link %s received connect message", link.Remote())
link.Lock()
// are we connecting to ourselves?
if id == t.id {
link.loopback = true
@ -374,6 +375,7 @@ func (t *tun) listen(link *link) {
link.id = id
// set as connected
link.connected = true
link.Unlock()
// save the link once connected
t.Lock()
@ -417,10 +419,19 @@ func (t *tun) listen(link *link) {
continue
// a new connection dialled outbound
case "open":
log.Debugf("Tunnel link %s received open %s %s", link.id, channel, sessionId)
// we just let it pass through to be processed
// an accept returned by the listener
case "accept":
s, exists := t.getSession(channel, sessionId)
// we don't need this
if exists && s.multicast {
s.accepted = true
continue
}
if exists && s.accepted {
continue
}
// a continued session
case "session":
// process message
@ -725,6 +736,12 @@ func (t *tun) Connect() error {
}
func (t *tun) close() error {
// close all the sessions
for id, s := range t.sessions {
s.Close()
delete(t.sessions, id)
}
// close all the links
for node, link := range t.links {
link.Send(&transport.Message{
@ -768,12 +785,6 @@ func (t *tun) Close() error {
case <-t.closed:
return nil
default:
// close all the sessions
for id, s := range t.sessions {
s.Close()
delete(t.sessions, id)
}
// close the connection
close(t.closed)
t.connected = false
@ -814,11 +825,16 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) {
// set the dial timeout
c.timeout = options.Timeout
// don't bother with the song and dance below
// we're just going to assume things come online
// as and when.
if c.multicast {
return c, nil
now := time.Now()
after := func() time.Duration {
d := time.Since(now)
// dial timeout minus time since
wait := options.Timeout - d
if wait < time.Duration(0) {
return time.Duration(0)
}
return wait
}
// non multicast so we need to find the link
@ -846,7 +862,17 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) {
// send the discovery message
t.send <- msg
// don't bother waiting around
// we're just going to assume things come online
if c.multicast {
c.discovered = true
c.accepted = true
return c, nil
}
select {
case <-time.After(after()):
return nil, ErrDialTimeout
case err := <-c.errChan:
if err != nil {
return nil, err
@ -859,6 +885,8 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) {
if msg.typ != "announce" {
return nil, errors.New("failed to discover channel")
}
case <-time.After(after()):
return nil, ErrDialTimeout
}
}

View File

@ -97,7 +97,7 @@ func (l *link) Close() error {
return nil
default:
close(l.closed)
return l.Socket.Close()
return nil
}
return nil