From 42c2a2c1e075bd381469314978b27e3ba83ad5d2 Mon Sep 17 00:00:00 2001 From: Harmen Date: Sat, 23 Mar 2024 09:08:16 +0100 Subject: [PATCH] arg parse cleanup (#360) Splits out the opts parsing for some complex cases, to avoid duplicate error logic. And easier testing. --- cmd_generic.go | 232 +++++++++++++++++++++++--------------------- cmd_generic_test.go | 24 +++++ opts.go | 11 +++ 3 files changed, 155 insertions(+), 112 deletions(-) diff --git a/cmd_generic.go b/cmd_generic.go index f3fd604e..9d1adba0 100644 --- a/cmd_generic.go +++ b/cmd_generic.go @@ -3,6 +3,7 @@ package miniredis import ( + "errors" "fmt" "sort" "strconv" @@ -60,6 +61,47 @@ func commandsGeneric(m *Miniredis) { m.srv.Register("UNLINK", m.cmdDel) } +type expireOpts struct { + key string + value int + nx bool + xx bool + gt bool + lt bool +} + +func expireParse(cmd string, args []string) (*expireOpts, error) { + var opts expireOpts + + opts.key = args[0] + if err := optIntSimple(args[1], &opts.value); err != nil { + return nil, err + } + args = args[2:] + for len(args) > 0 { + switch strings.ToLower(args[0]) { + case "nx": + opts.nx = true + case "xx": + opts.xx = true + case "gt": + opts.gt = true + case "lt": + opts.lt = true + default: + return nil, fmt.Errorf("ERR Unsupported option %s", args[0]) + } + args = args[1:] + } + if opts.gt && opts.lt { + return nil, errors.New("ERR GT and LT options at the same time are not compatible") + } + if opts.nx && (opts.xx || opts.gt || opts.lt) { + return nil, errors.New("ERR NX and XX, GT or LT options at the same time are not compatible") + } + return &opts, nil +} + // generic expire command for EXPIRE, PEXPIRE, EXPIREAT, PEXPIREAT // d is the time unit. If unix is set it'll be seen as a unixtimestamp and // converted to a duration. @@ -77,44 +119,10 @@ func makeCmdExpire(m *Miniredis, unix bool, d time.Duration) func(*server.Peer, return } - var opts struct { - key string - value int - nx bool - xx bool - gt bool - lt bool - } - opts.key = args[0] - if ok := optInt(c, args[1], &opts.value); !ok { - return - } - args = args[2:] - for len(args) > 0 { - switch strings.ToLower(args[0]) { - case "nx": - opts.nx = true - case "xx": - opts.xx = true - case "gt": - opts.gt = true - case "lt": - opts.lt = true - default: - setDirty(c) - c.WriteError(fmt.Sprintf("ERR Unsupported option %s", args[0])) - return - } - args = args[1:] - } - if opts.gt && opts.lt { - setDirty(c) - c.WriteError("ERR GT and LT options at the same time are not compatible") - return - } - if opts.nx && (opts.xx || opts.gt || opts.lt) { + opts, err := expireParse(cmd, args) + if err != nil { setDirty(c) - c.WriteError("ERR NX and XX, GT or LT options at the same time are not compatible") + c.WriteError(err.Error()) return } @@ -597,31 +605,19 @@ func (m *Miniredis) cmdRenamenx(c *server.Peer, cmd string, args []string) { }) } -// SCAN -func (m *Miniredis) cmdScan(c *server.Peer, cmd string, args []string) { - if len(args) < 1 { - setDirty(c) - c.WriteError(errWrongNumber(cmd)) - return - } - if !m.handleAuth(c) { - return - } - if m.checkPubsub(c, cmd) { - return - } - - var opts struct { - cursor int - count int - withMatch bool - match string - withType bool - _type string - } +type scanOpts struct { + cursor int + count int + withMatch bool + match string + withType bool + _type string +} - if ok := optIntErr(c, args[0], &opts.cursor, msgInvalidCursor); !ok { - return +func scanParse(cmd string, args []string) (*scanOpts, error) { + var opts scanOpts + if err := optIntSimple(args[0], &opts.cursor); err != nil { + return nil, errors.New(msgInvalidCursor) } args = args[1:] @@ -629,20 +625,14 @@ func (m *Miniredis) cmdScan(c *server.Peer, cmd string, args []string) { for len(args) > 0 { if strings.ToLower(args[0]) == "count" { if len(args) < 2 { - setDirty(c) - c.WriteError(msgSyntaxError) - return + return nil, errors.New(msgSyntaxError) } count, err := strconv.Atoi(args[1]) if err != nil || count < 0 { - setDirty(c) - c.WriteError(msgInvalidInt) - return + return nil, errors.New(msgInvalidInt) } if count == 0 { - setDirty(c) - c.WriteError(msgSyntaxError) - return + return nil, errors.New(msgSyntaxError) } opts.count = count args = args[2:] @@ -650,9 +640,7 @@ func (m *Miniredis) cmdScan(c *server.Peer, cmd string, args []string) { } if strings.ToLower(args[0]) == "match" { if len(args) < 2 { - setDirty(c) - c.WriteError(msgSyntaxError) - return + return nil, errors.New(msgSyntaxError) } opts.withMatch = true opts.match, args = args[1], args[2:] @@ -660,16 +648,35 @@ func (m *Miniredis) cmdScan(c *server.Peer, cmd string, args []string) { } if strings.ToLower(args[0]) == "type" { if len(args) < 2 { - setDirty(c) - c.WriteError(msgSyntaxError) - return + return nil, errors.New(msgSyntaxError) } opts.withType = true opts._type, args = strings.ToLower(args[1]), args[2:] continue } + return nil, errors.New(msgSyntaxError) + } + return &opts, nil +} + +// SCAN +func (m *Miniredis) cmdScan(c *server.Peer, cmd string, args []string) { + if len(args) < 1 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + if m.checkPubsub(c, cmd) { + return + } + + opts, err := scanParse(cmd, args) + if err != nil { setDirty(c) - c.WriteError(msgSyntaxError) + c.WriteError(err.Error()) return } @@ -724,26 +731,15 @@ func (m *Miniredis) cmdScan(c *server.Peer, cmd string, args []string) { }) } -// COPY -func (m *Miniredis) cmdCopy(c *server.Peer, cmd string, args []string) { - if len(args) < 2 { - setDirty(c) - c.WriteError(errWrongNumber(cmd)) - return - } - if !m.handleAuth(c) { - return - } - if m.checkPubsub(c, cmd) { - return - } +type copyOpts struct { + from string + to string + destinationDB int + replace bool +} - var opts = struct { - from string - to string - destinationDB int - replace bool - }{ +func copyParse(cmd string, args []string) (*copyOpts, error) { + opts := copyOpts{ destinationDB: -1, } @@ -752,33 +748,45 @@ func (m *Miniredis) cmdCopy(c *server.Peer, cmd string, args []string) { switch strings.ToLower(args[0]) { case "db": if len(args) < 2 { - setDirty(c) - c.WriteError(msgSyntaxError) - return + return nil, errors.New(msgSyntaxError) } - db, err := strconv.Atoi(args[1]) - if err != nil { - setDirty(c) - c.WriteError(msgInvalidInt) - return + if err := optIntSimple(args[1], &opts.destinationDB); err != nil { + return nil, err } - if db < 0 { - setDirty(c) - c.WriteError(msgDBIndexOutOfRange) - return + if opts.destinationDB < 0 { + return nil, errors.New(msgDBIndexOutOfRange) } - opts.destinationDB = db args = args[2:] case "replace": opts.replace = true args = args[1:] default: - setDirty(c) - c.WriteError(msgSyntaxError) - return + return nil, errors.New(msgSyntaxError) } } + return &opts, nil +} +// COPY +func (m *Miniredis) cmdCopy(c *server.Peer, cmd string, args []string) { + if len(args) < 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + if m.checkPubsub(c, cmd) { + return + } + + opts, err := copyParse(cmd, args) + if err != nil { + setDirty(c) + c.WriteError(err.Error()) + return + } withTx(m, c, func(c *server.Peer, ctx *connCtx) { fromDB, toDB := ctx.selectedDB, opts.destinationDB if toDB == -1 { diff --git a/cmd_generic_test.go b/cmd_generic_test.go index 2f47b4d9..52f7df3c 100644 --- a/cmd_generic_test.go +++ b/cmd_generic_test.go @@ -17,6 +17,14 @@ func TestTTL(t *testing.T) { ok(t, err) defer c.Close() + t.Run("parse", func(t *testing.T) { + t.Run("basic", func(t *testing.T) { + v, err := expireParse("SCAN", []string{"foo", "200"}) + ok(t, err) + equals(t, expireOpts{key: "foo", value: 200}, *v) + }) + }) + // Not volatile yet { equals(t, time.Duration(0), s.TTL("foo")) @@ -681,6 +689,14 @@ func TestScan(t *testing.T) { ok(t, err) defer c.Close() + t.Run("parse", func(t *testing.T) { + t.Run("basic", func(t *testing.T) { + v, err := scanParse("SCAN", []string{"0", "COUNT", "200"}) + ok(t, err) + equals(t, scanOpts{count: 200}, *v) + }) + }) + // We cheat with scan. It always returns everything. s.Set("key", "value") @@ -896,6 +912,14 @@ func TestCopy(t *testing.T) { ok(t, err) defer c.Close() + t.Run("parse", func(t *testing.T) { + t.Run("basic", func(t *testing.T) { + v, err := copyParse("copy", []string{"key1", "key2"}) + ok(t, err) + equals(t, copyOpts{from: "key1", to: "key2", destinationDB: -1}, *v) + }) + }) + t.Run("basic", func(t *testing.T) { s.Set("key1", "value") // should return 1 after a successful copy operation: diff --git a/opts.go b/opts.go index 666ace7f..5b29c78c 100644 --- a/opts.go +++ b/opts.go @@ -1,6 +1,7 @@ package miniredis import ( + "errors" "math" "strconv" "time" @@ -26,6 +27,16 @@ func optIntErr(c *server.Peer, src string, dest *int, errMsg string) bool { return true } +// optIntSimple sets dest or returns an error +func optIntSimple(src string, dest *int) error { + n, err := strconv.Atoi(src) + if err != nil { + return errors.New(msgInvalidInt) + } + *dest = n + return nil +} + func optDuration(c *server.Peer, src string, dest *time.Duration) bool { n, err := strconv.ParseFloat(src, 64) if err != nil {