From 508eabfa166278eb264b02aa96e64cfdc406e463 Mon Sep 17 00:00:00 2001 From: ichigozero <61618751+ichigozero@users.noreply.github.com> Date: Fri, 10 Nov 2023 05:54:52 +0900 Subject: [PATCH] feat(random): use crypto/rand for random string generator (#55) --- random/random.go | 52 ++++++++++++++++++++++++++++++++++++------ random/random_test.go | 53 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+), 7 deletions(-) diff --git a/random/random.go b/random/random.go index 482d052..16deda2 100644 --- a/random/random.go +++ b/random/random.go @@ -1,13 +1,16 @@ package random import ( - "math/rand" + "bufio" + "crypto/rand" + "io" "strings" - "time" + "sync" ) type ( Random struct { + readerPool sync.Pool } ) @@ -27,8 +30,11 @@ var ( ) func New() *Random { - rand.Seed(time.Now().UnixNano()) - return new(Random) + // https://tip.golang.org/doc/go1.19#:~:text=Read%20no%20longer%20buffers%20random%20data%20obtained%20from%20the%20operating%20system%20between%20calls + p := sync.Pool{New: func() interface{} { + return bufio.NewReader(rand.Reader) + }} + return &Random{readerPool: p} } func (r *Random) String(length uint8, charsets ...string) string { @@ -36,11 +42,43 @@ func (r *Random) String(length uint8, charsets ...string) string { if charset == "" { charset = Alphanumeric } + + charsetLen := len(charset) + if charsetLen > 255 { + charsetLen = 255 + } + maxByte := 255 - (256 % charsetLen) + + reader := r.readerPool.Get().(*bufio.Reader) + defer r.readerPool.Put(reader) + b := make([]byte, length) - for i := range b { - b[i] = charset[rand.Int63()%int64(len(charset))] + rs := make([]byte, length+(length/4)) // perf: avoid read from rand.Reader many times + var i uint8 = 0 + + // security note: + // we can't just simply do b[i]=charset[rb%byte(charsetLen)], + // for example, when charsetLen is 52, and rb is [0, 255], 256 = 52 * 4 + 48. + // this will make the first 48 characters more possibly to be generated then others. + // so we have to skip bytes when rb > maxByte + + for { + _, err := io.ReadFull(reader, rs) + if err != nil { + panic("unexpected error happened when reading from bufio.NewReader(crypto/rand.Reader)") + } + for _, rb := range rs { + if rb > byte(maxByte) { + // Skip this number to avoid bias. + continue + } + b[i] = charset[rb%byte(charsetLen)] + i++ + if i == length { + return string(b) + } + } } - return string(b) } func String(length uint8, charsets ...string) string { diff --git a/random/random_test.go b/random/random_test.go index 6009b89..5927dcf 100644 --- a/random/random_test.go +++ b/random/random_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func Test(t *testing.T) { @@ -12,3 +13,55 @@ func Test(t *testing.T) { r := New() assert.Regexp(t, regexp.MustCompile("[0-9]+$"), r.String(8, Numeric)) } + +func TestRandomString(t *testing.T) { + var testCases = []struct { + name string + whenLength uint8 + expect string + }{ + { + name: "ok, 16", + whenLength: 16, + }, + { + name: "ok, 32", + whenLength: 32, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + uid := String(tc.whenLength, Alphabetic) + assert.Len(t, uid, int(tc.whenLength)) + }) + } +} + +func TestRandomStringBias(t *testing.T) { + t.Parallel() + const slen = 33 + const loop = 100000 + + counts := make(map[rune]int) + var count int64 + + for i := 0; i < loop; i++ { + s := String(slen, Alphabetic) + require.Equal(t, slen, len(s)) + for _, b := range s { + counts[b]++ + count++ + } + } + + require.Equal(t, len(Alphabetic), len(counts)) + + avg := float64(count) / float64(len(counts)) + for k, n := range counts { + diff := float64(n) / avg + if diff < 0.95 || diff > 1.05 { + t.Errorf("Bias on '%c': expected average %f, got %d", k, avg, n) + } + } +}