diff --git a/network/node.go b/network/node.go index ec9b9962..159f15f9 100644 --- a/network/node.go +++ b/network/node.go @@ -38,6 +38,8 @@ type node struct { lastSeen time.Time // lastSync keeps track of node last sync request lastSync time.Time + // errCount tracks error count when communicating with peer + errCount int } // Id is node ide @@ -71,15 +73,17 @@ func (n *node) walk(until func(peer *node) bool, action func(parent, peer *node) for queue.Len() > 0 { // pop the node from the front of the queue qnode := queue.Front() + //fmt.Printf("qnodeValue: %v\n", qnode.Value.(*node)) if until(qnode.Value.(*node)) { return visited } // iterate through all of the node peers // mark the visited nodes; enqueue the non-visted for id, peer := range qnode.Value.(*node).peers { + action(qnode.Value.(*node), peer) if _, ok := visited[id]; !ok { visited[id] = peer - action(qnode.Value.(*node), peer) + //action(qnode.Value.(*node), peer) queue.PushBack(peer) } } @@ -229,7 +233,25 @@ func (n *node) DeletePeerNode(id string) error { return nil } -// PruneStalePeerNodes prune the peers that have not been seen for longer than given time +// PrunePeer prunes the peers with the given id +func (n *node) PrunePeer(id string) { + n.Lock() + defer n.Unlock() + + untilNoMorePeers := func(node *node) bool { + return node == nil + } + + prunePeer := func(parent, node *node) { + if node.id != n.id && node.id == id { + delete(parent.peers, node.id) + } + } + + n.walk(untilNoMorePeers, prunePeer) +} + +// PruneStalePeerNodes prunes the peers that have not been seen for longer than pruneTime // It returns a map of the the nodes that got pruned func (n *node) PruneStalePeers(pruneTime time.Duration) map[string]*node { n.Lock() @@ -252,6 +274,30 @@ func (n *node) PruneStalePeers(pruneTime time.Duration) map[string]*node { return pruned } +// IncErrCount increments node error count +func (n *node) IncErrCount() { + n.Lock() + defer n.Unlock() + + n.errCount++ +} + +// ResetErrCount reset node error count +func (n *node) ResetErrCount() { + n.Lock() + defer n.Unlock() + + n.errCount = 0 +} + +// ErrCount returns node error count +func (n *node) ErrCount() int { + n.RLock() + defer n.RUnlock() + + return n.errCount +} + // getTopology traverses node graph and builds node topology // NOTE: this function is not thread safe func (n *node) getTopology(depth uint) *node { diff --git a/network/node_test.go b/network/node_test.go index 51c4988f..7f2efe3b 100644 --- a/network/node_test.go +++ b/network/node_test.go @@ -215,7 +215,22 @@ func TestDeletePeerNode(t *testing.T) { } } -func TestPruneStalePeerNodes(t *testing.T) { +func TestPrunePeer(t *testing.T) { + // complicated node graph + node := testSetup() + + before := node.Nodes() + + node.PrunePeer("peer3") + + now := node.Nodes() + + if len(now) != len(before)-1 { + t.Errorf("Expected pruned node count: %d, got: %d", len(before)-1, len(now)) + } +} + +func TestPruneStalePeers(t *testing.T) { // complicated node graph node := testSetup() @@ -224,7 +239,7 @@ func TestPruneStalePeerNodes(t *testing.T) { pruneTime := 10 * time.Millisecond time.Sleep(pruneTime) - // should delete all nodes besides node + // should delete all nodes besides (root) node pruned := node.PruneStalePeers(pruneTime) if len(pruned) != len(nodes)-1 {