sync/waitgroup: initial sync.WaitGroup wrapper with context support #319
							
								
								
									
										73
									
								
								sync/waitgroup.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										73
									
								
								sync/waitgroup.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,73 @@
 | 
				
			|||||||
 | 
					package sync
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
 | 
						"sync"
 | 
				
			||||||
 | 
						"sync/atomic"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type WaitGroup struct {
 | 
				
			||||||
 | 
						wg    *sync.WaitGroup
 | 
				
			||||||
 | 
						drain *atomic.Bool
 | 
				
			||||||
 | 
						c     *atomic.Int64
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func WrapWaitGroup(wg *sync.WaitGroup) *WaitGroup {
 | 
				
			||||||
 | 
						g := &WaitGroup{
 | 
				
			||||||
 | 
							wg:    wg,
 | 
				
			||||||
 | 
							c:     &(atomic.Int64{}),
 | 
				
			||||||
 | 
							drain: &(atomic.Bool{}),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return g
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewWaitGroup() *WaitGroup {
 | 
				
			||||||
 | 
						var wg sync.WaitGroup
 | 
				
			||||||
 | 
						g := &WaitGroup{
 | 
				
			||||||
 | 
							wg:    &wg,
 | 
				
			||||||
 | 
							c:     &(atomic.Int64{}),
 | 
				
			||||||
 | 
							drain: &(atomic.Bool{}),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return g
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (g *WaitGroup) Add(n int) {
 | 
				
			||||||
 | 
						g.c.Add(int64(n))
 | 
				
			||||||
 | 
						g.wg.Add(n)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (g *WaitGroup) Done() {
 | 
				
			||||||
 | 
						if g.drain.Load() {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						g.c.Add(int64(-1))
 | 
				
			||||||
 | 
						g.wg.Done()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (g *WaitGroup) Wait() {
 | 
				
			||||||
 | 
						g.wg.Wait()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (g *WaitGroup) WaitContext(ctx context.Context) {
 | 
				
			||||||
 | 
						done := make(chan struct{})
 | 
				
			||||||
 | 
						go func() {
 | 
				
			||||||
 | 
							g.wg.Wait()
 | 
				
			||||||
 | 
							close(done)
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						select {
 | 
				
			||||||
 | 
						case <-ctx.Done():
 | 
				
			||||||
 | 
							g.drain.Swap(true)
 | 
				
			||||||
 | 
							for g.c.Load() > 0 {
 | 
				
			||||||
 | 
								select {
 | 
				
			||||||
 | 
								case <-done:
 | 
				
			||||||
 | 
									g.drain.Swap(false)
 | 
				
			||||||
 | 
									return
 | 
				
			||||||
 | 
								default:
 | 
				
			||||||
 | 
									g.wg.Done()
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						case <-done:
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										16
									
								
								sync/waitgroup_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								sync/waitgroup_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,16 @@
 | 
				
			|||||||
 | 
					package sync
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestWaitGroup(t *testing.T) {
 | 
				
			||||||
 | 
						wg := NewWaitGroup()
 | 
				
			||||||
 | 
						_ = t
 | 
				
			||||||
 | 
						wg.Add(1)
 | 
				
			||||||
 | 
						ctx, cancel := context.WithTimeout(context.TODO(), 1*time.Second)
 | 
				
			||||||
 | 
						defer cancel()
 | 
				
			||||||
 | 
						wg.WaitContext(ctx)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
		Reference in New Issue
	
	Block a user