Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

From string validation #75

Merged
merged 4 commits into from Mar 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 11 additions & 0 deletions error.go
@@ -0,0 +1,11 @@
package xid

const (
// ErrInvalidID is returned when trying to unmarshal an invalid ID.
ErrInvalidID strErr = "xid: invalid ID"
)

// strErr allows declaring errors as constants.
type strErr string

func (err strErr) Error() string { return string(err) }
22 changes: 15 additions & 7 deletions id.go
Expand Up @@ -47,7 +47,6 @@ import (
"crypto/rand"
"database/sql/driver"
"encoding/binary"
"errors"
"fmt"
"hash/crc32"
"io/ioutil"
Expand All @@ -73,9 +72,6 @@ const (
)

var (
// ErrInvalidID is returned when trying to unmarshal an invalid ID
ErrInvalidID = errors.New("xid: invalid ID")

// objectIDCounter is atomically incremented when generating a new ObjectId
// using NewObjectId() function. It's used as a counter part of an id.
// This id is initialized with a random value.
Expand Down Expand Up @@ -242,7 +238,9 @@ func (id *ID) UnmarshalText(text []byte) error {
return ErrInvalidID
}
}
decode(id, text)
if !decode(id, text) {
return ErrInvalidID
}
return nil
}

Expand All @@ -260,8 +258,8 @@ func (id *ID) UnmarshalJSON(b []byte) error {
return id.UnmarshalText(b[1 : len(b)-1])
}

// decode by unrolling the stdlib base32 algorithm + removing all safe checks
func decode(id *ID, src []byte) {
// decode by unrolling the stdlib base32 algorithm + customized safe check.
func decode(id *ID, src []byte) bool {
_ = src[19]
_ = id[11]

Expand All @@ -277,6 +275,16 @@ func decode(id *ID, src []byte) {
id[2] = dec[src[3]]<<4 | dec[src[4]]>>1
id[1] = dec[src[1]]<<6 | dec[src[2]]<<1 | dec[src[3]]>>4
id[0] = dec[src[0]]<<3 | dec[src[1]]>>2

// Validate that there are no discarer bits (padding) in src that would
// cause the string-encoded id not to equal src.
var check [4]byte

check[3] = encoding[(id[11]<<4)&0x1F]
check[2] = encoding[(id[11]>>1)&0x1F]
check[1] = encoding[(id[11]>>6)&0x1F|(id[10]<<2)&0x1F]
check[0] = encoding[id[10]>>3]
return bytes.Equal([]byte(src[16:20]), check[:])
}

// Time returns the timestamp part of the id.
Expand Down
64 changes: 60 additions & 4 deletions id_test.go
Expand Up @@ -5,8 +5,10 @@ import (
"encoding/json"
"errors"
"fmt"
"math/rand"
"reflect"
"testing"
"testing/quick"
"time"
)

Expand All @@ -19,21 +21,21 @@ type IDParts struct {
}

var IDs = []IDParts{
IDParts{
{
ID{0x4d, 0x88, 0xe1, 0x5b, 0x60, 0xf4, 0x86, 0xe4, 0x28, 0x41, 0x2d, 0xc9},
1300816219,
[]byte{0x60, 0xf4, 0x86},
0xe428,
4271561,
},
IDParts{
{
ID{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
0,
[]byte{0x00, 0x00, 0x00},
0x0000,
0,
},
IDParts{
{
ID{0x00, 0x00, 0x00, 0x00, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0x00, 0x00, 0x01},
0,
[]byte{0xaa, 0xbb, 0xcc},
Expand Down Expand Up @@ -252,6 +254,60 @@ func BenchmarkFromString(b *testing.B) {
})
}

func TestFromStringQuick(t *testing.T) {
f := func(id1 ID, c byte) bool {
s1 := id1.String()
for i := range s1 {
s2 := []byte(s1)
s2[i] = c
id2, err := FromString(string(s2))
if id1 == id2 && err == nil && c != s1[i] {
t.Logf("comparing XIDs:\na: %q\nb: %q (index %d changed to %c)", s1, s2, i, c)
return false
}
}
return true
}
err := quick.Check(f, &quick.Config{
Values: func(args []reflect.Value, r *rand.Rand) {
i := r.Intn(len(encoding))
args[0] = reflect.ValueOf(New())
args[1] = reflect.ValueOf(byte(encoding[i]))
},
MaxCount: 1000,
})
if err != nil {
t.Error(err)
}
}

func TestFromStringQuickInvalidChars(t *testing.T) {
f := func(id1 ID, c byte) bool {
s1 := id1.String()
for i := range s1 {
s2 := []byte(s1)
s2[i] = c
id2, err := FromString(string(s2))
if id1 == id2 && err == nil && c != s1[i] {
t.Logf("comparing XIDs:\na: %q\nb: %q (index %d changed to %c)", s1, s2, i, c)
return false
}
}
return true
}
err := quick.Check(f, &quick.Config{
Values: func(args []reflect.Value, r *rand.Rand) {
i := r.Intn(0xFF)
args[0] = reflect.ValueOf(New())
args[1] = reflect.ValueOf(byte(i))
},
MaxCount: 2000,
})
if err != nil {
t.Error(err)
}
}

// func BenchmarkUUIDv1(b *testing.B) {
// b.RunParallel(func(pb *testing.PB) {
// for pb.Next() {
Expand Down Expand Up @@ -329,7 +385,7 @@ func TestFromBytes_InvalidBytes(t *testing.T) {
{13, true},
}
for _, c := range cases {
b := make([]byte, c.length, c.length)
b := make([]byte, c.length)
_, err := FromBytes(b)
if got, want := err != nil, c.shouldFail; got != want {
t.Errorf("FromBytes() error got %v, want %v", got, want)
Expand Down