From 5db7514a914cd3c01e18fbae833e8f4428065818 Mon Sep 17 00:00:00 2001 From: Milos Gajdos Date: Thu, 5 Dec 2019 15:50:32 +0000 Subject: [PATCH] This PR fixes various tunnel race conditions --- tunnel/default.go | 20 ++++++------ tunnel/link.go | 6 ++++ tunnel/tunnel.go | 13 ++++---- tunnel/tunnel_reconnect_test.go | 55 +++++++++++++++++++++++++++++++++ tunnel/tunnel_test.go | 46 --------------------------- 5 files changed, 78 insertions(+), 62 deletions(-) create mode 100644 tunnel/tunnel_reconnect_test.go diff --git a/tunnel/default.go b/tunnel/default.go index da96540c..1cb0dbfc 100644 --- a/tunnel/default.go +++ b/tunnel/default.go @@ -854,16 +854,6 @@ func (t *tun) connect() error { } }() - // setup links - t.setupLinks() - - // process outbound messages to be sent - // process sends to all links - go t.process() - - // monitor links - go t.monitor() - return nil } @@ -889,6 +879,16 @@ func (t *tun) Connect() error { // create new close channel t.closed = make(chan bool) + // setup links + t.setupLinks() + + // process outbound messages to be sent + // process sends to all links + go t.process() + + // monitor links + go t.monitor() + return nil } diff --git a/tunnel/link.go b/tunnel/link.go index 25083138..042f1cb9 100644 --- a/tunnel/link.go +++ b/tunnel/link.go @@ -229,7 +229,10 @@ func (l *link) manage() { // check the type of message switch { case bytes.Equal(p.message.Body, linkRequest): + l.RLock() log.Tracef("Link %s received link request %v", l.id, p.message.Body) + l.RUnlock() + // send response if err := send(linkResponse); err != nil { l.Lock() @@ -239,7 +242,10 @@ func (l *link) manage() { case bytes.Equal(p.message.Body, linkResponse): // set round trip time d := time.Since(now) + l.RLock() log.Tracef("Link %s received link response in %v", p.message.Body, d) + l.RUnlock() + // set the RTT l.setRTT(d) } case <-t.C: diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 212e07e4..56928d8d 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -36,26 +36,27 @@ type Mode uint8 // and Micro-Tunnel-Session header. The tunnel id is a hash of // the address being requested. type Tunnel interface { + // Init initializes tunnel with options Init(opts ...Option) error - // Address the tunnel is listening on + // Address returns the address the tunnel is listening on Address() string // Connect connects the tunnel Connect() error // Close closes the tunnel Close() error - // All the links the tunnel is connected to + // Links returns all the links the tunnel is connected to Links() []Link - // Connect to a channel + // Dial allows a client to connect to a channel Dial(channel string, opts ...DialOption) (Session, error) - // Accept connections on a channel + // Listen allows to accept connections on a channel Listen(channel string, opts ...ListenOption) (Listener, error) - // Name of the tunnel implementation + // String returns the name of the tunnel implementation String() string } // Link represents internal links to the tunnel type Link interface { - // The id of the link + // Id returns the link unique Id Id() string // Delay is the current load on the link (lower is better) Delay() int64 diff --git a/tunnel/tunnel_reconnect_test.go b/tunnel/tunnel_reconnect_test.go new file mode 100644 index 00000000..2c78b8b2 --- /dev/null +++ b/tunnel/tunnel_reconnect_test.go @@ -0,0 +1,55 @@ +// +build !race + +package tunnel + +import ( + "sync" + "testing" + "time" +) + +func TestReconnectTunnel(t *testing.T) { + // create a new tunnel client + tunA := NewTunnel( + Address("127.0.0.1:9096"), + Nodes("127.0.0.1:9097"), + ) + + // create a new tunnel server + tunB := NewTunnel( + Address("127.0.0.1:9097"), + ) + + // start tunnel + err := tunB.Connect() + if err != nil { + t.Fatal(err) + } + defer tunB.Close() + + // we manually override the tunnel.ReconnectTime value here + // this is so that we make the reconnects faster than the default 5s + ReconnectTime = 200 * time.Millisecond + + // start tunnel + err = tunA.Connect() + if err != nil { + t.Fatal(err) + } + defer tunA.Close() + + wait := make(chan bool) + + var wg sync.WaitGroup + + wg.Add(1) + // start tunnel listener + go testBrokenTunAccept(t, tunB, wait, &wg) + + wg.Add(1) + // start tunnel sender + go testBrokenTunSend(t, tunA, wait, &wg) + + // wait until done + wg.Wait() +} diff --git a/tunnel/tunnel_test.go b/tunnel/tunnel_test.go index e884e692..8c3119da 100644 --- a/tunnel/tunnel_test.go +++ b/tunnel/tunnel_test.go @@ -247,52 +247,6 @@ func testBrokenTunSend(t *testing.T, tun Tunnel, wait chan bool, wg *sync.WaitGr <-wait } -func TestReconnectTunnel(t *testing.T) { - // create a new tunnel client - tunA := NewTunnel( - Address("127.0.0.1:9096"), - Nodes("127.0.0.1:9097"), - ) - - // create a new tunnel server - tunB := NewTunnel( - Address("127.0.0.1:9097"), - ) - - // start tunnel - err := tunB.Connect() - if err != nil { - t.Fatal(err) - } - defer tunB.Close() - - // we manually override the tunnel.ReconnectTime value here - // this is so that we make the reconnects faster than the default 5s - ReconnectTime = 200 * time.Millisecond - - // start tunnel - err = tunA.Connect() - if err != nil { - t.Fatal(err) - } - defer tunA.Close() - - wait := make(chan bool) - - var wg sync.WaitGroup - - wg.Add(1) - // start tunnel listener - go testBrokenTunAccept(t, tunB, wait, &wg) - - wg.Add(1) - // start tunnel sender - go testBrokenTunSend(t, tunA, wait, &wg) - - // wait until done - wg.Wait() -} - func TestTunnelRTTRate(t *testing.T) { // create a new tunnel client tunA := NewTunnel(