Merge pull request #1022 from milosgajdos83/tunnel-races

This PR fixes various tunnel race conditions
This commit is contained in:
Asim Aslam 2019-12-05 15:59:29 +00:00 committed by GitHub
commit 3a10b1cdde
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 78 additions and 62 deletions

View File

@ -854,16 +854,6 @@ func (t *tun) connect() error {
} }
}() }()
// setup links
t.setupLinks()
// process outbound messages to be sent
// process sends to all links
go t.process()
// monitor links
go t.monitor()
return nil return nil
} }
@ -889,6 +879,16 @@ func (t *tun) Connect() error {
// create new close channel // create new close channel
t.closed = make(chan bool) t.closed = make(chan bool)
// setup links
t.setupLinks()
// process outbound messages to be sent
// process sends to all links
go t.process()
// monitor links
go t.monitor()
return nil return nil
} }

View File

@ -229,7 +229,10 @@ func (l *link) manage() {
// check the type of message // check the type of message
switch { switch {
case bytes.Equal(p.message.Body, linkRequest): case bytes.Equal(p.message.Body, linkRequest):
l.RLock()
log.Tracef("Link %s received link request %v", l.id, p.message.Body) log.Tracef("Link %s received link request %v", l.id, p.message.Body)
l.RUnlock()
// send response // send response
if err := send(linkResponse); err != nil { if err := send(linkResponse); err != nil {
l.Lock() l.Lock()
@ -239,7 +242,10 @@ func (l *link) manage() {
case bytes.Equal(p.message.Body, linkResponse): case bytes.Equal(p.message.Body, linkResponse):
// set round trip time // set round trip time
d := time.Since(now) d := time.Since(now)
l.RLock()
log.Tracef("Link %s received link response in %v", p.message.Body, d) log.Tracef("Link %s received link response in %v", p.message.Body, d)
l.RUnlock()
// set the RTT
l.setRTT(d) l.setRTT(d)
} }
case <-t.C: case <-t.C:

View File

@ -36,26 +36,27 @@ type Mode uint8
// and Micro-Tunnel-Session header. The tunnel id is a hash of // and Micro-Tunnel-Session header. The tunnel id is a hash of
// the address being requested. // the address being requested.
type Tunnel interface { type Tunnel interface {
// Init initializes tunnel with options
Init(opts ...Option) error Init(opts ...Option) error
// Address the tunnel is listening on // Address returns the address the tunnel is listening on
Address() string Address() string
// Connect connects the tunnel // Connect connects the tunnel
Connect() error Connect() error
// Close closes the tunnel // Close closes the tunnel
Close() error Close() error
// All the links the tunnel is connected to // Links returns all the links the tunnel is connected to
Links() []Link Links() []Link
// Connect to a channel // Dial allows a client to connect to a channel
Dial(channel string, opts ...DialOption) (Session, error) Dial(channel string, opts ...DialOption) (Session, error)
// Accept connections on a channel // Listen allows to accept connections on a channel
Listen(channel string, opts ...ListenOption) (Listener, error) Listen(channel string, opts ...ListenOption) (Listener, error)
// Name of the tunnel implementation // String returns the name of the tunnel implementation
String() string String() string
} }
// Link represents internal links to the tunnel // Link represents internal links to the tunnel
type Link interface { type Link interface {
// The id of the link // Id returns the link unique Id
Id() string Id() string
// Delay is the current load on the link (lower is better) // Delay is the current load on the link (lower is better)
Delay() int64 Delay() int64

View File

@ -0,0 +1,55 @@
// +build !race
package tunnel
import (
"sync"
"testing"
"time"
)
func TestReconnectTunnel(t *testing.T) {
// create a new tunnel client
tunA := NewTunnel(
Address("127.0.0.1:9096"),
Nodes("127.0.0.1:9097"),
)
// create a new tunnel server
tunB := NewTunnel(
Address("127.0.0.1:9097"),
)
// start tunnel
err := tunB.Connect()
if err != nil {
t.Fatal(err)
}
defer tunB.Close()
// we manually override the tunnel.ReconnectTime value here
// this is so that we make the reconnects faster than the default 5s
ReconnectTime = 200 * time.Millisecond
// start tunnel
err = tunA.Connect()
if err != nil {
t.Fatal(err)
}
defer tunA.Close()
wait := make(chan bool)
var wg sync.WaitGroup
wg.Add(1)
// start tunnel listener
go testBrokenTunAccept(t, tunB, wait, &wg)
wg.Add(1)
// start tunnel sender
go testBrokenTunSend(t, tunA, wait, &wg)
// wait until done
wg.Wait()
}

View File

@ -247,52 +247,6 @@ func testBrokenTunSend(t *testing.T, tun Tunnel, wait chan bool, wg *sync.WaitGr
<-wait <-wait
} }
func TestReconnectTunnel(t *testing.T) {
// create a new tunnel client
tunA := NewTunnel(
Address("127.0.0.1:9096"),
Nodes("127.0.0.1:9097"),
)
// create a new tunnel server
tunB := NewTunnel(
Address("127.0.0.1:9097"),
)
// start tunnel
err := tunB.Connect()
if err != nil {
t.Fatal(err)
}
defer tunB.Close()
// we manually override the tunnel.ReconnectTime value here
// this is so that we make the reconnects faster than the default 5s
ReconnectTime = 200 * time.Millisecond
// start tunnel
err = tunA.Connect()
if err != nil {
t.Fatal(err)
}
defer tunA.Close()
wait := make(chan bool)
var wg sync.WaitGroup
wg.Add(1)
// start tunnel listener
go testBrokenTunAccept(t, tunB, wait, &wg)
wg.Add(1)
// start tunnel sender
go testBrokenTunSend(t, tunA, wait, &wg)
// wait until done
wg.Wait()
}
func TestTunnelRTTRate(t *testing.T) { func TestTunnelRTTRate(t *testing.T) {
// create a new tunnel client // create a new tunnel client
tunA := NewTunnel( tunA := NewTunnel(