commit
f19308f1e6
@ -17,6 +17,9 @@ type tun struct {
|
||||
|
||||
sync.RWMutex
|
||||
|
||||
// tunnel token
|
||||
token string
|
||||
|
||||
// to indicate if we're connected or not
|
||||
connected bool
|
||||
|
||||
@ -50,6 +53,7 @@ func newTunnel(opts ...Option) *tun {
|
||||
|
||||
return &tun{
|
||||
options: options,
|
||||
token: uuid.New().String(),
|
||||
send: make(chan *message, 128),
|
||||
closed: make(chan bool),
|
||||
sockets: make(map[string]*socket),
|
||||
@ -57,6 +61,14 @@ func newTunnel(opts ...Option) *tun {
|
||||
}
|
||||
}
|
||||
|
||||
// Init initializes tunnel options
|
||||
func (t *tun) Init(opts ...Option) error {
|
||||
for _, o := range opts {
|
||||
o(&t.options)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getSocket returns a socket from the internal socket map.
|
||||
// It does this based on the Micro-Tunnel-Id and Micro-Tunnel-Session
|
||||
func (t *tun) getSocket(id, session string) (*socket, bool) {
|
||||
@ -92,6 +104,7 @@ func (t *tun) newSocket(id, session string) (*socket, bool) {
|
||||
t.Unlock()
|
||||
return nil, false
|
||||
}
|
||||
|
||||
t.sockets[id+session] = s
|
||||
t.Unlock()
|
||||
|
||||
@ -126,6 +139,9 @@ func (t *tun) process() {
|
||||
// set the session id
|
||||
newMsg.Header["Micro-Tunnel-Session"] = msg.session
|
||||
|
||||
// set the tunnel token
|
||||
newMsg.Header["Micro-Tunnel-Token"] = t.token
|
||||
|
||||
// send the message via the interface
|
||||
t.RLock()
|
||||
if len(t.links) == 0 {
|
||||
@ -144,7 +160,10 @@ func (t *tun) process() {
|
||||
}
|
||||
|
||||
// process incoming messages
|
||||
func (t *tun) listen(link transport.Socket, listener bool) {
|
||||
func (t *tun) listen(link transport.Socket) {
|
||||
// loopback flag
|
||||
var loopback bool
|
||||
|
||||
for {
|
||||
// process anything via the net interface
|
||||
msg := new(transport.Message)
|
||||
@ -155,10 +174,21 @@ func (t *tun) listen(link transport.Socket, listener bool) {
|
||||
}
|
||||
|
||||
switch msg.Header["Micro-Tunnel"] {
|
||||
case "connect", "close":
|
||||
// TODO: handle the connect/close message
|
||||
// maybe used to create the dial/listen sockets
|
||||
// or report io.EOF or maybe to kill the link
|
||||
case "connect":
|
||||
// check the Micro-Tunnel-Token
|
||||
token, ok := msg.Header["Micro-Tunnel-Token"]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// are we connecting to ourselves?
|
||||
if token == t.token {
|
||||
loopback = true
|
||||
}
|
||||
continue
|
||||
case "close":
|
||||
// TODO: handle the close message
|
||||
// maybe report io.EOF or kill the link
|
||||
continue
|
||||
}
|
||||
|
||||
@ -182,18 +212,23 @@ func (t *tun) listen(link transport.Socket, listener bool) {
|
||||
var exists bool
|
||||
|
||||
log.Debugf("Received %+v from %s", msg, link.Remote())
|
||||
// get the socket based on the tunnel id and session
|
||||
// this could be something we dialed in which case
|
||||
// we have a session for it otherwise its a listener
|
||||
s, exists = t.getSocket(id, session)
|
||||
if !exists {
|
||||
// try get it based on just the tunnel id
|
||||
// the assumption here is that a listener
|
||||
// has no session but its set a listener session
|
||||
s, exists = t.getSocket(id, "listener")
|
||||
}
|
||||
|
||||
// no socket in existence
|
||||
switch {
|
||||
case loopback:
|
||||
s, exists = t.getSocket(id, "listener")
|
||||
default:
|
||||
// get the socket based on the tunnel id and session
|
||||
// this could be something we dialed in which case
|
||||
// we have a session for it otherwise its a listener
|
||||
s, exists = t.getSocket(id, session)
|
||||
if !exists {
|
||||
// try get it based on just the tunnel id
|
||||
// the assumption here is that a listener
|
||||
// has no session but its set a listener session
|
||||
s, exists = t.getSocket(id, "listener")
|
||||
}
|
||||
}
|
||||
// bail if no socket has been found
|
||||
if !exists {
|
||||
log.Debugf("Tunnel skipping no socket exists")
|
||||
// drop it, we don't care about
|
||||
@ -246,6 +281,7 @@ func (t *tun) listen(link transport.Socket, listener bool) {
|
||||
}
|
||||
}
|
||||
|
||||
// connect the tunnel to all the nodes and listen for incoming tunnel connections
|
||||
func (t *tun) connect() error {
|
||||
l, err := t.options.Transport.Listen(t.options.Address)
|
||||
if err != nil {
|
||||
@ -277,7 +313,7 @@ func (t *tun) connect() error {
|
||||
}()
|
||||
|
||||
// listen for inbound messages
|
||||
t.listen(sock, true)
|
||||
t.listen(sock)
|
||||
})
|
||||
|
||||
t.Lock()
|
||||
@ -306,14 +342,15 @@ func (t *tun) connect() error {
|
||||
|
||||
if err := c.Send(&transport.Message{
|
||||
Header: map[string]string{
|
||||
"Micro-Tunnel": "connect",
|
||||
"Micro-Tunnel": "connect",
|
||||
"Micro-Tunnel-Token": t.token,
|
||||
},
|
||||
}); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// process incoming messages
|
||||
go t.listen(c, false)
|
||||
go t.listen(c)
|
||||
|
||||
// save the link
|
||||
id := uuid.New().String()
|
||||
@ -330,12 +367,36 @@ func (t *tun) connect() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Connect the tunnel
|
||||
func (t *tun) Connect() error {
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
|
||||
// already connected
|
||||
if t.connected {
|
||||
return nil
|
||||
}
|
||||
|
||||
// send the connect message
|
||||
if err := t.connect(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// set as connected
|
||||
t.connected = true
|
||||
// create new close channel
|
||||
t.closed = make(chan bool)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) close() error {
|
||||
// close all the links
|
||||
for id, link := range t.links {
|
||||
link.Send(&transport.Message{
|
||||
Header: map[string]string{
|
||||
"Micro-Tunnel": "close",
|
||||
"Micro-Tunnel": "close",
|
||||
"Micro-Tunnel-Token": t.token,
|
||||
},
|
||||
})
|
||||
link.Close()
|
||||
@ -376,36 +437,6 @@ func (t *tun) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Connect the tunnel
|
||||
func (t *tun) Connect() error {
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
|
||||
// already connected
|
||||
if t.connected {
|
||||
return nil
|
||||
}
|
||||
|
||||
// send the connect message
|
||||
if err := t.connect(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// set as connected
|
||||
t.connected = true
|
||||
// create new close channel
|
||||
t.closed = make(chan bool)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) Init(opts ...Option) error {
|
||||
for _, o := range opts {
|
||||
o(&t.options)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Dial an address
|
||||
func (t *tun) Dial(addr string) (Conn, error) {
|
||||
log.Debugf("Tunnel dialing %s", addr)
|
||||
@ -413,7 +444,6 @@ func (t *tun) Dial(addr string) (Conn, error) {
|
||||
if !ok {
|
||||
return nil, errors.New("error dialing " + addr)
|
||||
}
|
||||
|
||||
// set remote
|
||||
c.remote = addr
|
||||
// set local
|
||||
|
@ -24,10 +24,22 @@ func testAccept(t *testing.T, tun Tunnel, wg *sync.WaitGroup) {
|
||||
|
||||
// get a message
|
||||
for {
|
||||
// accept the message
|
||||
m := new(transport.Message)
|
||||
if err := c.Recv(m); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if v := m.Header["test"]; v != "send" {
|
||||
t.Fatalf("Accept side expected test:send header. Received: %s", v)
|
||||
}
|
||||
|
||||
// now respond
|
||||
m.Header["test"] = "accept"
|
||||
if err := c.Send(m); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
wg.Done()
|
||||
return
|
||||
}
|
||||
@ -44,13 +56,24 @@ func testSend(t *testing.T, tun Tunnel) {
|
||||
|
||||
m := transport.Message{
|
||||
Header: map[string]string{
|
||||
"test": "header",
|
||||
"test": "send",
|
||||
},
|
||||
}
|
||||
|
||||
// send the message
|
||||
if err := c.Send(&m); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// now wait for the response
|
||||
mr := new(transport.Message)
|
||||
if err := c.Recv(mr); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if v := mr.Header["test"]; v != "accept" {
|
||||
t.Fatalf("Message not received from accepted side. Received: %s", v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTunnel(t *testing.T) {
|
||||
@ -98,3 +121,35 @@ func TestTunnel(t *testing.T) {
|
||||
// wait until done
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestLoopbackTunnel(t *testing.T) {
|
||||
// create a new tunnel client
|
||||
tun := NewTunnel(
|
||||
Address("127.0.0.1:9096"),
|
||||
Nodes("127.0.0.1:9096"),
|
||||
)
|
||||
|
||||
// start tunB
|
||||
err := tun.Connect()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer tun.Close()
|
||||
|
||||
time.Sleep(time.Millisecond * 50)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// start accepting connections
|
||||
// on tunnel A
|
||||
wg.Add(1)
|
||||
go testAccept(t, tun, &wg)
|
||||
|
||||
time.Sleep(time.Millisecond * 50)
|
||||
|
||||
// dial and send via B
|
||||
testSend(t, tun)
|
||||
|
||||
// wait until done
|
||||
wg.Wait()
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user