sync/waitgroup: initial sync.WaitGroup wrapper with context support #319

Merged
vtolstov merged 7 commits from waitgroup into master 2024-03-09 23:35:14 +03:00
2 changed files with 89 additions and 0 deletions
Showing only changes of commit d112f7148e - Show all commits

73
sync/waitgroup.go Normal file
View File

@ -0,0 +1,73 @@
package sync
import (
"context"
"sync"
"sync/atomic"
)
type WaitGroup struct {
wg *sync.WaitGroup
drain *atomic.Bool
c *atomic.Int64
}
func WrapWaitGroup(wg *sync.WaitGroup) *WaitGroup {
g := &WaitGroup{
wg: wg,
c: &(atomic.Int64{}),
drain: &(atomic.Bool{}),
}
return g
}
func NewWaitGroup() *WaitGroup {
var wg sync.WaitGroup
g := &WaitGroup{
wg: &wg,
c: &(atomic.Int64{}),
drain: &(atomic.Bool{}),
}
return g
}
func (g *WaitGroup) Add(n int) {
g.c.Add(int64(n))
g.wg.Add(n)
}
func (g *WaitGroup) Done() {
if g.drain.Load() {
return
}
g.c.Add(int64(-1))
g.wg.Done()
}
func (g *WaitGroup) Wait() {
g.wg.Wait()
}
func (g *WaitGroup) WaitContext(ctx context.Context) {
done := make(chan struct{})
go func() {
g.wg.Wait()
close(done)
}()
select {
case <-ctx.Done():
g.drain.Swap(true)
for g.c.Load() > 0 {
select {
case <-done:
g.drain.Swap(false)
return
default:
g.wg.Done()
}
}
case <-done:
return
}
}

16
sync/waitgroup_test.go Normal file
View File

@ -0,0 +1,16 @@
package sync
import (
"context"
"testing"
"time"
)
func TestWaitGroup(t *testing.T) {
wg := NewWaitGroup()
_ = t
wg.Add(1)
ctx, cancel := context.WithTimeout(context.TODO(), 1*time.Second)
defer cancel()
wg.WaitContext(ctx)
}