| @@ -17,6 +17,9 @@ type tun struct { | |||||||
|  |  | ||||||
| 	sync.RWMutex | 	sync.RWMutex | ||||||
|  |  | ||||||
|  | 	// tunnel token | ||||||
|  | 	token string | ||||||
|  |  | ||||||
| 	// to indicate if we're connected or not | 	// to indicate if we're connected or not | ||||||
| 	connected bool | 	connected bool | ||||||
|  |  | ||||||
| @@ -50,6 +53,7 @@ func newTunnel(opts ...Option) *tun { | |||||||
|  |  | ||||||
| 	return &tun{ | 	return &tun{ | ||||||
| 		options: options, | 		options: options, | ||||||
|  | 		token:   uuid.New().String(), | ||||||
| 		send:    make(chan *message, 128), | 		send:    make(chan *message, 128), | ||||||
| 		closed:  make(chan bool), | 		closed:  make(chan bool), | ||||||
| 		sockets: make(map[string]*socket), | 		sockets: make(map[string]*socket), | ||||||
| @@ -57,6 +61,14 @@ func newTunnel(opts ...Option) *tun { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Init initializes tunnel options | ||||||
|  | func (t *tun) Init(opts ...Option) error { | ||||||
|  | 	for _, o := range opts { | ||||||
|  | 		o(&t.options) | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
| // getSocket returns a socket from the internal socket map. | // getSocket returns a socket from the internal socket map. | ||||||
| // It does this based on the Micro-Tunnel-Id and Micro-Tunnel-Session | // It does this based on the Micro-Tunnel-Id and Micro-Tunnel-Session | ||||||
| func (t *tun) getSocket(id, session string) (*socket, bool) { | func (t *tun) getSocket(id, session string) (*socket, bool) { | ||||||
| @@ -92,6 +104,7 @@ func (t *tun) newSocket(id, session string) (*socket, bool) { | |||||||
| 		t.Unlock() | 		t.Unlock() | ||||||
| 		return nil, false | 		return nil, false | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	t.sockets[id+session] = s | 	t.sockets[id+session] = s | ||||||
| 	t.Unlock() | 	t.Unlock() | ||||||
|  |  | ||||||
| @@ -126,6 +139,9 @@ func (t *tun) process() { | |||||||
| 			// set the session id | 			// set the session id | ||||||
| 			newMsg.Header["Micro-Tunnel-Session"] = msg.session | 			newMsg.Header["Micro-Tunnel-Session"] = msg.session | ||||||
|  |  | ||||||
|  | 			// set the tunnel token | ||||||
|  | 			newMsg.Header["Micro-Tunnel-Token"] = t.token | ||||||
|  |  | ||||||
| 			// send the message via the interface | 			// send the message via the interface | ||||||
| 			t.RLock() | 			t.RLock() | ||||||
| 			if len(t.links) == 0 { | 			if len(t.links) == 0 { | ||||||
| @@ -144,7 +160,10 @@ func (t *tun) process() { | |||||||
| } | } | ||||||
|  |  | ||||||
| // process incoming messages | // process incoming messages | ||||||
| func (t *tun) listen(link transport.Socket, listener bool) { | func (t *tun) listen(link transport.Socket) { | ||||||
|  | 	// loopback flag | ||||||
|  | 	var loopback bool | ||||||
|  |  | ||||||
| 	for { | 	for { | ||||||
| 		// process anything via the net interface | 		// process anything via the net interface | ||||||
| 		msg := new(transport.Message) | 		msg := new(transport.Message) | ||||||
| @@ -155,10 +174,21 @@ func (t *tun) listen(link transport.Socket, listener bool) { | |||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		switch msg.Header["Micro-Tunnel"] { | 		switch msg.Header["Micro-Tunnel"] { | ||||||
| 		case "connect", "close": | 		case "connect": | ||||||
| 			// TODO: handle the connect/close message | 			// check the Micro-Tunnel-Token | ||||||
| 			// maybe used to create the dial/listen sockets | 			token, ok := msg.Header["Micro-Tunnel-Token"] | ||||||
| 			// or report io.EOF or maybe to kill the link | 			if !ok { | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			// are we connecting to ourselves? | ||||||
|  | 			if token == t.token { | ||||||
|  | 				loopback = true | ||||||
|  | 			} | ||||||
|  | 			continue | ||||||
|  | 		case "close": | ||||||
|  | 			// TODO: handle the close message | ||||||
|  | 			// maybe report io.EOF or kill the link | ||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| @@ -182,18 +212,23 @@ func (t *tun) listen(link transport.Socket, listener bool) { | |||||||
| 		var exists bool | 		var exists bool | ||||||
|  |  | ||||||
| 		log.Debugf("Received %+v from %s", msg, link.Remote()) | 		log.Debugf("Received %+v from %s", msg, link.Remote()) | ||||||
| 		// get the socket based on the tunnel id and session |  | ||||||
| 		// this could be something we dialed in which case |  | ||||||
| 		// we have a session for it otherwise its a listener |  | ||||||
| 		s, exists = t.getSocket(id, session) |  | ||||||
| 		if !exists { |  | ||||||
| 			// try get it based on just the tunnel id |  | ||||||
| 			// the assumption here is that a listener |  | ||||||
| 			// has no session but its set a listener session |  | ||||||
| 			s, exists = t.getSocket(id, "listener") |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		// no socket in existence | 		switch { | ||||||
|  | 		case loopback: | ||||||
|  | 			s, exists = t.getSocket(id, "listener") | ||||||
|  | 		default: | ||||||
|  | 			// get the socket based on the tunnel id and session | ||||||
|  | 			// this could be something we dialed in which case | ||||||
|  | 			// we have a session for it otherwise its a listener | ||||||
|  | 			s, exists = t.getSocket(id, session) | ||||||
|  | 			if !exists { | ||||||
|  | 				// try get it based on just the tunnel id | ||||||
|  | 				// the assumption here is that a listener | ||||||
|  | 				// has no session but its set a listener session | ||||||
|  | 				s, exists = t.getSocket(id, "listener") | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		// bail if no socket has been found | ||||||
| 		if !exists { | 		if !exists { | ||||||
| 			log.Debugf("Tunnel skipping no socket exists") | 			log.Debugf("Tunnel skipping no socket exists") | ||||||
| 			// drop it, we don't care about | 			// drop it, we don't care about | ||||||
| @@ -246,6 +281,7 @@ func (t *tun) listen(link transport.Socket, listener bool) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // connect the tunnel to all the nodes and listen for incoming tunnel connections | ||||||
| func (t *tun) connect() error { | func (t *tun) connect() error { | ||||||
| 	l, err := t.options.Transport.Listen(t.options.Address) | 	l, err := t.options.Transport.Listen(t.options.Address) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -277,7 +313,7 @@ func (t *tun) connect() error { | |||||||
| 			}() | 			}() | ||||||
|  |  | ||||||
| 			// listen for inbound messages | 			// listen for inbound messages | ||||||
| 			t.listen(sock, true) | 			t.listen(sock) | ||||||
| 		}) | 		}) | ||||||
|  |  | ||||||
| 		t.Lock() | 		t.Lock() | ||||||
| @@ -306,14 +342,15 @@ func (t *tun) connect() error { | |||||||
|  |  | ||||||
| 		if err := c.Send(&transport.Message{ | 		if err := c.Send(&transport.Message{ | ||||||
| 			Header: map[string]string{ | 			Header: map[string]string{ | ||||||
| 				"Micro-Tunnel": "connect", | 				"Micro-Tunnel":       "connect", | ||||||
|  | 				"Micro-Tunnel-Token": t.token, | ||||||
| 			}, | 			}, | ||||||
| 		}); err != nil { | 		}); err != nil { | ||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// process incoming messages | 		// process incoming messages | ||||||
| 		go t.listen(c, false) | 		go t.listen(c) | ||||||
|  |  | ||||||
| 		// save the link | 		// save the link | ||||||
| 		id := uuid.New().String() | 		id := uuid.New().String() | ||||||
| @@ -330,12 +367,36 @@ func (t *tun) connect() error { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Connect the tunnel | ||||||
|  | func (t *tun) Connect() error { | ||||||
|  | 	t.Lock() | ||||||
|  | 	defer t.Unlock() | ||||||
|  |  | ||||||
|  | 	// already connected | ||||||
|  | 	if t.connected { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// send the connect message | ||||||
|  | 	if err := t.connect(); err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// set as connected | ||||||
|  | 	t.connected = true | ||||||
|  | 	// create new close channel | ||||||
|  | 	t.closed = make(chan bool) | ||||||
|  |  | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
| func (t *tun) close() error { | func (t *tun) close() error { | ||||||
| 	// close all the links | 	// close all the links | ||||||
| 	for id, link := range t.links { | 	for id, link := range t.links { | ||||||
| 		link.Send(&transport.Message{ | 		link.Send(&transport.Message{ | ||||||
| 			Header: map[string]string{ | 			Header: map[string]string{ | ||||||
| 				"Micro-Tunnel": "close", | 				"Micro-Tunnel":       "close", | ||||||
|  | 				"Micro-Tunnel-Token": t.token, | ||||||
| 			}, | 			}, | ||||||
| 		}) | 		}) | ||||||
| 		link.Close() | 		link.Close() | ||||||
| @@ -376,36 +437,6 @@ func (t *tun) Close() error { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| // Connect the tunnel |  | ||||||
| func (t *tun) Connect() error { |  | ||||||
| 	t.Lock() |  | ||||||
| 	defer t.Unlock() |  | ||||||
|  |  | ||||||
| 	// already connected |  | ||||||
| 	if t.connected { |  | ||||||
| 		return nil |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// send the connect message |  | ||||||
| 	if err := t.connect(); err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// set as connected |  | ||||||
| 	t.connected = true |  | ||||||
| 	// create new close channel |  | ||||||
| 	t.closed = make(chan bool) |  | ||||||
|  |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (t *tun) Init(opts ...Option) error { |  | ||||||
| 	for _, o := range opts { |  | ||||||
| 		o(&t.options) |  | ||||||
| 	} |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Dial an address | // Dial an address | ||||||
| func (t *tun) Dial(addr string) (Conn, error) { | func (t *tun) Dial(addr string) (Conn, error) { | ||||||
| 	log.Debugf("Tunnel dialing %s", addr) | 	log.Debugf("Tunnel dialing %s", addr) | ||||||
| @@ -413,7 +444,6 @@ func (t *tun) Dial(addr string) (Conn, error) { | |||||||
| 	if !ok { | 	if !ok { | ||||||
| 		return nil, errors.New("error dialing " + addr) | 		return nil, errors.New("error dialing " + addr) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// set remote | 	// set remote | ||||||
| 	c.remote = addr | 	c.remote = addr | ||||||
| 	// set local | 	// set local | ||||||
|   | |||||||
| @@ -24,10 +24,22 @@ func testAccept(t *testing.T, tun Tunnel, wg *sync.WaitGroup) { | |||||||
|  |  | ||||||
| 	// get a message | 	// get a message | ||||||
| 	for { | 	for { | ||||||
|  | 		// accept the message | ||||||
| 		m := new(transport.Message) | 		m := new(transport.Message) | ||||||
| 		if err := c.Recv(m); err != nil { | 		if err := c.Recv(m); err != nil { | ||||||
| 			t.Fatal(err) | 			t.Fatal(err) | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 		if v := m.Header["test"]; v != "send" { | ||||||
|  | 			t.Fatalf("Accept side expected test:send header. Received: %s", v) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		// now respond | ||||||
|  | 		m.Header["test"] = "accept" | ||||||
|  | 		if err := c.Send(m); err != nil { | ||||||
|  | 			t.Fatal(err) | ||||||
|  | 		} | ||||||
|  |  | ||||||
| 		wg.Done() | 		wg.Done() | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| @@ -44,13 +56,24 @@ func testSend(t *testing.T, tun Tunnel) { | |||||||
|  |  | ||||||
| 	m := transport.Message{ | 	m := transport.Message{ | ||||||
| 		Header: map[string]string{ | 		Header: map[string]string{ | ||||||
| 			"test": "header", | 			"test": "send", | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// send the message | ||||||
| 	if err := c.Send(&m); err != nil { | 	if err := c.Send(&m); err != nil { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// now wait for the response | ||||||
|  | 	mr := new(transport.Message) | ||||||
|  | 	if err := c.Recv(mr); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if v := mr.Header["test"]; v != "accept" { | ||||||
|  | 		t.Fatalf("Message not received from accepted side. Received: %s", v) | ||||||
|  | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestTunnel(t *testing.T) { | func TestTunnel(t *testing.T) { | ||||||
| @@ -98,3 +121,35 @@ func TestTunnel(t *testing.T) { | |||||||
| 	// wait until done | 	// wait until done | ||||||
| 	wg.Wait() | 	wg.Wait() | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestLoopbackTunnel(t *testing.T) { | ||||||
|  | 	// create a new tunnel client | ||||||
|  | 	tun := NewTunnel( | ||||||
|  | 		Address("127.0.0.1:9096"), | ||||||
|  | 		Nodes("127.0.0.1:9096"), | ||||||
|  | 	) | ||||||
|  |  | ||||||
|  | 	// start tunB | ||||||
|  | 	err := tun.Connect() | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	defer tun.Close() | ||||||
|  |  | ||||||
|  | 	time.Sleep(time.Millisecond * 50) | ||||||
|  |  | ||||||
|  | 	var wg sync.WaitGroup | ||||||
|  |  | ||||||
|  | 	// start accepting connections | ||||||
|  | 	// on tunnel A | ||||||
|  | 	wg.Add(1) | ||||||
|  | 	go testAccept(t, tun, &wg) | ||||||
|  |  | ||||||
|  | 	time.Sleep(time.Millisecond * 50) | ||||||
|  |  | ||||||
|  | 	// dial and send via B | ||||||
|  | 	testSend(t, tun) | ||||||
|  |  | ||||||
|  | 	// wait until done | ||||||
|  | 	wg.Wait() | ||||||
|  | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user