Don't process unless connected, and only fire loopback messages back up the loopback
This commit is contained in:
		| @@ -49,8 +49,18 @@ type tun struct { | ||||
|  | ||||
| type link struct { | ||||
| 	transport.Socket | ||||
| 	id            string | ||||
| 	loopback      bool | ||||
| 	// unique id of this link e.g uuid | ||||
| 	// which we define for ourselves | ||||
| 	id string | ||||
| 	// whether its a loopback connection | ||||
| 	loopback bool | ||||
| 	// whether its actually connected | ||||
| 	// dialled side sets it to connected | ||||
| 	// after sending the message. the | ||||
| 	// listener waits for the connect | ||||
| 	connected bool | ||||
| 	// the last time we received a keepalive | ||||
| 	// on this link from the remote side | ||||
| 	lastKeepAlive time.Time | ||||
| } | ||||
|  | ||||
| @@ -190,9 +200,25 @@ func (t *tun) process() { | ||||
| 				log.Debugf("No links to send to") | ||||
| 			} | ||||
| 			for node, link := range t.links { | ||||
| 				// if the link is not connected skip it | ||||
| 				if !link.connected { | ||||
| 					log.Debugf("Link for node %s not connected", node) | ||||
| 					continue | ||||
| 				} | ||||
|  | ||||
| 				// if the link was a loopback accepted connection | ||||
| 				// and the message is being sent outbound via | ||||
| 				// a dialled connection don't use this link | ||||
| 				if link.loopback && msg.outbound { | ||||
| 					continue | ||||
| 				} | ||||
|  | ||||
| 				// if the message was being returned by the loopback listener | ||||
| 				// send it back up the loopback link only | ||||
| 				if msg.loopback && !link.loopback { | ||||
| 					continue | ||||
| 				} | ||||
|  | ||||
| 				log.Debugf("Sending %+v to %s", newMsg, node) | ||||
| 				if err := link.Send(newMsg); err != nil { | ||||
| 					log.Debugf("Tunnel error sending %+v to %s: %v", newMsg, node, err) | ||||
| @@ -209,15 +235,22 @@ func (t *tun) process() { | ||||
|  | ||||
| // process incoming messages | ||||
| func (t *tun) listen(link *link) { | ||||
| 	// remove the link on exit | ||||
| 	defer func() { | ||||
| 		log.Debugf("Tunnel deleting connection from %s", link.Remote()) | ||||
| 		t.Lock() | ||||
| 		delete(t.links, link.Remote()) | ||||
| 		t.Unlock() | ||||
| 	}() | ||||
|  | ||||
| 	// let us know if its a loopback | ||||
| 	var loopback bool | ||||
|  | ||||
| 	for { | ||||
| 		// process anything via the net interface | ||||
| 		msg := new(transport.Message) | ||||
| 		err := link.Recv(msg) | ||||
| 		if err != nil { | ||||
| 		if err := link.Recv(msg); err != nil { | ||||
| 			log.Debugf("Tunnel link %s receive error: %#v", link.Remote(), err) | ||||
| 			t.Lock() | ||||
| 			delete(t.links, link.Remote()) | ||||
| 			t.Unlock() | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| @@ -232,11 +265,18 @@ func (t *tun) listen(link *link) { | ||||
|  | ||||
| 			// are we connecting to ourselves? | ||||
| 			if token == t.token { | ||||
| 				t.Lock() | ||||
| 				link.loopback = true | ||||
| 				t.Unlock() | ||||
| 				loopback = true | ||||
| 			} | ||||
|  | ||||
| 			// set as connected | ||||
| 			link.connected = true | ||||
|  | ||||
| 			// save the link once connected | ||||
| 			t.Lock() | ||||
| 			t.links[link.Remote()] = link | ||||
| 			t.Unlock() | ||||
|  | ||||
| 			// nothing more to do | ||||
| 			continue | ||||
| 		case "close": | ||||
| @@ -258,6 +298,11 @@ func (t *tun) listen(link *link) { | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		// if its not connected throw away the link | ||||
| 		if !link.connected { | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		// strip message header | ||||
| 		delete(msg.Header, "Micro-Tunnel") | ||||
|  | ||||
| @@ -283,8 +328,10 @@ func (t *tun) listen(link *link) { | ||||
| 		var s *socket | ||||
| 		var exists bool | ||||
|  | ||||
| 		// If its a loopback connection then we've enabled link direction | ||||
| 		// listening side is used for listening, the dialling side for dialling | ||||
| 		switch { | ||||
| 		case link.loopback: | ||||
| 		case loopback: | ||||
| 			s, exists = t.getSocket(id, "listener") | ||||
| 		default: | ||||
| 			// get the socket based on the tunnel id and session | ||||
| @@ -298,6 +345,7 @@ func (t *tun) listen(link *link) { | ||||
| 				s, exists = t.getSocket(id, "listener") | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		// bail if no socket has been found | ||||
| 		if !exists { | ||||
| 			log.Debugf("Tunnel skipping no socket exists") | ||||
| @@ -337,9 +385,10 @@ func (t *tun) listen(link *link) { | ||||
|  | ||||
| 		// construct the internal message | ||||
| 		imsg := &message{ | ||||
| 			id:      id, | ||||
| 			session: session, | ||||
| 			data:    tmsg, | ||||
| 			id:       id, | ||||
| 			session:  session, | ||||
| 			data:     tmsg, | ||||
| 			loopback: loopback, | ||||
| 		} | ||||
|  | ||||
| 		// append to recv backlog | ||||
| @@ -399,13 +448,14 @@ func (t *tun) setupLink(node string) (*link, error) { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	// save the link | ||||
| 	id := uuid.New().String() | ||||
| 	// create a new link | ||||
| 	link := &link{ | ||||
| 		Socket: c, | ||||
| 		id:     id, | ||||
| 		id:     uuid.New().String(), | ||||
| 		// we made the outbound connection | ||||
| 		// and sent the connect message | ||||
| 		connected: true, | ||||
| 	} | ||||
| 	t.links[node] = link | ||||
|  | ||||
| 	// process incoming messages | ||||
| 	go t.listen(link) | ||||
| @@ -430,25 +480,16 @@ func (t *tun) connect() error { | ||||
| 		// accept inbound connections | ||||
| 		err := l.Accept(func(sock transport.Socket) { | ||||
| 			log.Debugf("Tunnel accepted connection from %s", sock.Remote()) | ||||
| 			// save the link | ||||
| 			id := uuid.New().String() | ||||
| 			t.Lock() | ||||
|  | ||||
| 			// create a new link | ||||
| 			link := &link{ | ||||
| 				Socket: sock, | ||||
| 				id:     id, | ||||
| 				id:     uuid.New().String(), | ||||
| 			} | ||||
| 			t.links[sock.Remote()] = link | ||||
| 			t.Unlock() | ||||
|  | ||||
| 			// delete the link | ||||
| 			defer func() { | ||||
| 				log.Debugf("Tunnel deleting connection from %s", sock.Remote()) | ||||
| 				t.Lock() | ||||
| 				delete(t.links, sock.Remote()) | ||||
| 				t.Unlock() | ||||
| 			}() | ||||
|  | ||||
| 			// listen for inbound messages | ||||
| 			// listen for inbound messages. | ||||
| 			// only save the link once connected. | ||||
| 			// we do this inside liste | ||||
| 			t.listen(link) | ||||
| 		}) | ||||
|  | ||||
| @@ -473,6 +514,7 @@ func (t *tun) connect() error { | ||||
| 			log.Debugf("Tunnel failed to establish node link to %s: %v", node, err) | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		// save the link | ||||
| 		t.links[node] = link | ||||
| 	} | ||||
|   | ||||
| @@ -41,6 +41,8 @@ func (t *tunListener) process() { | ||||
| 					id: m.id, | ||||
| 					// the session id | ||||
| 					session: m.session, | ||||
| 					// is loopback conn | ||||
| 					loopback: m.loopback, | ||||
| 					// close chan | ||||
| 					closed: make(chan bool), | ||||
| 					// recv called by the acceptor | ||||
|   | ||||
| @@ -25,8 +25,10 @@ type socket struct { | ||||
| 	recv chan *message | ||||
| 	// wait until we have a connection | ||||
| 	wait chan bool | ||||
| 	// outbound marks the socket as outbound | ||||
| 	// outbound marks the socket as outbound dialled connection | ||||
| 	outbound bool | ||||
| 	// lookback marks the socket as a loopback on the inbound | ||||
| 	loopback bool | ||||
| } | ||||
|  | ||||
| // message is sent over the send channel | ||||
| @@ -37,6 +39,8 @@ type message struct { | ||||
| 	session string | ||||
| 	// outbound marks the message as outbound | ||||
| 	outbound bool | ||||
| 	// loopback marks the message intended for loopback | ||||
| 	loopback bool | ||||
| 	// transport data | ||||
| 	data *transport.Message | ||||
| } | ||||
| @@ -80,6 +84,7 @@ func (s *socket) Send(m *transport.Message) error { | ||||
| 		id:       s.id, | ||||
| 		session:  s.session, | ||||
| 		outbound: s.outbound, | ||||
| 		loopback: s.loopback, | ||||
| 		data:     data, | ||||
| 	} | ||||
| 	log.Debugf("Appending %+v to send backlog", msg) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user