From 1cd8fc30a154acc5937cf28cb4544bc45dbcc5ea Mon Sep 17 00:00:00 2001 From: Matias Insaurralde Date: Sat, 8 Jan 2022 21:14:31 -0300 Subject: [PATCH] Implement Copy command --- cmd_generic.go | 34 ++++++++++++++++++++++++++++++++++ cmd_generic_test.go | 30 ++++++++++++++++++++++++++++++ db.go | 31 +++++++++++++++++++++++++++++++ direct.go | 16 ++++++++++++++++ integration/generic_test.go | 13 +++++++++++++ 5 files changed, 124 insertions(+) diff --git a/cmd_generic.go b/cmd_generic.go index 3859838c..7027f031 100644 --- a/cmd_generic.go +++ b/cmd_generic.go @@ -35,6 +35,8 @@ func commandsGeneric(m *Miniredis) { m.srv.Register("TTL", m.cmdTTL) m.srv.Register("TYPE", m.cmdType) m.srv.Register("SCAN", m.cmdScan) + // COPY + m.srv.Register("COPY", m.cmdCopy) } // generic expire command for EXPIRE, PEXPIRE, EXPIREAT, PEXPIREAT @@ -541,3 +543,35 @@ func (m *Miniredis) cmdScan(c *server.Peer, cmd string, args []string) { } }) } + +// COPY +func (m *Miniredis) cmdCopy(c *server.Peer, cmd string, args []string) { + if len(args) != 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + if m.checkPubsub(c, cmd) { + return + } + + from, to := args[0], args[1] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if !db.exists(from) { + c.WriteInt(0) + return + } + + if !db.copy(from, to) { + c.WriteInt(0) + return + } + c.WriteInt(1) + }) +} diff --git a/cmd_generic_test.go b/cmd_generic_test.go index b062135e..1188dda9 100644 --- a/cmd_generic_test.go +++ b/cmd_generic_test.go @@ -762,3 +762,33 @@ func TestRenamenx(t *testing.T) { ) }) } + +func TestCopy(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := proto.Dial(s.Addr()) + ok(t, err) + defer c.Close() + + s.Set("key1", "value") + s.CheckGet(t, "key1", "value") + s.Copy("key1", "key2") + // should return 1 after a successful copy operation: + must1(t, c, "COPY", "key1", "key2") + s.CheckGet(t, "key2", "value") + + // should return 0 when trying to copy a nonexistent key: + t.Run("nonexistent key", func(t *testing.T) { + must0(t, c, "COPY", "nosuch", "to") + }) + + // should return 0 when trying to overwrite an existing key: + t.Run("existing key", func(t *testing.T) { + s.Set("existingkey", "value") + s.Set("newkey", "newvalue") + must0(t, c, "COPY", "newkey", "existingkey") + // existing key value should remain unchanged: + s.CheckGet(t, "existingkey", "value") + }) +} diff --git a/db.go b/db.go index 5c2b1aaf..95b57bae 100644 --- a/db.go +++ b/db.go @@ -112,6 +112,37 @@ func (db *RedisDB) rename(from, to string) { db.del(from, true) } +func (db *RedisDB) copy(from, to string) bool { + if _, ok := db.keys[to]; ok { + return false + } + db.keys[to] = from + switch db.t(from) { + case "string": + db.stringKeys[to] = db.stringKeys[from] + case "hash": + db.hashKeys[to] = db.hashKeys[from] + case "list": + db.listKeys[to] = db.listKeys[from] + case "set": + db.setKeys[to] = db.setKeys[from] + case "zset": + db.sortedsetKeys[to] = db.sortedsetKeys[from] + case "stream": + db.streamKeys[to] = db.streamKeys[from] + case "hll": + db.hllKeys[to] = db.hllKeys[from] + default: + panic("missing case") + } + db.keys[to] = db.keys[from] + db.keyVersion[to]++ + if v, ok := db.ttl[from]; ok { + db.ttl[to] = v + } + return true +} + func (db *RedisDB) del(k string, delTTL bool) { if !db.exists(k) { return diff --git a/direct.go b/direct.go index 23b6703a..db6e3f88 100644 --- a/direct.go +++ b/direct.go @@ -793,3 +793,19 @@ func (db *RedisDB) HllMerge(destKey string, sourceKeys ...string) error { return db.hllMerge(append([]string{destKey}, sourceKeys...)) } + +func (m *Miniredis) Copy(src, dest string) (bool, error) { + return m.DB(m.selectedDB).Copy(src, dest) +} + +func (db *RedisDB) Copy(src, dest string) (bool, error) { + db.master.Lock() + defer db.master.Unlock() + defer db.master.signal.Broadcast() + + if !db.exists(src) { + return false, ErrKeyNotFound + } + // return db.copy(src, dest), nil + return true, nil +} diff --git a/integration/generic_test.go b/integration/generic_test.go index d9f5702c..ba793c4a 100644 --- a/integration/generic_test.go +++ b/integration/generic_test.go @@ -299,3 +299,16 @@ func TestPersist(t *testing.T) { c.Do("TTL", "foo") }) } + +func TestCopy(t *testing.T) { + testRaw(t, func(c *client) { + c.Error("wrong number", "COPY") + c.Error("wrong number", "COPY", "a") + + c.Do("SET", "a", "1") + c.Do("COPY", "a", "b") // returns 1 - successfully copied + c.Do("EXISTS", "b") + c.Do("COPY", "nonexistent", "c") // returns 1 - not successfully copied + c.Do("RENAME", "b", "c") // rename the copied key + }) +}