diff --git a/push.go b/push.go index c62c823..15b1058 100644 --- a/push.go +++ b/push.go @@ -31,6 +31,9 @@ type PushOptions struct { // // By default the compression is enabled. DisableCompression bool + + // Optional WaitGroup for waiting until all the push workers created with this WaitGroup are stopped. + WaitGroup *sync.WaitGroup } // InitPushWithOptions sets up periodic push for globally registered metrics to the given pushURL with the given interval. @@ -207,6 +210,13 @@ func InitPushExtWithOptions(ctx context.Context, pushURL string, interval time.D } pushMetricsSet.GetOrCreateFloatCounter(fmt.Sprintf(`metrics_push_interval_seconds{url=%q}`, pc.pushURLRedacted)).Set(interval.Seconds()) + var wg *sync.WaitGroup + if opts != nil { + wg = opts.WaitGroup + if wg != nil { + wg.Add(1) + } + } go func() { ticker := time.NewTicker(interval) defer ticker.Stop() @@ -221,6 +231,9 @@ func InitPushExtWithOptions(ctx context.Context, pushURL string, interval time.D log.Printf("ERROR: metrics.push: %s", err) } case <-stopCh: + if wg != nil { + wg.Done() + } return } } diff --git a/push_test.go b/push_test.go index dd5376a..45029d8 100644 --- a/push_test.go +++ b/push_test.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "net/http/httptest" + "sync" "testing" "time" ) @@ -91,6 +92,10 @@ func TestInitPushWithOptions(t *testing.T) { })) defer srv.Close() ctx, cancel := context.WithCancel(context.Background()) + var wg sync.WaitGroup + if opts != nil { + opts.WaitGroup = &wg + } if err := s.InitPushWithOptions(ctx, srv.URL, time.Millisecond, opts); err != nil { t.Fatalf("unexpected error: %s", err) } @@ -100,6 +105,7 @@ func TestInitPushWithOptions(t *testing.T) { case <-doneCh: // stop the periodic pusher cancel() + wg.Wait() } if reqErr != nil { t.Fatalf("unexpected error: %s", reqErr)