Use best link in tunnel, loop waiting for announce and accept messages, cleanup some code
This commit is contained in:
parent
f26d470db1
commit
3831199600
@ -90,6 +90,7 @@ func (t *tun) getSession(channel, session string) (*session, bool) {
|
||||
return s, ok
|
||||
}
|
||||
|
||||
// delSession deletes a session if it exists
|
||||
func (t *tun) delSession(channel, session string) {
|
||||
t.Lock()
|
||||
delete(t.sessions, channel+session)
|
||||
@ -146,6 +147,9 @@ func (t *tun) newSessionId() string {
|
||||
return uuid.New().String()
|
||||
}
|
||||
|
||||
// announce will send a message to the link to tell the other side of a channel mapping we have.
|
||||
// This usually happens if someone calls Dial and sends a discover message but otherwise we
|
||||
// periodically send these messages to asynchronously manage channel mappings.
|
||||
func (t *tun) announce(channel, session string, link *link) {
|
||||
// create the "announce" response message for a discover request
|
||||
msg := &transport.Message{
|
||||
@ -206,7 +210,7 @@ func (t *tun) monitor() {
|
||||
// check the link status and purge dead links
|
||||
for node, link := range t.links {
|
||||
// check link status
|
||||
switch link.Status() {
|
||||
switch link.State() {
|
||||
case "closed":
|
||||
delLinks = append(delLinks, node)
|
||||
case "error":
|
||||
@ -303,8 +307,16 @@ func (t *tun) process() {
|
||||
|
||||
// build the list of links ot send to
|
||||
for node, link := range t.links {
|
||||
// get the values we need
|
||||
link.RLock()
|
||||
id := link.id
|
||||
connected := link.connected
|
||||
loopback := link.loopback
|
||||
_, exists := link.channels[msg.channel]
|
||||
link.RUnlock()
|
||||
|
||||
// if the link is not connected skip it
|
||||
if !link.connected {
|
||||
if !connected {
|
||||
log.Debugf("Link for node %s not connected", node)
|
||||
err = errors.New("link not connected")
|
||||
continue
|
||||
@ -313,32 +325,29 @@ func (t *tun) process() {
|
||||
// if the link was a loopback accepted connection
|
||||
// and the message is being sent outbound via
|
||||
// a dialled connection don't use this link
|
||||
if link.loopback && msg.outbound {
|
||||
if loopback && msg.outbound {
|
||||
err = errors.New("link is loopback")
|
||||
continue
|
||||
}
|
||||
|
||||
// if the message was being returned by the loopback listener
|
||||
// send it back up the loopback link only
|
||||
if msg.loopback && !link.loopback {
|
||||
if msg.loopback && !loopback {
|
||||
err = errors.New("link is not loopback")
|
||||
continue
|
||||
}
|
||||
|
||||
// check the multicast mappings
|
||||
if msg.mode == Multicast {
|
||||
link.RLock()
|
||||
_, ok := link.channels[msg.channel]
|
||||
link.RUnlock()
|
||||
// channel mapping not found in link
|
||||
if !ok {
|
||||
if !exists {
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
// if we're picking the link check the id
|
||||
// this is where we explicitly set the link
|
||||
// in a message received via the listen method
|
||||
if len(msg.link) > 0 && link.id != msg.link {
|
||||
if len(msg.link) > 0 && id != msg.link {
|
||||
err = errors.New("link not found")
|
||||
continue
|
||||
}
|
||||
@ -422,6 +431,12 @@ func (t *tun) listen(link *link) {
|
||||
|
||||
// let us know if its a loopback
|
||||
var loopback bool
|
||||
var connected bool
|
||||
|
||||
// set the connected value
|
||||
link.RLock()
|
||||
connected = link.connected
|
||||
link.RUnlock()
|
||||
|
||||
for {
|
||||
// process anything via the net interface
|
||||
@ -451,7 +466,7 @@ func (t *tun) listen(link *link) {
|
||||
|
||||
// if its not connected throw away the link
|
||||
// the first message we process needs to be connect
|
||||
if !link.connected && mtype != "connect" {
|
||||
if !connected && mtype != "connect" {
|
||||
log.Debugf("Tunnel link %s not connected", link.id)
|
||||
return
|
||||
}
|
||||
@ -461,7 +476,8 @@ func (t *tun) listen(link *link) {
|
||||
log.Debugf("Tunnel link %s received connect message", link.Remote())
|
||||
|
||||
link.Lock()
|
||||
// are we connecting to ourselves?
|
||||
|
||||
// check if we're connecting to ourselves?
|
||||
if id == t.id {
|
||||
link.loopback = true
|
||||
loopback = true
|
||||
@ -471,6 +487,8 @@ func (t *tun) listen(link *link) {
|
||||
link.id = link.Remote()
|
||||
// set as connected
|
||||
link.connected = true
|
||||
connected = true
|
||||
|
||||
link.Unlock()
|
||||
|
||||
// save the link once connected
|
||||
@ -494,9 +512,7 @@ func (t *tun) listen(link *link) {
|
||||
|
||||
// the entire listener was closed so remove it from the mapping
|
||||
if sessionId == "listener" {
|
||||
link.Lock()
|
||||
delete(link.channels, channel)
|
||||
link.Unlock()
|
||||
link.delChannel(channel)
|
||||
continue
|
||||
}
|
||||
|
||||
@ -510,10 +526,8 @@ func (t *tun) listen(link *link) {
|
||||
// otherwise its a session mapping of sorts
|
||||
case "keepalive":
|
||||
log.Debugf("Tunnel link %s received keepalive", link.Remote())
|
||||
link.Lock()
|
||||
// save the keepalive
|
||||
link.lastKeepAlive = time.Now()
|
||||
link.Unlock()
|
||||
link.keepalive()
|
||||
continue
|
||||
// a new connection dialled outbound
|
||||
case "open":
|
||||
@ -540,11 +554,7 @@ func (t *tun) listen(link *link) {
|
||||
channels := strings.Split(channel, ",")
|
||||
|
||||
// update mapping in the link
|
||||
link.Lock()
|
||||
for _, channel := range channels {
|
||||
link.channels[channel] = time.Now()
|
||||
}
|
||||
link.Unlock()
|
||||
link.setChannel(channels...)
|
||||
|
||||
// this was an announcement not intended for anything
|
||||
if sessionId == "listener" || sessionId == "" {
|
||||
@ -904,6 +914,53 @@ func (t *tun) close() error {
|
||||
return t.listener.Close()
|
||||
}
|
||||
|
||||
// pickLink will pick the best link based on connectivity, delay, rate and length
|
||||
func (t *tun) pickLink(links []*link) *link {
|
||||
var metric float64
|
||||
var chosen *link
|
||||
|
||||
// find the best link
|
||||
for i, link := range links {
|
||||
// don't use disconnected or errored links
|
||||
if link.State() != "connected" {
|
||||
continue
|
||||
}
|
||||
|
||||
// get the link state info
|
||||
d := float64(link.Delay())
|
||||
l := float64(link.Length())
|
||||
r := link.Rate()
|
||||
|
||||
// metric = delay x length x rate
|
||||
m := d * l * r
|
||||
|
||||
// first link so just and go
|
||||
if i == 0 {
|
||||
metric = m
|
||||
chosen = link
|
||||
continue
|
||||
}
|
||||
|
||||
// we found a better metric
|
||||
if m < metric {
|
||||
metric = m
|
||||
chosen = link
|
||||
}
|
||||
}
|
||||
|
||||
// if there's no link we're just going to mess around
|
||||
if chosen == nil {
|
||||
i := rand.Intn(len(links))
|
||||
return links[i]
|
||||
}
|
||||
|
||||
// we chose the link with;
|
||||
// the lowest delay e.g least messages queued
|
||||
// the lowest rate e.g the least messages flowing
|
||||
// the lowest length e.g the smallest roundtrip time
|
||||
return chosen
|
||||
}
|
||||
|
||||
func (t *tun) Address() string {
|
||||
t.RLock()
|
||||
defer t.RUnlock()
|
||||
@ -967,42 +1024,32 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) {
|
||||
c.mode = options.Mode
|
||||
// set the dial timeout
|
||||
c.timeout = options.Timeout
|
||||
// get the current time
|
||||
now := time.Now()
|
||||
|
||||
after := func() time.Duration {
|
||||
d := time.Since(now)
|
||||
// dial timeout minus time since
|
||||
wait := options.Timeout - d
|
||||
if wait < time.Duration(0) {
|
||||
return time.Duration(0)
|
||||
}
|
||||
return wait
|
||||
}
|
||||
|
||||
var links []string
|
||||
var links []*link
|
||||
// did we measure the rtt
|
||||
var measured bool
|
||||
|
||||
// non multicast so we need to find the link
|
||||
t.RLock()
|
||||
|
||||
// non multicast so we need to find the link
|
||||
for _, link := range t.links {
|
||||
// use the link specified it its available
|
||||
if id := options.Link; len(id) > 0 && link.id != id {
|
||||
continue
|
||||
}
|
||||
|
||||
link.RLock()
|
||||
_, ok := link.channels[channel]
|
||||
link.RUnlock()
|
||||
// get the channel
|
||||
lastMapped := link.getChannel(channel)
|
||||
|
||||
// we have at least one channel mapping
|
||||
if ok {
|
||||
if !lastMapped.IsZero() {
|
||||
links = append(links, link)
|
||||
c.discovered = true
|
||||
links = append(links, link.id)
|
||||
}
|
||||
}
|
||||
|
||||
t.RUnlock()
|
||||
|
||||
// link not found
|
||||
if len(links) == 0 && len(options.Link) > 0 {
|
||||
// delete session and return error
|
||||
@ -1015,9 +1062,9 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) {
|
||||
// TODO: pick the link efficiently based
|
||||
// on link status and saturation.
|
||||
if c.discovered && c.mode == Unicast {
|
||||
// set the link
|
||||
i := rand.Intn(len(links))
|
||||
c.link = links[i]
|
||||
// pickLink will pick the best link
|
||||
link := t.pickLink(links)
|
||||
c.link = link.id
|
||||
}
|
||||
|
||||
// shit fuck
|
||||
@ -1025,57 +1072,8 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) {
|
||||
// piggy back roundtrip
|
||||
nowRTT := time.Now()
|
||||
|
||||
// create a new discovery message for this channel
|
||||
msg := c.newMessage("discover")
|
||||
msg.mode = Broadcast
|
||||
msg.outbound = true
|
||||
msg.link = ""
|
||||
|
||||
// send the discovery message
|
||||
t.send <- msg
|
||||
|
||||
select {
|
||||
case <-time.After(after()):
|
||||
t.delSession(c.channel, c.session)
|
||||
log.Debugf("Tunnel deleting session %s %s: %v", c.session, c.channel, ErrDialTimeout)
|
||||
return nil, ErrDialTimeout
|
||||
case err := <-c.errChan:
|
||||
if err != nil {
|
||||
t.delSession(c.channel, c.session)
|
||||
log.Debugf("Tunnel deleting session %s %s: %v", c.session, c.channel, err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
// set a dialTimeout
|
||||
dialTimeout := after()
|
||||
|
||||
// set a shorter delay for multicast
|
||||
if c.mode != Unicast {
|
||||
// shorten this
|
||||
dialTimeout = time.Millisecond * 500
|
||||
}
|
||||
|
||||
// wait for announce
|
||||
select {
|
||||
case msg := <-c.recv:
|
||||
if msg.typ != "announce" {
|
||||
err = ErrDiscoverChan
|
||||
}
|
||||
case <-time.After(dialTimeout):
|
||||
err = ErrDialTimeout
|
||||
}
|
||||
|
||||
// if its multicast just go ahead because this is best effort
|
||||
if c.mode != Unicast {
|
||||
c.discovered = true
|
||||
c.accepted = true
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// otherwise return an error
|
||||
// attempt to discover the link
|
||||
err := c.Discover()
|
||||
if err != nil {
|
||||
t.delSession(c.channel, c.session)
|
||||
log.Debugf("Tunnel deleting session %s %s: %v", c.session, c.channel, err)
|
||||
@ -1096,34 +1094,34 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) {
|
||||
// set measured to true
|
||||
measured = true
|
||||
}
|
||||
|
||||
// set discovered to true
|
||||
c.discovered = true
|
||||
}
|
||||
|
||||
// a unicast session so we call "open" and wait for an "accept"
|
||||
|
||||
// reset now in case we use it
|
||||
now = time.Now()
|
||||
now := time.Now()
|
||||
|
||||
// try to open the session
|
||||
err := c.Open()
|
||||
if err != nil {
|
||||
if err := c.Open(); err != nil {
|
||||
// delete the session
|
||||
t.delSession(c.channel, c.session)
|
||||
log.Debugf("Tunnel deleting session %s %s: %v", c.session, c.channel, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// set time take to open
|
||||
d := time.Since(now)
|
||||
|
||||
// if we haven't measured the roundtrip do it now
|
||||
if !measured && c.mode == Unicast {
|
||||
// set the link time
|
||||
t.RLock()
|
||||
link, ok := t.links[c.link]
|
||||
t.RUnlock()
|
||||
|
||||
if ok {
|
||||
// set the rountrip time
|
||||
link.setRTT(time.Since(now))
|
||||
link.setRTT(d)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -17,9 +17,11 @@ type link struct {
|
||||
sync.RWMutex
|
||||
// stops the link
|
||||
closed chan bool
|
||||
// send queue
|
||||
// link state channel for testing link
|
||||
state chan *packet
|
||||
// send queue for sending packets
|
||||
sendQueue chan *packet
|
||||
// receive queue
|
||||
// receive queue for receiving packets
|
||||
recvQueue chan *packet
|
||||
// unique id of this link e.g uuid
|
||||
// which we define for ourselves
|
||||
@ -44,9 +46,6 @@ type link struct {
|
||||
rate float64
|
||||
// keep an error count on the link
|
||||
errCount int
|
||||
|
||||
// link state channel
|
||||
state chan *packet
|
||||
}
|
||||
|
||||
// packet send over link
|
||||
@ -73,9 +72,9 @@ func newLink(s transport.Socket) *link {
|
||||
Socket: s,
|
||||
id: uuid.New().String(),
|
||||
lastKeepAlive: time.Now(),
|
||||
channels: make(map[string]time.Time),
|
||||
closed: make(chan bool),
|
||||
state: make(chan *packet, 64),
|
||||
channels: make(map[string]time.Time),
|
||||
sendQueue: make(chan *packet, 128),
|
||||
recvQueue: make(chan *packet, 128),
|
||||
}
|
||||
@ -119,6 +118,33 @@ func (l *link) setRTT(d time.Duration) {
|
||||
l.length = int64(length)
|
||||
}
|
||||
|
||||
func (l *link) delChannel(ch string) {
|
||||
l.Lock()
|
||||
delete(l.channels, ch)
|
||||
l.Unlock()
|
||||
}
|
||||
|
||||
func (l *link) setChannel(channels ...string) {
|
||||
l.Lock()
|
||||
for _, ch := range channels {
|
||||
l.channels[ch] = time.Now()
|
||||
}
|
||||
l.Unlock()
|
||||
}
|
||||
|
||||
func (l *link) getChannel(ch string) time.Time {
|
||||
l.RLock()
|
||||
defer l.RUnlock()
|
||||
return l.channels[ch]
|
||||
}
|
||||
|
||||
// set the keepalive time
|
||||
func (l *link) keepalive() {
|
||||
l.Lock()
|
||||
l.lastKeepAlive = time.Now()
|
||||
l.Unlock()
|
||||
}
|
||||
|
||||
// process deals with the send queue
|
||||
func (l *link) process() {
|
||||
// receive messages
|
||||
@ -176,8 +202,8 @@ func (l *link) manage() {
|
||||
defer t.Stop()
|
||||
|
||||
// used to send link state packets
|
||||
send := func(b []byte) {
|
||||
l.Send(&transport.Message{
|
||||
send := func(b []byte) error {
|
||||
return l.Send(&transport.Message{
|
||||
Header: map[string]string{
|
||||
"Micro-Method": "link",
|
||||
}, Body: b,
|
||||
@ -205,7 +231,11 @@ func (l *link) manage() {
|
||||
case bytes.Compare(p.message.Body, linkRequest) == 0:
|
||||
log.Tracef("Link %s received link request %v", l.id, p.message.Body)
|
||||
// send response
|
||||
send(linkResponse)
|
||||
if err := send(linkResponse); err != nil {
|
||||
l.Lock()
|
||||
l.errCount++
|
||||
l.Unlock()
|
||||
}
|
||||
case bytes.Compare(p.message.Body, linkResponse) == 0:
|
||||
// set round trip time
|
||||
d := time.Since(now)
|
||||
@ -270,7 +300,6 @@ func (l *link) Delay() int64 {
|
||||
func (l *link) Rate() float64 {
|
||||
l.RLock()
|
||||
defer l.RUnlock()
|
||||
|
||||
return l.rate
|
||||
}
|
||||
|
||||
@ -279,7 +308,6 @@ func (l *link) Rate() float64 {
|
||||
func (l *link) Length() int64 {
|
||||
l.RLock()
|
||||
defer l.RUnlock()
|
||||
|
||||
return l.length
|
||||
}
|
||||
|
||||
@ -398,8 +426,8 @@ func (l *link) Recv(m *transport.Message) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Status can return connected, closed, error
|
||||
func (l *link) Status() string {
|
||||
// State can return connected, closed, error
|
||||
func (l *link) State() string {
|
||||
select {
|
||||
case <-l.closed:
|
||||
return "closed"
|
||||
|
@ -106,6 +106,112 @@ func (s *session) newMessage(typ string) *message {
|
||||
}
|
||||
}
|
||||
|
||||
// waitFor waits for the message type required until the timeout specified
|
||||
func (s *session) waitFor(msgType string, timeout time.Duration) error {
|
||||
now := time.Now()
|
||||
|
||||
after := func() time.Duration {
|
||||
d := time.Since(now)
|
||||
// dial timeout minus time since
|
||||
wait := timeout - d
|
||||
|
||||
if wait < time.Duration(0) {
|
||||
return time.Duration(0)
|
||||
}
|
||||
|
||||
return wait
|
||||
}
|
||||
|
||||
// wait for the message type
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case msg := <-s.recv:
|
||||
// ignore what we don't want
|
||||
if msg.typ != msgType {
|
||||
log.Debugf("Tunnel received non %s message in waiting for %s", msg.typ, msgType)
|
||||
continue
|
||||
}
|
||||
|
||||
// got the message
|
||||
break loop
|
||||
case <-time.After(after()):
|
||||
return ErrDialTimeout
|
||||
case <-s.closed:
|
||||
return io.EOF
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Discover attempts to discover the link for a specific channel
|
||||
func (s *session) Discover() error {
|
||||
// create a new discovery message for this channel
|
||||
msg := s.newMessage("discover")
|
||||
msg.mode = Broadcast
|
||||
msg.outbound = true
|
||||
msg.link = ""
|
||||
|
||||
// send the discovery message
|
||||
s.send <- msg
|
||||
|
||||
// set time now
|
||||
now := time.Now()
|
||||
|
||||
after := func() time.Duration {
|
||||
d := time.Since(now)
|
||||
// dial timeout minus time since
|
||||
wait := s.timeout - d
|
||||
if wait < time.Duration(0) {
|
||||
return time.Duration(0)
|
||||
}
|
||||
return wait
|
||||
}
|
||||
|
||||
// wait to hear back about the sent message
|
||||
select {
|
||||
case <-time.After(after()):
|
||||
return ErrDialTimeout
|
||||
case err := <-s.errChan:
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
// set a new dialTimeout
|
||||
dialTimeout := after()
|
||||
|
||||
// set a shorter delay for multicast
|
||||
if s.mode != Unicast {
|
||||
// shorten this
|
||||
dialTimeout = time.Millisecond * 500
|
||||
}
|
||||
|
||||
// wait for announce
|
||||
if err := s.waitFor("announce", dialTimeout); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// if its multicast just go ahead because this is best effort
|
||||
if s.mode != Unicast {
|
||||
s.discovered = true
|
||||
s.accepted = true
|
||||
return nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// set discovered
|
||||
s.discovered = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Open will fire the open message for the session. This is called by the dialler.
|
||||
func (s *session) Open() error {
|
||||
// create a new message
|
||||
@ -131,22 +237,15 @@ func (s *session) Open() error {
|
||||
}
|
||||
|
||||
// now wait for the accept
|
||||
select {
|
||||
case msg = <-s.recv:
|
||||
if msg.typ != "accept" {
|
||||
log.Debugf("Received non accept message in Open %s", msg.typ)
|
||||
return errors.New("failed to connect")
|
||||
}
|
||||
// set to accepted
|
||||
s.accepted = true
|
||||
// set link
|
||||
s.link = msg.link
|
||||
case <-time.After(s.timeout):
|
||||
return ErrDialTimeout
|
||||
case <-s.closed:
|
||||
return io.EOF
|
||||
if err := s.waitFor("accept", s.timeout); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// set to accepted
|
||||
s.accepted = true
|
||||
// set link
|
||||
s.link = msg.link
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -63,8 +63,8 @@ type Link interface {
|
||||
Length() int64
|
||||
// Current transfer rate as bits per second (lower is better)
|
||||
Rate() float64
|
||||
// Status of the link e.g connected/closed
|
||||
Status() string
|
||||
// State of the link e.g connected/closed
|
||||
State() string
|
||||
// honours transport socket
|
||||
transport.Socket
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user