diff --git a/redis.go b/redis.go index cccbf6d..3190219 100755 --- a/redis.go +++ b/redis.go @@ -45,26 +45,20 @@ var ( type Store struct { opts store.Options - cli redisClient + cli *wrappedClient pool pool.Pool[*strings.Builder] } -type redisClient interface { - Get(ctx context.Context, key string) *redis.StringCmd - Del(ctx context.Context, keys ...string) *redis.IntCmd - Set(ctx context.Context, key string, value interface{}, expiration time.Duration) *redis.StatusCmd - Keys(ctx context.Context, pattern string) *redis.StringSliceCmd - MGet(ctx context.Context, keys ...string) *redis.SliceCmd - MSet(ctx context.Context, kv ...interface{}) *redis.StatusCmd - Exists(ctx context.Context, keys ...string) *redis.IntCmd - Ping(ctx context.Context) *redis.StatusCmd - Pipeline() redis.Pipeliner - Pipelined(ctx context.Context, fn func(redis.Pipeliner) error) ([]redis.Cmder, error) - Close() error +type wrappedClient struct { + *redis.Client + *redis.ClusterClient } func (r *Store) Connect(ctx context.Context) error { - return r.cli.Ping(ctx).Err() + if r.cli.Client != nil { + return r.cli.Client.Ping(ctx).Err() + } + return r.cli.ClusterClient.Ping(ctx).Err() } func (r *Store) Init(opts ...store.Option) error { @@ -75,12 +69,25 @@ func (r *Store) Init(opts ...store.Option) error { return r.configure() } -func (r *Store) Redis() *redis.Client { - return r.cli.(*redis.Client) +func (r *Store) Client() *redis.Client { + if r.cli.Client != nil { + return r.cli.Client + } + return nil +} + +func (r *Store) ClusterClient() *redis.ClusterClient { + if r.cli.ClusterClient != nil { + return r.cli.ClusterClient + } + return nil } func (r *Store) Disconnect(ctx context.Context) error { - return r.cli.Close() + if r.cli.Client != nil { + return r.cli.Client.Close() + } + return r.cli.ClusterClient.Close() } func (r *Store) Exists(ctx context.Context, key string, opts ...store.ExistsOption) error { @@ -103,7 +110,13 @@ func (r *Store) Exists(ctx context.Context, key string, opts ...store.ExistsOpti r.opts.Meter.Counter(semconv.CacheRequestInflight, "name", options.Name).Inc() ts := time.Now() - val, err := r.cli.Exists(ctx, rkey).Result() + var err error + var val int64 + if r.cli.Client != nil { + val, err = r.cli.Client.Exists(ctx, rkey).Result() + } else { + val, err = r.cli.ClusterClient.Exists(ctx, rkey).Result() + } te := time.Since(ts) r.opts.Meter.Counter(semconv.CacheRequestInflight, "name", options.Name).Dec() r.opts.Meter.Summary(semconv.CacheRequestLatencyMicroseconds, "name", options.Name).Update(te.Seconds()) @@ -142,7 +155,13 @@ func (r *Store) Read(ctx context.Context, key string, val interface{}, opts ...s r.opts.Meter.Counter(semconv.CacheRequestInflight, "name", options.Name).Inc() ts := time.Now() - buf, err := r.cli.Get(ctx, rkey).Bytes() + var buf []byte + var err error + if r.cli.Client != nil { + buf, err = r.cli.Client.Get(ctx, rkey).Bytes() + } else { + buf, err = r.cli.ClusterClient.Get(ctx, rkey).Bytes() + } te := time.Since(ts) r.opts.Meter.Counter(semconv.CacheRequestInflight, "name", options.Name).Dec() r.opts.Meter.Summary(semconv.CacheRequestLatencyMicroseconds, "name", options.Name).Update(te.Seconds()) @@ -197,7 +216,13 @@ func (r *Store) MRead(ctx context.Context, keys []string, vals interface{}, opts r.opts.Meter.Counter(semconv.CacheRequestInflight, "name", options.Name).Inc() ts := time.Now() - rvals, err := r.cli.MGet(ctx, keys...).Result() + var rvals []interface{} + var err error + if r.cli.Client != nil { + rvals, err = r.cli.Client.MGet(ctx, keys...).Result() + } else { + rvals, err = r.cli.ClusterClient.MGet(ctx, keys...).Result() + } te := time.Since(ts) r.opts.Meter.Counter(semconv.CacheRequestInflight, "name", options.Name).Dec() r.opts.Meter.Summary(semconv.CacheRequestLatencyMicroseconds, "name", options.Name).Update(te.Seconds()) @@ -284,7 +309,12 @@ func (r *Store) MDelete(ctx context.Context, keys []string, opts ...store.Delete r.opts.Meter.Counter(semconv.CacheRequestInflight, "name", options.Name).Inc() ts := time.Now() - err := r.cli.Del(ctx, keys...).Err() + var err error + if r.cli.Client != nil { + err = r.cli.Client.Del(ctx, keys...).Err() + } else { + err = r.cli.ClusterClient.Del(ctx, keys...).Err() + } te := time.Since(ts) r.opts.Meter.Counter(semconv.CacheRequestInflight, "name", options.Name).Dec() r.opts.Meter.Summary(semconv.CacheRequestLatencyMicroseconds, "name", options.Name).Update(te.Seconds()) @@ -322,7 +352,12 @@ func (r *Store) Delete(ctx context.Context, key string, opts ...store.DeleteOpti r.opts.Meter.Counter(semconv.CacheRequestInflight, "name", options.Name).Inc() ts := time.Now() - err := r.cli.Del(ctx, r.getKey(r.opts.Namespace, options.Namespace, key)).Err() + var err error + if r.cli.Client != nil { + err = r.cli.Client.Del(ctx, r.getKey(r.opts.Namespace, options.Namespace, key)).Err() + } else { + err = r.cli.ClusterClient.Del(ctx, r.getKey(r.opts.Namespace, options.Namespace, key)).Err() + } te := time.Since(ts) r.opts.Meter.Counter(semconv.CacheRequestInflight, "name", options.Name).Dec() r.opts.Meter.Summary(semconv.CacheRequestLatencyMicroseconds, "name", options.Name).Update(te.Seconds()) @@ -381,14 +416,23 @@ func (r *Store) MWrite(ctx context.Context, keys []string, vals []interface{}, o r.opts.Meter.Counter(semconv.CacheRequestInflight, "name", options.Name).Inc() ts := time.Now() - cmds, err := r.cli.Pipelined(ctx, func(pipe redis.Pipeliner) error { + pipeliner := func(pipe redis.Pipeliner) error { for idx := 0; idx < len(kvs); idx += 2 { if _, err := pipe.Set(ctx, kvs[idx], kvs[idx+1], options.TTL).Result(); err != nil { return err } } return nil - }) + } + + var err error + var cmds []redis.Cmder + + if r.cli.Client != nil { + cmds, err = r.cli.Client.Pipelined(ctx, pipeliner) + } else { + cmds, err = r.cli.ClusterClient.Pipelined(ctx, pipeliner) + } te := time.Since(ts) r.opts.Meter.Counter(semconv.CacheRequestInflight, "name", options.Name).Dec() @@ -455,7 +499,12 @@ func (r *Store) Write(ctx context.Context, key string, val interface{}, opts ... r.opts.Meter.Counter(semconv.CacheRequestInflight, "name", options.Name).Inc() ts := time.Now() - err := r.cli.Set(ctx, rkey, buf, options.TTL).Err() + var err error + if r.cli.Client != nil { + err = r.cli.Client.Set(ctx, rkey, buf, options.TTL).Err() + } else { + err = r.cli.ClusterClient.Set(ctx, rkey, buf, options.TTL).Err() + } te := time.Since(ts) r.opts.Meter.Counter(semconv.CacheRequestInflight, "name", options.Name).Dec() r.opts.Meter.Summary(semconv.CacheRequestLatencyMicroseconds, "name", options.Name).Update(te.Seconds()) @@ -502,7 +551,21 @@ func (r *Store) List(ctx context.Context, opts ...store.ListOption) ([]string, e // TODO: add support for prefix/suffix/limit r.opts.Meter.Counter(semconv.CacheRequestInflight, "name", options.Name).Inc() ts := time.Now() - keys, err := r.cli.Keys(ctx, rkey).Result() + var keys []string + var err error + + if r.cli.Client != nil { + keys, err = r.cli.Client.Keys(ctx, rkey).Result() + } else { + err = r.cli.ClusterClient.ForEachMaster(ctx, func(nctx context.Context, cli *redis.Client) error { + nkeys, nerr := cli.Keys(nctx, rkey).Result() + if nerr != nil { + return nerr + } + keys = append(keys, nkeys...) + return nil + }) + } te := time.Since(ts) r.opts.Meter.Counter(semconv.CacheRequestInflight, "name", options.Name).Dec() r.opts.Meter.Summary(semconv.CacheRequestLatencyMicroseconds, "name", options.Name).Update(te.Seconds()) @@ -602,9 +665,9 @@ func (r *Store) configure() error { } if redisOptions != nil { - r.cli = redis.NewClient(redisOptions) + r.cli = &wrappedClient{Client: redis.NewClient(redisOptions)} } else if redisClusterOptions != nil { - r.cli = redis.NewClusterClient(redisClusterOptions) + r.cli = &wrappedClient{ClusterClient: redis.NewClusterClient(redisClusterOptions)} } r.pool = pool.NewPool(func() *strings.Builder { return &strings.Builder{} }) diff --git a/redis_test.go b/redis_test.go index 13fbd6c..c95e507 100755 --- a/redis_test.go +++ b/redis_test.go @@ -7,14 +7,13 @@ import ( "testing" "time" - "github.com/redis/go-redis/v9" "go.unistack.org/micro/v3/store" ) func Test_rkv_configure(t *testing.T) { type fields struct { options store.Options - Client *redis.Client + Client *wrappedClient } type wantValues struct { username string