Merge pull request #1038 from micro/tun

Next level tunnel optimisation
This commit is contained in:
Asim Aslam 2019-12-13 15:34:03 +00:00 committed by GitHub
commit 64e438a8d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 600 additions and 265 deletions

View File

@ -6,6 +6,7 @@ import (
"hash/fnv" "hash/fnv"
"io" "io"
"math" "math"
"math/rand"
"sort" "sort"
"sync" "sync"
"time" "time"
@ -88,6 +89,7 @@ type message struct {
// newNetwork returns a new network node // newNetwork returns a new network node
func newNetwork(opts ...Option) Network { func newNetwork(opts ...Option) Network {
rand.Seed(time.Now().UnixNano())
options := DefaultOptions() options := DefaultOptions()
for _, o := range opts { for _, o := range opts {
@ -168,7 +170,7 @@ func newNetwork(opts ...Option) Network {
tunClient: make(map[string]transport.Client), tunClient: make(map[string]transport.Client),
peerLinks: make(map[string]tunnel.Link), peerLinks: make(map[string]tunnel.Link),
discovered: make(chan bool, 1), discovered: make(chan bool, 1),
solicited: make(chan *node, 1), solicited: make(chan *node, 32),
} }
network.node.network = network network.node.network = network
@ -178,12 +180,11 @@ func newNetwork(opts ...Option) Network {
func (n *network) Init(opts ...Option) error { func (n *network) Init(opts ...Option) error {
n.Lock() n.Lock()
defer n.Unlock()
// TODO: maybe only allow reinit of certain opts // TODO: maybe only allow reinit of certain opts
for _, o := range opts { for _, o := range opts {
o(&n.options) o(&n.options)
} }
n.Unlock()
return nil return nil
} }
@ -191,10 +192,8 @@ func (n *network) Init(opts ...Option) error {
// Options returns network options // Options returns network options
func (n *network) Options() Options { func (n *network) Options() Options {
n.RLock() n.RLock()
defer n.RUnlock()
options := n.options options := n.options
n.RUnlock()
return options return options
} }
@ -332,8 +331,9 @@ func (n *network) advertise(advertChan <-chan *router.Advert) {
// someone requested the route // someone requested the route
n.sendTo("advert", ControlChannel, peer, msg) n.sendTo("advert", ControlChannel, peer, msg)
default: default:
// send to all since we can't get anything if err := n.sendMsg("advert", ControlChannel, msg); err != nil {
n.sendMsg("advert", ControlChannel, msg) log.Debugf("Network failed to advertise routes: %v", err)
}
} }
case <-n.closed: case <-n.closed:
return return
@ -498,12 +498,12 @@ func (n *network) getHopCount(rtr string) int {
} }
// the route origin is our peer // the route origin is our peer
if _, ok := n.peers[rtr]; ok { if _, ok := n.node.peers[rtr]; ok {
return 10 return 10
} }
// the route origin is the peer of our peer // the route origin is the peer of our peer
for _, peer := range n.peers { for _, peer := range n.node.peers {
for id := range peer.peers { for id := range peer.peers {
if rtr == id { if rtr == id {
return 100 return 100
@ -667,7 +667,7 @@ func (n *network) processCtrlChan(listener tunnel.Listener) {
log.Debugf("Network failed to process advert %s: %v", advert.Id, err) log.Debugf("Network failed to process advert %s: %v", advert.Id, err)
} }
case "solicit": case "solicit":
pbRtrSolicit := &pbRtr.Solicit{} pbRtrSolicit := new(pbRtr.Solicit)
if err := proto.Unmarshal(m.msg.Body, pbRtrSolicit); err != nil { if err := proto.Unmarshal(m.msg.Body, pbRtrSolicit); err != nil {
log.Debugf("Network fail to unmarshal solicit message: %v", err) log.Debugf("Network fail to unmarshal solicit message: %v", err)
continue continue
@ -682,11 +682,6 @@ func (n *network) processCtrlChan(listener tunnel.Listener) {
log.Tracef("Network router flushing routes for: %s", pbRtrSolicit.Id) log.Tracef("Network router flushing routes for: %s", pbRtrSolicit.Id)
// advertise all the routes when a new node has connected
if err := n.router.Solicit(); err != nil {
log.Debugf("Network failed to solicit routes: %s", err)
}
peer := &node{ peer := &node{
id: pbRtrSolicit.Id, id: pbRtrSolicit.Id,
link: m.msg.Header["Micro-Link"], link: m.msg.Header["Micro-Link"],
@ -698,6 +693,11 @@ func (n *network) processCtrlChan(listener tunnel.Listener) {
default: default:
// don't block // don't block
} }
// advertise all the routes when a new node has connected
if err := n.router.Solicit(); err != nil {
log.Debugf("Network failed to solicit routes: %s", err)
}
} }
case <-n.closed: case <-n.closed:
return return
@ -767,22 +767,31 @@ func (n *network) processNetChan(listener tunnel.Listener) {
// get node peers down to MaxDepth encoded in protobuf // get node peers down to MaxDepth encoded in protobuf
msg := PeersToProto(n.node, MaxDepth) msg := PeersToProto(n.node, MaxDepth)
// advertise yourself to the network go func() {
if err := n.sendTo("peer", NetworkChannel, peer, msg); err != nil { // advertise yourself to the new node
log.Debugf("Network failed to advertise peers: %v", err) if err := n.sendTo("peer", NetworkChannel, peer, msg); err != nil {
} log.Debugf("Network failed to advertise peers: %v", err)
}
// advertise all the routes when a new node has connected <-time.After(time.Millisecond * 100)
if err := n.router.Solicit(); err != nil {
log.Debugf("Network failed to solicit routes: %s", err)
}
// specify that we're soliciting // ask for the new nodes routes
select { if err := n.sendTo("solicit", ControlChannel, peer, msg); err != nil {
case n.solicited <- peer: log.Debugf("Network failed to send solicit message: %s", err)
default: }
// don't block
} // now advertise our own routes
select {
case n.solicited <- peer:
default:
// don't block
}
// advertise all the routes when a new node has connected
if err := n.router.Solicit(); err != nil {
log.Debugf("Network failed to solicit routes: %s", err)
}
}()
case "peer": case "peer":
// mark the time the message has been received // mark the time the message has been received
now := time.Now() now := time.Now()
@ -820,10 +829,32 @@ func (n *network) processNetChan(listener tunnel.Listener) {
Id: n.options.Id, Id: n.options.Id,
} }
// only solicit this peer go func() {
if err := n.sendTo("solicit", ControlChannel, peer, msg); err != nil { // advertise yourself to the peer
log.Debugf("Network failed to send solicit message: %s", err) if err := n.sendTo("peer", NetworkChannel, peer, msg); err != nil {
} log.Debugf("Network failed to advertise peers: %v", err)
}
// wait for a second
<-time.After(time.Millisecond * 100)
// then solicit this peer
if err := n.sendTo("solicit", ControlChannel, peer, msg); err != nil {
log.Debugf("Network failed to send solicit message: %s", err)
}
// now advertise our own routes
select {
case n.solicited <- peer:
default:
// don't block
}
// advertise all the routes when a new node has connected
if err := n.router.Solicit(); err != nil {
log.Debugf("Network failed to solicit routes: %s", err)
}
}()
continue continue
// we're expecting any error to be ErrPeerExists // we're expecting any error to be ErrPeerExists
@ -835,12 +866,15 @@ func (n *network) processNetChan(listener tunnel.Listener) {
log.Tracef("Network peer exists, refreshing: %s", pbNetPeer.Node.Id) log.Tracef("Network peer exists, refreshing: %s", pbNetPeer.Node.Id)
// update lastSeen time for the peer // update lastSeen time for the peer
if err := n.RefreshPeer(pbNetPeer.Node.Id, peer.link, now); err != nil { if err := n.RefreshPeer(peer.id, peer.link, now); err != nil {
log.Debugf("Network failed refreshing peer %s: %v", pbNetPeer.Node.Id, err) log.Debugf("Network failed refreshing peer %s: %v", pbNetPeer.Node.Id, err)
} }
// NOTE: we don't unpack MaxDepth toplogy // NOTE: we don't unpack MaxDepth toplogy
peer = UnpackPeerTopology(pbNetPeer, now, MaxDepth-1) peer = UnpackPeerTopology(pbNetPeer, now, MaxDepth-1)
// update the link
peer.link = m.msg.Header["Micro-Link"]
log.Tracef("Network updating topology of node: %s", n.node.id) log.Tracef("Network updating topology of node: %s", n.node.id)
if err := n.node.UpdatePeer(peer); err != nil { if err := n.node.UpdatePeer(peer); err != nil {
log.Debugf("Network failed to update peers: %v", err) log.Debugf("Network failed to update peers: %v", err)
@ -939,15 +973,109 @@ func (n *network) manage() {
resolve := time.NewTicker(ResolveTime) resolve := time.NewTicker(ResolveTime)
defer resolve.Stop() defer resolve.Stop()
// list of links we've sent to
links := make(map[string]time.Time)
for { for {
select { select {
case <-n.closed: case <-n.closed:
return return
case <-announce.C: case <-announce.C:
current := make(map[string]time.Time)
// build link map of current links
for _, link := range n.tunnel.Links() {
if n.isLoopback(link) {
continue
}
// get an existing timestamp if it exists
current[link.Id()] = links[link.Id()]
}
// replace link map
// we do this because a growing map is not
// garbage collected
links = current
n.RLock()
var i int
// create a list of peers to send to
var peers []*node
// check peers to see if they need to be sent to
for _, peer := range n.peers {
if i >= 3 {
break
}
// get last sent
lastSent := links[peer.link]
// check when we last sent to the peer
// and send a peer message if we havent
if lastSent.IsZero() || time.Since(lastSent) > KeepAliveTime {
link := peer.link
id := peer.id
// might not exist for some weird reason
if len(link) == 0 {
// set the link via peer links
l, ok := n.peerLinks[peer.address]
if ok {
log.Debugf("Network link not found for peer %s cannot announce", peer.id)
continue
}
link = l.Id()
}
// add to the list of peers we're going to send to
peers = append(peers, &node{
id: id,
link: link,
})
// increment our count
i++
}
}
n.RUnlock()
// peers to proto
msg := PeersToProto(n.node, MaxDepth) msg := PeersToProto(n.node, MaxDepth)
// advertise yourself to the network
if err := n.sendMsg("peer", NetworkChannel, msg); err != nil { // we're only going to send to max 3 peers at any given tick
log.Debugf("Network failed to advertise peers: %v", err) for _, peer := range peers {
// advertise yourself to the network
if err := n.sendTo("peer", NetworkChannel, peer, msg); err != nil {
log.Debugf("Network failed to advertise peer %s: %v", peer.id, err)
continue
}
// update last sent time
links[peer.link] = time.Now()
}
// now look at links we may not have sent to. this may occur
// where a connect message was lost
for link, lastSent := range links {
if !lastSent.IsZero() {
continue
}
peer := &node{
// unknown id of the peer
link: link,
}
// unknown link and peer so lets do the connect flow
if err := n.sendTo("connect", NetworkChannel, peer, msg); err != nil {
log.Debugf("Network failed to advertise peer %s: %v", peer.id, err)
continue
}
links[peer.link] = time.Now()
} }
case <-prune.C: case <-prune.C:
pruned := n.PruneStalePeers(PruneTime) pruned := n.PruneStalePeers(PruneTime)
@ -1023,21 +1151,34 @@ func (n *network) sendTo(method, channel string, peer *node, msg proto.Message)
if err != nil { if err != nil {
return err return err
} }
c, err := n.tunnel.Dial(channel, tunnel.DialMode(tunnel.Multicast), tunnel.DialLink(peer.link)) // Create a unicast connection to the peer but don't do the open/accept flow
c, err := n.tunnel.Dial(channel, tunnel.DialWait(false), tunnel.DialLink(peer.link))
if err != nil { if err != nil {
return err return err
} }
defer c.Close() defer c.Close()
log.Debugf("Network sending %s message from: %s to %s", method, n.options.Id, peer.id) id := peer.id
return c.Send(&transport.Message{ if len(id) == 0 {
id = peer.link
}
log.Debugf("Network sending %s message from: %s to %s", method, n.options.Id, id)
tmsg := &transport.Message{
Header: map[string]string{ Header: map[string]string{
"Micro-Method": method, "Micro-Method": method,
"Micro-Peer": peer.id,
}, },
Body: body, Body: body,
}) }
// setting the peer header
if len(peer.id) > 0 {
tmsg.Header["Micro-Peer"] = peer.id
}
return c.Send(tmsg)
} }
// sendMsg sends a message to the tunnel channel // sendMsg sends a message to the tunnel channel
@ -1105,6 +1246,27 @@ func (n *network) updatePeerLinks(peer *node) error {
return nil return nil
} }
// isLoopback checks if a link is a loopback to ourselves
func (n *network) isLoopback(link tunnel.Link) bool {
// our advertise address
loopback := n.server.Options().Advertise
// actual address
address := n.tunnel.Address()
// skip loopback
if link.Loopback() {
return true
}
// if remote is ourselves
switch link.Remote() {
case loopback, address:
return true
}
return false
}
// connect will wait for a link to be established and send the connect // connect will wait for a link to be established and send the connect
// message. We're trying to ensure convergence pretty quickly. So we want // message. We're trying to ensure convergence pretty quickly. So we want
// to hear back. In the case we become completely disconnected we'll // to hear back. In the case we become completely disconnected we'll
@ -1114,11 +1276,6 @@ func (n *network) connect() {
var discovered bool var discovered bool
var attempts int var attempts int
// our advertise address
loopback := n.server.Options().Advertise
// actual address
address := n.tunnel.Address()
for { for {
// connected is used to define if the link is connected // connected is used to define if the link is connected
var connected bool var connected bool
@ -1126,13 +1283,7 @@ func (n *network) connect() {
// check the links state // check the links state
for _, link := range n.tunnel.Links() { for _, link := range n.tunnel.Links() {
// skip loopback // skip loopback
if link.Loopback() { if n.isLoopback(link) {
continue
}
// if remote is ourselves
switch link.Remote() {
case loopback, address:
continue continue
} }
@ -1216,7 +1367,6 @@ func (n *network) Connect() error {
netListener, err := n.tunnel.Listen( netListener, err := n.tunnel.Listen(
NetworkChannel, NetworkChannel,
tunnel.ListenMode(tunnel.Multicast), tunnel.ListenMode(tunnel.Multicast),
tunnel.ListenTimeout(AnnounceTime*2),
) )
if err != nil { if err != nil {
return err return err
@ -1226,7 +1376,6 @@ func (n *network) Connect() error {
ctrlListener, err := n.tunnel.Listen( ctrlListener, err := n.tunnel.Listen(
ControlChannel, ControlChannel,
tunnel.ListenMode(tunnel.Multicast), tunnel.ListenMode(tunnel.Multicast),
tunnel.ListenTimeout(router.AdvertiseTableTick*2),
) )
if err != nil { if err != nil {
return err return err
@ -1353,6 +1502,7 @@ func (n *network) Close() error {
default: default:
// TODO: send close message to the network channel // TODO: send close message to the network channel
close(n.closed) close(n.closed)
// set connected to false // set connected to false
n.connected = false n.connected = false
@ -1369,6 +1519,7 @@ func (n *network) Close() error {
if err := n.sendMsg("close", NetworkChannel, msg); err != nil { if err := n.sendMsg("close", NetworkChannel, msg); err != nil {
log.Debugf("Network failed to send close message: %s", err) log.Debugf("Network failed to send close message: %s", err)
} }
<-time.After(time.Millisecond * 100)
} }
return n.close() return n.close()

View File

@ -16,7 +16,9 @@ var (
// ResolveTime defines time interval to periodically resolve network nodes // ResolveTime defines time interval to periodically resolve network nodes
ResolveTime = 1 * time.Minute ResolveTime = 1 * time.Minute
// AnnounceTime defines time interval to periodically announce node neighbours // AnnounceTime defines time interval to periodically announce node neighbours
AnnounceTime = 30 * time.Second AnnounceTime = 1 * time.Second
// KeepAliveTime is the time in which we want to have sent a message to a peer
KeepAliveTime = 30 * time.Second
// PruneTime defines time interval to periodically check nodes that need to be pruned // PruneTime defines time interval to periodically check nodes that need to be pruned
// due to their not announcing their presence within this time interval // due to their not announcing their presence within this time interval
PruneTime = 90 * time.Second PruneTime = 90 * time.Second

View File

@ -140,10 +140,8 @@ func (n *node) RefreshPeer(id, link string, now time.Time) error {
// set peer link // set peer link
peer.link = link peer.link = link
// set last seen
if peer.lastSeen.Before(now) { peer.lastSeen = now
peer.lastSeen = now
}
return nil return nil
} }

View File

@ -143,7 +143,7 @@ func (r *runtime) run(events <-chan Event) {
} }
} }
case <-r.closed: case <-r.closed:
log.Debugf("Runtime stopped.") log.Debugf("Runtime stopped")
return return
} }
} }

View File

@ -60,6 +60,11 @@ func Decrypt(data []byte, key string) ([]byte, error) {
} }
nonceSize := gcm.NonceSize() nonceSize := gcm.NonceSize()
if len(data) < nonceSize {
return nil, ErrDecryptingData
}
// NOTE: we need to parse out nonce from the payload // NOTE: we need to parse out nonce from the payload
// we prepend the nonce to every encrypted payload // we prepend the nonce to every encrypted payload
nonce, ciphertext := data[:nonceSize], data[nonceSize:] nonce, ciphertext := data[:nonceSize], data[nonceSize:]

View File

@ -14,7 +14,7 @@ import (
var ( var (
// DiscoverTime sets the time at which we fire discover messages // DiscoverTime sets the time at which we fire discover messages
DiscoverTime = 60 * time.Second DiscoverTime = 30 * time.Second
// KeepAliveTime defines time interval we send keepalive messages to outbound links // KeepAliveTime defines time interval we send keepalive messages to outbound links
KeepAliveTime = 30 * time.Second KeepAliveTime = 30 * time.Second
// ReconnectTime defines time interval we periodically attempt to reconnect dead links // ReconnectTime defines time interval we periodically attempt to reconnect dead links
@ -54,6 +54,7 @@ type tun struct {
// create new tunnel on top of a link // create new tunnel on top of a link
func newTunnel(opts ...Option) *tun { func newTunnel(opts ...Option) *tun {
rand.Seed(time.Now().UnixNano())
options := DefaultOptions() options := DefaultOptions()
for _, o := range opts { for _, o := range opts {
o(&options) o(&options)
@ -73,10 +74,10 @@ func newTunnel(opts ...Option) *tun {
// Init initializes tunnel options // Init initializes tunnel options
func (t *tun) Init(opts ...Option) error { func (t *tun) Init(opts ...Option) error {
t.Lock() t.Lock()
defer t.Unlock()
for _, o := range opts { for _, o := range opts {
o(&t.options) o(&t.options)
} }
t.Unlock()
return nil return nil
} }
@ -103,7 +104,6 @@ func (t *tun) delSession(channel, session string) {
// listChannels returns a list of listening channels // listChannels returns a list of listening channels
func (t *tun) listChannels() []string { func (t *tun) listChannels() []string {
t.RLock() t.RLock()
defer t.RUnlock()
//nolint:prealloc //nolint:prealloc
var channels []string var channels []string
@ -113,6 +113,9 @@ func (t *tun) listChannels() []string {
} }
channels = append(channels, session.channel) channels = append(channels, session.channel)
} }
t.RUnlock()
return channels return channels
} }
@ -220,6 +223,12 @@ func (t *tun) manageLink(link *link) {
discover := time.NewTicker(DiscoverTime) discover := time.NewTicker(DiscoverTime)
defer discover.Stop() defer discover.Stop()
wait := func(d time.Duration) {
// jitter
j := rand.Int63n(int64(d.Seconds() / 2.0))
time.Sleep(time.Duration(j) * time.Second)
}
for { for {
select { select {
case <-t.closed: case <-t.closed:
@ -227,11 +236,18 @@ func (t *tun) manageLink(link *link) {
case <-link.closed: case <-link.closed:
return return
case <-discover.C: case <-discover.C:
// send a discovery message to all links // wait half the discover time
wait(DiscoverTime)
// send a discovery message to the link
log.Debugf("Tunnel sending discover to link: %v", link.Remote())
if err := t.sendMsg("discover", link); err != nil { if err := t.sendMsg("discover", link); err != nil {
log.Debugf("Tunnel failed to send discover to link %s: %v", link.Remote(), err) log.Debugf("Tunnel failed to send discover to link %s: %v", link.Remote(), err)
} }
case <-keepalive.C: case <-keepalive.C:
// wait half the keepalive time
wait(KeepAliveTime)
// send keepalive message // send keepalive message
log.Debugf("Tunnel sending keepalive to link: %v", link.Remote()) log.Debugf("Tunnel sending keepalive to link: %v", link.Remote())
if err := t.sendMsg("keepalive", link); err != nil { if err := t.sendMsg("keepalive", link); err != nil {
@ -244,53 +260,70 @@ func (t *tun) manageLink(link *link) {
} }
// manageLinks is a function that can be called to immediately to link setup // manageLinks is a function that can be called to immediately to link setup
// it purges dead links while generating new links for any nodes not connected
func (t *tun) manageLinks() { func (t *tun) manageLinks() {
var delLinks []string delLinks := make(map[*link]string)
connected := make(map[string]bool)
t.RLock() t.RLock()
// get list of nodes from options
nodes := t.options.Nodes
// check the link status and purge dead links // check the link status and purge dead links
for node, link := range t.links { for node, link := range t.links {
// check link status // check link status
switch link.State() { switch link.State() {
case "closed": case "closed", "error":
delLinks = append(delLinks, node) delLinks[link] = node
case "error": default:
delLinks = append(delLinks, node) connected[node] = true
} }
} }
t.RUnlock() t.RUnlock()
// build a list of links to connect to
var connect []string
for _, node := range nodes {
// check if we're connected
if _, ok := connected[node]; ok {
continue
}
// add nodes to connect o
connect = append(connect, node)
}
// delete the dead links // delete the dead links
if len(delLinks) > 0 { if len(delLinks) > 0 {
t.Lock() t.Lock()
for _, node := range delLinks {
for link, node := range delLinks {
log.Debugf("Tunnel deleting dead link for %s", node) log.Debugf("Tunnel deleting dead link for %s", node)
if link, ok := t.links[node]; ok {
link.Close() // check if the link exists
l, ok := t.links[node]
if ok {
// close and delete
l.Close()
delete(t.links, node) delete(t.links, node)
} }
// if the link does not match our own
if l != link {
// close our link just in case
link.Close()
}
} }
t.Unlock() t.Unlock()
} }
// check current link status
var connect []string
// build list of unknown nodes to connect to
t.RLock()
for _, node := range t.options.Nodes {
if _, ok := t.links[node]; !ok {
connect = append(connect, node)
}
}
t.RUnlock()
var wg sync.WaitGroup var wg sync.WaitGroup
// establish new links
for _, node := range connect { for _, node := range connect {
wg.Add(1) wg.Add(1)
@ -298,23 +331,26 @@ func (t *tun) manageLinks() {
defer wg.Done() defer wg.Done()
// create new link // create new link
// if we're using quic it should be a max 10 second handshake period
link, err := t.setupLink(node) link, err := t.setupLink(node)
if err != nil { if err != nil {
log.Debugf("Tunnel failed to setup node link to %s: %v", node, err) log.Debugf("Tunnel failed to setup node link to %s: %v", node, err)
return return
} }
// save the link
t.Lock() t.Lock()
defer t.Unlock()
// just check nothing else was setup in the interim // just check nothing else was setup in the interim
if _, ok := t.links[node]; ok { if _, ok := t.links[node]; ok {
link.Close() link.Close()
t.Unlock()
return return
} }
// save the link // save the link
t.links[node] = link t.links[node] = link
t.Unlock()
}(node) }(node)
} }
@ -329,43 +365,14 @@ func (t *tun) process() {
for { for {
select { select {
case msg := <-t.send: case msg := <-t.send:
newMsg := &transport.Message{ // build a list of links to send to
Header: make(map[string]string), var sendTo []*link
} var err error
// set the data
if msg.data != nil {
for k, v := range msg.data.Header {
newMsg.Header[k] = v
}
newMsg.Body = msg.data.Body
}
// set message head
newMsg.Header["Micro-Tunnel"] = msg.typ
// set the tunnel id on the outgoing message
newMsg.Header["Micro-Tunnel-Id"] = msg.tunnel
// set the tunnel channel on the outgoing message
newMsg.Header["Micro-Tunnel-Channel"] = msg.channel
// set the session id
newMsg.Header["Micro-Tunnel-Session"] = msg.session
// send the message via the interface
t.RLock() t.RLock()
if len(t.links) == 0 {
log.Debugf("No links to send message type: %s channel: %s", msg.typ, msg.channel)
}
var sent bool
var err error
var sendTo []*link
// build the list of links ot send to // build the list of links ot send to
for node, link := range t.links { for _, link := range t.links {
// get the values we need // get the values we need
link.RLock() link.RLock()
id := link.id id := link.id
@ -376,7 +383,7 @@ func (t *tun) process() {
// if the link is not connected skip it // if the link is not connected skip it
if !connected { if !connected {
log.Debugf("Link for node %s not connected", node) log.Debugf("Link for node %s not connected", id)
err = errors.New("link not connected") err = errors.New("link not connected")
continue continue
} }
@ -385,7 +392,6 @@ func (t *tun) process() {
// and the message is being sent outbound via // and the message is being sent outbound via
// a dialled connection don't use this link // a dialled connection don't use this link
if loopback && msg.outbound { if loopback && msg.outbound {
log.Tracef("Link for node %s is loopback", node)
err = errors.New("link is loopback") err = errors.New("link is loopback")
continue continue
} }
@ -393,7 +399,6 @@ func (t *tun) process() {
// if the message was being returned by the loopback listener // if the message was being returned by the loopback listener
// send it back up the loopback link only // send it back up the loopback link only
if msg.loopback && !loopback { if msg.loopback && !loopback {
log.Tracef("Link for message %s is loopback", node)
err = errors.New("link is not loopback") err = errors.New("link is not loopback")
continue continue
} }
@ -420,56 +425,125 @@ func (t *tun) process() {
t.RUnlock() t.RUnlock()
// send the message // no links to send to
for _, link := range sendTo { if len(sendTo) == 0 {
// send the message via the current link log.Debugf("No links to send message type: %s channel: %s", msg.typ, msg.channel)
log.Tracef("Tunnel sending %+v to %s", newMsg.Header, link.Remote()) t.respond(msg, err)
if errr := link.Send(newMsg); errr != nil {
log.Debugf("Tunnel error sending %+v to %s: %v", newMsg.Header, link.Remote(), errr)
err = errors.New(errr.Error())
t.delLink(link.Remote())
continue
}
// is sent
sent = true
// keep sending broadcast messages
if msg.mode > Unicast {
continue
}
// break on unicast
break
}
var gerr error
// set the error if not sent
if !sent {
gerr = err
}
// skip if its not been set
if msg.errChan == nil {
continue continue
} }
// return error non blocking // send the message
select { t.sendTo(sendTo, msg)
case msg.errChan <- gerr:
default:
}
case <-t.closed: case <-t.closed:
return return
} }
} }
} }
// send response back for a message to the caller
func (t *tun) respond(msg *message, err error) {
select {
case msg.errChan <- err:
default:
}
}
// sendTo sends a message to the chosen links
func (t *tun) sendTo(links []*link, msg *message) error {
// the function that sends the actual message
send := func(link *link, msg *transport.Message) error {
if err := link.Send(msg); err != nil {
log.Debugf("Tunnel error sending %+v to %s: %v", msg.Header, link.Remote(), err)
t.delLink(link.Remote())
return err
}
return nil
}
newMsg := &transport.Message{
Header: make(map[string]string),
}
// set the data
if msg.data != nil {
for k, v := range msg.data.Header {
newMsg.Header[k] = v
}
newMsg.Body = msg.data.Body
}
// set message head
newMsg.Header["Micro-Tunnel"] = msg.typ
// set the tunnel id on the outgoing message
newMsg.Header["Micro-Tunnel-Id"] = msg.tunnel
// set the tunnel channel on the outgoing message
newMsg.Header["Micro-Tunnel-Channel"] = msg.channel
// set the session id
newMsg.Header["Micro-Tunnel-Session"] = msg.session
// error channel for call
errChan := make(chan error, len(links))
// execute in parallel
sendTo := func(l *link, m *transport.Message, errChan chan error) {
errChan <- send(l, m)
}
// send the message
for _, link := range links {
// send the message via the current link
log.Tracef("Tunnel sending %+v to %s", newMsg.Header, link.Remote())
// blast it in a go routine since its multicast/broadcast
if msg.mode > Unicast {
// make a copy
m := &transport.Message{
Header: make(map[string]string),
Body: make([]byte, len(newMsg.Body)),
}
copy(m.Body, newMsg.Body)
for k, v := range newMsg.Header {
m.Header[k] = v
}
go sendTo(link, m, errChan)
continue
}
// otherwise send as unicast
if err := send(link, newMsg); err != nil {
// put in the error chan if it failed
errChan <- err
continue
}
// sent successfully so just return
t.respond(msg, nil)
return nil
}
// either all unicast attempts failed or we're
// checking the multicast/broadcast attempts
var err error
// check all the errors
for i := 0; i < len(links); i++ {
err = <-errChan
// success
if err == nil {
break
}
}
// return error. it's non blocking
t.respond(msg, err)
return err
}
func (t *tun) delLink(remote string) { func (t *tun) delLink(remote string) {
t.Lock() t.Lock()
defer t.Unlock()
// get the link // get the link
for id, link := range t.links { for id, link := range t.links {
@ -481,6 +555,8 @@ func (t *tun) delLink(remote string) {
link.Close() link.Close()
delete(t.links, id) delete(t.links, id)
} }
t.Unlock()
} }
// process incoming messages // process incoming messages
@ -564,7 +640,6 @@ func (t *tun) listen(link *link) {
// nothing more to do // nothing more to do
continue continue
case "close": case "close":
log.Debugf("Tunnel link %s received close message", link.Remote())
// if there is no channel then we close the link // if there is no channel then we close the link
// as its a signal from the other side to close the connection // as its a signal from the other side to close the connection
if len(channel) == 0 { if len(channel) == 0 {
@ -572,10 +647,13 @@ func (t *tun) listen(link *link) {
return return
} }
log.Debugf("Tunnel link %s received close message for %s", link.Remote(), channel)
// the entire listener was closed by the remote side so we need to // the entire listener was closed by the remote side so we need to
// remove the channel mapping for it. should we also close sessions? // remove the channel mapping for it. should we also close sessions?
if sessionId == "listener" { if sessionId == "listener" {
link.delChannel(channel) link.delChannel(channel)
// TODO: find all the non listener unicast sessions
// and close them. think aboud edge cases first
continue continue
} }
@ -947,6 +1025,11 @@ func (t *tun) pickLink(links []*link) *link {
continue continue
} }
// skip the loopback
if link.Loopback() {
continue
}
// get the link state info // get the link state info
d := float64(link.Delay()) d := float64(link.Delay())
l := float64(link.Length()) l := float64(link.Length())
@ -1020,33 +1103,38 @@ func (t *tun) Close() error {
// Dial an address // Dial an address
func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) {
log.Debugf("Tunnel dialing %s", channel) // get the options
c, ok := t.newSession(channel, t.newSessionId())
if !ok {
return nil, errors.New("error dialing " + channel)
}
// set remote
c.remote = channel
// set local
c.local = "local"
// outbound session
c.outbound = true
// get opts
options := DialOptions{ options := DialOptions{
Timeout: DefaultDialTimeout, Timeout: DefaultDialTimeout,
Wait: true,
} }
for _, o := range opts { for _, o := range opts {
o(&options) o(&options)
} }
// set the multicast option log.Debugf("Tunnel dialing %s", channel)
// create a new session
c, ok := t.newSession(channel, t.newSessionId())
if !ok {
return nil, errors.New("error dialing " + channel)
}
// set remote
c.remote = channel
// set local
c.local = "local"
// outbound session
c.outbound = true
// set the mode of connection unicast/multicast/broadcast
c.mode = options.Mode c.mode = options.Mode
// set the dial timeout // set the dial timeout
c.dialTimeout = options.Timeout c.dialTimeout = options.Timeout
// set read timeout set to never // set read timeout set to never
c.readTimeout = time.Duration(-1) c.readTimeout = time.Duration(-1)
// set the link
c.link = options.Link
var links []*link var links []*link
// did we measure the rtt // did we measure the rtt
@ -1057,7 +1145,7 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) {
// non multicast so we need to find the link // non multicast so we need to find the link
for _, link := range t.links { for _, link := range t.links {
// use the link specified it its available // use the link specified it its available
if id := options.Link; len(id) > 0 && link.id != id { if len(c.link) > 0 && link.id != c.link {
continue continue
} }
@ -1073,20 +1161,36 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) {
t.RUnlock() t.RUnlock()
// link not found and one was specified so error out // link option was specified to pick the link
if len(links) == 0 && len(options.Link) > 0 { if len(options.Link) > 0 {
// delete session and return error // link not found and one was specified so error out
t.delSession(c.channel, c.session) if len(links) == 0 {
log.Debugf("Tunnel deleting session %s %s: %v", c.session, c.channel, ErrLinkNotFound) // delete session and return error
return nil, ErrLinkNotFound t.delSession(c.channel, c.session)
log.Debugf("Tunnel deleting session %s %s: %v", c.session, c.channel, ErrLinkNotFound)
return nil, ErrLinkNotFound
}
// assume discovered because we picked
c.discovered = true
// link asked for and found and now
// we've been asked not to wait so return
if !options.Wait {
c.accepted = true
return c, nil
}
} }
// discovered so set the link if not multicast // discovered so set the link if not multicast
if c.discovered && c.mode == Unicast { if c.discovered && c.mode == Unicast {
// pickLink will pick the best link // pick a link if not specified
link := t.pickLink(links) if len(c.link) == 0 {
// set the link // pickLink will pick the best link
c.link = link.id link := t.pickLink(links)
// set the link
c.link = link.id
}
} }
// if its not already discovered we need to attempt to do so // if its not already discovered we need to attempt to do so
@ -1119,8 +1223,9 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) {
} }
// return early if its not unicast // return early if its not unicast
// we will not call "open" for multicast // we will not wait for "open" for multicast
if c.mode != Unicast { // and we will not wait it told not to
if c.mode != Unicast || !options.Wait {
return c, nil return c, nil
} }
@ -1213,23 +1318,20 @@ func (t *tun) Listen(channel string, opts ...ListenOption) (Listener, error) {
// to the existign sessions // to the existign sessions
go tl.process() go tl.process()
// announces the listener channel to others
go tl.announce()
// return the listener // return the listener
return tl, nil return tl, nil
} }
func (t *tun) Links() []Link { func (t *tun) Links() []Link {
t.RLock() t.RLock()
defer t.RUnlock()
links := make([]Link, 0, len(t.links)) links := make([]Link, 0, len(t.links))
for _, link := range t.links { for _, link := range t.links {
links = append(links, link) links = append(links, link)
} }
t.RUnlock()
return links return links
} }

View File

@ -22,6 +22,8 @@ type link struct {
// stops the link // stops the link
closed chan bool closed chan bool
// metric used to track metrics
metric chan *metric
// link state channel for testing link // link state channel for testing link
state chan *packet state chan *packet
// send queue for sending packets // send queue for sending packets
@ -65,6 +67,16 @@ type packet struct {
err error err error
} }
// metric is used to record link rate
type metric struct {
// amount of data sent
data int
// time taken to send
duration time.Duration
// if an error occurred
status error
}
var ( var (
// the 4 byte 0 packet sent to determine the link state // the 4 byte 0 packet sent to determine the link state
linkRequest = []byte{0, 0, 0, 0} linkRequest = []byte{0, 0, 0, 0}
@ -84,6 +96,7 @@ func newLink(s transport.Socket) *link {
state: make(chan *packet, 64), state: make(chan *packet, 64),
sendQueue: make(chan *packet, 128), sendQueue: make(chan *packet, 128),
recvQueue: make(chan *packet, 128), recvQueue: make(chan *packet, 128),
metric: make(chan *metric, 128),
} }
// process inbound/outbound packets // process inbound/outbound packets
@ -138,10 +151,10 @@ func (l *link) setRate(bits int64, delta time.Duration) {
// setRTT sets a nanosecond based moving average roundtrip time for the link // setRTT sets a nanosecond based moving average roundtrip time for the link
func (l *link) setRTT(d time.Duration) { func (l *link) setRTT(d time.Duration) {
l.Lock() l.Lock()
defer l.Unlock()
if l.length <= 0 { if l.length <= 0 {
l.length = d.Nanoseconds() l.length = d.Nanoseconds()
l.Unlock()
return return
} }
@ -149,6 +162,8 @@ func (l *link) setRTT(d time.Duration) {
length := 0.8*float64(l.length) + 0.2*float64(d.Nanoseconds()) length := 0.8*float64(l.length) + 0.2*float64(d.Nanoseconds())
// set new length // set new length
l.length = int64(length) l.length = int64(length)
l.Unlock()
} }
func (l *link) delChannel(ch string) { func (l *link) delChannel(ch string) {
@ -159,8 +174,9 @@ func (l *link) delChannel(ch string) {
func (l *link) getChannel(ch string) time.Time { func (l *link) getChannel(ch string) time.Time {
l.RLock() l.RLock()
defer l.RUnlock() t := l.channels[ch]
return l.channels[ch] l.RUnlock()
return t
} }
func (l *link) setChannel(channels ...string) { func (l *link) setChannel(channels ...string) {
@ -186,9 +202,11 @@ func (l *link) process() {
m := new(transport.Message) m := new(transport.Message)
err := l.recv(m) err := l.recv(m)
if err != nil { if err != nil {
l.Lock() // record the metric
l.errCount++ select {
l.Unlock() case l.metric <- &metric{status: err}:
default:
}
} }
// process new received message // process new received message
@ -237,8 +255,12 @@ func (l *link) process() {
// manage manages the link state including rtt packets and channel mapping expiry // manage manages the link state including rtt packets and channel mapping expiry
func (l *link) manage() { func (l *link) manage() {
// tick over every minute to expire and fire rtt packets // tick over every minute to expire and fire rtt packets
t := time.NewTicker(time.Minute) t1 := time.NewTicker(time.Minute)
defer t.Stop() defer t1.Stop()
// used to batch update link metrics
t2 := time.NewTicker(time.Second * 5)
defer t2.Stop()
// get link id // get link id
linkId := l.Id() linkId := l.Id()
@ -287,7 +309,7 @@ func (l *link) manage() {
// set the RTT // set the RTT
l.setRTT(d) l.setRTT(d)
} }
case <-t.C: case <-t1.C:
// drop any channel mappings older than 2 minutes // drop any channel mappings older than 2 minutes
var kill []string var kill []string
killTime := time.Minute * 2 killTime := time.Minute * 2
@ -315,10 +337,60 @@ func (l *link) manage() {
// fire off a link state rtt packet // fire off a link state rtt packet
now = time.Now() now = time.Now()
send(linkRequest) send(linkRequest)
case <-t2.C:
// get a batch of metrics
batch := l.batch()
// skip if there's no metrics
if len(batch) == 0 {
continue
}
// lock once to record a batch
l.Lock()
for _, metric := range batch {
l.record(metric)
}
l.Unlock()
} }
} }
} }
func (l *link) batch() []*metric {
var metrics []*metric
// pull all the metrics
for {
select {
case m := <-l.metric:
metrics = append(metrics, m)
// non blocking return
default:
return metrics
}
}
}
func (l *link) record(m *metric) {
// there's an error increment the counter and bail
if m.status != nil {
l.errCount++
return
}
// reset the counter
l.errCount = 0
// calculate based on data
if m.data > 0 {
// bit sent
bits := m.data * 1024
// set the rate
l.setRate(int64(bits), m.duration)
}
}
func (l *link) send(m *transport.Message) error { func (l *link) send(m *transport.Message) error {
if m.Header == nil { if m.Header == nil {
m.Header = make(map[string]string) m.Header = make(map[string]string)
@ -344,28 +416,32 @@ func (l *link) Delay() int64 {
// Current transfer rate as bits per second (lower is better) // Current transfer rate as bits per second (lower is better)
func (l *link) Rate() float64 { func (l *link) Rate() float64 {
l.RLock() l.RLock()
defer l.RUnlock() r := l.rate
return l.rate l.RUnlock()
return r
} }
func (l *link) Loopback() bool { func (l *link) Loopback() bool {
l.RLock() l.RLock()
defer l.RUnlock() lo := l.loopback
return l.loopback l.RUnlock()
return lo
} }
// Length returns the roundtrip time as nanoseconds (lower is better). // Length returns the roundtrip time as nanoseconds (lower is better).
// Returns 0 where no measurement has been taken. // Returns 0 where no measurement has been taken.
func (l *link) Length() int64 { func (l *link) Length() int64 {
l.RLock() l.RLock()
defer l.RUnlock() length := l.length
return l.length l.RUnlock()
return length
} }
func (l *link) Id() string { func (l *link) Id() string {
l.RLock() l.RLock()
defer l.RUnlock() id := l.id
return l.id l.RUnlock()
return id
} }
func (l *link) Close() error { func (l *link) Close() error {
@ -391,6 +467,14 @@ func (l *link) Send(m *transport.Message) error {
status: make(chan error, 1), status: make(chan error, 1),
} }
// calculate the data sent
dataSent := len(m.Body)
// set header length
for k, v := range m.Header {
dataSent += (len(k) + len(v))
}
// get time now // get time now
now := time.Now() now := time.Now()
@ -412,33 +496,19 @@ func (l *link) Send(m *transport.Message) error {
case err = <-p.status: case err = <-p.status:
} }
l.Lock() // create a metric with
defer l.Unlock() // time taken, size of package, error status
mt := &metric{
// there's an error increment the counter and bail data: dataSent,
if err != nil { duration: time.Since(now),
l.errCount++ status: err,
return err
} }
// reset the counter // pass back a metric
l.errCount = 0 // do not block
select {
// calculate the data sent case l.metric <- mt:
dataSent := len(m.Body) default:
// set header length
for k, v := range m.Header {
dataSent += (len(k) + len(v))
}
// calculate based on data
if dataSent > 0 {
// bit sent
bits := dataSent * 1024
// set the rate
l.setRate(int64(bits), time.Since(now))
} }
return nil return nil
@ -476,10 +546,13 @@ func (l *link) State() string {
return "closed" return "closed"
default: default:
l.RLock() l.RLock()
defer l.RUnlock() errCount := l.errCount
if l.errCount > 3 { l.RUnlock()
if errCount > 3 {
return "error" return "error"
} }
return "connected" return "connected"
} }
} }

View File

@ -2,7 +2,7 @@ package tunnel
import ( import (
"io" "io"
"time" "sync"
"github.com/micro/go-micro/util/log" "github.com/micro/go-micro/util/log"
) )
@ -14,32 +14,16 @@ type tunListener struct {
token string token string
// the accept channel // the accept channel
accept chan *session accept chan *session
// the channel to close
closed chan bool
// the tunnel closed channel // the tunnel closed channel
tunClosed chan bool tunClosed chan bool
// the listener session // the listener session
session *session session *session
// del func to kill listener // del func to kill listener
delFunc func() delFunc func()
}
// periodically announce self the channel being listened on sync.RWMutex
func (t *tunListener) announce() { // the channel to close
tick := time.NewTicker(time.Second * 30) closed chan bool
defer tick.Stop()
// first announcement
t.session.Announce()
for {
select {
case <-tick.C:
t.session.Announce()
case <-t.closed:
return
}
}
} }
func (t *tunListener) process() { func (t *tunListener) process() {
@ -68,7 +52,7 @@ func (t *tunListener) process() {
var sessionId string var sessionId string
var linkId string var linkId string
switch m.mode { switch t.session.mode {
case Multicast: case Multicast:
sessionId = "multicast" sessionId = "multicast"
linkId = "multicast" linkId = "multicast"
@ -106,7 +90,7 @@ func (t *tunListener) process() {
// the link the message was received on // the link the message was received on
link: linkId, link: linkId,
// set the connection mode // set the connection mode
mode: m.mode, mode: t.session.mode,
// close chan // close chan
closed: make(chan bool), closed: make(chan bool),
// recv called by the acceptor // recv called by the acceptor
@ -134,6 +118,11 @@ func (t *tunListener) process() {
switch m.typ { switch m.typ {
case "close": case "close":
// don't close multicast sessions
if sess.mode > Unicast {
continue
}
// received a close message // received a close message
select { select {
// check if the session is closed // check if the session is closed
@ -173,6 +162,9 @@ func (t *tunListener) Channel() string {
// Close closes tunnel listener // Close closes tunnel listener
func (t *tunListener) Close() error { func (t *tunListener) Close() error {
t.Lock()
defer t.Unlock()
select { select {
case <-t.closed: case <-t.closed:
return nil return nil

View File

@ -38,6 +38,8 @@ type DialOptions struct {
Link string Link string
// specify mode of the session // specify mode of the session
Mode Mode Mode Mode
// Wait for connection to be accepted
Wait bool
// the dial timeout // the dial timeout
Timeout time.Duration Timeout time.Duration
} }
@ -124,6 +126,14 @@ func DialLink(id string) DialOption {
} }
} }
// DialWait specifies whether to wait for the connection
// to be accepted before returning the session
func DialWait(b bool) DialOption {
return func(o *DialOptions) {
o.Wait = b
}
}
// DefaultOptions returns router default options // DefaultOptions returns router default options
func DefaultOptions() Options { func DefaultOptions() Options {
return Options{ return Options{

View File

@ -28,6 +28,8 @@ var (
ErrLinkNotFound = errors.New("link not found") ErrLinkNotFound = errors.New("link not found")
// ErrReadTimeout is a timeout on session.Recv // ErrReadTimeout is a timeout on session.Recv
ErrReadTimeout = errors.New("read timeout") ErrReadTimeout = errors.New("read timeout")
// ErrDecryptingData is for when theres a nonce error
ErrDecryptingData = errors.New("error decrypting data")
) )
// Mode of the session // Mode of the session