Skip to content

Commit

Permalink
fix not encoding nil to null
Browse files Browse the repository at this point in the history
  • Loading branch information
tiltwind committed Nov 27, 2023
1 parent 6c40744 commit 3cf574c
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 35 deletions.
33 changes: 16 additions & 17 deletions encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,19 +166,22 @@ func (e *Encoder) Encode(v interface{}) error {
}

default:
t := UnpackPtrType(reflect.TypeOf(v))
vv := reflect.ValueOf(v)

if vv.Kind() == reflect.Ptr && (vv.IsNil() || !vv.IsValid()) {
e.buffer = EncNull(e.buffer)
return nil
}

t := UnpackPtrType(vv.Type())
switch t.Kind() {
case reflect.Struct:
vv := reflect.ValueOf(v)
if vv.Kind() != reflect.Ptr {
v = PackPtrInterface(v, vv)
} else {
vv = UnpackPtr(vv)
}
if !vv.IsValid() {
e.buffer = EncNull(e.buffer)
return nil
}

if vv.Type().String() == "time.Time" {
e.buffer = encDateInMs(e.buffer, v)
return nil
Expand All @@ -193,16 +196,13 @@ func (e *Encoder) Encode(v interface{}) error {
}
return e.encObject(vv.Interface())
case reflect.Slice, reflect.Array:
return e.encList(v)
return e.encList(vv)
case reflect.Map: // the type must be map[string]int
return e.encMap(v)
return e.encMap(v, vv)
case reflect.Bool:
vv := v.(*bool)
if vv != nil {
e.buffer = encBool(e.buffer, *vv)
} else {
e.buffer = encBool(e.buffer, false)
}
vv = UnpackPtr(vv)
e.buffer = encBool(e.buffer, vv.Interface().(bool))
return nil
case reflect.Int32:
if t == _typeOfRune {
e.buffer = encString(e.buffer, string(*v.(*Rune)))
Expand All @@ -218,9 +218,8 @@ func (e *Encoder) Encode(v interface{}) error {
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64: // resolve base type
vVal := reflect.ValueOf(v)
if reflect.Ptr == vVal.Kind() && !vVal.IsNil() {
return e.Encode(vVal.Elem().Interface())
if reflect.Ptr == vv.Kind() {
return e.Encode(vv.Elem().Interface())
}
default:
return perrors.Errorf("type not supported! %s", t.Kind().String())
Expand Down
36 changes: 35 additions & 1 deletion java_lang_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ func TestDecodeJavaLangObjectHolder(t *testing.T) {
FieldCharacter: &h,
}

RegisterPOJO(obj)
doJavaLangObjectHolderTest(t, obj)

got, err := decodeJavaResponse(`customReplyJavaLangObjectHolder`, ``, false)
assert.NoError(t, err)
Expand All @@ -225,3 +225,37 @@ func TestDecodeJavaLangObjectHolder(t *testing.T) {
t.Logf("customReplyJavaLangObjectHolderForNull: %T %+v", got, got)
assert.Equal(t, &JavaLangObjectHolder{}, got)
}

func TestNilJavaLangObject(t *testing.T) {
obj := &JavaLangObjectHolder{
FieldInteger: nil,
FieldLong: nil,
FieldBoolean: nil,
FieldShort: nil,
FieldByte: nil,
FieldFloat: nil,
FieldDouble: nil,
FieldCharacter: nil,
}

doJavaLangObjectHolderTest(t, obj)
}

func doJavaLangObjectHolderTest(t *testing.T, holder *JavaLangObjectHolder) {
RegisterPOJO(holder)

e := NewEncoder()
err := e.Encode(holder)
if err != nil {
t.Errorf("encode error: %v", err)
t.FailNow()
}
buf := e.Buffer()
decoder := NewDecoder(buf)
des, derr := decoder.Decode()
if derr != nil {
t.Errorf("dencode error: %v", derr)
t.FailNow()
}
assert.Equal(t, des, holder)
}
12 changes: 5 additions & 7 deletions list.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,9 @@ type Object interface{}
/////////////////////////////////////////

// encList write list
func (e *Encoder) encList(v interface{}) error {
if !strings.Contains(reflect.TypeOf(v).String(), "interface {}") {
// The `v` should not be nil, should check nil in advance.
func (e *Encoder) encList(v reflect.Value) error {
if !strings.Contains(v.Type().String(), "interface {}") {
return e.writeTypedList(v)
}
return e.writeUntypedList(v)
Expand All @@ -156,10 +157,9 @@ func (e *Encoder) encList(v interface{}) error {
// ::= x55 type value* 'Z' # variable-length list
// ::= 'V' type int value* # fixed-length list
// ::= [x70-77] type value* # fixed-length typed list
func (e *Encoder) writeTypedList(v interface{}) error {
func (e *Encoder) writeTypedList(value reflect.Value) error {
var err error

value := reflect.ValueOf(v)
// https://github.com/apache/dubbo-go-hessian2/issues/317
// if list is null, just return 'N'
if value.IsNil() {
Expand Down Expand Up @@ -202,11 +202,9 @@ func (e *Encoder) writeTypedList(v interface{}) error {
// ::= x57 value* 'Z' # variable-length untyped list
// ::= x58 int value* # fixed-length untyped list
// ::= [x78-7f] value* # fixed-length untyped list
func (e *Encoder) writeUntypedList(v interface{}) error {
func (e *Encoder) writeUntypedList(value reflect.Value) error {
var err error

value := reflect.ValueOf(v)

// check ref
if n, ok := e.checkRefMap(value); ok {
e.buffer = encRef(e.buffer, n)
Expand Down
10 changes: 5 additions & 5 deletions map.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,17 @@ func getMapKey(key reflect.Value, t reflect.Type) (interface{}, error) {
return nil, perrors.Errorf("unsupported map key kind %s", t.Kind().String())
}

func (e *Encoder) encMap(m interface{}) error {
// encMap encode map object.
// - `m` is the map
// - `value` is the reflect value of the map
func (e *Encoder) encMap(m interface{}, value reflect.Value) error {
var (
err error
k interface{}
typ reflect.Type
value reflect.Value
keys []reflect.Value
)

value = reflect.ValueOf(m)

// check ref
if n, ok := e.checkRefMap(value); ok {
e.buffer = encRef(e.buffer, n)
Expand All @@ -116,7 +116,7 @@ func (e *Encoder) encMap(m interface{}) error {
}

value = UnpackPtrValue(value)
// check nil map
// check nil map.
if value.IsNil() || (value.Kind() == reflect.Ptr && !value.Elem().IsValid()) {
e.buffer = EncNull(e.buffer)
return nil
Expand Down
7 changes: 2 additions & 5 deletions object_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -681,11 +681,8 @@ func TestBasePointer(t *testing.T) {
base = BasePointer{
A: nil,
}
expectedF := false
expectedBase := BasePointer{
A: &expectedF,
}
doTestBasePointer(t, &base, &expectedBase)

doTestBasePointer(t, &base, &base)
}

func doTestBasePointer(t *testing.T, base *BasePointer, expected *BasePointer) {
Expand Down

0 comments on commit 3cf574c

Please sign in to comment.