diff --git a/network/tunnel/default.go b/network/tunnel/default.go index 89e04dd3..a810783e 100644 --- a/network/tunnel/default.go +++ b/network/tunnel/default.go @@ -132,9 +132,11 @@ func (t *tun) listen() { case "connect": // assuming new connection // TODO: do something with this + continue case "close": // assuming connection closed // TODO: do something with this + continue } // the tunnel id diff --git a/network/tunnel/listener.go b/network/tunnel/listener.go index dac315ce..6c803eb9 100644 --- a/network/tunnel/listener.go +++ b/network/tunnel/listener.go @@ -46,15 +46,25 @@ func (t *tunListener) process() { wait: make(chan bool), } + // first message + sock.recv <- m + // save the socket conns[m.session] = sock + + // send to accept chan + select { + case <-t.closed: + return + case t.accept <- sock: + } } // send this to the accept chan select { - case <-t.closed: - return - case t.accept <- sock: + case <-sock.closed: + delete(conns, m.session) + case sock.recv <- m: } } } diff --git a/network/tunnel/tunnel_test.go b/network/tunnel/tunnel_test.go new file mode 100644 index 00000000..d8b84739 --- /dev/null +++ b/network/tunnel/tunnel_test.go @@ -0,0 +1,111 @@ +package tunnel + +import ( + "testing" + + "github.com/micro/go-micro/network/link" + "github.com/micro/go-micro/transport" +) + +func testAccept(t *testing.T, l transport.Listener, wait chan bool) error { + // accept new connections on the transport + // establish a link and tunnel + return l.Accept(func(s transport.Socket) { + // convert the socket into a link + li := link.NewLink( + link.Socket(s), + ) + + // connect the link e.g start internal buffers + if err := li.Connect(); err != nil { + t.Fatal(err) + } + + // create a new tunnel + tun := NewTunnel(li) + + // connect the tunnel + if err := tun.Connect(); err != nil { + t.Fatal(err) + } + + // listen on some virtual address + tl, err := tun.Listen("test-tunnel") + if err != nil { + t.Fatal(err) + return + } + + // accept a connection + c, err := tl.Accept() + if err != nil { + t.Fatal(err) + } + + // get a message + for { + m := new(transport.Message) + if err := c.Recv(m); err != nil { + t.Fatal(err) + } + close(wait) + return + } + }) +} + +func testSend(t *testing.T, addr string) { + // create a new link + l := link.NewLink( + link.Address(addr), + ) + + // connect the link, this includes dialing + if err := l.Connect(); err != nil { + t.Fatal(err) + } + + // create a tunnel on the link + tun := NewTunnel(l) + + // connect the tunnel with the remote side + if err := tun.Connect(); err != nil { + t.Fatal(err) + } + + // dial a new session + c, err := tun.Dial("test-tunnel") + if err != nil { + t.Fatal(err) + } + + m := transport.Message{ + Header: map[string]string{ + "test": "header", + }, + } + if err := c.Send(&m); err != nil { + t.Fatal(err) + } +} + +func TestTunnel(t *testing.T) { + // create a new listener + tr := transport.NewTransport() + l, err := tr.Listen(":0") + if err != nil { + t.Fatal(err) + } + defer l.Close() + + wait := make(chan bool) + + // start accepting connections + go testAccept(t, l, wait) + + // send a message + testSend(t, l.Addr()) + + // wait until message is received + <-wait +}