| @@ -3,19 +3,17 @@ package sync | |||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"sync/atomic" |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type WaitGroup struct { | type WaitGroup struct { | ||||||
| 	wg *sync.WaitGroup | 	wg *sync.WaitGroup | ||||||
| 	c  *atomic.Int64 | 	c  int | ||||||
| 	mu sync.Mutex | 	mu sync.Mutex | ||||||
| } | } | ||||||
|  |  | ||||||
| func WrapWaitGroup(wg *sync.WaitGroup) *WaitGroup { | func WrapWaitGroup(wg *sync.WaitGroup) *WaitGroup { | ||||||
| 	g := &WaitGroup{ | 	g := &WaitGroup{ | ||||||
| 		wg: wg, | 		wg: wg, | ||||||
| 		c:  &(atomic.Int64{}), |  | ||||||
| 	} | 	} | ||||||
| 	return g | 	return g | ||||||
| } | } | ||||||
| @@ -27,14 +25,14 @@ func NewWaitGroup() *WaitGroup { | |||||||
|  |  | ||||||
| func (g *WaitGroup) Add(n int) { | func (g *WaitGroup) Add(n int) { | ||||||
| 	g.mu.Lock() | 	g.mu.Lock() | ||||||
| 	g.c.Add(int64(n)) | 	g.c += n | ||||||
| 	g.wg.Add(n) | 	g.wg.Add(n) | ||||||
| 	g.mu.Unlock() | 	g.mu.Unlock() | ||||||
| } | } | ||||||
|  |  | ||||||
| func (g *WaitGroup) Done() { | func (g *WaitGroup) Done() { | ||||||
| 	g.mu.Lock() | 	g.mu.Lock() | ||||||
| 	g.c.Add(int64(-1)) | 	g.c += -1 | ||||||
| 	g.wg.Add(-1) | 	g.wg.Add(-1) | ||||||
| 	g.mu.Unlock() | 	g.mu.Unlock() | ||||||
| } | } | ||||||
| @@ -53,9 +51,9 @@ func (g *WaitGroup) WaitContext(ctx context.Context) { | |||||||
| 	select { | 	select { | ||||||
| 	case <-ctx.Done(): | 	case <-ctx.Done(): | ||||||
| 		g.mu.Lock() | 		g.mu.Lock() | ||||||
| 		g.wg.Add(-int(g.c.Load())) | 		g.wg.Add(-g.c) | ||||||
| 		<-done | 		<-done | ||||||
| 		g.wg.Add(int(g.c.Load())) | 		g.wg.Add(g.c) | ||||||
| 		g.mu.Unlock() | 		g.mu.Unlock() | ||||||
| 		return | 		return | ||||||
| 	case <-done: | 	case <-done: | ||||||
| @@ -64,5 +62,8 @@ func (g *WaitGroup) WaitContext(ctx context.Context) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (g *WaitGroup) Waiters() int { | func (g *WaitGroup) Waiters() int { | ||||||
| 	return int(g.c.Load()) | 	g.mu.Lock() | ||||||
|  | 	c := g.c | ||||||
|  | 	g.mu.Unlock() | ||||||
|  | 	return c | ||||||
| } | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user