From 3831199600ad3f889aa536f8240e61b5992892d7 Mon Sep 17 00:00:00 2001 From: Asim Aslam Date: Fri, 25 Oct 2019 14:16:22 +0100 Subject: [PATCH] Use best link in tunnel, loop waiting for announce and accept messages, cleanup some code --- tunnel/default.go | 202 +++++++++++++++++++++++----------------------- tunnel/link.go | 54 ++++++++++--- tunnel/session.go | 127 +++++++++++++++++++++++++---- tunnel/tunnel.go | 4 +- 4 files changed, 256 insertions(+), 131 deletions(-) diff --git a/tunnel/default.go b/tunnel/default.go index a7555eb1..0b064c36 100644 --- a/tunnel/default.go +++ b/tunnel/default.go @@ -90,6 +90,7 @@ func (t *tun) getSession(channel, session string) (*session, bool) { return s, ok } +// delSession deletes a session if it exists func (t *tun) delSession(channel, session string) { t.Lock() delete(t.sessions, channel+session) @@ -146,6 +147,9 @@ func (t *tun) newSessionId() string { return uuid.New().String() } +// announce will send a message to the link to tell the other side of a channel mapping we have. +// This usually happens if someone calls Dial and sends a discover message but otherwise we +// periodically send these messages to asynchronously manage channel mappings. func (t *tun) announce(channel, session string, link *link) { // create the "announce" response message for a discover request msg := &transport.Message{ @@ -206,7 +210,7 @@ func (t *tun) monitor() { // check the link status and purge dead links for node, link := range t.links { // check link status - switch link.Status() { + switch link.State() { case "closed": delLinks = append(delLinks, node) case "error": @@ -303,8 +307,16 @@ func (t *tun) process() { // build the list of links ot send to for node, link := range t.links { + // get the values we need + link.RLock() + id := link.id + connected := link.connected + loopback := link.loopback + _, exists := link.channels[msg.channel] + link.RUnlock() + // if the link is not connected skip it - if !link.connected { + if !connected { log.Debugf("Link for node %s not connected", node) err = errors.New("link not connected") continue @@ -313,32 +325,29 @@ func (t *tun) process() { // if the link was a loopback accepted connection // and the message is being sent outbound via // a dialled connection don't use this link - if link.loopback && msg.outbound { + if loopback && msg.outbound { err = errors.New("link is loopback") continue } // if the message was being returned by the loopback listener // send it back up the loopback link only - if msg.loopback && !link.loopback { + if msg.loopback && !loopback { err = errors.New("link is not loopback") continue } // check the multicast mappings if msg.mode == Multicast { - link.RLock() - _, ok := link.channels[msg.channel] - link.RUnlock() // channel mapping not found in link - if !ok { + if !exists { continue } } else { // if we're picking the link check the id // this is where we explicitly set the link // in a message received via the listen method - if len(msg.link) > 0 && link.id != msg.link { + if len(msg.link) > 0 && id != msg.link { err = errors.New("link not found") continue } @@ -422,6 +431,12 @@ func (t *tun) listen(link *link) { // let us know if its a loopback var loopback bool + var connected bool + + // set the connected value + link.RLock() + connected = link.connected + link.RUnlock() for { // process anything via the net interface @@ -451,7 +466,7 @@ func (t *tun) listen(link *link) { // if its not connected throw away the link // the first message we process needs to be connect - if !link.connected && mtype != "connect" { + if !connected && mtype != "connect" { log.Debugf("Tunnel link %s not connected", link.id) return } @@ -461,7 +476,8 @@ func (t *tun) listen(link *link) { log.Debugf("Tunnel link %s received connect message", link.Remote()) link.Lock() - // are we connecting to ourselves? + + // check if we're connecting to ourselves? if id == t.id { link.loopback = true loopback = true @@ -471,6 +487,8 @@ func (t *tun) listen(link *link) { link.id = link.Remote() // set as connected link.connected = true + connected = true + link.Unlock() // save the link once connected @@ -494,9 +512,7 @@ func (t *tun) listen(link *link) { // the entire listener was closed so remove it from the mapping if sessionId == "listener" { - link.Lock() - delete(link.channels, channel) - link.Unlock() + link.delChannel(channel) continue } @@ -510,10 +526,8 @@ func (t *tun) listen(link *link) { // otherwise its a session mapping of sorts case "keepalive": log.Debugf("Tunnel link %s received keepalive", link.Remote()) - link.Lock() // save the keepalive - link.lastKeepAlive = time.Now() - link.Unlock() + link.keepalive() continue // a new connection dialled outbound case "open": @@ -540,11 +554,7 @@ func (t *tun) listen(link *link) { channels := strings.Split(channel, ",") // update mapping in the link - link.Lock() - for _, channel := range channels { - link.channels[channel] = time.Now() - } - link.Unlock() + link.setChannel(channels...) // this was an announcement not intended for anything if sessionId == "listener" || sessionId == "" { @@ -904,6 +914,53 @@ func (t *tun) close() error { return t.listener.Close() } +// pickLink will pick the best link based on connectivity, delay, rate and length +func (t *tun) pickLink(links []*link) *link { + var metric float64 + var chosen *link + + // find the best link + for i, link := range links { + // don't use disconnected or errored links + if link.State() != "connected" { + continue + } + + // get the link state info + d := float64(link.Delay()) + l := float64(link.Length()) + r := link.Rate() + + // metric = delay x length x rate + m := d * l * r + + // first link so just and go + if i == 0 { + metric = m + chosen = link + continue + } + + // we found a better metric + if m < metric { + metric = m + chosen = link + } + } + + // if there's no link we're just going to mess around + if chosen == nil { + i := rand.Intn(len(links)) + return links[i] + } + + // we chose the link with; + // the lowest delay e.g least messages queued + // the lowest rate e.g the least messages flowing + // the lowest length e.g the smallest roundtrip time + return chosen +} + func (t *tun) Address() string { t.RLock() defer t.RUnlock() @@ -967,42 +1024,32 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { c.mode = options.Mode // set the dial timeout c.timeout = options.Timeout - // get the current time - now := time.Now() - after := func() time.Duration { - d := time.Since(now) - // dial timeout minus time since - wait := options.Timeout - d - if wait < time.Duration(0) { - return time.Duration(0) - } - return wait - } - - var links []string + var links []*link // did we measure the rtt var measured bool - // non multicast so we need to find the link t.RLock() + + // non multicast so we need to find the link for _, link := range t.links { // use the link specified it its available if id := options.Link; len(id) > 0 && link.id != id { continue } - link.RLock() - _, ok := link.channels[channel] - link.RUnlock() + // get the channel + lastMapped := link.getChannel(channel) // we have at least one channel mapping - if ok { + if !lastMapped.IsZero() { + links = append(links, link) c.discovered = true - links = append(links, link.id) } } + t.RUnlock() + // link not found if len(links) == 0 && len(options.Link) > 0 { // delete session and return error @@ -1015,9 +1062,9 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { // TODO: pick the link efficiently based // on link status and saturation. if c.discovered && c.mode == Unicast { - // set the link - i := rand.Intn(len(links)) - c.link = links[i] + // pickLink will pick the best link + link := t.pickLink(links) + c.link = link.id } // shit fuck @@ -1025,57 +1072,8 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { // piggy back roundtrip nowRTT := time.Now() - // create a new discovery message for this channel - msg := c.newMessage("discover") - msg.mode = Broadcast - msg.outbound = true - msg.link = "" - - // send the discovery message - t.send <- msg - - select { - case <-time.After(after()): - t.delSession(c.channel, c.session) - log.Debugf("Tunnel deleting session %s %s: %v", c.session, c.channel, ErrDialTimeout) - return nil, ErrDialTimeout - case err := <-c.errChan: - if err != nil { - t.delSession(c.channel, c.session) - log.Debugf("Tunnel deleting session %s %s: %v", c.session, c.channel, err) - return nil, err - } - } - - var err error - - // set a dialTimeout - dialTimeout := after() - - // set a shorter delay for multicast - if c.mode != Unicast { - // shorten this - dialTimeout = time.Millisecond * 500 - } - - // wait for announce - select { - case msg := <-c.recv: - if msg.typ != "announce" { - err = ErrDiscoverChan - } - case <-time.After(dialTimeout): - err = ErrDialTimeout - } - - // if its multicast just go ahead because this is best effort - if c.mode != Unicast { - c.discovered = true - c.accepted = true - return c, nil - } - - // otherwise return an error + // attempt to discover the link + err := c.Discover() if err != nil { t.delSession(c.channel, c.session) log.Debugf("Tunnel deleting session %s %s: %v", c.session, c.channel, err) @@ -1096,34 +1094,34 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { // set measured to true measured = true } - - // set discovered to true - c.discovered = true } // a unicast session so we call "open" and wait for an "accept" // reset now in case we use it - now = time.Now() + now := time.Now() // try to open the session - err := c.Open() - if err != nil { + if err := c.Open(); err != nil { // delete the session t.delSession(c.channel, c.session) log.Debugf("Tunnel deleting session %s %s: %v", c.session, c.channel, err) return nil, err } + // set time take to open + d := time.Since(now) + // if we haven't measured the roundtrip do it now if !measured && c.mode == Unicast { // set the link time t.RLock() link, ok := t.links[c.link] t.RUnlock() + if ok { // set the rountrip time - link.setRTT(time.Since(now)) + link.setRTT(d) } } diff --git a/tunnel/link.go b/tunnel/link.go index 5c53c91a..a319c214 100644 --- a/tunnel/link.go +++ b/tunnel/link.go @@ -17,9 +17,11 @@ type link struct { sync.RWMutex // stops the link closed chan bool - // send queue + // link state channel for testing link + state chan *packet + // send queue for sending packets sendQueue chan *packet - // receive queue + // receive queue for receiving packets recvQueue chan *packet // unique id of this link e.g uuid // which we define for ourselves @@ -44,9 +46,6 @@ type link struct { rate float64 // keep an error count on the link errCount int - - // link state channel - state chan *packet } // packet send over link @@ -73,9 +72,9 @@ func newLink(s transport.Socket) *link { Socket: s, id: uuid.New().String(), lastKeepAlive: time.Now(), + channels: make(map[string]time.Time), closed: make(chan bool), state: make(chan *packet, 64), - channels: make(map[string]time.Time), sendQueue: make(chan *packet, 128), recvQueue: make(chan *packet, 128), } @@ -119,6 +118,33 @@ func (l *link) setRTT(d time.Duration) { l.length = int64(length) } +func (l *link) delChannel(ch string) { + l.Lock() + delete(l.channels, ch) + l.Unlock() +} + +func (l *link) setChannel(channels ...string) { + l.Lock() + for _, ch := range channels { + l.channels[ch] = time.Now() + } + l.Unlock() +} + +func (l *link) getChannel(ch string) time.Time { + l.RLock() + defer l.RUnlock() + return l.channels[ch] +} + +// set the keepalive time +func (l *link) keepalive() { + l.Lock() + l.lastKeepAlive = time.Now() + l.Unlock() +} + // process deals with the send queue func (l *link) process() { // receive messages @@ -176,8 +202,8 @@ func (l *link) manage() { defer t.Stop() // used to send link state packets - send := func(b []byte) { - l.Send(&transport.Message{ + send := func(b []byte) error { + return l.Send(&transport.Message{ Header: map[string]string{ "Micro-Method": "link", }, Body: b, @@ -205,7 +231,11 @@ func (l *link) manage() { case bytes.Compare(p.message.Body, linkRequest) == 0: log.Tracef("Link %s received link request %v", l.id, p.message.Body) // send response - send(linkResponse) + if err := send(linkResponse); err != nil { + l.Lock() + l.errCount++ + l.Unlock() + } case bytes.Compare(p.message.Body, linkResponse) == 0: // set round trip time d := time.Since(now) @@ -270,7 +300,6 @@ func (l *link) Delay() int64 { func (l *link) Rate() float64 { l.RLock() defer l.RUnlock() - return l.rate } @@ -279,7 +308,6 @@ func (l *link) Rate() float64 { func (l *link) Length() int64 { l.RLock() defer l.RUnlock() - return l.length } @@ -398,8 +426,8 @@ func (l *link) Recv(m *transport.Message) error { return nil } -// Status can return connected, closed, error -func (l *link) Status() string { +// State can return connected, closed, error +func (l *link) State() string { select { case <-l.closed: return "closed" diff --git a/tunnel/session.go b/tunnel/session.go index a185c3ce..f1e65184 100644 --- a/tunnel/session.go +++ b/tunnel/session.go @@ -106,6 +106,112 @@ func (s *session) newMessage(typ string) *message { } } +// waitFor waits for the message type required until the timeout specified +func (s *session) waitFor(msgType string, timeout time.Duration) error { + now := time.Now() + + after := func() time.Duration { + d := time.Since(now) + // dial timeout minus time since + wait := timeout - d + + if wait < time.Duration(0) { + return time.Duration(0) + } + + return wait + } + + // wait for the message type +loop: + for { + select { + case msg := <-s.recv: + // ignore what we don't want + if msg.typ != msgType { + log.Debugf("Tunnel received non %s message in waiting for %s", msg.typ, msgType) + continue + } + + // got the message + break loop + case <-time.After(after()): + return ErrDialTimeout + case <-s.closed: + return io.EOF + } + } + + return nil +} + +// Discover attempts to discover the link for a specific channel +func (s *session) Discover() error { + // create a new discovery message for this channel + msg := s.newMessage("discover") + msg.mode = Broadcast + msg.outbound = true + msg.link = "" + + // send the discovery message + s.send <- msg + + // set time now + now := time.Now() + + after := func() time.Duration { + d := time.Since(now) + // dial timeout minus time since + wait := s.timeout - d + if wait < time.Duration(0) { + return time.Duration(0) + } + return wait + } + + // wait to hear back about the sent message + select { + case <-time.After(after()): + return ErrDialTimeout + case err := <-s.errChan: + if err != nil { + return err + } + } + + var err error + + // set a new dialTimeout + dialTimeout := after() + + // set a shorter delay for multicast + if s.mode != Unicast { + // shorten this + dialTimeout = time.Millisecond * 500 + } + + // wait for announce + if err := s.waitFor("announce", dialTimeout); err != nil { + return err + } + + // if its multicast just go ahead because this is best effort + if s.mode != Unicast { + s.discovered = true + s.accepted = true + return nil + } + + if err != nil { + return err + } + + // set discovered + s.discovered = true + + return nil +} + // Open will fire the open message for the session. This is called by the dialler. func (s *session) Open() error { // create a new message @@ -131,22 +237,15 @@ func (s *session) Open() error { } // now wait for the accept - select { - case msg = <-s.recv: - if msg.typ != "accept" { - log.Debugf("Received non accept message in Open %s", msg.typ) - return errors.New("failed to connect") - } - // set to accepted - s.accepted = true - // set link - s.link = msg.link - case <-time.After(s.timeout): - return ErrDialTimeout - case <-s.closed: - return io.EOF + if err := s.waitFor("accept", s.timeout); err != nil { + return err } + // set to accepted + s.accepted = true + // set link + s.link = msg.link + return nil } diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 73c937f3..212e07e4 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -63,8 +63,8 @@ type Link interface { Length() int64 // Current transfer rate as bits per second (lower is better) Rate() float64 - // Status of the link e.g connected/closed - Status() string + // State of the link e.g connected/closed + State() string // honours transport socket transport.Socket }