| @@ -7,16 +7,15 @@ import ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| type WaitGroup struct { | type WaitGroup struct { | ||||||
| 	wg    *sync.WaitGroup | 	mu sync.Mutex | ||||||
| 	drain *atomic.Bool | 	wg *sync.WaitGroup | ||||||
| 	c     *atomic.Int64 | 	c  *atomic.Int64 | ||||||
| } | } | ||||||
|  |  | ||||||
| func WrapWaitGroup(wg *sync.WaitGroup) *WaitGroup { | func WrapWaitGroup(wg *sync.WaitGroup) *WaitGroup { | ||||||
| 	g := &WaitGroup{ | 	g := &WaitGroup{ | ||||||
| 		wg:    wg, | 		wg: wg, | ||||||
| 		c:     &(atomic.Int64{}), | 		c:  &(atomic.Int64{}), | ||||||
| 		drain: &(atomic.Bool{}), |  | ||||||
| 	} | 	} | ||||||
| 	return g | 	return g | ||||||
| } | } | ||||||
| @@ -24,24 +23,24 @@ func WrapWaitGroup(wg *sync.WaitGroup) *WaitGroup { | |||||||
| func NewWaitGroup() *WaitGroup { | func NewWaitGroup() *WaitGroup { | ||||||
| 	var wg sync.WaitGroup | 	var wg sync.WaitGroup | ||||||
| 	g := &WaitGroup{ | 	g := &WaitGroup{ | ||||||
| 		wg:    &wg, | 		wg: &wg, | ||||||
| 		c:     &(atomic.Int64{}), | 		c:  &(atomic.Int64{}), | ||||||
| 		drain: &(atomic.Bool{}), |  | ||||||
| 	} | 	} | ||||||
| 	return g | 	return g | ||||||
| } | } | ||||||
|  |  | ||||||
| func (g *WaitGroup) Add(n int) { | func (g *WaitGroup) Add(n int) { | ||||||
|  | 	g.mu.Lock() | ||||||
| 	g.c.Add(int64(n)) | 	g.c.Add(int64(n)) | ||||||
| 	g.wg.Add(n) | 	g.wg.Add(n) | ||||||
|  | 	g.mu.Unlock() | ||||||
| } | } | ||||||
|  |  | ||||||
| func (g *WaitGroup) Done() { | func (g *WaitGroup) Done() { | ||||||
| 	if g.drain.Load() { | 	g.mu.Lock() | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 	g.c.Add(int64(-1)) | 	g.c.Add(int64(-1)) | ||||||
| 	g.wg.Done() | 	g.wg.Add(-1) | ||||||
|  | 	g.mu.Unlock() | ||||||
| } | } | ||||||
|  |  | ||||||
| func (g *WaitGroup) Wait() { | func (g *WaitGroup) Wait() { | ||||||
| @@ -57,11 +56,17 @@ func (g *WaitGroup) WaitContext(ctx context.Context) { | |||||||
|  |  | ||||||
| 	select { | 	select { | ||||||
| 	case <-ctx.Done(): | 	case <-ctx.Done(): | ||||||
| 		g.drain.Store(true) | 		g.mu.Lock() | ||||||
| 		g.wg.Add(-int(g.c.Load())) | 		g.wg.Add(-int(g.c.Load())) | ||||||
| 		g.drain.Store(false) | 		<-done | ||||||
|  | 		g.wg.Add(int(g.c.Load())) | ||||||
|  | 		g.mu.Unlock() | ||||||
| 		return | 		return | ||||||
| 	case <-done: | 	case <-done: | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (g *WaitGroup) Waiters() int { | ||||||
|  | 	return int(g.c.Load()) | ||||||
|  | } | ||||||
|   | |||||||
| @@ -6,7 +6,7 @@ import ( | |||||||
| 	"time" | 	"time" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func TestWaitGroup(t *testing.T) { | func TestWaitGroupContext(t *testing.T) { | ||||||
| 	wg := NewWaitGroup() | 	wg := NewWaitGroup() | ||||||
| 	_ = t | 	_ = t | ||||||
| 	wg.Add(1) | 	wg.Add(1) | ||||||
| @@ -14,3 +14,24 @@ func TestWaitGroup(t *testing.T) { | |||||||
| 	defer cancel() | 	defer cancel() | ||||||
| 	wg.WaitContext(ctx) | 	wg.WaitContext(ctx) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestWaitGroupReuse(t *testing.T) { | ||||||
|  | 	wg := NewWaitGroup() | ||||||
|  | 	defer func() { | ||||||
|  | 		if wg.Waiters() != 0 { | ||||||
|  | 			t.Fatal("lost goroutines") | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
|  |  | ||||||
|  | 	wg.Add(1) | ||||||
|  | 	defer wg.Done() | ||||||
|  | 	ctx, cancel := context.WithTimeout(context.TODO(), 1*time.Second) | ||||||
|  | 	defer cancel() | ||||||
|  | 	wg.WaitContext(ctx) | ||||||
|  |  | ||||||
|  | 	wg.Add(1) | ||||||
|  | 	defer wg.Done() | ||||||
|  | 	ctx, cancel = context.WithTimeout(context.TODO(), 1*time.Second) | ||||||
|  | 	defer cancel() | ||||||
|  | 	wg.WaitContext(ctx) | ||||||
|  | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user