diff --git a/core/stores/kv/store.go b/core/stores/kv/store.go index 603f14a6e8c5..2d6659707eea 100644 --- a/core/stores/kv/store.go +++ b/core/stores/kv/store.go @@ -54,6 +54,7 @@ type ( Setex(key, value string, seconds int) error Setnx(key, value string) (bool, error) SetnxEx(key, value string, seconds int) (bool, error) + Getset(key, value string) (string, error) Sismember(key string, value interface{}) (bool, error) Smembers(key string) ([]string, error) Spop(key string) (string, error) @@ -459,6 +460,15 @@ func (cs clusterStore) SetnxEx(key, value string, seconds int) (bool, error) { return node.SetnxEx(key, value, seconds) } +func (cs clusterStore) Getset(key, value string) (string, error) { + node, err := cs.getRedis(key) + if err != nil { + return "", err + } + + return node.GetSet(key, value) +} + func (cs clusterStore) Sismember(key string, value interface{}) (bool, error) { node, err := cs.getRedis(key) if err != nil { diff --git a/core/stores/kv/store_test.go b/core/stores/kv/store_test.go index 0a8f44d2f54c..cb49aec6d881 100644 --- a/core/stores/kv/store_test.go +++ b/core/stores/kv/store_test.go @@ -490,6 +490,29 @@ func TestRedis_SetExNx(t *testing.T) { }) } +func TestRedis_Getset(t *testing.T) { + store := clusterStore{dispatcher: hash.NewConsistentHash()} + _, err := store.Getset("hello", "world") + assert.NotNil(t, err) + + runOnCluster(t, func(client Store) { + val, err := client.Getset("hello", "world") + assert.Nil(t, err) + assert.Equal(t, "", val) + val, err = client.Get("hello") + assert.Nil(t, err) + assert.Equal(t, "world", val) + val, err = client.Getset("hello", "newworld") + assert.Nil(t, err) + assert.Equal(t, "world", val) + val, err = client.Get("hello") + assert.Nil(t, err) + assert.Equal(t, "newworld", val) + _, err = client.Del("hello") + assert.Nil(t, err) + }) +} + func TestRedis_SetGetDelHashField(t *testing.T) { store := clusterStore{dispatcher: hash.NewConsistentHash()} err := store.Hset("key", "field", "value") diff --git a/core/stores/redis/redis.go b/core/stores/redis/redis.go index aff715915674..570bcacaa331 100644 --- a/core/stores/redis/redis.go +++ b/core/stores/redis/redis.go @@ -615,6 +615,31 @@ func (s *Redis) GetCtx(ctx context.Context, key string) (val string, err error) return } +// GetSet is the implementation of redis getset command. +func (s *Redis) GetSet(key, value string) (string, error) { + return s.GetSetCtx(context.Background(), key, value) +} + +// GetSetCtx is the implementation of redis getset command. +func (s *Redis) GetSetCtx(ctx context.Context, key, value string) (val string, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + if val, err = conn.GetSet(ctx, key, value).Result(); err == red.Nil { + return nil + } else if err != nil { + return err + } else { + return nil + } + }, acceptable) + + return +} + // GetBit is the implementation of redis getbit command. func (s *Redis) GetBit(key string, offset int64) (int, error) { return s.GetBitCtx(context.Background(), key, offset) diff --git a/core/stores/redis/redis_test.go b/core/stores/redis/redis_test.go index 1c196e85b1d4..59ab3e7e4c5c 100644 --- a/core/stores/redis/redis_test.go +++ b/core/stores/redis/redis_test.go @@ -701,6 +701,28 @@ func TestRedis_Set(t *testing.T) { }) } +func TestRedis_GetSet(t *testing.T) { + runOnRedis(t, func(client *Redis) { + val, err := New(client.Addr, badType()).GetSet("hello", "world") + assert.NotNil(t, err) + val, err = client.GetSet("hello", "world") + assert.Nil(t, err) + assert.Equal(t, "", val) + val, err = client.Get("hello") + assert.Nil(t, err) + assert.Equal(t, "world", val) + val, err = client.GetSet("hello", "newworld") + assert.Nil(t, err) + assert.Equal(t, "world", val) + val, err = client.Get("hello") + assert.Nil(t, err) + assert.Equal(t, "newworld", val) + ret, err := client.Del("hello") + assert.Nil(t, err) + assert.Equal(t, 1, ret) + }) +} + func TestRedis_SetGetDel(t *testing.T) { runOnRedis(t, func(client *Redis) { err := New(client.Addr, badType()).Set("hello", "world")