diff --git a/tunnel/link.go b/tunnel/link.go index 06871a13..036a4830 100644 --- a/tunnel/link.go +++ b/tunnel/link.go @@ -210,6 +210,7 @@ func (l *link) Close() error { case <-l.closed: return nil default: + l.Socket.Close() close(l.closed) } @@ -227,12 +228,19 @@ func (l *link) Send(m *transport.Message) error { // get time now now := time.Now() - // queue the message + // check if its closed first select { case <-l.closed: return io.EOF + default: + } + + // queue the message + select { case l.sendQueue <- p: // in the send queue + case <-l.closed: + return io.EOF } // error to use @@ -293,7 +301,17 @@ func (l *link) Send(m *transport.Message) error { func (l *link) Recv(m *transport.Message) error { select { case <-l.closed: - return io.EOF + // check if there's any messages left + select { + case pk := <-l.recvQueue: + // check the packet receive error + if pk.err != nil { + return pk.err + } + *m = *pk.message + default: + return io.EOF + } case pk := <-l.recvQueue: // check the packet receive error if pk.err != nil { diff --git a/tunnel/tunnel_test.go b/tunnel/tunnel_test.go index a06a1b01..e884e692 100644 --- a/tunnel/tunnel_test.go +++ b/tunnel/tunnel_test.go @@ -202,8 +202,8 @@ func testBrokenTunAccept(t *testing.T, tun Tunnel, wait chan bool, wg *sync.Wait t.Fatal(err) } - // notify sender we have received the message - <-wait + // notify the sender we have received + wait <- true } func testBrokenTunSend(t *testing.T, tun Tunnel, wait chan bool, wg *sync.WaitGroup) { @@ -234,7 +234,7 @@ func testBrokenTunSend(t *testing.T, tun Tunnel, wait chan bool, wg *sync.WaitGr <-wait // give it time to reconnect - time.Sleep(2 * ReconnectTime) + time.Sleep(5 * ReconnectTime) // send the message if err := c.Send(&m); err != nil { @@ -244,7 +244,7 @@ func testBrokenTunSend(t *testing.T, tun Tunnel, wait chan bool, wg *sync.WaitGr // wait for the listener to receive the message // c.Send merely enqueues the message to the link send queue and returns // in order to verify it was received we wait for the listener to tell us - wait <- true + <-wait } func TestReconnectTunnel(t *testing.T) {