diff --git a/sync/waitgroup.go b/sync/waitgroup.go index 5dbc1126..3124d948 100644 --- a/sync/waitgroup.go +++ b/sync/waitgroup.go @@ -3,19 +3,17 @@ package sync import ( "context" "sync" - "sync/atomic" ) type WaitGroup struct { wg *sync.WaitGroup - c *atomic.Int64 + c int mu sync.Mutex } func WrapWaitGroup(wg *sync.WaitGroup) *WaitGroup { g := &WaitGroup{ wg: wg, - c: &(atomic.Int64{}), } return g } @@ -27,14 +25,14 @@ func NewWaitGroup() *WaitGroup { func (g *WaitGroup) Add(n int) { g.mu.Lock() - g.c.Add(int64(n)) + g.c += n g.wg.Add(n) g.mu.Unlock() } func (g *WaitGroup) Done() { g.mu.Lock() - g.c.Add(int64(-1)) + g.c += -1 g.wg.Add(-1) g.mu.Unlock() } @@ -53,9 +51,9 @@ func (g *WaitGroup) WaitContext(ctx context.Context) { select { case <-ctx.Done(): g.mu.Lock() - g.wg.Add(-int(g.c.Load())) + g.wg.Add(-g.c) <-done - g.wg.Add(int(g.c.Load())) + g.wg.Add(g.c) g.mu.Unlock() return case <-done: @@ -64,5 +62,8 @@ func (g *WaitGroup) WaitContext(ctx context.Context) { } func (g *WaitGroup) Waiters() int { - return int(g.c.Load()) + g.mu.Lock() + c := g.c + g.mu.Unlock() + return c }