From 941abdb723fb46c4a9cb02afec4e64d240edc51d Mon Sep 17 00:00:00 2001 From: Florian Rey Date: Fri, 16 Sep 2022 12:52:14 +0200 Subject: [PATCH] Handle unmarshall unknown aliases --- decode.go | 104 +++++++++++++++++++++++++++++++++++++------------ decode_test.go | 96 ++++++++++++++++++++++++++++++++------------- 2 files changed, 147 insertions(+), 53 deletions(-) diff --git a/decode.go b/decode.go index 2682b2a..8cd7d85 100644 --- a/decode.go +++ b/decode.go @@ -265,36 +265,67 @@ func (d *Decoder) nodeToValue(node ast.Node) interface{} { return nil } -func (d *Decoder) resolveAlias(node ast.Node) ast.Node { +func (d *Decoder) resolveAlias(node ast.Node) (ast.Node, error) { switch n := node.(type) { case *ast.MappingNode: - for idx, value := range n.Values { - n.Values[idx] = d.resolveAlias(value).(*ast.MappingValueNode) + for idx, v := range n.Values { + value, err := d.resolveAlias(v) + if err != nil { + return nil, err + } + n.Values[idx] = value.(*ast.MappingValueNode) } case *ast.TagNode: - n.Value = d.resolveAlias(n.Value) + value, err := d.resolveAlias(n.Value) + if err != nil { + return nil, err + } + n.Value = value case *ast.MappingKeyNode: - n.Value = d.resolveAlias(n.Value) + value, err := d.resolveAlias(n.Value) + if err != nil { + return nil, err + } + n.Value = value case *ast.MappingValueNode: if n.Key.Type() == ast.MergeKeyType && n.Value.Type() == ast.AliasType { - value := d.resolveAlias(n.Value) + value, err := d.resolveAlias(n.Value) + if err != nil { + return nil, err + } keyColumn := n.Key.GetToken().Position.Column requiredColumn := keyColumn + 2 value.AddColumn(requiredColumn) n.Value = value } else { - n.Key = d.resolveAlias(n.Key).(ast.MapKeyNode) - n.Value = d.resolveAlias(n.Value) + key, err := d.resolveAlias(n.Key) + if err != nil { + return nil, err + } + n.Key = key.(ast.MapKeyNode) + value, err := d.resolveAlias(n.Value) + if err != nil { + return nil, err + } + n.Value = value } case *ast.SequenceNode: - for idx, value := range n.Values { - n.Values[idx] = d.resolveAlias(value) + for idx, v := range n.Values { + value, err := d.resolveAlias(v) + if err != nil { + return nil, err + } + n.Values[idx] = value } case *ast.AliasNode: aliasName := n.Value.GetToken().Value - return d.resolveAlias(d.anchorNodeMap[aliasName]) + node := d.anchorNodeMap[aliasName] + if node == nil { + return nil, xerrors.Errorf("cannot find anchor by alias name %s", aliasName) + } + return d.resolveAlias(node) } - return node + return node, nil } func (d *Decoder) getMapNode(node ast.Node) (ast.MapNode, error) { @@ -481,33 +512,41 @@ func (d *Decoder) lastNode(node ast.Node) ast.Node { return node } -func (d *Decoder) unmarshalableDocument(node ast.Node) []byte { - node = d.resolveAlias(node) +func (d *Decoder) unmarshalableDocument(node ast.Node) ([]byte, error) { + var err error + node, err = d.resolveAlias(node) + if err != nil { + return nil, err + } doc := node.String() last := d.lastNode(node) if last != nil && last.Type() == ast.LiteralType { doc += "\n" } - return []byte(doc) + return []byte(doc), nil } -func (d *Decoder) unmarshalableText(node ast.Node) ([]byte, bool) { - node = d.resolveAlias(node) +func (d *Decoder) unmarshalableText(node ast.Node) ([]byte, bool, error) { + var err error + node, err = d.resolveAlias(node) + if err != nil { + return nil, false, err + } if node.Type() == ast.AnchorType { node = node.(*ast.AnchorNode).Value } switch n := node.(type) { case *ast.StringNode: - return []byte(n.Value), true + return []byte(n.Value), true, nil case *ast.LiteralNode: - return []byte(n.Value.GetToken().Value), true + return []byte(n.Value.GetToken().Value), true, nil default: scalar, ok := n.(ast.ScalarNode) if ok { - return []byte(fmt.Sprint(scalar.GetValue())), true + return []byte(fmt.Sprint(scalar.GetValue())), true, nil } } - return nil, false + return nil, false, nil } type jsonUnmarshaler interface { @@ -541,14 +580,22 @@ func (d *Decoder) decodeByUnmarshaler(ctx context.Context, dst reflect.Value, sr iface := dst.Addr().Interface() if unmarshaler, ok := iface.(BytesUnmarshalerContext); ok { - if err := unmarshaler.UnmarshalYAML(ctx, d.unmarshalableDocument(src)); err != nil { + b, err := d.unmarshalableDocument(src) + if err != nil { + return errors.Wrapf(err, "failed to UnmarshalYAML") + } + if err := unmarshaler.UnmarshalYAML(ctx, b); err != nil { return errors.Wrapf(err, "failed to UnmarshalYAML") } return nil } if unmarshaler, ok := iface.(BytesUnmarshaler); ok { - if err := unmarshaler.UnmarshalYAML(d.unmarshalableDocument(src)); err != nil { + b, err := d.unmarshalableDocument(src) + if err != nil { + return errors.Wrapf(err, "failed to UnmarshalYAML") + } + if err := unmarshaler.UnmarshalYAML(b); err != nil { return errors.Wrapf(err, "failed to UnmarshalYAML") } return nil @@ -595,7 +642,10 @@ func (d *Decoder) decodeByUnmarshaler(ctx context.Context, dst reflect.Value, sr } if unmarshaler, isText := iface.(encoding.TextUnmarshaler); isText { - b, ok := d.unmarshalableText(src) + b, ok, err := d.unmarshalableText(src) + if err != nil { + return errors.Wrapf(err, "failed to UnmarshalText") + } if ok { if err := unmarshaler.UnmarshalText(b); err != nil { return errors.Wrapf(err, "failed to UnmarshalText") @@ -606,7 +656,11 @@ func (d *Decoder) decodeByUnmarshaler(ctx context.Context, dst reflect.Value, sr if d.useJSONUnmarshaler { if unmarshaler, ok := iface.(jsonUnmarshaler); ok { - jsonBytes, err := YAMLToJSON(d.unmarshalableDocument(src)) + b, err := d.unmarshalableDocument(src) + if err != nil { + return errors.Wrapf(err, "failed to UnmarshalJSON") + } + jsonBytes, err := YAMLToJSON(b) if err != nil { return errors.Wrapf(err, "failed to convert yaml to json") } diff --git a/decode_test.go b/decode_test.go index b671eb0..af6e454 100644 --- a/decode_test.go +++ b/decode_test.go @@ -2368,39 +2368,79 @@ func (v *unmarshalYAMLWithAliasMap) UnmarshalYAML(b []byte) error { } func TestDecoder_UnmarshalYAMLWithAlias(t *testing.T) { - yml := ` + type value struct { + String unmarshalYAMLWithAliasString + Map unmarshalYAMLWithAliasMap + } + tests := []struct { + name string + yaml string + expectedValue value + err string + }{ + { + name: "ok", + yaml: ` anchors: - x: &x "\"hello\" \"world\"" - map: &y + w: &w "\"hello\" \"world\"" + map: &x a: b c: d - d: *x -a: *x -b: - <<: *y + d: *w +string: *w +map: + <<: *x e: f -` - var v struct { - A unmarshalYAMLWithAliasString - B unmarshalYAMLWithAliasMap - } - if err := yaml.Unmarshal([]byte(yml), &v); err != nil { - t.Fatalf("%+v", err) - } - if v.A != `"hello" "world"` { - t.Fatal("failed to unmarshal with alias") - } - if len(v.B) != 4 { - t.Fatal("failed to unmarshal with alias") - } - if v.B["a"] != "b" { - t.Fatal("failed to unmarshal with alias") - } - if v.B["c"] != "d" { - t.Fatal("failed to unmarshal with alias") +`, + expectedValue: value{ + String: unmarshalYAMLWithAliasString(`"hello" "world"`), + Map: unmarshalYAMLWithAliasMap(map[string]interface{}{ + "a": "b", + "c": "d", + "d": `"hello" "world"`, + "e": "f", + }), + }, + }, + { + name: "unknown alias", + yaml: ` +anchors: + w: &w "\"hello\" \"world\"" + map: &x + a: b + c: d + d: *w +string: *y +map: + <<: *z + e: f +`, + err: "cannot find anchor by alias name y", + }, } - if v.B["d"] != `"hello" "world"` { - t.Fatal("failed to unmarshal with alias") + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var v value + err := yaml.Unmarshal([]byte(test.yaml), &v) + + if test.err != "" { + if err == nil { + t.Fatal("expected to error") + } + if !strings.Contains(err.Error(), test.err) { + t.Fatalf("expected error message: %s to contain: %s", err.Error(), test.err) + } + } else { + if err != nil { + t.Fatalf("%+v", err) + } + if !reflect.DeepEqual(test.expectedValue, v) { + t.Fatalf("non matching values:\nexpected[%s]\ngot [%s]", test.expectedValue, v) + } + } + }) } }