commit
f19308f1e6
@ -17,6 +17,9 @@ type tun struct {
|
|||||||
|
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
|
|
||||||
|
// tunnel token
|
||||||
|
token string
|
||||||
|
|
||||||
// to indicate if we're connected or not
|
// to indicate if we're connected or not
|
||||||
connected bool
|
connected bool
|
||||||
|
|
||||||
@ -50,6 +53,7 @@ func newTunnel(opts ...Option) *tun {
|
|||||||
|
|
||||||
return &tun{
|
return &tun{
|
||||||
options: options,
|
options: options,
|
||||||
|
token: uuid.New().String(),
|
||||||
send: make(chan *message, 128),
|
send: make(chan *message, 128),
|
||||||
closed: make(chan bool),
|
closed: make(chan bool),
|
||||||
sockets: make(map[string]*socket),
|
sockets: make(map[string]*socket),
|
||||||
@ -57,6 +61,14 @@ func newTunnel(opts ...Option) *tun {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Init initializes tunnel options
|
||||||
|
func (t *tun) Init(opts ...Option) error {
|
||||||
|
for _, o := range opts {
|
||||||
|
o(&t.options)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// getSocket returns a socket from the internal socket map.
|
// getSocket returns a socket from the internal socket map.
|
||||||
// It does this based on the Micro-Tunnel-Id and Micro-Tunnel-Session
|
// It does this based on the Micro-Tunnel-Id and Micro-Tunnel-Session
|
||||||
func (t *tun) getSocket(id, session string) (*socket, bool) {
|
func (t *tun) getSocket(id, session string) (*socket, bool) {
|
||||||
@ -92,6 +104,7 @@ func (t *tun) newSocket(id, session string) (*socket, bool) {
|
|||||||
t.Unlock()
|
t.Unlock()
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
t.sockets[id+session] = s
|
t.sockets[id+session] = s
|
||||||
t.Unlock()
|
t.Unlock()
|
||||||
|
|
||||||
@ -126,6 +139,9 @@ func (t *tun) process() {
|
|||||||
// set the session id
|
// set the session id
|
||||||
newMsg.Header["Micro-Tunnel-Session"] = msg.session
|
newMsg.Header["Micro-Tunnel-Session"] = msg.session
|
||||||
|
|
||||||
|
// set the tunnel token
|
||||||
|
newMsg.Header["Micro-Tunnel-Token"] = t.token
|
||||||
|
|
||||||
// send the message via the interface
|
// send the message via the interface
|
||||||
t.RLock()
|
t.RLock()
|
||||||
if len(t.links) == 0 {
|
if len(t.links) == 0 {
|
||||||
@ -144,7 +160,10 @@ func (t *tun) process() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// process incoming messages
|
// process incoming messages
|
||||||
func (t *tun) listen(link transport.Socket, listener bool) {
|
func (t *tun) listen(link transport.Socket) {
|
||||||
|
// loopback flag
|
||||||
|
var loopback bool
|
||||||
|
|
||||||
for {
|
for {
|
||||||
// process anything via the net interface
|
// process anything via the net interface
|
||||||
msg := new(transport.Message)
|
msg := new(transport.Message)
|
||||||
@ -155,10 +174,21 @@ func (t *tun) listen(link transport.Socket, listener bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
switch msg.Header["Micro-Tunnel"] {
|
switch msg.Header["Micro-Tunnel"] {
|
||||||
case "connect", "close":
|
case "connect":
|
||||||
// TODO: handle the connect/close message
|
// check the Micro-Tunnel-Token
|
||||||
// maybe used to create the dial/listen sockets
|
token, ok := msg.Header["Micro-Tunnel-Token"]
|
||||||
// or report io.EOF or maybe to kill the link
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// are we connecting to ourselves?
|
||||||
|
if token == t.token {
|
||||||
|
loopback = true
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
case "close":
|
||||||
|
// TODO: handle the close message
|
||||||
|
// maybe report io.EOF or kill the link
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -182,18 +212,23 @@ func (t *tun) listen(link transport.Socket, listener bool) {
|
|||||||
var exists bool
|
var exists bool
|
||||||
|
|
||||||
log.Debugf("Received %+v from %s", msg, link.Remote())
|
log.Debugf("Received %+v from %s", msg, link.Remote())
|
||||||
// get the socket based on the tunnel id and session
|
|
||||||
// this could be something we dialed in which case
|
|
||||||
// we have a session for it otherwise its a listener
|
|
||||||
s, exists = t.getSocket(id, session)
|
|
||||||
if !exists {
|
|
||||||
// try get it based on just the tunnel id
|
|
||||||
// the assumption here is that a listener
|
|
||||||
// has no session but its set a listener session
|
|
||||||
s, exists = t.getSocket(id, "listener")
|
|
||||||
}
|
|
||||||
|
|
||||||
// no socket in existence
|
switch {
|
||||||
|
case loopback:
|
||||||
|
s, exists = t.getSocket(id, "listener")
|
||||||
|
default:
|
||||||
|
// get the socket based on the tunnel id and session
|
||||||
|
// this could be something we dialed in which case
|
||||||
|
// we have a session for it otherwise its a listener
|
||||||
|
s, exists = t.getSocket(id, session)
|
||||||
|
if !exists {
|
||||||
|
// try get it based on just the tunnel id
|
||||||
|
// the assumption here is that a listener
|
||||||
|
// has no session but its set a listener session
|
||||||
|
s, exists = t.getSocket(id, "listener")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// bail if no socket has been found
|
||||||
if !exists {
|
if !exists {
|
||||||
log.Debugf("Tunnel skipping no socket exists")
|
log.Debugf("Tunnel skipping no socket exists")
|
||||||
// drop it, we don't care about
|
// drop it, we don't care about
|
||||||
@ -246,6 +281,7 @@ func (t *tun) listen(link transport.Socket, listener bool) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// connect the tunnel to all the nodes and listen for incoming tunnel connections
|
||||||
func (t *tun) connect() error {
|
func (t *tun) connect() error {
|
||||||
l, err := t.options.Transport.Listen(t.options.Address)
|
l, err := t.options.Transport.Listen(t.options.Address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -277,7 +313,7 @@ func (t *tun) connect() error {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
// listen for inbound messages
|
// listen for inbound messages
|
||||||
t.listen(sock, true)
|
t.listen(sock)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Lock()
|
t.Lock()
|
||||||
@ -306,14 +342,15 @@ func (t *tun) connect() error {
|
|||||||
|
|
||||||
if err := c.Send(&transport.Message{
|
if err := c.Send(&transport.Message{
|
||||||
Header: map[string]string{
|
Header: map[string]string{
|
||||||
"Micro-Tunnel": "connect",
|
"Micro-Tunnel": "connect",
|
||||||
|
"Micro-Tunnel-Token": t.token,
|
||||||
},
|
},
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// process incoming messages
|
// process incoming messages
|
||||||
go t.listen(c, false)
|
go t.listen(c)
|
||||||
|
|
||||||
// save the link
|
// save the link
|
||||||
id := uuid.New().String()
|
id := uuid.New().String()
|
||||||
@ -330,12 +367,36 @@ func (t *tun) connect() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Connect the tunnel
|
||||||
|
func (t *tun) Connect() error {
|
||||||
|
t.Lock()
|
||||||
|
defer t.Unlock()
|
||||||
|
|
||||||
|
// already connected
|
||||||
|
if t.connected {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// send the connect message
|
||||||
|
if err := t.connect(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// set as connected
|
||||||
|
t.connected = true
|
||||||
|
// create new close channel
|
||||||
|
t.closed = make(chan bool)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (t *tun) close() error {
|
func (t *tun) close() error {
|
||||||
// close all the links
|
// close all the links
|
||||||
for id, link := range t.links {
|
for id, link := range t.links {
|
||||||
link.Send(&transport.Message{
|
link.Send(&transport.Message{
|
||||||
Header: map[string]string{
|
Header: map[string]string{
|
||||||
"Micro-Tunnel": "close",
|
"Micro-Tunnel": "close",
|
||||||
|
"Micro-Tunnel-Token": t.token,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
link.Close()
|
link.Close()
|
||||||
@ -376,36 +437,6 @@ func (t *tun) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connect the tunnel
|
|
||||||
func (t *tun) Connect() error {
|
|
||||||
t.Lock()
|
|
||||||
defer t.Unlock()
|
|
||||||
|
|
||||||
// already connected
|
|
||||||
if t.connected {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// send the connect message
|
|
||||||
if err := t.connect(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// set as connected
|
|
||||||
t.connected = true
|
|
||||||
// create new close channel
|
|
||||||
t.closed = make(chan bool)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) Init(opts ...Option) error {
|
|
||||||
for _, o := range opts {
|
|
||||||
o(&t.options)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Dial an address
|
// Dial an address
|
||||||
func (t *tun) Dial(addr string) (Conn, error) {
|
func (t *tun) Dial(addr string) (Conn, error) {
|
||||||
log.Debugf("Tunnel dialing %s", addr)
|
log.Debugf("Tunnel dialing %s", addr)
|
||||||
@ -413,7 +444,6 @@ func (t *tun) Dial(addr string) (Conn, error) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("error dialing " + addr)
|
return nil, errors.New("error dialing " + addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// set remote
|
// set remote
|
||||||
c.remote = addr
|
c.remote = addr
|
||||||
// set local
|
// set local
|
||||||
|
@ -24,10 +24,22 @@ func testAccept(t *testing.T, tun Tunnel, wg *sync.WaitGroup) {
|
|||||||
|
|
||||||
// get a message
|
// get a message
|
||||||
for {
|
for {
|
||||||
|
// accept the message
|
||||||
m := new(transport.Message)
|
m := new(transport.Message)
|
||||||
if err := c.Recv(m); err != nil {
|
if err := c.Recv(m); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if v := m.Header["test"]; v != "send" {
|
||||||
|
t.Fatalf("Accept side expected test:send header. Received: %s", v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// now respond
|
||||||
|
m.Header["test"] = "accept"
|
||||||
|
if err := c.Send(m); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
wg.Done()
|
wg.Done()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -44,13 +56,24 @@ func testSend(t *testing.T, tun Tunnel) {
|
|||||||
|
|
||||||
m := transport.Message{
|
m := transport.Message{
|
||||||
Header: map[string]string{
|
Header: map[string]string{
|
||||||
"test": "header",
|
"test": "send",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// send the message
|
||||||
if err := c.Send(&m); err != nil {
|
if err := c.Send(&m); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// now wait for the response
|
||||||
|
mr := new(transport.Message)
|
||||||
|
if err := c.Recv(mr); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if v := mr.Header["test"]; v != "accept" {
|
||||||
|
t.Fatalf("Message not received from accepted side. Received: %s", v)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTunnel(t *testing.T) {
|
func TestTunnel(t *testing.T) {
|
||||||
@ -98,3 +121,35 @@ func TestTunnel(t *testing.T) {
|
|||||||
// wait until done
|
// wait until done
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLoopbackTunnel(t *testing.T) {
|
||||||
|
// create a new tunnel client
|
||||||
|
tun := NewTunnel(
|
||||||
|
Address("127.0.0.1:9096"),
|
||||||
|
Nodes("127.0.0.1:9096"),
|
||||||
|
)
|
||||||
|
|
||||||
|
// start tunB
|
||||||
|
err := tun.Connect()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer tun.Close()
|
||||||
|
|
||||||
|
time.Sleep(time.Millisecond * 50)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
// start accepting connections
|
||||||
|
// on tunnel A
|
||||||
|
wg.Add(1)
|
||||||
|
go testAccept(t, tun, &wg)
|
||||||
|
|
||||||
|
time.Sleep(time.Millisecond * 50)
|
||||||
|
|
||||||
|
// dial and send via B
|
||||||
|
testSend(t, tun)
|
||||||
|
|
||||||
|
// wait until done
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user