From 1cd8fc30a154acc5937cf28cb4544bc45dbcc5ea Mon Sep 17 00:00:00 2001 From: Matias Insaurralde Date: Sat, 8 Jan 2022 21:14:31 -0300 Subject: [PATCH 1/2] Implement Copy command --- cmd_generic.go | 34 ++++++++++++++++++++++++++++++++++ cmd_generic_test.go | 30 ++++++++++++++++++++++++++++++ db.go | 31 +++++++++++++++++++++++++++++++ direct.go | 16 ++++++++++++++++ integration/generic_test.go | 13 +++++++++++++ 5 files changed, 124 insertions(+) diff --git a/cmd_generic.go b/cmd_generic.go index 3859838c..7027f031 100644 --- a/cmd_generic.go +++ b/cmd_generic.go @@ -35,6 +35,8 @@ func commandsGeneric(m *Miniredis) { m.srv.Register("TTL", m.cmdTTL) m.srv.Register("TYPE", m.cmdType) m.srv.Register("SCAN", m.cmdScan) + // COPY + m.srv.Register("COPY", m.cmdCopy) } // generic expire command for EXPIRE, PEXPIRE, EXPIREAT, PEXPIREAT @@ -541,3 +543,35 @@ func (m *Miniredis) cmdScan(c *server.Peer, cmd string, args []string) { } }) } + +// COPY +func (m *Miniredis) cmdCopy(c *server.Peer, cmd string, args []string) { + if len(args) != 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + if m.checkPubsub(c, cmd) { + return + } + + from, to := args[0], args[1] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if !db.exists(from) { + c.WriteInt(0) + return + } + + if !db.copy(from, to) { + c.WriteInt(0) + return + } + c.WriteInt(1) + }) +} diff --git a/cmd_generic_test.go b/cmd_generic_test.go index b062135e..1188dda9 100644 --- a/cmd_generic_test.go +++ b/cmd_generic_test.go @@ -762,3 +762,33 @@ func TestRenamenx(t *testing.T) { ) }) } + +func TestCopy(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := proto.Dial(s.Addr()) + ok(t, err) + defer c.Close() + + s.Set("key1", "value") + s.CheckGet(t, "key1", "value") + s.Copy("key1", "key2") + // should return 1 after a successful copy operation: + must1(t, c, "COPY", "key1", "key2") + s.CheckGet(t, "key2", "value") + + // should return 0 when trying to copy a nonexistent key: + t.Run("nonexistent key", func(t *testing.T) { + must0(t, c, "COPY", "nosuch", "to") + }) + + // should return 0 when trying to overwrite an existing key: + t.Run("existing key", func(t *testing.T) { + s.Set("existingkey", "value") + s.Set("newkey", "newvalue") + must0(t, c, "COPY", "newkey", "existingkey") + // existing key value should remain unchanged: + s.CheckGet(t, "existingkey", "value") + }) +} diff --git a/db.go b/db.go index 5c2b1aaf..95b57bae 100644 --- a/db.go +++ b/db.go @@ -112,6 +112,37 @@ func (db *RedisDB) rename(from, to string) { db.del(from, true) } +func (db *RedisDB) copy(from, to string) bool { + if _, ok := db.keys[to]; ok { + return false + } + db.keys[to] = from + switch db.t(from) { + case "string": + db.stringKeys[to] = db.stringKeys[from] + case "hash": + db.hashKeys[to] = db.hashKeys[from] + case "list": + db.listKeys[to] = db.listKeys[from] + case "set": + db.setKeys[to] = db.setKeys[from] + case "zset": + db.sortedsetKeys[to] = db.sortedsetKeys[from] + case "stream": + db.streamKeys[to] = db.streamKeys[from] + case "hll": + db.hllKeys[to] = db.hllKeys[from] + default: + panic("missing case") + } + db.keys[to] = db.keys[from] + db.keyVersion[to]++ + if v, ok := db.ttl[from]; ok { + db.ttl[to] = v + } + return true +} + func (db *RedisDB) del(k string, delTTL bool) { if !db.exists(k) { return diff --git a/direct.go b/direct.go index 23b6703a..db6e3f88 100644 --- a/direct.go +++ b/direct.go @@ -793,3 +793,19 @@ func (db *RedisDB) HllMerge(destKey string, sourceKeys ...string) error { return db.hllMerge(append([]string{destKey}, sourceKeys...)) } + +func (m *Miniredis) Copy(src, dest string) (bool, error) { + return m.DB(m.selectedDB).Copy(src, dest) +} + +func (db *RedisDB) Copy(src, dest string) (bool, error) { + db.master.Lock() + defer db.master.Unlock() + defer db.master.signal.Broadcast() + + if !db.exists(src) { + return false, ErrKeyNotFound + } + // return db.copy(src, dest), nil + return true, nil +} diff --git a/integration/generic_test.go b/integration/generic_test.go index d9f5702c..ba793c4a 100644 --- a/integration/generic_test.go +++ b/integration/generic_test.go @@ -299,3 +299,16 @@ func TestPersist(t *testing.T) { c.Do("TTL", "foo") }) } + +func TestCopy(t *testing.T) { + testRaw(t, func(c *client) { + c.Error("wrong number", "COPY") + c.Error("wrong number", "COPY", "a") + + c.Do("SET", "a", "1") + c.Do("COPY", "a", "b") // returns 1 - successfully copied + c.Do("EXISTS", "b") + c.Do("COPY", "nonexistent", "c") // returns 1 - not successfully copied + c.Do("RENAME", "b", "c") // rename the copied key + }) +} From 48bccccf6294677379d22e0bbaa96ce6458bec6f Mon Sep 17 00:00:00 2001 From: Harmen Date: Wed, 12 Jan 2022 09:38:59 +0100 Subject: [PATCH 2/2] implement all COPY options --- README.md | 1 + cmd_connection.go | 4 +- cmd_generic.go | 64 ++++++++++++++++++--- cmd_generic_test.go | 49 ++++++++++++++-- db.go | 31 ---------- direct.go | 20 ++----- hll.go | 6 ++ integration/generic_test.go | 112 ++++++++++++++++++++++++++++++++++++ miniredis.go | 59 +++++++++++++++++++ redis.go | 1 + stream.go | 27 +++++++++ 11 files changed, 314 insertions(+), 60 deletions(-) diff --git a/README.md b/README.md index 8e067b03..2a065248 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,7 @@ Implemented commands: - SWAPDB - QUIT - Key + - COPY - DEL - EXISTS - EXPIRE diff --git a/cmd_connection.go b/cmd_connection.go index 1bf98012..defbbcca 100644 --- a/cmd_connection.go +++ b/cmd_connection.go @@ -227,7 +227,7 @@ func (m *Miniredis) cmdSelect(c *server.Peer, cmd string, args []string) { return } if id < 0 { - c.WriteError("ERR DB index is out of range") + c.WriteError(msgDBIndexOutOfRange) setDirty(c) return } @@ -262,7 +262,7 @@ func (m *Miniredis) cmdSwapdb(c *server.Peer, cmd string, args []string) { return } if id1 < 0 || id2 < 0 { - c.WriteError("ERR DB index is out of range") + c.WriteError(msgDBIndexOutOfRange) setDirty(c) return } diff --git a/cmd_generic.go b/cmd_generic.go index 7027f031..d9bac197 100644 --- a/cmd_generic.go +++ b/cmd_generic.go @@ -35,7 +35,6 @@ func commandsGeneric(m *Miniredis) { m.srv.Register("TTL", m.cmdTTL) m.srv.Register("TYPE", m.cmdType) m.srv.Register("SCAN", m.cmdScan) - // COPY m.srv.Register("COPY", m.cmdCopy) } @@ -546,7 +545,7 @@ func (m *Miniredis) cmdScan(c *server.Peer, cmd string, args []string) { // COPY func (m *Miniredis) cmdCopy(c *server.Peer, cmd string, args []string) { - if len(args) != 2 { + if len(args) < 2 { setDirty(c) c.WriteError(errWrongNumber(cmd)) return @@ -558,20 +557,71 @@ func (m *Miniredis) cmdCopy(c *server.Peer, cmd string, args []string) { return } - from, to := args[0], args[1] + var opts = struct { + from string + to string + destinationDB int + replace bool + }{ + destinationDB: -1, + } + + opts.from, opts.to, args = args[0], args[1], args[2:] + for len(args) > 0 { + switch strings.ToLower(args[0]) { + case "db": + if len(args) < 2 { + setDirty(c) + c.WriteError(msgSyntaxError) + return + } + db, err := strconv.Atoi(args[1]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidInt) + return + } + if db < 0 { + setDirty(c) + c.WriteError(msgDBIndexOutOfRange) + return + } + opts.destinationDB = db + args = args[2:] + case "replace": + opts.replace = true + args = args[1:] + default: + setDirty(c) + c.WriteError(msgSyntaxError) + return + } + } withTx(m, c, func(c *server.Peer, ctx *connCtx) { - db := m.db(ctx.selectedDB) + fromDB, toDB := ctx.selectedDB, opts.destinationDB + if toDB == -1 { + toDB = fromDB + } - if !db.exists(from) { - c.WriteInt(0) + if fromDB == toDB && opts.from == opts.to { + c.WriteError("ERR source and destination objects are the same") return } - if !db.copy(from, to) { + if !m.db(fromDB).exists(opts.from) { c.WriteInt(0) return } + + if !opts.replace { + if m.db(toDB).exists(opts.to) { + c.WriteInt(0) + return + } + } + + m.copy(m.db(fromDB), opts.from, m.db(toDB), opts.to) c.WriteInt(1) }) } diff --git a/cmd_generic_test.go b/cmd_generic_test.go index 1188dda9..7934126f 100644 --- a/cmd_generic_test.go +++ b/cmd_generic_test.go @@ -771,12 +771,13 @@ func TestCopy(t *testing.T) { ok(t, err) defer c.Close() - s.Set("key1", "value") - s.CheckGet(t, "key1", "value") - s.Copy("key1", "key2") - // should return 1 after a successful copy operation: - must1(t, c, "COPY", "key1", "key2") - s.CheckGet(t, "key2", "value") + t.Run("basic", func(t *testing.T) { + s.Set("key1", "value") + // should return 1 after a successful copy operation: + must1(t, c, "COPY", "key1", "key2") + s.CheckGet(t, "key2", "value") + equals(t, "string", s.Type("key2")) + }) // should return 0 when trying to copy a nonexistent key: t.Run("nonexistent key", func(t *testing.T) { @@ -791,4 +792,40 @@ func TestCopy(t *testing.T) { // existing key value should remain unchanged: s.CheckGet(t, "existingkey", "value") }) + + t.Run("destination db", func(t *testing.T) { + s.Set("akey1", "value") + must1(t, c, "COPY", "akey1", "akey2", "DB", "2") + s.Select(2) + s.CheckGet(t, "akey2", "value") + equals(t, "string", s.Type("akey2")) + }) + s.Select(0) + + t.Run("replace", func(t *testing.T) { + s.Set("rkey1", "value") + s.Set("rkey2", "another") + must1(t, c, "COPY", "rkey1", "rkey2", "REPLACE") + s.CheckGet(t, "rkey2", "value") + equals(t, "string", s.Type("rkey2")) + }) + + t.Run("direct", func(t *testing.T) { + s.Set("d1", "value") + ok(t, s.Copy(0, "d1", 0, "d2")) + equals(t, "string", s.Type("d2")) + s.CheckGet(t, "d2", "value") + }) + + t.Run("errors", func(t *testing.T) { + mustDo(t, c, "COPY", + proto.Error(errWrongNumber("copy")), + ) + mustDo(t, c, "COPY", "foo", + proto.Error(errWrongNumber("copy")), + ) + mustDo(t, c, "COPY", "foo", "bar", "baz", + proto.Error(msgSyntaxError), + ) + }) } diff --git a/db.go b/db.go index 95b57bae..5c2b1aaf 100644 --- a/db.go +++ b/db.go @@ -112,37 +112,6 @@ func (db *RedisDB) rename(from, to string) { db.del(from, true) } -func (db *RedisDB) copy(from, to string) bool { - if _, ok := db.keys[to]; ok { - return false - } - db.keys[to] = from - switch db.t(from) { - case "string": - db.stringKeys[to] = db.stringKeys[from] - case "hash": - db.hashKeys[to] = db.hashKeys[from] - case "list": - db.listKeys[to] = db.listKeys[from] - case "set": - db.setKeys[to] = db.setKeys[from] - case "zset": - db.sortedsetKeys[to] = db.sortedsetKeys[from] - case "stream": - db.streamKeys[to] = db.streamKeys[from] - case "hll": - db.hllKeys[to] = db.hllKeys[from] - default: - panic("missing case") - } - db.keys[to] = db.keys[from] - db.keyVersion[to]++ - if v, ok := db.ttl[from]; ok { - db.ttl[to] = v - } - return true -} - func (db *RedisDB) del(k string, delTTL bool) { if !db.exists(k) { return diff --git a/direct.go b/direct.go index db6e3f88..cd2323c6 100644 --- a/direct.go +++ b/direct.go @@ -794,18 +794,10 @@ func (db *RedisDB) HllMerge(destKey string, sourceKeys ...string) error { return db.hllMerge(append([]string{destKey}, sourceKeys...)) } -func (m *Miniredis) Copy(src, dest string) (bool, error) { - return m.DB(m.selectedDB).Copy(src, dest) -} - -func (db *RedisDB) Copy(src, dest string) (bool, error) { - db.master.Lock() - defer db.master.Unlock() - defer db.master.signal.Broadcast() - - if !db.exists(src) { - return false, ErrKeyNotFound - } - // return db.copy(src, dest), nil - return true, nil +// Copy a value. +// Needs the IDs of both the source and dest DBs (which can differ). +// Returns ErrKeyNotFound if src does not exist. +// Overwrites dest if it already exists (unlike the redis command, which needs a flag to allow that). +func (m *Miniredis) Copy(srcDB int, src string, destDB int, dest string) error { + return m.copy(m.DB(srcDB), src, m.DB(destDB), dest) } diff --git a/hll.go b/hll.go index 2f55fac9..d00ad78a 100644 --- a/hll.go +++ b/hll.go @@ -34,3 +34,9 @@ func (h *hll) Bytes() []byte { dataBytes, _ := h.inner.MarshalBinary() return dataBytes } + +func (h *hll) copy() *hll { + return &hll{ + inner: h.inner.Clone(), + } +} diff --git a/integration/generic_test.go b/integration/generic_test.go index ba793c4a..91f1ddac 100644 --- a/integration/generic_test.go +++ b/integration/generic_test.go @@ -304,11 +304,123 @@ func TestCopy(t *testing.T) { testRaw(t, func(c *client) { c.Error("wrong number", "COPY") c.Error("wrong number", "COPY", "a") + c.Error("syntax", "COPY", "a", "b", "c") + c.Error("syntax", "COPY", "a", "b", "DB") + c.Error("range", "COPY", "a", "b", "DB", "-1") + c.Error("integer", "COPY", "a", "b", "DB", "foo") + c.Error("syntax", "COPY", "a", "b", "DB", "1", "REPLACE", "foo") c.Do("SET", "a", "1") c.Do("COPY", "a", "b") // returns 1 - successfully copied c.Do("EXISTS", "b") + c.Do("GET", "b") + c.Do("TYPE", "b") + c.Do("COPY", "nonexistent", "c") // returns 1 - not successfully copied c.Do("RENAME", "b", "c") // rename the copied key + + t.Run("replace option", func(t *testing.T) { + c.Do("SET", "fromme", "1") + c.Do("HSET", "replaceme", "foo", "bar") + c.Do("COPY", "fromme", "replaceme", "REPLACE") + c.Do("TYPE", "replaceme") + c.Do("GET", "replaceme") + }) + + t.Run("different DB", func(t *testing.T) { + c.Do("SELECT", "2") + c.Do("SET", "fromme", "1") + c.Do("COPY", "fromme", "replaceme", "DB", "3") + c.Do("EXISTS", "replaceme") // your value is in another db + c.Do("SELECT", "3") + c.Do("EXISTS", "replaceme") + c.Do("TYPE", "replaceme") + c.Do("GET", "replaceme") + }) + c.Do("SELECT", "0") + + t.Run("copy to self", func(t *testing.T) { + // copy to self is never allowed + c.Do("SET", "double", "1") + c.Error("the same", "COPY", "double", "double") + c.Error("the same", "COPY", "double", "double", "REPLACE") + c.Do("COPY", "double", "double", "DB", "2") // different DB is fine + c.Do("SELECT", "2") + c.Do("TYPE", "double") + + c.Error("the same", "COPY", "noexisting", "noexisting") // "copy to self?" check comes before key check + }) + c.Do("SELECT", "0") + + // deep copies? + t.Run("hash", func(t *testing.T) { + c.Do("HSET", "temp", "paris", "12") + c.Do("HSET", "temp", "oslo", "-5") + c.Do("COPY", "temp", "temp2") + c.Do("TYPE", "temp2") + c.Do("HGET", "temp2", "oslo") + c.Do("HSET", "temp2", "oslo", "-7") + c.Do("HGET", "temp", "oslo") + c.Do("HGET", "temp2", "oslo") + }) + + t.Run("list", func(t *testing.T) { + c.Do("LPUSH", "list", "aap", "noot", "mies") + c.Do("COPY", "list", "list2") + c.Do("TYPE", "list2") + c.Do("LPUSH", "list", "vuur") + c.Do("LRANGE", "list", "0", "-1") + c.Do("LRANGE", "list2", "0", "-1") + }) + + t.Run("set", func(t *testing.T) { + c.Do("SADD", "set", "aap", "noot", "mies") + c.Do("COPY", "set", "set2") + c.Do("TYPE", "set2") + c.DoSorted("SMEMBERS", "set2") + c.Do("SADD", "set", "vuur") + c.DoSorted("SMEMBERS", "set") + c.DoSorted("SMEMBERS", "set2") + }) + + t.Run("sorted set", func(t *testing.T) { + c.Do("ZADD", "zset", "1", "aap", "2", "noot", "3", "mies") + c.Do("COPY", "zset", "zset2") + c.Do("TYPE", "zset2") + c.Do("ZCARD", "zset") + c.Do("ZCARD", "zset2") + c.Do("ZADD", "zset", "4", "vuur") + c.Do("ZCARD", "zset") + c.Do("ZCARD", "zset2") + }) + + t.Run("stream", func(t *testing.T) { + c.Do("XADD", + "planets", + "0-1", + "name", "Mercury", + ) + c.Do("COPY", "planets", "planets2") + c.Do("XLEN", "planets2") + c.Do("TYPE", "planets2") + + c.Do("XADD", + "planets", + "18446744073709551000-0", + "name", "Earth", + ) + c.Do("XLEN", "planets") + c.Do("XLEN", "planets2") + }) + + t.Run("stream", func(t *testing.T) { + c.Do("PFADD", "hlog", "42") + c.DoApprox(2, "PFCOUNT", "hlog") + c.Do("COPY", "hlog", "hlog2") + // c.Do("TYPE", "hlog2") broken + c.Do("PFADD", "hlog", "44") + c.Do("PFCOUNT", "hlog") + c.Do("PFCOUNT", "hlog2") + }) }) } diff --git a/miniredis.go b/miniredis.go index 8f78dea2..171678d9 100644 --- a/miniredis.go +++ b/miniredis.go @@ -637,3 +637,62 @@ func (m *Miniredis) at(i int, d time.Duration) time.Duration { now := m.effectiveNow() return ts.Sub(now) } + +// copy does not mind if dst already exists. +func (m *Miniredis) copy( + srcDB *RedisDB, src string, + destDB *RedisDB, dst string, +) error { + if !srcDB.exists(src) { + return ErrKeyNotFound + } + + switch srcDB.t(src) { + case "string": + destDB.stringKeys[dst] = srcDB.stringKeys[src] + case "hash": + destDB.hashKeys[dst] = copyHashKey(srcDB.hashKeys[src]) + case "list": + destDB.listKeys[dst] = srcDB.listKeys[src] + case "set": + destDB.setKeys[dst] = copySetKey(srcDB.setKeys[src]) + case "zset": + destDB.sortedsetKeys[dst] = copySortedSet(srcDB.sortedsetKeys[src]) + case "stream": + destDB.streamKeys[dst] = srcDB.streamKeys[src].copy() + case "hll": + destDB.hllKeys[dst] = srcDB.hllKeys[src].copy() + default: + panic("missing case") + } + destDB.keys[dst] = srcDB.keys[src] + destDB.keyVersion[dst]++ + if v, ok := srcDB.ttl[src]; ok { + destDB.ttl[dst] = v + } + return nil +} + +func copyHashKey(orig hashKey) hashKey { + cpy := hashKey{} + for k, v := range orig { + cpy[k] = v + } + return cpy +} + +func copySetKey(orig setKey) setKey { + cpy := setKey{} + for k, v := range orig { + cpy[k] = v + } + return cpy +} + +func copySortedSet(orig sortedSet) sortedSet { + cpy := sortedSet{} + for k, v := range orig { + cpy[k] = v + } + return cpy +} diff --git a/redis.go b/redis.go index e870f2c3..ff869d8f 100644 --- a/redis.go +++ b/redis.go @@ -45,6 +45,7 @@ const ( msgXtrimInvalidStrategy = "ERR unsupported XTRIM strategy. Please use MAXLEN, MINID" msgXtrimInvalidMaxLen = "ERR value is not an integer or out of range" msgXtrimInvalidLimit = "ERR syntax error, LIMIT cannot be used without the special ~ option" + msgDBIndexOutOfRange = "ERR DB index is out of range" ) func errWrongNumber(cmd string) string { diff --git a/stream.go b/stream.go index 0c75f8ab..d748cfa6 100644 --- a/stream.go +++ b/stream.go @@ -71,6 +71,20 @@ func (s *streamKey) lastID() string { return s.entries[len(s.entries)-1].ID } +func (s *streamKey) copy() *streamKey { + cpy := &streamKey{ + entries: s.entries, + } + groups := map[string]*streamGroup{} + for k, v := range s.groups { + gr := v.copy() + gr.stream = cpy + groups[k] = gr + } + cpy.groups = groups + return cpy +} + func parseStreamID(id string) ([2]uint64, error) { var ( res [2]uint64 @@ -347,3 +361,16 @@ func (g *streamGroup) pendingCount(consumer string) int { } return n } + +func (g *streamGroup) copy() *streamGroup { + cns := map[string]consumer{} + for k, v := range g.consumers { + cns[k] = v + } + return &streamGroup{ + // don't copy stream + lastID: g.lastID, + pending: g.pending, + consumers: cns, + } +}