diff --git a/jsonpb/decode.go b/jsonpb/decode.go index 7c6c5a5244..60e82caa9a 100644 --- a/jsonpb/decode.go +++ b/jsonpb/decode.go @@ -135,14 +135,14 @@ func (u *Unmarshaler) unmarshalMessage(m protoreflect.Message, in []byte) error md := m.Descriptor() fds := md.Fields() - if string(in) == "null" && md.FullName() != "google.protobuf.Value" { - return nil - } - if jsu, ok := proto.MessageV1(m.Interface()).(JSONPBUnmarshaler); ok { return jsu.UnmarshalJSONPB(u, in) } + if string(in) == "null" && md.FullName() != "google.protobuf.Value" { + return nil + } + switch wellKnownType(md.FullName()) { case "Any": var jsonObject map[string]json.RawMessage @@ -332,11 +332,12 @@ func (u *Unmarshaler) unmarshalMessage(m protoreflect.Message, in []byte) error raw = v } + field := m.NewField(fd) // Unmarshal the field value. - if raw == nil || (string(raw) == "null" && !isSingularWellKnownValue(fd)) { + if raw == nil || (string(raw) == "null" && !isSingularWellKnownValue(fd) && !isSingularJSONPBUnmarshaler(field, fd)) { continue } - v, err := u.unmarshalValue(m.NewField(fd), raw, fd) + v, err := u.unmarshalValue(field, raw, fd) if err != nil { return err } @@ -364,11 +365,12 @@ func (u *Unmarshaler) unmarshalMessage(m protoreflect.Message, in []byte) error return fmt.Errorf("extension field %q does not extend message %q", xname, m.Descriptor().FullName()) } + field := m.NewField(fd) // Unmarshal the field value. - if raw == nil || (string(raw) == "null" && !isSingularWellKnownValue(fd)) { + if raw == nil || (string(raw) == "null" && !isSingularWellKnownValue(fd) && !isSingularJSONPBUnmarshaler(field, fd)) { continue } - v, err := u.unmarshalValue(m.NewField(fd), raw, fd) + v, err := u.unmarshalValue(field, raw, fd) if err != nil { return err } @@ -390,6 +392,14 @@ func isSingularWellKnownValue(fd protoreflect.FieldDescriptor) bool { return false } +func isSingularJSONPBUnmarshaler(v protoreflect.Value, fd protoreflect.FieldDescriptor) bool { + if fd.Message() != nil && fd.Cardinality() != protoreflect.Repeated { + _, ok := proto.MessageV1(v.Interface()).(JSONPBUnmarshaler) + return ok + } + return false +} + func (u *Unmarshaler) unmarshalValue(v protoreflect.Value, in []byte, fd protoreflect.FieldDescriptor) (protoreflect.Value, error) { switch { case fd.IsList(): diff --git a/jsonpb/json_test.go b/jsonpb/json_test.go index 0ef23f2d30..a98ad169f2 100644 --- a/jsonpb/json_test.go +++ b/jsonpb/json_test.go @@ -1009,7 +1009,7 @@ func TestUnmarshalNullWithJSONPBUnmarshaler(t *testing.T) { t.Errorf("unmarshal error: %v", err) } - want := ptrFieldMessage{} + want := ptrFieldMessage{StringField: &stringField{IsSet: true, StringValue: "null"}} if !proto.Equal(&ptrFieldMsg, &want) { t.Errorf("unmarshal result StringField: got %v, want %v", ptrFieldMsg, want) }