diff --git a/network/tunnel/default.go b/network/tunnel/default.go new file mode 100644 index 00000000..41c9778b --- /dev/null +++ b/network/tunnel/default.go @@ -0,0 +1,229 @@ +package tunnel + +import ( + "crypto/sha256" + "errors" + "fmt" + "sync" + + "github.com/google/uuid" + "github.com/micro/go-micro/transport" +) + +// tun represents a network tunnel +type tun struct { + // interface to use + net Interface + + // connect + mtx sync.RWMutex + connected bool + + // the send channel + send chan *message + // close channel + closed chan bool + + // sockets + sockets map[string]*socket +} + +// create new tunnel +func newTunnel(net Interface) *tun { + return &tun{ + net: net, + send: make(chan *message, 128), + closed: make(chan bool), + sockets: make(map[string]*socket), + } +} + +func (t *tun) getSocket(id string) (*socket, bool) { + // get the socket + t.mtx.RLock() + s, ok := t.sockets[id] + t.mtx.RUnlock() + return s, ok +} + +func (t *tun) newSocket(id string) *socket { + // new id if it doesn't exist + if len(id) == 0 { + id = uuid.New().String() + } + + // hash the id + h := sha256.New() + h.Write([]byte(id)) + id = fmt.Sprintf("%x", h.Sum(nil)) + + // new socket + s := &socket{ + id: id, + closed: make(chan bool), + recv: make(chan *message, 128), + send: t.send, + } + + // save socket + t.mtx.Lock() + t.sockets[id] = s + t.mtx.Unlock() + + // return socket + return s +} + +// process outgoing messages +func (t *tun) process() { + // manage the send buffer + // all pseudo sockets throw everything down this + for { + select { + case msg := <-t.send: + nmsg := &Message{ + Header: msg.data.Header, + Body: msg.data.Body, + } + + // set the stream id on the outgoing message + nmsg.Header["Micro-Stream"] = msg.id + + // send the message via the interface + if err := t.net.Send(nmsg); err != nil { + // no op + // TODO: do something + } + case <-t.closed: + return + } + } +} + +// process incoming messages +func (t *tun) listen() { + for { + // process anything via the net interface + msg, err := t.net.Recv() + if err != nil { + return + } + + // a stream id + id := msg.Header["Micro-Stream"] + + // get the socket + s, exists := t.getSocket(id) + if !exists { + // no op + continue + } + + // is the socket closed? + select { + case <-s.closed: + // closed + delete(t.sockets, id) + continue + default: + // process + } + + // is the socket new? + select { + // if its new it will block here + case <-s.wait: + // its not new + default: + // its new + // set remote address of the socket + s.remote = msg.Header["Remote"] + close(s.wait) + } + + tmsg := &transport.Message{ + Header: msg.Header, + Body: msg.Body, + } + + // TODO: don't block on queuing + // append to recv backlog + s.recv <- &message{id: id, data: tmsg} + } +} + +// Close the tunnel +func (t *tun) Close() error { + t.mtx.Lock() + defer t.mtx.Unlock() + + if !t.connected { + return nil + } + + select { + case <-t.closed: + return nil + default: + // close all the sockets + for _, s := range t.sockets { + s.Close() + } + // close the connection + close(t.closed) + t.connected = false + } + + return nil +} + +// Connect the tunnel +func (t *tun) Connect() error { + t.mtx.Lock() + defer t.mtx.Unlock() + + // already connected + if t.connected { + return nil + } + + // set as connected + t.connected = true + // create new close channel + t.closed = make(chan bool) + + // process messages to be sent + go t.process() + // process incoming messages + go t.listen() + + return nil +} + +// Dial an address +func (t *tun) Dial(addr string) (Conn, error) { + c := t.newSocket(addr) + // set remote + c.remote = addr + // set local + c.local = t.net.Addr() + return c, nil +} + +func (t *tun) Accept(addr string) (Conn, error) { + c := t.newSocket(addr) + // set remote + c.remote = t.net.Addr() + // set local + c.local = addr + + select { + case <-c.closed: + return nil, errors.New("error creating socket") + // wait for the first message + case <-c.wait: + } + + // return socket + return c, nil +} diff --git a/network/tunnel/socket.go b/network/tunnel/socket.go new file mode 100644 index 00000000..e0ce1350 --- /dev/null +++ b/network/tunnel/socket.go @@ -0,0 +1,82 @@ +package tunnel + +import ( + "errors" + + "github.com/micro/go-micro/transport" +) + +// socket is our pseudo socket for transport.Socket +type socket struct { + // socket id based on Micro-Stream + id string + // closed + closed chan bool + // remote addr + remote string + // local addr + local string + // send chan + send chan *message + // recv chan + recv chan *message + // wait until we have a connection + wait chan bool +} + +// message is sent over the send channel +type message struct { + // socket id + id string + // transport data + data *transport.Message +} + +func (s *socket) Remote() string { + return s.remote +} + +func (s *socket) Local() string { + return s.local +} + +func (s *socket) Id() string { + return s.id +} + +func (s *socket) Send(m *transport.Message) error { + select { + case <-s.closed: + return errors.New("socket is closed") + default: + // no op + } + // append to backlog + s.send <- &message{id: s.id, data: m} + return nil +} + +func (s *socket) Recv(m *transport.Message) error { + select { + case <-s.closed: + return errors.New("socket is closed") + default: + // no op + } + // recv from backlog + msg := <-s.recv + // set message + *m = *msg.data + // return nil + return nil +} + +func (s *socket) Close() error { + select { + case <-s.closed: + // no op + default: + close(s.closed) + } + return nil +} diff --git a/network/tunnel/socket_test.go b/network/tunnel/socket_test.go new file mode 100644 index 00000000..37c1aeaa --- /dev/null +++ b/network/tunnel/socket_test.go @@ -0,0 +1,62 @@ +package tunnel + +import ( + "testing" + + "github.com/micro/go-micro/transport" +) + +func TestTunnelSocket(t *testing.T) { + s := &socket{ + id: "1", + closed: make(chan bool), + remote: "remote", + local: "local", + send: make(chan *message, 1), + recv: make(chan *message, 1), + wait: make(chan bool), + } + + // check addresses local and remote + if s.Local() != s.local { + t.Fatalf("Expected s.Local %s got %s", s.local, s.Local()) + } + if s.Remote() != s.remote { + t.Fatalf("Expected s.Remote %s got %s", s.remote, s.Remote()) + } + + // send a message + s.Send(&transport.Message{Header: map[string]string{}}) + + // get sent message + msg := <-s.send + + if msg.id != s.id { + t.Fatalf("Expected sent message id %s got %s", s.id, msg.id) + } + + // recv a message + msg.data.Header["Foo"] = "bar" + s.recv <- msg + + m := new(transport.Message) + s.Recv(m) + + // check header + if m.Header["Foo"] != "bar" { + t.Fatalf("Did not receive correct message %+v", m) + } + + // close the connection + s.Close() + + // check connection + err := s.Send(m) + if err == nil { + t.Fatal("Expected closed connection") + } + err = s.Recv(m) + if err == nil { + t.Fatal("Expected closed connection") + } +} diff --git a/network/tunnel/tunnel.go b/network/tunnel/tunnel.go new file mode 100644 index 00000000..1f7cd2b7 --- /dev/null +++ b/network/tunnel/tunnel.go @@ -0,0 +1,64 @@ +// Package tunnel provides a network tunnel +package tunnel + +import ( + "github.com/micro/go-micro/config/options" + "github.com/micro/go-micro/transport" +) + +// Tunnel creates a network tunnel +type Tunnel interface { + // Connect connects the tunnel + Connect() error + // Close closes the tunnel + Close() error + // Dial an endpoint + Dial(addr string) (Conn, error) + // Accept connections + Accept(addr string) (Conn, error) +} + +// Conn return a transport socket with a unique id. +// This means Conn can be used as a transport.Socket +type Conn interface { + // Unique id of the connection + Id() string + // Underlying socket + transport.Socket +} + +// A network interface to use for sending/receiving. +// When Tunnel.Connect is called it starts processing +// messages over the interface. +type Interface interface { + // Address of the interface + Addr() string + // Receive new messages + Recv() (*Message, error) + // Send messages + Send(*Message) error +} + +// Messages received over the interface +type Message struct { + Header map[string]string + Body []byte +} + +// NewTunnel creates a new tunnel +func NewTunnel(opts ...options.Option) Tunnel { + options := options.NewOptions(opts...) + + i, ok := options.Values().Get("tunnel.net") + if !ok { + // wtf + return nil + } + + return newTunnel(i.(Interface)) +} + +// WithInterface passes in the interface +func WithInterface(net Interface) options.Option { + return options.WithValue("tunnel.net", net) +}