Skip to content

Commit

Permalink
Enum primitives and empty enum variants (#2)
Browse files Browse the repository at this point in the history
* Implement empty enum - working with tests

TODO: Ensure primitive types are working with enums

* Add further test for enums with primitives

* go fmt

* Rename BorshEnumEmpty -> EmptyVariant

Resolves PR Comment

* Remove isEnumEmpty check -> use marshal interface

Resolving PR comments

* go fmt

* Refactor error checking on encodePrimitive

Resolving PR comments

* Remove redundant breaks
  • Loading branch information
ozgb committed Apr 6, 2022
1 parent 59fa436 commit 65cc9d2
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 29 deletions.
73 changes: 73 additions & 0 deletions borsh_test.go
Expand Up @@ -846,6 +846,20 @@ func TestBorsh_Encode(t *testing.T) {
BarB: "this is bar from pointer",
},
},
ComplexEmpty: ComplexEnumEmpty{
Enum: 0,
Foo: EmptyVariant{},
},

ComplexPrimitives1: ComplexEnumPrimitives{
Enum: 0,
Foo: 20,
},

ComplexPrimitives2: ComplexEnumPrimitives{
Enum: 1,
Bar: 11,
},

Complex2: ComplexEnumPointers{
Enum: 1,
Expand Down Expand Up @@ -939,6 +953,17 @@ func TestBorsh_Encode(t *testing.T) {
[]byte{0, 0, 0, 0},
[]byte{0, 0, 0, 0},

// .ComplexEmpty
[]byte{0},

// .ComplexPrimitives1
[]byte{0},
[]byte{20, 0, 0, 0},

// .ComplexPrimitives2
[]byte{1},
[]byte{11, 0},

// .Complex2
[]byte{1},
[]byte{62, 0, 0, 0, 0, 0, 0, 0},
Expand Down Expand Up @@ -1025,6 +1050,10 @@ type StructWithEnum struct {
ComplexPtr *ComplexEnum
ComplexPtrNotSet *ComplexEnum

ComplexEmpty ComplexEnumEmpty
ComplexPrimitives1 ComplexEnumPrimitives
ComplexPrimitives2 ComplexEnumPrimitives

Complex2 ComplexEnumPointers
Complex2Ptr *ComplexEnumPointers

Expand Down Expand Up @@ -1361,6 +1390,19 @@ type ComplexEnumPointers struct {
Foo *Foo
Bar *Bar
}

type ComplexEnumEmpty struct {
Enum BorshEnum `borsh_enum:"true"`
Foo EmptyVariant
Bar Bar
}

type ComplexEnumPrimitives struct {
Enum BorshEnum `borsh_enum:"true"`
Foo uint32
Bar int16
}

type Foo struct {
FooA int32
FooB string
Expand Down Expand Up @@ -1404,6 +1446,37 @@ func TestComplexEnum(t *testing.T) {
err = UnmarshalBorsh(y, data)
require.NoError(t, err)

require.Equal(t, x, *y)
}
{
x := ComplexEnumEmpty{
Enum: 1,
Bar: Bar{
BarA: 23,
BarB: "baz",
},
}
data, err := MarshalBorsh(x)
require.NoError(t, err)

y := new(ComplexEnumEmpty)
err = UnmarshalBorsh(y, data)
require.NoError(t, err)

require.Equal(t, x, *y)
}
{
x := ComplexEnumPrimitives{
Enum: 1,
Bar: 22,
}
data, err := MarshalBorsh(x)
require.NoError(t, err)

y := new(ComplexEnumPrimitives)
err = UnmarshalBorsh(y, data)
require.NoError(t, err)

require.Equal(t, x, *y)
}
}
Expand Down
89 changes: 60 additions & 29 deletions encoder_borsh.go
Expand Up @@ -26,6 +26,43 @@ import (
"go.uber.org/zap"
)

func (e *Encoder) encodePrimitive(rv reflect.Value, opt *option) (isPrimitive bool, err error) {
isPrimitive = true
switch rv.Kind() {
// case reflect.Int:
// err = e.WriteInt64(rv.Int(), LE)
// case reflect.Uint:
// err = e.WriteUint64(rv.Uint(), LE)
case reflect.String:
err = e.WriteString(rv.String())
case reflect.Uint8:
err = e.WriteByte(byte(rv.Uint()))
case reflect.Int8:
err = e.WriteByte(byte(rv.Int()))
case reflect.Int16:
err = e.WriteInt16(int16(rv.Int()), LE)
case reflect.Uint16:
err = e.WriteUint16(uint16(rv.Uint()), LE)
case reflect.Int32:
err = e.WriteInt32(int32(rv.Int()), LE)
case reflect.Uint32:
err = e.WriteUint32(uint32(rv.Uint()), LE)
case reflect.Uint64:
err = e.WriteUint64(rv.Uint(), LE)
case reflect.Int64:
err = e.WriteInt64(rv.Int(), LE)
case reflect.Float32:
err = e.WriteFloat32(float32(rv.Float()), LE)
case reflect.Float64:
err = e.WriteFloat64(rv.Float(), LE)
case reflect.Bool:
err = e.WriteBool(rv.Bool())
default:
isPrimitive = false
}
return
}

func (e *Encoder) encodeBorsh(rv reflect.Value, opt *option) (err error) {
if opt == nil {
opt = newDefaultOption()
Expand Down Expand Up @@ -70,35 +107,13 @@ func (e *Encoder) encodeBorsh(rv reflect.Value, opt *option) (err error) {
return marshaler.MarshalWithEncoder(e)
}

// Encode the value if it's a primitive type
isPrimitive, err := e.encodePrimitive(rv, nil)
if isPrimitive {
return err
}

switch rv.Kind() {
// case reflect.Int:
// return e.WriteInt64(rv.Int(), LE)
// case reflect.Uint:
// return e.WriteUint64(rv.Uint(), LE)
case reflect.String:
return e.WriteString(rv.String())
case reflect.Uint8:
return e.WriteByte(byte(rv.Uint()))
case reflect.Int8:
return e.WriteByte(byte(rv.Int()))
case reflect.Int16:
return e.WriteInt16(int16(rv.Int()), LE)
case reflect.Uint16:
return e.WriteUint16(uint16(rv.Uint()), LE)
case reflect.Int32:
return e.WriteInt32(int32(rv.Int()), LE)
case reflect.Uint32:
return e.WriteUint32(uint32(rv.Uint()), LE)
case reflect.Uint64:
return e.WriteUint64(rv.Uint(), LE)
case reflect.Int64:
return e.WriteInt64(rv.Int(), LE)
case reflect.Float32:
return e.WriteFloat32(float32(rv.Float()), LE)
case reflect.Float64:
return e.WriteFloat64(rv.Float(), LE)
case reflect.Bool:
return e.WriteBool(rv.Bool())
case reflect.Ptr:
if rv.IsNil() {
el := reflect.New(rv.Type().Elem()).Elem()
Expand Down Expand Up @@ -222,19 +237,35 @@ func (enc *Encoder) encodeComplexEnumBorsh(rv reflect.Value) error {
if int(enum)+1 >= t.NumField() {
return errors.New("complex enum too large")
}
// Enum is empty
field := rv.Field(int(enum) + 1)

if field.Kind() == reflect.Ptr {
field = field.Elem()
}
if field.Kind() == reflect.Struct {
return enc.encodeStructBorsh(field.Type(), field)
}
// Encode the value if it's a primitive type
isPrimitive, err := enc.encodePrimitive(field, nil)
if isPrimitive {
return err
}
return nil
}

type BorshEnum uint8

// EmptyVariant is an empty borsh enum variant.
type EmptyVariant struct{}

func (_ *EmptyVariant) MarshalWithEncoder(_ *Encoder) error {
return nil
}

func (_ *EmptyVariant) UnmarshalWithEncoder(_ *Encoder) error {
return nil
}

func (e *Encoder) encodeStructBorsh(rt reflect.Type, rv reflect.Value) (err error) {
l := rv.NumField()

Expand Down

0 comments on commit 65cc9d2

Please sign in to comment.