Skip to content

Commit

Permalink
arg parse cleanup (#360)
Browse files Browse the repository at this point in the history
Splits out the opts parsing for some complex cases, to avoid duplicate error logic. And easier testing.
  • Loading branch information
alicebob committed Mar 23, 2024
1 parent 6b28640 commit 42c2a2c
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 112 deletions.
232 changes: 120 additions & 112 deletions cmd_generic.go
Expand Up @@ -3,6 +3,7 @@
package miniredis

import (
"errors"
"fmt"
"sort"
"strconv"
Expand Down Expand Up @@ -60,6 +61,47 @@ func commandsGeneric(m *Miniredis) {
m.srv.Register("UNLINK", m.cmdDel)
}

type expireOpts struct {
key string
value int
nx bool
xx bool
gt bool
lt bool
}

func expireParse(cmd string, args []string) (*expireOpts, error) {
var opts expireOpts

opts.key = args[0]
if err := optIntSimple(args[1], &opts.value); err != nil {
return nil, err
}
args = args[2:]
for len(args) > 0 {
switch strings.ToLower(args[0]) {
case "nx":
opts.nx = true
case "xx":
opts.xx = true
case "gt":
opts.gt = true
case "lt":
opts.lt = true
default:
return nil, fmt.Errorf("ERR Unsupported option %s", args[0])
}
args = args[1:]
}
if opts.gt && opts.lt {
return nil, errors.New("ERR GT and LT options at the same time are not compatible")
}
if opts.nx && (opts.xx || opts.gt || opts.lt) {
return nil, errors.New("ERR NX and XX, GT or LT options at the same time are not compatible")
}
return &opts, nil
}

// generic expire command for EXPIRE, PEXPIRE, EXPIREAT, PEXPIREAT
// d is the time unit. If unix is set it'll be seen as a unixtimestamp and
// converted to a duration.
Expand All @@ -77,44 +119,10 @@ func makeCmdExpire(m *Miniredis, unix bool, d time.Duration) func(*server.Peer,
return
}

var opts struct {
key string
value int
nx bool
xx bool
gt bool
lt bool
}
opts.key = args[0]
if ok := optInt(c, args[1], &opts.value); !ok {
return
}
args = args[2:]
for len(args) > 0 {
switch strings.ToLower(args[0]) {
case "nx":
opts.nx = true
case "xx":
opts.xx = true
case "gt":
opts.gt = true
case "lt":
opts.lt = true
default:
setDirty(c)
c.WriteError(fmt.Sprintf("ERR Unsupported option %s", args[0]))
return
}
args = args[1:]
}
if opts.gt && opts.lt {
setDirty(c)
c.WriteError("ERR GT and LT options at the same time are not compatible")
return
}
if opts.nx && (opts.xx || opts.gt || opts.lt) {
opts, err := expireParse(cmd, args)
if err != nil {
setDirty(c)
c.WriteError("ERR NX and XX, GT or LT options at the same time are not compatible")
c.WriteError(err.Error())
return
}

Expand Down Expand Up @@ -597,79 +605,78 @@ func (m *Miniredis) cmdRenamenx(c *server.Peer, cmd string, args []string) {
})
}

// SCAN
func (m *Miniredis) cmdScan(c *server.Peer, cmd string, args []string) {
if len(args) < 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}

var opts struct {
cursor int
count int
withMatch bool
match string
withType bool
_type string
}
type scanOpts struct {
cursor int
count int
withMatch bool
match string
withType bool
_type string
}

if ok := optIntErr(c, args[0], &opts.cursor, msgInvalidCursor); !ok {
return
func scanParse(cmd string, args []string) (*scanOpts, error) {
var opts scanOpts
if err := optIntSimple(args[0], &opts.cursor); err != nil {
return nil, errors.New(msgInvalidCursor)
}
args = args[1:]

// MATCH, COUNT and TYPE options
for len(args) > 0 {
if strings.ToLower(args[0]) == "count" {
if len(args) < 2 {
setDirty(c)
c.WriteError(msgSyntaxError)
return
return nil, errors.New(msgSyntaxError)
}
count, err := strconv.Atoi(args[1])
if err != nil || count < 0 {
setDirty(c)
c.WriteError(msgInvalidInt)
return
return nil, errors.New(msgInvalidInt)
}
if count == 0 {
setDirty(c)
c.WriteError(msgSyntaxError)
return
return nil, errors.New(msgSyntaxError)
}
opts.count = count
args = args[2:]
continue
}
if strings.ToLower(args[0]) == "match" {
if len(args) < 2 {
setDirty(c)
c.WriteError(msgSyntaxError)
return
return nil, errors.New(msgSyntaxError)
}
opts.withMatch = true
opts.match, args = args[1], args[2:]
continue
}
if strings.ToLower(args[0]) == "type" {
if len(args) < 2 {
setDirty(c)
c.WriteError(msgSyntaxError)
return
return nil, errors.New(msgSyntaxError)
}
opts.withType = true
opts._type, args = strings.ToLower(args[1]), args[2:]
continue
}
return nil, errors.New(msgSyntaxError)
}
return &opts, nil
}

// SCAN
func (m *Miniredis) cmdScan(c *server.Peer, cmd string, args []string) {
if len(args) < 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}

opts, err := scanParse(cmd, args)
if err != nil {
setDirty(c)
c.WriteError(msgSyntaxError)
c.WriteError(err.Error())
return
}

Expand Down Expand Up @@ -724,26 +731,15 @@ func (m *Miniredis) cmdScan(c *server.Peer, cmd string, args []string) {
})
}

// COPY
func (m *Miniredis) cmdCopy(c *server.Peer, cmd string, args []string) {
if len(args) < 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
type copyOpts struct {
from string
to string
destinationDB int
replace bool
}

var opts = struct {
from string
to string
destinationDB int
replace bool
}{
func copyParse(cmd string, args []string) (*copyOpts, error) {
opts := copyOpts{
destinationDB: -1,
}

Expand All @@ -752,33 +748,45 @@ func (m *Miniredis) cmdCopy(c *server.Peer, cmd string, args []string) {
switch strings.ToLower(args[0]) {
case "db":
if len(args) < 2 {
setDirty(c)
c.WriteError(msgSyntaxError)
return
return nil, errors.New(msgSyntaxError)
}
db, err := strconv.Atoi(args[1])
if err != nil {
setDirty(c)
c.WriteError(msgInvalidInt)
return
if err := optIntSimple(args[1], &opts.destinationDB); err != nil {
return nil, err
}
if db < 0 {
setDirty(c)
c.WriteError(msgDBIndexOutOfRange)
return
if opts.destinationDB < 0 {
return nil, errors.New(msgDBIndexOutOfRange)
}
opts.destinationDB = db
args = args[2:]
case "replace":
opts.replace = true
args = args[1:]
default:
setDirty(c)
c.WriteError(msgSyntaxError)
return
return nil, errors.New(msgSyntaxError)
}
}
return &opts, nil
}

// COPY
func (m *Miniredis) cmdCopy(c *server.Peer, cmd string, args []string) {
if len(args) < 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}

opts, err := copyParse(cmd, args)
if err != nil {
setDirty(c)
c.WriteError(err.Error())
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
fromDB, toDB := ctx.selectedDB, opts.destinationDB
if toDB == -1 {
Expand Down
24 changes: 24 additions & 0 deletions cmd_generic_test.go
Expand Up @@ -17,6 +17,14 @@ func TestTTL(t *testing.T) {
ok(t, err)
defer c.Close()

t.Run("parse", func(t *testing.T) {
t.Run("basic", func(t *testing.T) {
v, err := expireParse("SCAN", []string{"foo", "200"})
ok(t, err)
equals(t, expireOpts{key: "foo", value: 200}, *v)
})
})

// Not volatile yet
{
equals(t, time.Duration(0), s.TTL("foo"))
Expand Down Expand Up @@ -681,6 +689,14 @@ func TestScan(t *testing.T) {
ok(t, err)
defer c.Close()

t.Run("parse", func(t *testing.T) {
t.Run("basic", func(t *testing.T) {
v, err := scanParse("SCAN", []string{"0", "COUNT", "200"})
ok(t, err)
equals(t, scanOpts{count: 200}, *v)
})
})

// We cheat with scan. It always returns everything.

s.Set("key", "value")
Expand Down Expand Up @@ -896,6 +912,14 @@ func TestCopy(t *testing.T) {
ok(t, err)
defer c.Close()

t.Run("parse", func(t *testing.T) {
t.Run("basic", func(t *testing.T) {
v, err := copyParse("copy", []string{"key1", "key2"})
ok(t, err)
equals(t, copyOpts{from: "key1", to: "key2", destinationDB: -1}, *v)
})
})

t.Run("basic", func(t *testing.T) {
s.Set("key1", "value")
// should return 1 after a successful copy operation:
Expand Down
11 changes: 11 additions & 0 deletions opts.go
@@ -1,6 +1,7 @@
package miniredis

import (
"errors"
"math"
"strconv"
"time"
Expand All @@ -26,6 +27,16 @@ func optIntErr(c *server.Peer, src string, dest *int, errMsg string) bool {
return true
}

// optIntSimple sets dest or returns an error
func optIntSimple(src string, dest *int) error {
n, err := strconv.Atoi(src)
if err != nil {
return errors.New(msgInvalidInt)
}
*dest = n
return nil
}

func optDuration(c *server.Peer, src string, dest *time.Duration) bool {
n, err := strconv.ParseFloat(src, 64)
if err != nil {
Expand Down

0 comments on commit 42c2a2c

Please sign in to comment.