sync/waitgroup: backport from master #320
							
								
								
									
										69
									
								
								sync/waitgroup.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								sync/waitgroup.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,69 @@ | ||||
| package sync | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"sync" | ||||
| ) | ||||
|  | ||||
| type WaitGroup struct { | ||||
| 	wg *sync.WaitGroup | ||||
| 	c  int | ||||
| 	mu sync.Mutex | ||||
| } | ||||
|  | ||||
| func WrapWaitGroup(wg *sync.WaitGroup) *WaitGroup { | ||||
| 	g := &WaitGroup{ | ||||
| 		wg: wg, | ||||
| 	} | ||||
| 	return g | ||||
| } | ||||
|  | ||||
| func NewWaitGroup() *WaitGroup { | ||||
| 	var wg sync.WaitGroup | ||||
| 	return WrapWaitGroup(&wg) | ||||
| } | ||||
|  | ||||
| func (g *WaitGroup) Add(n int) { | ||||
| 	g.mu.Lock() | ||||
| 	g.c += n | ||||
| 	g.wg.Add(n) | ||||
| 	g.mu.Unlock() | ||||
| } | ||||
|  | ||||
| func (g *WaitGroup) Done() { | ||||
| 	g.mu.Lock() | ||||
| 	g.c += -1 | ||||
| 	g.wg.Add(-1) | ||||
| 	g.mu.Unlock() | ||||
| } | ||||
|  | ||||
| func (g *WaitGroup) Wait() { | ||||
| 	g.wg.Wait() | ||||
| } | ||||
|  | ||||
| func (g *WaitGroup) WaitContext(ctx context.Context) { | ||||
| 	done := make(chan struct{}) | ||||
| 	go func() { | ||||
| 		g.wg.Wait() | ||||
| 		close(done) | ||||
| 	}() | ||||
|  | ||||
| 	select { | ||||
| 	case <-ctx.Done(): | ||||
| 		g.mu.Lock() | ||||
| 		g.wg.Add(-g.c) | ||||
| 		<-done | ||||
| 		g.wg.Add(g.c) | ||||
| 		g.mu.Unlock() | ||||
| 		return | ||||
| 	case <-done: | ||||
| 		return | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (g *WaitGroup) Waiters() int { | ||||
| 	g.mu.Lock() | ||||
| 	c := g.c | ||||
| 	g.mu.Unlock() | ||||
| 	return c | ||||
| } | ||||
							
								
								
									
										37
									
								
								sync/waitgroup_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								sync/waitgroup_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,37 @@ | ||||
| package sync | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| func TestWaitGroupContext(t *testing.T) { | ||||
| 	wg := NewWaitGroup() | ||||
| 	_ = t | ||||
| 	wg.Add(1) | ||||
| 	ctx, cancel := context.WithTimeout(context.TODO(), 1*time.Second) | ||||
| 	defer cancel() | ||||
| 	wg.WaitContext(ctx) | ||||
| } | ||||
|  | ||||
| func TestWaitGroupReuse(t *testing.T) { | ||||
| 	wg := NewWaitGroup() | ||||
| 	defer func() { | ||||
| 		if wg.Waiters() != 0 { | ||||
| 			t.Fatal("lost goroutines") | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	wg.Add(1) | ||||
| 	defer wg.Done() | ||||
| 	ctx, cancel := context.WithTimeout(context.TODO(), 1*time.Second) | ||||
| 	defer cancel() | ||||
| 	wg.WaitContext(ctx) | ||||
|  | ||||
| 	wg.Add(1) | ||||
| 	defer wg.Done() | ||||
| 	ctx, cancel = context.WithTimeout(context.TODO(), 1*time.Second) | ||||
| 	defer cancel() | ||||
| 	wg.WaitContext(ctx) | ||||
| } | ||||
		Reference in New Issue
	
	Block a user