diff --git a/network/default.go b/network/default.go index 202e18f1..4076f0e8 100644 --- a/network/default.go +++ b/network/default.go @@ -718,7 +718,7 @@ func (n *network) Connect() error { ) // dial into ControlChannel to send route adverts - ctrlClient, err := n.Tunnel.Dial(ControlChannel) + ctrlClient, err := n.Tunnel.Dial(ControlChannel, tunnel.DialMulticast()) if err != nil { return err } @@ -732,7 +732,7 @@ func (n *network) Connect() error { } // dial into NetworkChannel to send network messages - netClient, err := n.Tunnel.Dial(NetworkChannel) + netClient, err := n.Tunnel.Dial(NetworkChannel, tunnel.DialMulticast()) if err != nil { return err } diff --git a/tunnel/broker/broker.go b/tunnel/broker/broker.go index 0a33c610..6778dfaa 100644 --- a/tunnel/broker/broker.go +++ b/tunnel/broker/broker.go @@ -58,7 +58,7 @@ func (t *tunBroker) Disconnect() error { func (t *tunBroker) Publish(topic string, m *broker.Message, opts ...broker.PublishOption) error { // TODO: this is probably inefficient, we might want to just maintain an open connection // it may be easier to add broadcast to the tunnel - c, err := t.tunnel.Dial(topic) + c, err := t.tunnel.Dial(topic, tunnel.DialMulticast()) if err != nil { return err } diff --git a/tunnel/default.go b/tunnel/default.go index 0b9b9136..5d27809a 100644 --- a/tunnel/default.go +++ b/tunnel/default.go @@ -87,11 +87,17 @@ func (t *tun) getSession(channel, session string) (*session, bool) { return s, ok } +func (t *tun) delSession(channel, session string) { + t.Lock() + delete(t.sessions, channel+session) + t.Unlock() +} + // newSession creates a new session and saves it func (t *tun) newSession(channel, sessionId string) (*session, bool) { // new session s := &session{ - id: t.id, + tunnel: t.id, channel: channel, session: sessionId, closed: make(chan bool), @@ -150,7 +156,9 @@ func (t *tun) monitor() { log.Debugf("Tunnel failed to setup node link to %s: %v", node, err) continue } - + // set the link id to the node + // TODO: hash it + link.id = node // save the link t.Lock() t.links[node] = link @@ -169,18 +177,21 @@ func (t *tun) process() { case msg := <-t.send: newMsg := &transport.Message{ Header: make(map[string]string), - Body: msg.data.Body, } - for k, v := range msg.data.Header { - newMsg.Header[k] = v + // set the data + if msg.data != nil { + for k, v := range msg.data.Header { + newMsg.Header[k] = v + } + newMsg.Body = msg.data.Body } // set message head newMsg.Header["Micro-Tunnel"] = msg.typ // set the tunnel id on the outgoing message - newMsg.Header["Micro-Tunnel-Id"] = msg.id + newMsg.Header["Micro-Tunnel-Id"] = msg.tunnel // set the tunnel channel on the outgoing message newMsg.Header["Micro-Tunnel-Channel"] = msg.channel @@ -195,7 +206,7 @@ func (t *tun) process() { t.Lock() if len(t.links) == 0 { - log.Debugf("No links to send to") + log.Debugf("No links to send message type: %s channel: %s", msg.typ, msg.channel) } var sent bool @@ -232,25 +243,55 @@ func (t *tun) process() { continue } + // check the multicast mappings + if msg.multicast { + link.RLock() + _, ok := link.channels[msg.channel] + link.RUnlock() + // channel mapping not found in link + if !ok { + continue + } + } + // send the message via the current link log.Debugf("Sending %+v to %s", newMsg, node) + if errr := link.Send(newMsg); errr != nil { log.Debugf("Tunnel error sending %+v to %s: %v", newMsg, node, errr) err = errors.New(errr.Error()) + // kill the link + link.Close() + // delete the link delete(t.links, node) continue } + // is sent sent = true + + // keep sending broadcast messages + if msg.broadcast || msg.multicast { + continue + } + + // break on unicast + break } t.Unlock() + // set the error if not sent var gerr error if !sent { gerr = err } + // skip if its not been set + if msg.errChan == nil { + continue + } + // return error non blocking select { case msg.errChan <- gerr: @@ -262,14 +303,25 @@ func (t *tun) process() { } } +func (t *tun) delLink(id string) { + t.Lock() + defer t.Unlock() + // get the link + link, ok := t.links[id] + if !ok { + return + } + // close and delete + link.Close() + delete(t.links, id) +} + // process incoming messages func (t *tun) listen(link *link) { // remove the link on exit defer func() { log.Debugf("Tunnel deleting connection from %s", link.Remote()) - t.Lock() - delete(t.links, link.Remote()) - t.Unlock() + t.delLink(link.Remote()) }() // let us know if its a loopback @@ -292,18 +344,34 @@ func (t *tun) listen(link *link) { return } - switch msg.Header["Micro-Tunnel"] { + // message type + mtype := msg.Header["Micro-Tunnel"] + // the tunnel id + id := msg.Header["Micro-Tunnel-Id"] + // the tunnel channel + channel := msg.Header["Micro-Tunnel-Channel"] + // the session id + sessionId := msg.Header["Micro-Tunnel-Session"] + + // if its not connected throw away the link + // the first message we process needs to be connect + if !link.connected && mtype != "connect" { + log.Debugf("Tunnel link %s not connected", link.id) + return + } + + switch mtype { case "connect": log.Debugf("Tunnel link %s received connect message", link.Remote()) - id := msg.Header["Micro-Tunnel-Id"] - // are we connecting to ourselves? if id == t.id { link.loopback = true loopback = true } + // set to remote node + link.id = id // set as connected link.connected = true @@ -315,10 +383,31 @@ func (t *tun) listen(link *link) { // nothing more to do continue case "close": - log.Debugf("Tunnel link %s closing connection", link.Remote()) // TODO: handle the close message // maybe report io.EOF or kill the link - return + + // close the link entirely + if len(channel) == 0 { + log.Debugf("Tunnel link %s received close message", link.Remote()) + return + } + + // the entire listener was closed so remove it from the mapping + if sessionId == "listener" { + link.Lock() + delete(link.channels, channel) + link.Unlock() + continue + } + + // try get the dialing socket + s, exists := t.getSession(channel, sessionId) + if exists { + // close and continue + s.Close() + continue + } + // otherwise its a session mapping of sorts case "keepalive": log.Debugf("Tunnel link %s received keepalive", link.Remote()) t.Lock() @@ -326,27 +415,64 @@ func (t *tun) listen(link *link) { link.lastKeepAlive = time.Now() t.Unlock() continue - case "message": + // a new connection dialled outbound + case "open": + // we just let it pass through to be processed + // an accept returned by the listener + case "accept": + + // a continued session + case "session": // process message log.Debugf("Received %+v from %s", msg, link.Remote()) + // an announcement of a channel listener + case "announce": + // update mapping in the link + link.Lock() + link.channels[channel] = time.Now() + link.Unlock() + + // get the session that asked for the discovery + s, exists := t.getSession(channel, sessionId) + if exists { + // don't bother it's already discovered + if s.discovered { + continue + } + + // send the announce back to the caller + s.recv <- &message{ + typ: "announce", + tunnel: id, + channel: channel, + session: sessionId, + link: link.id, + } + } + continue + case "discover": + // looking for existing mapping + _, exists := t.getSession(channel, "listener") + if exists { + log.Debugf("Tunnel sending announce for discovery of channel %s", channel) + // send back the announcement + link.Send(&transport.Message{ + Header: map[string]string{ + "Micro-Tunnel": "announce", + "Micro-Tunnel-Id": t.id, + "Micro-Tunnel-Channel": channel, + "Micro-Tunnel-Session": sessionId, + "Micro-Tunnel-Link": link.id, + "Micro-Tunnel-Token": t.token, + }, + }) + } + continue default: // blackhole it continue } - // if its not connected throw away the link - if !link.connected { - log.Debugf("Tunnel link %s not connected", link.id) - return - } - - // the tunnel id - id := msg.Header["Micro-Tunnel-Id"] - // the tunnel channel - channel := msg.Header["Micro-Tunnel-Channel"] - // the session id - sessionId := msg.Header["Micro-Tunnel-Session"] - // strip tunnel message header for k, _ := range msg.Header { if strings.HasPrefix(k, "Micro-Tunnel") { @@ -368,8 +494,15 @@ func (t *tun) listen(link *link) { // If its a loopback connection then we've enabled link direction // listening side is used for listening, the dialling side for dialling switch { - case loopback: + case loopback, mtype == "open": s, exists = t.getSession(channel, "listener") + // only return accept to the session + case mtype == "accept": + log.Debugf("Received accept message for %s %s", channel, sessionId) + s, exists = t.getSession(channel, sessionId) + if exists && s.accepted { + continue + } default: // get the session based on the tunnel id and session // this could be something we dialed in which case @@ -383,7 +516,7 @@ func (t *tun) listen(link *link) { } } - // bail if no session has been found + // bail if no session or listener has been found if !exists { log.Debugf("Tunnel skipping no session exists") // drop it, we don't care about @@ -391,8 +524,6 @@ func (t *tun) listen(link *link) { continue } - log.Debugf("Tunnel using session %s %s", s.channel, s.session) - // is the session closed? select { case <-s.closed: @@ -403,6 +534,8 @@ func (t *tun) listen(link *link) { // process } + log.Debugf("Tunnel using channel %s session %s", s.channel, s.session) + // is the session new? select { // if its new the session is actually blocked waiting @@ -423,7 +556,8 @@ func (t *tun) listen(link *link) { // construct the internal message imsg := &message{ - id: id, + tunnel: id, + typ: mtype, channel: channel, session: sessionId, data: tmsg, @@ -461,9 +595,7 @@ func (t *tun) keepalive(link *link) { }, }); err != nil { log.Debugf("Error sending keepalive to link %v: %v", link.Remote(), err) - t.Lock() - delete(t.links, link.Remote()) - t.Unlock() + t.delLink(link.Remote()) return } } @@ -481,6 +613,7 @@ func (t *tun) setupLink(node string) (*link, error) { } log.Debugf("Tunnel connected to %s", node) + // send the first connect message if err := c.Send(&transport.Message{ Header: map[string]string{ "Micro-Tunnel": "connect", @@ -493,9 +626,11 @@ func (t *tun) setupLink(node string) (*link, error) { // create a new link link := newLink(c) - link.connected = true + // set link id to remote side + link.id = c.Remote() // we made the outbound connection // and sent the connect message + link.connected = true // process incoming messages go t.listen(link) @@ -553,7 +688,7 @@ func (t *tun) connect() error { } // save the link - t.links[node] = link + t.links[link.Remote()] = link } // process outbound messages to be sent @@ -627,6 +762,8 @@ func (t *tun) Close() error { return nil } + log.Debug("Tunnel closing") + select { case <-t.closed: return nil @@ -650,7 +787,7 @@ func (t *tun) Close() error { } // Dial an address -func (t *tun) Dial(channel string) (Session, error) { +func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { log.Debugf("Tunnel dialing %s", channel) c, ok := t.newSession(channel, t.newSessionId()) if !ok { @@ -663,18 +800,93 @@ func (t *tun) Dial(channel string) (Session, error) { // outbound session c.outbound = true + // get opts + options := DialOptions{ + Timeout: DefaultDialTimeout, + } + + for _, o := range opts { + o(&options) + } + + // set the multicast option + c.multicast = options.Multicast + // set the dial timeout + c.timeout = options.Timeout + + // don't bother with the song and dance below + // we're just going to assume things come online + // as and when. + if c.multicast { + return c, nil + } + + // non multicast so we need to find the link + t.RLock() + for _, link := range t.links { + link.RLock() + _, ok := link.channels[channel] + link.RUnlock() + + // we have at least one channel mapping + if ok { + c.discovered = true + break + } + } + t.RUnlock() + + // shit fuck + if !c.discovered { + msg := c.newMessage("discover") + msg.broadcast = true + msg.outbound = true + msg.link = "" + + // send the discovery message + t.send <- msg + + select { + case err := <-c.errChan: + if err != nil { + return nil, err + } + } + + // wait for announce + select { + case msg := <-c.recv: + if msg.typ != "announce" { + return nil, errors.New("failed to discover channel") + } + } + } + + // try to open the session + err := c.Open() + if err != nil { + // delete the session + t.delSession(c.channel, c.session) + return nil, err + } + return c, nil } // Accept a connection on the address func (t *tun) Listen(channel string) (Listener, error) { log.Debugf("Tunnel listening on %s", channel) + // create a new session by hashing the address c, ok := t.newSession(channel, "listener") if !ok { return nil, errors.New("already listening on " + channel) } + delFunc := func() { + t.delSession(channel, "listener") + } + // set remote. it will be replaced by the first message received c.remote = "remote" // set local @@ -690,6 +902,8 @@ func (t *tun) Listen(channel string) (Listener, error) { tunClosed: t.closed, // the listener session session: c, + // delete session + delFunc: delFunc, } // this kicks off the internal message processor @@ -698,10 +912,26 @@ func (t *tun) Listen(channel string) (Listener, error) { // to the existign sessions go tl.process() + // announces the listener channel to others + go tl.announce() + // return the listener return tl, nil } +func (t *tun) Links() []Link { + t.RLock() + defer t.RUnlock() + + var links []Link + + for _, link := range t.links { + links = append(links, link) + } + + return links +} + func (t *tun) String() string { return "mucp" } diff --git a/tunnel/link.go b/tunnel/link.go index fbec2e6a..9470ef6a 100644 --- a/tunnel/link.go +++ b/tunnel/link.go @@ -9,9 +9,10 @@ import ( ) type link struct { + transport.Socket + sync.RWMutex - transport.Socket // unique id of this link e.g uuid // which we define for ourselves id string @@ -27,11 +28,77 @@ type link struct { // the last time we received a keepalive // on this link from the remote side lastKeepAlive time.Time + // channels keeps a mapping of channels and last seen + channels map[string]time.Time + // stop the link + closed chan bool } func newLink(s transport.Socket) *link { - return &link{ - Socket: s, - id: uuid.New().String(), + l := &link{ + Socket: s, + id: uuid.New().String(), + channels: make(map[string]time.Time), + closed: make(chan bool), + } + go l.run() + return l +} + +func (l *link) run() { + t := time.NewTicker(time.Minute) + defer t.Stop() + + for { + select { + case <-l.closed: + return + case <-t.C: + // drop any channel mappings older than 2 minutes + var kill []string + killTime := time.Minute * 2 + + l.RLock() + for ch, t := range l.channels { + if d := time.Since(t); d > killTime { + kill = append(kill, ch) + } + } + l.RUnlock() + + // if nothing to kill don't bother with a wasted lock + if len(kill) == 0 { + continue + } + + // kill the channels! + l.Lock() + for _, ch := range kill { + delete(l.channels, ch) + } + l.Unlock() + } } } + +func (l *link) Id() string { + l.RLock() + defer l.RUnlock() + + return l.id +} + +func (l *link) Close() error { + l.Lock() + defer l.Unlock() + + select { + case <-l.closed: + return nil + default: + close(l.closed) + return l.Socket.Close() + } + + return nil +} diff --git a/tunnel/listener.go b/tunnel/listener.go index d62b58de..e60ff396 100644 --- a/tunnel/listener.go +++ b/tunnel/listener.go @@ -2,6 +2,7 @@ package tunnel import ( "io" + "time" "github.com/micro/go-micro/util/log" ) @@ -17,26 +18,62 @@ type tunListener struct { tunClosed chan bool // the listener session session *session + // del func to kill listener + delFunc func() +} + +// periodically announce self +func (t *tunListener) announce() { + tick := time.NewTicker(time.Minute) + defer tick.Stop() + + // first announcement + t.session.Announce() + + for { + select { + case <-tick.C: + t.session.Announce() + case <-t.closed: + return + } + } } func (t *tunListener) process() { // our connection map for session conns := make(map[string]*session) + defer func() { + // close the sessions + for _, conn := range conns { + conn.Close() + } + }() + for { select { case <-t.closed: return + case <-t.tunClosed: + t.Close() + return // receive a new message case m := <-t.session.recv: // get a session sess, ok := conns[m.session] - log.Debugf("Tunnel listener received id %s session %s exists: %t", m.id, m.session, ok) + log.Debugf("Tunnel listener received channel %s session %s exists: %t", m.channel, m.session, ok) if !ok { + switch m.typ { + case "open", "session": + default: + continue + } + // create a new session session sess = &session{ // the id of the remote side - id: m.id, + tunnel: m.tunnel, // the channel channel: m.channel, // the session id @@ -45,6 +82,8 @@ func (t *tunListener) process() { loopback: m.loopback, // the link the message was received on link: m.link, + // set multicast + multicast: m.multicast, // close chan closed: make(chan bool), // recv called by the acceptor @@ -60,20 +99,44 @@ func (t *tunListener) process() { // save the session conns[m.session] = sess - // send to accept chan select { case <-t.closed: return + // send to accept chan case t.accept <- sess: } } + // an existing session was found + + // received a close message + switch m.typ { + case "close": + select { + case <-sess.closed: + // no op + delete(conns, m.session) + default: + // close and delete session + close(sess.closed) + delete(conns, m.session) + } + + // continue + continue + case "session": + // operate on this + default: + // non operational type + continue + } + // send this to the accept chan select { case <-sess.closed: delete(conns, m.session) case sess.recv <- m: - log.Debugf("Tunnel listener sent to recv chan id %s session %s", m.id, m.session) + log.Debugf("Tunnel listener sent to recv chan channel %s session %s", m.channel, m.session) } } } @@ -89,6 +152,9 @@ func (t *tunListener) Close() error { case <-t.closed: return nil default: + // close and delete + t.delFunc() + t.session.Close() close(t.closed) } return nil @@ -102,13 +168,17 @@ func (t *tunListener) Accept() (Session, error) { return nil, io.EOF case <-t.tunClosed: // close the listener when the tunnel closes - t.Close() return nil, io.EOF // wait for a new connection case c, ok := <-t.accept: + // check if the accept chan is closed if !ok { return nil, io.EOF } + // send back the accept + if err := c.Accept(); err != nil { + return nil, err + } return c, nil } return nil, nil diff --git a/tunnel/options.go b/tunnel/options.go index 9d612173..39795671 100644 --- a/tunnel/options.go +++ b/tunnel/options.go @@ -1,6 +1,8 @@ package tunnel import ( + "time" + "github.com/google/uuid" "github.com/micro/go-micro/transport" "github.com/micro/go-micro/transport/quic" @@ -29,6 +31,15 @@ type Options struct { Transport transport.Transport } +type DialOption func(*DialOptions) + +type DialOptions struct { + // specify a multicast connection + Multicast bool + // the dial timeout + Timeout time.Duration +} + // The tunnel id func Id(id string) Option { return func(o *Options) { @@ -73,3 +84,18 @@ func DefaultOptions() Options { Transport: quic.NewTransport(), } } + +// Dial options + +// Dial multicast sets the multicast option to send only to those mapped +func DialMulticast() DialOption { + return func(o *DialOptions) { + o.Multicast = true + } +} + +func DialTimeout(t time.Duration) DialOption { + return func(o *DialOptions) { + o.Timeout = t + } +} diff --git a/tunnel/session.go b/tunnel/session.go index a4c779f3..58246fdd 100644 --- a/tunnel/session.go +++ b/tunnel/session.go @@ -3,6 +3,7 @@ package tunnel import ( "errors" "io" + "time" "github.com/micro/go-micro/transport" "github.com/micro/go-micro/util/log" @@ -10,8 +11,8 @@ import ( // session is our pseudo session for transport.Socket type session struct { - // unique id based on the remote tunnel id - id string + // the tunnel id + tunnel string // the channel name channel string // the session id based on Micro.Tunnel-Session @@ -28,10 +29,20 @@ type session struct { recv chan *message // wait until we have a connection wait chan bool + // if the discovery worked + discovered bool + // if the session was accepted + accepted bool // outbound marks the session as outbound dialled connection outbound bool // lookback marks the session as a loopback on the inbound loopback bool + // if the session is multicast + multicast bool + // if the session is broadcast + broadcast bool + // the timeout + timeout time.Duration // the link on which this message was received link string // the error response @@ -43,7 +54,7 @@ type message struct { // type of message typ string // tunnel id - id string + tunnel string // channel name channel string // the session id @@ -52,6 +63,10 @@ type message struct { outbound bool // loopback marks the message intended for loopback loopback bool + // whether to send as multicast + multicast bool + // broadcast sets the broadcast type + broadcast bool // the link to send the message on link string // transport data @@ -76,10 +91,111 @@ func (s *session) Channel() string { return s.channel } +// newMessage creates a new message based on the session +func (s *session) newMessage(typ string) *message { + return &message{ + typ: typ, + tunnel: s.tunnel, + channel: s.channel, + session: s.session, + outbound: s.outbound, + loopback: s.loopback, + multicast: s.multicast, + link: s.link, + errChan: s.errChan, + } +} + +// Open will fire the open message for the session. This is called by the dialler. +func (s *session) Open() error { + // create a new message + msg := s.newMessage("open") + + // send open message + s.send <- msg + + // wait for an error response for send + select { + case err := <-msg.errChan: + if err != nil { + return err + } + case <-s.closed: + return io.EOF + } + + // we don't wait on multicast + if s.multicast { + s.accepted = true + return nil + } + + // now wait for the accept + select { + case msg = <-s.recv: + if msg.typ != "accept" { + log.Debugf("Received non accept message in Open %s", msg.typ) + return errors.New("failed to connect") + } + // set to accepted + s.accepted = true + // set link + s.link = msg.link + case <-time.After(s.timeout): + return ErrDialTimeout + case <-s.closed: + return io.EOF + } + + return nil +} + +// Accept sends the accept response to an open message from a dialled connection +func (s *session) Accept() error { + msg := s.newMessage("accept") + + // send the accept message + select { + case <-s.closed: + return io.EOF + case s.send <- msg: + return nil + } + + // wait for send response + select { + case err := <-s.errChan: + if err != nil { + return err + } + case <-s.closed: + return io.EOF + } + + return nil +} + +// Announce sends an announcement to notify that this session exists. This is primarily used by the listener. +func (s *session) Announce() error { + msg := s.newMessage("announce") + // we don't need an error back + msg.errChan = nil + // we don't need the link + msg.link = "" + + select { + case s.send <- msg: + return nil + case <-s.closed: + return io.EOF + } +} + +// Send is used to send a message func (s *session) Send(m *transport.Message) error { select { case <-s.closed: - return errors.New("session is closed") + return io.EOF default: // no op } @@ -94,22 +210,18 @@ func (s *session) Send(m *transport.Message) error { data.Header[k] = v } - // append to backlog - msg := &message{ - typ: "message", - id: s.id, - channel: s.channel, - session: s.session, - outbound: s.outbound, - loopback: s.loopback, - data: data, - // specify the link on which to send this - // it will be blank for dialled sessions - link: s.link, - // error chan - errChan: s.errChan, + // create a new message + msg := s.newMessage("session") + // set the data + msg.data = data + + // if multicast don't set the link + if s.multicast { + msg.link = "" } + log.Debugf("Appending %+v to send backlog", msg) + // send the actual message s.send <- msg // wait for an error response @@ -123,6 +235,7 @@ func (s *session) Send(m *transport.Message) error { return nil } +// Recv is used to receive a message func (s *session) Recv(m *transport.Message) error { select { case <-s.closed: @@ -147,13 +260,25 @@ func (s *session) Recv(m *transport.Message) error { return nil } -// Close closes the session +// Close closes the session by sending a close message func (s *session) Close() error { select { case <-s.closed: // no op default: close(s.closed) + + // append to backlog + msg := s.newMessage("close") + // no error response on close + msg.errChan = nil + + // send the close message + select { + case s.send <- msg: + default: + } } + return nil } diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 58c0ac27..7c9e2afc 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -2,9 +2,19 @@ package tunnel import ( + "errors" + "time" + "github.com/micro/go-micro/transport" ) +var ( + // ErrDialTimeout is returned by a call to Dial where the timeout occurs + ErrDialTimeout = errors.New("dial timeout") + // DefaultDialTimeout is the dial timeout if none is specified + DefaultDialTimeout = time.Second * 5 +) + // Tunnel creates a gre tunnel on top of the go-micro/transport. // It establishes multiple streams using the Micro-Tunnel-Channel header // and Micro-Tunnel-Session header. The tunnel id is a hash of @@ -18,13 +28,23 @@ type Tunnel interface { // Close closes the tunnel Close() error // Connect to a channel - Dial(channel string) (Session, error) + Dial(channel string, opts ...DialOption) (Session, error) // Accept connections on a channel Listen(channel string) (Listener, error) + // All the links the tunnel is connected to + Links() []Link // Name of the tunnel implementation String() string } +// Link represents internal links to the tunnel +type Link interface { + // The id of the link + Id() string + // honours transport socket + transport.Socket +} + // The listener provides similar constructs to the transport.Listener type Listener interface { Accept() (Session, error) diff --git a/tunnel/tunnel_test.go b/tunnel/tunnel_test.go index 3fc84b6d..fc76421b 100644 --- a/tunnel/tunnel_test.go +++ b/tunnel/tunnel_test.go @@ -187,30 +187,15 @@ func testBrokenTunAccept(t *testing.T, tun Tunnel, wait chan bool, wg *sync.Wait if err := c.Recv(m); err != nil { t.Fatal(err) } - tun.Close() - // re-start tunnel - err = tun.Connect() - if err != nil { - t.Fatal(err) - } - defer tun.Close() - - // listen on some virtual address - tl, err = tun.Listen("test-tunnel") - if err != nil { - t.Fatal(err) + // close all the links + for _, link := range tun.Links() { + link.Close() } // receiver ready; notify sender wait <- true - // accept a connection - c, err = tl.Accept() - if err != nil { - t.Fatal(err) - } - // accept the message m = new(transport.Message) if err := c.Recv(m); err != nil { @@ -279,6 +264,7 @@ func TestReconnectTunnel(t *testing.T) { if err != nil { t.Fatal(err) } + defer tunB.Close() // we manually override the tunnel.ReconnectTime value here // this is so that we make the reconnects faster than the default 5s @@ -289,6 +275,7 @@ func TestReconnectTunnel(t *testing.T) { if err != nil { t.Fatal(err) } + defer tunA.Close() wait := make(chan bool)