68 lines
947 B
Go
68 lines
947 B
Go
package sync
|
|
|
|
import (
|
|
"context"
|
|
"sync"
|
|
"sync/atomic"
|
|
)
|
|
|
|
type WaitGroup struct {
|
|
wg *sync.WaitGroup
|
|
drain *atomic.Bool
|
|
c *atomic.Int64
|
|
}
|
|
|
|
func WrapWaitGroup(wg *sync.WaitGroup) *WaitGroup {
|
|
g := &WaitGroup{
|
|
wg: wg,
|
|
c: &(atomic.Int64{}),
|
|
drain: &(atomic.Bool{}),
|
|
}
|
|
return g
|
|
}
|
|
|
|
func NewWaitGroup() *WaitGroup {
|
|
var wg sync.WaitGroup
|
|
g := &WaitGroup{
|
|
wg: &wg,
|
|
c: &(atomic.Int64{}),
|
|
drain: &(atomic.Bool{}),
|
|
}
|
|
return g
|
|
}
|
|
|
|
func (g *WaitGroup) Add(n int) {
|
|
g.c.Add(int64(n))
|
|
g.wg.Add(n)
|
|
}
|
|
|
|
func (g *WaitGroup) Done() {
|
|
if g.drain.Load() {
|
|
return
|
|
}
|
|
g.c.Add(int64(-1))
|
|
g.wg.Done()
|
|
}
|
|
|
|
func (g *WaitGroup) Wait() {
|
|
g.wg.Wait()
|
|
}
|
|
|
|
func (g *WaitGroup) WaitContext(ctx context.Context) {
|
|
done := make(chan struct{})
|
|
go func() {
|
|
g.wg.Wait()
|
|
close(done)
|
|
}()
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
g.drain.Store(true)
|
|
g.wg.Add(-int(g.c.Load()))
|
|
g.drain.Store(false)
|
|
return
|
|
case <-done:
|
|
return
|
|
}
|
|
}
|