Add some fixes
This commit is contained in:
		| @@ -220,14 +220,6 @@ func (t *tun) process() { | |||||||
| 					continue | 					continue | ||||||
| 				} | 				} | ||||||
|  |  | ||||||
| 				// if we're picking the link check the id |  | ||||||
| 				// this is where we explicitly set the link |  | ||||||
| 				// in a message received via the listen method |  | ||||||
| 				if len(msg.link) > 0 && link.id != msg.link { |  | ||||||
| 					err = errors.New("link not found") |  | ||||||
| 					continue |  | ||||||
| 				} |  | ||||||
|  |  | ||||||
| 				// if the link was a loopback accepted connection | 				// if the link was a loopback accepted connection | ||||||
| 				// and the message is being sent outbound via | 				// and the message is being sent outbound via | ||||||
| 				// a dialled connection don't use this link | 				// a dialled connection don't use this link | ||||||
| @@ -252,6 +244,14 @@ func (t *tun) process() { | |||||||
| 					if !ok { | 					if !ok { | ||||||
| 						continue | 						continue | ||||||
| 					} | 					} | ||||||
|  | 				} else { | ||||||
|  | 					// if we're picking the link check the id | ||||||
|  | 					// this is where we explicitly set the link | ||||||
|  | 					// in a message received via the listen method | ||||||
|  | 					if len(msg.link) > 0 && link.id != msg.link { | ||||||
|  | 						err = errors.New("link not found") | ||||||
|  | 						continue | ||||||
|  | 					} | ||||||
| 				} | 				} | ||||||
|  |  | ||||||
| 				// send the message via the current link | 				// send the message via the current link | ||||||
| @@ -364,6 +364,7 @@ func (t *tun) listen(link *link) { | |||||||
| 		case "connect": | 		case "connect": | ||||||
| 			log.Debugf("Tunnel link %s received connect message", link.Remote()) | 			log.Debugf("Tunnel link %s received connect message", link.Remote()) | ||||||
|  |  | ||||||
|  | 			link.Lock() | ||||||
| 			// are we connecting to ourselves? | 			// are we connecting to ourselves? | ||||||
| 			if id == t.id { | 			if id == t.id { | ||||||
| 				link.loopback = true | 				link.loopback = true | ||||||
| @@ -374,6 +375,7 @@ func (t *tun) listen(link *link) { | |||||||
| 			link.id = id | 			link.id = id | ||||||
| 			// set as connected | 			// set as connected | ||||||
| 			link.connected = true | 			link.connected = true | ||||||
|  | 			link.Unlock() | ||||||
|  |  | ||||||
| 			// save the link once connected | 			// save the link once connected | ||||||
| 			t.Lock() | 			t.Lock() | ||||||
| @@ -417,10 +419,19 @@ func (t *tun) listen(link *link) { | |||||||
| 			continue | 			continue | ||||||
| 		// a new connection dialled outbound | 		// a new connection dialled outbound | ||||||
| 		case "open": | 		case "open": | ||||||
|  | 			log.Debugf("Tunnel link %s received open %s %s", link.id, channel, sessionId) | ||||||
| 			// we just let it pass through to be processed | 			// we just let it pass through to be processed | ||||||
| 		// an accept returned by the listener | 		// an accept returned by the listener | ||||||
| 		case "accept": | 		case "accept": | ||||||
|  | 			s, exists := t.getSession(channel, sessionId) | ||||||
|  | 			// we don't need this | ||||||
|  | 			if exists && s.multicast { | ||||||
|  | 				s.accepted = true | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 			if exists && s.accepted { | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
| 		// a continued session | 		// a continued session | ||||||
| 		case "session": | 		case "session": | ||||||
| 			// process message | 			// process message | ||||||
| @@ -725,6 +736,12 @@ func (t *tun) Connect() error { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (t *tun) close() error { | func (t *tun) close() error { | ||||||
|  | 	// close all the sessions | ||||||
|  | 	for id, s := range t.sessions { | ||||||
|  | 		s.Close() | ||||||
|  | 		delete(t.sessions, id) | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	// close all the links | 	// close all the links | ||||||
| 	for node, link := range t.links { | 	for node, link := range t.links { | ||||||
| 		link.Send(&transport.Message{ | 		link.Send(&transport.Message{ | ||||||
| @@ -768,12 +785,6 @@ func (t *tun) Close() error { | |||||||
| 	case <-t.closed: | 	case <-t.closed: | ||||||
| 		return nil | 		return nil | ||||||
| 	default: | 	default: | ||||||
| 		// close all the sessions |  | ||||||
| 		for id, s := range t.sessions { |  | ||||||
| 			s.Close() |  | ||||||
| 			delete(t.sessions, id) |  | ||||||
| 		} |  | ||||||
| 		// close the connection |  | ||||||
| 		close(t.closed) | 		close(t.closed) | ||||||
| 		t.connected = false | 		t.connected = false | ||||||
|  |  | ||||||
| @@ -814,11 +825,16 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { | |||||||
| 	// set the dial timeout | 	// set the dial timeout | ||||||
| 	c.timeout = options.Timeout | 	c.timeout = options.Timeout | ||||||
|  |  | ||||||
| 	// don't bother with the song and dance below | 	now := time.Now() | ||||||
| 	// we're just going to assume things come online |  | ||||||
| 	// as and when. | 	after := func() time.Duration { | ||||||
| 	if c.multicast { | 		d := time.Since(now) | ||||||
| 		return c, nil | 		// dial timeout minus time since | ||||||
|  | 		wait := options.Timeout - d | ||||||
|  | 		if wait < time.Duration(0) { | ||||||
|  | 			return time.Duration(0) | ||||||
|  | 		} | ||||||
|  | 		return wait | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// non multicast so we need to find the link | 	// non multicast so we need to find the link | ||||||
| @@ -846,7 +862,17 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { | |||||||
| 		// send the discovery message | 		// send the discovery message | ||||||
| 		t.send <- msg | 		t.send <- msg | ||||||
|  |  | ||||||
|  | 		// don't bother waiting around | ||||||
|  | 		// we're just going to assume things come online | ||||||
|  | 		if c.multicast { | ||||||
|  | 			c.discovered = true | ||||||
|  | 			c.accepted = true | ||||||
|  | 			return c, nil | ||||||
|  | 		} | ||||||
|  |  | ||||||
| 		select { | 		select { | ||||||
|  | 		case <-time.After(after()): | ||||||
|  | 			return nil, ErrDialTimeout | ||||||
| 		case err := <-c.errChan: | 		case err := <-c.errChan: | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return nil, err | 				return nil, err | ||||||
| @@ -859,6 +885,8 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { | |||||||
| 			if msg.typ != "announce" { | 			if msg.typ != "announce" { | ||||||
| 				return nil, errors.New("failed to discover channel") | 				return nil, errors.New("failed to discover channel") | ||||||
| 			} | 			} | ||||||
|  | 		case <-time.After(after()): | ||||||
|  | 			return nil, ErrDialTimeout | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|   | |||||||
| @@ -97,7 +97,7 @@ func (l *link) Close() error { | |||||||
| 		return nil | 		return nil | ||||||
| 	default: | 	default: | ||||||
| 		close(l.closed) | 		close(l.closed) | ||||||
| 		return l.Socket.Close() | 		return nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return nil | 	return nil | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user