From 2cc004b01c164ef26aec161a9bc168ddf7a333bc Mon Sep 17 00:00:00 2001 From: Vasiliy Tolstov Date: Sat, 9 Mar 2024 23:36:39 +0300 Subject: [PATCH] sync/waitgroup: backport from master Signed-off-by: Vasiliy Tolstov --- sync/waitgroup.go | 69 ++++++++++++++++++++++++++++++++++++++++++ sync/waitgroup_test.go | 37 ++++++++++++++++++++++ 2 files changed, 106 insertions(+) create mode 100644 sync/waitgroup.go create mode 100644 sync/waitgroup_test.go diff --git a/sync/waitgroup.go b/sync/waitgroup.go new file mode 100644 index 00000000..3124d948 --- /dev/null +++ b/sync/waitgroup.go @@ -0,0 +1,69 @@ +package sync + +import ( + "context" + "sync" +) + +type WaitGroup struct { + wg *sync.WaitGroup + c int + mu sync.Mutex +} + +func WrapWaitGroup(wg *sync.WaitGroup) *WaitGroup { + g := &WaitGroup{ + wg: wg, + } + return g +} + +func NewWaitGroup() *WaitGroup { + var wg sync.WaitGroup + return WrapWaitGroup(&wg) +} + +func (g *WaitGroup) Add(n int) { + g.mu.Lock() + g.c += n + g.wg.Add(n) + g.mu.Unlock() +} + +func (g *WaitGroup) Done() { + g.mu.Lock() + g.c += -1 + g.wg.Add(-1) + g.mu.Unlock() +} + +func (g *WaitGroup) Wait() { + g.wg.Wait() +} + +func (g *WaitGroup) WaitContext(ctx context.Context) { + done := make(chan struct{}) + go func() { + g.wg.Wait() + close(done) + }() + + select { + case <-ctx.Done(): + g.mu.Lock() + g.wg.Add(-g.c) + <-done + g.wg.Add(g.c) + g.mu.Unlock() + return + case <-done: + return + } +} + +func (g *WaitGroup) Waiters() int { + g.mu.Lock() + c := g.c + g.mu.Unlock() + return c +} diff --git a/sync/waitgroup_test.go b/sync/waitgroup_test.go new file mode 100644 index 00000000..c3f6f1b7 --- /dev/null +++ b/sync/waitgroup_test.go @@ -0,0 +1,37 @@ +package sync + +import ( + "context" + "testing" + "time" +) + +func TestWaitGroupContext(t *testing.T) { + wg := NewWaitGroup() + _ = t + wg.Add(1) + ctx, cancel := context.WithTimeout(context.TODO(), 1*time.Second) + defer cancel() + wg.WaitContext(ctx) +} + +func TestWaitGroupReuse(t *testing.T) { + wg := NewWaitGroup() + defer func() { + if wg.Waiters() != 0 { + t.Fatal("lost goroutines") + } + }() + + wg.Add(1) + defer wg.Done() + ctx, cancel := context.WithTimeout(context.TODO(), 1*time.Second) + defer cancel() + wg.WaitContext(ctx) + + wg.Add(1) + defer wg.Done() + ctx, cancel = context.WithTimeout(context.TODO(), 1*time.Second) + defer cancel() + wg.WaitContext(ctx) +}