diff --git a/network/node_test.go b/network/node_test.go index b99a284c..df5a5e4a 100644 --- a/network/node_test.go +++ b/network/node_test.go @@ -157,6 +157,22 @@ func TestPeers(t *testing.T) { } } +func collectTopologyIds(peers map[string]*node, ids map[string]bool) map[string]bool { + if len(peers) == 0 { + return ids + } + + // iterate through the whole graph + for id, peer := range peers { + ids = collectTopologyIds(peer.peers, ids) + if _, ok := ids[id]; !ok { + ids[id] = true + } + } + + return ids +} + func TestTopology(t *testing.T) { // single node single := &node{ @@ -200,17 +216,16 @@ func TestTopology(t *testing.T) { } topology = node.Topology(2) - // iterate through the whole graph - // NOTE: this is a manual iteration as we know the size of the graph - for id, peer := range topology.peers { - if _, ok := peerIds[id]; !ok { - t.Errorf("Expected to find %s peer", peer.Id()) - } - // peers of peers - for id := range peer.peers { - if _, ok := peerIds[id]; !ok { - t.Errorf("Expected to find %s peer", peer.Id()) - } + topIds := make(map[string]bool) + topIds = collectTopologyIds(topology.peers, topIds) + + if len(topIds) != len(peerIds) { + t.Errorf("Expected to find %d nodes, found: %d", len(peerIds), len(topIds)) + } + + for id := range topIds { + if _, ok := topIds[id]; !ok { + t.Errorf("Expected to find %s peer", id) } } }