diff --git a/error.go b/error.go new file mode 100644 index 0000000..ea25374 --- /dev/null +++ b/error.go @@ -0,0 +1,11 @@ +package xid + +const ( + // ErrInvalidID is returned when trying to unmarshal an invalid ID. + ErrInvalidID strErr = "xid: invalid ID" +) + +// strErr allows declaring errors as constants. +type strErr string + +func (err strErr) Error() string { return string(err) } diff --git a/helper.go b/helper.go index ecf650c..d4905b7 100644 --- a/helper.go +++ b/helper.go @@ -1,11 +1,11 @@ package xid -// New().String() +// NewString New().String() func NewString() string { return New().String() } -// New().Bytes() +// NewBytes New().Bytes() func NewBytes() []byte { return New().Bytes() } diff --git a/hostid_darwin.go b/hostid_darwin.go index 08351ff..b18d52f 100644 --- a/hostid_darwin.go +++ b/hostid_darwin.go @@ -1,3 +1,4 @@ +//go:build darwin // +build darwin package xid diff --git a/hostid_fallback.go b/hostid_fallback.go index 7fbd3c0..f13e314 100644 --- a/hostid_fallback.go +++ b/hostid_fallback.go @@ -1,3 +1,4 @@ +//go:build !darwin && !linux && !freebsd && !windows // +build !darwin,!linux,!freebsd,!windows package xid diff --git a/hostid_freebsd.go b/hostid_freebsd.go index be25a03..2414598 100644 --- a/hostid_freebsd.go +++ b/hostid_freebsd.go @@ -1,3 +1,4 @@ +//go:build freebsd // +build freebsd package xid diff --git a/hostid_linux.go b/hostid_linux.go index 837b204..4319208 100644 --- a/hostid_linux.go +++ b/hostid_linux.go @@ -1,3 +1,4 @@ +//go:build linux // +build linux package xid diff --git a/hostid_windows.go b/hostid_windows.go index ec2593e..2af811d 100644 --- a/hostid_windows.go +++ b/hostid_windows.go @@ -1,3 +1,4 @@ +//go:build windows // +build windows package xid diff --git a/id.go b/id.go index eaa6767..1f536b4 100644 --- a/id.go +++ b/id.go @@ -47,7 +47,6 @@ import ( "crypto/rand" "database/sql/driver" "encoding/binary" - "errors" "fmt" "hash/crc32" "io/ioutil" @@ -73,9 +72,6 @@ const ( ) var ( - // ErrInvalidID is returned when trying to unmarshal an invalid ID - ErrInvalidID = errors.New("xid: invalid ID") - // objectIDCounter is atomically incremented when generating a new ObjectId // using NewObjectId() function. It's used as a counter part of an id. // This id is initialized with a random value. @@ -242,7 +238,9 @@ func (id *ID) UnmarshalText(text []byte) error { return ErrInvalidID } } - decode(id, text) + if !decode(id, text) { + return ErrInvalidID + } return nil } @@ -260,8 +258,8 @@ func (id *ID) UnmarshalJSON(b []byte) error { return id.UnmarshalText(b[1 : len(b)-1]) } -// decode by unrolling the stdlib base32 algorithm + removing all safe checks -func decode(id *ID, src []byte) { +// decode by unrolling the stdlib base32 algorithm + customized safe check. +func decode(id *ID, src []byte) bool { _ = src[19] _ = id[11] @@ -277,6 +275,16 @@ func decode(id *ID, src []byte) { id[2] = dec[src[3]]<<4 | dec[src[4]]>>1 id[1] = dec[src[1]]<<6 | dec[src[2]]<<1 | dec[src[3]]>>4 id[0] = dec[src[0]]<<3 | dec[src[1]]>>2 + + // Validate that there are no discarer bits (padding) in src that would + // cause the string-encoded id not to equal src. + var check [4]byte + + check[3] = encoding[(id[11]<<4)&0x1F] + check[2] = encoding[(id[11]>>1)&0x1F] + check[1] = encoding[(id[11]>>6)&0x1F|(id[10]<<2)&0x1F] + check[0] = encoding[id[10]>>3] + return bytes.Equal([]byte(src[16:20]), check[:]) } // Time returns the timestamp part of the id. diff --git a/id_test.go b/id_test.go index 7c26817..c9ca659 100644 --- a/id_test.go +++ b/id_test.go @@ -5,8 +5,10 @@ import ( "encoding/json" "errors" "fmt" + "math/rand" "reflect" "testing" + "testing/quick" "time" ) @@ -19,21 +21,21 @@ type IDParts struct { } var IDs = []IDParts{ - IDParts{ + { ID{0x4d, 0x88, 0xe1, 0x5b, 0x60, 0xf4, 0x86, 0xe4, 0x28, 0x41, 0x2d, 0xc9}, 1300816219, []byte{0x60, 0xf4, 0x86}, 0xe428, 4271561, }, - IDParts{ + { ID{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, 0, []byte{0x00, 0x00, 0x00}, 0x0000, 0, }, - IDParts{ + { ID{0x00, 0x00, 0x00, 0x00, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0x00, 0x00, 0x01}, 0, []byte{0xaa, 0xbb, 0xcc}, @@ -252,6 +254,60 @@ func BenchmarkFromString(b *testing.B) { }) } +func TestFromStringQuick(t *testing.T) { + f := func(id1 ID, c byte) bool { + s1 := id1.String() + for i := range s1 { + s2 := []byte(s1) + s2[i] = c + id2, err := FromString(string(s2)) + if id1 == id2 && err == nil && c != s1[i] { + t.Logf("comparing XIDs:\na: %q\nb: %q (index %d changed to %c)", s1, s2, i, c) + return false + } + } + return true + } + err := quick.Check(f, &quick.Config{ + Values: func(args []reflect.Value, r *rand.Rand) { + i := r.Intn(len(encoding)) + args[0] = reflect.ValueOf(New()) + args[1] = reflect.ValueOf(byte(encoding[i])) + }, + MaxCount: 1000, + }) + if err != nil { + t.Error(err) + } +} + +func TestFromStringQuickInvalidChars(t *testing.T) { + f := func(id1 ID, c byte) bool { + s1 := id1.String() + for i := range s1 { + s2 := []byte(s1) + s2[i] = c + id2, err := FromString(string(s2)) + if id1 == id2 && err == nil && c != s1[i] { + t.Logf("comparing XIDs:\na: %q\nb: %q (index %d changed to %c)", s1, s2, i, c) + return false + } + } + return true + } + err := quick.Check(f, &quick.Config{ + Values: func(args []reflect.Value, r *rand.Rand) { + i := r.Intn(0xFF) + args[0] = reflect.ValueOf(New()) + args[1] = reflect.ValueOf(byte(i)) + }, + MaxCount: 2000, + }) + if err != nil { + t.Error(err) + } +} + // func BenchmarkUUIDv1(b *testing.B) { // b.RunParallel(func(pb *testing.PB) { // for pb.Next() { @@ -329,7 +385,7 @@ func TestFromBytes_InvalidBytes(t *testing.T) { {13, true}, } for _, c := range cases { - b := make([]byte, c.length, c.length) + b := make([]byte, c.length) _, err := FromBytes(b) if got, want := err != nil, c.shouldFail; got != want { t.Errorf("FromBytes() error got %v, want %v", got, want)