diff --git a/tunnel/default.go b/tunnel/default.go index 03d7c503..c819c7a7 100644 --- a/tunnel/default.go +++ b/tunnel/default.go @@ -5,12 +5,20 @@ import ( "errors" "fmt" "sync" + "time" "github.com/google/uuid" "github.com/micro/go-micro/transport" "github.com/micro/go-micro/util/log" ) +var ( + // KeepAliveTime defines time interval we send keepalive messages to outbound links + KeepAliveTime = 30 * time.Second + // ReconnectTime defines time interval we periodically attempt to reconnect dead links + ReconnectTime = 5 * time.Second +) + // tun represents a network tunnel type tun struct { options Options @@ -41,8 +49,9 @@ type tun struct { type link struct { transport.Socket - id string - loopback bool + id string + loopback bool + lastKeepAlive time.Time } // create new tunnel on top of a link @@ -118,6 +127,33 @@ func (t *tun) newSession() string { return uuid.New().String() } +// monitor monitors outbound links and attempts to reconnect to the failed ones +func (t *tun) monitor() { + reconnect := time.NewTicker(ReconnectTime) + defer reconnect.Stop() + + for { + select { + case <-t.closed: + return + case <-reconnect.C: + for _, node := range t.options.Nodes { + t.Lock() + if _, ok := t.links[node]; !ok { + link, err := t.setupLink(node) + if err != nil { + log.Debugf("Tunnel failed to setup node link to %s: %v", node, err) + t.Unlock() + continue + } + t.links[node] = link + } + t.Unlock() + } + } + } +} + // process outgoing messages sent by all local sockets func (t *tun) process() { // manage the send buffer @@ -144,19 +180,22 @@ func (t *tun) process() { newMsg.Header["Micro-Tunnel-Token"] = t.token // send the message via the interface - t.RLock() + t.Lock() if len(t.links) == 0 { - log.Debugf("Zero links to send to") + log.Debugf("No links to send to") } - for _, link := range t.links { - // TODO: error check and reconnect - log.Debugf("Sending %+v to %s", newMsg, link.Remote()) + for node, link := range t.links { if link.loopback && msg.outbound { continue } - link.Send(newMsg) + log.Debugf("Sending %+v to %s", newMsg, node) + if err := link.Send(newMsg); err != nil { + log.Debugf("Tunnel error sending %+v to %s: %v", newMsg, node, err) + delete(t.links, node) + continue + } } - t.RUnlock() + t.Unlock() case <-t.closed: return } @@ -165,20 +204,21 @@ func (t *tun) process() { // process incoming messages func (t *tun) listen(link *link) { - // loopback flag - var loopback bool - for { // process anything via the net interface msg := new(transport.Message) err := link.Recv(msg) if err != nil { - log.Debugf("Tunnel link %s receive error: %v", link.Remote(), err) + log.Debugf("Tunnel link %s receive error: %#v", link.Remote(), err) + t.Lock() + delete(t.links, link.Remote()) + t.Unlock() return } switch msg.Header["Micro-Tunnel"] { case "connect": + log.Debugf("Tunnel link %s received connect message", link.Remote()) // check the Micro-Tunnel-Token token, ok := msg.Header["Micro-Tunnel-Token"] if !ok { @@ -187,14 +227,18 @@ func (t *tun) listen(link *link) { // are we connecting to ourselves? if token == t.token { - loopback = true link.loopback = true } continue case "close": + log.Debugf("Tunnel link %s closing connection", link.Remote()) // TODO: handle the close message // maybe report io.EOF or kill the link continue + case "keepalive": + log.Debugf("Tunnel link %s received keepalive", link.Remote()) + link.lastKeepAlive = time.Now() + continue } // the tunnel id @@ -219,7 +263,7 @@ func (t *tun) listen(link *link) { log.Debugf("Received %+v from %s", msg, link.Remote()) switch { - case loopback: + case link.loopback: s, exists = t.getSocket(id, "listener") default: // get the socket based on the tunnel id and session @@ -286,6 +330,71 @@ func (t *tun) listen(link *link) { } } +// keepalive periodically sends keepalive messages to link +func (t *tun) keepalive(link *link) { + keepalive := time.NewTicker(KeepAliveTime) + defer keepalive.Stop() + + for { + select { + case <-t.closed: + return + case <-keepalive.C: + // send keepalive message + log.Debugf("Tunnel sending keepalive to link: %v", link.Remote()) + if err := link.Send(&transport.Message{ + Header: map[string]string{ + "Micro-Tunnel": "keepalive", + "Micro-Tunnel-Token": t.token, + }, + }); err != nil { + log.Debugf("Error sending keepalive to link %v: %v", link.Remote(), err) + t.Lock() + delete(t.links, link.Remote()) + t.Unlock() + return + } + } + } +} + +// setupLink connects to node and returns link if successful +// It returns error if the link failed to be established +func (t *tun) setupLink(node string) (*link, error) { + log.Debugf("Tunnel dialing %s", node) + c, err := t.options.Transport.Dial(node) + if err != nil { + log.Debugf("Tunnel failed to connect to %s: %v", node, err) + return nil, err + } + log.Debugf("Tunnel connected to %s", node) + + if err := c.Send(&transport.Message{ + Header: map[string]string{ + "Micro-Tunnel": "connect", + "Micro-Tunnel-Token": t.token, + }, + }); err != nil { + return nil, err + } + + // save the link + id := uuid.New().String() + link := &link{ + Socket: c, + id: id, + } + t.links[node] = link + + // process incoming messages + go t.listen(link) + + // start keepalive monitor + go t.keepalive(link) + + return link, nil +} + // connect the tunnel to all the nodes and listen for incoming tunnel connections func (t *tun) connect() error { l, err := t.options.Transport.Listen(t.options.Address) @@ -307,14 +416,14 @@ func (t *tun) connect() error { Socket: sock, id: id, } - t.links[id] = link + t.links[sock.Remote()] = link t.Unlock() // delete the link defer func() { - log.Debugf("Deleting connection from %s", sock.Remote()) + log.Debugf("Tunnel deleting connection from %s", sock.Remote()) t.Lock() - delete(t.links, id) + delete(t.links, sock.Remote()) t.Unlock() }() @@ -337,40 +446,23 @@ func (t *tun) connect() error { continue } - log.Debugf("Tunnel dialing %s", node) - // TODO: reconnection logic is required to keep the tunnel established - c, err := t.options.Transport.Dial(node) + // connect to node and return link + link, err := t.setupLink(node) if err != nil { - log.Debugf("Tunnel failed to connect to %s: %v", node, err) + log.Debugf("Tunnel failed to establish node link to %s: %v", node, err) continue } - log.Debugf("Tunnel connected to %s", node) - - if err := c.Send(&transport.Message{ - Header: map[string]string{ - "Micro-Tunnel": "connect", - "Micro-Tunnel-Token": t.token, - }, - }); err != nil { - continue - } - // save the link - id := uuid.New().String() - link := &link{ - Socket: c, - id: id, - } - t.links[id] = link - - // process incoming messages - go t.listen(link) + t.links[node] = link } // process outbound messages to be sent // process sends to all links go t.process() + // monitor links + go t.monitor() + return nil } @@ -399,7 +491,7 @@ func (t *tun) Connect() error { func (t *tun) close() error { // close all the links - for id, link := range t.links { + for node, link := range t.links { link.Send(&transport.Message{ Header: map[string]string{ "Micro-Tunnel": "close", @@ -407,7 +499,7 @@ func (t *tun) close() error { }, }) link.Close() - delete(t.links, id) + delete(t.links, node) } // close the listener @@ -428,8 +520,9 @@ func (t *tun) Close() error { return nil default: // close all the sockets - for _, s := range t.sockets { + for id, s := range t.sockets { s.Close() + delete(t.sockets, id) } // close the connection close(t.closed) diff --git a/tunnel/tunnel_test.go b/tunnel/tunnel_test.go index 79db046a..5b6cfbc0 100644 --- a/tunnel/tunnel_test.go +++ b/tunnel/tunnel_test.go @@ -9,13 +9,16 @@ import ( ) // testAccept will accept connections on the transport, create a new link and tunnel on top -func testAccept(t *testing.T, tun Tunnel, wg *sync.WaitGroup) { +func testAccept(t *testing.T, tun Tunnel, wait chan bool, wg *sync.WaitGroup) { // listen on some virtual address tl, err := tun.Listen("test-tunnel") if err != nil { t.Fatal(err) } + // receiver ready; notify sender + wait <- true + // accept a connection c, err := tl.Accept() if err != nil { @@ -46,7 +49,12 @@ func testAccept(t *testing.T, tun Tunnel, wg *sync.WaitGroup) { } // testSend will create a new link to an address and then a tunnel on top -func testSend(t *testing.T, tun Tunnel) { +func testSend(t *testing.T, tun Tunnel, wait chan bool, wg *sync.WaitGroup) { + defer wg.Done() + + // wait for the listener to get ready + <-wait + // dial a new session c, err := tun.Dial("test-tunnel") if err != nil { @@ -95,8 +103,6 @@ func TestTunnel(t *testing.T) { } defer tunB.Close() - time.Sleep(time.Millisecond * 50) - // start tunA err = tunA.Connect() if err != nil { @@ -104,51 +110,190 @@ func TestTunnel(t *testing.T) { } defer tunA.Close() - time.Sleep(time.Millisecond * 50) + wait := make(chan bool) var wg sync.WaitGroup - // start accepting connections - // on tunnel A wg.Add(1) - go testAccept(t, tunA, &wg) + // start the listener + go testAccept(t, tunB, wait, &wg) - time.Sleep(time.Millisecond * 50) - - // dial and send via B - testSend(t, tunB) + wg.Add(1) + // start the client + go testSend(t, tunA, wait, &wg) // wait until done wg.Wait() } func TestLoopbackTunnel(t *testing.T) { - // create a new tunnel client + // create a new tunnel tun := NewTunnel( Address("127.0.0.1:9096"), Nodes("127.0.0.1:9096"), ) - // start tunB + // start tunnel err := tun.Connect() if err != nil { t.Fatal(err) } defer tun.Close() - time.Sleep(time.Millisecond * 50) + wait := make(chan bool) var wg sync.WaitGroup - // start accepting connections - // on tunnel A wg.Add(1) - go testAccept(t, tun, &wg) + // start the listener + go testAccept(t, tun, wait, &wg) - time.Sleep(time.Millisecond * 50) - - // dial and send via B - testSend(t, tun) + wg.Add(1) + // start the client + go testSend(t, tun, wait, &wg) + + // wait until done + wg.Wait() +} + +func testBrokenTunAccept(t *testing.T, tun Tunnel, wait chan bool, wg *sync.WaitGroup) { + defer wg.Done() + + // listen on some virtual address + tl, err := tun.Listen("test-tunnel") + if err != nil { + t.Fatal(err) + } + + // receiver ready; notify sender + wait <- true + + // accept a connection + c, err := tl.Accept() + if err != nil { + t.Fatal(err) + } + + // accept the message and close the tunnel + // we do this to simulate loss of network connection + m := new(transport.Message) + if err := c.Recv(m); err != nil { + t.Fatal(err) + } + tun.Close() + + // re-start tunnel + err = tun.Connect() + if err != nil { + t.Fatal(err) + } + defer tun.Close() + + // listen on some virtual address + tl, err = tun.Listen("test-tunnel") + if err != nil { + t.Fatal(err) + } + + // receiver ready; notify sender + wait <- true + + // accept a connection + c, err = tl.Accept() + if err != nil { + t.Fatal(err) + } + + // accept the message + m = new(transport.Message) + if err := c.Recv(m); err != nil { + t.Fatal(err) + } + + // notify sender we have received the message + <-wait +} + +func testBrokenTunSend(t *testing.T, tun Tunnel, wait chan bool, wg *sync.WaitGroup) { + defer wg.Done() + + // wait for the listener to get ready + <-wait + + // dial a new session + c, err := tun.Dial("test-tunnel") + if err != nil { + t.Fatal(err) + } + defer c.Close() + + m := transport.Message{ + Header: map[string]string{ + "test": "send", + }, + } + + // send the message + if err := c.Send(&m); err != nil { + t.Fatal(err) + } + + // wait for the listener to get ready + <-wait + + // give it time to reconnect + time.Sleep(2 * ReconnectTime) + + // send the message + if err := c.Send(&m); err != nil { + t.Fatal(err) + } + + // wait for the listener to receive the message + // c.Send merely enqueues the message to the link send queue and returns + // in order to verify it was received we wait for the listener to tell us + wait <- true +} + +func TestReconnectTunnel(t *testing.T) { + // create a new tunnel client + tunA := NewTunnel( + Address("127.0.0.1:9096"), + Nodes("127.0.0.1:9097"), + ) + + // create a new tunnel server + tunB := NewTunnel( + Address("127.0.0.1:9097"), + ) + + // start tunnel + err := tunB.Connect() + if err != nil { + t.Fatal(err) + } + + // we manually override the tunnel.ReconnectTime value here + // this is so that we make the reconnects faster than the default 5s + ReconnectTime = 200 * time.Millisecond + + // start tunnel + err = tunA.Connect() + if err != nil { + t.Fatal(err) + } + + wait := make(chan bool) + + var wg sync.WaitGroup + + wg.Add(1) + // start tunnel listener + go testBrokenTunAccept(t, tunB, wait, &wg) + + wg.Add(1) + // start tunnel sender + go testBrokenTunSend(t, tunA, wait, &wg) // wait until done wg.Wait()