diff --git a/segmentio.go b/segmentio.go index f736020..2add790 100644 --- a/segmentio.go +++ b/segmentio.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "sync" + "sync/atomic" "time" "github.com/google/uuid" @@ -31,7 +32,6 @@ type subscriber struct { topic string opts broker.SubscribeOptions closed bool - done chan struct{} group *kafka.ConsumerGroup cgcfg kafka.ConsumerGroupConfig brokerOpts broker.Options @@ -39,14 +39,14 @@ type subscriber struct { } type publication struct { - topic string - partition int - offset int64 - err error - ackErr *error - msg *broker.Message - ackCh chan map[string]map[int]int64 - sync.Mutex + topic string + partition int + offset int64 + err error + ackErr atomic.Value + msg *broker.Message + ackCh chan map[string]map[int]int64 + readerDone *int32 } func (p *publication) Topic() string { @@ -58,8 +58,14 @@ func (p *publication) Message() *broker.Message { } func (p *publication) Ack() error { + if atomic.LoadInt32(p.readerDone) == 1 { + return kafka.ErrGroupClosed + } p.ackCh <- map[string]map[int]int64{p.topic: {p.partition: p.offset}} - return *p.ackErr + if cerr := p.ackErr.Load(); cerr != nil { + return cerr.(error) + } + return nil } func (p *publication) Error() error { @@ -79,7 +85,6 @@ func (s *subscriber) Unsubscribe(ctx context.Context) error { s.Lock() s.closed = true group := s.group - close(s.done) s.Unlock() if group != nil { @@ -266,26 +271,42 @@ func (k *kBroker) Subscribe(ctx context.Context, topic string, handler broker.Ha StartOffset: k.readerConfig.StartOffset, Logger: k.readerConfig.Logger, ErrorLogger: k.readerConfig.ErrorLogger, + Dialer: k.readerConfig.Dialer, } if err := cgcfg.Validate(); err != nil { return nil, err } - - cgroup, err := kafka.NewConsumerGroup(cgcfg) - if err != nil { - return nil, err + gCtx := k.opts.Context + if ctx != nil { + gCtx = ctx } - sub := &subscriber{brokerOpts: k.opts, opts: opt, topic: topic, group: cgroup, cgcfg: cgcfg, done: make(chan struct{})} + sub := &subscriber{brokerOpts: k.opts, opts: opt, topic: topic, cgcfg: cgcfg} + sub.createGroup(gCtx) + go func() { + defer func() { + sub.RLock() + closed := sub.closed + sub.RUnlock() + if !closed { + if err := sub.group.Close(); err != nil { + k.opts.Logger.Errorf(k.opts.Context, "[segmentio] consumer group close error %v", err) + } + } + }() + for { select { - case <-sub.done: - return case <-ctx.Done(): - // unexpected context closed + sub.RLock() + closed := sub.closed + sub.RUnlock() + if closed { + return + } if k.opts.Context.Err() != nil && k.opts.Logger.V(logger.ErrorLevel) { - k.opts.Logger.Errorf(k.opts.Context, "[segmentio] context closed unexpected %v", k.opts.Context.Err()) + k.opts.Logger.Errorf(k.opts.Context, "[segmentio] subscribe context closed %v", k.opts.Context.Err()) } return case <-k.opts.Context.Done(): @@ -293,77 +314,67 @@ func (k *kBroker) Subscribe(ctx context.Context, topic string, handler broker.Ha closed := sub.closed sub.RUnlock() if closed { - // unsubcribed and closed return } - // unexpected context closed if k.opts.Context.Err() != nil && k.opts.Logger.V(logger.ErrorLevel) { - k.opts.Logger.Errorf(k.opts.Context, "[segmentio] context closed unexpected %v", k.opts.Context.Err()) + k.opts.Logger.Errorf(k.opts.Context, "[segmentio] broker context closed error %v", k.opts.Context.Err()) } return default: sub.RLock() - group := sub.group closed := sub.closed sub.RUnlock() if closed { return } - gCtx := k.opts.Context - if ctx != nil { - gCtx = ctx - } - generation, err := group.Next(gCtx) + generation, err := sub.group.Next(gCtx) switch err { case nil: // normal execution case kafka.ErrGroupClosed: + k.opts.Logger.Tracef(k.opts.Context, "group closed %v", err) sub.RLock() closed := sub.closed sub.RUnlock() - if !closed { - if k.opts.Logger.V(logger.ErrorLevel) { - k.opts.Logger.Errorf(k.opts.Context, "[segmentio] recreate consumer group, as it closed by kafka %v", k.opts.Context.Err()) - } - if err = group.Close(); err != nil && k.opts.Logger.V(logger.ErrorLevel) { - k.opts.Logger.Errorf(k.opts.Context, "[segmentio] consumer group close error %v", err) - continue - } - sub.createGroup(gCtx) - continue + if closed { + return } - return + if k.opts.Logger.V(logger.ErrorLevel) { + k.opts.Logger.Errorf(k.opts.Context, "[segmentio] recreate consumer group, as it closed by kafka %v", k.opts.Context.Err()) + } + sub.createGroup(gCtx) + continue default: + k.opts.Logger.Tracef(k.opts.Context, "some error: %v", err) sub.RLock() closed := sub.closed sub.RUnlock() - if !closed { - if k.opts.Logger.V(logger.TraceLevel) { - k.opts.Logger.Tracef(k.opts.Context, "[segmentio] recreate consumer group, as unexpected consumer error %T %v", err, err) - } + if closed { + return } - if err = group.Close(); err != nil && k.opts.Logger.V(logger.ErrorLevel) { - k.opts.Logger.Errorf(k.opts.Context, "[segmentio] consumer group close error %v", err) + if k.opts.Logger.V(logger.TraceLevel) { + k.opts.Logger.Tracef(k.opts.Context, "[segmentio] recreate consumer group, as unexpected consumer error %T %v", err, err) } sub.createGroup(gCtx) continue } - var wg sync.WaitGroup - ackCh := make(chan map[string]map[int]int64, DefaultCommitQueueSize) errChLen := 0 for _, assignments := range generation.Assignments { errChLen += len(assignments) } - errChs := make([]chan error, errChLen) + errChs := make([]chan error, 0, errChLen) + + commitDoneCh := make(chan bool) + readerDone := int32(0) + cntWait := int32(0) for topic, assignments := range generation.Assignments { if k.opts.Logger.V(logger.TraceLevel) { k.opts.Logger.Tracef(k.opts.Context, "topic: %s assignments: %v", topic, assignments) } for _, assignment := range assignments { - errCh := make(chan error) cfg := k.readerConfig cfg.Topic = topic cfg.Partition = assignment.ID @@ -374,9 +385,27 @@ func (k *kBroker) Subscribe(ctx context.Context, topic string, handler broker.Ha if k.opts.Logger.V(logger.ErrorLevel) { k.opts.Logger.Errorf(k.opts.Context, "assignments offset %d can be set by reader: %v", assignment.Offset, err) } + if err = reader.Close(); err != nil { + if k.opts.Logger.V(logger.ErrorLevel) { + k.opts.Logger.Errorf(k.opts.Context, "reader close err: %v", err) + } + } continue } - cgh := &cgHandler{brokerOpts: k.opts, subOpts: opt, reader: reader, handler: handler, ackCh: ackCh, errCh: errCh, wg: &wg} + errCh := make(chan error) + errChs = append(errChs, errCh) + cgh := &cgHandler{ + brokerOpts: k.opts, + subOpts: opt, + reader: reader, + handler: handler, + ackCh: ackCh, + errCh: errCh, + cntWait: &cntWait, + readerDone: &readerDone, + commitDoneCh: commitDoneCh, + } + atomic.AddInt32(cgh.cntWait, 1) generation.Start(cgh.run) } } @@ -384,7 +413,7 @@ func (k *kBroker) Subscribe(ctx context.Context, topic string, handler broker.Ha k.opts.Logger.Trace(k.opts.Context, "start async commit loop") } // run async commit loop - go k.commitLoop(generation, k.readerConfig.CommitInterval, ackCh, errChs, &wg) + go k.commitLoop(generation, k.readerConfig.CommitInterval, ackCh, errChs, &readerDone, commitDoneCh, &cntWait) } } }() @@ -393,16 +422,22 @@ func (k *kBroker) Subscribe(ctx context.Context, topic string, handler broker.Ha } type cgHandler struct { - brokerOpts broker.Options - subOpts broker.SubscribeOptions - reader *kafka.Reader - handler broker.Handler - ackCh chan map[string]map[int]int64 - errCh chan error - wg *sync.WaitGroup + brokerOpts broker.Options + subOpts broker.SubscribeOptions + reader *kafka.Reader + handler broker.Handler + ackCh chan map[string]map[int]int64 + errCh chan error + readerDone *int32 + commitDoneCh chan bool + cntWait *int32 } -func (k *kBroker) commitLoop(generation *kafka.Generation, commitInterval time.Duration, ackCh chan map[string]map[int]int64, errChs []chan error, wg *sync.WaitGroup) { +func (k *kBroker) commitLoop(generation *kafka.Generation, commitInterval time.Duration, ackCh chan map[string]map[int]int64, errChs []chan error, readerDone *int32, commitDoneCh chan bool, cntWait *int32) { + + if k.opts.Logger.V(logger.TraceLevel) { + k.opts.Logger.Trace(k.opts.Context, "start commit loop") + } td := DefaultCommitInterval @@ -414,22 +449,101 @@ func (k *kBroker) commitLoop(generation *kafka.Generation, commitInterval time.D td = v } - // async commit loop - if td > 0 { - ticker := time.NewTicker(td) - defer ticker.Stop() + var mapMu sync.Mutex + offsets := make(map[string]map[int]int64, 4) - var mapMu sync.Mutex - offsets := make(map[string]map[int]int64, 4) + go func() { + defer func() { + close(commitDoneCh) + }() + + checkTicker := time.NewTicker(300 * time.Millisecond) + defer checkTicker.Stop() for { select { - default: - wg.Wait() - if k.opts.Logger.V(logger.TraceLevel) { - k.opts.Logger.Trace(k.opts.Context, "all readers are done, return from commit loop") + case <-checkTicker.C: + if atomic.LoadInt32(cntWait) == 0 { + mapMu.Lock() + if len(offsets) > 0 { + if err := generation.CommitOffsets(offsets); err != nil { + for _, errCh := range errChs { + errCh <- err + } + return + } + } + mapMu.Unlock() + if k.opts.Logger.V(logger.TraceLevel) { + k.opts.Logger.Trace(k.opts.Context, "stop commit loop") + } + return + } + case ack := <-ackCh: + if k.opts.Logger.V(logger.TraceLevel) { + k.opts.Logger.Tracef(k.opts.Context, "new commit offsets: %v", ack) + } + switch td { + case 0: // sync commits as CommitInterval == 0 + if len(ack) > 0 { + err := generation.CommitOffsets(ack) + if err != nil { + for _, errCh := range errChs { + errCh <- err + } + return + } + } + default: // async commits as CommitInterval > 0 + mapMu.Lock() + for t, p := range ack { + if _, ok := offsets[t]; !ok { + offsets[t] = make(map[int]int64, 4) + } + for k, v := range p { + offsets[t][k] = v + } + } + mapMu.Unlock() + } + // check for readers done and commit offsets + if atomic.LoadInt32(cntWait) == 0 { + mapMu.Lock() + if len(offsets) > 0 { + if err := generation.CommitOffsets(offsets); err != nil { + for _, errCh := range errChs { + errCh <- err + } + return + } + } + mapMu.Unlock() + if k.opts.Logger.V(logger.TraceLevel) { + k.opts.Logger.Trace(k.opts.Context, "stop commit loop") + } + return + } + } + } + }() + + // async commit loop + if td > 0 { + ticker := time.NewTicker(td) + doneTicker := time.NewTicker(300 * time.Millisecond) + defer doneTicker.Stop() + + for { + select { + case <-doneTicker.C: + if atomic.LoadInt32(readerDone) == 1 { + mapMu.Lock() + if len(offsets) == 0 { + defer ticker.Stop() + return + } + ticker.Stop() } - return case <-ticker.C: mapMu.Lock() if len(offsets) == 0 { @@ -443,74 +557,56 @@ func (k *kBroker) commitLoop(generation *kafka.Generation, commitInterval time.D if err != nil { for _, errCh := range errChs { errCh <- err - close(errCh) } mapMu.Unlock() return } - mapMu.Unlock() offsets = make(map[string]map[int]int64, 4) - } - } - } - - // sync commit loop - for { - select { - default: - wg.Wait() - if k.opts.Logger.V(logger.TraceLevel) { - k.opts.Logger.Trace(k.opts.Context, "all readers are done, return from commit loop") - } - return - case ack := <-ackCh: - if k.opts.Logger.V(logger.TraceLevel) { - k.opts.Logger.Tracef(k.opts.Context, "sync commit offsets: %v", ack) - } - err := generation.CommitOffsets(ack) - if err != nil { - for _, errCh := range errChs { - errCh <- err - close(errCh) + mapMu.Unlock() + if atomic.LoadInt32(readerDone) == 1 && atomic.LoadInt32(cntWait) == 0 { + return } - return } } } - } func (h *cgHandler) run(ctx context.Context) { + if h.brokerOpts.Logger.V(logger.TraceLevel) { + h.brokerOpts.Logger.Trace(ctx, "start partition reader") + } + td := DefaultStatsInterval if v, ok := h.brokerOpts.Context.Value(statsIntervalKey{}).(time.Duration); ok && td > 0 { td = v } + // start stats loop go readerStats(ctx, h.reader, td, h.brokerOpts.Meter) - commitDuration := DefaultCommitInterval - if v, ok := h.brokerOpts.Context.Value(commitIntervalKey{}).(time.Duration); ok && td > 0 { - commitDuration = v - } - - var commitErr error - - h.wg.Add(1) + var commitErr atomic.Value defer func() { - h.wg.Done() + atomic.AddInt32(h.cntWait, -1) + + atomic.CompareAndSwapInt32(h.readerDone, 0, 1) if err := h.reader.Close(); err != nil && h.brokerOpts.Logger.V(logger.ErrorLevel) { h.brokerOpts.Logger.Errorf(h.brokerOpts.Context, "[segmentio] reader close error: %v", err) } + <-h.commitDoneCh + if h.brokerOpts.Logger.V(logger.TraceLevel) { + h.brokerOpts.Logger.Trace(ctx, "stop partition reader") + } }() go func() { for { select { case err := <-h.errCh: - commitErr = err + if err != nil { + commitErr.Store(err) + } case <-ctx.Done(): - time.Sleep(commitDuration) return } } @@ -527,16 +623,23 @@ func (h *cgHandler) run(ctx context.Context) { case kafka.ErrGenerationEnded: // generation has ended if h.brokerOpts.Logger.V(logger.TraceLevel) { - h.brokerOpts.Logger.Trace(h.brokerOpts.Context, "[segmentio] generation ended, rebalance") + h.brokerOpts.Logger.Trace(h.brokerOpts.Context, "[segmentio] generation ended, rebalance or close") } return case nil: + if cerr := commitErr.Load(); cerr != nil { + if h.brokerOpts.Logger.V(logger.ErrorLevel) { + h.brokerOpts.Logger.Errorf(h.brokerOpts.Context, "[segmentio] commit error: %v", cerr) + } + return + } + eh := h.brokerOpts.ErrorHandler if h.subOpts.ErrorHandler != nil { eh = h.subOpts.ErrorHandler } - p := &publication{ackCh: h.ackCh, partition: msg.Partition, offset: msg.Offset, topic: msg.Topic, msg: &broker.Message{}} + p := &publication{ackCh: h.ackCh, partition: msg.Partition, offset: msg.Offset + 1, topic: msg.Topic, msg: &broker.Message{}, readerDone: h.readerDone} if h.subOpts.BodyOnly { p.msg.Body = msg.Value @@ -554,14 +657,14 @@ func (h *cgHandler) run(ctx context.Context) { continue } } - p.Lock() - p.ackErr = &commitErr - p.Unlock() + if cerr := commitErr.Load(); cerr != nil { + p.ackErr.Store(cerr.(bool)) + } err = h.handler(p) if err == nil && h.subOpts.AutoAck { if err = p.Ack(); err != nil { if h.brokerOpts.Logger.V(logger.ErrorLevel) { - h.brokerOpts.Logger.Errorf(h.brokerOpts.Context, "[segmentio]: unable to commit msg: %v", err) + h.brokerOpts.Logger.Errorf(h.brokerOpts.Context, "[segmentio]: message ack error: %v", err) return } }