diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 6790326..97c9180 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -4,15 +4,15 @@ jobs: test: strategy: matrix: - go-version: [1.16.x, 1.17.x, 1.18.x] + go-version: [1.16.x, 1.17.x, 1.18.x, 1.19.x] os: [ubuntu-latest, macos-latest, windows-latest] runs-on: ${{ matrix.os }} steps: - - name: Install Go - uses: actions/setup-go@v2 - with: - go-version: ${{ matrix.go-version }} - - name: Checkout code - uses: actions/checkout@v2 - - name: Test - run: go test ./... -count=100 + - name: Install Go + uses: actions/setup-go@v2 + with: + go-version: ${{ matrix.go-version }} + - name: Checkout code + uses: actions/checkout@v2 + - name: Test + run: go test ./... -count=100 diff --git a/borsh_test.go b/borsh_test.go index 6d16cde..763affe 100644 --- a/borsh_test.go +++ b/borsh_test.go @@ -1076,13 +1076,6 @@ type StructWithOptionalFields struct { Hello string } -func concatByteSlices(slices ...[]byte) (out []byte) { - for i := range slices { - out = append(out, slices[i]...) - } - return -} - type Struct struct { Foo string Bar uint32 @@ -1486,8 +1479,9 @@ type S struct { } func TestSet(t *testing.T) { + emptyStruct := struct{}{} x := S{ - S: map[int64]struct{}{124: struct{}{}, 214: struct{}{}, 24: struct{}{}, 53: struct{}{}}, + S: map[int64]struct{}{124: emptyStruct, 214: emptyStruct, 24: emptyStruct, 53: emptyStruct}, } data, err := MarshalBorsh(x) require.NoError(t, err) @@ -1653,3 +1647,112 @@ func TestCustomType(t *testing.T) { require.Equal(t, x, *y) } + +func TestStringSlice(t *testing.T) { + { + // slice: + x := []string{"a", "b", "c"} + data, err := MarshalBorsh(x) + require.NoError(t, err) + + require.Equal(t, concatByteSlices( + []byte{0x3, 0x0, 0x0, 0x0}, // length + + []byte{0x1, 0x0, 0x0, 0x0}, // length of first string + []byte("a"), + + []byte{0x1, 0x0, 0x0, 0x0}, // length of second string + []byte("b"), + + []byte{0x1, 0x0, 0x0, 0x0}, // length of third string + []byte("c"), + ), data) + + y := new([]string) + err = UnmarshalBorsh(y, data) + require.NoError(t, err) + + require.Equal(t, x, *y) + } + { + // string slice as field: + type S struct { + A []string + } + x := S{ + A: []string{"a", "b", "c"}, + } + data, err := MarshalBorsh(x) + require.NoError(t, err) + + require.Equal(t, concatByteSlices( + []byte{0x3, 0x0, 0x0, 0x0}, // length of A + + []byte{0x1, 0x0, 0x0, 0x0}, // length of A[0] + []byte("a"), + + []byte{0x1, 0x0, 0x0, 0x0}, // length of A[1] + []byte("b"), + + []byte{0x1, 0x0, 0x0, 0x0}, // length of A[2] + []byte("c"), + ), data) + + y := new(S) + err = UnmarshalBorsh(y, data) + require.NoError(t, err) + + require.Equal(t, x, *y) + } + { + // string slice as optional field (present): + type S struct { + A *[]string `bin:"optional"` + } + slice := []string{"a", "b", "c"} + x := S{ + A: &slice, + } + data, err := MarshalBorsh(x) + require.NoError(t, err) + + require.Equal(t, concatByteSlices( + []byte{0x01}, // optionality + []byte{0x3, 0x0, 0x0, 0x0}, // slice length + + []byte{0x1, 0x0, 0x0, 0x0}, // slice item length (string) + []byte("a"), + + []byte{0x1, 0x0, 0x0, 0x0}, // slice item length (string) + []byte("b"), + + []byte{0x1, 0x0, 0x0, 0x0}, // slice item length (string) + []byte("c"), + ), data) + + y := new(S) + err = UnmarshalBorsh(y, data) + require.NoError(t, err) + + require.Equal(t, x, *y) + } + { + // string slice as optional field (absent): + type S struct { + A *[]string `bin:"optional"` + } + x := S{} + data, err := MarshalBorsh(x) + require.NoError(t, err) + + require.Equal(t, concatByteSlices( + []byte{0x0}, // optionality + ), data) + + y := new(S) + err = UnmarshalBorsh(y, data) + require.NoError(t, err) + + require.Equal(t, x, *y) + } +} diff --git a/decoder.go b/decoder.go index d341e12..e563027 100644 --- a/decoder.go +++ b/decoder.go @@ -141,7 +141,7 @@ func sizeof(t reflect.Type, v reflect.Value) int { } return n default: - panic(fmt.Sprintf("sizeof field ")) + panic(fmt.Sprintf("sizeof field not implemented for kind %s", t.Kind())) } } @@ -184,7 +184,6 @@ func (dec *Decoder) ReadVarint32() (out int32, err error) { } func (dec *Decoder) ReadUvarint32() (out uint32, err error) { - n, err := dec.ReadUvarint64() if err != nil { return out, err @@ -195,6 +194,7 @@ func (dec *Decoder) ReadUvarint32() (out uint32, err error) { } return } + func (dec *Decoder) ReadVarint16() (out int16, err error) { n, err := dec.ReadVarint64() if err != nil { @@ -208,7 +208,6 @@ func (dec *Decoder) ReadVarint16() (out int16, err error) { } func (dec *Decoder) ReadUvarint16() (out uint16, err error) { - n, err := dec.ReadUvarint64() if err != nil { return out, err @@ -276,6 +275,12 @@ type peekAbleByteReader interface { } func readNBytes(n int, reader peekAbleByteReader) ([]byte, error) { + if n == 0 { + return make([]byte, 0), nil + } + if n < 0 || n > 0x7FFF_FFFF { + return nil, fmt.Errorf("invalid length n: %v", n) + } buf := make([]byte, n) for i := 0; i < n; i++ { b, err := reader.ReadByte() @@ -284,14 +289,27 @@ func readNBytes(n int, reader peekAbleByteReader) ([]byte, error) { } buf[i] = b } - return buf, nil } +func discardNBytes(n int, reader *Decoder) error { + if n == 0 { + return nil + } + if n < 0 || n > 0x7FFF_FFFF { + return fmt.Errorf("invalid length n: %v", n) + } + return reader.SkipBytes(uint(n)) +} + func (dec *Decoder) ReadNBytes(n int) (out []byte, err error) { return readNBytes(n, dec) } +func (dec *Decoder) Discard(n int) (err error) { + return discardNBytes(n, dec) +} + func (dec *Decoder) ReadTypeID() (out TypeID, err error) { discriminator, err := dec.ReadNBytes(8) if err != nil { @@ -349,7 +367,6 @@ func (dec *Decoder) ReadBool() (out bool, err error) { zlog.Debug("decode: read bool", zap.Bool("val", out)) } return - } func (dec *Decoder) ReadUint8() (out uint8, err error) { @@ -443,7 +460,6 @@ func (dec *Decoder) ReadInt128(order binary.ByteOrder) (out Int128, err error) { if err != nil { return } - return Int128(v), nil } @@ -517,7 +533,6 @@ func (dec *Decoder) ReadFloat128(order binary.ByteOrder) (out Float128, err erro if err != nil { return out, fmt.Errorf("float128: %s", err) } - return Float128(value), nil } @@ -551,6 +566,9 @@ func (dec *Decoder) ReadRustString() (out string, err error) { if err != nil { return "", err } + if length > 0x7FFF_FFFF { + return "", io.ErrUnexpectedEOF + } bytes, err := dec.ReadNBytes(int(length)) if err != nil { return "", err @@ -673,3 +691,113 @@ func indirect(v reflect.Value, decodingNull bool) (BinaryUnmarshaler, reflect.Va } return nil, v } + +func reflect_readArrayOfBytes(d *Decoder, l int, rv reflect.Value) error { + buf, err := d.ReadNBytes(l) + if err != nil { + return err + } + switch rv.Kind() { + case reflect.Array: + reflect.Copy(rv, reflect.ValueOf(buf)) + case reflect.Slice: + rv.Set(reflect.ValueOf(buf)) + default: + return fmt.Errorf("unsupported kind: %s", rv.Kind()) + } + return nil +} + +func reflect_readArrayOfUint16(d *Decoder, l int, rv reflect.Value, order binary.ByteOrder) error { + buf := make([]uint16, l) + for i := 0; i < l; i++ { + n, err := d.ReadUint16(order) + if err != nil { + return err + } + buf[i] = n + } + switch rv.Kind() { + case reflect.Array: + reflect.Copy(rv, reflect.ValueOf(buf)) + case reflect.Slice: + rv.Set(reflect.ValueOf(buf)) + default: + return fmt.Errorf("unsupported kind: %s", rv.Kind()) + } + return nil +} + +func reflect_readArrayOfUint32(d *Decoder, l int, rv reflect.Value, order binary.ByteOrder) error { + buf := make([]uint32, l) + for i := 0; i < l; i++ { + n, err := d.ReadUint32(order) + if err != nil { + return err + } + buf[i] = n + } + switch rv.Kind() { + case reflect.Array: + reflect.Copy(rv, reflect.ValueOf(buf)) + case reflect.Slice: + rv.Set(reflect.ValueOf(buf)) + default: + return fmt.Errorf("unsupported kind: %s", rv.Kind()) + } + return nil +} + +func reflect_readArrayOfUint64(d *Decoder, l int, rv reflect.Value, order binary.ByteOrder) error { + buf := make([]uint64, l) + for i := 0; i < l; i++ { + n, err := d.ReadUint64(order) + if err != nil { + return err + } + buf[i] = n + } + switch rv.Kind() { + case reflect.Array: + reflect.Copy(rv, reflect.ValueOf(buf)) + case reflect.Slice: + rv.Set(reflect.ValueOf(buf)) + default: + return fmt.Errorf("unsupported kind: %s", rv.Kind()) + } + return nil +} + +// reflect_readArrayOfUint_ is used for reading arrays/slices of uints of any size. +func reflect_readArrayOfUint_(d *Decoder, l int, k reflect.Kind, rv reflect.Value, order binary.ByteOrder) error { + switch k { + // case reflect.Uint: + // // switch on system architecture (32 or 64 bit) + // if unsafe.Sizeof(uintptr(0)) == 4 { + // return reflect_readArrayOfUint32( d, l, rv, order) + // } + // return reflect_readArrayOfUint64( d, l, rv, order) + case reflect.Uint8: + if l > d.Remaining() { + return io.ErrUnexpectedEOF + } + return reflect_readArrayOfBytes(d, l, rv) + case reflect.Uint16: + if l*2 > d.Remaining() { + return io.ErrUnexpectedEOF + } + return reflect_readArrayOfUint16(d, l, rv, order) + case reflect.Uint32: + if l*4 > d.Remaining() { + return io.ErrUnexpectedEOF + } + return reflect_readArrayOfUint32(d, l, rv, order) + case reflect.Uint64: + if l*8 > d.Remaining() { + return io.ErrUnexpectedEOF + } + return reflect_readArrayOfUint64(d, l, rv, order) + default: + return fmt.Errorf("unsupported kind: %v", k) + } +} diff --git a/decoder_bench_test.go b/decoder_bench_test.go new file mode 100644 index 0000000..5e81637 --- /dev/null +++ b/decoder_bench_test.go @@ -0,0 +1,419 @@ +package bin + +import ( + "reflect" + "testing" +) + +func newUint64SliceEncoded(l int) []byte { + buf := make([]byte, 0) + for i := 0; i < l; i++ { + buf = append(buf, uint64ToBytes(uint64(i), LE)...) + } + return buf +} + +func Benchmark_uintSlice64_Decode_noMake(b *testing.B) { + l := 1024 + buf := concatByteSlices( + // length: + uint32ToBytes(uint32(l), LE), + // data: + newUint64SliceEncoded(l), + ) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + var got []uint64 + + decoder := NewBorshDecoder(buf) + err := decoder.Decode(&got) + if err != nil { + b.Error(err) + } + if len(got) != l { + b.Errorf("got %d, want %d", len(got), l) + } + } +} +func Benchmark_uintSlice64_Decode_make(b *testing.B) { + l := 1024 + buf := concatByteSlices( + // length: + uint32ToBytes(uint32(l), LE), + // data: + newUint64SliceEncoded(l), + ) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + got := make([]uint64, 0) + + decoder := NewBorshDecoder(buf) + err := decoder.Decode(&got) + if err != nil { + b.Error(err) + } + if len(got) != l { + b.Errorf("got %d, want %d", len(got), l) + } + } +} + +func Benchmark_uintSlice64_Decode_field_noMake(b *testing.B) { + l := 1024 + buf := concatByteSlices( + // length: + uint32ToBytes(uint32(l), LE), + // data: + newUint64SliceEncoded(l), + ) + b.ReportAllocs() + b.ResetTimer() + type S struct { + Field []uint64 + } + for i := 0; i < b.N; i++ { + var got S + + decoder := NewBorshDecoder(buf) + err := decoder.Decode(&got) + if err != nil { + b.Error(err) + } + if len(got.Field) != l { + b.Errorf("got %d, want %d", len(got.Field), l) + } + } +} + +func Benchmark_uintSlice64_Decode_field_make(b *testing.B) { + l := 1024 + buf := concatByteSlices( + // length: + uint32ToBytes(uint32(l), LE), + // data: + newUint64SliceEncoded(l), + ) + b.ReportAllocs() + b.ResetTimer() + type S struct { + Field []uint64 + } + for i := 0; i < b.N; i++ { + var got S + got.Field = make([]uint64, 0) + + decoder := NewBorshDecoder(buf) + err := decoder.Decode(&got) + if err != nil { + b.Error(err) + } + if len(got.Field) != l { + b.Errorf("got %d, want %d", len(got.Field), l) + } + } +} + +func Benchmark_uintSlice64_readArray_noMake(b *testing.B) { + l := 1024 + buf := concatByteSlices( + newUint64SliceEncoded(l), + ) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + var got []uint64 + + decoder := NewBorshDecoder(buf) + rv := reflect.ValueOf(&got).Elem() + k := rv.Type().Elem().Kind() + + err := reflect_readArrayOfUint_(decoder, len(buf)/8, k, rv, LE) + if err != nil { + b.Error(err) + } + if len(got) != l { + b.Errorf("got %d, want %d", len(got), l) + } + } +} + +func Benchmark_uintSlice64_readArray_make(b *testing.B) { + l := 1024 + buf := concatByteSlices( + newUint64SliceEncoded(l), + ) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + got := make([]uint64, 0) + + decoder := NewBorshDecoder(buf) + rv := reflect.ValueOf(&got).Elem() + k := rv.Type().Elem().Kind() + + err := reflect_readArrayOfUint_(decoder, len(buf)/8, k, rv, LE) + if err != nil { + b.Error(err) + } + if len(got) != l { + b.Errorf("got %d, want %d", len(got), l) + } + } +} + +type sliceUint64WithCustomDecoder []uint64 + +// UnmarshalWithDecoder +func (s *sliceUint64WithCustomDecoder) UnmarshalWithDecoder(decoder *Decoder) error { + // read length + l, err := decoder.ReadUint32(LE) + if err != nil { + return err + } + // read data + *s = make([]uint64, l) + for i := 0; i < int(l); i++ { + (*s)[i], err = decoder.ReadUint64(LE) + if err != nil { + return err + } + } + return nil +} + +func Benchmark_uintSlice64_Decode_field_withCustomDecoder(b *testing.B) { + l := 1024 + buf := concatByteSlices( + // length: + uint32ToBytes(uint32(l), LE), + // data: + newUint64SliceEncoded(l), + ) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + var got sliceUint64WithCustomDecoder + + decoder := NewBorshDecoder(buf) + err := got.UnmarshalWithDecoder(decoder) + if err != nil { + b.Error(err) + } + if len(got) != l { + b.Errorf("got %d, want %d", len(got), l) + } + } +} + +func newUint32SliceEncoded(l int) []byte { + buf := make([]byte, 0) + for i := 0; i < l; i++ { + buf = append(buf, uint32ToBytes(uint32(i), LE)...) + } + return buf +} + +func Benchmark_uintSlice32_Decode_noMake(b *testing.B) { + l := 1024 + buf := concatByteSlices( + // length: + uint32ToBytes(uint32(l), LE), + // data: + newUint32SliceEncoded(l), + ) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + var got []uint32 + + decoder := NewBorshDecoder(buf) + err := decoder.Decode(&got) + if err != nil { + b.Error(err) + } + if len(got) != l { + b.Errorf("got %d, want %d", len(got), l) + } + } +} +func Benchmark_uintSlice32_Decode_make(b *testing.B) { + l := 1024 + buf := concatByteSlices( + // length: + uint32ToBytes(uint32(l), LE), + // data: + newUint32SliceEncoded(l), + ) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + got := make([]uint32, 0) + + decoder := NewBorshDecoder(buf) + err := decoder.Decode(&got) + if err != nil { + b.Error(err) + } + if len(got) != l { + b.Errorf("got %d, want %d", len(got), l) + } + } +} + +func Benchmark_uintSlice32_Decode_field_noMake(b *testing.B) { + l := 1024 + buf := concatByteSlices( + // length: + uint32ToBytes(uint32(l), LE), + // data: + newUint32SliceEncoded(l), + ) + b.ReportAllocs() + b.ResetTimer() + type S struct { + Field []uint32 + } + for i := 0; i < b.N; i++ { + var got S + + decoder := NewBorshDecoder(buf) + err := decoder.Decode(&got) + if err != nil { + b.Error(err) + } + if len(got.Field) != l { + b.Errorf("got %d, want %d", len(got.Field), l) + } + } +} + +func Benchmark_uintSlice32_Decode_field_make(b *testing.B) { + l := 1024 + buf := concatByteSlices( + // length: + uint32ToBytes(uint32(l), LE), + // data: + newUint32SliceEncoded(l), + ) + b.ReportAllocs() + b.ResetTimer() + type S struct { + Field []uint32 + } + for i := 0; i < b.N; i++ { + var got S + got.Field = make([]uint32, 0) + + decoder := NewBorshDecoder(buf) + err := decoder.Decode(&got) + if err != nil { + b.Error(err) + } + if len(got.Field) != l { + b.Errorf("got %d, want %d", len(got.Field), l) + } + } +} + +func Benchmark_uintSlice32_readArray_noMake(b *testing.B) { + l := 1024 + buf := concatByteSlices( + newUint32SliceEncoded(l), + ) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + var got []uint32 + + decoder := NewBorshDecoder(buf) + rv := reflect.ValueOf(&got).Elem() + k := rv.Type().Elem().Kind() + + err := reflect_readArrayOfUint_(decoder, len(buf)/4, k, rv, LE) + if err != nil { + b.Error(err) + } + if len(got) != l { + b.Errorf("got %d, want %d", len(got), l) + } + } +} + +func Benchmark_uintSlice32_readArray_make(b *testing.B) { + l := 1024 + buf := concatByteSlices( + newUint32SliceEncoded(l), + ) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + got := make([]uint32, 0) + + decoder := NewBorshDecoder(buf) + rv := reflect.ValueOf(&got).Elem() + k := rv.Type().Elem().Kind() + + err := reflect_readArrayOfUint_(decoder, len(buf)/4, k, rv, LE) + if err != nil { + b.Error(err) + } + if len(got) != l { + b.Errorf("got %d, want %d", len(got), l) + } + } +} + +type sliceUint32WithCustomDecoder []uint32 + +// UnmarshalWithDecoder +func (s *sliceUint32WithCustomDecoder) UnmarshalWithDecoder(decoder *Decoder) error { + // read length + l, err := decoder.ReadUint32(LE) + if err != nil { + return err + } + // read data + *s = make([]uint32, l) + for i := 0; i < int(l); i++ { + (*s)[i], err = decoder.ReadUint32(LE) + if err != nil { + return err + } + } + return nil +} +func Benchmark_uintSlice32_Decode_field_withCustomDecoder(b *testing.B) { + l := 1024 + buf := concatByteSlices( + // length: + uint32ToBytes(uint32(l), LE), + // data: + newUint32SliceEncoded(l), + ) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + var got sliceUint32WithCustomDecoder + + decoder := NewBorshDecoder(buf) + err := got.UnmarshalWithDecoder(decoder) + if err != nil { + b.Error(err) + } + if len(got) != l { + b.Errorf("got %d, want %d", len(got), l) + } + } +} diff --git a/decoder_bin.go b/decoder_bin.go index d4b12d9..14e6108 100644 --- a/decoder_bin.go +++ b/decoder_bin.go @@ -155,13 +155,21 @@ func (dec *Decoder) decodeBin(rv reflect.Value, opt *option) (err error) { } switch rt.Kind() { case reflect.Array: - length := rt.Len() + l := rt.Len() if traceEnabled { - zlog.Debug("decoding: reading array", zap.Int("length", length)) + zlog.Debug("decoding: reading array", zap.Int("length", l)) } - for i := 0; i < length; i++ { - if err = dec.decodeBin(rv.Index(i), nil); err != nil { - return + + switch k := rv.Type().Elem().Kind(); k { + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if err := reflect_readArrayOfUint_(dec, l, k, rv, LE); err != nil { + return err + } + default: + for i := 0; i < l; i++ { + if err = dec.decodeBin(rv.Index(i), nil); err != nil { + return + } } } return @@ -185,10 +193,22 @@ func (dec *Decoder) decodeBin(rv reflect.Value, opt *option) (err error) { return io.ErrUnexpectedEOF } - rv.Set(reflect.MakeSlice(rt, l, l)) - for i := 0; i < l; i++ { - if err = dec.decodeBin(rv.Index(i), nil); err != nil { - return + switch k := rv.Type().Elem().Kind(); k { + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if err := reflect_readArrayOfUint_(dec, l, k, rv, LE); err != nil { + return err + } + default: + rv.Set(reflect.MakeSlice(rt, 0, 0)) + for i := 0; i < l; i++ { + // create new element of type rt: + element := reflect.New(rt.Elem()) + // decode into element: + if err = dec.decodeBin(element, nil); err != nil { + return + } + // append to slice: + rv.Set(reflect.Append(rv, element.Elem())) } } diff --git a/decoder_borsh.go b/decoder_borsh.go index 4da42b5..81927ab 100644 --- a/decoder_borsh.go +++ b/decoder_borsh.go @@ -171,13 +171,21 @@ func (dec *Decoder) decodeBorsh(rv reflect.Value, opt *option) (err error) { } switch rt.Kind() { case reflect.Array: - length := rt.Len() + l := rt.Len() if traceEnabled { - zlog.Debug("decoding: reading array", zap.Int("length", length)) + zlog.Debug("decoding: reading array", zap.Int("length", l)) } - for i := 0; i < length; i++ { - if err = dec.decodeBorsh(rv.Index(i), nil); err != nil { - return + + switch k := rv.Type().Elem().Kind(); k { + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if err := reflect_readArrayOfUint_(dec, l, k, rv, LE); err != nil { + return err + } + default: + for i := 0; i < l; i++ { + if err = dec.decodeBorsh(rv.Index(i), nil); err != nil { + return + } } } return @@ -205,10 +213,22 @@ func (dec *Decoder) decodeBorsh(rv reflect.Value, opt *option) (err error) { return io.ErrUnexpectedEOF } - rv.Set(reflect.MakeSlice(rt, l, l)) - for i := 0; i < l; i++ { - if err = dec.decodeBorsh(rv.Index(i), nil); err != nil { - return + switch k := rv.Type().Elem().Kind(); k { + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if err := reflect_readArrayOfUint_(dec, l, k, rv, LE); err != nil { + return err + } + default: + rv.Set(reflect.MakeSlice(rt, 0, 0)) + for i := 0; i < l; i++ { + // create new element of type rt: + element := reflect.New(rt.Elem()) + // decode into element: + if err = dec.decodeBorsh(element, nil); err != nil { + return + } + // append to slice: + rv.Set(reflect.Append(rv, element.Elem())) } } diff --git a/decoder_compact-u16.go b/decoder_compact-u16.go index a7d4519..a5f3d38 100644 --- a/decoder_compact-u16.go +++ b/decoder_compact-u16.go @@ -154,13 +154,21 @@ func (dec *Decoder) decodeCompactU16(rv reflect.Value, opt *option) (err error) } switch rt.Kind() { case reflect.Array: - length := rt.Len() + l := rt.Len() if traceEnabled { - zlog.Debug("decoding: reading array", zap.Int("length", length)) + zlog.Debug("decoding: reading array", zap.Int("length", l)) } - for i := 0; i < length; i++ { - if err = dec.decodeCompactU16(rv.Index(i), nil); err != nil { - return + + switch k := rv.Type().Elem().Kind(); k { + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if err := reflect_readArrayOfUint_(dec, l, k, rv, LE); err != nil { + return err + } + default: + for i := 0; i < l; i++ { + if err = dec.decodeCompactU16(rv.Index(i), nil); err != nil { + return + } } } return @@ -184,10 +192,22 @@ func (dec *Decoder) decodeCompactU16(rv reflect.Value, opt *option) (err error) return io.ErrUnexpectedEOF } - rv.Set(reflect.MakeSlice(rt, l, l)) - for i := 0; i < l; i++ { - if err = dec.decodeCompactU16(rv.Index(i), nil); err != nil { - return + switch k := rv.Type().Elem().Kind(); k { + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if err := reflect_readArrayOfUint_(dec, l, k, rv, LE); err != nil { + return err + } + default: + rv.Set(reflect.MakeSlice(rt, 0, 0)) + for i := 0; i < l; i++ { + // create new element of type rt: + element := reflect.New(rt.Elem()) + // decode into element: + if err = dec.decodeCompactU16(element, nil); err != nil { + return + } + // append to slice: + rv.Set(reflect.Append(rv, element.Elem())) } } diff --git a/decoder_test.go b/decoder_test.go index a148e98..86ccfac 100644 --- a/decoder_test.go +++ b/decoder_test.go @@ -21,6 +21,7 @@ import ( "encoding/binary" "encoding/hex" "math" + "reflect" "testing" "github.com/stretchr/testify/assert" @@ -763,5 +764,513 @@ func TestDecoder_SkipBytes(t *testing.T) { err = decoder.SkipBytes(5) require.NoError(t, err) require.Equal(t, 0, decoder.Remaining()) +} + +func Test_Discard(t *testing.T) { + buf := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + decoder := NewBinDecoder(buf) + err := decoder.Discard(5) + require.NoError(t, err) + require.Equal(t, 5, decoder.Remaining()) + remaining, err := decoder.Peek(decoder.Remaining()) + require.NoError(t, err) + require.Equal(t, []byte{5, 6, 7, 8, 9}, remaining) +} +func Test_reflect_readArrayOfBytes(t *testing.T) { + { + { + buf := []byte{0, 1, 2, 3, 4, 5, 6, 7} + decoder := NewBinDecoder(buf) + + got := make([]byte, 0) + err := reflect_readArrayOfBytes(decoder, len(buf), reflect.ValueOf(&got).Elem()) + require.NoError(t, err) + require.Equal(t, buf, got) + } + { + buf := []byte{0, 1, 2, 3, 4, 5, 6, 7} + decoder := NewBinDecoder(buf) + + got := [8]byte{0, 0, 0, 0, 0, 0, 0, 0} + err := reflect_readArrayOfBytes(decoder, len(buf), reflect.ValueOf(&got).Elem()) + require.NoError(t, err) + require.Equal(t, buf, got[:]) + } + } + { + { + buf := []byte{0, 1, 2, 3, 4, 5, 6, 7} + decoder := NewBorshDecoder(buf) + + got := make([]byte, 0) + err := reflect_readArrayOfBytes(decoder, len(buf), reflect.ValueOf(&got).Elem()) + require.NoError(t, err) + require.Equal(t, buf, got) + } + { + buf := []byte{0, 1, 2, 3, 4, 5, 6, 7} + decoder := NewBorshDecoder(buf) + + got := [8]byte{0, 0, 0, 0, 0, 0, 0, 0} + err := reflect_readArrayOfBytes(decoder, len(buf), reflect.ValueOf(&got).Elem()) + require.NoError(t, err) + require.Equal(t, buf, got[:]) + } + } +} + +func Test_reflect_readArrayOfUint16(t *testing.T) { + { + { + buf := concatByteSlices( + uint16ToBytes(0, LE), + uint16ToBytes(1, LE), + uint16ToBytes(2, LE), + uint16ToBytes(3, LE), + uint16ToBytes(4, LE), + uint16ToBytes(5, LE), + uint16ToBytes(6, LE), + uint16ToBytes(7, LE), + ) + decoder := NewBinDecoder(buf) + + got := make([]uint16, 0) + err := reflect_readArrayOfUint16(decoder, len(buf)/2, reflect.ValueOf(&got).Elem(), LE) + require.NoError(t, err) + require.Equal(t, []uint16{0, 1, 2, 3, 4, 5, 6, 7}, got) + } + { + buf := concatByteSlices( + uint16ToBytes(0, LE), + uint16ToBytes(1, LE), + uint16ToBytes(2, LE), + uint16ToBytes(3, LE), + uint16ToBytes(4, LE), + uint16ToBytes(5, LE), + uint16ToBytes(6, LE), + uint16ToBytes(7, LE), + ) + decoder := NewBinDecoder(buf) + + got := [8]uint16{0, 0, 0, 0, 0, 0, 0, 0} + err := reflect_readArrayOfUint16(decoder, len(buf)/2, reflect.ValueOf(&got).Elem(), LE) + require.NoError(t, err) + require.Equal(t, []uint16{0, 1, 2, 3, 4, 5, 6, 7}, got[:]) + } + } + { + { + buf := concatByteSlices( + uint16ToBytes(0, LE), + uint16ToBytes(1, LE), + uint16ToBytes(2, LE), + uint16ToBytes(3, LE), + uint16ToBytes(4, LE), + uint16ToBytes(5, LE), + uint16ToBytes(6, LE), + uint16ToBytes(7, LE), + ) + decoder := NewBorshDecoder(buf) + + got := make([]uint16, 0) + err := reflect_readArrayOfUint16(decoder, len(buf)/2, reflect.ValueOf(&got).Elem(), LE) + require.NoError(t, err) + require.Equal(t, []uint16{0, 1, 2, 3, 4, 5, 6, 7}, got) + } + { + buf := concatByteSlices( + uint16ToBytes(0, LE), + uint16ToBytes(1, LE), + uint16ToBytes(2, LE), + uint16ToBytes(3, LE), + uint16ToBytes(4, LE), + uint16ToBytes(5, LE), + uint16ToBytes(6, LE), + uint16ToBytes(7, LE), + ) + decoder := NewBorshDecoder(buf) + + got := [8]uint16{0, 0, 0, 0, 0, 0, 0, 0} + err := reflect_readArrayOfUint16(decoder, len(buf)/2, reflect.ValueOf(&got).Elem(), LE) + require.NoError(t, err) + require.Equal(t, []uint16{0, 1, 2, 3, 4, 5, 6, 7}, got[:]) + } + } +} + +func Test_reflect_readArrayOfUint32(t *testing.T) { + { + { + buf := concatByteSlices( + uint32ToBytes(0, LE), + uint32ToBytes(1, LE), + uint32ToBytes(2, LE), + uint32ToBytes(3, LE), + uint32ToBytes(4, LE), + uint32ToBytes(5, LE), + uint32ToBytes(6, LE), + uint32ToBytes(7, LE), + ) + decoder := NewBinDecoder(buf) + + got := make([]uint32, 0) + err := reflect_readArrayOfUint32(decoder, len(buf)/4, reflect.ValueOf(&got).Elem(), LE) + require.NoError(t, err) + require.Equal(t, []uint32{0, 1, 2, 3, 4, 5, 6, 7}, got) + } + { + buf := concatByteSlices( + uint32ToBytes(0, LE), + uint32ToBytes(1, LE), + uint32ToBytes(2, LE), + uint32ToBytes(3, LE), + uint32ToBytes(4, LE), + uint32ToBytes(5, LE), + uint32ToBytes(6, LE), + uint32ToBytes(7, LE), + ) + decoder := NewBinDecoder(buf) + + got := [8]uint32{0, 0, 0, 0, 0, 0, 0, 0} + err := reflect_readArrayOfUint32(decoder, len(buf)/4, reflect.ValueOf(&got).Elem(), LE) + require.NoError(t, err) + require.Equal(t, []uint32{0, 1, 2, 3, 4, 5, 6, 7}, got[:]) + } + } + { + { + buf := concatByteSlices( + uint32ToBytes(0, LE), + uint32ToBytes(1, LE), + uint32ToBytes(2, LE), + uint32ToBytes(3, LE), + uint32ToBytes(4, LE), + uint32ToBytes(5, LE), + uint32ToBytes(6, LE), + uint32ToBytes(7, LE), + ) + decoder := NewBorshDecoder(buf) + + got := make([]uint32, 0) + err := reflect_readArrayOfUint32(decoder, len(buf)/4, reflect.ValueOf(&got).Elem(), LE) + require.NoError(t, err) + require.Equal(t, []uint32{0, 1, 2, 3, 4, 5, 6, 7}, got) + } + { + buf := concatByteSlices( + uint32ToBytes(0, LE), + uint32ToBytes(1, LE), + uint32ToBytes(2, LE), + uint32ToBytes(3, LE), + uint32ToBytes(4, LE), + uint32ToBytes(5, LE), + uint32ToBytes(6, LE), + uint32ToBytes(7, LE), + ) + decoder := NewBorshDecoder(buf) + + got := [8]uint32{0, 0, 0, 0, 0, 0, 0, 0} + err := reflect_readArrayOfUint32(decoder, len(buf)/4, reflect.ValueOf(&got).Elem(), LE) + require.NoError(t, err) + require.Equal(t, []uint32{0, 1, 2, 3, 4, 5, 6, 7}, got[:]) + } + } +} + +func Test_reflect_readArrayOfUint64(t *testing.T) { + { + { + buf := concatByteSlices( + uint64ToBytes(0, LE), + uint64ToBytes(1, LE), + uint64ToBytes(2, LE), + uint64ToBytes(3, LE), + uint64ToBytes(4, LE), + uint64ToBytes(5, LE), + uint64ToBytes(6, LE), + uint64ToBytes(7, LE), + ) + decoder := NewBinDecoder(buf) + + got := make([]uint64, 0) + err := reflect_readArrayOfUint64(decoder, len(buf)/8, reflect.ValueOf(&got).Elem(), LE) + require.NoError(t, err) + require.Equal(t, []uint64{0, 1, 2, 3, 4, 5, 6, 7}, got) + } + { + buf := concatByteSlices( + uint64ToBytes(0, LE), + uint64ToBytes(1, LE), + uint64ToBytes(2, LE), + uint64ToBytes(3, LE), + uint64ToBytes(4, LE), + uint64ToBytes(5, LE), + uint64ToBytes(6, LE), + uint64ToBytes(7, LE), + ) + decoder := NewBinDecoder(buf) + got := [8]uint64{0, 0, 0, 0, 0, 0, 0, 0} + err := reflect_readArrayOfUint64(decoder, len(buf)/8, reflect.ValueOf(&got).Elem(), LE) + require.NoError(t, err) + require.Equal(t, []uint64{0, 1, 2, 3, 4, 5, 6, 7}, got[:]) + } + } + { + { + buf := concatByteSlices( + uint64ToBytes(0, LE), + uint64ToBytes(1, LE), + uint64ToBytes(2, LE), + uint64ToBytes(3, LE), + uint64ToBytes(4, LE), + uint64ToBytes(5, LE), + uint64ToBytes(6, LE), + uint64ToBytes(7, LE), + ) + decoder := NewBorshDecoder(buf) + + got := make([]uint64, 0) + err := reflect_readArrayOfUint64(decoder, len(buf)/8, reflect.ValueOf(&got).Elem(), LE) + require.NoError(t, err) + require.Equal(t, []uint64{0, 1, 2, 3, 4, 5, 6, 7}, got) + } + { + buf := concatByteSlices( + uint64ToBytes(0, LE), + uint64ToBytes(1, LE), + uint64ToBytes(2, LE), + uint64ToBytes(3, LE), + uint64ToBytes(4, LE), + uint64ToBytes(5, LE), + uint64ToBytes(6, LE), + uint64ToBytes(7, LE), + ) + decoder := NewBorshDecoder(buf) + got := [8]uint64{0, 0, 0, 0, 0, 0, 0, 0} + err := reflect_readArrayOfUint64(decoder, len(buf)/8, reflect.ValueOf(&got).Elem(), LE) + require.NoError(t, err) + require.Equal(t, []uint64{0, 1, 2, 3, 4, 5, 6, 7}, got[:]) + } + } +} + +func Test_reflect_readArrayOfUint(t *testing.T) { + { + { + buf := concatByteSlices( + uint32ToBytes(0, LE), + uint32ToBytes(1, LE), + uint32ToBytes(2, LE), + uint32ToBytes(3, LE), + uint32ToBytes(4, LE), + uint32ToBytes(5, LE), + uint32ToBytes(6, LE), + uint32ToBytes(7, LE), + ) + decoder := NewBinDecoder(buf) + + got := make([]uint32, 0) + rv := reflect.ValueOf(&got).Elem() + k := rv.Type().Elem().Kind() + err := reflect_readArrayOfUint_(decoder, len(buf)/4, k, rv, LE) + require.NoError(t, err) + require.Equal(t, []uint32{0, 1, 2, 3, 4, 5, 6, 7}, got) + } + { + buf := concatByteSlices( + uint32ToBytes(0, LE), + uint32ToBytes(1, LE), + uint32ToBytes(2, LE), + uint32ToBytes(3, LE), + uint32ToBytes(4, LE), + uint32ToBytes(5, LE), + uint32ToBytes(6, LE), + uint32ToBytes(7, LE), + ) + decoder := NewBinDecoder(buf) + got := [8]uint32{0, 0, 0, 0, 0, 0, 0, 0} + rv := reflect.ValueOf(&got).Elem() + k := rv.Type().Elem().Kind() + err := reflect_readArrayOfUint_(decoder, len(buf)/4, k, rv, LE) + require.NoError(t, err) + require.Equal(t, []uint32{0, 1, 2, 3, 4, 5, 6, 7}, got[:]) + } + } +} + +func Test_Decode_readArrayOfUint(t *testing.T) { + { + { + buf := concatByteSlices( + // length: + []byte{3}, + // data: + uint32ToBytes(0, LE), + uint32ToBytes(1, LE), + uint32ToBytes(2, LE), + ) + decoder := NewBinDecoder(buf) + + got := make([]uint32, 0) + err := decoder.Decode(&got) + require.NoError(t, err) + require.Equal(t, []uint32{0, 1, 2}, got) + } + { + buf := concatByteSlices( + uint32ToBytes(0, LE), + uint32ToBytes(1, LE), + uint32ToBytes(2, LE), + uint32ToBytes(3, LE), + uint32ToBytes(4, LE), + uint32ToBytes(5, LE), + uint32ToBytes(6, LE), + uint32ToBytes(7, LE), + ) + decoder := NewBinDecoder(buf) + got := [8]uint32{0, 0, 0, 0, 0, 0, 0, 0} + err := decoder.Decode(&got) + require.NoError(t, err) + require.Equal(t, []uint32{0, 1, 2, 3, 4, 5, 6, 7}, got[:]) + } + } + { + { + buf := concatByteSlices( + // length: + uint32ToBytes(8, LE), + // data: + uint32ToBytes(0, LE), + uint32ToBytes(1, LE), + uint32ToBytes(2, LE), + uint32ToBytes(3, LE), + uint32ToBytes(4, LE), + uint32ToBytes(5, LE), + uint32ToBytes(6, LE), + uint32ToBytes(7, LE), + ) + decoder := NewBorshDecoder(buf) + + got := make([]uint32, 0) + err := decoder.Decode(&got) + require.NoError(t, err) + require.Equal(t, []uint32{0, 1, 2, 3, 4, 5, 6, 7}, got) + } + { + buf := concatByteSlices( + uint32ToBytes(0, LE), + uint32ToBytes(1, LE), + uint32ToBytes(2, LE), + uint32ToBytes(3, LE), + uint32ToBytes(4, LE), + uint32ToBytes(5, LE), + uint32ToBytes(6, LE), + uint32ToBytes(7, LE), + ) + decoder := NewBorshDecoder(buf) + got := [8]uint32{0, 0, 0, 0, 0, 0, 0, 0} + err := decoder.Decode(&got) + require.NoError(t, err) + require.Equal(t, []uint32{0, 1, 2, 3, 4, 5, 6, 7}, got[:]) + } + } +} + +func Test_reflect_readArrayOfUint16_asField(t *testing.T) { + { + { + buf := concatByteSlices( + // length: + []byte{8}, + // data: + uint16ToBytes(0, LE), + uint16ToBytes(1, LE), + uint16ToBytes(2, LE), + uint16ToBytes(3, LE), + uint16ToBytes(4, LE), + uint16ToBytes(5, LE), + uint16ToBytes(6, LE), + uint16ToBytes(7, LE), + ) + decoder := NewBinDecoder(buf) + + type S struct { + Val []uint16 + } + var got S + err := decoder.Decode(&got) + require.NoError(t, err) + require.Equal(t, S{[]uint16{0, 1, 2, 3, 4, 5, 6, 7}}, got) + } + { + buf := concatByteSlices( + // data: + uint16ToBytes(0, LE), + uint16ToBytes(1, LE), + uint16ToBytes(2, LE), + uint16ToBytes(3, LE), + uint16ToBytes(4, LE), + uint16ToBytes(5, LE), + uint16ToBytes(6, LE), + uint16ToBytes(7, LE), + ) + decoder := NewBinDecoder(buf) + + type S struct { + Val [8]uint16 + } + var got S + err := decoder.Decode(&got) + require.NoError(t, err) + require.Equal(t, S{[8]uint16{0, 1, 2, 3, 4, 5, 6, 7}}, got) + } + } + { + { + buf := concatByteSlices( + // length: + uint32ToBytes(8, LE), + // data: + uint16ToBytes(0, LE), + uint16ToBytes(1, LE), + uint16ToBytes(2, LE), + uint16ToBytes(3, LE), + uint16ToBytes(4, LE), + uint16ToBytes(5, LE), + uint16ToBytes(6, LE), + uint16ToBytes(7, LE), + ) + decoder := NewBorshDecoder(buf) + + type S struct { + Val []uint16 + } + var got S + err := decoder.Decode(&got) + require.NoError(t, err) + require.Equal(t, S{[]uint16{0, 1, 2, 3, 4, 5, 6, 7}}, got) + } + { + buf := concatByteSlices( + uint16ToBytes(0, LE), + uint16ToBytes(1, LE), + uint16ToBytes(2, LE), + uint16ToBytes(3, LE), + uint16ToBytes(4, LE), + uint16ToBytes(5, LE), + uint16ToBytes(6, LE), + uint16ToBytes(7, LE), + ) + decoder := NewBorshDecoder(buf) + + type S struct { + Val [8]uint16 + } + var got S + err := decoder.Decode(&got) + require.NoError(t, err) + require.Equal(t, S{[8]uint16{0, 1, 2, 3, 4, 5, 6, 7}}, got) + } + } } diff --git a/encoder.go b/encoder.go index aa279b5..fa2f939 100644 --- a/encoder.go +++ b/encoder.go @@ -87,11 +87,9 @@ func (e *Encoder) Encode(v interface{}) (err error) { func (e *Encoder) toWriter(bytes []byte) (err error) { e.count += len(bytes) - if traceEnabled { zlog.Debug(" > encode: appending", zap.Stringer("hex", HexBytes(bytes)), zap.Int("pos", e.count)) } - _, err = e.output.Write(bytes) return } @@ -257,7 +255,7 @@ func (e *Encoder) WriteFloat32(f float32, order binary.ByteOrder) (err error) { } if e.IsBorsh() { - if float64(f) == math.NaN() { + if math.IsNaN(float64(f)) { return errors.New("NaN float value") } } @@ -268,13 +266,14 @@ func (e *Encoder) WriteFloat32(f float32, order binary.ByteOrder) (err error) { return e.toWriter(buf) } + func (e *Encoder) WriteFloat64(f float64, order binary.ByteOrder) (err error) { if traceEnabled { zlog.Debug("encode: write float64", zap.Float64("val", f)) } if e.IsBorsh() { - if float64(f) == math.NaN() { + if math.IsNaN(float64(f)) { return errors.New("NaN float value") } } @@ -312,14 +311,56 @@ func (e *Encoder) WriteCompactU16Length(ln int) (err error) { return e.toWriter(buf) } -// TODO: add rust string. -// https://github.com/bmresearch/Solnet/blob/7826cc93ec6c997fc997a7a3c6be0f3511ca0c63/src/Solnet.Programs/Utilities/Serialization.cs#L219 -// public static byte[] EncodeRustString(string data) -// { -// byte[] stringBytes = Encoding.ASCII.GetBytes(data); -// byte[] encoded = new byte[stringBytes.Length + sizeof(uint)]; +func reflect_writeArrayOfBytes(e *Encoder, l int, rv reflect.Value) error { + arr := make([]byte, l) + for i := 0; i < l; i++ { + arr[i] = byte(rv.Index(i).Uint()) + } + return e.WriteBytes(arr, false) +} -// encoded.WriteU32((uint) stringBytes.Length, 0); -// encoded.WriteSpan(stringBytes, sizeof(uint)); -// return encoded; -// } +func reflect_writeArrayOfUint16(e *Encoder, l int, rv reflect.Value, order binary.ByteOrder) error { + arr := make([]byte, l*2) + for i := 0; i < l; i++ { + order.PutUint16(arr[i*2:], uint16(rv.Index(i).Uint())) + } + return e.WriteBytes(arr, false) +} + +func reflect_writeArrayOfUint32(e *Encoder, l int, rv reflect.Value, order binary.ByteOrder) error { + arr := make([]byte, l*4) + for i := 0; i < l; i++ { + order.PutUint32(arr[i*4:], uint32(rv.Index(i).Uint())) + } + return e.WriteBytes(arr, false) +} + +func reflect_writeArrayOfUint64(e *Encoder, l int, rv reflect.Value, order binary.ByteOrder) error { + arr := make([]byte, l*8) + for i := 0; i < l; i++ { + order.PutUint64(arr[i*8:], uint64(rv.Index(i).Uint())) + } + return e.WriteBytes(arr, false) +} + +// reflect_writeArrayOfUint_ is used for writing arrays/slices of uints of any size. +func reflect_writeArrayOfUint_(e *Encoder, l int, k reflect.Kind, rv reflect.Value, order binary.ByteOrder) error { + switch k { + // case reflect.Uint: + // // switch on system architecture (32 or 64 bit) + // if unsafe.Sizeof(uintptr(0)) == 4 { + // return reflect_writeArrayOfUint32(e, l, rv, order) + // } + // return reflect_writeArrayOfUint64(e, l, rv, order) + case reflect.Uint8: + return reflect_writeArrayOfBytes(e, l, rv) + case reflect.Uint16: + return reflect_writeArrayOfUint16(e, l, rv, order) + case reflect.Uint32: + return reflect_writeArrayOfUint32(e, l, rv, order) + case reflect.Uint64: + return reflect_writeArrayOfUint64(e, l, rv, order) + default: + return fmt.Errorf("unsupported kind: %v", k) + } +} diff --git a/encoder_bin.go b/encoder_bin.go index 4d286a8..585b41e 100644 --- a/encoder_bin.go +++ b/encoder_bin.go @@ -107,22 +107,20 @@ func (e *Encoder) encodeBin(rv reflect.Value, opt *option) (err error) { zlog.Debug("encode: array", zap.Int("length", l), zap.Stringer("type", rv.Kind())) } - if rv.Type().Elem().Kind() == reflect.Uint8 { + switch k := rv.Type().Elem().Kind(); k { + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: // if it's a [n]byte, accumulate and write in one command: - arr := make([]byte, l) - for i := 0; i < l; i++ { - arr[i] = byte(rv.Index(i).Uint()) - } - if err := e.WriteBytes(arr, false); err != nil { + if err := reflect_writeArrayOfUint_(e, l, k, rv, LE); err != nil { return err } - } else { + default: for i := 0; i < l; i++ { if err = e.encodeBin(rv.Index(i), nil); err != nil { return } } } + case reflect.Slice: var l int if opt.hasSizeOfSlice() { @@ -144,9 +142,17 @@ func (e *Encoder) encodeBin(rv reflect.Value, opt *option) (err error) { // we would want to skip to the correct head_offset - for i := 0; i < l; i++ { - if err = e.encodeBin(rv.Index(i), nil); err != nil { - return + switch k := rv.Type().Elem().Kind(); k { + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + // if it's a [n]byte, accumulate and write in one command: + if err := reflect_writeArrayOfUint_(e, l, k, rv, LE); err != nil { + return err + } + default: + for i := 0; i < l; i++ { + if err = e.encodeBin(rv.Index(i), nil); err != nil { + return + } } } case reflect.Struct: diff --git a/encoder_borsh.go b/encoder_borsh.go index 728af55..75ebf59 100644 --- a/encoder_borsh.go +++ b/encoder_borsh.go @@ -139,16 +139,13 @@ func (e *Encoder) encodeBorsh(rv reflect.Value, opt *option) (err error) { zlog.Debug("encode: array", zap.Int("length", l), zap.Stringer("type", rv.Kind())) } - if rv.Type().Elem().Kind() == reflect.Uint8 { + switch k := rv.Type().Elem().Kind(); k { + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: // if it's a [n]byte, accumulate and write in one command: - arr := make([]byte, l) - for i := 0; i < l; i++ { - arr[i] = byte(rv.Index(i).Uint()) - } - if err := e.WriteBytes(arr, false); err != nil { + if err := reflect_writeArrayOfUint_(e, l, k, rv, LE); err != nil { return err } - } else { + default: for i := 0; i < l; i++ { if err = e.encodeBorsh(rv.Index(i), nil); err != nil { return @@ -176,11 +173,20 @@ func (e *Encoder) encodeBorsh(rv reflect.Value, opt *option) (err error) { // we would want to skip to the correct head_offset - for i := 0; i < l; i++ { - if err = e.encodeBorsh(rv.Index(i), nil); err != nil { - return + switch k := rv.Type().Elem().Kind(); k { + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + // if it's a [n]byte, accumulate and write in one command: + if err := reflect_writeArrayOfUint_(e, l, k, rv, LE); err != nil { + return err + } + default: + for i := 0; i < l; i++ { + if err = e.encodeBorsh(rv.Index(i), nil); err != nil { + return + } } } + case reflect.Struct: if err = e.encodeStructBorsh(rt, rv); err != nil { return diff --git a/encoder_compact-u16.go b/encoder_compact-u16.go index a9ab4fd..488dda6 100644 --- a/encoder_compact-u16.go +++ b/encoder_compact-u16.go @@ -106,16 +106,13 @@ func (e *Encoder) encodeCompactU16(rv reflect.Value, opt *option) (err error) { zlog.Debug("encode: array", zap.Int("length", l), zap.Stringer("type", rv.Kind())) } - if rv.Type().Elem().Kind() == reflect.Uint8 { + switch k := rv.Type().Elem().Kind(); k { + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: // if it's a [n]byte, accumulate and write in one command: - arr := make([]byte, l) - for i := 0; i < l; i++ { - arr[i] = byte(rv.Index(i).Uint()) - } - if err := e.WriteBytes(arr, false); err != nil { + if err := reflect_writeArrayOfUint_(e, l, k, rv, LE); err != nil { return err } - } else { + default: for i := 0; i < l; i++ { if err = e.encodeCompactU16(rv.Index(i), nil); err != nil { return @@ -143,9 +140,17 @@ func (e *Encoder) encodeCompactU16(rv reflect.Value, opt *option) (err error) { // we would want to skip to the correct head_offset - for i := 0; i < l; i++ { - if err = e.encodeCompactU16(rv.Index(i), nil); err != nil { - return + switch k := rv.Type().Elem().Kind(); k { + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + // if it's a [n]byte, accumulate and write in one command: + if err := reflect_writeArrayOfUint_(e, l, k, rv, LE); err != nil { + return err + } + default: + for i := 0; i < l; i++ { + if err = e.encodeCompactU16(rv.Index(i), nil); err != nil { + return + } } } case reflect.Struct: diff --git a/encoder_test.go b/encoder_test.go index 95440b3..4afb5f9 100644 --- a/encoder_test.go +++ b/encoder_test.go @@ -22,6 +22,7 @@ import ( "encoding/binary" "encoding/hex" "math" + "reflect" "testing" "github.com/stretchr/testify/assert" @@ -641,3 +642,636 @@ func TestEncoder_InterfaceNil(t *testing.T) { err := enc.Encode(foo) assert.NoError(t, err) } + +func TestByteArrays(t *testing.T) { + { + var buf bytes.Buffer + enc := NewBinEncoder(&buf) + err := enc.Encode([3]byte{1, 2, 3}) + assert.NoError(t, err) + assert.Equal(t, []byte{1, 2, 3}, buf.Bytes()) + } + { + var buf bytes.Buffer + enc := NewBorshEncoder(&buf) + err := enc.Encode([3]byte{1, 2, 3}) + assert.NoError(t, err) + assert.Equal(t, []byte{1, 2, 3}, buf.Bytes()) + } +} + +func TestUintArrays(t *testing.T) { + { + { + var buf bytes.Buffer + enc := NewBinEncoder(&buf) + err := enc.Encode([3]uint8{1, 2, 3}) + assert.NoError(t, err) + assert.Equal(t, []byte{1, 2, 3}, buf.Bytes()) + } + { + var buf bytes.Buffer + enc := NewBorshEncoder(&buf) + err := enc.Encode([3]uint8{1, 2, 3}) + assert.NoError(t, err) + assert.Equal(t, []byte{1, 2, 3}, buf.Bytes()) + } + } + { + { + var buf bytes.Buffer + enc := NewBinEncoder(&buf) + err := enc.Encode([3]uint16{1, 2, 3}) + assert.NoError(t, err) + assert.Equal(t, concatByteSlices( + []byte{1, 0, 2, 0, 3, 0}, + ), buf.Bytes()) + } + { + var buf bytes.Buffer + enc := NewBorshEncoder(&buf) + err := enc.Encode([3]uint16{1, 2, 3}) + assert.NoError(t, err) + assert.Equal(t, concatByteSlices( + []byte{1, 0, 2, 0, 3, 0}, + ), buf.Bytes()) + } + } + { + { + var buf bytes.Buffer + enc := NewBinEncoder(&buf) + err := enc.Encode([3]uint32{1, 2, 3}) + assert.NoError(t, err) + assert.Equal(t, + concatByteSlices( + []byte{1, 0, 0, 0}, + []byte{2, 0, 0, 0}, + []byte{3, 0, 0, 0}, + ), buf.Bytes()) + } + { + var buf bytes.Buffer + enc := NewBorshEncoder(&buf) + err := enc.Encode([3]uint32{1, 2, 3}) + assert.NoError(t, err) + assert.Equal(t, + concatByteSlices( + []byte{1, 0, 0, 0}, + []byte{2, 0, 0, 0}, + []byte{3, 0, 0, 0}, + ), buf.Bytes()) + } + } + { + { + var buf bytes.Buffer + enc := NewBinEncoder(&buf) + err := enc.Encode([3]uint64{1, 2, 3}) + assert.NoError(t, err) + assert.Equal(t, + concatByteSlices( + []byte{1, 0, 0, 0, 0, 0, 0, 0}, + []byte{2, 0, 0, 0, 0, 0, 0, 0}, + []byte{3, 0, 0, 0, 0, 0, 0, 0}, + ), buf.Bytes()) + } + { + var buf bytes.Buffer + enc := NewBorshEncoder(&buf) + err := enc.Encode([3]uint64{1, 2, 3}) + assert.NoError(t, err) + assert.Equal(t, + concatByteSlices( + []byte{1, 0, 0, 0, 0, 0, 0, 0}, + []byte{2, 0, 0, 0, 0, 0, 0, 0}, + []byte{3, 0, 0, 0, 0, 0, 0, 0}, + ), buf.Bytes()) + } + } +} + +func TestUintSlices(t *testing.T) { + { + { + var buf bytes.Buffer + enc := NewBinEncoder(&buf) + err := enc.Encode([]uint8{1, 2, 3}) + assert.NoError(t, err) + assert.Equal(t, concatByteSlices( + // length: + []byte{3}, + // data: + []byte{1, 2, 3}, + ), buf.Bytes()) + } + { + var buf bytes.Buffer + enc := NewBorshEncoder(&buf) + err := enc.Encode([]uint8{1, 2, 3}) + assert.NoError(t, err) + assert.Equal(t, + concatByteSlices( + // length: + []byte{3, 0, 0, 0}, + // data: + []byte{1, 2, 3}, + ), buf.Bytes()) + } + } + { + { + var buf bytes.Buffer + enc := NewBinEncoder(&buf) + err := enc.Encode([]uint16{1, 2, 3}) + assert.NoError(t, err) + assert.Equal(t, concatByteSlices( + // length: + []byte{3}, + // data: + []byte{1, 0, 2, 0, 3, 0}, + ), buf.Bytes()) + } + { + var buf bytes.Buffer + enc := NewBorshEncoder(&buf) + err := enc.Encode([]uint16{1, 2, 3}) + assert.NoError(t, err) + assert.Equal(t, concatByteSlices( + // length: + []byte{3, 0, 0, 0}, + // data: + []byte{1, 0, 2, 0, 3, 0}, + ), buf.Bytes()) + } + } + { + { + var buf bytes.Buffer + enc := NewBinEncoder(&buf) + err := enc.Encode([]uint32{1, 2, 3}) + assert.NoError(t, err) + assert.Equal(t, + concatByteSlices( + // length: + []byte{3}, + // data: + []byte{1, 0, 0, 0}, + []byte{2, 0, 0, 0}, + []byte{3, 0, 0, 0}, + ), buf.Bytes()) + } + { + var buf bytes.Buffer + enc := NewBorshEncoder(&buf) + err := enc.Encode([]uint32{1, 2, 3}) + assert.NoError(t, err) + assert.Equal(t, + concatByteSlices( + // length: + []byte{3, 0, 0, 0}, + // data: + []byte{1, 0, 0, 0}, + []byte{2, 0, 0, 0}, + []byte{3, 0, 0, 0}, + ), buf.Bytes()) + } + } + { + { + var buf bytes.Buffer + enc := NewBinEncoder(&buf) + err := enc.Encode([]uint64{1, 2, 3}) + assert.NoError(t, err) + assert.Equal(t, + concatByteSlices( + // length: + []byte{3}, + // data: + []byte{1, 0, 0, 0, 0, 0, 0, 0}, + []byte{2, 0, 0, 0, 0, 0, 0, 0}, + []byte{3, 0, 0, 0, 0, 0, 0, 0}, + ), buf.Bytes()) + } + { + var buf bytes.Buffer + enc := NewBorshEncoder(&buf) + err := enc.Encode([]uint64{1, 2, 3}) + assert.NoError(t, err) + assert.Equal(t, + concatByteSlices( + // length: + []byte{3, 0, 0, 0}, + // data: + []byte{1, 0, 0, 0, 0, 0, 0, 0}, + []byte{2, 0, 0, 0, 0, 0, 0, 0}, + []byte{3, 0, 0, 0, 0, 0, 0, 0}, + ), buf.Bytes()) + } + } +} + +func Test_writeArrayOfBytes(t *testing.T) { + { + var buf bytes.Buffer + enc := NewBinEncoder(&buf) + + arr := [3]byte{1, 2, 3} + l := len(arr) + + err := reflect_writeArrayOfBytes(enc, l, reflect.ValueOf(arr)) + assert.NoError(t, err) + assert.Equal(t, arr[:], buf.Bytes()) + } + { + var buf bytes.Buffer + enc := NewBinEncoder(&buf) + + arr := [10]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} + l := len(arr) + + err := reflect_writeArrayOfBytes(enc, l, reflect.ValueOf(arr)) + assert.NoError(t, err) + assert.Equal(t, arr[:], buf.Bytes()) + } + { + var buf bytes.Buffer + enc := NewBinEncoder(&buf) + + arr := [32]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32} + l := len(arr) + + err := reflect_writeArrayOfBytes(enc, l, reflect.ValueOf(arr)) + assert.NoError(t, err) + assert.Equal(t, arr[:], buf.Bytes()) + } +} + +func Test_writeArrayOfUint16(t *testing.T) { + { + { + var buf bytes.Buffer + enc := NewBinEncoder(&buf) + + arr := [3]uint16{1, 2, 3} + l := len(arr) + + err := reflect_writeArrayOfUint16(enc, l, reflect.ValueOf(arr), LE) + assert.NoError(t, err) + assert.Equal(t, concatByteSlices( + []byte{1, 0, 2, 0, 3, 0}, + ), buf.Bytes()) + } + { + var buf bytes.Buffer + enc := NewBorshEncoder(&buf) + + arr := [3]uint16{1, 2, 3} + l := len(arr) + + err := reflect_writeArrayOfUint16(enc, l, reflect.ValueOf(arr), LE) + assert.NoError(t, err) + assert.Equal(t, concatByteSlices( + []byte{1, 0, 2, 0, 3, 0}, + ), buf.Bytes()) + } + } + { + { + var buf bytes.Buffer + enc := NewBinEncoder(&buf) + + arr := []uint16{1, 2, 3} + l := len(arr) + + err := reflect_writeArrayOfUint16(enc, l, reflect.ValueOf(arr), LE) + assert.NoError(t, err) + assert.Equal(t, concatByteSlices( + []byte{1, 0, 2, 0, 3, 0}, + ), buf.Bytes()) + } + { + var buf bytes.Buffer + enc := NewBorshEncoder(&buf) + + arr := []uint16{1, 2, 3} + l := len(arr) + + err := reflect_writeArrayOfUint16(enc, l, reflect.ValueOf(arr), LE) + assert.NoError(t, err) + assert.Equal(t, concatByteSlices( + []byte{1, 0, 2, 0, 3, 0}, + ), buf.Bytes()) + } + } +} + +func Test_writeArrayOfUint32(t *testing.T) { + { + { + var buf bytes.Buffer + enc := NewBinEncoder(&buf) + + arr := [3]uint32{1, 2, 3} + l := len(arr) + + err := reflect_writeArrayOfUint32(enc, l, reflect.ValueOf(arr), LE) + assert.NoError(t, err) + assert.Equal(t, + concatByteSlices( + []byte{1, 0, 0, 0}, + []byte{2, 0, 0, 0}, + []byte{3, 0, 0, 0}, + ), + buf.Bytes(), + ) + } + { + var buf bytes.Buffer + enc := NewBorshEncoder(&buf) + + arr := [3]uint32{1, 2, 3} + l := len(arr) + + err := reflect_writeArrayOfUint32(enc, l, reflect.ValueOf(arr), LE) + assert.NoError(t, err) + assert.Equal(t, + concatByteSlices( + []byte{1, 0, 0, 0}, + []byte{2, 0, 0, 0}, + []byte{3, 0, 0, 0}, + ), + buf.Bytes(), + ) + } + } + { + { + var buf bytes.Buffer + enc := NewBinEncoder(&buf) + + arr := []uint32{1, 2, 3} + l := len(arr) + + err := reflect_writeArrayOfUint32(enc, l, reflect.ValueOf(arr), LE) + assert.NoError(t, err) + + assert.Equal(t, + concatByteSlices( + []byte{1, 0, 0, 0}, + []byte{2, 0, 0, 0}, + []byte{3, 0, 0, 0}, + ), + buf.Bytes(), + ) + } + { + var buf bytes.Buffer + enc := NewBorshEncoder(&buf) + + arr := []uint32{1, 2, 3} + l := len(arr) + + err := reflect_writeArrayOfUint32(enc, l, reflect.ValueOf(arr), LE) + assert.NoError(t, err) + + assert.Equal(t, + concatByteSlices( + []byte{1, 0, 0, 0}, + []byte{2, 0, 0, 0}, + []byte{3, 0, 0, 0}, + ), + buf.Bytes(), + ) + } + } + { + var buf bytes.Buffer + enc := NewBinEncoder(&buf) + + arr := [10]uint32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} + l := len(arr) + + err := reflect_writeArrayOfUint32(enc, l, reflect.ValueOf(arr), LE) + assert.NoError(t, err) + assert.Equal(t, + concatByteSlices( + []byte{1, 0, 0, 0}, + []byte{2, 0, 0, 0}, + []byte{3, 0, 0, 0}, + []byte{4, 0, 0, 0}, + []byte{5, 0, 0, 0}, + []byte{6, 0, 0, 0}, + []byte{7, 0, 0, 0}, + []byte{8, 0, 0, 0}, + []byte{9, 0, 0, 0}, + []byte{10, 0, 0, 0}, + ), + buf.Bytes()) + } + { + var buf bytes.Buffer + enc := NewBinEncoder(&buf) + + arr := [32]uint32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32} + l := len(arr) + + err := reflect_writeArrayOfUint32(enc, l, reflect.ValueOf(arr), LE) + assert.NoError(t, err) + assert.Equal(t, + concatByteSlices( + []byte{1, 0, 0, 0}, + []byte{2, 0, 0, 0}, + []byte{3, 0, 0, 0}, + []byte{4, 0, 0, 0}, + []byte{5, 0, 0, 0}, + []byte{6, 0, 0, 0}, + []byte{7, 0, 0, 0}, + []byte{8, 0, 0, 0}, + []byte{9, 0, 0, 0}, + []byte{10, 0, 0, 0}, + []byte{11, 0, 0, 0}, + []byte{12, 0, 0, 0}, + []byte{13, 0, 0, 0}, + []byte{14, 0, 0, 0}, + []byte{15, 0, 0, 0}, + []byte{16, 0, 0, 0}, + []byte{17, 0, 0, 0}, + []byte{18, 0, 0, 0}, + []byte{19, 0, 0, 0}, + []byte{20, 0, 0, 0}, + []byte{21, 0, 0, 0}, + []byte{22, 0, 0, 0}, + []byte{23, 0, 0, 0}, + []byte{24, 0, 0, 0}, + []byte{25, 0, 0, 0}, + []byte{26, 0, 0, 0}, + []byte{27, 0, 0, 0}, + []byte{28, 0, 0, 0}, + []byte{29, 0, 0, 0}, + []byte{30, 0, 0, 0}, + []byte{31, 0, 0, 0}, + []byte{32, 0, 0, 0}, + ), + buf.Bytes()) + } + +} + +func Test_writeArrayOfUint64(t *testing.T) { + { + { + var buf bytes.Buffer + enc := NewBinEncoder(&buf) + + arr := [3]uint64{1, 2, 3} + l := len(arr) + + err := reflect_writeArrayOfUint64(enc, l, reflect.ValueOf(arr), LE) + assert.NoError(t, err) + assert.Equal(t, + concatByteSlices( + []byte{1, 0, 0, 0, 0, 0, 0, 0}, + []byte{2, 0, 0, 0, 0, 0, 0, 0}, + []byte{3, 0, 0, 0, 0, 0, 0, 0}, + ), + buf.Bytes(), + ) + } + { + var buf bytes.Buffer + enc := NewBinEncoder(&buf) + + arr := []uint64{1, 2, 3} + l := len(arr) + + err := reflect_writeArrayOfUint64(enc, l, reflect.ValueOf(arr), LE) + assert.NoError(t, err) + assert.Equal(t, + concatByteSlices( + []byte{1, 0, 0, 0, 0, 0, 0, 0}, + []byte{2, 0, 0, 0, 0, 0, 0, 0}, + []byte{3, 0, 0, 0, 0, 0, 0, 0}, + ), + buf.Bytes(), + ) + } + } + { + { + var buf bytes.Buffer + enc := NewBorshEncoder(&buf) + + arr := [3]uint64{1, 2, 3} + l := len(arr) + + err := reflect_writeArrayOfUint64(enc, l, reflect.ValueOf(arr), LE) + assert.NoError(t, err) + assert.Equal(t, + concatByteSlices( + []byte{1, 0, 0, 0, 0, 0, 0, 0}, + []byte{2, 0, 0, 0, 0, 0, 0, 0}, + []byte{3, 0, 0, 0, 0, 0, 0, 0}, + ), + buf.Bytes(), + ) + } + { + var buf bytes.Buffer + enc := NewBorshEncoder(&buf) + + arr := []uint64{1, 2, 3} + l := len(arr) + + err := reflect_writeArrayOfUint64(enc, l, reflect.ValueOf(arr), LE) + assert.NoError(t, err) + assert.Equal(t, + concatByteSlices( + []byte{1, 0, 0, 0, 0, 0, 0, 0}, + []byte{2, 0, 0, 0, 0, 0, 0, 0}, + []byte{3, 0, 0, 0, 0, 0, 0, 0}, + ), + buf.Bytes(), + ) + } + } + { + var buf bytes.Buffer + + enc := NewBinEncoder(&buf) + arr := [64]uint64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64} + l := len(arr) + + err := reflect_writeArrayOfUint64(enc, l, reflect.ValueOf(arr), LE) + assert.NoError(t, err) + + assert.Equal(t, + concatByteSlices( + []byte{1, 0, 0, 0, 0, 0, 0, 0}, + []byte{2, 0, 0, 0, 0, 0, 0, 0}, + []byte{3, 0, 0, 0, 0, 0, 0, 0}, + []byte{4, 0, 0, 0, 0, 0, 0, 0}, + []byte{5, 0, 0, 0, 0, 0, 0, 0}, + []byte{6, 0, 0, 0, 0, 0, 0, 0}, + []byte{7, 0, 0, 0, 0, 0, 0, 0}, + []byte{8, 0, 0, 0, 0, 0, 0, 0}, + []byte{9, 0, 0, 0, 0, 0, 0, 0}, + []byte{10, 0, 0, 0, 0, 0, 0, 0}, + []byte{11, 0, 0, 0, 0, 0, 0, 0}, + []byte{12, 0, 0, 0, 0, 0, 0, 0}, + []byte{13, 0, 0, 0, 0, 0, 0, 0}, + []byte{14, 0, 0, 0, 0, 0, 0, 0}, + []byte{15, 0, 0, 0, 0, 0, 0, 0}, + []byte{16, 0, 0, 0, 0, 0, 0, 0}, + []byte{17, 0, 0, 0, 0, 0, 0, 0}, + []byte{18, 0, 0, 0, 0, 0, 0, 0}, + []byte{19, 0, 0, 0, 0, 0, 0, 0}, + []byte{20, 0, 0, 0, 0, 0, 0, 0}, + []byte{21, 0, 0, 0, 0, 0, 0, 0}, + []byte{22, 0, 0, 0, 0, 0, 0, 0}, + []byte{23, 0, 0, 0, 0, 0, 0, 0}, + []byte{24, 0, 0, 0, 0, 0, 0, 0}, + []byte{25, 0, 0, 0, 0, 0, 0, 0}, + []byte{26, 0, 0, 0, 0, 0, 0, 0}, + []byte{27, 0, 0, 0, 0, 0, 0, 0}, + []byte{28, 0, 0, 0, 0, 0, 0, 0}, + []byte{29, 0, 0, 0, 0, 0, 0, 0}, + []byte{30, 0, 0, 0, 0, 0, 0, 0}, + []byte{31, 0, 0, 0, 0, 0, 0, 0}, + []byte{32, 0, 0, 0, 0, 0, 0, 0}, + []byte{33, 0, 0, 0, 0, 0, 0, 0}, + []byte{34, 0, 0, 0, 0, 0, 0, 0}, + []byte{35, 0, 0, 0, 0, 0, 0, 0}, + []byte{36, 0, 0, 0, 0, 0, 0, 0}, + []byte{37, 0, 0, 0, 0, 0, 0, 0}, + []byte{38, 0, 0, 0, 0, 0, 0, 0}, + []byte{39, 0, 0, 0, 0, 0, 0, 0}, + []byte{40, 0, 0, 0, 0, 0, 0, 0}, + []byte{41, 0, 0, 0, 0, 0, 0, 0}, + []byte{42, 0, 0, 0, 0, 0, 0, 0}, + []byte{43, 0, 0, 0, 0, 0, 0, 0}, + []byte{44, 0, 0, 0, 0, 0, 0, 0}, + []byte{45, 0, 0, 0, 0, 0, 0, 0}, + []byte{46, 0, 0, 0, 0, 0, 0, 0}, + []byte{47, 0, 0, 0, 0, 0, 0, 0}, + []byte{48, 0, 0, 0, 0, 0, 0, 0}, + []byte{49, 0, 0, 0, 0, 0, 0, 0}, + []byte{50, 0, 0, 0, 0, 0, 0, 0}, + []byte{51, 0, 0, 0, 0, 0, 0, 0}, + []byte{52, 0, 0, 0, 0, 0, 0, 0}, + []byte{53, 0, 0, 0, 0, 0, 0, 0}, + []byte{54, 0, 0, 0, 0, 0, 0, 0}, + []byte{55, 0, 0, 0, 0, 0, 0, 0}, + []byte{56, 0, 0, 0, 0, 0, 0, 0}, + []byte{57, 0, 0, 0, 0, 0, 0, 0}, + []byte{58, 0, 0, 0, 0, 0, 0, 0}, + []byte{59, 0, 0, 0, 0, 0, 0, 0}, + []byte{60, 0, 0, 0, 0, 0, 0, 0}, + []byte{61, 0, 0, 0, 0, 0, 0, 0}, + []byte{62, 0, 0, 0, 0, 0, 0, 0}, + []byte{63, 0, 0, 0, 0, 0, 0, 0}, + []byte{64, 0, 0, 0, 0, 0, 0, 0}, + ), + buf.Bytes()) + } +} diff --git a/interface_test.go b/interface_test.go index ce3aaf4..f053330 100644 --- a/interface_test.go +++ b/interface_test.go @@ -36,33 +36,116 @@ func (e *Example) UnmarshalWithDecoder(decoder *Decoder) (err error) { return nil } -func (e *Example) MarshalWithEncoder(encoder *Encoder) error { +func (e Example) MarshalWithEncoder(encoder *Encoder) error { if err := encoder.WriteByte(e.Prefix); err != nil { return err } return encoder.WriteUint32(e.Value, BE) } +type testCustomCoder struct { + val string +} + +func (d *testCustomCoder) UnmarshalWithDecoder(decoder *Decoder) error { + d.val = "hello world" + return nil +} + +func (d testCustomCoder) MarshalWithEncoder(encoder *Encoder) error { + return encoder.WriteBytes([]byte("this is a test"), false) +} + func TestMarshalWithEncoder(t *testing.T) { - buf := new(bytes.Buffer) - e := &Example{Value: 72, Prefix: 0xaa} - enc := NewBinEncoder(buf) - enc.Encode(e) - - assert.Equal(t, []byte{ - 0xaa, 0x00, 0x00, 0x00, 0x48, - }, buf.Bytes()) + { + buf := new(bytes.Buffer) + e := &Example{Value: 72, Prefix: 0xaa} + enc := NewBinEncoder(buf) + enc.Encode(e) + + assert.Equal(t, []byte{ + 0xaa, 0x00, 0x00, 0x00, 0x48, + }, buf.Bytes()) + } + { + // on pointer: + { + buf := new(bytes.Buffer) + e := &testCustomCoder{} + enc := NewBinEncoder(buf) + err := enc.Encode(e) + assert.NoError(t, err) + + assert.Equal(t, []byte("this is a test"), buf.Bytes()) + } + { + buf := new(bytes.Buffer) + e := &testCustomCoder{} + enc := NewBorshEncoder(buf) + err := enc.Encode(e) + assert.NoError(t, err) + + assert.Equal(t, []byte("this is a test"), buf.Bytes()) + } + // on value: + { + buf := new(bytes.Buffer) + e := testCustomCoder{} + enc := NewBinEncoder(buf) + err := enc.Encode(e) + assert.NoError(t, err) + + assert.Equal(t, []byte("this is a test"), buf.Bytes()) + } + { + buf := new(bytes.Buffer) + e := testCustomCoder{} + enc := NewBorshEncoder(buf) + err := enc.Encode(e) + assert.NoError(t, err) + + assert.Equal(t, []byte("this is a test"), buf.Bytes()) + } + } } func TestUnmarshalWithDecoder(t *testing.T) { - buf := []byte{ - 0xaa, 0x00, 0x00, 0x00, 0x48, + { + buf := []byte{ + 0xaa, 0x00, 0x00, 0x00, 0x48, + } + + e := &Example{} + d := NewBinDecoder(buf) + err := d.Decode(e) + assert.NoError(t, err) + assert.Equal(t, e, &Example{Value: 72, Prefix: 0xaa}) + assert.Equal(t, 0, d.Remaining()) } + { + { + buf := []byte{ + 0xaa, 0x00, 0x00, 0x00, 0x48, + } - e := &Example{} - d := NewBinDecoder(buf) - err := d.Decode(e) - assert.NoError(t, err) - assert.Equal(t, e, &Example{Value: 72, Prefix: 0xaa}) - assert.Equal(t, 0, d.Remaining()) + e := &testCustomCoder{} + d := NewBinDecoder(buf) + err := d.Decode(e) + assert.NoError(t, err) + + assert.Equal(t, "hello world", e.val) + } + { + buf := []byte{ + 0xaa, 0x00, 0x00, 0x00, 0x48, + } + + e := &testCustomCoder{} + d := NewBorshDecoder(buf) + err := d.Decode(e) + assert.NoError(t, err) + + assert.Equal(t, "hello world", e.val) + } + } } diff --git a/tools_test.go b/tools_test.go new file mode 100644 index 0000000..24d4657 --- /dev/null +++ b/tools_test.go @@ -0,0 +1,27 @@ +package bin + +import "encoding/binary" + +func concatByteSlices(slices ...[]byte) (out []byte) { + for i := range slices { + out = append(out, slices[i]...) + } + return +} +func uint16ToBytes(i uint16, order binary.ByteOrder) []byte { + buf := make([]byte, 2) + order.PutUint16(buf, i) + return buf +} + +func uint32ToBytes(i uint32, order binary.ByteOrder) []byte { + buf := make([]byte, 4) + order.PutUint32(buf, i) + return buf +} + +func uint64ToBytes(i uint64, order binary.ByteOrder) []byte { + buf := make([]byte, 8) + order.PutUint64(buf, i) + return buf +}