diff --git a/flow/dag_test.go b/flow/dag_test.go index 6563200c..175098ae 100644 --- a/flow/dag_test.go +++ b/flow/dag_test.go @@ -7,6 +7,55 @@ import ( "github.com/silas/dag" ) +func TestDeps(t *testing.T) { + d := &dag.AcyclicGraph{} + + v0 := d.Add(&node{"v0"}) + v1 := d.Add(&node{"v1"}) + v2 := d.Add(&node{"v2"}) + v3 := d.Add(&node{"v3"}) + v4 := d.Add(&node{"v4"}) + + d.Connect(dag.BasicEdge(v0, v1)) + d.Connect(dag.BasicEdge(v1, v2)) + d.Connect(dag.BasicEdge(v2, v4)) + d.Connect(dag.BasicEdge(v0, v3)) + d.Connect(dag.BasicEdge(v3, v4)) + + if err := d.Validate(); err != nil { + t.Fatal(err) + } + + d.TransitiveReduction() + + var steps [][]string + fn := func(n dag.Vertex, idx int) error { + if idx == 0 { + steps = make([][]string, 1) + steps[0] = make([]string, 0, 1) + } else if idx >= len(steps) { + tsteps := make([][]string, idx+1) + copy(tsteps, steps) + steps = tsteps + steps[idx] = make([]string, 0, 1) + } + steps[idx] = append(steps[idx], fmt.Sprintf("%s", n)) + return nil + } + + start := &node{"v0"} + err := d.SortedDepthFirstWalk([]dag.Vertex{start}, fn) + checkErr(t, err) + + for idx, steps := range steps { + fmt.Printf("level %d steps %#+v\n", idx, steps) + } + + if len(steps[2]) != 1 { + t.Fatalf("invalid steps %#+v", steps[2]) + } +} + func checkErr(t *testing.T, err error) { if err != nil { t.Fatal(err)