Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enum primitives and empty enum variants #2

Merged
merged 8 commits into from Apr 6, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
73 changes: 73 additions & 0 deletions borsh_test.go
Expand Up @@ -846,6 +846,20 @@ func TestBorsh_Encode(t *testing.T) {
BarB: "this is bar from pointer",
},
},
ComplexEmpty: ComplexEnumEmpty{
Enum: 0,
Foo: BorshEnumEmpty{},
ozgb marked this conversation as resolved.
Show resolved Hide resolved
},

ComplexPrimitives1: ComplexEnumPrimitives{
Enum: 0,
Foo: 20,
},

ComplexPrimitives2: ComplexEnumPrimitives{
Enum: 1,
Bar: 11,
},

Complex2: ComplexEnumPointers{
Enum: 1,
Expand Down Expand Up @@ -939,6 +953,17 @@ func TestBorsh_Encode(t *testing.T) {
[]byte{0, 0, 0, 0},
[]byte{0, 0, 0, 0},

// .ComplexEmpty
[]byte{0},

// .ComplexPrimitives1
[]byte{0},
[]byte{20, 0, 0, 0},

// .ComplexPrimitives2
[]byte{1},
[]byte{11, 0},

// .Complex2
[]byte{1},
[]byte{62, 0, 0, 0, 0, 0, 0, 0},
Expand Down Expand Up @@ -1025,6 +1050,10 @@ type StructWithEnum struct {
ComplexPtr *ComplexEnum
ComplexPtrNotSet *ComplexEnum

ComplexEmpty ComplexEnumEmpty
ComplexPrimitives1 ComplexEnumPrimitives
ComplexPrimitives2 ComplexEnumPrimitives

Complex2 ComplexEnumPointers
Complex2Ptr *ComplexEnumPointers

Expand Down Expand Up @@ -1361,6 +1390,19 @@ type ComplexEnumPointers struct {
Foo *Foo
Bar *Bar
}

type ComplexEnumEmpty struct {
Enum BorshEnum `borsh_enum:"true"`
Foo BorshEnumEmpty
ozgb marked this conversation as resolved.
Show resolved Hide resolved
Bar Bar
}

type ComplexEnumPrimitives struct {
Enum BorshEnum `borsh_enum:"true"`
Foo uint32
Bar int16
}

type Foo struct {
FooA int32
FooB string
Expand Down Expand Up @@ -1404,6 +1446,37 @@ func TestComplexEnum(t *testing.T) {
err = UnmarshalBorsh(y, data)
require.NoError(t, err)

require.Equal(t, x, *y)
}
{
x := ComplexEnumEmpty{
Enum: 1,
Bar: Bar{
BarA: 23,
BarB: "baz",
},
}
data, err := MarshalBorsh(x)
require.NoError(t, err)

y := new(ComplexEnumEmpty)
err = UnmarshalBorsh(y, data)
require.NoError(t, err)

require.Equal(t, x, *y)
}
{
x := ComplexEnumPrimitives{
Enum: 1,
Bar: 22,
}
data, err := MarshalBorsh(x)
require.NoError(t, err)

y := new(ComplexEnumPrimitives)
err = UnmarshalBorsh(y, data)
require.NoError(t, err)

require.Equal(t, x, *y)
}
}
Expand Down
8 changes: 8 additions & 0 deletions decoder_borsh.go
Expand Up @@ -86,6 +86,10 @@ func (dec *Decoder) decodeBorsh(rv reflect.Value, opt *option) (err error) {
}

rt := rv.Type()
if isTypeBorshEnumEmpty(rt) {
// Empty enum type, nothing to deserialize
return
}
ozgb marked this conversation as resolved.
Show resolved Hide resolved
switch rv.Kind() {
// case reflect.Int:
// // TODO: check if is x32 or x64
Expand Down Expand Up @@ -269,6 +273,10 @@ func isTypeBorshEnum(typ reflect.Type) bool {
return typ.Kind() == reflect.Uint8 && typ == borshEnumType
}

func isTypeBorshEnumEmpty(typ reflect.Type) bool {
return typ.Kind() == reflect.Uint8 && typ == borshEnumType
}

ozgb marked this conversation as resolved.
Show resolved Hide resolved
func (dec *Decoder) decodeStructBorsh(rt reflect.Type, rv reflect.Value) (err error) {
l := rv.NumField()

Expand Down
98 changes: 69 additions & 29 deletions encoder_borsh.go
Expand Up @@ -26,6 +26,55 @@ import (
"go.uber.org/zap"
)

func (e *Encoder) encodePrimitive(rv reflect.Value, opt *option) (isPrimitive bool, err error) {
isPrimitive = true
switch rv.Kind() {
// case reflect.Int:
// err = e.WriteInt64(rv.Int(), LE)
// case reflect.Uint:
// err = e.WriteUint64(rv.Uint(), LE)
case reflect.String:
err = e.WriteString(rv.String())
break
case reflect.Uint8:
err = e.WriteByte(byte(rv.Uint()))
break
case reflect.Int8:
err = e.WriteByte(byte(rv.Int()))
break
case reflect.Int16:
err = e.WriteInt16(int16(rv.Int()), LE)
break
case reflect.Uint16:
err = e.WriteUint16(uint16(rv.Uint()), LE)
break
case reflect.Int32:
err = e.WriteInt32(int32(rv.Int()), LE)
break
case reflect.Uint32:
err = e.WriteUint32(uint32(rv.Uint()), LE)
break
case reflect.Uint64:
err = e.WriteUint64(rv.Uint(), LE)
break
case reflect.Int64:
err = e.WriteInt64(rv.Int(), LE)
break
case reflect.Float32:
err = e.WriteFloat32(float32(rv.Float()), LE)
break
case reflect.Float64:
err = e.WriteFloat64(rv.Float(), LE)
break
case reflect.Bool:
err = e.WriteBool(rv.Bool())
break
default:
isPrimitive = false
}
return
}

func (e *Encoder) encodeBorsh(rv reflect.Value, opt *option) (err error) {
if opt == nil {
opt = newDefaultOption()
Expand Down Expand Up @@ -70,35 +119,15 @@ func (e *Encoder) encodeBorsh(rv reflect.Value, opt *option) (err error) {
return marshaler.MarshalWithEncoder(e)
}

// Encode the value if it's a primitive type
isPrimitive, err := e.encodePrimitive(rv, nil)
if isPrimitive {
return
} else if err != nil {
return err
}

ozgb marked this conversation as resolved.
Show resolved Hide resolved
switch rv.Kind() {
// case reflect.Int:
// return e.WriteInt64(rv.Int(), LE)
// case reflect.Uint:
// return e.WriteUint64(rv.Uint(), LE)
case reflect.String:
return e.WriteString(rv.String())
case reflect.Uint8:
return e.WriteByte(byte(rv.Uint()))
case reflect.Int8:
return e.WriteByte(byte(rv.Int()))
case reflect.Int16:
return e.WriteInt16(int16(rv.Int()), LE)
case reflect.Uint16:
return e.WriteUint16(uint16(rv.Uint()), LE)
case reflect.Int32:
return e.WriteInt32(int32(rv.Int()), LE)
case reflect.Uint32:
return e.WriteUint32(uint32(rv.Uint()), LE)
case reflect.Uint64:
return e.WriteUint64(rv.Uint(), LE)
case reflect.Int64:
return e.WriteInt64(rv.Int(), LE)
case reflect.Float32:
return e.WriteFloat32(float32(rv.Float()), LE)
case reflect.Float64:
return e.WriteFloat64(rv.Float(), LE)
case reflect.Bool:
return e.WriteBool(rv.Bool())
case reflect.Ptr:
if rv.IsNil() {
el := reflect.New(rv.Type().Elem()).Elem()
Expand Down Expand Up @@ -222,18 +251,29 @@ func (enc *Encoder) encodeComplexEnumBorsh(rv reflect.Value) error {
if int(enum)+1 >= t.NumField() {
return errors.New("complex enum too large")
}
// Enum is empty
field := rv.Field(int(enum) + 1)

if isTypeBorshEnumEmpty(field.Type()) {
return nil
}
ozgb marked this conversation as resolved.
Show resolved Hide resolved
if field.Kind() == reflect.Ptr {
field = field.Elem()
}
if field.Kind() == reflect.Struct {
return enc.encodeStructBorsh(field.Type(), field)
}
// Encode the value if it's a primitive type
isPrimitive, err := enc.encodePrimitive(field, nil)
if isPrimitive {
return nil
} else if err != nil {
return err
}
ozgb marked this conversation as resolved.
Show resolved Hide resolved
return nil
}

type BorshEnum uint8
type BorshEnumEmpty struct{}
ozgb marked this conversation as resolved.
Show resolved Hide resolved

func (e *Encoder) encodeStructBorsh(rt reflect.Type, rv reflect.Value) (err error) {
l := rv.NumField()
Expand Down