diff --git a/hooks/sql/cluster.go b/hooks/sql/cluster.go index f30a1828..e731e2c1 100644 --- a/hooks/sql/cluster.go +++ b/hooks/sql/cluster.go @@ -84,6 +84,14 @@ func NewCluster[T Querier](opts ...ClusterOption) (ClusterQuerier, error) { return nil, ErrClusterPicker } + if options.Retries < 1 { + options.Retries = 1 + } + + if options.NodeStateCriterion == 0 { + options.NodeStateCriterion = hasql.Primary + } + options.Options = append(options.Options, hasql.WithNodePicker(options.NodePicker)) if p, ok := options.NodePicker.(*CustomPicker[Querier]); ok { p.opts.Priority = options.NodePriority @@ -111,18 +119,11 @@ func NodeStateCriterion(ctx context.Context, c hasql.NodeStateCriterion) context return context.WithValue(ctx, nodeStateCriterionKey{}, c) } -func getNodeStateCriterion(ctx context.Context) hasql.NodeStateCriterion { - if v, ok := ctx.Value(nodeStateCriterionKey{}).(hasql.NodeStateCriterion); ok { - return v - } - return hasql.PreferPrimary -} - // CustomPickerOptions holds options to pick nodes type CustomPickerOptions struct { - MaxLag int - Priority map[string]int32 - RetryOnError bool + MaxLag int + Priority map[string]int32 + Retries int } // CustomPickerOption func apply option to CustomPickerOptions @@ -228,13 +229,14 @@ func (p *CustomPicker[T]) CompareNodes(a, b hasql.CheckedNode[T]) int { // ClusterOptions contains cluster specific options type ClusterOptions struct { - NodeChecker hasql.NodeChecker - NodePicker hasql.NodePicker[Querier] - NodeDiscoverer hasql.NodeDiscoverer[Querier] - Options []hasql.ClusterOpt[Querier] - Context context.Context - RetryOnError bool - NodePriority map[string]int32 + NodeChecker hasql.NodeChecker + NodePicker hasql.NodePicker[Querier] + NodeDiscoverer hasql.NodeDiscoverer[Querier] + Options []hasql.ClusterOpt[Querier] + Context context.Context + Retries int + NodePriority map[string]int32 + NodeStateCriterion hasql.NodeStateCriterion } // ClusterOption apply cluster options to ClusterOptions @@ -261,10 +263,10 @@ func WithClusterNodeDiscoverer(d hasql.NodeDiscoverer[Querier]) ClusterOption { } } -// WithRetryOnError retry on other nodes on error -func WithRetryOnError(b bool) ClusterOption { +// WithRetries retry count on other nodes in case of error +func WithRetries(n int) ClusterOption { return func(o *ClusterOptions) { - o.RetryOnError = b + o.Retries = n } } @@ -282,6 +284,13 @@ func WithClusterOptions(opts ...hasql.ClusterOpt[Querier]) ClusterOption { } } +// WithClusterNodeStateCriterion pass default hasql.NodeStateCriterion +func WithClusterNodeStateCriterion(c hasql.NodeStateCriterion) ClusterOption { + return func(o *ClusterOptions) { + o.NodeStateCriterion = c + } +} + type ClusterNode struct { Name string DB Querier @@ -310,9 +319,12 @@ func (c *Cluster) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, er var tx *sql.Tx var err error - c.hasql.NodesIter(getNodeStateCriterion(ctx))(func(n *hasql.Node[Querier]) bool { - if tx, err = n.DB().BeginTx(ctx, opts); err != nil && !c.options.RetryOnError { - return true + retries := 0 + c.hasql.NodesIter(c.getNodeStateCriterion(ctx))(func(n *hasql.Node[Querier]) bool { + for ; retries < c.options.Retries; retries++ { + if tx, err = n.DB().BeginTx(ctx, opts); err != nil && retries >= c.options.Retries { + return true + } } return false }) @@ -332,9 +344,12 @@ func (c *Cluster) Conn(ctx context.Context) (*sql.Conn, error) { var conn *sql.Conn var err error - c.hasql.NodesIter(getNodeStateCriterion(ctx))(func(n *hasql.Node[Querier]) bool { - if conn, err = n.DB().Conn(ctx); err != nil && !c.options.RetryOnError { - return true + retries := 0 + c.hasql.NodesIter(c.getNodeStateCriterion(ctx))(func(n *hasql.Node[Querier]) bool { + for ; retries < c.options.Retries; retries++ { + if conn, err = n.DB().Conn(ctx); err != nil && retries >= c.options.Retries { + return true + } } return false }) @@ -350,9 +365,12 @@ func (c *Cluster) ExecContext(ctx context.Context, query string, args ...interfa var res sql.Result var err error - c.hasql.NodesIter(getNodeStateCriterion(ctx))(func(n *hasql.Node[Querier]) bool { - if res, err = n.DB().ExecContext(ctx, query, args...); err != nil && !c.options.RetryOnError { - return true + retries := 0 + c.hasql.NodesIter(c.getNodeStateCriterion(ctx))(func(n *hasql.Node[Querier]) bool { + for ; retries < c.options.Retries; retries++ { + if res, err = n.DB().ExecContext(ctx, query, args...); err != nil && retries >= c.options.Retries { + return true + } } return false }) @@ -368,9 +386,12 @@ func (c *Cluster) PrepareContext(ctx context.Context, query string) (*sql.Stmt, var res *sql.Stmt var err error - c.hasql.NodesIter(getNodeStateCriterion(ctx))(func(n *hasql.Node[Querier]) bool { - if res, err = n.DB().PrepareContext(ctx, query); err != nil && !c.options.RetryOnError { - return true + retries := 0 + c.hasql.NodesIter(c.getNodeStateCriterion(ctx))(func(n *hasql.Node[Querier]) bool { + for ; retries < c.options.Retries; retries++ { + if res, err = n.DB().PrepareContext(ctx, query); err != nil && retries >= c.options.Retries { + return true + } } return false }) @@ -386,9 +407,12 @@ func (c *Cluster) QueryContext(ctx context.Context, query string, args ...interf var res *sql.Rows var err error - c.hasql.NodesIter(getNodeStateCriterion(ctx))(func(n *hasql.Node[Querier]) bool { - if res, err = n.DB().QueryContext(ctx, query); err != nil && err != sql.ErrNoRows && !c.options.RetryOnError { - return true + retries := 0 + c.hasql.NodesIter(c.getNodeStateCriterion(ctx))(func(n *hasql.Node[Querier]) bool { + for ; retries < c.options.Retries; retries++ { + if res, err = n.DB().QueryContext(ctx, query); err != nil && err != sql.ErrNoRows && retries >= c.options.Retries { + return true + } } return false }) @@ -402,12 +426,16 @@ func (c *Cluster) QueryContext(ctx context.Context, query string, args ...interf func (c *Cluster) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { var res *sql.Row - c.hasql.NodesIter(getNodeStateCriterion(ctx))(func(n *hasql.Node[Querier]) bool { - res = n.DB().QueryRowContext(ctx, query, args...) - if res.Err() == nil { - return false - } else if res.Err() != nil && !c.options.RetryOnError { - return false + + retries := 0 + c.hasql.NodesIter(c.getNodeStateCriterion(ctx))(func(n *hasql.Node[Querier]) bool { + for ; retries < c.options.Retries; retries++ { + res = n.DB().QueryRowContext(ctx, query, args...) + if res.Err() == nil { + return false + } else if res.Err() != nil && retries >= c.options.Retries { + return false + } } return true }) @@ -423,10 +451,13 @@ func (c *Cluster) PingContext(ctx context.Context) error { var err error var ok bool - c.hasql.NodesIter(getNodeStateCriterion(ctx))(func(n *hasql.Node[Querier]) bool { + retries := 0 + c.hasql.NodesIter(c.getNodeStateCriterion(ctx))(func(n *hasql.Node[Querier]) bool { ok = true - if err = n.DB().PingContext(ctx); err != nil && !c.options.RetryOnError { - return true + for ; retries < c.options.Retries; retries++ { + if err = n.DB().PingContext(ctx); err != nil && retries >= c.options.Retries { + return true + } } return false }) @@ -491,3 +522,10 @@ func (c *Cluster) Stats() sql.DBStats { }) return s } + +func (c *Cluster) getNodeStateCriterion(ctx context.Context) hasql.NodeStateCriterion { + if v, ok := ctx.Value(nodeStateCriterionKey{}).(hasql.NodeStateCriterion); ok { + return v + } + return c.options.NodeStateCriterion +}