Merge pull request 'sync/waitgroup: backport from master' (#320) from waitgroup into v3
Reviewed-on: #320
This commit is contained in:
commit
ed7972a1fa
69
sync/waitgroup.go
Normal file
69
sync/waitgroup.go
Normal file
@ -0,0 +1,69 @@
|
||||
package sync
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type WaitGroup struct {
|
||||
wg *sync.WaitGroup
|
||||
c int
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func WrapWaitGroup(wg *sync.WaitGroup) *WaitGroup {
|
||||
g := &WaitGroup{
|
||||
wg: wg,
|
||||
}
|
||||
return g
|
||||
}
|
||||
|
||||
func NewWaitGroup() *WaitGroup {
|
||||
var wg sync.WaitGroup
|
||||
return WrapWaitGroup(&wg)
|
||||
}
|
||||
|
||||
func (g *WaitGroup) Add(n int) {
|
||||
g.mu.Lock()
|
||||
g.c += n
|
||||
g.wg.Add(n)
|
||||
g.mu.Unlock()
|
||||
}
|
||||
|
||||
func (g *WaitGroup) Done() {
|
||||
g.mu.Lock()
|
||||
g.c += -1
|
||||
g.wg.Add(-1)
|
||||
g.mu.Unlock()
|
||||
}
|
||||
|
||||
func (g *WaitGroup) Wait() {
|
||||
g.wg.Wait()
|
||||
}
|
||||
|
||||
func (g *WaitGroup) WaitContext(ctx context.Context) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
g.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
g.mu.Lock()
|
||||
g.wg.Add(-g.c)
|
||||
<-done
|
||||
g.wg.Add(g.c)
|
||||
g.mu.Unlock()
|
||||
return
|
||||
case <-done:
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (g *WaitGroup) Waiters() int {
|
||||
g.mu.Lock()
|
||||
c := g.c
|
||||
g.mu.Unlock()
|
||||
return c
|
||||
}
|
37
sync/waitgroup_test.go
Normal file
37
sync/waitgroup_test.go
Normal file
@ -0,0 +1,37 @@
|
||||
package sync
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestWaitGroupContext(t *testing.T) {
|
||||
wg := NewWaitGroup()
|
||||
_ = t
|
||||
wg.Add(1)
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), 1*time.Second)
|
||||
defer cancel()
|
||||
wg.WaitContext(ctx)
|
||||
}
|
||||
|
||||
func TestWaitGroupReuse(t *testing.T) {
|
||||
wg := NewWaitGroup()
|
||||
defer func() {
|
||||
if wg.Waiters() != 0 {
|
||||
t.Fatal("lost goroutines")
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
defer wg.Done()
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), 1*time.Second)
|
||||
defer cancel()
|
||||
wg.WaitContext(ctx)
|
||||
|
||||
wg.Add(1)
|
||||
defer wg.Done()
|
||||
ctx, cancel = context.WithTimeout(context.TODO(), 1*time.Second)
|
||||
defer cancel()
|
||||
wg.WaitContext(ctx)
|
||||
}
|
Loading…
Reference in New Issue
Block a user