diff --git a/tunnel/default.go b/tunnel/default.go index 1cb0dbfc..e64359a6 100644 --- a/tunnel/default.go +++ b/tunnel/default.go @@ -127,7 +127,6 @@ func (t *tun) newSession(channel, sessionId string) (*session, bool) { closed: make(chan bool), recv: make(chan *message, 128), send: t.send, - wait: make(chan bool), errChan: make(chan error, 1), } @@ -470,6 +469,9 @@ func (t *tun) listen(link *link) { return } + // this state machine block handles the only message types + // that we know or care about; connect, close, open, accept, + // discover, announce, session, keepalive switch mtype { case "connect": log.Debugf("Tunnel link %s received connect message", link.Remote()) @@ -500,9 +502,6 @@ func (t *tun) listen(link *link) { // nothing more to do continue case "close": - // TODO: handle the close message - // maybe report io.EOF or kill the link - // if there is no channel then we close the link // as its a signal from the other side to close the connection if len(channel) == 0 { @@ -521,6 +520,8 @@ func (t *tun) listen(link *link) { // try get the dialing socket s, exists := t.getSession(channel, sessionId) if exists && !loopback { + // only delete the session if its unicast + // otherwise ignore close on the multicast if s.mode == Unicast { // only delete this if its unicast // but not if its a loopback conn @@ -541,14 +542,16 @@ func (t *tun) listen(link *link) { // an accept returned by the listener case "accept": s, exists := t.getSession(channel, sessionId) - // we don't need this + // just set accepted on anything not unicast if exists && s.mode > Unicast { s.accepted = true continue } + // if its already accepted move on if exists && s.accepted { continue } + // otherwise we're going to process to accept // a continued session case "session": // process message @@ -562,7 +565,10 @@ func (t *tun) listen(link *link) { link.setChannel(channels...) // this was an announcement not intended for anything - if sessionId == "listener" || sessionId == "" { + // if the dialing side sent "discover" then a session + // id would be present. We skip in case of multicast. + switch sessionId { + case "listener", "multicast", "": continue } @@ -574,14 +580,19 @@ func (t *tun) listen(link *link) { continue } - // send the announce back to the caller - s.recv <- &message{ + msg := &message{ typ: "announce", tunnel: id, channel: channel, session: sessionId, link: link.id, } + + // send the announce back to the caller + select { + case <-s.closed: + case s.recv <- msg: + } } continue case "discover": @@ -651,22 +662,10 @@ func (t *tun) listen(link *link) { delete(t.sessions, channel) continue default: - // process + // otherwise process } - log.Debugf("Tunnel using channel %s session %s", s.channel, s.session) - - // is the session new? - select { - // if its new the session is actually blocked waiting - // for a connection. so we check if its waiting. - case <-s.wait: - // if its waiting e.g its new then we close it - default: - // set remote address of the session - s.remote = msg.Header["Remote"] - close(s.wait) - } + log.Debugf("Tunnel using channel %s session %s type %s", s.channel, s.session, mtype) // construct a new transport message tmsg := &transport.Message{ @@ -1052,7 +1051,7 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { t.RUnlock() - // link not found + // link not found and one was specified so error out if len(links) == 0 && len(options.Link) > 0 { // delete session and return error t.delSession(c.channel, c.session) @@ -1061,15 +1060,14 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { } // discovered so set the link if not multicast - // TODO: pick the link efficiently based - // on link status and saturation. if c.discovered && c.mode == Unicast { // pickLink will pick the best link link := t.pickLink(links) + // set the link c.link = link.id } - // shit fuck + // if its not already discovered we need to attempt to do so if !c.discovered { // piggy back roundtrip nowRTT := time.Now() @@ -1098,7 +1096,15 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { } } - // a unicast session so we call "open" and wait for an "accept" + // return early if its not unicast + // we will not call "open" for multicast + if c.mode != Unicast { + return c, nil + } + + // Note: we go no further for multicast or broadcast. + // This is a unicast session so we call "open" and wait + // for an "accept" // reset now in case we use it now := time.Now() @@ -1115,7 +1121,7 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { d := time.Since(now) // if we haven't measured the roundtrip do it now - if !measured && c.mode == Unicast { + if !measured { // set the link time t.RLock() link, ok := t.links[c.link] @@ -1145,6 +1151,7 @@ func (t *tun) Listen(channel string, opts ...ListenOption) (Listener, error) { return nil, errors.New("already listening on " + channel) } + // delete function removes the session when closed delFunc := func() { t.delSession(channel, "listener") } diff --git a/tunnel/listener.go b/tunnel/listener.go index 7ba6074e..36d70e96 100644 --- a/tunnel/listener.go +++ b/tunnel/listener.go @@ -24,7 +24,7 @@ type tunListener struct { delFunc func() } -// periodically announce self +// periodically announce self the channel being listened on func (t *tunListener) announce() { tick := time.NewTicker(time.Second * 30) defer tick.Stop() @@ -48,9 +48,12 @@ func (t *tunListener) process() { defer func() { // close the sessions - for _, conn := range conns { + for id, conn := range conns { conn.Close() + delete(conns, id) } + // unassign + conns = nil }() for { @@ -62,9 +65,22 @@ func (t *tunListener) process() { return // receive a new message case m := <-t.session.recv: + var sessionId string + + // get the session id + switch m.mode { + case Multicast, Broadcast: + // use channel name if multicast/broadcast + sessionId = m.channel + log.Tracef("Tunnel listener using session %s for real session %s", sessionId, m.session) + default: + // use session id if unicast + sessionId = m.session + } + // get a session - sess, ok := conns[m.session] - log.Debugf("Tunnel listener received channel %s session %s exists: %t", m.channel, m.session, ok) + sess, ok := conns[sessionId] + log.Debugf("Tunnel listener received channel %s session %s type %s exists: %t", m.channel, sessionId, m.typ, ok) if !ok { // we only process open and session types switch m.typ { @@ -80,7 +96,7 @@ func (t *tunListener) process() { // the channel channel: m.channel, // the session id - session: m.session, + session: sessionId, // tunnel token token: t.token, // is loopback conn @@ -95,14 +111,12 @@ func (t *tunListener) process() { recv: make(chan *message, 128), // use the internal send buffer send: t.session.send, - // wait - wait: make(chan bool), // error channel errChan: make(chan error, 1), } // save the session - conns[m.session] = sess + conns[sessionId] = sess select { case <-t.closed: @@ -114,17 +128,21 @@ func (t *tunListener) process() { // an existing session was found - // received a close message switch m.typ { case "close": + // received a close message select { + // check if the session is closed case <-sess.closed: // no op - delete(conns, m.session) + delete(conns, sessionId) default: - // close and delete session - close(sess.closed) - delete(conns, m.session) + if sess.mode == Unicast { + // only close if unicast session + // close and delete session + close(sess.closed) + delete(conns, sessionId) + } } // continue @@ -139,9 +157,9 @@ func (t *tunListener) process() { // send this to the accept chan select { case <-sess.closed: - delete(conns, m.session) + delete(conns, sessionId) case sess.recv <- m: - log.Debugf("Tunnel listener sent to recv chan channel %s session %s", m.channel, m.session) + log.Debugf("Tunnel listener sent to recv chan channel %s session %s type %s", m.channel, sessionId, m.typ) } } } diff --git a/tunnel/tunnel_reconnect_test.go b/tunnel/reconnect_test.go similarity index 100% rename from tunnel/tunnel_reconnect_test.go rename to tunnel/reconnect_test.go diff --git a/tunnel/session.go b/tunnel/session.go index b5b724b0..6c4ad69a 100644 --- a/tunnel/session.go +++ b/tunnel/session.go @@ -30,8 +30,6 @@ type session struct { send chan *message // recv chan recv chan *message - // wait until we have a connection - wait chan bool // if the discovery worked discovered bool // if the session was accepted @@ -109,6 +107,29 @@ func (s *session) newMessage(typ string) *message { } } +func (s *session) sendMsg(msg *message) error { + select { + case <-s.closed: + return io.EOF + case s.send <- msg: + return nil + } +} + +func (s *session) wait(msg *message) error { + // wait for an error response + select { + case err := <-msg.errChan: + if err != nil { + return err + } + case <-s.closed: + return io.EOF + } + + return nil +} + // waitFor waits for the message type required until the timeout specified func (s *session) waitFor(msgType string, timeout time.Duration) (*message, error) { now := time.Now() @@ -144,20 +165,32 @@ func (s *session) waitFor(msgType string, timeout time.Duration) (*message, erro } } -// Discover attempts to discover the link for a specific channel +// Discover attempts to discover the link for a specific channel. +// This is only used by the tunnel.Dial when first connecting. func (s *session) Discover() error { // create a new discovery message for this channel msg := s.newMessage("discover") + // broadcast the message to all links msg.mode = Broadcast + // its an outbound connection since we're dialling msg.outbound = true + // don't set the link since we don't know where it is msg.link = "" - // send the discovery message - s.send <- msg + // if multicast then set that as session + if s.mode == Multicast { + msg.session = "multicast" + } + + // send discover message + if err := s.sendMsg(msg); err != nil { + return err + } // set time now now := time.Now() + // after strips down the dial timeout after := func() time.Duration { d := time.Since(now) // dial timeout minus time since @@ -168,6 +201,7 @@ func (s *session) Discover() error { return wait } + // the discover message is sent out, now // wait to hear back about the sent message select { case <-time.After(after()): @@ -178,27 +212,16 @@ func (s *session) Discover() error { } } - var err error - - // set a new dialTimeout - dialTimeout := after() - - // set a shorter delay for multicast - if s.mode != Unicast { - // shorten this - dialTimeout = time.Millisecond * 500 - } - - // wait for announce - _, err = s.waitFor("announce", dialTimeout) - - // if its multicast just go ahead because this is best effort + // bail early if its not unicast + // we don't need to wait for the announce if s.mode != Unicast { s.discovered = true s.accepted = true return nil } + // wait for announce + _, err := s.waitFor("announce", after()) if err != nil { return err } @@ -210,30 +233,22 @@ func (s *session) Discover() error { } // Open will fire the open message for the session. This is called by the dialler. +// This is to indicate that we want to create a new session. func (s *session) Open() error { // create a new message msg := s.newMessage("open") // send open message - s.send <- msg + if err := s.sendMsg(msg); err != nil { + return err + } // wait for an error response for send - select { - case err := <-msg.errChan: - if err != nil { - return err - } - case <-s.closed: - return io.EOF + if err := s.wait(msg); err != nil { + return err } - // don't wait on multicast/broadcast - if s.mode == Multicast { - s.accepted = true - return nil - } - - // now wait for the accept + // now wait for the accept message to be returned msg, err := s.waitFor("accept", s.timeout) if err != nil { return err @@ -252,32 +267,16 @@ func (s *session) Accept() error { msg := s.newMessage("accept") // send the accept message - select { - case <-s.closed: - return io.EOF - case s.send <- msg: - // no op here - } - - // don't wait on multicast/broadcast - if s.mode == Multicast { - return nil + if err := s.sendMsg(msg); err != nil { + return err } // wait for send response - select { - case err := <-s.errChan: - if err != nil { - return err - } - case <-s.closed: - return io.EOF - } - - return nil + return s.wait(msg) } -// Announce sends an announcement to notify that this session exists. This is primarily used by the listener. +// Announce sends an announcement to notify that this session exists. +// This is primarily used by the listener. func (s *session) Announce() error { msg := s.newMessage("announce") // we don't need an error back @@ -287,23 +286,12 @@ func (s *session) Announce() error { // we don't need the link msg.link = "" - select { - case s.send <- msg: - return nil - case <-s.closed: - return io.EOF - } + // send announce message + return s.sendMsg(msg) } // Send is used to send a message func (s *session) Send(m *transport.Message) error { - select { - case <-s.closed: - return io.EOF - default: - // no op - } - // encrypt the transport message payload body, err := Encrypt(m.Body, s.token+s.channel+s.session) if err != nil { @@ -335,21 +323,19 @@ func (s *session) Send(m *transport.Message) error { msg.data = data // if multicast don't set the link - if s.mode == Multicast { + if s.mode != Unicast { msg.link = "" } log.Tracef("Appending %+v to send backlog", msg) + // send the actual message - s.send <- msg + if err := s.sendMsg(msg); err != nil { + return err + } // wait for an error response - select { - case err := <-msg.errChan: - return err - case <-s.closed: - return io.EOF - } + return s.wait(msg) } // Recv is used to receive a message @@ -413,6 +399,11 @@ func (s *session) Close() error { default: close(s.closed) + // don't send close on multicast + if s.mode != Unicast { + return nil + } + // append to backlog msg := s.newMessage("close") // no error response on close @@ -421,7 +412,7 @@ func (s *session) Close() error { // send the close message select { case s.send <- msg: - default: + case <-time.After(time.Millisecond * 10): } }