diff --git a/sync/waitgroup.go b/sync/waitgroup.go new file mode 100644 index 00000000..c417104e --- /dev/null +++ b/sync/waitgroup.go @@ -0,0 +1,73 @@ +package sync + +import ( + "context" + "sync" + "sync/atomic" +) + +type WaitGroup struct { + wg *sync.WaitGroup + drain *atomic.Bool + c *atomic.Int64 +} + +func WrapWaitGroup(wg *sync.WaitGroup) *WaitGroup { + g := &WaitGroup{ + wg: wg, + c: &(atomic.Int64{}), + drain: &(atomic.Bool{}), + } + return g +} + +func NewWaitGroup() *WaitGroup { + var wg sync.WaitGroup + g := &WaitGroup{ + wg: &wg, + c: &(atomic.Int64{}), + drain: &(atomic.Bool{}), + } + return g +} + +func (g *WaitGroup) Add(n int) { + g.c.Add(int64(n)) + g.wg.Add(n) +} + +func (g *WaitGroup) Done() { + if g.drain.Load() { + return + } + g.c.Add(int64(-1)) + g.wg.Done() +} + +func (g *WaitGroup) Wait() { + g.wg.Wait() +} + +func (g *WaitGroup) WaitContext(ctx context.Context) { + done := make(chan struct{}) + go func() { + g.wg.Wait() + close(done) + }() + + select { + case <-ctx.Done(): + g.drain.Swap(true) + for g.c.Load() > 0 { + select { + case <-done: + g.drain.Swap(false) + return + default: + g.wg.Done() + } + } + case <-done: + return + } +} diff --git a/sync/waitgroup_test.go b/sync/waitgroup_test.go new file mode 100644 index 00000000..67b1ca5b --- /dev/null +++ b/sync/waitgroup_test.go @@ -0,0 +1,16 @@ +package sync + +import ( + "context" + "testing" + "time" +) + +func TestWaitGroup(t *testing.T) { + wg := NewWaitGroup() + _ = t + wg.Add(1) + ctx, cancel := context.WithTimeout(context.TODO(), 1*time.Second) + defer cancel() + wg.WaitContext(ctx) +}