sync/waitgroup: initial sync.WaitGroup wrapper with context support #319
@ -7,8 +7,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type WaitGroup struct {
|
type WaitGroup struct {
|
||||||
|
mu sync.Mutex
|
||||||
wg *sync.WaitGroup
|
wg *sync.WaitGroup
|
||||||
drain *atomic.Bool
|
|
||||||
c *atomic.Int64
|
c *atomic.Int64
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -16,7 +16,6 @@ func WrapWaitGroup(wg *sync.WaitGroup) *WaitGroup {
|
|||||||
g := &WaitGroup{
|
g := &WaitGroup{
|
||||||
wg: wg,
|
wg: wg,
|
||||||
c: &(atomic.Int64{}),
|
c: &(atomic.Int64{}),
|
||||||
drain: &(atomic.Bool{}),
|
|
||||||
}
|
}
|
||||||
return g
|
return g
|
||||||
}
|
}
|
||||||
@ -26,22 +25,22 @@ func NewWaitGroup() *WaitGroup {
|
|||||||
g := &WaitGroup{
|
g := &WaitGroup{
|
||||||
wg: &wg,
|
wg: &wg,
|
||||||
c: &(atomic.Int64{}),
|
c: &(atomic.Int64{}),
|
||||||
drain: &(atomic.Bool{}),
|
|
||||||
}
|
}
|
||||||
return g
|
return g
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *WaitGroup) Add(n int) {
|
func (g *WaitGroup) Add(n int) {
|
||||||
|
g.mu.Lock()
|
||||||
g.c.Add(int64(n))
|
g.c.Add(int64(n))
|
||||||
g.wg.Add(n)
|
g.wg.Add(n)
|
||||||
|
g.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *WaitGroup) Done() {
|
func (g *WaitGroup) Done() {
|
||||||
if g.drain.Load() {
|
g.mu.Lock()
|
||||||
return
|
|
||||||
}
|
|
||||||
g.c.Add(int64(-1))
|
g.c.Add(int64(-1))
|
||||||
g.wg.Done()
|
g.wg.Add(-1)
|
||||||
|
g.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *WaitGroup) Wait() {
|
func (g *WaitGroup) Wait() {
|
||||||
@ -57,11 +56,17 @@ func (g *WaitGroup) WaitContext(ctx context.Context) {
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
g.drain.Store(true)
|
g.mu.Lock()
|
||||||
g.wg.Add(-int(g.c.Load()))
|
g.wg.Add(-int(g.c.Load()))
|
||||||
g.drain.Store(false)
|
<-done
|
||||||
|
g.wg.Add(int(g.c.Load()))
|
||||||
|
g.mu.Unlock()
|
||||||
return
|
return
|
||||||
case <-done:
|
case <-done:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (g *WaitGroup) Waiters() int {
|
||||||
|
return int(g.c.Load())
|
||||||
|
}
|
||||||
|
@ -6,7 +6,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestWaitGroup(t *testing.T) {
|
func TestWaitGroupContext(t *testing.T) {
|
||||||
wg := NewWaitGroup()
|
wg := NewWaitGroup()
|
||||||
_ = t
|
_ = t
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
@ -14,3 +14,24 @@ func TestWaitGroup(t *testing.T) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
wg.WaitContext(ctx)
|
wg.WaitContext(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWaitGroupReuse(t *testing.T) {
|
||||||
|
wg := NewWaitGroup()
|
||||||
|
defer func() {
|
||||||
|
if wg.Waiters() != 0 {
|
||||||
|
t.Fatal("lost goroutines")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
defer wg.Done()
|
||||||
|
ctx, cancel := context.WithTimeout(context.TODO(), 1*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
wg.WaitContext(ctx)
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
defer wg.Done()
|
||||||
|
ctx, cancel = context.WithTimeout(context.TODO(), 1*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
wg.WaitContext(ctx)
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user