This PR fixes various tunnel race conditions
This commit is contained in:
		| @@ -854,16 +854,6 @@ func (t *tun) connect() error { | |||||||
| 		} | 		} | ||||||
| 	}() | 	}() | ||||||
|  |  | ||||||
| 	// setup links |  | ||||||
| 	t.setupLinks() |  | ||||||
|  |  | ||||||
| 	// process outbound messages to be sent |  | ||||||
| 	// process sends to all links |  | ||||||
| 	go t.process() |  | ||||||
|  |  | ||||||
| 	// monitor links |  | ||||||
| 	go t.monitor() |  | ||||||
|  |  | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -889,6 +879,16 @@ func (t *tun) Connect() error { | |||||||
| 	// create new close channel | 	// create new close channel | ||||||
| 	t.closed = make(chan bool) | 	t.closed = make(chan bool) | ||||||
|  |  | ||||||
|  | 	// setup links | ||||||
|  | 	t.setupLinks() | ||||||
|  |  | ||||||
|  | 	// process outbound messages to be sent | ||||||
|  | 	// process sends to all links | ||||||
|  | 	go t.process() | ||||||
|  |  | ||||||
|  | 	// monitor links | ||||||
|  | 	go t.monitor() | ||||||
|  |  | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -229,7 +229,10 @@ func (l *link) manage() { | |||||||
| 			// check the type of message | 			// check the type of message | ||||||
| 			switch { | 			switch { | ||||||
| 			case bytes.Equal(p.message.Body, linkRequest): | 			case bytes.Equal(p.message.Body, linkRequest): | ||||||
|  | 				l.RLock() | ||||||
| 				log.Tracef("Link %s received link request %v", l.id, p.message.Body) | 				log.Tracef("Link %s received link request %v", l.id, p.message.Body) | ||||||
|  | 				l.RUnlock() | ||||||
|  |  | ||||||
| 				// send response | 				// send response | ||||||
| 				if err := send(linkResponse); err != nil { | 				if err := send(linkResponse); err != nil { | ||||||
| 					l.Lock() | 					l.Lock() | ||||||
| @@ -239,7 +242,10 @@ func (l *link) manage() { | |||||||
| 			case bytes.Equal(p.message.Body, linkResponse): | 			case bytes.Equal(p.message.Body, linkResponse): | ||||||
| 				// set round trip time | 				// set round trip time | ||||||
| 				d := time.Since(now) | 				d := time.Since(now) | ||||||
|  | 				l.RLock() | ||||||
| 				log.Tracef("Link %s received link response in %v", p.message.Body, d) | 				log.Tracef("Link %s received link response in %v", p.message.Body, d) | ||||||
|  | 				l.RUnlock() | ||||||
|  | 				// set the RTT | ||||||
| 				l.setRTT(d) | 				l.setRTT(d) | ||||||
| 			} | 			} | ||||||
| 		case <-t.C: | 		case <-t.C: | ||||||
|   | |||||||
| @@ -36,26 +36,27 @@ type Mode uint8 | |||||||
| // and Micro-Tunnel-Session header. The tunnel id is a hash of | // and Micro-Tunnel-Session header. The tunnel id is a hash of | ||||||
| // the address being requested. | // the address being requested. | ||||||
| type Tunnel interface { | type Tunnel interface { | ||||||
|  | 	// Init initializes tunnel with options | ||||||
| 	Init(opts ...Option) error | 	Init(opts ...Option) error | ||||||
| 	// Address the tunnel is listening on | 	// Address returns the address the tunnel is listening on | ||||||
| 	Address() string | 	Address() string | ||||||
| 	// Connect connects the tunnel | 	// Connect connects the tunnel | ||||||
| 	Connect() error | 	Connect() error | ||||||
| 	// Close closes the tunnel | 	// Close closes the tunnel | ||||||
| 	Close() error | 	Close() error | ||||||
| 	// All the links the tunnel is connected to | 	// Links returns all the links the tunnel is connected to | ||||||
| 	Links() []Link | 	Links() []Link | ||||||
| 	// Connect to a channel | 	// Dial allows a client to connect to a channel | ||||||
| 	Dial(channel string, opts ...DialOption) (Session, error) | 	Dial(channel string, opts ...DialOption) (Session, error) | ||||||
| 	// Accept connections on a channel | 	// Listen allows to accept connections on a channel | ||||||
| 	Listen(channel string, opts ...ListenOption) (Listener, error) | 	Listen(channel string, opts ...ListenOption) (Listener, error) | ||||||
| 	// Name of the tunnel implementation | 	// String returns the name of the tunnel implementation | ||||||
| 	String() string | 	String() string | ||||||
| } | } | ||||||
|  |  | ||||||
| // Link represents internal links to the tunnel | // Link represents internal links to the tunnel | ||||||
| type Link interface { | type Link interface { | ||||||
| 	// The id of the link | 	// Id returns the link unique Id | ||||||
| 	Id() string | 	Id() string | ||||||
| 	// Delay is the current load on the link (lower is better) | 	// Delay is the current load on the link (lower is better) | ||||||
| 	Delay() int64 | 	Delay() int64 | ||||||
|   | |||||||
							
								
								
									
										55
									
								
								tunnel/tunnel_reconnect_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								tunnel/tunnel_reconnect_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,55 @@ | |||||||
|  | // +build !race | ||||||
|  |  | ||||||
|  | package tunnel | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"sync" | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestReconnectTunnel(t *testing.T) { | ||||||
|  | 	// create a new tunnel client | ||||||
|  | 	tunA := NewTunnel( | ||||||
|  | 		Address("127.0.0.1:9096"), | ||||||
|  | 		Nodes("127.0.0.1:9097"), | ||||||
|  | 	) | ||||||
|  |  | ||||||
|  | 	// create a new tunnel server | ||||||
|  | 	tunB := NewTunnel( | ||||||
|  | 		Address("127.0.0.1:9097"), | ||||||
|  | 	) | ||||||
|  |  | ||||||
|  | 	// start tunnel | ||||||
|  | 	err := tunB.Connect() | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	defer tunB.Close() | ||||||
|  |  | ||||||
|  | 	// we manually override the tunnel.ReconnectTime value here | ||||||
|  | 	// this is so that we make the reconnects faster than the default 5s | ||||||
|  | 	ReconnectTime = 200 * time.Millisecond | ||||||
|  |  | ||||||
|  | 	// start tunnel | ||||||
|  | 	err = tunA.Connect() | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	defer tunA.Close() | ||||||
|  |  | ||||||
|  | 	wait := make(chan bool) | ||||||
|  |  | ||||||
|  | 	var wg sync.WaitGroup | ||||||
|  |  | ||||||
|  | 	wg.Add(1) | ||||||
|  | 	// start tunnel listener | ||||||
|  | 	go testBrokenTunAccept(t, tunB, wait, &wg) | ||||||
|  |  | ||||||
|  | 	wg.Add(1) | ||||||
|  | 	// start tunnel sender | ||||||
|  | 	go testBrokenTunSend(t, tunA, wait, &wg) | ||||||
|  |  | ||||||
|  | 	// wait until done | ||||||
|  | 	wg.Wait() | ||||||
|  | } | ||||||
| @@ -247,52 +247,6 @@ func testBrokenTunSend(t *testing.T, tun Tunnel, wait chan bool, wg *sync.WaitGr | |||||||
| 	<-wait | 	<-wait | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestReconnectTunnel(t *testing.T) { |  | ||||||
| 	// create a new tunnel client |  | ||||||
| 	tunA := NewTunnel( |  | ||||||
| 		Address("127.0.0.1:9096"), |  | ||||||
| 		Nodes("127.0.0.1:9097"), |  | ||||||
| 	) |  | ||||||
|  |  | ||||||
| 	// create a new tunnel server |  | ||||||
| 	tunB := NewTunnel( |  | ||||||
| 		Address("127.0.0.1:9097"), |  | ||||||
| 	) |  | ||||||
|  |  | ||||||
| 	// start tunnel |  | ||||||
| 	err := tunB.Connect() |  | ||||||
| 	if err != nil { |  | ||||||
| 		t.Fatal(err) |  | ||||||
| 	} |  | ||||||
| 	defer tunB.Close() |  | ||||||
|  |  | ||||||
| 	// we manually override the tunnel.ReconnectTime value here |  | ||||||
| 	// this is so that we make the reconnects faster than the default 5s |  | ||||||
| 	ReconnectTime = 200 * time.Millisecond |  | ||||||
|  |  | ||||||
| 	// start tunnel |  | ||||||
| 	err = tunA.Connect() |  | ||||||
| 	if err != nil { |  | ||||||
| 		t.Fatal(err) |  | ||||||
| 	} |  | ||||||
| 	defer tunA.Close() |  | ||||||
|  |  | ||||||
| 	wait := make(chan bool) |  | ||||||
|  |  | ||||||
| 	var wg sync.WaitGroup |  | ||||||
|  |  | ||||||
| 	wg.Add(1) |  | ||||||
| 	// start tunnel listener |  | ||||||
| 	go testBrokenTunAccept(t, tunB, wait, &wg) |  | ||||||
|  |  | ||||||
| 	wg.Add(1) |  | ||||||
| 	// start tunnel sender |  | ||||||
| 	go testBrokenTunSend(t, tunA, wait, &wg) |  | ||||||
|  |  | ||||||
| 	// wait until done |  | ||||||
| 	wg.Wait() |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestTunnelRTTRate(t *testing.T) { | func TestTunnelRTTRate(t *testing.T) { | ||||||
| 	// create a new tunnel client | 	// create a new tunnel client | ||||||
| 	tunA := NewTunnel( | 	tunA := NewTunnel( | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user