Use best link in tunnel, loop waiting for announce and accept messages, cleanup some code

This commit is contained in:
Asim Aslam 2019-10-25 14:16:22 +01:00
parent f26d470db1
commit 3831199600
4 changed files with 256 additions and 131 deletions

View File

@ -90,6 +90,7 @@ func (t *tun) getSession(channel, session string) (*session, bool) {
return s, ok
}
// delSession deletes a session if it exists
func (t *tun) delSession(channel, session string) {
t.Lock()
delete(t.sessions, channel+session)
@ -146,6 +147,9 @@ func (t *tun) newSessionId() string {
return uuid.New().String()
}
// announce will send a message to the link to tell the other side of a channel mapping we have.
// This usually happens if someone calls Dial and sends a discover message but otherwise we
// periodically send these messages to asynchronously manage channel mappings.
func (t *tun) announce(channel, session string, link *link) {
// create the "announce" response message for a discover request
msg := &transport.Message{
@ -206,7 +210,7 @@ func (t *tun) monitor() {
// check the link status and purge dead links
for node, link := range t.links {
// check link status
switch link.Status() {
switch link.State() {
case "closed":
delLinks = append(delLinks, node)
case "error":
@ -303,8 +307,16 @@ func (t *tun) process() {
// build the list of links ot send to
for node, link := range t.links {
// get the values we need
link.RLock()
id := link.id
connected := link.connected
loopback := link.loopback
_, exists := link.channels[msg.channel]
link.RUnlock()
// if the link is not connected skip it
if !link.connected {
if !connected {
log.Debugf("Link for node %s not connected", node)
err = errors.New("link not connected")
continue
@ -313,32 +325,29 @@ func (t *tun) process() {
// if the link was a loopback accepted connection
// and the message is being sent outbound via
// a dialled connection don't use this link
if link.loopback && msg.outbound {
if loopback && msg.outbound {
err = errors.New("link is loopback")
continue
}
// if the message was being returned by the loopback listener
// send it back up the loopback link only
if msg.loopback && !link.loopback {
if msg.loopback && !loopback {
err = errors.New("link is not loopback")
continue
}
// check the multicast mappings
if msg.mode == Multicast {
link.RLock()
_, ok := link.channels[msg.channel]
link.RUnlock()
// channel mapping not found in link
if !ok {
if !exists {
continue
}
} else {
// if we're picking the link check the id
// this is where we explicitly set the link
// in a message received via the listen method
if len(msg.link) > 0 && link.id != msg.link {
if len(msg.link) > 0 && id != msg.link {
err = errors.New("link not found")
continue
}
@ -422,6 +431,12 @@ func (t *tun) listen(link *link) {
// let us know if its a loopback
var loopback bool
var connected bool
// set the connected value
link.RLock()
connected = link.connected
link.RUnlock()
for {
// process anything via the net interface
@ -451,7 +466,7 @@ func (t *tun) listen(link *link) {
// if its not connected throw away the link
// the first message we process needs to be connect
if !link.connected && mtype != "connect" {
if !connected && mtype != "connect" {
log.Debugf("Tunnel link %s not connected", link.id)
return
}
@ -461,7 +476,8 @@ func (t *tun) listen(link *link) {
log.Debugf("Tunnel link %s received connect message", link.Remote())
link.Lock()
// are we connecting to ourselves?
// check if we're connecting to ourselves?
if id == t.id {
link.loopback = true
loopback = true
@ -471,6 +487,8 @@ func (t *tun) listen(link *link) {
link.id = link.Remote()
// set as connected
link.connected = true
connected = true
link.Unlock()
// save the link once connected
@ -494,9 +512,7 @@ func (t *tun) listen(link *link) {
// the entire listener was closed so remove it from the mapping
if sessionId == "listener" {
link.Lock()
delete(link.channels, channel)
link.Unlock()
link.delChannel(channel)
continue
}
@ -510,10 +526,8 @@ func (t *tun) listen(link *link) {
// otherwise its a session mapping of sorts
case "keepalive":
log.Debugf("Tunnel link %s received keepalive", link.Remote())
link.Lock()
// save the keepalive
link.lastKeepAlive = time.Now()
link.Unlock()
link.keepalive()
continue
// a new connection dialled outbound
case "open":
@ -540,11 +554,7 @@ func (t *tun) listen(link *link) {
channels := strings.Split(channel, ",")
// update mapping in the link
link.Lock()
for _, channel := range channels {
link.channels[channel] = time.Now()
}
link.Unlock()
link.setChannel(channels...)
// this was an announcement not intended for anything
if sessionId == "listener" || sessionId == "" {
@ -904,6 +914,53 @@ func (t *tun) close() error {
return t.listener.Close()
}
// pickLink will pick the best link based on connectivity, delay, rate and length
func (t *tun) pickLink(links []*link) *link {
var metric float64
var chosen *link
// find the best link
for i, link := range links {
// don't use disconnected or errored links
if link.State() != "connected" {
continue
}
// get the link state info
d := float64(link.Delay())
l := float64(link.Length())
r := link.Rate()
// metric = delay x length x rate
m := d * l * r
// first link so just and go
if i == 0 {
metric = m
chosen = link
continue
}
// we found a better metric
if m < metric {
metric = m
chosen = link
}
}
// if there's no link we're just going to mess around
if chosen == nil {
i := rand.Intn(len(links))
return links[i]
}
// we chose the link with;
// the lowest delay e.g least messages queued
// the lowest rate e.g the least messages flowing
// the lowest length e.g the smallest roundtrip time
return chosen
}
func (t *tun) Address() string {
t.RLock()
defer t.RUnlock()
@ -967,42 +1024,32 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) {
c.mode = options.Mode
// set the dial timeout
c.timeout = options.Timeout
// get the current time
now := time.Now()
after := func() time.Duration {
d := time.Since(now)
// dial timeout minus time since
wait := options.Timeout - d
if wait < time.Duration(0) {
return time.Duration(0)
}
return wait
}
var links []string
var links []*link
// did we measure the rtt
var measured bool
// non multicast so we need to find the link
t.RLock()
// non multicast so we need to find the link
for _, link := range t.links {
// use the link specified it its available
if id := options.Link; len(id) > 0 && link.id != id {
continue
}
link.RLock()
_, ok := link.channels[channel]
link.RUnlock()
// get the channel
lastMapped := link.getChannel(channel)
// we have at least one channel mapping
if ok {
if !lastMapped.IsZero() {
links = append(links, link)
c.discovered = true
links = append(links, link.id)
}
}
t.RUnlock()
// link not found
if len(links) == 0 && len(options.Link) > 0 {
// delete session and return error
@ -1015,9 +1062,9 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) {
// TODO: pick the link efficiently based
// on link status and saturation.
if c.discovered && c.mode == Unicast {
// set the link
i := rand.Intn(len(links))
c.link = links[i]
// pickLink will pick the best link
link := t.pickLink(links)
c.link = link.id
}
// shit fuck
@ -1025,57 +1072,8 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) {
// piggy back roundtrip
nowRTT := time.Now()
// create a new discovery message for this channel
msg := c.newMessage("discover")
msg.mode = Broadcast
msg.outbound = true
msg.link = ""
// send the discovery message
t.send <- msg
select {
case <-time.After(after()):
t.delSession(c.channel, c.session)
log.Debugf("Tunnel deleting session %s %s: %v", c.session, c.channel, ErrDialTimeout)
return nil, ErrDialTimeout
case err := <-c.errChan:
if err != nil {
t.delSession(c.channel, c.session)
log.Debugf("Tunnel deleting session %s %s: %v", c.session, c.channel, err)
return nil, err
}
}
var err error
// set a dialTimeout
dialTimeout := after()
// set a shorter delay for multicast
if c.mode != Unicast {
// shorten this
dialTimeout = time.Millisecond * 500
}
// wait for announce
select {
case msg := <-c.recv:
if msg.typ != "announce" {
err = ErrDiscoverChan
}
case <-time.After(dialTimeout):
err = ErrDialTimeout
}
// if its multicast just go ahead because this is best effort
if c.mode != Unicast {
c.discovered = true
c.accepted = true
return c, nil
}
// otherwise return an error
// attempt to discover the link
err := c.Discover()
if err != nil {
t.delSession(c.channel, c.session)
log.Debugf("Tunnel deleting session %s %s: %v", c.session, c.channel, err)
@ -1096,34 +1094,34 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) {
// set measured to true
measured = true
}
// set discovered to true
c.discovered = true
}
// a unicast session so we call "open" and wait for an "accept"
// reset now in case we use it
now = time.Now()
now := time.Now()
// try to open the session
err := c.Open()
if err != nil {
if err := c.Open(); err != nil {
// delete the session
t.delSession(c.channel, c.session)
log.Debugf("Tunnel deleting session %s %s: %v", c.session, c.channel, err)
return nil, err
}
// set time take to open
d := time.Since(now)
// if we haven't measured the roundtrip do it now
if !measured && c.mode == Unicast {
// set the link time
t.RLock()
link, ok := t.links[c.link]
t.RUnlock()
if ok {
// set the rountrip time
link.setRTT(time.Since(now))
link.setRTT(d)
}
}

View File

@ -17,9 +17,11 @@ type link struct {
sync.RWMutex
// stops the link
closed chan bool
// send queue
// link state channel for testing link
state chan *packet
// send queue for sending packets
sendQueue chan *packet
// receive queue
// receive queue for receiving packets
recvQueue chan *packet
// unique id of this link e.g uuid
// which we define for ourselves
@ -44,9 +46,6 @@ type link struct {
rate float64
// keep an error count on the link
errCount int
// link state channel
state chan *packet
}
// packet send over link
@ -73,9 +72,9 @@ func newLink(s transport.Socket) *link {
Socket: s,
id: uuid.New().String(),
lastKeepAlive: time.Now(),
channels: make(map[string]time.Time),
closed: make(chan bool),
state: make(chan *packet, 64),
channels: make(map[string]time.Time),
sendQueue: make(chan *packet, 128),
recvQueue: make(chan *packet, 128),
}
@ -119,6 +118,33 @@ func (l *link) setRTT(d time.Duration) {
l.length = int64(length)
}
func (l *link) delChannel(ch string) {
l.Lock()
delete(l.channels, ch)
l.Unlock()
}
func (l *link) setChannel(channels ...string) {
l.Lock()
for _, ch := range channels {
l.channels[ch] = time.Now()
}
l.Unlock()
}
func (l *link) getChannel(ch string) time.Time {
l.RLock()
defer l.RUnlock()
return l.channels[ch]
}
// set the keepalive time
func (l *link) keepalive() {
l.Lock()
l.lastKeepAlive = time.Now()
l.Unlock()
}
// process deals with the send queue
func (l *link) process() {
// receive messages
@ -176,8 +202,8 @@ func (l *link) manage() {
defer t.Stop()
// used to send link state packets
send := func(b []byte) {
l.Send(&transport.Message{
send := func(b []byte) error {
return l.Send(&transport.Message{
Header: map[string]string{
"Micro-Method": "link",
}, Body: b,
@ -205,7 +231,11 @@ func (l *link) manage() {
case bytes.Compare(p.message.Body, linkRequest) == 0:
log.Tracef("Link %s received link request %v", l.id, p.message.Body)
// send response
send(linkResponse)
if err := send(linkResponse); err != nil {
l.Lock()
l.errCount++
l.Unlock()
}
case bytes.Compare(p.message.Body, linkResponse) == 0:
// set round trip time
d := time.Since(now)
@ -270,7 +300,6 @@ func (l *link) Delay() int64 {
func (l *link) Rate() float64 {
l.RLock()
defer l.RUnlock()
return l.rate
}
@ -279,7 +308,6 @@ func (l *link) Rate() float64 {
func (l *link) Length() int64 {
l.RLock()
defer l.RUnlock()
return l.length
}
@ -398,8 +426,8 @@ func (l *link) Recv(m *transport.Message) error {
return nil
}
// Status can return connected, closed, error
func (l *link) Status() string {
// State can return connected, closed, error
func (l *link) State() string {
select {
case <-l.closed:
return "closed"

View File

@ -106,6 +106,112 @@ func (s *session) newMessage(typ string) *message {
}
}
// waitFor waits for the message type required until the timeout specified
func (s *session) waitFor(msgType string, timeout time.Duration) error {
now := time.Now()
after := func() time.Duration {
d := time.Since(now)
// dial timeout minus time since
wait := timeout - d
if wait < time.Duration(0) {
return time.Duration(0)
}
return wait
}
// wait for the message type
loop:
for {
select {
case msg := <-s.recv:
// ignore what we don't want
if msg.typ != msgType {
log.Debugf("Tunnel received non %s message in waiting for %s", msg.typ, msgType)
continue
}
// got the message
break loop
case <-time.After(after()):
return ErrDialTimeout
case <-s.closed:
return io.EOF
}
}
return nil
}
// Discover attempts to discover the link for a specific channel
func (s *session) Discover() error {
// create a new discovery message for this channel
msg := s.newMessage("discover")
msg.mode = Broadcast
msg.outbound = true
msg.link = ""
// send the discovery message
s.send <- msg
// set time now
now := time.Now()
after := func() time.Duration {
d := time.Since(now)
// dial timeout minus time since
wait := s.timeout - d
if wait < time.Duration(0) {
return time.Duration(0)
}
return wait
}
// wait to hear back about the sent message
select {
case <-time.After(after()):
return ErrDialTimeout
case err := <-s.errChan:
if err != nil {
return err
}
}
var err error
// set a new dialTimeout
dialTimeout := after()
// set a shorter delay for multicast
if s.mode != Unicast {
// shorten this
dialTimeout = time.Millisecond * 500
}
// wait for announce
if err := s.waitFor("announce", dialTimeout); err != nil {
return err
}
// if its multicast just go ahead because this is best effort
if s.mode != Unicast {
s.discovered = true
s.accepted = true
return nil
}
if err != nil {
return err
}
// set discovered
s.discovered = true
return nil
}
// Open will fire the open message for the session. This is called by the dialler.
func (s *session) Open() error {
// create a new message
@ -131,21 +237,14 @@ func (s *session) Open() error {
}
// now wait for the accept
select {
case msg = <-s.recv:
if msg.typ != "accept" {
log.Debugf("Received non accept message in Open %s", msg.typ)
return errors.New("failed to connect")
if err := s.waitFor("accept", s.timeout); err != nil {
return err
}
// set to accepted
s.accepted = true
// set link
s.link = msg.link
case <-time.After(s.timeout):
return ErrDialTimeout
case <-s.closed:
return io.EOF
}
return nil
}

View File

@ -63,8 +63,8 @@ type Link interface {
Length() int64
// Current transfer rate as bits per second (lower is better)
Rate() float64
// Status of the link e.g connected/closed
Status() string
// State of the link e.g connected/closed
State() string
// honours transport socket
transport.Socket
}