diff --git a/jsonpb/jsonpb.go b/jsonpb/jsonpb.go index 1f64aabb87..e8134ec8ba 100644 --- a/jsonpb/jsonpb.go +++ b/jsonpb/jsonpb.go @@ -164,6 +164,11 @@ type isWkt interface { XXX_WellKnownType() string } +var ( + wktType = reflect.TypeOf((*isWkt)(nil)).Elem() + messageType = reflect.TypeOf((*proto.Message)(nil)).Elem() +) + // marshalObject writes a struct to the Writer. func (m *Marshaler) marshalObject(out *errWriter, v proto.Message, indent, typeURL string) error { if jsm, ok := v.(JSONPBMarshaler); ok { @@ -551,7 +556,8 @@ func (m *Marshaler) marshalValue(out *errWriter, prop *proto.Properties, v refle // Handle well-known types. // Most are handled up in marshalObject (because 99% are messages). - if wkt, ok := v.Interface().(isWkt); ok { + if v.Type().Implements(wktType) { + wkt := v.Interface().(isWkt) switch wkt.XXX_WellKnownType() { case "NullValue": out.write("null") @@ -1422,8 +1428,8 @@ func checkRequiredFields(pb proto.Message) error { } func checkRequiredFieldsInValue(v reflect.Value) error { - if pm, ok := v.Interface().(proto.Message); ok { - return checkRequiredFields(pm) + if v.Type().Implements(messageType) { + return checkRequiredFields(v.Interface().(proto.Message)) } return nil } diff --git a/proto/all_test.go b/proto/all_test.go index f391af74c3..966ceb567d 100644 --- a/proto/all_test.go +++ b/proto/all_test.go @@ -45,9 +45,11 @@ import ( "testing" "time" + "github.com/gogo/protobuf/jsonpb" . "github.com/gogo/protobuf/proto" pb3 "github.com/gogo/protobuf/proto/proto3_proto" . "github.com/gogo/protobuf/proto/test_proto" + descriptorpb "github.com/gogo/protobuf/protoc-gen-gogo/descriptor" ) var globalO *Buffer @@ -2509,3 +2511,33 @@ func BenchmarkUnmarshalUnrecognizedFields(b *testing.B) { p2.Unmarshal(pbd) } } + +// TestRace tests whether there are races among the different marshalers. +func TestRace(t *testing.T) { + m := &descriptorpb.FileDescriptorProto{ + Options: &descriptorpb.FileOptions{ + GoPackage: String("path/to/my/package"), + }, + } + + wg := &sync.WaitGroup{} + defer wg.Wait() + + wg.Add(1) + go func() { + defer wg.Done() + Marshal(m) + }() + + wg.Add(1) + go func() { + defer wg.Done() + (&jsonpb.Marshaler{}).MarshalToString(m) + }() + + wg.Add(1) + go func() { + defer wg.Done() + _ = m.String() + }() +} diff --git a/proto/text.go b/proto/text.go index 0407ba85d0..87416afe95 100644 --- a/proto/text.go +++ b/proto/text.go @@ -476,6 +476,8 @@ func (tm *TextMarshaler) writeStruct(w *textWriter, sv reflect.Value) error { return nil } +var textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() + // writeAny writes an arbitrary field. func (tm *TextMarshaler) writeAny(w *textWriter, v reflect.Value, props *Properties) error { v = reflect.Indirect(v) @@ -589,8 +591,8 @@ func (tm *TextMarshaler) writeAny(w *textWriter, v reflect.Value, props *Propert // mutating this value. v = v.Addr() } - if etm, ok := v.Interface().(encoding.TextMarshaler); ok { - text, err := etm.MarshalText() + if v.Type().Implements(textMarshalerType) { + text, err := v.Interface().(encoding.TextMarshaler).MarshalText() if err != nil { return err }