sync/waitgroup: backport from master #320
							
								
								
									
										69
									
								
								sync/waitgroup.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								sync/waitgroup.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,69 @@ | |||||||
|  | package sync | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"sync" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type WaitGroup struct { | ||||||
|  | 	wg *sync.WaitGroup | ||||||
|  | 	c  int | ||||||
|  | 	mu sync.Mutex | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func WrapWaitGroup(wg *sync.WaitGroup) *WaitGroup { | ||||||
|  | 	g := &WaitGroup{ | ||||||
|  | 		wg: wg, | ||||||
|  | 	} | ||||||
|  | 	return g | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func NewWaitGroup() *WaitGroup { | ||||||
|  | 	var wg sync.WaitGroup | ||||||
|  | 	return WrapWaitGroup(&wg) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (g *WaitGroup) Add(n int) { | ||||||
|  | 	g.mu.Lock() | ||||||
|  | 	g.c += n | ||||||
|  | 	g.wg.Add(n) | ||||||
|  | 	g.mu.Unlock() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (g *WaitGroup) Done() { | ||||||
|  | 	g.mu.Lock() | ||||||
|  | 	g.c += -1 | ||||||
|  | 	g.wg.Add(-1) | ||||||
|  | 	g.mu.Unlock() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (g *WaitGroup) Wait() { | ||||||
|  | 	g.wg.Wait() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (g *WaitGroup) WaitContext(ctx context.Context) { | ||||||
|  | 	done := make(chan struct{}) | ||||||
|  | 	go func() { | ||||||
|  | 		g.wg.Wait() | ||||||
|  | 		close(done) | ||||||
|  | 	}() | ||||||
|  |  | ||||||
|  | 	select { | ||||||
|  | 	case <-ctx.Done(): | ||||||
|  | 		g.mu.Lock() | ||||||
|  | 		g.wg.Add(-g.c) | ||||||
|  | 		<-done | ||||||
|  | 		g.wg.Add(g.c) | ||||||
|  | 		g.mu.Unlock() | ||||||
|  | 		return | ||||||
|  | 	case <-done: | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (g *WaitGroup) Waiters() int { | ||||||
|  | 	g.mu.Lock() | ||||||
|  | 	c := g.c | ||||||
|  | 	g.mu.Unlock() | ||||||
|  | 	return c | ||||||
|  | } | ||||||
							
								
								
									
										37
									
								
								sync/waitgroup_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								sync/waitgroup_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,37 @@ | |||||||
|  | package sync | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestWaitGroupContext(t *testing.T) { | ||||||
|  | 	wg := NewWaitGroup() | ||||||
|  | 	_ = t | ||||||
|  | 	wg.Add(1) | ||||||
|  | 	ctx, cancel := context.WithTimeout(context.TODO(), 1*time.Second) | ||||||
|  | 	defer cancel() | ||||||
|  | 	wg.WaitContext(ctx) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestWaitGroupReuse(t *testing.T) { | ||||||
|  | 	wg := NewWaitGroup() | ||||||
|  | 	defer func() { | ||||||
|  | 		if wg.Waiters() != 0 { | ||||||
|  | 			t.Fatal("lost goroutines") | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
|  |  | ||||||
|  | 	wg.Add(1) | ||||||
|  | 	defer wg.Done() | ||||||
|  | 	ctx, cancel := context.WithTimeout(context.TODO(), 1*time.Second) | ||||||
|  | 	defer cancel() | ||||||
|  | 	wg.WaitContext(ctx) | ||||||
|  |  | ||||||
|  | 	wg.Add(1) | ||||||
|  | 	defer wg.Done() | ||||||
|  | 	ctx, cancel = context.WithTimeout(context.TODO(), 1*time.Second) | ||||||
|  | 	defer cancel() | ||||||
|  | 	wg.WaitContext(ctx) | ||||||
|  | } | ||||||
		Reference in New Issue
	
	Block a user