sync/waitgroup: initial sync.WaitGroup wrapper with context support #319

Merged
vtolstov merged 7 commits from waitgroup into master 2024-03-09 23:35:14 +03:00
2 changed files with 42 additions and 16 deletions
Showing only changes of commit 5df9e59f28 - Show all commits

View File

@ -7,16 +7,15 @@ import (
) )
type WaitGroup struct { type WaitGroup struct {
wg *sync.WaitGroup mu sync.Mutex
drain *atomic.Bool wg *sync.WaitGroup
c *atomic.Int64 c *atomic.Int64
} }
func WrapWaitGroup(wg *sync.WaitGroup) *WaitGroup { func WrapWaitGroup(wg *sync.WaitGroup) *WaitGroup {
g := &WaitGroup{ g := &WaitGroup{
wg: wg, wg: wg,
c: &(atomic.Int64{}), c: &(atomic.Int64{}),
drain: &(atomic.Bool{}),
} }
return g return g
} }
@ -24,24 +23,24 @@ func WrapWaitGroup(wg *sync.WaitGroup) *WaitGroup {
func NewWaitGroup() *WaitGroup { func NewWaitGroup() *WaitGroup {
var wg sync.WaitGroup var wg sync.WaitGroup
g := &WaitGroup{ g := &WaitGroup{
wg: &wg, wg: &wg,
c: &(atomic.Int64{}), c: &(atomic.Int64{}),
drain: &(atomic.Bool{}),
} }
return g return g
} }
func (g *WaitGroup) Add(n int) { func (g *WaitGroup) Add(n int) {
g.mu.Lock()
g.c.Add(int64(n)) g.c.Add(int64(n))
g.wg.Add(n) g.wg.Add(n)
g.mu.Unlock()
} }
func (g *WaitGroup) Done() { func (g *WaitGroup) Done() {
if g.drain.Load() { g.mu.Lock()
return
}
g.c.Add(int64(-1)) g.c.Add(int64(-1))
g.wg.Done() g.wg.Add(-1)
g.mu.Unlock()
} }
func (g *WaitGroup) Wait() { func (g *WaitGroup) Wait() {
@ -57,11 +56,17 @@ func (g *WaitGroup) WaitContext(ctx context.Context) {
select { select {
case <-ctx.Done(): case <-ctx.Done():
g.drain.Store(true) g.mu.Lock()
g.wg.Add(-int(g.c.Load())) g.wg.Add(-int(g.c.Load()))
g.drain.Store(false) <-done
g.wg.Add(int(g.c.Load()))
g.mu.Unlock()
return return
case <-done: case <-done:
return return
} }
} }
func (g *WaitGroup) Waiters() int {
return int(g.c.Load())
}

View File

@ -6,7 +6,7 @@ import (
"time" "time"
) )
func TestWaitGroup(t *testing.T) { func TestWaitGroupContext(t *testing.T) {
wg := NewWaitGroup() wg := NewWaitGroup()
_ = t _ = t
wg.Add(1) wg.Add(1)
@ -14,3 +14,24 @@ func TestWaitGroup(t *testing.T) {
defer cancel() defer cancel()
wg.WaitContext(ctx) wg.WaitContext(ctx)
} }
func TestWaitGroupReuse(t *testing.T) {
wg := NewWaitGroup()
defer func() {
if wg.Waiters() != 0 {
t.Fatal("lost goroutines")
}
}()
wg.Add(1)
defer wg.Done()
ctx, cancel := context.WithTimeout(context.TODO(), 1*time.Second)
defer cancel()
wg.WaitContext(ctx)
wg.Add(1)
defer wg.Done()
ctx, cancel = context.WithTimeout(context.TODO(), 1*time.Second)
defer cancel()
wg.WaitContext(ctx)
}