sync/waitgroup: initial sync.WaitGroup wrapper with context support #319
73
sync/waitgroup.go
Normal file
73
sync/waitgroup.go
Normal file
@ -0,0 +1,73 @@
|
||||
package sync
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type WaitGroup struct {
|
||||
wg *sync.WaitGroup
|
||||
drain *atomic.Bool
|
||||
c *atomic.Int64
|
||||
}
|
||||
|
||||
func WrapWaitGroup(wg *sync.WaitGroup) *WaitGroup {
|
||||
g := &WaitGroup{
|
||||
wg: wg,
|
||||
c: &(atomic.Int64{}),
|
||||
drain: &(atomic.Bool{}),
|
||||
}
|
||||
return g
|
||||
}
|
||||
|
||||
func NewWaitGroup() *WaitGroup {
|
||||
var wg sync.WaitGroup
|
||||
g := &WaitGroup{
|
||||
wg: &wg,
|
||||
c: &(atomic.Int64{}),
|
||||
drain: &(atomic.Bool{}),
|
||||
}
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *WaitGroup) Add(n int) {
|
||||
g.c.Add(int64(n))
|
||||
g.wg.Add(n)
|
||||
}
|
||||
|
||||
func (g *WaitGroup) Done() {
|
||||
if g.drain.Load() {
|
||||
return
|
||||
}
|
||||
g.c.Add(int64(-1))
|
||||
g.wg.Done()
|
||||
}
|
||||
|
||||
func (g *WaitGroup) Wait() {
|
||||
g.wg.Wait()
|
||||
}
|
||||
|
||||
func (g *WaitGroup) WaitContext(ctx context.Context) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
g.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
g.drain.Swap(true)
|
||||
for g.c.Load() > 0 {
|
||||
select {
|
||||
case <-done:
|
||||
g.drain.Swap(false)
|
||||
return
|
||||
default:
|
||||
g.wg.Done()
|
||||
}
|
||||
}
|
||||
case <-done:
|
||||
return
|
||||
}
|
||||
}
|
16
sync/waitgroup_test.go
Normal file
16
sync/waitgroup_test.go
Normal file
@ -0,0 +1,16 @@
|
||||
package sync
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestWaitGroup(t *testing.T) {
|
||||
wg := NewWaitGroup()
|
||||
_ = t
|
||||
wg.Add(1)
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), 1*time.Second)
|
||||
defer cancel()
|
||||
wg.WaitContext(ctx)
|
||||
}
|
Loading…
Reference in New Issue
Block a user