Rename Tunnel ID to Channel
This commit is contained in:
		| @@ -179,10 +179,10 @@ func (n *network) resolve() { | |||||||
| } | } | ||||||
|  |  | ||||||
| // handleNetConn handles network announcement messages | // handleNetConn handles network announcement messages | ||||||
| func (n *network) handleNetConn(conn tunnel.Conn, msg chan *transport.Message) { | func (n *network) handleNetConn(sess tunnel.Session, msg chan *transport.Message) { | ||||||
| 	for { | 	for { | ||||||
| 		m := new(transport.Message) | 		m := new(transport.Message) | ||||||
| 		if err := conn.Recv(m); err != nil { | 		if err := sess.Recv(m); err != nil { | ||||||
| 			// TODO: should we bail here? | 			// TODO: should we bail here? | ||||||
| 			log.Debugf("Network tunnel [%s] receive error: %v", NetworkChannel, err) | 			log.Debugf("Network tunnel [%s] receive error: %v", NetworkChannel, err) | ||||||
| 			return | 			return | ||||||
| @@ -349,10 +349,10 @@ func (n *network) announce(client transport.Client) { | |||||||
| } | } | ||||||
|  |  | ||||||
| // handleCtrlConn handles ControlChannel connections | // handleCtrlConn handles ControlChannel connections | ||||||
| func (n *network) handleCtrlConn(conn tunnel.Conn, msg chan *transport.Message) { | func (n *network) handleCtrlConn(sess tunnel.Session, msg chan *transport.Message) { | ||||||
| 	for { | 	for { | ||||||
| 		m := new(transport.Message) | 		m := new(transport.Message) | ||||||
| 		if err := conn.Recv(m); err != nil { | 		if err := sess.Recv(m); err != nil { | ||||||
| 			// TODO: should we bail here? | 			// TODO: should we bail here? | ||||||
| 			log.Debugf("Network tunnel advert receive error: %v", err) | 			log.Debugf("Network tunnel advert receive error: %v", err) | ||||||
| 			return | 			return | ||||||
|   | |||||||
| @@ -1,9 +1,8 @@ | |||||||
| package tunnel | package tunnel | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"crypto/sha256" |  | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"strings" | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| @@ -25,7 +24,10 @@ type tun struct { | |||||||
|  |  | ||||||
| 	sync.RWMutex | 	sync.RWMutex | ||||||
|  |  | ||||||
| 	// tunnel token | 	// the unique id for this tunnel | ||||||
|  | 	id string | ||||||
|  |  | ||||||
|  | 	// tunnel token for authentication | ||||||
| 	token string | 	token string | ||||||
|  |  | ||||||
| 	// to indicate if we're connected or not | 	// to indicate if we're connected or not | ||||||
| @@ -37,8 +39,8 @@ type tun struct { | |||||||
| 	// close channel | 	// close channel | ||||||
| 	closed chan bool | 	closed chan bool | ||||||
|  |  | ||||||
| 	// a map of sockets based on Micro-Tunnel-Id | 	// a map of sessions based on Micro-Tunnel-Channel | ||||||
| 	sockets map[string]*socket | 	sessions map[string]*session | ||||||
|  |  | ||||||
| 	// outbound links | 	// outbound links | ||||||
| 	links map[string]*link | 	links map[string]*link | ||||||
| @@ -47,25 +49,6 @@ type tun struct { | |||||||
| 	listener transport.Listener | 	listener transport.Listener | ||||||
| } | } | ||||||
|  |  | ||||||
| type link struct { |  | ||||||
| 	transport.Socket |  | ||||||
| 	// unique id of this link e.g uuid |  | ||||||
| 	// which we define for ourselves |  | ||||||
| 	id string |  | ||||||
| 	// whether its a loopback connection |  | ||||||
| 	// this flag is used by the transport listener |  | ||||||
| 	// which accepts inbound quic connections |  | ||||||
| 	loopback bool |  | ||||||
| 	// whether its actually connected |  | ||||||
| 	// dialled side sets it to connected |  | ||||||
| 	// after sending the message. the |  | ||||||
| 	// listener waits for the connect |  | ||||||
| 	connected bool |  | ||||||
| 	// the last time we received a keepalive |  | ||||||
| 	// on this link from the remote side |  | ||||||
| 	lastKeepAlive time.Time |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // create new tunnel on top of a link | // create new tunnel on top of a link | ||||||
| func newTunnel(opts ...Option) *tun { | func newTunnel(opts ...Option) *tun { | ||||||
| 	options := DefaultOptions() | 	options := DefaultOptions() | ||||||
| @@ -74,12 +57,13 @@ func newTunnel(opts ...Option) *tun { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return &tun{ | 	return &tun{ | ||||||
| 		options: options, | 		options:  options, | ||||||
| 		token:   uuid.New().String(), | 		id:       options.Id, | ||||||
| 		send:    make(chan *message, 128), | 		token:    options.Token, | ||||||
| 		closed:  make(chan bool), | 		send:     make(chan *message, 128), | ||||||
| 		sockets: make(map[string]*socket), | 		closed:   make(chan bool), | ||||||
| 		links:   make(map[string]*link), | 		sessions: make(map[string]*session), | ||||||
|  | 		links:    make(map[string]*link), | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -93,51 +77,48 @@ func (t *tun) Init(opts ...Option) error { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| // getSocket returns a socket from the internal socket map. | // getSession returns a session from the internal session map. | ||||||
| // It does this based on the Micro-Tunnel-Id and Micro-Tunnel-Session | // It does this based on the Micro-Tunnel-Channel and Micro-Tunnel-Session | ||||||
| func (t *tun) getSocket(id, session string) (*socket, bool) { | func (t *tun) getSession(channel, session string) (*session, bool) { | ||||||
| 	// get the socket | 	// get the session | ||||||
| 	t.RLock() | 	t.RLock() | ||||||
| 	s, ok := t.sockets[id+session] | 	s, ok := t.sessions[channel+session] | ||||||
| 	t.RUnlock() | 	t.RUnlock() | ||||||
| 	return s, ok | 	return s, ok | ||||||
| } | } | ||||||
|  |  | ||||||
| // newSocket creates a new socket and saves it | // newSession creates a new session and saves it | ||||||
| func (t *tun) newSocket(id, session string) (*socket, bool) { | func (t *tun) newSession(channel, sessionId string) (*session, bool) { | ||||||
| 	// hash the id | 	// new session | ||||||
| 	h := sha256.New() | 	s := &session{ | ||||||
| 	h.Write([]byte(id)) | 		id:      t.id, | ||||||
| 	id = fmt.Sprintf("%x", h.Sum(nil)) | 		channel: channel, | ||||||
|  | 		session: sessionId, | ||||||
| 	// new socket |  | ||||||
| 	s := &socket{ |  | ||||||
| 		id:      id, |  | ||||||
| 		session: session, |  | ||||||
| 		closed:  make(chan bool), | 		closed:  make(chan bool), | ||||||
| 		recv:    make(chan *message, 128), | 		recv:    make(chan *message, 128), | ||||||
| 		send:    t.send, | 		send:    t.send, | ||||||
| 		wait:    make(chan bool), | 		wait:    make(chan bool), | ||||||
|  | 		errChan: make(chan error, 1), | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// save socket | 	// save session | ||||||
| 	t.Lock() | 	t.Lock() | ||||||
| 	_, ok := t.sockets[id+session] | 	_, ok := t.sessions[channel+sessionId] | ||||||
| 	if ok { | 	if ok { | ||||||
| 		// socket already exists | 		// session already exists | ||||||
| 		t.Unlock() | 		t.Unlock() | ||||||
| 		return nil, false | 		return nil, false | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	t.sockets[id+session] = s | 	t.sessions[channel+sessionId] = s | ||||||
| 	t.Unlock() | 	t.Unlock() | ||||||
|  |  | ||||||
| 	// return socket | 	// return session | ||||||
| 	return s, true | 	return s, true | ||||||
| } | } | ||||||
|  |  | ||||||
| // TODO: use tunnel id as part of the session | // TODO: use tunnel id as part of the session | ||||||
| func (t *tun) newSession() string { | func (t *tun) newSessionId() string { | ||||||
| 	return uuid.New().String() | 	return uuid.New().String() | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -168,10 +149,10 @@ func (t *tun) monitor() { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| // process outgoing messages sent by all local sockets | // process outgoing messages sent by all local sessions | ||||||
| func (t *tun) process() { | func (t *tun) process() { | ||||||
| 	// manage the send buffer | 	// manage the send buffer | ||||||
| 	// all pseudo sockets throw everything down this | 	// all pseudo sessions throw everything down this | ||||||
| 	for { | 	for { | ||||||
| 		select { | 		select { | ||||||
| 		case msg := <-t.send: | 		case msg := <-t.send: | ||||||
| @@ -190,6 +171,9 @@ func (t *tun) process() { | |||||||
| 			// set the tunnel id on the outgoing message | 			// set the tunnel id on the outgoing message | ||||||
| 			newMsg.Header["Micro-Tunnel-Id"] = msg.id | 			newMsg.Header["Micro-Tunnel-Id"] = msg.id | ||||||
|  |  | ||||||
|  | 			// set the tunnel channel on the outgoing message | ||||||
|  | 			newMsg.Header["Micro-Tunnel-Channel"] = msg.channel | ||||||
|  |  | ||||||
| 			// set the session id | 			// set the session id | ||||||
| 			newMsg.Header["Micro-Tunnel-Session"] = msg.session | 			newMsg.Header["Micro-Tunnel-Session"] = msg.session | ||||||
|  |  | ||||||
| @@ -203,10 +187,14 @@ func (t *tun) process() { | |||||||
| 				log.Debugf("No links to send to") | 				log.Debugf("No links to send to") | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
|  | 			var sent bool | ||||||
|  | 			var err error | ||||||
|  |  | ||||||
| 			for node, link := range t.links { | 			for node, link := range t.links { | ||||||
| 				// if the link is not connected skip it | 				// if the link is not connected skip it | ||||||
| 				if !link.connected { | 				if !link.connected { | ||||||
| 					log.Debugf("Link for node %s not connected", node) | 					log.Debugf("Link for node %s not connected", node) | ||||||
|  | 					err = errors.New("link not connected") | ||||||
| 					continue | 					continue | ||||||
| 				} | 				} | ||||||
|  |  | ||||||
| @@ -214,6 +202,7 @@ func (t *tun) process() { | |||||||
| 				// this is where we explicitly set the link | 				// this is where we explicitly set the link | ||||||
| 				// in a message received via the listen method | 				// in a message received via the listen method | ||||||
| 				if len(msg.link) > 0 && link.id != msg.link { | 				if len(msg.link) > 0 && link.id != msg.link { | ||||||
|  | 					err = errors.New("link not found") | ||||||
| 					continue | 					continue | ||||||
| 				} | 				} | ||||||
|  |  | ||||||
| @@ -221,25 +210,41 @@ func (t *tun) process() { | |||||||
| 				// and the message is being sent outbound via | 				// and the message is being sent outbound via | ||||||
| 				// a dialled connection don't use this link | 				// a dialled connection don't use this link | ||||||
| 				if link.loopback && msg.outbound { | 				if link.loopback && msg.outbound { | ||||||
|  | 					err = errors.New("link is loopback") | ||||||
| 					continue | 					continue | ||||||
| 				} | 				} | ||||||
|  |  | ||||||
| 				// if the message was being returned by the loopback listener | 				// if the message was being returned by the loopback listener | ||||||
| 				// send it back up the loopback link only | 				// send it back up the loopback link only | ||||||
| 				if msg.loopback && !link.loopback { | 				if msg.loopback && !link.loopback { | ||||||
|  | 					err = errors.New("link is not loopback") | ||||||
| 					continue | 					continue | ||||||
| 				} | 				} | ||||||
|  |  | ||||||
| 				// send the message via the current link | 				// send the message via the current link | ||||||
| 				log.Debugf("Sending %+v to %s", newMsg, node) | 				log.Debugf("Sending %+v to %s", newMsg, node) | ||||||
| 				if err := link.Send(newMsg); err != nil { | 				if errr := link.Send(newMsg); errr != nil { | ||||||
| 					log.Debugf("Tunnel error sending %+v to %s: %v", newMsg, node, err) | 					log.Debugf("Tunnel error sending %+v to %s: %v", newMsg, node, errr) | ||||||
|  | 					err = errors.New(errr.Error()) | ||||||
| 					delete(t.links, node) | 					delete(t.links, node) | ||||||
| 					continue | 					continue | ||||||
| 				} | 				} | ||||||
|  | 				// is sent | ||||||
|  | 				sent = true | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			t.Unlock() | 			t.Unlock() | ||||||
|  |  | ||||||
|  | 			var gerr error | ||||||
|  | 			if !sent { | ||||||
|  | 				gerr = err | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			// return error non blocking | ||||||
|  | 			select { | ||||||
|  | 			case msg.errChan <- gerr: | ||||||
|  | 			default: | ||||||
|  | 			} | ||||||
| 		case <-t.closed: | 		case <-t.closed: | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| @@ -267,17 +272,23 @@ func (t *tun) listen(link *link) { | |||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 		// always ensure we have the correct auth token | ||||||
|  | 		// TODO: segment the tunnel based on token | ||||||
|  | 		// e.g use it as the basis | ||||||
|  | 		token := msg.Header["Micro-Tunnel-Token"] | ||||||
|  | 		if token != t.token { | ||||||
|  | 			log.Debugf("Tunnel link %s received invalid token %s", token) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  |  | ||||||
| 		switch msg.Header["Micro-Tunnel"] { | 		switch msg.Header["Micro-Tunnel"] { | ||||||
| 		case "connect": | 		case "connect": | ||||||
| 			log.Debugf("Tunnel link %s received connect message", link.Remote()) | 			log.Debugf("Tunnel link %s received connect message", link.Remote()) | ||||||
| 			// check the Micro-Tunnel-Token |  | ||||||
| 			token, ok := msg.Header["Micro-Tunnel-Token"] | 			id := msg.Header["Micro-Tunnel-Id"] | ||||||
| 			if !ok { |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			// are we connecting to ourselves? | 			// are we connecting to ourselves? | ||||||
| 			if token == t.token { | 			if id == t.id { | ||||||
| 				link.loopback = true | 				link.loopback = true | ||||||
| 				loopback = true | 				loopback = true | ||||||
| 			} | 			} | ||||||
| @@ -318,76 +329,77 @@ func (t *tun) listen(link *link) { | |||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// strip message header |  | ||||||
| 		delete(msg.Header, "Micro-Tunnel") |  | ||||||
|  |  | ||||||
| 		// the tunnel id | 		// the tunnel id | ||||||
| 		id := msg.Header["Micro-Tunnel-Id"] | 		id := msg.Header["Micro-Tunnel-Id"] | ||||||
| 		delete(msg.Header, "Micro-Tunnel-Id") | 		// the tunnel channel | ||||||
|  | 		channel := msg.Header["Micro-Tunnel-Channel"] | ||||||
| 		// the session id | 		// the session id | ||||||
| 		session := msg.Header["Micro-Tunnel-Session"] | 		sessionId := msg.Header["Micro-Tunnel-Session"] | ||||||
| 		delete(msg.Header, "Micro-Tunnel-Session") |  | ||||||
|  |  | ||||||
| 		// strip token header | 		// strip tunnel message header | ||||||
| 		delete(msg.Header, "Micro-Tunnel-Token") | 		for k, _ := range msg.Header { | ||||||
|  | 			if strings.HasPrefix(k, "Micro-Tunnel") { | ||||||
|  | 				delete(msg.Header, k) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  |  | ||||||
| 		// if the session id is blank there's nothing we can do | 		// if the session id is blank there's nothing we can do | ||||||
| 		// TODO: check this is the case, is there any reason | 		// TODO: check this is the case, is there any reason | ||||||
| 		// why we'd have a blank session? Is the tunnel | 		// why we'd have a blank session? Is the tunnel | ||||||
| 		// used for some other purpose? | 		// used for some other purpose? | ||||||
| 		if len(id) == 0 || len(session) == 0 { | 		if len(channel) == 0 || len(sessionId) == 0 { | ||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		var s *socket | 		var s *session | ||||||
| 		var exists bool | 		var exists bool | ||||||
|  |  | ||||||
| 		// If its a loopback connection then we've enabled link direction | 		// If its a loopback connection then we've enabled link direction | ||||||
| 		// listening side is used for listening, the dialling side for dialling | 		// listening side is used for listening, the dialling side for dialling | ||||||
| 		switch { | 		switch { | ||||||
| 		case loopback: | 		case loopback: | ||||||
| 			s, exists = t.getSocket(id, "listener") | 			s, exists = t.getSession(channel, "listener") | ||||||
| 		default: | 		default: | ||||||
| 			// get the socket based on the tunnel id and session | 			// get the session based on the tunnel id and session | ||||||
| 			// this could be something we dialed in which case | 			// this could be something we dialed in which case | ||||||
| 			// we have a session for it otherwise its a listener | 			// we have a session for it otherwise its a listener | ||||||
| 			s, exists = t.getSocket(id, session) | 			s, exists = t.getSession(channel, sessionId) | ||||||
| 			if !exists { | 			if !exists { | ||||||
| 				// try get it based on just the tunnel id | 				// try get it based on just the tunnel id | ||||||
| 				// the assumption here is that a listener | 				// the assumption here is that a listener | ||||||
| 				// has no session but its set a listener session | 				// has no session but its set a listener session | ||||||
| 				s, exists = t.getSocket(id, "listener") | 				s, exists = t.getSession(channel, "listener") | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// bail if no socket has been found | 		// bail if no session has been found | ||||||
| 		if !exists { | 		if !exists { | ||||||
| 			log.Debugf("Tunnel skipping no socket exists") | 			log.Debugf("Tunnel skipping no session exists") | ||||||
| 			// drop it, we don't care about | 			// drop it, we don't care about | ||||||
| 			// messages we don't know about | 			// messages we don't know about | ||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
| 		log.Debugf("Tunnel using socket %s %s", s.id, s.session) |  | ||||||
|  |  | ||||||
| 		// is the socket closed? | 		log.Debugf("Tunnel using session %s %s", s.channel, s.session) | ||||||
|  |  | ||||||
|  | 		// is the session closed? | ||||||
| 		select { | 		select { | ||||||
| 		case <-s.closed: | 		case <-s.closed: | ||||||
| 			// closed | 			// closed | ||||||
| 			delete(t.sockets, id) | 			delete(t.sessions, channel) | ||||||
| 			continue | 			continue | ||||||
| 		default: | 		default: | ||||||
| 			// process | 			// process | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// is the socket new? | 		// is the session new? | ||||||
| 		select { | 		select { | ||||||
| 		// if its new the socket is actually blocked waiting | 		// if its new the session is actually blocked waiting | ||||||
| 		// for a connection. so we check if its waiting. | 		// for a connection. so we check if its waiting. | ||||||
| 		case <-s.wait: | 		case <-s.wait: | ||||||
| 		// if its waiting e.g its new then we close it | 		// if its waiting e.g its new then we close it | ||||||
| 		default: | 		default: | ||||||
| 			// set remote address of the socket | 			// set remote address of the session | ||||||
| 			s.remote = msg.Header["Remote"] | 			s.remote = msg.Header["Remote"] | ||||||
| 			close(s.wait) | 			close(s.wait) | ||||||
| 		} | 		} | ||||||
| @@ -401,10 +413,12 @@ func (t *tun) listen(link *link) { | |||||||
| 		// construct the internal message | 		// construct the internal message | ||||||
| 		imsg := &message{ | 		imsg := &message{ | ||||||
| 			id:       id, | 			id:       id, | ||||||
| 			session:  session, | 			channel:  channel, | ||||||
|  | 			session:  sessionId, | ||||||
| 			data:     tmsg, | 			data:     tmsg, | ||||||
| 			link:     link.id, | 			link:     link.id, | ||||||
| 			loopback: loopback, | 			loopback: loopback, | ||||||
|  | 			errChan:  make(chan error, 1), | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// append to recv backlog | 		// append to recv backlog | ||||||
| @@ -431,6 +445,7 @@ func (t *tun) keepalive(link *link) { | |||||||
| 			if err := link.Send(&transport.Message{ | 			if err := link.Send(&transport.Message{ | ||||||
| 				Header: map[string]string{ | 				Header: map[string]string{ | ||||||
| 					"Micro-Tunnel":       "keepalive", | 					"Micro-Tunnel":       "keepalive", | ||||||
|  | 					"Micro-Tunnel-Id":    t.id, | ||||||
| 					"Micro-Tunnel-Token": t.token, | 					"Micro-Tunnel-Token": t.token, | ||||||
| 				}, | 				}, | ||||||
| 			}); err != nil { | 			}); err != nil { | ||||||
| @@ -458,6 +473,7 @@ func (t *tun) setupLink(node string) (*link, error) { | |||||||
| 	if err := c.Send(&transport.Message{ | 	if err := c.Send(&transport.Message{ | ||||||
| 		Header: map[string]string{ | 		Header: map[string]string{ | ||||||
| 			"Micro-Tunnel":       "connect", | 			"Micro-Tunnel":       "connect", | ||||||
|  | 			"Micro-Tunnel-Id":    t.id, | ||||||
| 			"Micro-Tunnel-Token": t.token, | 			"Micro-Tunnel-Token": t.token, | ||||||
| 		}, | 		}, | ||||||
| 	}); err != nil { | 	}); err != nil { | ||||||
| @@ -568,6 +584,7 @@ func (t *tun) close() error { | |||||||
| 		link.Send(&transport.Message{ | 		link.Send(&transport.Message{ | ||||||
| 			Header: map[string]string{ | 			Header: map[string]string{ | ||||||
| 				"Micro-Tunnel":       "close", | 				"Micro-Tunnel":       "close", | ||||||
|  | 				"Micro-Tunnel-Id":    t.id, | ||||||
| 				"Micro-Tunnel-Token": t.token, | 				"Micro-Tunnel-Token": t.token, | ||||||
| 			}, | 			}, | ||||||
| 		}) | 		}) | ||||||
| @@ -603,10 +620,10 @@ func (t *tun) Close() error { | |||||||
| 	case <-t.closed: | 	case <-t.closed: | ||||||
| 		return nil | 		return nil | ||||||
| 	default: | 	default: | ||||||
| 		// close all the sockets | 		// close all the sessions | ||||||
| 		for id, s := range t.sockets { | 		for id, s := range t.sessions { | ||||||
| 			s.Close() | 			s.Close() | ||||||
| 			delete(t.sockets, id) | 			delete(t.sessions, id) | ||||||
| 		} | 		} | ||||||
| 		// close the connection | 		// close the connection | ||||||
| 		close(t.closed) | 		close(t.closed) | ||||||
| @@ -622,52 +639,50 @@ func (t *tun) Close() error { | |||||||
| } | } | ||||||
|  |  | ||||||
| // Dial an address | // Dial an address | ||||||
| func (t *tun) Dial(addr string) (Conn, error) { | func (t *tun) Dial(channel string) (Session, error) { | ||||||
| 	log.Debugf("Tunnel dialing %s", addr) | 	log.Debugf("Tunnel dialing %s", channel) | ||||||
| 	c, ok := t.newSocket(addr, t.newSession()) | 	c, ok := t.newSession(channel, t.newSessionId()) | ||||||
| 	if !ok { | 	if !ok { | ||||||
| 		return nil, errors.New("error dialing " + addr) | 		return nil, errors.New("error dialing " + channel) | ||||||
| 	} | 	} | ||||||
| 	// set remote | 	// set remote | ||||||
| 	c.remote = addr | 	c.remote = channel | ||||||
| 	// set local | 	// set local | ||||||
| 	c.local = "local" | 	c.local = "local" | ||||||
| 	// outbound socket | 	// outbound session | ||||||
| 	c.outbound = true | 	c.outbound = true | ||||||
|  |  | ||||||
| 	return c, nil | 	return c, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| // Accept a connection on the address | // Accept a connection on the address | ||||||
| func (t *tun) Listen(addr string) (Listener, error) { | func (t *tun) Listen(channel string) (Listener, error) { | ||||||
| 	log.Debugf("Tunnel listening on %s", addr) | 	log.Debugf("Tunnel listening on %s", channel) | ||||||
| 	// create a new socket by hashing the address | 	// create a new session by hashing the address | ||||||
| 	c, ok := t.newSocket(addr, "listener") | 	c, ok := t.newSession(channel, "listener") | ||||||
| 	if !ok { | 	if !ok { | ||||||
| 		return nil, errors.New("already listening on " + addr) | 		return nil, errors.New("already listening on " + channel) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// set remote. it will be replaced by the first message received | 	// set remote. it will be replaced by the first message received | ||||||
| 	c.remote = "remote" | 	c.remote = "remote" | ||||||
| 	// set local | 	// set local | ||||||
| 	c.local = addr | 	c.local = channel | ||||||
|  |  | ||||||
| 	tl := &tunListener{ | 	tl := &tunListener{ | ||||||
| 		addr: addr, | 		channel: channel, | ||||||
| 		// the accept channel | 		// the accept channel | ||||||
| 		accept: make(chan *socket, 128), | 		accept: make(chan *session, 128), | ||||||
| 		// the channel to close | 		// the channel to close | ||||||
| 		closed: make(chan bool), | 		closed: make(chan bool), | ||||||
| 		// tunnel closed channel | 		// tunnel closed channel | ||||||
| 		tunClosed: t.closed, | 		tunClosed: t.closed, | ||||||
| 		// the connection | 		// the listener session | ||||||
| 		conn: c, | 		session: c, | ||||||
| 		// the listener socket |  | ||||||
| 		socket: c, |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// this kicks off the internal message processor | 	// this kicks off the internal message processor | ||||||
| 	// for the listener so it can create pseudo sockets | 	// for the listener so it can create pseudo sessions | ||||||
| 	// per session if they do not exist or pass messages | 	// per session if they do not exist or pass messages | ||||||
| 	// to the existign sessions | 	// to the existign sessions | ||||||
| 	go tl.process() | 	go tl.process() | ||||||
|   | |||||||
| @@ -1,10 +1,34 @@ | |||||||
| package tunnel | package tunnel | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"sync" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
| 	"github.com/google/uuid" | 	"github.com/google/uuid" | ||||||
| 	"github.com/micro/go-micro/transport" | 	"github.com/micro/go-micro/transport" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | type link struct { | ||||||
|  | 	sync.RWMutex | ||||||
|  |  | ||||||
|  | 	transport.Socket | ||||||
|  | 	// unique id of this link e.g uuid | ||||||
|  | 	// which we define for ourselves | ||||||
|  | 	id string | ||||||
|  | 	// whether its a loopback connection | ||||||
|  | 	// this flag is used by the transport listener | ||||||
|  | 	// which accepts inbound quic connections | ||||||
|  | 	loopback bool | ||||||
|  | 	// whether its actually connected | ||||||
|  | 	// dialled side sets it to connected | ||||||
|  | 	// after sending the message. the | ||||||
|  | 	// listener waits for the connect | ||||||
|  | 	connected bool | ||||||
|  | 	// the last time we received a keepalive | ||||||
|  | 	// on this link from the remote side | ||||||
|  | 	lastKeepAlive time.Time | ||||||
|  | } | ||||||
|  |  | ||||||
| func newLink(s transport.Socket) *link { | func newLink(s transport.Socket) *link { | ||||||
| 	return &link{ | 	return &link{ | ||||||
| 		Socket: s, | 		Socket: s, | ||||||
|   | |||||||
| @@ -8,37 +8,37 @@ import ( | |||||||
|  |  | ||||||
| type tunListener struct { | type tunListener struct { | ||||||
| 	// address of the listener | 	// address of the listener | ||||||
| 	addr string | 	channel string | ||||||
| 	// the accept channel | 	// the accept channel | ||||||
| 	accept chan *socket | 	accept chan *session | ||||||
| 	// the channel to close | 	// the channel to close | ||||||
| 	closed chan bool | 	closed chan bool | ||||||
| 	// the tunnel closed channel | 	// the tunnel closed channel | ||||||
| 	tunClosed chan bool | 	tunClosed chan bool | ||||||
| 	// the connection | 	// the listener session | ||||||
| 	conn Conn | 	session *session | ||||||
| 	// the listener socket |  | ||||||
| 	socket *socket |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func (t *tunListener) process() { | func (t *tunListener) process() { | ||||||
| 	// our connection map for session | 	// our connection map for session | ||||||
| 	conns := make(map[string]*socket) | 	conns := make(map[string]*session) | ||||||
|  |  | ||||||
| 	for { | 	for { | ||||||
| 		select { | 		select { | ||||||
| 		case <-t.closed: | 		case <-t.closed: | ||||||
| 			return | 			return | ||||||
| 		// receive a new message | 		// receive a new message | ||||||
| 		case m := <-t.socket.recv: | 		case m := <-t.session.recv: | ||||||
| 			// get a socket | 			// get a session | ||||||
| 			sock, ok := conns[m.session] | 			sess, ok := conns[m.session] | ||||||
| 			log.Debugf("Tunnel listener received id %s session %s exists: %t", m.id, m.session, ok) | 			log.Debugf("Tunnel listener received id %s session %s exists: %t", m.id, m.session, ok) | ||||||
| 			if !ok { | 			if !ok { | ||||||
| 				// create a new socket session | 				// create a new session session | ||||||
| 				sock = &socket{ | 				sess = &session{ | ||||||
| 					// our tunnel id | 					// the id of the remote side | ||||||
| 					id: m.id, | 					id: m.id, | ||||||
|  | 					// the channel | ||||||
|  | 					channel: m.channel, | ||||||
| 					// the session id | 					// the session id | ||||||
| 					session: m.session, | 					session: m.session, | ||||||
| 					// is loopback conn | 					// is loopback conn | ||||||
| @@ -50,35 +50,37 @@ func (t *tunListener) process() { | |||||||
| 					// recv called by the acceptor | 					// recv called by the acceptor | ||||||
| 					recv: make(chan *message, 128), | 					recv: make(chan *message, 128), | ||||||
| 					// use the internal send buffer | 					// use the internal send buffer | ||||||
| 					send: t.socket.send, | 					send: t.session.send, | ||||||
| 					// wait | 					// wait | ||||||
| 					wait: make(chan bool), | 					wait: make(chan bool), | ||||||
|  | 					// error channel | ||||||
|  | 					errChan: make(chan error, 1), | ||||||
| 				} | 				} | ||||||
|  |  | ||||||
| 				// save the socket | 				// save the session | ||||||
| 				conns[m.session] = sock | 				conns[m.session] = sess | ||||||
|  |  | ||||||
| 				// send to accept chan | 				// send to accept chan | ||||||
| 				select { | 				select { | ||||||
| 				case <-t.closed: | 				case <-t.closed: | ||||||
| 					return | 					return | ||||||
| 				case t.accept <- sock: | 				case t.accept <- sess: | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			// send this to the accept chan | 			// send this to the accept chan | ||||||
| 			select { | 			select { | ||||||
| 			case <-sock.closed: | 			case <-sess.closed: | ||||||
| 				delete(conns, m.session) | 				delete(conns, m.session) | ||||||
| 			case sock.recv <- m: | 			case sess.recv <- m: | ||||||
| 				log.Debugf("Tunnel listener sent to recv chan id %s session %s", m.id, m.session) | 				log.Debugf("Tunnel listener sent to recv chan id %s session %s", m.id, m.session) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func (t *tunListener) Addr() string { | func (t *tunListener) Channel() string { | ||||||
| 	return t.addr | 	return t.channel | ||||||
| } | } | ||||||
|  |  | ||||||
| // Close closes tunnel listener | // Close closes tunnel listener | ||||||
| @@ -93,9 +95,9 @@ func (t *tunListener) Close() error { | |||||||
| } | } | ||||||
|  |  | ||||||
| // Everytime accept is called we essentially block till we get a new connection | // Everytime accept is called we essentially block till we get a new connection | ||||||
| func (t *tunListener) Accept() (Conn, error) { | func (t *tunListener) Accept() (Session, error) { | ||||||
| 	select { | 	select { | ||||||
| 	// if the socket is closed return | 	// if the session is closed return | ||||||
| 	case <-t.closed: | 	case <-t.closed: | ||||||
| 		return nil, io.EOF | 		return nil, io.EOF | ||||||
| 	case <-t.tunClosed: | 	case <-t.tunClosed: | ||||||
|   | |||||||
| @@ -9,6 +9,8 @@ import ( | |||||||
| var ( | var ( | ||||||
| 	// DefaultAddress is default tunnel bind address | 	// DefaultAddress is default tunnel bind address | ||||||
| 	DefaultAddress = ":0" | 	DefaultAddress = ":0" | ||||||
|  | 	// The shared default token | ||||||
|  | 	DefaultToken = "micro" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type Option func(*Options) | type Option func(*Options) | ||||||
| @@ -21,6 +23,8 @@ type Options struct { | |||||||
| 	Address string | 	Address string | ||||||
| 	// Nodes are remote nodes | 	// Nodes are remote nodes | ||||||
| 	Nodes []string | 	Nodes []string | ||||||
|  | 	// The shared auth token | ||||||
|  | 	Token string | ||||||
| 	// Transport listens to incoming connections | 	// Transport listens to incoming connections | ||||||
| 	Transport transport.Transport | 	Transport transport.Transport | ||||||
| } | } | ||||||
| @@ -46,6 +50,13 @@ func Nodes(n ...string) Option { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Token sets the shared token for auth | ||||||
|  | func Token(t string) Option { | ||||||
|  | 	return func(o *Options) { | ||||||
|  | 		o.Token = t | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| // Transport listens for incoming connections | // Transport listens for incoming connections | ||||||
| func Transport(t transport.Transport) Option { | func Transport(t transport.Transport) Option { | ||||||
| 	return func(o *Options) { | 	return func(o *Options) { | ||||||
| @@ -58,6 +69,7 @@ func DefaultOptions() Options { | |||||||
| 	return Options{ | 	return Options{ | ||||||
| 		Id:        uuid.New().String(), | 		Id:        uuid.New().String(), | ||||||
| 		Address:   DefaultAddress, | 		Address:   DefaultAddress, | ||||||
|  | 		Token:     DefaultToken, | ||||||
| 		Transport: quic.NewTransport(), | 		Transport: quic.NewTransport(), | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|   | |||||||
| @@ -2,15 +2,18 @@ package tunnel | |||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"errors" | 	"errors" | ||||||
|  | 	"io" | ||||||
| 
 | 
 | ||||||
| 	"github.com/micro/go-micro/transport" | 	"github.com/micro/go-micro/transport" | ||||||
| 	"github.com/micro/go-micro/util/log" | 	"github.com/micro/go-micro/util/log" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // socket is our pseudo socket for transport.Socket | // session is our pseudo session for transport.Socket | ||||||
| type socket struct { | type session struct { | ||||||
| 	// socket id based on Micro-Tunnel | 	// unique id based on the remote tunnel id | ||||||
| 	id string | 	id string | ||||||
|  | 	// the channel name | ||||||
|  | 	channel string | ||||||
| 	// the session id based on Micro.Tunnel-Session | 	// the session id based on Micro.Tunnel-Session | ||||||
| 	session string | 	session string | ||||||
| 	// closed | 	// closed | ||||||
| @@ -25,12 +28,14 @@ type socket struct { | |||||||
| 	recv chan *message | 	recv chan *message | ||||||
| 	// wait until we have a connection | 	// wait until we have a connection | ||||||
| 	wait chan bool | 	wait chan bool | ||||||
| 	// outbound marks the socket as outbound dialled connection | 	// outbound marks the session as outbound dialled connection | ||||||
| 	outbound bool | 	outbound bool | ||||||
| 	// lookback marks the socket as a loopback on the inbound | 	// lookback marks the session as a loopback on the inbound | ||||||
| 	loopback bool | 	loopback bool | ||||||
| 	// the link on which this message was received | 	// the link on which this message was received | ||||||
| 	link string | 	link string | ||||||
|  | 	// the error response | ||||||
|  | 	errChan chan error | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // message is sent over the send channel | // message is sent over the send channel | ||||||
| @@ -39,6 +44,8 @@ type message struct { | |||||||
| 	typ string | 	typ string | ||||||
| 	// tunnel id | 	// tunnel id | ||||||
| 	id string | 	id string | ||||||
|  | 	// channel name | ||||||
|  | 	channel string | ||||||
| 	// the session id | 	// the session id | ||||||
| 	session string | 	session string | ||||||
| 	// outbound marks the message as outbound | 	// outbound marks the message as outbound | ||||||
| @@ -49,28 +56,30 @@ type message struct { | |||||||
| 	link string | 	link string | ||||||
| 	// transport data | 	// transport data | ||||||
| 	data *transport.Message | 	data *transport.Message | ||||||
|  | 	// the error channel | ||||||
|  | 	errChan chan error | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (s *socket) Remote() string { | func (s *session) Remote() string { | ||||||
| 	return s.remote | 	return s.remote | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (s *socket) Local() string { | func (s *session) Local() string { | ||||||
| 	return s.local | 	return s.local | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (s *socket) Id() string { | func (s *session) Id() string { | ||||||
| 	return s.id |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (s *socket) Session() string { |  | ||||||
| 	return s.session | 	return s.session | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (s *socket) Send(m *transport.Message) error { | func (s *session) Channel() string { | ||||||
|  | 	return s.channel | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *session) Send(m *transport.Message) error { | ||||||
| 	select { | 	select { | ||||||
| 	case <-s.closed: | 	case <-s.closed: | ||||||
| 		return errors.New("socket is closed") | 		return errors.New("session is closed") | ||||||
| 	default: | 	default: | ||||||
| 		// no op | 		// no op | ||||||
| 	} | 	} | ||||||
| @@ -89,28 +98,48 @@ func (s *socket) Send(m *transport.Message) error { | |||||||
| 	msg := &message{ | 	msg := &message{ | ||||||
| 		typ:      "message", | 		typ:      "message", | ||||||
| 		id:       s.id, | 		id:       s.id, | ||||||
|  | 		channel:  s.channel, | ||||||
| 		session:  s.session, | 		session:  s.session, | ||||||
| 		outbound: s.outbound, | 		outbound: s.outbound, | ||||||
| 		loopback: s.loopback, | 		loopback: s.loopback, | ||||||
| 		data:     data, | 		data:     data, | ||||||
| 		// specify the link on which to send this | 		// specify the link on which to send this | ||||||
| 		// it will be blank for dialled sockets | 		// it will be blank for dialled sessions | ||||||
| 		link: s.link, | 		link: s.link, | ||||||
|  | 		// error chan | ||||||
|  | 		errChan: s.errChan, | ||||||
| 	} | 	} | ||||||
| 	log.Debugf("Appending %+v to send backlog", msg) | 	log.Debugf("Appending %+v to send backlog", msg) | ||||||
| 	s.send <- msg | 	s.send <- msg | ||||||
|  | 
 | ||||||
|  | 	// wait for an error response | ||||||
|  | 	select { | ||||||
|  | 	case err := <-msg.errChan: | ||||||
|  | 		return err | ||||||
|  | 	case <-s.closed: | ||||||
|  | 		return io.EOF | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (s *socket) Recv(m *transport.Message) error { | func (s *session) Recv(m *transport.Message) error { | ||||||
| 	select { | 	select { | ||||||
| 	case <-s.closed: | 	case <-s.closed: | ||||||
| 		return errors.New("socket is closed") | 		return errors.New("session is closed") | ||||||
| 	default: | 	default: | ||||||
| 		// no op | 		// no op | ||||||
| 	} | 	} | ||||||
| 	// recv from backlog | 	// recv from backlog | ||||||
| 	msg := <-s.recv | 	msg := <-s.recv | ||||||
|  | 
 | ||||||
|  | 	// check the error if one exists | ||||||
|  | 	select { | ||||||
|  | 	case err := <-msg.errChan: | ||||||
|  | 		return err | ||||||
|  | 	default: | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	log.Debugf("Received %+v from recv backlog", msg) | 	log.Debugf("Received %+v from recv backlog", msg) | ||||||
| 	// set message | 	// set message | ||||||
| 	*m = *msg.data | 	*m = *msg.data | ||||||
| @@ -118,8 +147,8 @@ func (s *socket) Recv(m *transport.Message) error { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Close closes the socket | // Close closes the session | ||||||
| func (s *socket) Close() error { | func (s *session) Close() error { | ||||||
| 	select { | 	select { | ||||||
| 	case <-s.closed: | 	case <-s.closed: | ||||||
| 		// no op | 		// no op | ||||||
| @@ -10,7 +10,7 @@ type tunListener struct { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (t *tunListener) Addr() string { | func (t *tunListener) Addr() string { | ||||||
| 	return t.l.Addr() | 	return t.l.Channel() | ||||||
| } | } | ||||||
|  |  | ||||||
| func (t *tunListener) Close() error { | func (t *tunListener) Close() error { | ||||||
|   | |||||||
| @@ -17,27 +17,27 @@ type Tunnel interface { | |||||||
| 	Connect() error | 	Connect() error | ||||||
| 	// Close closes the tunnel | 	// Close closes the tunnel | ||||||
| 	Close() error | 	Close() error | ||||||
| 	// Dial an endpoint | 	// Connect to a channel | ||||||
| 	Dial(addr string) (Conn, error) | 	Dial(channel string) (Session, error) | ||||||
| 	// Accept connections | 	// Accept connections on a channel | ||||||
| 	Listen(addr string) (Listener, error) | 	Listen(channel string) (Listener, error) | ||||||
| 	// Name of the tunnel implementation | 	// Name of the tunnel implementation | ||||||
| 	String() string | 	String() string | ||||||
| } | } | ||||||
|  |  | ||||||
| // The listener provides similar constructs to the transport.Listener | // The listener provides similar constructs to the transport.Listener | ||||||
| type Listener interface { | type Listener interface { | ||||||
| 	Addr() string | 	Channel() string | ||||||
| 	Close() error | 	Close() error | ||||||
| 	Accept() (Conn, error) | 	Accept() (Session, error) | ||||||
| } | } | ||||||
|  |  | ||||||
| // Conn is a connection dialed or accepted which includes the tunnel id and session | // Session is a unique session created when dialling or accepting connections on the tunnel | ||||||
| type Conn interface { | type Session interface { | ||||||
| 	// Specifies the tunnel id | 	// Specifies the tunnel id | ||||||
| 	Id() string | 	Id() string | ||||||
| 	// The session | 	// The session | ||||||
| 	Session() string | 	Channel() string | ||||||
| 	// a transport socket | 	// a transport socket | ||||||
| 	transport.Socket | 	transport.Socket | ||||||
| } | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user