diff --git a/codec.go b/codec.go index ad9b964c..466edb45 100644 --- a/codec.go +++ b/codec.go @@ -172,6 +172,16 @@ func UnpackPtrType(typ reflect.Type) reflect.Type { return typ } +// UnpackType unpack pointer type to original type and return the pointer depth. +func UnpackType(typ reflect.Type) (reflect.Type, int) { + depth := 0 + for typ.Kind() == reflect.Ptr { + typ = typ.Elem() + depth++ + } + return typ, depth +} + // UnpackPtrValue unpack pointer value to original value // return the pointer if its elem is zero value, because lots of operations on zero value is invalid func UnpackPtrValue(v reflect.Value) reflect.Value { @@ -181,6 +191,14 @@ func UnpackPtrValue(v reflect.Value) reflect.Value { return v } +// UnpackToRootAddressableValue unpack pointer value to the root addressable value. +func UnpackToRootAddressableValue(v reflect.Value) reflect.Value { + for v.Kind() == reflect.Ptr && v.Elem().CanAddr() { + v = v.Elem() + } + return v +} + // SprintHex converts the []byte to a Hex string. func SprintHex(b []byte) (rs string) { rs = fmt.Sprintf("[]byte{") @@ -253,69 +271,56 @@ func EnsureRawAny(in interface{}) interface{} { // SetValue set the value to dest. // It will auto check the Ptr pack level and unpack/pack to the right level. -// It make sure success to set value +// It makes sure success to set value func SetValue(dest, v reflect.Value) { - // check whether the v is a ref holder - if v.IsValid() { - if h, ok := v.Interface().(*_refHolder); ok { - h.add(dest) - return - } + // zero value not need to set + if !v.IsValid() { + return } - // temporary process, only handle the same type of situation - if v.IsValid() && UnpackPtrType(dest.Type()) == UnpackPtrType(v.Type()) && dest.Kind() == reflect.Ptr && dest.CanSet() { - for dest.Type() != v.Type() { - v = PackPtr(v) - } + + vType := v.Type() + destType := dest.Type() + + // for most cases, the types are the same and can set the value directly. + if dest.CanSet() && destType == vType { dest.Set(v) return } - // if the kind of dest is Ptr, the original value will be zero value - // set value on zero value is not allowed - // unpack to one-level pointer - for dest.Kind() == reflect.Ptr && dest.Elem().Kind() == reflect.Ptr { - dest = dest.Elem() + // check whether the v is a ref holder + if vType == _refHolderPtrType { + h := v.Interface().(*_refHolder) + h.add(dest) + return } - // if the kind of dest is Ptr, change the v to a Ptr value too. - if dest.Kind() == reflect.Ptr { + vRawType, vPtrDepth := UnpackType(vType) - // unpack to one-level pointer - for v.IsValid() && v.Kind() == reflect.Ptr && v.Elem().Kind() == reflect.Ptr { - v = v.Elem() - } - // zero value not need to set - if !v.IsValid() { - return - } + // unpack to the root addressable value, so that to set the value. + dest = UnpackToRootAddressableValue(dest) + destType = dest.Type() + destRawType, destPtrDepth := UnpackType(destType) - if v.Kind() != reflect.Ptr { - // change the v to a Ptr value - v = PackPtr(v) + // it can set the value directly if the raw types are of the same type. + if destRawType == vRawType { + if destPtrDepth > vPtrDepth { + // pack to the same level of dest + for i := 0; i < destPtrDepth-vPtrDepth; i++ { + v = PackPtr(v) + } + } else if destPtrDepth < vPtrDepth { + // unpack to the same level of dest + for i := 0; i < vPtrDepth-destPtrDepth; i++ { + v = v.Elem() + } } - } else { - v = UnpackPtrValue(v) - } - // zero value not need to set - if !v.IsValid() { - return - } - // set value as required type - if dest.Type() == v.Type() && dest.CanSet() { dest.Set(v) - return - } - // unpack ptr so that to special check for float,int,uint kind - if dest.Kind() == reflect.Ptr { - dest = UnpackPtrValue(dest) - v = UnpackPtrValue(v) + return } - kind := dest.Kind() - switch kind { + switch destType.Kind() { case reflect.Float32, reflect.Float64: dest.SetFloat(v.Float()) return @@ -323,39 +328,22 @@ func SetValue(dest, v reflect.Value) { dest.SetInt(v.Int()) return case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - dest.SetUint(v.Uint()) + // hessian only support 64-bit signed long integer. + dest.SetUint(uint64(v.Int())) return case reflect.Ptr: - setRawValueToPointer(dest, v) + SetValueToPtrDest(dest, v) return - } - - dest.Set(v) -} - -// setRawValueToDest set the raw value to dest. -func setRawValueToDest(dest reflect.Value, v reflect.Value) { - if dest.Type() == v.Type() { + default: + // It's ok when the dest is an interface{}, while the v is a pointer. dest.Set(v) - return - } - - if dest.Type().Kind() == reflect.Ptr { - setRawValueToPointer(dest, v) - return } - - dest.Set(v) } -// setRawValueToPointer set the raw value to dest. -func setRawValueToPointer(dest reflect.Value, v reflect.Value) { - pv := PackPtr(v) - if dest.Type() == pv.Type() { - dest.Set(pv) - return - } - +// SetValueToPtrDest set the raw value to a pointer dest. +func SetValueToPtrDest(dest reflect.Value, v reflect.Value) { + // for number, the type of value may be different with the dest, + // must convert it to the correct type of value then set. switch dest.Type() { case _typeOfIntPtr: vv := v.Int() @@ -410,8 +398,18 @@ func setRawValueToPointer(dest reflect.Value, v reflect.Value) { vv := v.Float() dest.Set(reflect.ValueOf(&vv)) return + case _typeOfRunePtr: + if v.Kind() == reflect.String { + vv := Rune(v.String()[0]) + dest.Set(reflect.ValueOf(&vv)) + return + } + + vv := Rune(v.Int()) + dest.Set(reflect.ValueOf(&vv)) + return default: - dest.Set(pv) + dest.Set(v) } } diff --git a/date_test.go b/date_test.go index 2954ffc6..80f21bfa 100644 --- a/date_test.go +++ b/date_test.go @@ -152,8 +152,8 @@ func TestEncDateNull(t *testing.T) { assert.Equal(t, ZeroDate, res.(*DateDemo).Date) assert.Equal(t, 2, len(res.(*DateDemo).Dates)) assert.Equal(t, tz.Local().String(), (*res.(*DateDemo).Dates[0]).String()) - assert.Equal(t, &ZeroDate, res.(*DateDemo).NilDate) - assert.Equal(t, ZeroDate, *res.(*DateDemo).Date1) + assert.Nil(t, res.(*DateDemo).NilDate) + assert.Nil(t, res.(*DateDemo).Date1) assert.Equal(t, tz.Local().String(), (*res.(*DateDemo).Date2).String()) assert.Equal(t, tz.Local().String(), (*(*res.(*DateDemo).Date3)).String()) } @@ -174,6 +174,6 @@ func doTestDateNull(t *testing.T, method string) { testDecodeFrameworkFunc(t, method, func(r interface{}) { t.Logf("%#v", r) assert.Equal(t, ZeroDate, r.(*DateDemo).Date) - assert.Equal(t, &ZeroDate, r.(*DateDemo).Date1) + assert.Nil(t, r.(*DateDemo).Date1) }) } diff --git a/decode.go b/decode.go index 13d2a3c0..be519bc1 100644 --- a/decode.go +++ b/decode.go @@ -324,6 +324,53 @@ func (d *Decoder) DecodeValue() (interface{}, error) { } } +// decToDest decode data to dest value. +// Before and includes the version v1.12.1, it checks all possible types of the destination, +// and then decode the data according to the type. +// But there are too many cases, and it's impossible to handle all of them. +// After v1.12.1, it decodes the data first, and then set the value to the destination. +// If the destination is map, slice, array, it decodes separately. +func (d *Decoder) decToDest(dest reflect.Value) error { + destType := dest.Type() + destRawType := UnpackPtrType(destType) + + // decode for special type, include map, slice, array. + switch destRawType.Kind() { + case reflect.Map: + return d.decMapByValue(dest) + case reflect.Slice, reflect.Array: + m, err := d.decList(TAG_READ) + if err != nil { + if perrors.Is(err, io.EOF) { + return nil + } + return perrors.WithStack(err) + } + + return SetSlice(UnpackPtrValue(dest), m) + } + + dec, err := d.DecodeValue() + if err != nil { + return perrors.Wrapf(err, "decToDest: %s", dest.Type().Name()) + } + + // if dec is nil, then return directly. + if dec == nil { + return nil + } + + if ref, ok := dec.(*_refHolder); ok { + return unpackRefHolder(UnpackPtrValue(dest), destRawType, ref) + } + + decValue := EnsurePackValue(dec) + + SetValue(dest, decValue) + + return nil +} + // /////////////////////////////////////// // typeRefs // /////////////////////////////////////// diff --git a/java_lang.go b/java_lang.go index 82dc723e..6ff715d3 100644 --- a/java_lang.go +++ b/java_lang.go @@ -50,5 +50,7 @@ func init() { type Rune rune var ( - _typeOfRune = reflect.TypeOf(Rune(0)) + _varRune = Rune(0) + _typeOfRune = reflect.TypeOf(_varRune) + _typeOfRunePtr = reflect.TypeOf(&_varRune) ) diff --git a/java_lang_test.go b/java_lang_test.go index dd863ee8..6cc7fa01 100644 --- a/java_lang_test.go +++ b/java_lang_test.go @@ -176,3 +176,52 @@ func TestDecodeJavaCharacterArray(t *testing.T) { t.Logf("%T %+v", got, got) assert.Equal(t, arr, got) } + +type JavaLangObjectHolder struct { + FieldInteger *int32 `json:"fieldInteger"` + FieldLong *int64 `json:"fieldLong"` + FieldBoolean *bool `json:"fieldBoolean"` + FieldShort *int16 `json:"fieldShort"` + FieldByte *int8 `json:"fieldByte"` + FieldFloat *float32 `json:"fieldFloat"` + FieldDouble *float64 `json:"fieldDouble"` + FieldCharacter *Rune `json:"fieldCharacter"` +} + +func (h JavaLangObjectHolder) JavaClassName() string { + return "test.model.JavaLangObjectHolder" +} + +func TestDecodeJavaLangObjectHolder(t *testing.T) { + var a int32 = 123 + var b int64 = 456 + var c = true + var d int16 = 789 + var e int8 = 12 + var f float32 = 3.45 + var g = 6.78 + var h Rune = 'A' + + obj := &JavaLangObjectHolder{ + FieldInteger: &a, + FieldLong: &b, + FieldBoolean: &c, + FieldShort: &d, + FieldByte: &e, + FieldFloat: &f, + FieldDouble: &g, + FieldCharacter: &h, + } + + RegisterPOJO(obj) + + got, err := decodeJavaResponse(`customReplyJavaLangObjectHolder`, ``, false) + assert.NoError(t, err) + t.Logf("customReplyJavaLangObjectHolder: %T %+v", got, got) + assert.Equal(t, obj, got) + + got, err = decodeJavaResponse(`customReplyJavaLangObjectHolderForNull`, ``, false) + assert.NoError(t, err) + t.Logf("customReplyJavaLangObjectHolderForNull: %T %+v", got, got) + assert.Equal(t, &JavaLangObjectHolder{}, got) +} diff --git a/list.go b/list.go index d5b23b4b..7280e090 100644 --- a/list.go +++ b/list.go @@ -393,7 +393,7 @@ func (d *Decoder) readTypedListValue(length int, listTyp string, isVariableArr b } else { if it != nil { //aryValue.Index(j).Set(EnsureRawValue(it)) - setRawValueToDest(aryValue.Index(j), EnsureRawValue(it)) + SetValue(aryValue.Index(j), EnsureRawValue(it)) } } } diff --git a/object.go b/object.go index f5898cf7..6ddbbc52 100644 --- a/object.go +++ b/object.go @@ -18,7 +18,6 @@ package hessian import ( - "io" "reflect" "strings" "sync" @@ -484,158 +483,11 @@ func (d *Decoder) decInstance(typ reflect.Type, cls *ClassInfo) (interface{}, er } field := vv.FieldByIndex(index) - if !field.CanSet() { - return nil, perrors.Errorf("decInstance CanSet false for field %s", fieldName) - } - - // get field type from type object, not do that from value - fldTyp := UnpackPtrType(field.Type()) - - // unpack pointer to enable value setting - fldRawValue := UnpackPtrValue(field) - kind := fldTyp.Kind() - - switch kind { - case reflect.String: - str, err := d.decString(TAG_READ) - if err != nil { - return nil, perrors.Wrapf(err, "decInstance->ReadString: %s", fieldName) - } - fldRawValue.SetString(str) - - case reflect.Int32, reflect.Int16, reflect.Int8: - num, err := d.decInt32(TAG_READ) - if err != nil { - // java enum - if fldRawValue.Type().Implements(javaEnumType) { - _ = d.unreadByte() // Enum parsing, decInt64 above has read a byte, so you need to return a byte here - enumVal, decErr := d.DecodeValue() - if decErr != nil { - return nil, perrors.Wrapf(decErr, "decInstance->decObject field name:%s", fieldName) - } - - SetValue(fldRawValue, reflect.ValueOf(enumVal)) - - continue - } - - return nil, perrors.Wrapf(err, "decInstance->decInt32, field name:%s", fieldName) - } - - fldRawValue.SetInt(int64(num)) - case reflect.Uint16, reflect.Uint8: - num, err := d.decInt32(TAG_READ) - if err != nil { - return nil, perrors.Wrapf(err, "decInstance->decInt32, field name:%s", fieldName) - } - fldRawValue.SetUint(uint64(num)) - case reflect.Uint, reflect.Int, reflect.Int64: - num, err := d.decInt64(TAG_READ) - if err != nil { - if fldTyp.Implements(javaEnumType) { - d.unreadByte() // Enum parsing, decInt64 above has read a byte, so you need to return a byte here - s, decErr := d.Decode() - if decErr != nil { - return nil, perrors.Wrapf(decErr, "decInstance->decObject field name:%s", fieldName) - } - enumValue, _ := s.(JavaEnum) - num = int64(enumValue) - } else { - return nil, perrors.Wrapf(err, "decInstance->decInt64 field name:%s", fieldName) - } - } - - fldRawValue.SetInt(num) - case reflect.Uint32, reflect.Uint64: - num, err := d.decInt64(TAG_READ) - if err != nil { - return nil, perrors.Wrapf(err, "decInstance->decInt64, field name:%s", fieldName) - } - fldRawValue.SetUint(uint64(num)) - case reflect.Bool: - b, err := d.Decode() - if err != nil { - return nil, perrors.Wrapf(err, "decInstance->Decode field name:%s", fieldName) - } - v, ok := b.(bool) - if !ok { - return nil, perrors.Errorf("value convert to bool failed, field name:%s", fieldName) - } - - if fldRawValue.Kind() == reflect.Ptr && fldRawValue.CanSet() { - if b != nil { - field.Set(reflect.ValueOf(&v)) - } - } else if fldRawValue.Kind() != reflect.Ptr { - fldRawValue.SetBool(v) - } - case reflect.Float32, reflect.Float64: - num, err := d.decDouble(TAG_READ) - if err != nil { - return nil, perrors.Wrapf(err, "decInstance->decDouble field name:%s", fieldName) - } - fldRawValue.SetFloat(num.(float64)) - - case reflect.Map: - // decode map should use the original field value for correct value setting - err := d.decMapByValue(field) - if err != nil { - return nil, perrors.Wrapf(err, "decInstance->decMapByValue field name: %s", fieldName) - } - - case reflect.Slice, reflect.Array: - m, err := d.decList(TAG_READ) - if err != nil { - if perrors.Is(err, io.EOF) { - break - } - return nil, perrors.WithStack(err) - } - - // set slice separately - err = SetSlice(fldRawValue, m) - if err != nil { - return nil, err - } - case reflect.Struct: - var ( - err error - s interface{} - ) - fldType := UnpackPtrType(fldRawValue.Type()) - if fldType.String() == "time.Time" { - s, err = d.decDate(TAG_READ) - if err != nil { - return nil, perrors.WithStack(err) - } - SetValue(fldRawValue, EnsurePackValue(s)) - } else { - s, err = d.decObject(TAG_READ) - if err != nil { - return nil, perrors.WithStack(err) - } - if s != nil { - // set value which accepting pointers - SetValue(fldRawValue, EnsurePackValue(s)) - } - } - case reflect.Interface: - s, err := d.DecodeValue() - if err != nil { - return nil, perrors.WithStack(err) - } - if s != nil { - if ref, ok := s.(*_refHolder); ok { - _ = unpackRefHolder(fldRawValue, fldTyp, ref) - } else { - // set value which accepting pointers - SetValue(fldRawValue, EnsurePackValue(s)) - } - } - default: - return nil, perrors.Errorf("unknown struct member type: %v %v", kind, typ.Name()+"."+fieldStruct.Name) + if err = d.decToDest(field); err != nil { + return nil, perrors.Wrapf(err, "decInstance->DecodeValue: %s", fieldName) } + } // end for return vRef.Interface(), nil diff --git a/object_test.go b/object_test.go index 41e6a0f9..a47bd1bd 100644 --- a/object_test.go +++ b/object_test.go @@ -953,7 +953,7 @@ func TestCustomReplyGenericResponseBusinessData(t *testing.T) { } res := &GenericResponse{ Code: 201, - Data: data, + Data: &data, } RegisterPOJO(data) RegisterPOJO(res) @@ -1057,7 +1057,7 @@ func TestWrapperClassArray(t *testing.T) { } type User struct { - Id int32 + Id *int32 List []int32 } @@ -1067,12 +1067,13 @@ func (u *User) JavaClassName() string { func TestDecodeIntegerHasNull(t *testing.T) { RegisterPOJO(&User{}) - testDecodeFramework(t, "customReplyTypedIntegerHasNull", &User{Id: 0}) + + testDecodeFramework(t, "customReplyTypedIntegerHasNull", &User{Id: nil}) } func TestDecodeSliceIntegerHasNull(t *testing.T) { RegisterPOJO(&User{}) - testDecodeFramework(t, "customReplyTypedListIntegerHasNull", &User{Id: 0, List: []int32{1, 0}}) + testDecodeFramework(t, "customReplyTypedListIntegerHasNull", &User{Id: nil, List: []int32{1, 0}}) } func TestDecodeCustomReplyEnumVariableList(t *testing.T) { diff --git a/ref.go b/ref.go index 0767ca07..80678173 100644 --- a/ref.go +++ b/ref.go @@ -32,6 +32,9 @@ var _emptySliceAddr = unsafe.Pointer(reflect.ValueOf([]interface{}{}).Pointer()) // The addresses of all nil map are the same. var _nilMapAddr = unsafe.Pointer(reflect.ValueOf(map[interface{}]interface{}(nil)).Pointer()) +// the ref holder pointer type. +var _refHolderPtrType = reflect.TypeOf(&_refHolder{}) + // used to ref object,list,map type _refElem struct { // record the kind of target, objects are the same only if the address and kind are the same diff --git a/test_hessian/src/main/java/test/TestCustomReply.java b/test_hessian/src/main/java/test/TestCustomReply.java index f396d4f1..4dbb2fc6 100644 --- a/test_hessian/src/main/java/test/TestCustomReply.java +++ b/test_hessian/src/main/java/test/TestCustomReply.java @@ -26,6 +26,7 @@ import test.generic.Response; import test.model.CustomMap; import test.model.DateDemo; +import test.model.JavaLangObjectHolder; import test.model.User; import java.io.OutputStream; @@ -797,6 +798,29 @@ public void customReplyEnumVariableList() throws Exception { output.writeObject(enumList.toArray(new Locale.Category[enumList.size()])); output.flush(); } + + public void customReplyJavaLangObjectHolder() throws Exception { + JavaLangObjectHolder holder = new JavaLangObjectHolder(); + + holder.setFieldInteger(123); + holder.setFieldLong(456L); + holder.setFieldBoolean(true); + holder.setFieldShort((short) 789); + holder.setFieldByte((byte) 12); + holder.setFieldFloat(3.45f); + holder.setFieldDouble(6.78); + holder.setFieldCharacter('A'); + + output.writeObject(holder); + output.flush(); + } + + public void customReplyJavaLangObjectHolderForNull() throws Exception { + // all fields are default null. + JavaLangObjectHolder holder = new JavaLangObjectHolder(); + output.writeObject(holder); + output.flush(); + } } interface Leg { diff --git a/test_hessian/src/main/java/test/model/JavaLangObjectHolder.java b/test_hessian/src/main/java/test/model/JavaLangObjectHolder.java new file mode 100644 index 00000000..3b6f2bcc --- /dev/null +++ b/test_hessian/src/main/java/test/model/JavaLangObjectHolder.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.model; + +import java.io.Serializable; + +/** + * @author tiltwind + */ +public class JavaLangObjectHolder implements Serializable { + private Integer fieldInteger; + private Long fieldLong; + private Boolean fieldBoolean; + private Double fieldDouble; + private Float fieldFloat; + private Short fieldShort; + private Byte fieldByte; + private Character fieldCharacter; + + public Integer getFieldInteger() { + return fieldInteger; + } + + public void setFieldInteger(Integer fieldInteger) { + this.fieldInteger = fieldInteger; + } + + public Long getFieldLong() { + return fieldLong; + } + + public void setFieldLong(Long fieldLong) { + this.fieldLong = fieldLong; + } + + public Boolean getFieldBoolean() { + return fieldBoolean; + } + + public void setFieldBoolean(Boolean fieldBoolean) { + this.fieldBoolean = fieldBoolean; + } + + public Double getFieldDouble() { + return fieldDouble; + } + + public void setFieldDouble(Double fieldDouble) { + this.fieldDouble = fieldDouble; + } + + public Float getFieldFloat() { + return fieldFloat; + } + + public void setFieldFloat(Float fieldFloat) { + this.fieldFloat = fieldFloat; + } + + public Short getFieldShort() { + return fieldShort; + } + + public void setFieldShort(Short fieldShort) { + this.fieldShort = fieldShort; + } + + public Byte getFieldByte() { + return fieldByte; + } + + public void setFieldByte(Byte fieldByte) { + this.fieldByte = fieldByte; + } + + public Character getFieldCharacter() { + return fieldCharacter; + } + + public void setFieldCharacter(Character fieldCharacter) { + this.fieldCharacter = fieldCharacter; + } +} diff --git a/testcases/user/user_test.go b/testcases/user/user_test.go index 3c7cf9df..d5e721c3 100644 --- a/testcases/user/user_test.go +++ b/testcases/user/user_test.go @@ -53,5 +53,6 @@ func TestUserEncodeDecode(t *testing.T) { decoder := hessian.NewDecoder(buf) dec, err := decoder.Decode() assert.Nil(t, err) + assert.Equal(t, u1, dec) }