Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enum primitives and empty enum variants #2

Merged
merged 8 commits into from Apr 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
73 changes: 73 additions & 0 deletions borsh_test.go
Expand Up @@ -846,6 +846,20 @@ func TestBorsh_Encode(t *testing.T) {
BarB: "this is bar from pointer",
},
},
ComplexEmpty: ComplexEnumEmpty{
Enum: 0,
Foo: EmptyVariant{},
},

ComplexPrimitives1: ComplexEnumPrimitives{
Enum: 0,
Foo: 20,
},

ComplexPrimitives2: ComplexEnumPrimitives{
Enum: 1,
Bar: 11,
},

Complex2: ComplexEnumPointers{
Enum: 1,
Expand Down Expand Up @@ -939,6 +953,17 @@ func TestBorsh_Encode(t *testing.T) {
[]byte{0, 0, 0, 0},
[]byte{0, 0, 0, 0},

// .ComplexEmpty
[]byte{0},

// .ComplexPrimitives1
[]byte{0},
[]byte{20, 0, 0, 0},

// .ComplexPrimitives2
[]byte{1},
[]byte{11, 0},

// .Complex2
[]byte{1},
[]byte{62, 0, 0, 0, 0, 0, 0, 0},
Expand Down Expand Up @@ -1025,6 +1050,10 @@ type StructWithEnum struct {
ComplexPtr *ComplexEnum
ComplexPtrNotSet *ComplexEnum

ComplexEmpty ComplexEnumEmpty
ComplexPrimitives1 ComplexEnumPrimitives
ComplexPrimitives2 ComplexEnumPrimitives

Complex2 ComplexEnumPointers
Complex2Ptr *ComplexEnumPointers

Expand Down Expand Up @@ -1361,6 +1390,19 @@ type ComplexEnumPointers struct {
Foo *Foo
Bar *Bar
}

type ComplexEnumEmpty struct {
Enum BorshEnum `borsh_enum:"true"`
Foo EmptyVariant
Bar Bar
}

type ComplexEnumPrimitives struct {
Enum BorshEnum `borsh_enum:"true"`
Foo uint32
Bar int16
}

type Foo struct {
FooA int32
FooB string
Expand Down Expand Up @@ -1404,6 +1446,37 @@ func TestComplexEnum(t *testing.T) {
err = UnmarshalBorsh(y, data)
require.NoError(t, err)

require.Equal(t, x, *y)
}
{
x := ComplexEnumEmpty{
Enum: 1,
Bar: Bar{
BarA: 23,
BarB: "baz",
},
}
data, err := MarshalBorsh(x)
require.NoError(t, err)

y := new(ComplexEnumEmpty)
err = UnmarshalBorsh(y, data)
require.NoError(t, err)

require.Equal(t, x, *y)
}
{
x := ComplexEnumPrimitives{
Enum: 1,
Bar: 22,
}
data, err := MarshalBorsh(x)
require.NoError(t, err)

y := new(ComplexEnumPrimitives)
err = UnmarshalBorsh(y, data)
require.NoError(t, err)

require.Equal(t, x, *y)
}
}
Expand Down
89 changes: 60 additions & 29 deletions encoder_borsh.go
Expand Up @@ -26,6 +26,43 @@ import (
"go.uber.org/zap"
)

func (e *Encoder) encodePrimitive(rv reflect.Value, opt *option) (isPrimitive bool, err error) {
isPrimitive = true
switch rv.Kind() {
// case reflect.Int:
// err = e.WriteInt64(rv.Int(), LE)
// case reflect.Uint:
// err = e.WriteUint64(rv.Uint(), LE)
case reflect.String:
err = e.WriteString(rv.String())
case reflect.Uint8:
err = e.WriteByte(byte(rv.Uint()))
case reflect.Int8:
err = e.WriteByte(byte(rv.Int()))
case reflect.Int16:
err = e.WriteInt16(int16(rv.Int()), LE)
case reflect.Uint16:
err = e.WriteUint16(uint16(rv.Uint()), LE)
case reflect.Int32:
err = e.WriteInt32(int32(rv.Int()), LE)
case reflect.Uint32:
err = e.WriteUint32(uint32(rv.Uint()), LE)
case reflect.Uint64:
err = e.WriteUint64(rv.Uint(), LE)
case reflect.Int64:
err = e.WriteInt64(rv.Int(), LE)
case reflect.Float32:
err = e.WriteFloat32(float32(rv.Float()), LE)
case reflect.Float64:
err = e.WriteFloat64(rv.Float(), LE)
case reflect.Bool:
err = e.WriteBool(rv.Bool())
default:
isPrimitive = false
}
return
}

func (e *Encoder) encodeBorsh(rv reflect.Value, opt *option) (err error) {
if opt == nil {
opt = newDefaultOption()
Expand Down Expand Up @@ -70,35 +107,13 @@ func (e *Encoder) encodeBorsh(rv reflect.Value, opt *option) (err error) {
return marshaler.MarshalWithEncoder(e)
}

// Encode the value if it's a primitive type
isPrimitive, err := e.encodePrimitive(rv, nil)
if isPrimitive {
return err
}

switch rv.Kind() {
// case reflect.Int:
// return e.WriteInt64(rv.Int(), LE)
// case reflect.Uint:
// return e.WriteUint64(rv.Uint(), LE)
case reflect.String:
return e.WriteString(rv.String())
case reflect.Uint8:
return e.WriteByte(byte(rv.Uint()))
case reflect.Int8:
return e.WriteByte(byte(rv.Int()))
case reflect.Int16:
return e.WriteInt16(int16(rv.Int()), LE)
case reflect.Uint16:
return e.WriteUint16(uint16(rv.Uint()), LE)
case reflect.Int32:
return e.WriteInt32(int32(rv.Int()), LE)
case reflect.Uint32:
return e.WriteUint32(uint32(rv.Uint()), LE)
case reflect.Uint64:
return e.WriteUint64(rv.Uint(), LE)
case reflect.Int64:
return e.WriteInt64(rv.Int(), LE)
case reflect.Float32:
return e.WriteFloat32(float32(rv.Float()), LE)
case reflect.Float64:
return e.WriteFloat64(rv.Float(), LE)
case reflect.Bool:
return e.WriteBool(rv.Bool())
case reflect.Ptr:
if rv.IsNil() {
el := reflect.New(rv.Type().Elem()).Elem()
Expand Down Expand Up @@ -222,19 +237,35 @@ func (enc *Encoder) encodeComplexEnumBorsh(rv reflect.Value) error {
if int(enum)+1 >= t.NumField() {
return errors.New("complex enum too large")
}
// Enum is empty
field := rv.Field(int(enum) + 1)

if field.Kind() == reflect.Ptr {
field = field.Elem()
}
if field.Kind() == reflect.Struct {
return enc.encodeStructBorsh(field.Type(), field)
}
// Encode the value if it's a primitive type
isPrimitive, err := enc.encodePrimitive(field, nil)
if isPrimitive {
return err
}
return nil
}

type BorshEnum uint8

// EmptyVariant is an empty borsh enum variant.
type EmptyVariant struct{}

func (_ *EmptyVariant) MarshalWithEncoder(_ *Encoder) error {
return nil
}

func (_ *EmptyVariant) UnmarshalWithEncoder(_ *Encoder) error {
return nil
}

func (e *Encoder) encodeStructBorsh(rt reflect.Type, rv reflect.Value) (err error) {
l := rv.NumField()

Expand Down