diff --git a/options.go b/options.go index a52bb81..2982ee0 100644 --- a/options.go +++ b/options.go @@ -2,6 +2,7 @@ package stan import ( "context" + "time" "github.com/micro/go-micro/broker" stan "github.com/nats-io/go-nats-streaming" @@ -40,3 +41,17 @@ type ackSuccessKey struct{} func AckOnSuccess() broker.SubscribeOption { return setSubscribeOption(ackSuccessKey{}, true) } + +type timeoutKey struct{} + +// Timeout for connecting to broker -1 infinitive or time.Duration value +func Timeout(td time.Duration) broker.Option { + return setBrokerOption(timeoutKey{}, td) +} + +type reconnectKey struct{} + +// Reconnect to broker in case of errors +func Reconnect(v bool) broker.Option { + return setBrokerOption(reconnectKey{}, v) +} diff --git a/stan.go b/stan.go index 9b36c3b..f34744d 100644 --- a/stan.go +++ b/stan.go @@ -4,10 +4,13 @@ package stan import ( "context" "errors" + "fmt" "strings" "sync" + "time" "github.com/google/uuid" + log "github.com/micro/go-log" "github.com/micro/go-micro/broker" "github.com/micro/go-micro/cmd" "github.com/micro/go-micro/codec/json" @@ -16,10 +19,16 @@ import ( type stanBroker struct { sync.RWMutex - addrs []string - conn stan.Conn - opts broker.Options - nopts stan.Options + addrs []string + conn stan.Conn + opts broker.Options + sopts stan.Options + nopts []stan.Option + clusterID string + timeout time.Duration + reconnect bool + done chan struct{} + ctx context.Context } type subscriber struct { @@ -108,6 +117,66 @@ func setAddrs(addrs []string) []string { return cAddrs } +func (n *stanBroker) reconnectCB(c stan.Conn, err error) { + if n.reconnect { + if err := n.connect(); err != nil { + log.Log(err.Error()) + } + } +} + +func (n *stanBroker) connect() error { + timeout := make(<-chan time.Time) + + if n.timeout > 0 { + timeout = time.After(n.timeout) + } + + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + fn := func() error { + clientID := uuid.New().String() + c, err := stan.Connect(n.clusterID, clientID, n.nopts...) + if err == nil { + n.Lock() + n.conn = c + n.Unlock() + } + return err + } + + // don't wait for first try + if err := fn(); err == nil { + return nil + } + + // wait loop + for { + select { + // context closed + case <-n.opts.Context.Done(): + return nil + // call close, don't wait anymore + case <-n.done: + return nil + // in case of timeout fail with a timeout error + case <-timeout: + return fmt.Errorf("timeout connect to %v", n.addrs) + // got a tick, try to connect + case <-ticker.C: + err := fn() + if err == nil { + log.Logf("successeful connected to %v", n.addrs) + return nil + } + log.Logf("failed to connect %v: %v\n", n.addrs, err) + } + } + + return nil +} + func (n *stanBroker) Connect() error { n.RLock() if n.conn != nil { @@ -116,41 +185,66 @@ func (n *stanBroker) Connect() error { } n.RUnlock() - opts := n.nopts - opts.NatsURL = strings.Join(n.addrs, ",") - clusterID, ok := n.opts.Context.Value(clusterIDKey{}).(string) if !ok || len(clusterID) == 0 { return errors.New("must specify ClusterID Option") } - clientID := uuid.New().String() + var reconnect bool + if val, ok := n.opts.Context.Value(reconnectKey{}).(bool); ok && val { + reconnect = val + } + + var timeout time.Duration + if td, ok := n.opts.Context.Value(timeoutKey{}).(time.Duration); ok { + timeout = td + } else { + timeout = 5 * time.Second + } + + if n.sopts.ConnectionLostCB != nil && reconnect { + return errors.New("impossible to use custom ConnectionLostCB and Reconnect(true)") + } + + if reconnect { + n.sopts.ConnectionLostCB = n.reconnectCB + } nopts := []stan.Option{ - stan.NatsURL(opts.NatsURL), - stan.NatsConn(opts.NatsConn), - stan.ConnectWait(opts.ConnectTimeout), - stan.PubAckWait(opts.AckTimeout), - stan.MaxPubAcksInflight(opts.MaxPubAcksInflight), - stan.Pings(opts.PingIterval, opts.PingMaxOut), - stan.SetConnectionLostHandler(opts.ConnectionLostCB), + stan.NatsURL(n.sopts.NatsURL), + stan.NatsConn(n.sopts.NatsConn), + stan.ConnectWait(n.sopts.ConnectTimeout), + stan.PubAckWait(n.sopts.AckTimeout), + stan.MaxPubAcksInflight(n.sopts.MaxPubAcksInflight), + stan.Pings(n.sopts.PingIterval, n.sopts.PingMaxOut), + stan.SetConnectionLostHandler(n.sopts.ConnectionLostCB), } + nopts = append(nopts, stan.NatsURL(strings.Join(n.addrs, ","))) - c, err := stan.Connect(clusterID, clientID, nopts...) - if err != nil { - return err - } n.Lock() - n.conn = c + n.nopts = nopts + n.clusterID = clusterID + n.timeout = timeout + n.reconnect = reconnect n.Unlock() - return nil + + return n.connect() } func (n *stanBroker) Disconnect() error { - n.RLock() - n.conn.Close() - n.RUnlock() - return nil + var err error + + n.Lock() + defer n.Unlock() + + if n.done != nil { + close(n.done) + n.done = nil + } + if n.conn != nil { + err = n.conn.Close() + } + return err } func (n *stanBroker) Init(opts ...broker.Option) error { @@ -279,8 +373,9 @@ func NewBroker(opts ...broker.Option) broker.Broker { } nb := &stanBroker{ + done: make(chan struct{}), opts: options, - nopts: stanOpts, + sopts: stanOpts, addrs: setAddrs(options.Addrs), }