From dad2dedc9037e78707d666be2095bf552e5ea228 Mon Sep 17 00:00:00 2001 From: Braydon Kains Date: Wed, 9 Feb 2022 23:06:32 +0000 Subject: [PATCH 1/2] yaml: typeError implements PrettyPrinter interface When unmarshalling YAML and encountering an error in decoded that results in a `yaml.typeError`, the error can't have the same nice formatting because it doesn't implement the `errors.PrettyPrinter` interface. This adds a field to the `yaml.typeError` struct that specifies a token which is used in the implementation of the PrettyPrinter interface. The tests were updated to instead match on expected message substrings instead of full equality. --- decode.go | 45 ++++++++++++++++++++++++++++--------- decode_test.go | 48 ++++++++++++++++++++-------------------- internal/errors/error.go | 24 ++++++++++---------- 3 files changed, 71 insertions(+), 46 deletions(-) diff --git a/decode.go b/decode.go index 5d1f63a..7ad5661 100644 --- a/decode.go +++ b/decode.go @@ -18,6 +18,7 @@ import ( "github.com/goccy/go-yaml/ast" "github.com/goccy/go-yaml/internal/errors" "github.com/goccy/go-yaml/parser" + "github.com/goccy/go-yaml/printer" "github.com/goccy/go-yaml/token" "golang.org/x/xerrors" ) @@ -367,10 +368,10 @@ func (d *Decoder) fileToNode(f *ast.File) ast.Node { return nil } -func (d *Decoder) convertValue(v reflect.Value, typ reflect.Type) (reflect.Value, error) { +func (d *Decoder) convertValue(v reflect.Value, typ reflect.Type, src ast.Node) (reflect.Value, error) { if typ.Kind() != reflect.String { if !v.Type().ConvertibleTo(typ) { - return reflect.Zero(typ), errTypeMismatch(typ, v.Type()) + return reflect.Zero(typ), errTypeMismatch(typ, v.Type(), src.GetToken()) } return v.Convert(typ), nil } @@ -386,7 +387,7 @@ func (d *Decoder) convertValue(v reflect.Value, typ reflect.Type) (reflect.Value return reflect.ValueOf(fmt.Sprint(v.Bool())), nil } if !v.Type().ConvertibleTo(typ) { - return reflect.Zero(typ), errTypeMismatch(typ, v.Type()) + return reflect.Zero(typ), errTypeMismatch(typ, v.Type(), src.GetToken()) } return v.Convert(typ), nil } @@ -408,6 +409,7 @@ type typeError struct { dstType reflect.Type srcType reflect.Type structFieldName *string + token *token.Token } func (e *typeError) Error() string { @@ -417,8 +419,31 @@ func (e *typeError) Error() string { return fmt.Sprintf("cannot unmarshal %s into Go value of type %s", e.srcType, e.dstType) } -func errTypeMismatch(dstType, srcType reflect.Type) *typeError { - return &typeError{dstType: dstType, srcType: srcType} +func (e *typeError) PrettyPrint(p xerrors.Printer, colored, inclSource bool) error { + return e.FormatError(&errors.FormatErrorPrinter{Printer: p, Colored: colored, InclSource: inclSource}) +} + +func (e *typeError) FormatError(p xerrors.Printer) error { + var pp printer.Printer + + var colored, inclSource bool + if fep, ok := p.(*errors.FormatErrorPrinter); ok { + colored = fep.Colored + inclSource = fep.InclSource + } + + pos := fmt.Sprintf("[%d:%d] ", e.token.Position.Line, e.token.Position.Column) + msg := pp.PrintErrorMessage(fmt.Sprintf("%s%s", pos, e.Error()), colored) + if inclSource { + msg += "\n" + pp.PrintErrorToken(e.token, colored) + } + p.Print(msg) + + return nil +} + +func errTypeMismatch(dstType, srcType reflect.Type, token *token.Token) *typeError { + return &typeError{dstType: dstType, srcType: srcType, token: token} } type unknownFieldError struct { @@ -709,7 +734,7 @@ func (d *Decoder) decodeValue(ctx context.Context, dst reflect.Value, src ast.No return nil } default: - return errTypeMismatch(valueType, reflect.TypeOf(v)) + return errTypeMismatch(valueType, reflect.TypeOf(v), src.GetToken()) } return errOverflow(valueType, fmt.Sprint(v)) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: @@ -731,13 +756,13 @@ func (d *Decoder) decodeValue(ctx context.Context, dst reflect.Value, src ast.No return nil } default: - return errTypeMismatch(valueType, reflect.TypeOf(v)) + return errTypeMismatch(valueType, reflect.TypeOf(v), src.GetToken()) } return errOverflow(valueType, fmt.Sprint(v)) } v := reflect.ValueOf(d.nodeToValue(src)) if v.IsValid() { - convertedValue, err := d.convertValue(v, dst.Type()) + convertedValue, err := d.convertValue(v, dst.Type(), src) if err != nil { return errors.Wrapf(err, "failed to convert value") } @@ -905,7 +930,7 @@ func (d *Decoder) castToTime(src ast.Node) (time.Time, error) { } s, ok := v.(string) if !ok { - return time.Time{}, errTypeMismatch(reflect.TypeOf(time.Time{}), reflect.TypeOf(v)) + return time.Time{}, errTypeMismatch(reflect.TypeOf(time.Time{}), reflect.TypeOf(v), src.GetToken()) } for _, format := range allowedTimestampFormats { t, err := time.Parse(format, s) @@ -937,7 +962,7 @@ func (d *Decoder) castToDuration(src ast.Node) (time.Duration, error) { } s, ok := v.(string) if !ok { - return 0, errTypeMismatch(reflect.TypeOf(time.Duration(0)), reflect.TypeOf(v)) + return 0, errTypeMismatch(reflect.TypeOf(time.Duration(0)), reflect.TypeOf(v), src.GetToken()) } t, err := time.ParseDuration(s) if err != nil { diff --git a/decode_test.go b/decode_test.go index fb13b9a..4d93f1f 100644 --- a/decode_test.go +++ b/decode_test.go @@ -1074,8 +1074,8 @@ func TestDecoder_TypeConversionError(t *testing.T) { t.Fatal("expected to error") } msg := "cannot unmarshal string into Go struct field T.A of type int" - if err.Error() != msg { - t.Fatalf("unexpected error message: %s. expect: %s", err.Error(), msg) + if !strings.Contains(err.Error(), msg) { + t.Fatalf("expected error message: %s to contain: %s", err.Error(), msg) } }) t.Run("string to bool", func(t *testing.T) { @@ -1085,8 +1085,8 @@ func TestDecoder_TypeConversionError(t *testing.T) { t.Fatal("expected to error") } msg := "cannot unmarshal string into Go struct field T.D of type bool" - if err.Error() != msg { - t.Fatalf("unexpected error message: %s. expect: %s", err.Error(), msg) + if !strings.Contains(err.Error(), msg) { + t.Fatalf("expected error message: %s to contain: %s", err.Error(), msg) } }) t.Run("string to int at inline", func(t *testing.T) { @@ -1096,8 +1096,8 @@ func TestDecoder_TypeConversionError(t *testing.T) { t.Fatal("expected to error") } msg := "cannot unmarshal string into Go struct field U.T.A of type int" - if err.Error() != msg { - t.Fatalf("unexpected error message: %s. expect: %s", err.Error(), msg) + if !strings.Contains(err.Error(), msg) { + t.Fatalf("expected error message: %s to contain: %s", err.Error(), msg) } }) }) @@ -1109,8 +1109,8 @@ func TestDecoder_TypeConversionError(t *testing.T) { t.Fatal("expected to error") } msg := "cannot unmarshal string into Go value of type int" - if err.Error() != msg { - t.Fatalf("unexpected error message: %s. expect: %s", err.Error(), msg) + if !strings.Contains(err.Error(), msg) { + t.Fatalf("expected error message: %s to contain: %s", err.Error(), msg) } if len(v) == 0 || len(v["v"]) == 0 { t.Fatal("failed to decode value") @@ -1126,8 +1126,8 @@ func TestDecoder_TypeConversionError(t *testing.T) { t.Fatal("expected to error") } msg := "cannot unmarshal string into Go value of type int" - if err.Error() != msg { - t.Fatalf("unexpected error message: %s. expect: %s", err.Error(), msg) + if !strings.Contains(err.Error(), msg) { + t.Fatalf("expected error message: %s to contain: %s", err.Error(), msg) } if len(v) == 0 || len(v["v"]) == 0 { t.Fatal("failed to decode value") @@ -1145,8 +1145,8 @@ func TestDecoder_TypeConversionError(t *testing.T) { t.Fatal("expected to error") } msg := "cannot unmarshal -42 into Go value of type uint ( overflow )" - if err.Error() != msg { - t.Fatalf("unexpected error message: %s. expect: %s", err.Error(), msg) + if !strings.Contains(err.Error(), msg) { + t.Fatalf("expected error message: %s to contain: %s", err.Error(), msg) } if v["v"] != 0 { t.Fatal("failed to decode value") @@ -1159,8 +1159,8 @@ func TestDecoder_TypeConversionError(t *testing.T) { t.Fatal("expected to error") } msg := "cannot unmarshal -4294967296 into Go value of type uint64 ( overflow )" - if err.Error() != msg { - t.Fatalf("unexpected error message: %s. expect: %s", err.Error(), msg) + if !strings.Contains(err.Error(), msg) { + t.Fatalf("expected error message: %s to contain: %s", err.Error(), msg) } if v["v"] != 0 { t.Fatal("failed to decode value") @@ -1173,8 +1173,8 @@ func TestDecoder_TypeConversionError(t *testing.T) { t.Fatal("expected to error") } msg := "cannot unmarshal 4294967297 into Go value of type int32 ( overflow )" - if err.Error() != msg { - t.Fatalf("unexpected error message: %s. expect: %s", err.Error(), msg) + if !strings.Contains(err.Error(), msg) { + t.Fatalf("expected error message: %s to contain: %s", err.Error(), msg) } if v["v"] != 0 { t.Fatal("failed to decode value") @@ -1187,8 +1187,8 @@ func TestDecoder_TypeConversionError(t *testing.T) { t.Fatal("expected to error") } msg := "cannot unmarshal 128 into Go value of type int8 ( overflow )" - if err.Error() != msg { - t.Fatalf("unexpected error message: %s. expect: %s", err.Error(), msg) + if !strings.Contains(err.Error(), msg) { + t.Fatalf("expected error message: %s to contain: %s", err.Error(), msg) } if v["v"] != 0 { t.Fatal("failed to decode value") @@ -1207,8 +1207,8 @@ func TestDecoder_TypeConversionError(t *testing.T) { t.Fatal("expected to error") } msg := "cannot unmarshal uint64 into Go struct field T.A of type time.Time" - if err.Error() != msg { - t.Fatalf("unexpected error message: %s. expect: %s", err.Error(), msg) + if !strings.Contains(err.Error(), msg) { + t.Fatalf("expected error message: %s to contain: %s", err.Error(), msg) } }) t.Run("string to duration", func(t *testing.T) { @@ -1218,8 +1218,8 @@ func TestDecoder_TypeConversionError(t *testing.T) { t.Fatal("expected to error") } msg := `time: invalid duration "str"` - if err.Error() != msg { - t.Fatalf("unexpected error message: %s. expect: %s", err.Error(), msg) + if !strings.Contains(err.Error(), msg) { + t.Fatalf("expected error message: %s to contain: %s", err.Error(), msg) } }) t.Run("int to duration", func(t *testing.T) { @@ -1229,8 +1229,8 @@ func TestDecoder_TypeConversionError(t *testing.T) { t.Fatal("expected to error") } msg := "cannot unmarshal uint64 into Go struct field T.B of type time.Duration" - if err.Error() != msg { - t.Fatalf("unexpected error message: %s. expect: %s", err.Error(), msg) + if !strings.Contains(err.Error(), msg) { + t.Fatalf("expected error message: %s to contain: %s", err.Error(), msg) } }) }) diff --git a/internal/errors/error.go b/internal/errors/error.go index c7b1103..6b22460 100644 --- a/internal/errors/error.go +++ b/internal/errors/error.go @@ -67,10 +67,10 @@ type wrapError struct { frame xerrors.Frame } -type myprinter struct { +type FormatErrorPrinter struct { xerrors.Printer - colored bool - inclSource bool + Colored bool + InclSource bool } func (e *wrapError) As(target interface{}) bool { @@ -90,15 +90,15 @@ func (e *wrapError) Unwrap() error { } func (e *wrapError) PrettyPrint(p xerrors.Printer, colored, inclSource bool) error { - return e.FormatError(&myprinter{Printer: p, colored: colored, inclSource: inclSource}) + return e.FormatError(&FormatErrorPrinter{Printer: p, Colored: colored, InclSource: inclSource}) } func (e *wrapError) FormatError(p xerrors.Printer) error { - if _, ok := p.(*myprinter); !ok { - p = &myprinter{ + if _, ok := p.(*FormatErrorPrinter); !ok { + p = &FormatErrorPrinter{ Printer: p, - colored: defaultColorize, - inclSource: defaultIncludeSource, + Colored: defaultColorize, + InclSource: defaultIncludeSource, } } if e.verb == 'v' && e.state.Flag('+') { @@ -171,16 +171,16 @@ type syntaxError struct { } func (e *syntaxError) PrettyPrint(p xerrors.Printer, colored, inclSource bool) error { - return e.FormatError(&myprinter{Printer: p, colored: colored, inclSource: inclSource}) + return e.FormatError(&FormatErrorPrinter{Printer: p, Colored: colored, InclSource: inclSource}) } func (e *syntaxError) FormatError(p xerrors.Printer) error { var pp printer.Printer var colored, inclSource bool - if mp, ok := p.(*myprinter); ok { - colored = mp.colored - inclSource = mp.inclSource + if fep, ok := p.(*FormatErrorPrinter); ok { + colored = fep.Colored + inclSource = fep.InclSource } pos := fmt.Sprintf("[%d:%d] ", e.token.Position.Line, e.token.Position.Column) From bf2da56730636857ee05e890c598140a9e27eb87 Mon Sep 17 00:00:00 2001 From: Braydon Kains Date: Mon, 2 May 2022 18:00:24 +0000 Subject: [PATCH 2/2] internal/errors: moved typeError to errors package The typeError previously located in decode.go now implements PrettyPrinter, so it's been moved to internal/errors to be colocated with the other PrettyPrinter errors. --- decode.go | 56 +++++++--------------------------------- internal/errors/error.go | 38 +++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 47 deletions(-) diff --git a/decode.go b/decode.go index 7ad5661..d519b78 100644 --- a/decode.go +++ b/decode.go @@ -18,7 +18,6 @@ import ( "github.com/goccy/go-yaml/ast" "github.com/goccy/go-yaml/internal/errors" "github.com/goccy/go-yaml/parser" - "github.com/goccy/go-yaml/printer" "github.com/goccy/go-yaml/token" "golang.org/x/xerrors" ) @@ -405,45 +404,8 @@ func errOverflow(dstType reflect.Type, num string) *overflowError { return &overflowError{dstType: dstType, srcNum: num} } -type typeError struct { - dstType reflect.Type - srcType reflect.Type - structFieldName *string - token *token.Token -} - -func (e *typeError) Error() string { - if e.structFieldName != nil { - return fmt.Sprintf("cannot unmarshal %s into Go struct field %s of type %s", e.srcType, *e.structFieldName, e.dstType) - } - return fmt.Sprintf("cannot unmarshal %s into Go value of type %s", e.srcType, e.dstType) -} - -func (e *typeError) PrettyPrint(p xerrors.Printer, colored, inclSource bool) error { - return e.FormatError(&errors.FormatErrorPrinter{Printer: p, Colored: colored, InclSource: inclSource}) -} - -func (e *typeError) FormatError(p xerrors.Printer) error { - var pp printer.Printer - - var colored, inclSource bool - if fep, ok := p.(*errors.FormatErrorPrinter); ok { - colored = fep.Colored - inclSource = fep.InclSource - } - - pos := fmt.Sprintf("[%d:%d] ", e.token.Position.Line, e.token.Position.Column) - msg := pp.PrintErrorMessage(fmt.Sprintf("%s%s", pos, e.Error()), colored) - if inclSource { - msg += "\n" + pp.PrintErrorToken(e.token, colored) - } - p.Print(msg) - - return nil -} - -func errTypeMismatch(dstType, srcType reflect.Type, token *token.Token) *typeError { - return &typeError{dstType: dstType, srcType: srcType, token: token} +func errTypeMismatch(dstType, srcType reflect.Type, token *token.Token) *errors.TypeError { + return &errors.TypeError{DstType: dstType, SrcType: srcType, Token: token} } type unknownFieldError struct { @@ -1077,14 +1039,14 @@ func (d *Decoder) decodeStruct(ctx context.Context, dst reflect.Value, src ast.N if foundErr != nil { continue } - var te *typeError + var te *errors.TypeError if xerrors.As(err, &te) { - if te.structFieldName != nil { - fieldName := fmt.Sprintf("%s.%s", structType.Name(), *te.structFieldName) - te.structFieldName = &fieldName + if te.StructFieldName != nil { + fieldName := fmt.Sprintf("%s.%s", structType.Name(), *te.StructFieldName) + te.StructFieldName = &fieldName } else { fieldName := fmt.Sprintf("%s.%s", structType.Name(), field.Name) - te.structFieldName = &fieldName + te.StructFieldName = &fieldName } foundErr = te continue @@ -1113,10 +1075,10 @@ func (d *Decoder) decodeStruct(ctx context.Context, dst reflect.Value, src ast.N if foundErr != nil { continue } - var te *typeError + var te *errors.TypeError if xerrors.As(err, &te) { fieldName := fmt.Sprintf("%s.%s", structType.Name(), field.Name) - te.structFieldName = &fieldName + te.StructFieldName = &fieldName foundErr = te } else { foundErr = err diff --git a/internal/errors/error.go b/internal/errors/error.go index 6b22460..7f1ea9a 100644 --- a/internal/errors/error.go +++ b/internal/errors/error.go @@ -3,6 +3,7 @@ package errors import ( "bytes" "fmt" + "reflect" "github.com/goccy/go-yaml/printer" "github.com/goccy/go-yaml/token" @@ -220,3 +221,40 @@ func (e *syntaxError) Error() string { e.PrettyPrint(&Sink{&buf}, defaultColorize, defaultIncludeSource) return buf.String() } + +type TypeError struct { + DstType reflect.Type + SrcType reflect.Type + StructFieldName *string + Token *token.Token +} + +func (e *TypeError) Error() string { + if e.StructFieldName != nil { + return fmt.Sprintf("cannot unmarshal %s into Go struct field %s of type %s", e.SrcType, *e.StructFieldName, e.DstType) + } + return fmt.Sprintf("cannot unmarshal %s into Go value of type %s", e.SrcType, e.DstType) +} + +func (e *TypeError) PrettyPrint(p xerrors.Printer, colored, inclSource bool) error { + return e.FormatError(&FormatErrorPrinter{Printer: p, Colored: colored, InclSource: inclSource}) +} + +func (e *TypeError) FormatError(p xerrors.Printer) error { + var pp printer.Printer + + var colored, inclSource bool + if fep, ok := p.(*FormatErrorPrinter); ok { + colored = fep.Colored + inclSource = fep.InclSource + } + + pos := fmt.Sprintf("[%d:%d] ", e.Token.Position.Line, e.Token.Position.Column) + msg := pp.PrintErrorMessage(fmt.Sprintf("%s%s", pos, e.Error()), colored) + if inclSource { + msg += "\n" + pp.PrintErrorToken(e.Token, colored) + } + p.Print(msg) + + return nil +}