package sync

import (
	"context"
	"sync"
)

type WaitGroup struct {
	wg *sync.WaitGroup
	c  int
	mu sync.Mutex
}

func WrapWaitGroup(wg *sync.WaitGroup) *WaitGroup {
	g := &WaitGroup{
		wg: wg,
	}
	return g
}

func NewWaitGroup() *WaitGroup {
	var wg sync.WaitGroup
	return WrapWaitGroup(&wg)
}

func (g *WaitGroup) Add(n int) {
	g.mu.Lock()
	g.c += n
	g.wg.Add(n)
	g.mu.Unlock()
}

func (g *WaitGroup) Done() {
	g.mu.Lock()
	g.c += -1
	g.wg.Add(-1)
	g.mu.Unlock()
}

func (g *WaitGroup) Wait() {
	g.wg.Wait()
}

func (g *WaitGroup) WaitContext(ctx context.Context) {
	done := make(chan struct{})
	go func() {
		g.wg.Wait()
		close(done)
	}()

	select {
	case <-ctx.Done():
		g.mu.Lock()
		g.wg.Add(-g.c)
		<-done
		g.wg.Add(g.c)
		g.mu.Unlock()
		return
	case <-done:
		return
	}
}

func (g *WaitGroup) Waiters() int {
	g.mu.Lock()
	c := g.c
	g.mu.Unlock()
	return c
}