diff --git a/sync/waitgroup.go b/sync/waitgroup.go index be6872f2..e7c09add 100644 --- a/sync/waitgroup.go +++ b/sync/waitgroup.go @@ -7,16 +7,15 @@ import ( ) type WaitGroup struct { - wg *sync.WaitGroup - drain *atomic.Bool - c *atomic.Int64 + mu sync.Mutex + wg *sync.WaitGroup + c *atomic.Int64 } func WrapWaitGroup(wg *sync.WaitGroup) *WaitGroup { g := &WaitGroup{ - wg: wg, - c: &(atomic.Int64{}), - drain: &(atomic.Bool{}), + wg: wg, + c: &(atomic.Int64{}), } return g } @@ -24,24 +23,24 @@ func WrapWaitGroup(wg *sync.WaitGroup) *WaitGroup { func NewWaitGroup() *WaitGroup { var wg sync.WaitGroup g := &WaitGroup{ - wg: &wg, - c: &(atomic.Int64{}), - drain: &(atomic.Bool{}), + wg: &wg, + c: &(atomic.Int64{}), } return g } func (g *WaitGroup) Add(n int) { + g.mu.Lock() g.c.Add(int64(n)) g.wg.Add(n) + g.mu.Unlock() } func (g *WaitGroup) Done() { - if g.drain.Load() { - return - } + g.mu.Lock() g.c.Add(int64(-1)) - g.wg.Done() + g.wg.Add(-1) + g.mu.Unlock() } func (g *WaitGroup) Wait() { @@ -57,11 +56,17 @@ func (g *WaitGroup) WaitContext(ctx context.Context) { select { case <-ctx.Done(): - g.drain.Store(true) + g.mu.Lock() g.wg.Add(-int(g.c.Load())) - g.drain.Store(false) + <-done + g.wg.Add(int(g.c.Load())) + g.mu.Unlock() return case <-done: return } } + +func (g *WaitGroup) Waiters() int { + return int(g.c.Load()) +} diff --git a/sync/waitgroup_test.go b/sync/waitgroup_test.go index 67b1ca5b..c3f6f1b7 100644 --- a/sync/waitgroup_test.go +++ b/sync/waitgroup_test.go @@ -6,7 +6,7 @@ import ( "time" ) -func TestWaitGroup(t *testing.T) { +func TestWaitGroupContext(t *testing.T) { wg := NewWaitGroup() _ = t wg.Add(1) @@ -14,3 +14,24 @@ func TestWaitGroup(t *testing.T) { defer cancel() wg.WaitContext(ctx) } + +func TestWaitGroupReuse(t *testing.T) { + wg := NewWaitGroup() + defer func() { + if wg.Waiters() != 0 { + t.Fatal("lost goroutines") + } + }() + + wg.Add(1) + defer wg.Done() + ctx, cancel := context.WithTimeout(context.TODO(), 1*time.Second) + defer cancel() + wg.WaitContext(ctx) + + wg.Add(1) + defer wg.Done() + ctx, cancel = context.WithTimeout(context.TODO(), 1*time.Second) + defer cancel() + wg.WaitContext(ctx) +}