package kgo import ( "context" "fmt" "strconv" "sync" "sync/atomic" "time" "github.com/twmb/franz-go/pkg/kadm" "github.com/twmb/franz-go/pkg/kgo" "github.com/twmb/franz-go/pkg/kmsg" "go.unistack.org/micro/v3/broker" "go.unistack.org/micro/v3/logger" "go.unistack.org/micro/v3/metadata" "go.unistack.org/micro/v3/meter" "go.unistack.org/micro/v3/semconv" "go.unistack.org/micro/v3/tracer" ) type tp struct { t string p int32 } type consumer struct { topic string c *kgo.Client htracer *hookTracer connected *atomic.Uint32 quit chan struct{} done chan struct{} recs chan kgo.FetchTopicPartition handler broker.Handler kopts broker.Options opts broker.SubscribeOptions partition int32 } type Subscriber struct { topic string consumers map[tp]*consumer c *kgo.Client htracer *hookTracer connected *atomic.Uint32 handler broker.Handler done chan struct{} kopts broker.Options opts broker.SubscribeOptions fatalOnError bool closed atomic.Bool sync.RWMutex sync.WaitGroup } func (s *Subscriber) Client() *kgo.Client { return s.c } func (s *Subscriber) Options() broker.SubscribeOptions { return s.opts } func (s *Subscriber) Topic() string { return s.topic } func (s *Subscriber) Unsubscribe(ctx context.Context) error { if s.closed.Load() { return nil } s.Wait() s.c.PauseFetchTopics(s.topic) s.c.CloseAllowingRebalance() kc := make(map[string][]int32) for ctp := range s.consumers { kc[ctp.t] = append(kc[ctp.t], ctp.p) } s.killConsumers(ctx, kc) close(s.done) s.closed.Store(true) s.c.ResumeFetchTopics(s.topic) s.c.Close() return nil } func (s *Subscriber) poll(ctx context.Context) { maxInflight := DefaultSubscribeMaxInflight if s.opts.Context != nil { if n, ok := s.opts.Context.Value(subscribeMaxInflightKey{}).(int); n > 0 && ok { maxInflight = n } } s.Add(1) go s.pollLag(ctx) for { select { case <-ctx.Done(): s.c.CloseAllowingRebalance() return case <-s.done: return default: fetches := s.c.PollRecords(ctx, maxInflight) if !s.closed.Load() && fetches.IsClientClosed() { s.closed.Store(true) return } fetches.EachError(func(t string, p int32, err error) { s.kopts.Logger.Fatal(ctx, fmt.Sprintf("[kgo] fetch topic %s partition %d error", t, p), err) }) fetches.EachPartition(func(p kgo.FetchTopicPartition) { tps := tp{p.Topic, p.Partition} if consumer, ok := s.consumers[tps]; ok { select { case consumer.recs <- p: default: if s.kopts.Logger.V(logger.WarnLevel) { s.kopts.Logger.Warn(ctx, fmt.Sprintf("[kgo] consumer channel full topic %s partition %d", p.Topic, p.Partition)) } } } }) s.c.AllowRebalance() } } } func (s *Subscriber) pollLag(ctx context.Context) { ac := kadm.NewClient(s.c) ticker := time.NewTicker(DefaultStatsInterval) defer func() { s.Done() ticker.Stop() }() // кеш ключей метрик lag: map[partition]metricCounter type lagMetric struct { counter meter.Counter lastLag int64 } lagCache := make(map[int32]*lagMetric) for { select { case <-ctx.Done(): return case <-ticker.C: dgls, err := ac.Lag(ctx, s.opts.Group) if err != nil || !dgls.Ok() { continue } dgl, ok := dgls[s.opts.Group] if !ok { continue } lmap, ok := dgl.Lag[s.topic] if !ok { continue } s.Lock() for p, l := range lmap { lagVal := l.Lag if metric, exists := lagCache[p]; exists { if metric.lastLag != lagVal { metric.counter.Set(uint64(lagVal)) metric.lastLag = lagVal } } else { counter := s.kopts.Meter.Counter(semconv.BrokerGroupLag, "topic", s.topic, "group", s.opts.Group, "partition", strconv.Itoa(int(p))) counter.Set(uint64(lagVal)) lagCache[p] = &lagMetric{ counter: counter, lastLag: lagVal, } } } s.Unlock() } } } func (s *Subscriber) killConsumers(ctx context.Context, lost map[string][]int32) { var wg sync.WaitGroup s.Lock() for topic, partitions := range lost { for _, partition := range partitions { tps := tp{topic, partition} pc, ok := s.consumers[tps] if !ok { continue } delete(s.consumers, tps) close(pc.quit) if s.kopts.Logger.V(logger.DebugLevel) { s.kopts.Logger.Debug(ctx, fmt.Sprintf("[kgo] waiting for work to finish topic %s partition %d", topic, partition)) } wg.Add(1) go func(pc *consumer) { defer wg.Done() <-pc.done }(pc) } } s.Unlock() wg.Wait() } func (s *Subscriber) autocommit(_ *kgo.Client, _ *kmsg.OffsetCommitRequest, _ *kmsg.OffsetCommitResponse, err error) { if err != nil { // s.connected.Store(0) if s.fatalOnError { s.kopts.Logger.Fatal(context.TODO(), "kgo.AutoCommitCallback error", err) } } } func (s *Subscriber) lost(ctx context.Context, _ *kgo.Client, lost map[string][]int32) { if s.kopts.Logger.V(logger.ErrorLevel) { s.kopts.Logger.Error(ctx, fmt.Sprintf("[kgo] lost %#+v", lost)) } s.killConsumers(ctx, lost) // s.connected.Store(0) } func (s *Subscriber) revoked(ctx context.Context, c *kgo.Client, revoked map[string][]int32) { if s.kopts.Logger.V(logger.DebugLevel) { s.kopts.Logger.Debug(ctx, fmt.Sprintf("[kgo] revoked %#+v", revoked)) } s.killConsumers(ctx, revoked) if err := c.CommitMarkedOffsets(ctx); err != nil { s.kopts.Logger.Error(ctx, "[kgo] revoked CommitMarkedOffsets error", err) // s.connected.Store(0) } } func (s *Subscriber) assigned(_ context.Context, c *kgo.Client, assigned map[string][]int32) { for topic, partitions := range assigned { for _, partition := range partitions { pc := &consumer{ c: c, topic: topic, partition: partition, htracer: s.htracer, quit: make(chan struct{}), done: make(chan struct{}), recs: make(chan kgo.FetchTopicPartition, 100), handler: s.handler, kopts: s.kopts, opts: s.opts, connected: s.connected, } s.Lock() s.consumers[tp{topic, partition}] = pc s.Unlock() go pc.consume() } } } func (pc *consumer) consume() { defer close(pc.done) if pc.kopts.Logger.V(logger.DebugLevel) { pc.kopts.Logger.Debug(pc.kopts.Context, fmt.Sprintf("starting, topic %s partition %d", pc.topic, pc.partition)) defer pc.kopts.Logger.Debug(pc.kopts.Context, fmt.Sprintf("killing, topic %s partition %d", pc.topic, pc.partition)) } eh := pc.kopts.ErrorHandler if pc.opts.ErrorHandler != nil { eh = pc.opts.ErrorHandler } for { select { case <-pc.quit: return case p := <-pc.recs: pc.processBatch(p, eh) } } } func (pc *consumer) processBatch(p kgo.FetchTopicPartition, eh broker.Handler) { var successCount, failureCount int topic := pc.topic for _, record := range p.Records { ts := time.Now() pc.kopts.Meter.Counter(semconv.SubscribeMessageInflight, "endpoint", topic, "topic", topic).Inc() err := pc.handleRecord(record, eh) pc.kopts.Meter.Counter(semconv.SubscribeMessageInflight, "endpoint", topic, "topic", topic).Dec() te := time.Since(ts) pc.kopts.Meter.Summary(semconv.SubscribeMessageLatencyMicroseconds, "endpoint", topic, "topic", topic).Update(te.Seconds()) pc.kopts.Meter.Histogram(semconv.SubscribeMessageDurationSeconds, "endpoint", topic, "topic", topic).Update(te.Seconds()) if err == nil { successCount++ } else { failureCount++ } } if successCount > 0 { pc.kopts.Meter.Counter(semconv.SubscribeMessageTotal, "status", "success", "topic", topic).Add(successCount) } if failureCount > 0 { pc.kopts.Meter.Counter(semconv.SubscribeMessageTotal, "status", "failure", "topic", topic).Add(failureCount) } } func (pc *consumer) handleRecord(record *kgo.Record, eh broker.Handler) error { ctx, sp := pc.htracer.WithProcessSpan(record) p := eventPool.Get().(*event) p.reset() defer func() { eventPool.Put(p) if sp != nil { sp.Finish() } }() p.topic = record.Topic p.ctx = ctx p.msg.Header = metadata.New(len(record.Headers)) for _, hdr := range record.Headers { p.msg.Header.Set(hdr.Key, string(hdr.Value)) } if pc.kopts.Codec.String() == "noop" || pc.opts.BodyOnly { p.msg.Body = record.Value } else { if sp != nil { sp.AddEvent("codec unmarshal start") } err := pc.kopts.Codec.Unmarshal(record.Value, p.msg) if sp != nil { sp.AddEvent("codec unmarshal stop") } if err != nil { if sp != nil { sp.SetStatus(tracer.SpanStatusError, err.Error()) } p.err = err p.msg.Body = record.Value if eh != nil { _ = eh(p) if p.ack { pc.c.MarkCommitRecords(record) } else { pc.kopts.Logger.Fatal(pc.kopts.Context, "[kgo] ErrLostMessage wtf?") } return err } pc.kopts.Logger.Error(pc.kopts.Context, "[kgo]: unmarshal error", err) return err } } if sp != nil { sp.AddEvent("handler start") } err := pc.handler(p) if sp != nil { sp.AddEvent("handler stop") } if err == nil { if pc.opts.AutoAck { p.ack = true } } else { if sp != nil { sp.SetStatus(tracer.SpanStatusError, err.Error()) } p.err = err if eh != nil { _ = eh(p) } else if pc.kopts.Logger.V(logger.ErrorLevel) { pc.kopts.Logger.Error(pc.kopts.Context, "[kgo]: subscriber error", err) } } if p.ack { pc.c.MarkCommitRecords(record) } else { pc.kopts.Logger.Fatal(pc.kopts.Context, "[kgo] ErrLostMessage wtf?") } return nil }