Skip to content

Commit

Permalink
perf: From string validation (rs#75) @smyrman
Browse files Browse the repository at this point in the history
  • Loading branch information
fufuok committed Mar 12, 2022
1 parent e02d2e4 commit e89aee3
Show file tree
Hide file tree
Showing 9 changed files with 93 additions and 13 deletions.
11 changes: 11 additions & 0 deletions error.go
@@ -0,0 +1,11 @@
package xid

const (
// ErrInvalidID is returned when trying to unmarshal an invalid ID.
ErrInvalidID strErr = "xid: invalid ID"
)

// strErr allows declaring errors as constants.
type strErr string

func (err strErr) Error() string { return string(err) }
4 changes: 2 additions & 2 deletions helper.go
@@ -1,11 +1,11 @@
package xid

// New().String()
// NewString New().String()
func NewString() string {
return New().String()
}

// New().Bytes()
// NewBytes New().Bytes()
func NewBytes() []byte {
return New().Bytes()
}
1 change: 1 addition & 0 deletions hostid_darwin.go
@@ -1,3 +1,4 @@
//go:build darwin
// +build darwin

package xid
Expand Down
1 change: 1 addition & 0 deletions hostid_fallback.go
@@ -1,3 +1,4 @@
//go:build !darwin && !linux && !freebsd && !windows
// +build !darwin,!linux,!freebsd,!windows

package xid
Expand Down
1 change: 1 addition & 0 deletions hostid_freebsd.go
@@ -1,3 +1,4 @@
//go:build freebsd
// +build freebsd

package xid
Expand Down
1 change: 1 addition & 0 deletions hostid_linux.go
@@ -1,3 +1,4 @@
//go:build linux
// +build linux

package xid
Expand Down
1 change: 1 addition & 0 deletions hostid_windows.go
@@ -1,3 +1,4 @@
//go:build windows
// +build windows

package xid
Expand Down
22 changes: 15 additions & 7 deletions id.go
Expand Up @@ -47,7 +47,6 @@ import (
"crypto/rand"
"database/sql/driver"
"encoding/binary"
"errors"
"fmt"
"hash/crc32"
"io/ioutil"
Expand All @@ -73,9 +72,6 @@ const (
)

var (
// ErrInvalidID is returned when trying to unmarshal an invalid ID
ErrInvalidID = errors.New("xid: invalid ID")

// objectIDCounter is atomically incremented when generating a new ObjectId
// using NewObjectId() function. It's used as a counter part of an id.
// This id is initialized with a random value.
Expand Down Expand Up @@ -242,7 +238,9 @@ func (id *ID) UnmarshalText(text []byte) error {
return ErrInvalidID
}
}
decode(id, text)
if !decode(id, text) {
return ErrInvalidID
}
return nil
}

Expand All @@ -260,8 +258,8 @@ func (id *ID) UnmarshalJSON(b []byte) error {
return id.UnmarshalText(b[1 : len(b)-1])
}

// decode by unrolling the stdlib base32 algorithm + removing all safe checks
func decode(id *ID, src []byte) {
// decode by unrolling the stdlib base32 algorithm + customized safe check.
func decode(id *ID, src []byte) bool {
_ = src[19]
_ = id[11]

Expand All @@ -277,6 +275,16 @@ func decode(id *ID, src []byte) {
id[2] = dec[src[3]]<<4 | dec[src[4]]>>1
id[1] = dec[src[1]]<<6 | dec[src[2]]<<1 | dec[src[3]]>>4
id[0] = dec[src[0]]<<3 | dec[src[1]]>>2

// Validate that there are no discarer bits (padding) in src that would
// cause the string-encoded id not to equal src.
var check [4]byte

check[3] = encoding[(id[11]<<4)&0x1F]
check[2] = encoding[(id[11]>>1)&0x1F]
check[1] = encoding[(id[11]>>6)&0x1F|(id[10]<<2)&0x1F]
check[0] = encoding[id[10]>>3]
return bytes.Equal([]byte(src[16:20]), check[:])
}

// Time returns the timestamp part of the id.
Expand Down
64 changes: 60 additions & 4 deletions id_test.go
Expand Up @@ -5,8 +5,10 @@ import (
"encoding/json"
"errors"
"fmt"
"math/rand"
"reflect"
"testing"
"testing/quick"
"time"
)

Expand All @@ -19,21 +21,21 @@ type IDParts struct {
}

var IDs = []IDParts{
IDParts{
{
ID{0x4d, 0x88, 0xe1, 0x5b, 0x60, 0xf4, 0x86, 0xe4, 0x28, 0x41, 0x2d, 0xc9},
1300816219,
[]byte{0x60, 0xf4, 0x86},
0xe428,
4271561,
},
IDParts{
{
ID{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
0,
[]byte{0x00, 0x00, 0x00},
0x0000,
0,
},
IDParts{
{
ID{0x00, 0x00, 0x00, 0x00, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0x00, 0x00, 0x01},
0,
[]byte{0xaa, 0xbb, 0xcc},
Expand Down Expand Up @@ -252,6 +254,60 @@ func BenchmarkFromString(b *testing.B) {
})
}

func TestFromStringQuick(t *testing.T) {
f := func(id1 ID, c byte) bool {
s1 := id1.String()
for i := range s1 {
s2 := []byte(s1)
s2[i] = c
id2, err := FromString(string(s2))
if id1 == id2 && err == nil && c != s1[i] {
t.Logf("comparing XIDs:\na: %q\nb: %q (index %d changed to %c)", s1, s2, i, c)
return false
}
}
return true
}
err := quick.Check(f, &quick.Config{
Values: func(args []reflect.Value, r *rand.Rand) {
i := r.Intn(len(encoding))
args[0] = reflect.ValueOf(New())
args[1] = reflect.ValueOf(byte(encoding[i]))
},
MaxCount: 1000,
})
if err != nil {
t.Error(err)
}
}

func TestFromStringQuickInvalidChars(t *testing.T) {
f := func(id1 ID, c byte) bool {
s1 := id1.String()
for i := range s1 {
s2 := []byte(s1)
s2[i] = c
id2, err := FromString(string(s2))
if id1 == id2 && err == nil && c != s1[i] {
t.Logf("comparing XIDs:\na: %q\nb: %q (index %d changed to %c)", s1, s2, i, c)
return false
}
}
return true
}
err := quick.Check(f, &quick.Config{
Values: func(args []reflect.Value, r *rand.Rand) {
i := r.Intn(0xFF)
args[0] = reflect.ValueOf(New())
args[1] = reflect.ValueOf(byte(i))
},
MaxCount: 2000,
})
if err != nil {
t.Error(err)
}
}

// func BenchmarkUUIDv1(b *testing.B) {
// b.RunParallel(func(pb *testing.PB) {
// for pb.Next() {
Expand Down Expand Up @@ -329,7 +385,7 @@ func TestFromBytes_InvalidBytes(t *testing.T) {
{13, true},
}
for _, c := range cases {
b := make([]byte, c.length, c.length)
b := make([]byte, c.length)
_, err := FromBytes(b)
if got, want := err != nil, c.shouldFail; got != want {
t.Errorf("FromBytes() error got %v, want %v", got, want)
Expand Down

0 comments on commit e89aee3

Please sign in to comment.