diff --git a/tunnel/default.go b/tunnel/default.go index 69a5d878..f03016ed 100644 --- a/tunnel/default.go +++ b/tunnel/default.go @@ -49,8 +49,20 @@ type tun struct { type link struct { transport.Socket - id string - loopback bool + // unique id of this link e.g uuid + // which we define for ourselves + id string + // whether its a loopback connection + // this flag is used by the transport listener + // which accepts inbound quic connections + loopback bool + // whether its actually connected + // dialled side sets it to connected + // after sending the message. the + // listener waits for the connect + connected bool + // the last time we received a keepalive + // on this link from the remote side lastKeepAlive time.Time } @@ -173,7 +185,7 @@ func (t *tun) process() { } // set message head - newMsg.Header["Micro-Tunnel"] = "message" + newMsg.Header["Micro-Tunnel"] = msg.typ // set the tunnel id on the outgoing message newMsg.Header["Micro-Tunnel-Id"] = msg.id @@ -186,13 +198,39 @@ func (t *tun) process() { // send the message via the interface t.Lock() + if len(t.links) == 0 { log.Debugf("No links to send to") } + for node, link := range t.links { + // if the link is not connected skip it + if !link.connected { + log.Debugf("Link for node %s not connected", node) + continue + } + + // if we're picking the link check the id + // this is where we explicitly set the link + // in a message received via the listen method + if len(msg.link) > 0 && link.id != msg.link { + continue + } + + // if the link was a loopback accepted connection + // and the message is being sent outbound via + // a dialled connection don't use this link if link.loopback && msg.outbound { continue } + + // if the message was being returned by the loopback listener + // send it back up the loopback link only + if msg.loopback && !link.loopback { + continue + } + + // send the message via the current link log.Debugf("Sending %+v to %s", newMsg, node) if err := link.Send(newMsg); err != nil { log.Debugf("Tunnel error sending %+v to %s: %v", newMsg, node, err) @@ -200,6 +238,7 @@ func (t *tun) process() { continue } } + t.Unlock() case <-t.closed: return @@ -209,15 +248,22 @@ func (t *tun) process() { // process incoming messages func (t *tun) listen(link *link) { + // remove the link on exit + defer func() { + log.Debugf("Tunnel deleting connection from %s", link.Remote()) + t.Lock() + delete(t.links, link.Remote()) + t.Unlock() + }() + + // let us know if its a loopback + var loopback bool + for { // process anything via the net interface msg := new(transport.Message) - err := link.Recv(msg) - if err != nil { + if err := link.Recv(msg); err != nil { log.Debugf("Tunnel link %s receive error: %#v", link.Remote(), err) - t.Lock() - delete(t.links, link.Remote()) - t.Unlock() return } @@ -232,21 +278,29 @@ func (t *tun) listen(link *link) { // are we connecting to ourselves? if token == t.token { - t.Lock() link.loopback = true - t.Unlock() + loopback = true } + // set as connected + link.connected = true + + // save the link once connected + t.Lock() + t.links[link.Remote()] = link + t.Unlock() + // nothing more to do continue case "close": log.Debugf("Tunnel link %s closing connection", link.Remote()) // TODO: handle the close message // maybe report io.EOF or kill the link - continue + return case "keepalive": log.Debugf("Tunnel link %s received keepalive", link.Remote()) t.Lock() + // save the keepalive link.lastKeepAlive = time.Now() t.Unlock() continue @@ -258,6 +312,12 @@ func (t *tun) listen(link *link) { continue } + // if its not connected throw away the link + if !link.connected { + log.Debugf("Tunnel link %s not connected", link.id) + return + } + // strip message header delete(msg.Header, "Micro-Tunnel") @@ -283,8 +343,10 @@ func (t *tun) listen(link *link) { var s *socket var exists bool + // If its a loopback connection then we've enabled link direction + // listening side is used for listening, the dialling side for dialling switch { - case link.loopback: + case loopback: s, exists = t.getSocket(id, "listener") default: // get the socket based on the tunnel id and session @@ -298,6 +360,7 @@ func (t *tun) listen(link *link) { s, exists = t.getSocket(id, "listener") } } + // bail if no socket has been found if !exists { log.Debugf("Tunnel skipping no socket exists") @@ -337,9 +400,11 @@ func (t *tun) listen(link *link) { // construct the internal message imsg := &message{ - id: id, - session: session, - data: tmsg, + id: id, + session: session, + data: tmsg, + link: link.id, + loopback: loopback, } // append to recv backlog @@ -399,13 +464,11 @@ func (t *tun) setupLink(node string) (*link, error) { return nil, err } - // save the link - id := uuid.New().String() - link := &link{ - Socket: c, - id: id, - } - t.links[node] = link + // create a new link + link := newLink(c) + link.connected = true + // we made the outbound connection + // and sent the connect message // process incoming messages go t.listen(link) @@ -430,30 +493,18 @@ func (t *tun) connect() error { // accept inbound connections err := l.Accept(func(sock transport.Socket) { log.Debugf("Tunnel accepted connection from %s", sock.Remote()) - // save the link - id := uuid.New().String() - t.Lock() - link := &link{ - Socket: sock, - id: id, - } - t.links[sock.Remote()] = link - t.Unlock() - // delete the link - defer func() { - log.Debugf("Tunnel deleting connection from %s", sock.Remote()) - t.Lock() - delete(t.links, sock.Remote()) - t.Unlock() - }() + // create a new link + link := newLink(sock) - // listen for inbound messages + // listen for inbound messages. + // only save the link once connected. + // we do this inside liste t.listen(link) }) - t.Lock() - defer t.Unlock() + t.RLock() + defer t.RUnlock() // still connected but the tunnel died if err != nil && t.connected { @@ -473,6 +524,7 @@ func (t *tun) connect() error { log.Debugf("Tunnel failed to establish node link to %s: %v", node, err) continue } + // save the link t.links[node] = link } diff --git a/tunnel/link.go b/tunnel/link.go new file mode 100644 index 00000000..6b8f30aa --- /dev/null +++ b/tunnel/link.go @@ -0,0 +1,13 @@ +package tunnel + +import ( + "github.com/google/uuid" + "github.com/micro/go-micro/transport" +) + +func newLink(s transport.Socket) *link { + return &link{ + Socket: s, + id: uuid.New().String(), + } +} diff --git a/tunnel/listener.go b/tunnel/listener.go index 3002e7b6..42aadfd1 100644 --- a/tunnel/listener.go +++ b/tunnel/listener.go @@ -41,6 +41,10 @@ func (t *tunListener) process() { id: m.id, // the session id session: m.session, + // is loopback conn + loopback: m.loopback, + // the link the message was received on + link: m.link, // close chan closed: make(chan bool), // recv called by the acceptor diff --git a/tunnel/socket.go b/tunnel/socket.go index 2590a48e..789d9555 100644 --- a/tunnel/socket.go +++ b/tunnel/socket.go @@ -25,18 +25,28 @@ type socket struct { recv chan *message // wait until we have a connection wait chan bool - // outbound marks the socket as outbound + // outbound marks the socket as outbound dialled connection outbound bool + // lookback marks the socket as a loopback on the inbound + loopback bool + // the link on which this message was received + link string } // message is sent over the send channel type message struct { + // type of message + typ string // tunnel id id string // the session id session string // outbound marks the message as outbound outbound bool + // loopback marks the message intended for loopback + loopback bool + // the link to send the message on + link string // transport data data *transport.Message } @@ -77,10 +87,15 @@ func (s *socket) Send(m *transport.Message) error { // append to backlog msg := &message{ + typ: "message", id: s.id, session: s.session, outbound: s.outbound, + loopback: s.loopback, data: data, + // specify the link on which to send this + // it will be blank for dialled sockets + link: s.link, } log.Debugf("Appending %+v to send backlog", msg) s.send <- msg