From f26d470db1d7b3be4bb6e3d8b8223eb96999afdf Mon Sep 17 00:00:00 2001 From: Asim Aslam Date: Thu, 24 Oct 2019 17:51:41 +0100 Subject: [PATCH 1/4] A few changes for the network / tunnel link state --- network/default.go | 28 ++++++++---- proxy/mucp/mucp.go | 3 ++ tunnel/default.go | 11 ++--- tunnel/link.go | 109 +++++++++++++++++++++++++++++++++++++-------- 4 files changed, 119 insertions(+), 32 deletions(-) diff --git a/network/default.go b/network/default.go index 741c4fb8..57fcbc26 100644 --- a/network/default.go +++ b/network/default.go @@ -677,19 +677,19 @@ func (n *network) getHopCount(rtr string) int { // the route origin is our peer if _, ok := n.peers[rtr]; ok { - return 2 + return 10 } // the route origin is the peer of our peer for _, peer := range n.peers { for id := range peer.peers { if rtr == id { - return 3 + return 100 } } } // otherwise we are three hops away - return 4 + return 1000 } // getRouteMetric calculates router metric and returns it @@ -721,11 +721,15 @@ func (n *network) getRouteMetric(router string, gateway string, link string) int // make sure length is non-zero length := link.Length() if length == 0 { - length = 10e10 + log.Debugf("Link length is 0 %v %v", link, link.Length()) + length = 10e9 } - return (delay * length * int64(hops)) / 10e9 + log.Debugf("Network calculated metric %v delay %v length %v distance %v", (delay*length*int64(hops))/10e6, delay, length, hops) + return (delay * length * int64(hops)) / 10e6 } + log.Debugf("Network failed to find a link to gateway: %s", gateway) + return math.MaxInt64 } @@ -783,12 +787,18 @@ func (n *network) processCtrlChan(listener tunnel.Listener) { } // calculate route metric and add to the advertised metric // we need to make sure we do not overflow math.MaxInt64 - log.Debugf("Network metric for router %s and gateway %s", event.Route.Router, event.Route.Gateway) - if metric := n.getRouteMetric(event.Route.Router, event.Route.Gateway, event.Route.Link); metric != math.MaxInt64 { - route.Metric += metric + metric := n.getRouteMetric(event.Route.Router, event.Route.Gateway, event.Route.Link) + log.Debugf("Network metric for router %s and gateway %s: %v", event.Route.Router, event.Route.Gateway, metric) + + // check we don't overflow max int 64 + if d := route.Metric + metric; d > math.MaxInt64 || d <= 0 { + // set to max int64 if we overflow + route.Metric = math.MaxInt64 } else { - route.Metric = metric + // set the combined value of metrics otherwise + route.Metric = d } + // create router event e := &router.Event{ Type: router.EventType(event.Type), diff --git a/proxy/mucp/mucp.go b/proxy/mucp/mucp.go index b336e5f3..06ab676f 100644 --- a/proxy/mucp/mucp.go +++ b/proxy/mucp/mucp.go @@ -19,6 +19,7 @@ import ( "github.com/micro/go-micro/proxy" "github.com/micro/go-micro/router" "github.com/micro/go-micro/server" + "github.com/micro/go-micro/util/log" ) // Proxy will transparently proxy requests to an endpoint. @@ -294,6 +295,8 @@ func (p *Proxy) ServeRequest(ctx context.Context, req server.Request, rsp server continue } + log.Debugf("Proxy using route %+v\n", route) + // set the address to call addresses := toNodes([]router.Route{route}) opts = append(opts, client.WithAddress(addresses...)) diff --git a/tunnel/default.go b/tunnel/default.go index 0d16ae90..a7555eb1 100644 --- a/tunnel/default.go +++ b/tunnel/default.go @@ -746,8 +746,13 @@ func (t *tun) setupLink(node string) (*link, error) { } log.Debugf("Tunnel connected to %s", node) + // create a new link + link := newLink(c) + // set link id to remote side + link.id = c.Remote() + // send the first connect message - if err := c.Send(&transport.Message{ + if err := link.Send(&transport.Message{ Header: map[string]string{ "Micro-Tunnel": "connect", "Micro-Tunnel-Id": t.id, @@ -757,10 +762,6 @@ func (t *tun) setupLink(node string) (*link, error) { return nil, err } - // create a new link - link := newLink(c) - // set link id to remote side - link.id = c.Remote() // we made the outbound connection // and sent the connect message link.connected = true diff --git a/tunnel/link.go b/tunnel/link.go index 26b568fc..5c53c91a 100644 --- a/tunnel/link.go +++ b/tunnel/link.go @@ -1,12 +1,14 @@ package tunnel import ( + "bytes" "io" "sync" "time" "github.com/google/uuid" "github.com/micro/go-micro/transport" + "github.com/micro/go-micro/util/log" ) type link struct { @@ -42,6 +44,9 @@ type link struct { rate float64 // keep an error count on the link errCount int + + // link state channel + state chan *packet } // packet send over link @@ -56,21 +61,49 @@ type packet struct { err error } +var ( + // the 4 byte 0 packet sent to determine the link state + linkRequest = []byte{0, 0, 0, 0} + // the 4 byte 1 filled packet sent to determine link state + linkResponse = []byte{1, 1, 1, 1} +) + func newLink(s transport.Socket) *link { l := &link{ Socket: s, id: uuid.New().String(), lastKeepAlive: time.Now(), closed: make(chan bool), + state: make(chan *packet, 64), channels: make(map[string]time.Time), sendQueue: make(chan *packet, 128), recvQueue: make(chan *packet, 128), } + + // process inbound/outbound packets go l.process() - go l.expiry() + // manage the link state + go l.manage() + return l } +// setRate sets the bits per second rate as a float64 +func (l *link) setRate(bits int64, delta time.Duration) { + // rate of send in bits per nanosecond + rate := float64(bits) / float64(delta.Nanoseconds()) + + // default the rate if its zero + if l.rate == 0 { + // rate per second + l.rate = rate * 1e9 + } else { + // set new rate per second + l.rate = 0.8*l.rate + 0.2*(rate*1e9) + } +} + +// setRTT sets a nanosecond based moving average roundtrip time for the link func (l *link) setRTT(d time.Duration) { l.Lock() defer l.Unlock() @@ -101,8 +134,22 @@ func (l *link) process() { // process new received message + pk := &packet{message: m, err: err} + + // this is our link state packet + if m.Header["Micro-Method"] == "link" { + // process link state message + select { + case l.state <- pk: + default: + } + continue + } + + // process all messages as is + select { - case l.recvQueue <- &packet{message: m, err: err}: + case l.recvQueue <- pk: case <-l.closed: return } @@ -122,15 +169,49 @@ func (l *link) process() { } } -// watches the channel expiry -func (l *link) expiry() { +// manage manages the link state including rtt packets and channel mapping expiry +func (l *link) manage() { + // tick over every minute to expire and fire rtt packets t := time.NewTicker(time.Minute) defer t.Stop() + // used to send link state packets + send := func(b []byte) { + l.Send(&transport.Message{ + Header: map[string]string{ + "Micro-Method": "link", + }, Body: b, + }) + } + + // set time now + now := time.Now() + + // send the initial rtt request packet + send(linkRequest) + for { select { + // exit if closed case <-l.closed: return + // process link state rtt packets + case p := <-l.state: + if p.err != nil { + continue + } + // check the type of message + switch { + case bytes.Compare(p.message.Body, linkRequest) == 0: + log.Tracef("Link %s received link request %v", l.id, p.message.Body) + // send response + send(linkResponse) + case bytes.Compare(p.message.Body, linkResponse) == 0: + // set round trip time + d := time.Since(now) + log.Tracef("Link %s received link response in %v", p.message.Body, d) + l.setRTT(d) + } case <-t.C: // drop any channel mappings older than 2 minutes var kill []string @@ -155,6 +236,10 @@ func (l *link) expiry() { delete(l.channels, ch) } l.Unlock() + + // fire off a link state rtt packet + now = time.Now() + send(linkRequest) } } } @@ -278,23 +363,11 @@ func (l *link) Send(m *transport.Message) error { // calculate based on data if dataSent > 0 { - // measure time taken - delta := time.Since(now) - // bit sent bits := dataSent * 1024 - // rate of send in bits per nanosecond - rate := float64(bits) / float64(delta.Nanoseconds()) - - // default the rate if its zero - if l.rate == 0 { - // rate per second - l.rate = rate * 1e9 - } else { - // set new rate per second - l.rate = 0.8*l.rate + 0.2*(rate*1e9) - } + // set the rate + l.setRate(int64(bits), time.Since(now)) } return nil From 3831199600ad3f889aa536f8240e61b5992892d7 Mon Sep 17 00:00:00 2001 From: Asim Aslam Date: Fri, 25 Oct 2019 14:16:22 +0100 Subject: [PATCH 2/4] Use best link in tunnel, loop waiting for announce and accept messages, cleanup some code --- tunnel/default.go | 202 +++++++++++++++++++++++----------------------- tunnel/link.go | 54 ++++++++++--- tunnel/session.go | 127 +++++++++++++++++++++++++---- tunnel/tunnel.go | 4 +- 4 files changed, 256 insertions(+), 131 deletions(-) diff --git a/tunnel/default.go b/tunnel/default.go index a7555eb1..0b064c36 100644 --- a/tunnel/default.go +++ b/tunnel/default.go @@ -90,6 +90,7 @@ func (t *tun) getSession(channel, session string) (*session, bool) { return s, ok } +// delSession deletes a session if it exists func (t *tun) delSession(channel, session string) { t.Lock() delete(t.sessions, channel+session) @@ -146,6 +147,9 @@ func (t *tun) newSessionId() string { return uuid.New().String() } +// announce will send a message to the link to tell the other side of a channel mapping we have. +// This usually happens if someone calls Dial and sends a discover message but otherwise we +// periodically send these messages to asynchronously manage channel mappings. func (t *tun) announce(channel, session string, link *link) { // create the "announce" response message for a discover request msg := &transport.Message{ @@ -206,7 +210,7 @@ func (t *tun) monitor() { // check the link status and purge dead links for node, link := range t.links { // check link status - switch link.Status() { + switch link.State() { case "closed": delLinks = append(delLinks, node) case "error": @@ -303,8 +307,16 @@ func (t *tun) process() { // build the list of links ot send to for node, link := range t.links { + // get the values we need + link.RLock() + id := link.id + connected := link.connected + loopback := link.loopback + _, exists := link.channels[msg.channel] + link.RUnlock() + // if the link is not connected skip it - if !link.connected { + if !connected { log.Debugf("Link for node %s not connected", node) err = errors.New("link not connected") continue @@ -313,32 +325,29 @@ func (t *tun) process() { // if the link was a loopback accepted connection // and the message is being sent outbound via // a dialled connection don't use this link - if link.loopback && msg.outbound { + if loopback && msg.outbound { err = errors.New("link is loopback") continue } // if the message was being returned by the loopback listener // send it back up the loopback link only - if msg.loopback && !link.loopback { + if msg.loopback && !loopback { err = errors.New("link is not loopback") continue } // check the multicast mappings if msg.mode == Multicast { - link.RLock() - _, ok := link.channels[msg.channel] - link.RUnlock() // channel mapping not found in link - if !ok { + if !exists { continue } } else { // if we're picking the link check the id // this is where we explicitly set the link // in a message received via the listen method - if len(msg.link) > 0 && link.id != msg.link { + if len(msg.link) > 0 && id != msg.link { err = errors.New("link not found") continue } @@ -422,6 +431,12 @@ func (t *tun) listen(link *link) { // let us know if its a loopback var loopback bool + var connected bool + + // set the connected value + link.RLock() + connected = link.connected + link.RUnlock() for { // process anything via the net interface @@ -451,7 +466,7 @@ func (t *tun) listen(link *link) { // if its not connected throw away the link // the first message we process needs to be connect - if !link.connected && mtype != "connect" { + if !connected && mtype != "connect" { log.Debugf("Tunnel link %s not connected", link.id) return } @@ -461,7 +476,8 @@ func (t *tun) listen(link *link) { log.Debugf("Tunnel link %s received connect message", link.Remote()) link.Lock() - // are we connecting to ourselves? + + // check if we're connecting to ourselves? if id == t.id { link.loopback = true loopback = true @@ -471,6 +487,8 @@ func (t *tun) listen(link *link) { link.id = link.Remote() // set as connected link.connected = true + connected = true + link.Unlock() // save the link once connected @@ -494,9 +512,7 @@ func (t *tun) listen(link *link) { // the entire listener was closed so remove it from the mapping if sessionId == "listener" { - link.Lock() - delete(link.channels, channel) - link.Unlock() + link.delChannel(channel) continue } @@ -510,10 +526,8 @@ func (t *tun) listen(link *link) { // otherwise its a session mapping of sorts case "keepalive": log.Debugf("Tunnel link %s received keepalive", link.Remote()) - link.Lock() // save the keepalive - link.lastKeepAlive = time.Now() - link.Unlock() + link.keepalive() continue // a new connection dialled outbound case "open": @@ -540,11 +554,7 @@ func (t *tun) listen(link *link) { channels := strings.Split(channel, ",") // update mapping in the link - link.Lock() - for _, channel := range channels { - link.channels[channel] = time.Now() - } - link.Unlock() + link.setChannel(channels...) // this was an announcement not intended for anything if sessionId == "listener" || sessionId == "" { @@ -904,6 +914,53 @@ func (t *tun) close() error { return t.listener.Close() } +// pickLink will pick the best link based on connectivity, delay, rate and length +func (t *tun) pickLink(links []*link) *link { + var metric float64 + var chosen *link + + // find the best link + for i, link := range links { + // don't use disconnected or errored links + if link.State() != "connected" { + continue + } + + // get the link state info + d := float64(link.Delay()) + l := float64(link.Length()) + r := link.Rate() + + // metric = delay x length x rate + m := d * l * r + + // first link so just and go + if i == 0 { + metric = m + chosen = link + continue + } + + // we found a better metric + if m < metric { + metric = m + chosen = link + } + } + + // if there's no link we're just going to mess around + if chosen == nil { + i := rand.Intn(len(links)) + return links[i] + } + + // we chose the link with; + // the lowest delay e.g least messages queued + // the lowest rate e.g the least messages flowing + // the lowest length e.g the smallest roundtrip time + return chosen +} + func (t *tun) Address() string { t.RLock() defer t.RUnlock() @@ -967,42 +1024,32 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { c.mode = options.Mode // set the dial timeout c.timeout = options.Timeout - // get the current time - now := time.Now() - after := func() time.Duration { - d := time.Since(now) - // dial timeout minus time since - wait := options.Timeout - d - if wait < time.Duration(0) { - return time.Duration(0) - } - return wait - } - - var links []string + var links []*link // did we measure the rtt var measured bool - // non multicast so we need to find the link t.RLock() + + // non multicast so we need to find the link for _, link := range t.links { // use the link specified it its available if id := options.Link; len(id) > 0 && link.id != id { continue } - link.RLock() - _, ok := link.channels[channel] - link.RUnlock() + // get the channel + lastMapped := link.getChannel(channel) // we have at least one channel mapping - if ok { + if !lastMapped.IsZero() { + links = append(links, link) c.discovered = true - links = append(links, link.id) } } + t.RUnlock() + // link not found if len(links) == 0 && len(options.Link) > 0 { // delete session and return error @@ -1015,9 +1062,9 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { // TODO: pick the link efficiently based // on link status and saturation. if c.discovered && c.mode == Unicast { - // set the link - i := rand.Intn(len(links)) - c.link = links[i] + // pickLink will pick the best link + link := t.pickLink(links) + c.link = link.id } // shit fuck @@ -1025,57 +1072,8 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { // piggy back roundtrip nowRTT := time.Now() - // create a new discovery message for this channel - msg := c.newMessage("discover") - msg.mode = Broadcast - msg.outbound = true - msg.link = "" - - // send the discovery message - t.send <- msg - - select { - case <-time.After(after()): - t.delSession(c.channel, c.session) - log.Debugf("Tunnel deleting session %s %s: %v", c.session, c.channel, ErrDialTimeout) - return nil, ErrDialTimeout - case err := <-c.errChan: - if err != nil { - t.delSession(c.channel, c.session) - log.Debugf("Tunnel deleting session %s %s: %v", c.session, c.channel, err) - return nil, err - } - } - - var err error - - // set a dialTimeout - dialTimeout := after() - - // set a shorter delay for multicast - if c.mode != Unicast { - // shorten this - dialTimeout = time.Millisecond * 500 - } - - // wait for announce - select { - case msg := <-c.recv: - if msg.typ != "announce" { - err = ErrDiscoverChan - } - case <-time.After(dialTimeout): - err = ErrDialTimeout - } - - // if its multicast just go ahead because this is best effort - if c.mode != Unicast { - c.discovered = true - c.accepted = true - return c, nil - } - - // otherwise return an error + // attempt to discover the link + err := c.Discover() if err != nil { t.delSession(c.channel, c.session) log.Debugf("Tunnel deleting session %s %s: %v", c.session, c.channel, err) @@ -1096,34 +1094,34 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { // set measured to true measured = true } - - // set discovered to true - c.discovered = true } // a unicast session so we call "open" and wait for an "accept" // reset now in case we use it - now = time.Now() + now := time.Now() // try to open the session - err := c.Open() - if err != nil { + if err := c.Open(); err != nil { // delete the session t.delSession(c.channel, c.session) log.Debugf("Tunnel deleting session %s %s: %v", c.session, c.channel, err) return nil, err } + // set time take to open + d := time.Since(now) + // if we haven't measured the roundtrip do it now if !measured && c.mode == Unicast { // set the link time t.RLock() link, ok := t.links[c.link] t.RUnlock() + if ok { // set the rountrip time - link.setRTT(time.Since(now)) + link.setRTT(d) } } diff --git a/tunnel/link.go b/tunnel/link.go index 5c53c91a..a319c214 100644 --- a/tunnel/link.go +++ b/tunnel/link.go @@ -17,9 +17,11 @@ type link struct { sync.RWMutex // stops the link closed chan bool - // send queue + // link state channel for testing link + state chan *packet + // send queue for sending packets sendQueue chan *packet - // receive queue + // receive queue for receiving packets recvQueue chan *packet // unique id of this link e.g uuid // which we define for ourselves @@ -44,9 +46,6 @@ type link struct { rate float64 // keep an error count on the link errCount int - - // link state channel - state chan *packet } // packet send over link @@ -73,9 +72,9 @@ func newLink(s transport.Socket) *link { Socket: s, id: uuid.New().String(), lastKeepAlive: time.Now(), + channels: make(map[string]time.Time), closed: make(chan bool), state: make(chan *packet, 64), - channels: make(map[string]time.Time), sendQueue: make(chan *packet, 128), recvQueue: make(chan *packet, 128), } @@ -119,6 +118,33 @@ func (l *link) setRTT(d time.Duration) { l.length = int64(length) } +func (l *link) delChannel(ch string) { + l.Lock() + delete(l.channels, ch) + l.Unlock() +} + +func (l *link) setChannel(channels ...string) { + l.Lock() + for _, ch := range channels { + l.channels[ch] = time.Now() + } + l.Unlock() +} + +func (l *link) getChannel(ch string) time.Time { + l.RLock() + defer l.RUnlock() + return l.channels[ch] +} + +// set the keepalive time +func (l *link) keepalive() { + l.Lock() + l.lastKeepAlive = time.Now() + l.Unlock() +} + // process deals with the send queue func (l *link) process() { // receive messages @@ -176,8 +202,8 @@ func (l *link) manage() { defer t.Stop() // used to send link state packets - send := func(b []byte) { - l.Send(&transport.Message{ + send := func(b []byte) error { + return l.Send(&transport.Message{ Header: map[string]string{ "Micro-Method": "link", }, Body: b, @@ -205,7 +231,11 @@ func (l *link) manage() { case bytes.Compare(p.message.Body, linkRequest) == 0: log.Tracef("Link %s received link request %v", l.id, p.message.Body) // send response - send(linkResponse) + if err := send(linkResponse); err != nil { + l.Lock() + l.errCount++ + l.Unlock() + } case bytes.Compare(p.message.Body, linkResponse) == 0: // set round trip time d := time.Since(now) @@ -270,7 +300,6 @@ func (l *link) Delay() int64 { func (l *link) Rate() float64 { l.RLock() defer l.RUnlock() - return l.rate } @@ -279,7 +308,6 @@ func (l *link) Rate() float64 { func (l *link) Length() int64 { l.RLock() defer l.RUnlock() - return l.length } @@ -398,8 +426,8 @@ func (l *link) Recv(m *transport.Message) error { return nil } -// Status can return connected, closed, error -func (l *link) Status() string { +// State can return connected, closed, error +func (l *link) State() string { select { case <-l.closed: return "closed" diff --git a/tunnel/session.go b/tunnel/session.go index a185c3ce..f1e65184 100644 --- a/tunnel/session.go +++ b/tunnel/session.go @@ -106,6 +106,112 @@ func (s *session) newMessage(typ string) *message { } } +// waitFor waits for the message type required until the timeout specified +func (s *session) waitFor(msgType string, timeout time.Duration) error { + now := time.Now() + + after := func() time.Duration { + d := time.Since(now) + // dial timeout minus time since + wait := timeout - d + + if wait < time.Duration(0) { + return time.Duration(0) + } + + return wait + } + + // wait for the message type +loop: + for { + select { + case msg := <-s.recv: + // ignore what we don't want + if msg.typ != msgType { + log.Debugf("Tunnel received non %s message in waiting for %s", msg.typ, msgType) + continue + } + + // got the message + break loop + case <-time.After(after()): + return ErrDialTimeout + case <-s.closed: + return io.EOF + } + } + + return nil +} + +// Discover attempts to discover the link for a specific channel +func (s *session) Discover() error { + // create a new discovery message for this channel + msg := s.newMessage("discover") + msg.mode = Broadcast + msg.outbound = true + msg.link = "" + + // send the discovery message + s.send <- msg + + // set time now + now := time.Now() + + after := func() time.Duration { + d := time.Since(now) + // dial timeout minus time since + wait := s.timeout - d + if wait < time.Duration(0) { + return time.Duration(0) + } + return wait + } + + // wait to hear back about the sent message + select { + case <-time.After(after()): + return ErrDialTimeout + case err := <-s.errChan: + if err != nil { + return err + } + } + + var err error + + // set a new dialTimeout + dialTimeout := after() + + // set a shorter delay for multicast + if s.mode != Unicast { + // shorten this + dialTimeout = time.Millisecond * 500 + } + + // wait for announce + if err := s.waitFor("announce", dialTimeout); err != nil { + return err + } + + // if its multicast just go ahead because this is best effort + if s.mode != Unicast { + s.discovered = true + s.accepted = true + return nil + } + + if err != nil { + return err + } + + // set discovered + s.discovered = true + + return nil +} + // Open will fire the open message for the session. This is called by the dialler. func (s *session) Open() error { // create a new message @@ -131,22 +237,15 @@ func (s *session) Open() error { } // now wait for the accept - select { - case msg = <-s.recv: - if msg.typ != "accept" { - log.Debugf("Received non accept message in Open %s", msg.typ) - return errors.New("failed to connect") - } - // set to accepted - s.accepted = true - // set link - s.link = msg.link - case <-time.After(s.timeout): - return ErrDialTimeout - case <-s.closed: - return io.EOF + if err := s.waitFor("accept", s.timeout); err != nil { + return err } + // set to accepted + s.accepted = true + // set link + s.link = msg.link + return nil } diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 73c937f3..212e07e4 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -63,8 +63,8 @@ type Link interface { Length() int64 // Current transfer rate as bits per second (lower is better) Rate() float64 - // Status of the link e.g connected/closed - Status() string + // State of the link e.g connected/closed + State() string // honours transport socket transport.Socket } From c170189efbcd7f776237f7d2b211f7af16380e0f Mon Sep 17 00:00:00 2001 From: Asim Aslam Date: Fri, 25 Oct 2019 14:22:38 +0100 Subject: [PATCH 3/4] We need the message back to set the link --- tunnel/session.go | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/tunnel/session.go b/tunnel/session.go index f1e65184..1dfe992c 100644 --- a/tunnel/session.go +++ b/tunnel/session.go @@ -107,7 +107,7 @@ func (s *session) newMessage(typ string) *message { } // waitFor waits for the message type required until the timeout specified -func (s *session) waitFor(msgType string, timeout time.Duration) error { +func (s *session) waitFor(msgType string, timeout time.Duration) (*message, error) { now := time.Now() after := func() time.Duration { @@ -123,7 +123,6 @@ func (s *session) waitFor(msgType string, timeout time.Duration) error { } // wait for the message type -loop: for { select { case msg := <-s.recv: @@ -132,17 +131,14 @@ loop: log.Debugf("Tunnel received non %s message in waiting for %s", msg.typ, msgType) continue } - // got the message - break loop + return msg, nil case <-time.After(after()): - return ErrDialTimeout + return nil, ErrDialTimeout case <-s.closed: - return io.EOF + return nil, io.EOF } } - - return nil } // Discover attempts to discover the link for a specific channel @@ -191,7 +187,8 @@ func (s *session) Discover() error { } // wait for announce - if err := s.waitFor("announce", dialTimeout); err != nil { + _, err = s.waitFor("announce", dialTimeout) + if err != nil { return err } @@ -237,7 +234,8 @@ func (s *session) Open() error { } // now wait for the accept - if err := s.waitFor("accept", s.timeout); err != nil { + msg, err := s.waitFor("accept", s.timeout) + if err != nil { return err } From 1c9ada6413eb3564110964d15e4239df6f75d9bc Mon Sep 17 00:00:00 2001 From: Asim Aslam Date: Fri, 25 Oct 2019 14:24:37 +0100 Subject: [PATCH 4/4] Reorder setChannel method --- tunnel/link.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tunnel/link.go b/tunnel/link.go index a319c214..d135c152 100644 --- a/tunnel/link.go +++ b/tunnel/link.go @@ -124,6 +124,12 @@ func (l *link) delChannel(ch string) { l.Unlock() } +func (l *link) getChannel(ch string) time.Time { + l.RLock() + defer l.RUnlock() + return l.channels[ch] +} + func (l *link) setChannel(channels ...string) { l.Lock() for _, ch := range channels { @@ -132,12 +138,6 @@ func (l *link) setChannel(channels ...string) { l.Unlock() } -func (l *link) getChannel(ch string) time.Time { - l.RLock() - defer l.RUnlock() - return l.channels[ch] -} - // set the keepalive time func (l *link) keepalive() { l.Lock()