diff --git a/network/default.go b/network/default.go index 186513db..1b5669c6 100644 --- a/network/default.go +++ b/network/default.go @@ -179,10 +179,10 @@ func (n *network) resolve() { } // handleNetConn handles network announcement messages -func (n *network) handleNetConn(conn tunnel.Conn, msg chan *transport.Message) { +func (n *network) handleNetConn(sess tunnel.Session, msg chan *transport.Message) { for { m := new(transport.Message) - if err := conn.Recv(m); err != nil { + if err := sess.Recv(m); err != nil { // TODO: should we bail here? log.Debugf("Network tunnel [%s] receive error: %v", NetworkChannel, err) return @@ -349,10 +349,10 @@ func (n *network) announce(client transport.Client) { } // handleCtrlConn handles ControlChannel connections -func (n *network) handleCtrlConn(conn tunnel.Conn, msg chan *transport.Message) { +func (n *network) handleCtrlConn(sess tunnel.Session, msg chan *transport.Message) { for { m := new(transport.Message) - if err := conn.Recv(m); err != nil { + if err := sess.Recv(m); err != nil { // TODO: should we bail here? log.Debugf("Network tunnel advert receive error: %v", err) return diff --git a/tunnel/default.go b/tunnel/default.go index fc94c5a5..51468164 100644 --- a/tunnel/default.go +++ b/tunnel/default.go @@ -1,9 +1,8 @@ package tunnel import ( - "crypto/sha256" "errors" - "fmt" + "strings" "sync" "time" @@ -25,7 +24,10 @@ type tun struct { sync.RWMutex - // tunnel token + // the unique id for this tunnel + id string + + // tunnel token for authentication token string // to indicate if we're connected or not @@ -37,8 +39,8 @@ type tun struct { // close channel closed chan bool - // a map of sockets based on Micro-Tunnel-Id - sockets map[string]*socket + // a map of sessions based on Micro-Tunnel-Channel + sessions map[string]*session // outbound links links map[string]*link @@ -47,25 +49,6 @@ type tun struct { listener transport.Listener } -type link struct { - transport.Socket - // unique id of this link e.g uuid - // which we define for ourselves - id string - // whether its a loopback connection - // this flag is used by the transport listener - // which accepts inbound quic connections - loopback bool - // whether its actually connected - // dialled side sets it to connected - // after sending the message. the - // listener waits for the connect - connected bool - // the last time we received a keepalive - // on this link from the remote side - lastKeepAlive time.Time -} - // create new tunnel on top of a link func newTunnel(opts ...Option) *tun { options := DefaultOptions() @@ -74,12 +57,13 @@ func newTunnel(opts ...Option) *tun { } return &tun{ - options: options, - token: uuid.New().String(), - send: make(chan *message, 128), - closed: make(chan bool), - sockets: make(map[string]*socket), - links: make(map[string]*link), + options: options, + id: options.Id, + token: options.Token, + send: make(chan *message, 128), + closed: make(chan bool), + sessions: make(map[string]*session), + links: make(map[string]*link), } } @@ -93,51 +77,48 @@ func (t *tun) Init(opts ...Option) error { return nil } -// getSocket returns a socket from the internal socket map. -// It does this based on the Micro-Tunnel-Id and Micro-Tunnel-Session -func (t *tun) getSocket(id, session string) (*socket, bool) { - // get the socket +// getSession returns a session from the internal session map. +// It does this based on the Micro-Tunnel-Channel and Micro-Tunnel-Session +func (t *tun) getSession(channel, session string) (*session, bool) { + // get the session t.RLock() - s, ok := t.sockets[id+session] + s, ok := t.sessions[channel+session] t.RUnlock() return s, ok } -// newSocket creates a new socket and saves it -func (t *tun) newSocket(id, session string) (*socket, bool) { - // hash the id - h := sha256.New() - h.Write([]byte(id)) - id = fmt.Sprintf("%x", h.Sum(nil)) - - // new socket - s := &socket{ - id: id, - session: session, +// newSession creates a new session and saves it +func (t *tun) newSession(channel, sessionId string) (*session, bool) { + // new session + s := &session{ + id: t.id, + channel: channel, + session: sessionId, closed: make(chan bool), recv: make(chan *message, 128), send: t.send, wait: make(chan bool), + errChan: make(chan error, 1), } - // save socket + // save session t.Lock() - _, ok := t.sockets[id+session] + _, ok := t.sessions[channel+sessionId] if ok { - // socket already exists + // session already exists t.Unlock() return nil, false } - t.sockets[id+session] = s + t.sessions[channel+sessionId] = s t.Unlock() - // return socket + // return session return s, true } // TODO: use tunnel id as part of the session -func (t *tun) newSession() string { +func (t *tun) newSessionId() string { return uuid.New().String() } @@ -168,10 +149,10 @@ func (t *tun) monitor() { } } -// process outgoing messages sent by all local sockets +// process outgoing messages sent by all local sessions func (t *tun) process() { // manage the send buffer - // all pseudo sockets throw everything down this + // all pseudo sessions throw everything down this for { select { case msg := <-t.send: @@ -190,6 +171,9 @@ func (t *tun) process() { // set the tunnel id on the outgoing message newMsg.Header["Micro-Tunnel-Id"] = msg.id + // set the tunnel channel on the outgoing message + newMsg.Header["Micro-Tunnel-Channel"] = msg.channel + // set the session id newMsg.Header["Micro-Tunnel-Session"] = msg.session @@ -203,10 +187,14 @@ func (t *tun) process() { log.Debugf("No links to send to") } + var sent bool + var err error + for node, link := range t.links { // if the link is not connected skip it if !link.connected { log.Debugf("Link for node %s not connected", node) + err = errors.New("link not connected") continue } @@ -214,6 +202,7 @@ func (t *tun) process() { // this is where we explicitly set the link // in a message received via the listen method if len(msg.link) > 0 && link.id != msg.link { + err = errors.New("link not found") continue } @@ -221,25 +210,41 @@ func (t *tun) process() { // and the message is being sent outbound via // a dialled connection don't use this link if link.loopback && msg.outbound { + err = errors.New("link is loopback") continue } // if the message was being returned by the loopback listener // send it back up the loopback link only if msg.loopback && !link.loopback { + err = errors.New("link is not loopback") continue } // send the message via the current link log.Debugf("Sending %+v to %s", newMsg, node) - if err := link.Send(newMsg); err != nil { - log.Debugf("Tunnel error sending %+v to %s: %v", newMsg, node, err) + if errr := link.Send(newMsg); errr != nil { + log.Debugf("Tunnel error sending %+v to %s: %v", newMsg, node, errr) + err = errors.New(errr.Error()) delete(t.links, node) continue } + // is sent + sent = true } t.Unlock() + + var gerr error + if !sent { + gerr = err + } + + // return error non blocking + select { + case msg.errChan <- gerr: + default: + } case <-t.closed: return } @@ -267,17 +272,23 @@ func (t *tun) listen(link *link) { return } + // always ensure we have the correct auth token + // TODO: segment the tunnel based on token + // e.g use it as the basis + token := msg.Header["Micro-Tunnel-Token"] + if token != t.token { + log.Debugf("Tunnel link %s received invalid token %s", token) + return + } + switch msg.Header["Micro-Tunnel"] { case "connect": log.Debugf("Tunnel link %s received connect message", link.Remote()) - // check the Micro-Tunnel-Token - token, ok := msg.Header["Micro-Tunnel-Token"] - if !ok { - continue - } + + id := msg.Header["Micro-Tunnel-Id"] // are we connecting to ourselves? - if token == t.token { + if id == t.id { link.loopback = true loopback = true } @@ -318,76 +329,77 @@ func (t *tun) listen(link *link) { return } - // strip message header - delete(msg.Header, "Micro-Tunnel") - // the tunnel id id := msg.Header["Micro-Tunnel-Id"] - delete(msg.Header, "Micro-Tunnel-Id") - + // the tunnel channel + channel := msg.Header["Micro-Tunnel-Channel"] // the session id - session := msg.Header["Micro-Tunnel-Session"] - delete(msg.Header, "Micro-Tunnel-Session") + sessionId := msg.Header["Micro-Tunnel-Session"] - // strip token header - delete(msg.Header, "Micro-Tunnel-Token") + // strip tunnel message header + for k, _ := range msg.Header { + if strings.HasPrefix(k, "Micro-Tunnel") { + delete(msg.Header, k) + } + } // if the session id is blank there's nothing we can do // TODO: check this is the case, is there any reason // why we'd have a blank session? Is the tunnel // used for some other purpose? - if len(id) == 0 || len(session) == 0 { + if len(channel) == 0 || len(sessionId) == 0 { continue } - var s *socket + var s *session var exists bool // If its a loopback connection then we've enabled link direction // listening side is used for listening, the dialling side for dialling switch { case loopback: - s, exists = t.getSocket(id, "listener") + s, exists = t.getSession(channel, "listener") default: - // get the socket based on the tunnel id and session + // get the session based on the tunnel id and session // this could be something we dialed in which case // we have a session for it otherwise its a listener - s, exists = t.getSocket(id, session) + s, exists = t.getSession(channel, sessionId) if !exists { // try get it based on just the tunnel id // the assumption here is that a listener // has no session but its set a listener session - s, exists = t.getSocket(id, "listener") + s, exists = t.getSession(channel, "listener") } } - // bail if no socket has been found + // bail if no session has been found if !exists { - log.Debugf("Tunnel skipping no socket exists") + log.Debugf("Tunnel skipping no session exists") // drop it, we don't care about // messages we don't know about continue } - log.Debugf("Tunnel using socket %s %s", s.id, s.session) - // is the socket closed? + log.Debugf("Tunnel using session %s %s", s.channel, s.session) + + // is the session closed? select { case <-s.closed: // closed - delete(t.sockets, id) + delete(t.sessions, channel) continue default: // process } - // is the socket new? + // is the session new? select { - // if its new the socket is actually blocked waiting + // if its new the session is actually blocked waiting // for a connection. so we check if its waiting. case <-s.wait: // if its waiting e.g its new then we close it default: - // set remote address of the socket + // set remote address of the session s.remote = msg.Header["Remote"] close(s.wait) } @@ -401,10 +413,12 @@ func (t *tun) listen(link *link) { // construct the internal message imsg := &message{ id: id, - session: session, + channel: channel, + session: sessionId, data: tmsg, link: link.id, loopback: loopback, + errChan: make(chan error, 1), } // append to recv backlog @@ -431,6 +445,7 @@ func (t *tun) keepalive(link *link) { if err := link.Send(&transport.Message{ Header: map[string]string{ "Micro-Tunnel": "keepalive", + "Micro-Tunnel-Id": t.id, "Micro-Tunnel-Token": t.token, }, }); err != nil { @@ -458,6 +473,7 @@ func (t *tun) setupLink(node string) (*link, error) { if err := c.Send(&transport.Message{ Header: map[string]string{ "Micro-Tunnel": "connect", + "Micro-Tunnel-Id": t.id, "Micro-Tunnel-Token": t.token, }, }); err != nil { @@ -568,6 +584,7 @@ func (t *tun) close() error { link.Send(&transport.Message{ Header: map[string]string{ "Micro-Tunnel": "close", + "Micro-Tunnel-Id": t.id, "Micro-Tunnel-Token": t.token, }, }) @@ -603,10 +620,10 @@ func (t *tun) Close() error { case <-t.closed: return nil default: - // close all the sockets - for id, s := range t.sockets { + // close all the sessions + for id, s := range t.sessions { s.Close() - delete(t.sockets, id) + delete(t.sessions, id) } // close the connection close(t.closed) @@ -622,52 +639,50 @@ func (t *tun) Close() error { } // Dial an address -func (t *tun) Dial(addr string) (Conn, error) { - log.Debugf("Tunnel dialing %s", addr) - c, ok := t.newSocket(addr, t.newSession()) +func (t *tun) Dial(channel string) (Session, error) { + log.Debugf("Tunnel dialing %s", channel) + c, ok := t.newSession(channel, t.newSessionId()) if !ok { - return nil, errors.New("error dialing " + addr) + return nil, errors.New("error dialing " + channel) } // set remote - c.remote = addr + c.remote = channel // set local c.local = "local" - // outbound socket + // outbound session c.outbound = true return c, nil } // Accept a connection on the address -func (t *tun) Listen(addr string) (Listener, error) { - log.Debugf("Tunnel listening on %s", addr) - // create a new socket by hashing the address - c, ok := t.newSocket(addr, "listener") +func (t *tun) Listen(channel string) (Listener, error) { + log.Debugf("Tunnel listening on %s", channel) + // create a new session by hashing the address + c, ok := t.newSession(channel, "listener") if !ok { - return nil, errors.New("already listening on " + addr) + return nil, errors.New("already listening on " + channel) } // set remote. it will be replaced by the first message received c.remote = "remote" // set local - c.local = addr + c.local = channel tl := &tunListener{ - addr: addr, + channel: channel, // the accept channel - accept: make(chan *socket, 128), + accept: make(chan *session, 128), // the channel to close closed: make(chan bool), // tunnel closed channel tunClosed: t.closed, - // the connection - conn: c, - // the listener socket - socket: c, + // the listener session + session: c, } // this kicks off the internal message processor - // for the listener so it can create pseudo sockets + // for the listener so it can create pseudo sessions // per session if they do not exist or pass messages // to the existign sessions go tl.process() diff --git a/tunnel/link.go b/tunnel/link.go index 6b8f30aa..fbec2e6a 100644 --- a/tunnel/link.go +++ b/tunnel/link.go @@ -1,10 +1,34 @@ package tunnel import ( + "sync" + "time" + "github.com/google/uuid" "github.com/micro/go-micro/transport" ) +type link struct { + sync.RWMutex + + transport.Socket + // unique id of this link e.g uuid + // which we define for ourselves + id string + // whether its a loopback connection + // this flag is used by the transport listener + // which accepts inbound quic connections + loopback bool + // whether its actually connected + // dialled side sets it to connected + // after sending the message. the + // listener waits for the connect + connected bool + // the last time we received a keepalive + // on this link from the remote side + lastKeepAlive time.Time +} + func newLink(s transport.Socket) *link { return &link{ Socket: s, diff --git a/tunnel/listener.go b/tunnel/listener.go index 42aadfd1..d62b58de 100644 --- a/tunnel/listener.go +++ b/tunnel/listener.go @@ -8,37 +8,37 @@ import ( type tunListener struct { // address of the listener - addr string + channel string // the accept channel - accept chan *socket + accept chan *session // the channel to close closed chan bool // the tunnel closed channel tunClosed chan bool - // the connection - conn Conn - // the listener socket - socket *socket + // the listener session + session *session } func (t *tunListener) process() { // our connection map for session - conns := make(map[string]*socket) + conns := make(map[string]*session) for { select { case <-t.closed: return // receive a new message - case m := <-t.socket.recv: - // get a socket - sock, ok := conns[m.session] + case m := <-t.session.recv: + // get a session + sess, ok := conns[m.session] log.Debugf("Tunnel listener received id %s session %s exists: %t", m.id, m.session, ok) if !ok { - // create a new socket session - sock = &socket{ - // our tunnel id + // create a new session session + sess = &session{ + // the id of the remote side id: m.id, + // the channel + channel: m.channel, // the session id session: m.session, // is loopback conn @@ -50,35 +50,37 @@ func (t *tunListener) process() { // recv called by the acceptor recv: make(chan *message, 128), // use the internal send buffer - send: t.socket.send, + send: t.session.send, // wait wait: make(chan bool), + // error channel + errChan: make(chan error, 1), } - // save the socket - conns[m.session] = sock + // save the session + conns[m.session] = sess // send to accept chan select { case <-t.closed: return - case t.accept <- sock: + case t.accept <- sess: } } // send this to the accept chan select { - case <-sock.closed: + case <-sess.closed: delete(conns, m.session) - case sock.recv <- m: + case sess.recv <- m: log.Debugf("Tunnel listener sent to recv chan id %s session %s", m.id, m.session) } } } } -func (t *tunListener) Addr() string { - return t.addr +func (t *tunListener) Channel() string { + return t.channel } // Close closes tunnel listener @@ -93,9 +95,9 @@ func (t *tunListener) Close() error { } // Everytime accept is called we essentially block till we get a new connection -func (t *tunListener) Accept() (Conn, error) { +func (t *tunListener) Accept() (Session, error) { select { - // if the socket is closed return + // if the session is closed return case <-t.closed: return nil, io.EOF case <-t.tunClosed: diff --git a/tunnel/options.go b/tunnel/options.go index 99406b05..9d612173 100644 --- a/tunnel/options.go +++ b/tunnel/options.go @@ -9,6 +9,8 @@ import ( var ( // DefaultAddress is default tunnel bind address DefaultAddress = ":0" + // The shared default token + DefaultToken = "micro" ) type Option func(*Options) @@ -21,6 +23,8 @@ type Options struct { Address string // Nodes are remote nodes Nodes []string + // The shared auth token + Token string // Transport listens to incoming connections Transport transport.Transport } @@ -46,6 +50,13 @@ func Nodes(n ...string) Option { } } +// Token sets the shared token for auth +func Token(t string) Option { + return func(o *Options) { + o.Token = t + } +} + // Transport listens for incoming connections func Transport(t transport.Transport) Option { return func(o *Options) { @@ -58,6 +69,7 @@ func DefaultOptions() Options { return Options{ Id: uuid.New().String(), Address: DefaultAddress, + Token: DefaultToken, Transport: quic.NewTransport(), } } diff --git a/tunnel/socket.go b/tunnel/session.go similarity index 60% rename from tunnel/socket.go rename to tunnel/session.go index 789d9555..a4c779f3 100644 --- a/tunnel/socket.go +++ b/tunnel/session.go @@ -2,15 +2,18 @@ package tunnel import ( "errors" + "io" "github.com/micro/go-micro/transport" "github.com/micro/go-micro/util/log" ) -// socket is our pseudo socket for transport.Socket -type socket struct { - // socket id based on Micro-Tunnel +// session is our pseudo session for transport.Socket +type session struct { + // unique id based on the remote tunnel id id string + // the channel name + channel string // the session id based on Micro.Tunnel-Session session string // closed @@ -25,12 +28,14 @@ type socket struct { recv chan *message // wait until we have a connection wait chan bool - // outbound marks the socket as outbound dialled connection + // outbound marks the session as outbound dialled connection outbound bool - // lookback marks the socket as a loopback on the inbound + // lookback marks the session as a loopback on the inbound loopback bool // the link on which this message was received link string + // the error response + errChan chan error } // message is sent over the send channel @@ -39,6 +44,8 @@ type message struct { typ string // tunnel id id string + // channel name + channel string // the session id session string // outbound marks the message as outbound @@ -49,28 +56,30 @@ type message struct { link string // transport data data *transport.Message + // the error channel + errChan chan error } -func (s *socket) Remote() string { +func (s *session) Remote() string { return s.remote } -func (s *socket) Local() string { +func (s *session) Local() string { return s.local } -func (s *socket) Id() string { - return s.id -} - -func (s *socket) Session() string { +func (s *session) Id() string { return s.session } -func (s *socket) Send(m *transport.Message) error { +func (s *session) Channel() string { + return s.channel +} + +func (s *session) Send(m *transport.Message) error { select { case <-s.closed: - return errors.New("socket is closed") + return errors.New("session is closed") default: // no op } @@ -89,28 +98,48 @@ func (s *socket) Send(m *transport.Message) error { msg := &message{ typ: "message", id: s.id, + channel: s.channel, session: s.session, outbound: s.outbound, loopback: s.loopback, data: data, // specify the link on which to send this - // it will be blank for dialled sockets + // it will be blank for dialled sessions link: s.link, + // error chan + errChan: s.errChan, } log.Debugf("Appending %+v to send backlog", msg) s.send <- msg + + // wait for an error response + select { + case err := <-msg.errChan: + return err + case <-s.closed: + return io.EOF + } + return nil } -func (s *socket) Recv(m *transport.Message) error { +func (s *session) Recv(m *transport.Message) error { select { case <-s.closed: - return errors.New("socket is closed") + return errors.New("session is closed") default: // no op } // recv from backlog msg := <-s.recv + + // check the error if one exists + select { + case err := <-msg.errChan: + return err + default: + } + log.Debugf("Received %+v from recv backlog", msg) // set message *m = *msg.data @@ -118,8 +147,8 @@ func (s *socket) Recv(m *transport.Message) error { return nil } -// Close closes the socket -func (s *socket) Close() error { +// Close closes the session +func (s *session) Close() error { select { case <-s.closed: // no op diff --git a/tunnel/transport/listener.go b/tunnel/transport/listener.go index b7a7280c..075f12cf 100644 --- a/tunnel/transport/listener.go +++ b/tunnel/transport/listener.go @@ -10,7 +10,7 @@ type tunListener struct { } func (t *tunListener) Addr() string { - return t.l.Addr() + return t.l.Channel() } func (t *tunListener) Close() error { diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index f7bc91cb..0349293a 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -17,27 +17,27 @@ type Tunnel interface { Connect() error // Close closes the tunnel Close() error - // Dial an endpoint - Dial(addr string) (Conn, error) - // Accept connections - Listen(addr string) (Listener, error) + // Connect to a channel + Dial(channel string) (Session, error) + // Accept connections on a channel + Listen(channel string) (Listener, error) // Name of the tunnel implementation String() string } // The listener provides similar constructs to the transport.Listener type Listener interface { - Addr() string + Channel() string Close() error - Accept() (Conn, error) + Accept() (Session, error) } -// Conn is a connection dialed or accepted which includes the tunnel id and session -type Conn interface { +// Session is a unique session created when dialling or accepting connections on the tunnel +type Session interface { // Specifies the tunnel id Id() string // The session - Session() string + Channel() string // a transport socket transport.Socket }