sync/waitgroup: initial sync.WaitGroup wrapper with context support #319
							
								
								
									
										73
									
								
								sync/waitgroup.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										73
									
								
								sync/waitgroup.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,73 @@
 | 
			
		||||
package sync
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"sync/atomic"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type WaitGroup struct {
 | 
			
		||||
	wg    *sync.WaitGroup
 | 
			
		||||
	drain *atomic.Bool
 | 
			
		||||
	c     *atomic.Int64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func WrapWaitGroup(wg *sync.WaitGroup) *WaitGroup {
 | 
			
		||||
	g := &WaitGroup{
 | 
			
		||||
		wg:    wg,
 | 
			
		||||
		c:     &(atomic.Int64{}),
 | 
			
		||||
		drain: &(atomic.Bool{}),
 | 
			
		||||
	}
 | 
			
		||||
	return g
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewWaitGroup() *WaitGroup {
 | 
			
		||||
	var wg sync.WaitGroup
 | 
			
		||||
	g := &WaitGroup{
 | 
			
		||||
		wg:    &wg,
 | 
			
		||||
		c:     &(atomic.Int64{}),
 | 
			
		||||
		drain: &(atomic.Bool{}),
 | 
			
		||||
	}
 | 
			
		||||
	return g
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (g *WaitGroup) Add(n int) {
 | 
			
		||||
	g.c.Add(int64(n))
 | 
			
		||||
	g.wg.Add(n)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (g *WaitGroup) Done() {
 | 
			
		||||
	if g.drain.Load() {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	g.c.Add(int64(-1))
 | 
			
		||||
	g.wg.Done()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (g *WaitGroup) Wait() {
 | 
			
		||||
	g.wg.Wait()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (g *WaitGroup) WaitContext(ctx context.Context) {
 | 
			
		||||
	done := make(chan struct{})
 | 
			
		||||
	go func() {
 | 
			
		||||
		g.wg.Wait()
 | 
			
		||||
		close(done)
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	select {
 | 
			
		||||
	case <-ctx.Done():
 | 
			
		||||
		g.drain.Swap(true)
 | 
			
		||||
		for g.c.Load() > 0 {
 | 
			
		||||
			select {
 | 
			
		||||
			case <-done:
 | 
			
		||||
				g.drain.Swap(false)
 | 
			
		||||
				return
 | 
			
		||||
			default:
 | 
			
		||||
				g.wg.Done()
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	case <-done:
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										16
									
								
								sync/waitgroup_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								sync/waitgroup_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,16 @@
 | 
			
		||||
package sync
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestWaitGroup(t *testing.T) {
 | 
			
		||||
	wg := NewWaitGroup()
 | 
			
		||||
	_ = t
 | 
			
		||||
	wg.Add(1)
 | 
			
		||||
	ctx, cancel := context.WithTimeout(context.TODO(), 1*time.Second)
 | 
			
		||||
	defer cancel()
 | 
			
		||||
	wg.WaitContext(ctx)
 | 
			
		||||
}
 | 
			
		||||
		Reference in New Issue
	
	Block a user