Use best link in tunnel, loop waiting for announce and accept messages, cleanup some code

This commit is contained in:
Asim Aslam 2019-10-25 14:16:22 +01:00
parent f26d470db1
commit 3831199600
4 changed files with 256 additions and 131 deletions

View File

@ -90,6 +90,7 @@ func (t *tun) getSession(channel, session string) (*session, bool) {
return s, ok return s, ok
} }
// delSession deletes a session if it exists
func (t *tun) delSession(channel, session string) { func (t *tun) delSession(channel, session string) {
t.Lock() t.Lock()
delete(t.sessions, channel+session) delete(t.sessions, channel+session)
@ -146,6 +147,9 @@ func (t *tun) newSessionId() string {
return uuid.New().String() return uuid.New().String()
} }
// announce will send a message to the link to tell the other side of a channel mapping we have.
// This usually happens if someone calls Dial and sends a discover message but otherwise we
// periodically send these messages to asynchronously manage channel mappings.
func (t *tun) announce(channel, session string, link *link) { func (t *tun) announce(channel, session string, link *link) {
// create the "announce" response message for a discover request // create the "announce" response message for a discover request
msg := &transport.Message{ msg := &transport.Message{
@ -206,7 +210,7 @@ func (t *tun) monitor() {
// check the link status and purge dead links // check the link status and purge dead links
for node, link := range t.links { for node, link := range t.links {
// check link status // check link status
switch link.Status() { switch link.State() {
case "closed": case "closed":
delLinks = append(delLinks, node) delLinks = append(delLinks, node)
case "error": case "error":
@ -303,8 +307,16 @@ func (t *tun) process() {
// build the list of links ot send to // build the list of links ot send to
for node, link := range t.links { for node, link := range t.links {
// get the values we need
link.RLock()
id := link.id
connected := link.connected
loopback := link.loopback
_, exists := link.channels[msg.channel]
link.RUnlock()
// if the link is not connected skip it // if the link is not connected skip it
if !link.connected { if !connected {
log.Debugf("Link for node %s not connected", node) log.Debugf("Link for node %s not connected", node)
err = errors.New("link not connected") err = errors.New("link not connected")
continue continue
@ -313,32 +325,29 @@ func (t *tun) process() {
// if the link was a loopback accepted connection // if the link was a loopback accepted connection
// and the message is being sent outbound via // and the message is being sent outbound via
// a dialled connection don't use this link // a dialled connection don't use this link
if link.loopback && msg.outbound { if loopback && msg.outbound {
err = errors.New("link is loopback") err = errors.New("link is loopback")
continue continue
} }
// if the message was being returned by the loopback listener // if the message was being returned by the loopback listener
// send it back up the loopback link only // send it back up the loopback link only
if msg.loopback && !link.loopback { if msg.loopback && !loopback {
err = errors.New("link is not loopback") err = errors.New("link is not loopback")
continue continue
} }
// check the multicast mappings // check the multicast mappings
if msg.mode == Multicast { if msg.mode == Multicast {
link.RLock()
_, ok := link.channels[msg.channel]
link.RUnlock()
// channel mapping not found in link // channel mapping not found in link
if !ok { if !exists {
continue continue
} }
} else { } else {
// if we're picking the link check the id // if we're picking the link check the id
// this is where we explicitly set the link // this is where we explicitly set the link
// in a message received via the listen method // in a message received via the listen method
if len(msg.link) > 0 && link.id != msg.link { if len(msg.link) > 0 && id != msg.link {
err = errors.New("link not found") err = errors.New("link not found")
continue continue
} }
@ -422,6 +431,12 @@ func (t *tun) listen(link *link) {
// let us know if its a loopback // let us know if its a loopback
var loopback bool var loopback bool
var connected bool
// set the connected value
link.RLock()
connected = link.connected
link.RUnlock()
for { for {
// process anything via the net interface // process anything via the net interface
@ -451,7 +466,7 @@ func (t *tun) listen(link *link) {
// if its not connected throw away the link // if its not connected throw away the link
// the first message we process needs to be connect // the first message we process needs to be connect
if !link.connected && mtype != "connect" { if !connected && mtype != "connect" {
log.Debugf("Tunnel link %s not connected", link.id) log.Debugf("Tunnel link %s not connected", link.id)
return return
} }
@ -461,7 +476,8 @@ func (t *tun) listen(link *link) {
log.Debugf("Tunnel link %s received connect message", link.Remote()) log.Debugf("Tunnel link %s received connect message", link.Remote())
link.Lock() link.Lock()
// are we connecting to ourselves?
// check if we're connecting to ourselves?
if id == t.id { if id == t.id {
link.loopback = true link.loopback = true
loopback = true loopback = true
@ -471,6 +487,8 @@ func (t *tun) listen(link *link) {
link.id = link.Remote() link.id = link.Remote()
// set as connected // set as connected
link.connected = true link.connected = true
connected = true
link.Unlock() link.Unlock()
// save the link once connected // save the link once connected
@ -494,9 +512,7 @@ func (t *tun) listen(link *link) {
// the entire listener was closed so remove it from the mapping // the entire listener was closed so remove it from the mapping
if sessionId == "listener" { if sessionId == "listener" {
link.Lock() link.delChannel(channel)
delete(link.channels, channel)
link.Unlock()
continue continue
} }
@ -510,10 +526,8 @@ func (t *tun) listen(link *link) {
// otherwise its a session mapping of sorts // otherwise its a session mapping of sorts
case "keepalive": case "keepalive":
log.Debugf("Tunnel link %s received keepalive", link.Remote()) log.Debugf("Tunnel link %s received keepalive", link.Remote())
link.Lock()
// save the keepalive // save the keepalive
link.lastKeepAlive = time.Now() link.keepalive()
link.Unlock()
continue continue
// a new connection dialled outbound // a new connection dialled outbound
case "open": case "open":
@ -540,11 +554,7 @@ func (t *tun) listen(link *link) {
channels := strings.Split(channel, ",") channels := strings.Split(channel, ",")
// update mapping in the link // update mapping in the link
link.Lock() link.setChannel(channels...)
for _, channel := range channels {
link.channels[channel] = time.Now()
}
link.Unlock()
// this was an announcement not intended for anything // this was an announcement not intended for anything
if sessionId == "listener" || sessionId == "" { if sessionId == "listener" || sessionId == "" {
@ -904,6 +914,53 @@ func (t *tun) close() error {
return t.listener.Close() return t.listener.Close()
} }
// pickLink will pick the best link based on connectivity, delay, rate and length
func (t *tun) pickLink(links []*link) *link {
var metric float64
var chosen *link
// find the best link
for i, link := range links {
// don't use disconnected or errored links
if link.State() != "connected" {
continue
}
// get the link state info
d := float64(link.Delay())
l := float64(link.Length())
r := link.Rate()
// metric = delay x length x rate
m := d * l * r
// first link so just and go
if i == 0 {
metric = m
chosen = link
continue
}
// we found a better metric
if m < metric {
metric = m
chosen = link
}
}
// if there's no link we're just going to mess around
if chosen == nil {
i := rand.Intn(len(links))
return links[i]
}
// we chose the link with;
// the lowest delay e.g least messages queued
// the lowest rate e.g the least messages flowing
// the lowest length e.g the smallest roundtrip time
return chosen
}
func (t *tun) Address() string { func (t *tun) Address() string {
t.RLock() t.RLock()
defer t.RUnlock() defer t.RUnlock()
@ -967,42 +1024,32 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) {
c.mode = options.Mode c.mode = options.Mode
// set the dial timeout // set the dial timeout
c.timeout = options.Timeout c.timeout = options.Timeout
// get the current time
now := time.Now()
after := func() time.Duration { var links []*link
d := time.Since(now)
// dial timeout minus time since
wait := options.Timeout - d
if wait < time.Duration(0) {
return time.Duration(0)
}
return wait
}
var links []string
// did we measure the rtt // did we measure the rtt
var measured bool var measured bool
// non multicast so we need to find the link
t.RLock() t.RLock()
// non multicast so we need to find the link
for _, link := range t.links { for _, link := range t.links {
// use the link specified it its available // use the link specified it its available
if id := options.Link; len(id) > 0 && link.id != id { if id := options.Link; len(id) > 0 && link.id != id {
continue continue
} }
link.RLock() // get the channel
_, ok := link.channels[channel] lastMapped := link.getChannel(channel)
link.RUnlock()
// we have at least one channel mapping // we have at least one channel mapping
if ok { if !lastMapped.IsZero() {
links = append(links, link)
c.discovered = true c.discovered = true
links = append(links, link.id)
} }
} }
t.RUnlock() t.RUnlock()
// link not found // link not found
if len(links) == 0 && len(options.Link) > 0 { if len(links) == 0 && len(options.Link) > 0 {
// delete session and return error // delete session and return error
@ -1015,9 +1062,9 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) {
// TODO: pick the link efficiently based // TODO: pick the link efficiently based
// on link status and saturation. // on link status and saturation.
if c.discovered && c.mode == Unicast { if c.discovered && c.mode == Unicast {
// set the link // pickLink will pick the best link
i := rand.Intn(len(links)) link := t.pickLink(links)
c.link = links[i] c.link = link.id
} }
// shit fuck // shit fuck
@ -1025,57 +1072,8 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) {
// piggy back roundtrip // piggy back roundtrip
nowRTT := time.Now() nowRTT := time.Now()
// create a new discovery message for this channel // attempt to discover the link
msg := c.newMessage("discover") err := c.Discover()
msg.mode = Broadcast
msg.outbound = true
msg.link = ""
// send the discovery message
t.send <- msg
select {
case <-time.After(after()):
t.delSession(c.channel, c.session)
log.Debugf("Tunnel deleting session %s %s: %v", c.session, c.channel, ErrDialTimeout)
return nil, ErrDialTimeout
case err := <-c.errChan:
if err != nil {
t.delSession(c.channel, c.session)
log.Debugf("Tunnel deleting session %s %s: %v", c.session, c.channel, err)
return nil, err
}
}
var err error
// set a dialTimeout
dialTimeout := after()
// set a shorter delay for multicast
if c.mode != Unicast {
// shorten this
dialTimeout = time.Millisecond * 500
}
// wait for announce
select {
case msg := <-c.recv:
if msg.typ != "announce" {
err = ErrDiscoverChan
}
case <-time.After(dialTimeout):
err = ErrDialTimeout
}
// if its multicast just go ahead because this is best effort
if c.mode != Unicast {
c.discovered = true
c.accepted = true
return c, nil
}
// otherwise return an error
if err != nil { if err != nil {
t.delSession(c.channel, c.session) t.delSession(c.channel, c.session)
log.Debugf("Tunnel deleting session %s %s: %v", c.session, c.channel, err) log.Debugf("Tunnel deleting session %s %s: %v", c.session, c.channel, err)
@ -1096,34 +1094,34 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) {
// set measured to true // set measured to true
measured = true measured = true
} }
// set discovered to true
c.discovered = true
} }
// a unicast session so we call "open" and wait for an "accept" // a unicast session so we call "open" and wait for an "accept"
// reset now in case we use it // reset now in case we use it
now = time.Now() now := time.Now()
// try to open the session // try to open the session
err := c.Open() if err := c.Open(); err != nil {
if err != nil {
// delete the session // delete the session
t.delSession(c.channel, c.session) t.delSession(c.channel, c.session)
log.Debugf("Tunnel deleting session %s %s: %v", c.session, c.channel, err) log.Debugf("Tunnel deleting session %s %s: %v", c.session, c.channel, err)
return nil, err return nil, err
} }
// set time take to open
d := time.Since(now)
// if we haven't measured the roundtrip do it now // if we haven't measured the roundtrip do it now
if !measured && c.mode == Unicast { if !measured && c.mode == Unicast {
// set the link time // set the link time
t.RLock() t.RLock()
link, ok := t.links[c.link] link, ok := t.links[c.link]
t.RUnlock() t.RUnlock()
if ok { if ok {
// set the rountrip time // set the rountrip time
link.setRTT(time.Since(now)) link.setRTT(d)
} }
} }

View File

@ -17,9 +17,11 @@ type link struct {
sync.RWMutex sync.RWMutex
// stops the link // stops the link
closed chan bool closed chan bool
// send queue // link state channel for testing link
state chan *packet
// send queue for sending packets
sendQueue chan *packet sendQueue chan *packet
// receive queue // receive queue for receiving packets
recvQueue chan *packet recvQueue chan *packet
// unique id of this link e.g uuid // unique id of this link e.g uuid
// which we define for ourselves // which we define for ourselves
@ -44,9 +46,6 @@ type link struct {
rate float64 rate float64
// keep an error count on the link // keep an error count on the link
errCount int errCount int
// link state channel
state chan *packet
} }
// packet send over link // packet send over link
@ -73,9 +72,9 @@ func newLink(s transport.Socket) *link {
Socket: s, Socket: s,
id: uuid.New().String(), id: uuid.New().String(),
lastKeepAlive: time.Now(), lastKeepAlive: time.Now(),
channels: make(map[string]time.Time),
closed: make(chan bool), closed: make(chan bool),
state: make(chan *packet, 64), state: make(chan *packet, 64),
channels: make(map[string]time.Time),
sendQueue: make(chan *packet, 128), sendQueue: make(chan *packet, 128),
recvQueue: make(chan *packet, 128), recvQueue: make(chan *packet, 128),
} }
@ -119,6 +118,33 @@ func (l *link) setRTT(d time.Duration) {
l.length = int64(length) l.length = int64(length)
} }
func (l *link) delChannel(ch string) {
l.Lock()
delete(l.channels, ch)
l.Unlock()
}
func (l *link) setChannel(channels ...string) {
l.Lock()
for _, ch := range channels {
l.channels[ch] = time.Now()
}
l.Unlock()
}
func (l *link) getChannel(ch string) time.Time {
l.RLock()
defer l.RUnlock()
return l.channels[ch]
}
// set the keepalive time
func (l *link) keepalive() {
l.Lock()
l.lastKeepAlive = time.Now()
l.Unlock()
}
// process deals with the send queue // process deals with the send queue
func (l *link) process() { func (l *link) process() {
// receive messages // receive messages
@ -176,8 +202,8 @@ func (l *link) manage() {
defer t.Stop() defer t.Stop()
// used to send link state packets // used to send link state packets
send := func(b []byte) { send := func(b []byte) error {
l.Send(&transport.Message{ return l.Send(&transport.Message{
Header: map[string]string{ Header: map[string]string{
"Micro-Method": "link", "Micro-Method": "link",
}, Body: b, }, Body: b,
@ -205,7 +231,11 @@ func (l *link) manage() {
case bytes.Compare(p.message.Body, linkRequest) == 0: case bytes.Compare(p.message.Body, linkRequest) == 0:
log.Tracef("Link %s received link request %v", l.id, p.message.Body) log.Tracef("Link %s received link request %v", l.id, p.message.Body)
// send response // send response
send(linkResponse) if err := send(linkResponse); err != nil {
l.Lock()
l.errCount++
l.Unlock()
}
case bytes.Compare(p.message.Body, linkResponse) == 0: case bytes.Compare(p.message.Body, linkResponse) == 0:
// set round trip time // set round trip time
d := time.Since(now) d := time.Since(now)
@ -270,7 +300,6 @@ func (l *link) Delay() int64 {
func (l *link) Rate() float64 { func (l *link) Rate() float64 {
l.RLock() l.RLock()
defer l.RUnlock() defer l.RUnlock()
return l.rate return l.rate
} }
@ -279,7 +308,6 @@ func (l *link) Rate() float64 {
func (l *link) Length() int64 { func (l *link) Length() int64 {
l.RLock() l.RLock()
defer l.RUnlock() defer l.RUnlock()
return l.length return l.length
} }
@ -398,8 +426,8 @@ func (l *link) Recv(m *transport.Message) error {
return nil return nil
} }
// Status can return connected, closed, error // State can return connected, closed, error
func (l *link) Status() string { func (l *link) State() string {
select { select {
case <-l.closed: case <-l.closed:
return "closed" return "closed"

View File

@ -106,6 +106,112 @@ func (s *session) newMessage(typ string) *message {
} }
} }
// waitFor waits for the message type required until the timeout specified
func (s *session) waitFor(msgType string, timeout time.Duration) error {
now := time.Now()
after := func() time.Duration {
d := time.Since(now)
// dial timeout minus time since
wait := timeout - d
if wait < time.Duration(0) {
return time.Duration(0)
}
return wait
}
// wait for the message type
loop:
for {
select {
case msg := <-s.recv:
// ignore what we don't want
if msg.typ != msgType {
log.Debugf("Tunnel received non %s message in waiting for %s", msg.typ, msgType)
continue
}
// got the message
break loop
case <-time.After(after()):
return ErrDialTimeout
case <-s.closed:
return io.EOF
}
}
return nil
}
// Discover attempts to discover the link for a specific channel
func (s *session) Discover() error {
// create a new discovery message for this channel
msg := s.newMessage("discover")
msg.mode = Broadcast
msg.outbound = true
msg.link = ""
// send the discovery message
s.send <- msg
// set time now
now := time.Now()
after := func() time.Duration {
d := time.Since(now)
// dial timeout minus time since
wait := s.timeout - d
if wait < time.Duration(0) {
return time.Duration(0)
}
return wait
}
// wait to hear back about the sent message
select {
case <-time.After(after()):
return ErrDialTimeout
case err := <-s.errChan:
if err != nil {
return err
}
}
var err error
// set a new dialTimeout
dialTimeout := after()
// set a shorter delay for multicast
if s.mode != Unicast {
// shorten this
dialTimeout = time.Millisecond * 500
}
// wait for announce
if err := s.waitFor("announce", dialTimeout); err != nil {
return err
}
// if its multicast just go ahead because this is best effort
if s.mode != Unicast {
s.discovered = true
s.accepted = true
return nil
}
if err != nil {
return err
}
// set discovered
s.discovered = true
return nil
}
// Open will fire the open message for the session. This is called by the dialler. // Open will fire the open message for the session. This is called by the dialler.
func (s *session) Open() error { func (s *session) Open() error {
// create a new message // create a new message
@ -131,22 +237,15 @@ func (s *session) Open() error {
} }
// now wait for the accept // now wait for the accept
select { if err := s.waitFor("accept", s.timeout); err != nil {
case msg = <-s.recv: return err
if msg.typ != "accept" {
log.Debugf("Received non accept message in Open %s", msg.typ)
return errors.New("failed to connect")
}
// set to accepted
s.accepted = true
// set link
s.link = msg.link
case <-time.After(s.timeout):
return ErrDialTimeout
case <-s.closed:
return io.EOF
} }
// set to accepted
s.accepted = true
// set link
s.link = msg.link
return nil return nil
} }

View File

@ -63,8 +63,8 @@ type Link interface {
Length() int64 Length() int64
// Current transfer rate as bits per second (lower is better) // Current transfer rate as bits per second (lower is better)
Rate() float64 Rate() float64
// Status of the link e.g connected/closed // State of the link e.g connected/closed
Status() string State() string
// honours transport socket // honours transport socket
transport.Socket transport.Socket
} }