Signed-off-by: Vasiliy Tolstov <v.tolstov@unistack.org>
This commit is contained in:
Василий Толстов 2024-03-09 22:52:24 +03:00
parent e430c15ae8
commit 5df9e59f28
2 changed files with 42 additions and 16 deletions

View File

@ -7,8 +7,8 @@ import (
)
type WaitGroup struct {
mu sync.Mutex
wg *sync.WaitGroup
drain *atomic.Bool
c *atomic.Int64
}
@ -16,7 +16,6 @@ func WrapWaitGroup(wg *sync.WaitGroup) *WaitGroup {
g := &WaitGroup{
wg: wg,
c: &(atomic.Int64{}),
drain: &(atomic.Bool{}),
}
return g
}
@ -26,22 +25,22 @@ func NewWaitGroup() *WaitGroup {
g := &WaitGroup{
wg: &wg,
c: &(atomic.Int64{}),
drain: &(atomic.Bool{}),
}
return g
}
func (g *WaitGroup) Add(n int) {
g.mu.Lock()
g.c.Add(int64(n))
g.wg.Add(n)
g.mu.Unlock()
}
func (g *WaitGroup) Done() {
if g.drain.Load() {
return
}
g.mu.Lock()
g.c.Add(int64(-1))
g.wg.Done()
g.wg.Add(-1)
g.mu.Unlock()
}
func (g *WaitGroup) Wait() {
@ -57,11 +56,17 @@ func (g *WaitGroup) WaitContext(ctx context.Context) {
select {
case <-ctx.Done():
g.drain.Store(true)
g.mu.Lock()
g.wg.Add(-int(g.c.Load()))
g.drain.Store(false)
<-done
g.wg.Add(int(g.c.Load()))
g.mu.Unlock()
return
case <-done:
return
}
}
func (g *WaitGroup) Waiters() int {
return int(g.c.Load())
}

View File

@ -6,7 +6,7 @@ import (
"time"
)
func TestWaitGroup(t *testing.T) {
func TestWaitGroupContext(t *testing.T) {
wg := NewWaitGroup()
_ = t
wg.Add(1)
@ -14,3 +14,24 @@ func TestWaitGroup(t *testing.T) {
defer cancel()
wg.WaitContext(ctx)
}
func TestWaitGroupReuse(t *testing.T) {
wg := NewWaitGroup()
defer func() {
if wg.Waiters() != 0 {
t.Fatal("lost goroutines")
}
}()
wg.Add(1)
defer wg.Done()
ctx, cancel := context.WithTimeout(context.TODO(), 1*time.Second)
defer cancel()
wg.WaitContext(ctx)
wg.Add(1)
defer wg.Done()
ctx, cancel = context.WithTimeout(context.TODO(), 1*time.Second)
defer cancel()
wg.WaitContext(ctx)
}