diff --git a/broker/nats/nats.go b/broker/nats/nats.go index 027f1dac..eecb16ec 100644 --- a/broker/nats/nats.go +++ b/broker/nats/nats.go @@ -13,18 +13,19 @@ import ( ) type natsBroker struct { + sync.Once sync.RWMutex - addrs []string - conn *nats.Conn - opts broker.Options - nopts nats.Options - drain bool + addrs []string + conn *nats.Conn + opts broker.Options + nopts nats.Options + drain bool + closeCh chan (error) } type subscriber struct { - s *nats.Subscription - opts broker.SubscribeOptions - drain bool + s *nats.Subscription + opts broker.SubscribeOptions } type publication struct { @@ -54,9 +55,6 @@ func (s *subscriber) Topic() string { } func (s *subscriber) Unsubscribe() error { - if s.drain { - return s.s.Drain() - } return s.s.Unsubscribe() } @@ -122,20 +120,17 @@ func (n *natsBroker) Connect() error { func (n *natsBroker) Disconnect() error { n.RLock() + defer n.RUnlock() if n.drain { n.conn.Drain() - } else { - n.conn.Close() + return <-n.closeCh } - n.RUnlock() + n.conn.Close() return nil } func (n *natsBroker) Init(opts ...broker.Option) error { - for _, o := range opts { - o(&n.opts) - } - n.addrs = setAddrs(n.opts.Addrs) + n.setOption(opts...) return nil } @@ -167,11 +162,6 @@ func (n *natsBroker) Subscribe(topic string, handler broker.Handler, opts ...bro o(&opt) } - var drain bool - if _, ok := opt.Context.Value(drainSubscriptionKey{}).(bool); ok { - drain = true - } - fn := func(msg *nats.Msg) { var m broker.Message if err := n.opts.Codec.Unmarshal(msg.Data, &m); err != nil { @@ -193,7 +183,7 @@ func (n *natsBroker) Subscribe(topic string, handler broker.Handler, opts ...bro if err != nil { return nil, err } - return &subscriber{s: sub, opts: opt, drain: drain}, nil + return &subscriber{s: sub, opts: opt}, nil } func (n *natsBroker) String() string { @@ -207,39 +197,59 @@ func NewBroker(opts ...broker.Option) broker.Broker { Context: context.Background(), } + n := &natsBroker{ + opts: options, + } + n.setOption(opts...) + + return n +} + +func (n *natsBroker) setOption(opts ...broker.Option) { for _, o := range opts { - o(&options) + o(&n.opts) } - natsOpts := nats.GetDefaultOptions() - if n, ok := options.Context.Value(optionsKey{}).(nats.Options); ok { - natsOpts = n - } + n.Once.Do(func() { + n.nopts = nats.GetDefaultOptions() + }) - var drain bool - if _, ok := options.Context.Value(drainSubscriptionKey{}).(bool); ok { - drain = true + if nopts, ok := n.opts.Context.Value(optionsKey{}).(nats.Options); ok { + n.nopts = nopts } // broker.Options have higher priority than nats.Options // only if Addrs, Secure or TLSConfig were not set through a broker.Option // we read them from nats.Option - if len(options.Addrs) == 0 { - options.Addrs = natsOpts.Servers + if len(n.opts.Addrs) == 0 { + n.opts.Addrs = n.nopts.Servers } - if !options.Secure { - options.Secure = natsOpts.Secure + if !n.opts.Secure { + n.opts.Secure = n.nopts.Secure } - if options.TLSConfig == nil { - options.TLSConfig = natsOpts.TLSConfig + if n.opts.TLSConfig == nil { + n.opts.TLSConfig = n.nopts.TLSConfig } + n.addrs = setAddrs(n.opts.Addrs) - return &natsBroker{ - opts: options, - nopts: natsOpts, - addrs: setAddrs(options.Addrs), - drain: drain, + if n.opts.Context.Value(drainConnectionKey{}) != nil { + n.drain = true + n.closeCh = make(chan error) + n.nopts.ClosedCB = n.onClose + n.nopts.AsyncErrorCB = n.onAsyncError + } +} + +func (n *natsBroker) onClose(conn *nats.Conn) { + n.closeCh <- nil +} + +func (n *natsBroker) onAsyncError(conn *nats.Conn, sub *nats.Subscription, err error) { + // There are kinds of different async error nats might callback, but we are interested + // in ErrDrainTimeout only here. + if err == nats.ErrDrainTimeout { + n.closeCh <- err } } diff --git a/broker/nats/options.go b/broker/nats/options.go index 47431606..b5b106c0 100644 --- a/broker/nats/options.go +++ b/broker/nats/options.go @@ -7,7 +7,6 @@ import ( type optionsKey struct{} type drainConnectionKey struct{} -type drainSubscriptionKey struct{} // Options accepts nats.Options func Options(opts nats.Options) broker.Option { @@ -16,10 +15,5 @@ func Options(opts nats.Options) broker.Option { // DrainConnection will drain subscription on close func DrainConnection() broker.Option { - return setBrokerOption(drainConnectionKey{}, true) -} - -// DrainSubscription will drain pending messages when unsubscribe -func DrainSubscription() broker.SubscribeOption { - return setSubscribeOption(drainSubscriptionKey{}, true) + return setBrokerOption(drainConnectionKey{}, struct{}{}) }