package sync import ( "context" "sync" "sync/atomic" ) type WaitGroup struct { wg *sync.WaitGroup drain *atomic.Bool c *atomic.Int64 } func WrapWaitGroup(wg *sync.WaitGroup) *WaitGroup { g := &WaitGroup{ wg: wg, c: &(atomic.Int64{}), drain: &(atomic.Bool{}), } return g } func NewWaitGroup() *WaitGroup { var wg sync.WaitGroup g := &WaitGroup{ wg: &wg, c: &(atomic.Int64{}), drain: &(atomic.Bool{}), } return g } func (g *WaitGroup) Add(n int) { g.c.Add(int64(n)) g.wg.Add(n) } func (g *WaitGroup) Done() { if g.drain.Load() { return } g.c.Add(int64(-1)) g.wg.Done() } func (g *WaitGroup) Wait() { g.wg.Wait() } func (g *WaitGroup) WaitContext(ctx context.Context) { done := make(chan struct{}) go func() { g.wg.Wait() close(done) }() select { case <-ctx.Done(): g.drain.Swap(true) for g.c.Load() > 0 { select { case <-done: g.drain.Swap(false) return default: g.wg.Done() } } case <-done: return } }