From 66f8c42da230c3323ed4e29805e73eefbad41fc5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sindre=20R=C3=B8kenes=20Myren?= Date: Thu, 10 Mar 2022 23:51:57 +0100 Subject: [PATCH] From string validation (#75) * chore: address linter issues Perform changes suggested by linter; no functional changes. * error.go: make ErrInvalidID a const Make ErrInvalidID a constant instead of a variable. This prevent it from being changed by external packages; a behavior that although allowed by the compiler, should probably be considered an invalid operation. * add benchmark and new failing test for FromString * fix: let decode look for additional base32 padding Update FromString and XID.TextUnmarshal so that it looks for discarded bits in the final source character. This ensures that XIDs that have been manually tampered with in a way that's ignored by base32 decode, will not pass as valid. --- error.go | 11 ++++++++++ id.go | 22 +++++++++++++------ id_test.go | 64 ++++++++++++++++++++++++++++++++++++++++++++++++++---- 3 files changed, 86 insertions(+), 11 deletions(-) create mode 100644 error.go diff --git a/error.go b/error.go new file mode 100644 index 0000000..ea25374 --- /dev/null +++ b/error.go @@ -0,0 +1,11 @@ +package xid + +const ( + // ErrInvalidID is returned when trying to unmarshal an invalid ID. + ErrInvalidID strErr = "xid: invalid ID" +) + +// strErr allows declaring errors as constants. +type strErr string + +func (err strErr) Error() string { return string(err) } diff --git a/id.go b/id.go index eaa6767..1f536b4 100644 --- a/id.go +++ b/id.go @@ -47,7 +47,6 @@ import ( "crypto/rand" "database/sql/driver" "encoding/binary" - "errors" "fmt" "hash/crc32" "io/ioutil" @@ -73,9 +72,6 @@ const ( ) var ( - // ErrInvalidID is returned when trying to unmarshal an invalid ID - ErrInvalidID = errors.New("xid: invalid ID") - // objectIDCounter is atomically incremented when generating a new ObjectId // using NewObjectId() function. It's used as a counter part of an id. // This id is initialized with a random value. @@ -242,7 +238,9 @@ func (id *ID) UnmarshalText(text []byte) error { return ErrInvalidID } } - decode(id, text) + if !decode(id, text) { + return ErrInvalidID + } return nil } @@ -260,8 +258,8 @@ func (id *ID) UnmarshalJSON(b []byte) error { return id.UnmarshalText(b[1 : len(b)-1]) } -// decode by unrolling the stdlib base32 algorithm + removing all safe checks -func decode(id *ID, src []byte) { +// decode by unrolling the stdlib base32 algorithm + customized safe check. +func decode(id *ID, src []byte) bool { _ = src[19] _ = id[11] @@ -277,6 +275,16 @@ func decode(id *ID, src []byte) { id[2] = dec[src[3]]<<4 | dec[src[4]]>>1 id[1] = dec[src[1]]<<6 | dec[src[2]]<<1 | dec[src[3]]>>4 id[0] = dec[src[0]]<<3 | dec[src[1]]>>2 + + // Validate that there are no discarer bits (padding) in src that would + // cause the string-encoded id not to equal src. + var check [4]byte + + check[3] = encoding[(id[11]<<4)&0x1F] + check[2] = encoding[(id[11]>>1)&0x1F] + check[1] = encoding[(id[11]>>6)&0x1F|(id[10]<<2)&0x1F] + check[0] = encoding[id[10]>>3] + return bytes.Equal([]byte(src[16:20]), check[:]) } // Time returns the timestamp part of the id. diff --git a/id_test.go b/id_test.go index 7c26817..c9ca659 100644 --- a/id_test.go +++ b/id_test.go @@ -5,8 +5,10 @@ import ( "encoding/json" "errors" "fmt" + "math/rand" "reflect" "testing" + "testing/quick" "time" ) @@ -19,21 +21,21 @@ type IDParts struct { } var IDs = []IDParts{ - IDParts{ + { ID{0x4d, 0x88, 0xe1, 0x5b, 0x60, 0xf4, 0x86, 0xe4, 0x28, 0x41, 0x2d, 0xc9}, 1300816219, []byte{0x60, 0xf4, 0x86}, 0xe428, 4271561, }, - IDParts{ + { ID{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, 0, []byte{0x00, 0x00, 0x00}, 0x0000, 0, }, - IDParts{ + { ID{0x00, 0x00, 0x00, 0x00, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0x00, 0x00, 0x01}, 0, []byte{0xaa, 0xbb, 0xcc}, @@ -252,6 +254,60 @@ func BenchmarkFromString(b *testing.B) { }) } +func TestFromStringQuick(t *testing.T) { + f := func(id1 ID, c byte) bool { + s1 := id1.String() + for i := range s1 { + s2 := []byte(s1) + s2[i] = c + id2, err := FromString(string(s2)) + if id1 == id2 && err == nil && c != s1[i] { + t.Logf("comparing XIDs:\na: %q\nb: %q (index %d changed to %c)", s1, s2, i, c) + return false + } + } + return true + } + err := quick.Check(f, &quick.Config{ + Values: func(args []reflect.Value, r *rand.Rand) { + i := r.Intn(len(encoding)) + args[0] = reflect.ValueOf(New()) + args[1] = reflect.ValueOf(byte(encoding[i])) + }, + MaxCount: 1000, + }) + if err != nil { + t.Error(err) + } +} + +func TestFromStringQuickInvalidChars(t *testing.T) { + f := func(id1 ID, c byte) bool { + s1 := id1.String() + for i := range s1 { + s2 := []byte(s1) + s2[i] = c + id2, err := FromString(string(s2)) + if id1 == id2 && err == nil && c != s1[i] { + t.Logf("comparing XIDs:\na: %q\nb: %q (index %d changed to %c)", s1, s2, i, c) + return false + } + } + return true + } + err := quick.Check(f, &quick.Config{ + Values: func(args []reflect.Value, r *rand.Rand) { + i := r.Intn(0xFF) + args[0] = reflect.ValueOf(New()) + args[1] = reflect.ValueOf(byte(i)) + }, + MaxCount: 2000, + }) + if err != nil { + t.Error(err) + } +} + // func BenchmarkUUIDv1(b *testing.B) { // b.RunParallel(func(pb *testing.PB) { // for pb.Next() { @@ -329,7 +385,7 @@ func TestFromBytes_InvalidBytes(t *testing.T) { {13, true}, } for _, c := range cases { - b := make([]byte, c.length, c.length) + b := make([]byte, c.length) _, err := FromBytes(b) if got, want := err != nil, c.shouldFail; got != want { t.Errorf("FromBytes() error got %v, want %v", got, want)