diff --git a/network/default.go b/network/default.go index 831118a5..e80f5d8d 100644 --- a/network/default.go +++ b/network/default.go @@ -38,6 +38,8 @@ var ( DefaultLink = "network" // MaxConnections is the max number of network client connections MaxConnections = 3 + // MaxPeerErrors is the max number of peer errors before we remove it from network graph + MaxPeerErrors = 3 ) var ( @@ -45,6 +47,8 @@ var ( ErrClientNotFound = errors.New("client not found") // ErrPeerLinkNotFound is returned when peer link could not be found in tunnel Links ErrPeerLinkNotFound = errors.New("peer link not found") + // ErrPeerMaxExceeded is returned when peer has reached its max error count limit + ErrPeerMaxExceeded = errors.New("peer max errors exceeded") ) // network implements Network interface @@ -1327,6 +1331,11 @@ func (n *network) sendTo(method, channel string, peer *node, msg proto.Message) // Create a unicast connection to the peer but don't do the open/accept flow c, err := n.tunnel.Dial(channel, tunnel.DialWait(false), tunnel.DialLink(peer.link)) if err != nil { + // increment the peer error count; prune peer if we exceed MaxPeerErrors + peer.err.Increment() + if count := peer.err.GetCount(); count == MaxPeerErrors { + n.PrunePeer(peer.id) + } return err } defer c.Close() @@ -1351,7 +1360,16 @@ func (n *network) sendTo(method, channel string, peer *node, msg proto.Message) tmsg.Header["Micro-Peer"] = peer.id } - return c.Send(tmsg) + if err := c.Send(tmsg); err != nil { + // increment the peer error count; prune peer if we exceed MaxPeerErrors + peer.err.Increment() + if count := peer.err.GetCount(); count == MaxPeerErrors { + n.PrunePeer(peer.id) + } + return err + } + + return nil } // sendMsg sends a message to the tunnel channel diff --git a/network/node.go b/network/node.go index 159f15f9..a5d3ad91 100644 --- a/network/node.go +++ b/network/node.go @@ -21,6 +21,35 @@ var ( ErrPeerNotFound = errors.New("peer not found") ) +type nodeError struct { + sync.RWMutex + count int +} + +// Increment increments node error count +func (n *nodeError) Increment() { + n.Lock() + defer n.Unlock() + + n.count++ +} + +// Reset reset node error count +func (n *nodeError) Reset() { + n.Lock() + defer n.Unlock() + + n.count = 0 +} + +// GetCount returns node error count +func (n *nodeError) GetCount() int { + n.RLock() + defer n.RUnlock() + + return n.count +} + // node is network node type node struct { sync.RWMutex @@ -38,8 +67,8 @@ type node struct { lastSeen time.Time // lastSync keeps track of node last sync request lastSync time.Time - // errCount tracks error count when communicating with peer - errCount int + // err tracks node errors + err nodeError } // Id is node ide @@ -274,30 +303,6 @@ func (n *node) PruneStalePeers(pruneTime time.Duration) map[string]*node { return pruned } -// IncErrCount increments node error count -func (n *node) IncErrCount() { - n.Lock() - defer n.Unlock() - - n.errCount++ -} - -// ResetErrCount reset node error count -func (n *node) ResetErrCount() { - n.Lock() - defer n.Unlock() - - n.errCount = 0 -} - -// ErrCount returns node error count -func (n *node) ErrCount() int { - n.RLock() - defer n.RUnlock() - - return n.errCount -} - // getTopology traverses node graph and builds node topology // NOTE: this function is not thread safe func (n *node) getTopology(depth uint) *node {