Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle unmarshall unknown aliases #317

Merged
merged 1 commit into from Oct 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
104 changes: 79 additions & 25 deletions decode.go
Expand Up @@ -265,36 +265,67 @@ func (d *Decoder) nodeToValue(node ast.Node) interface{} {
return nil
}

func (d *Decoder) resolveAlias(node ast.Node) ast.Node {
func (d *Decoder) resolveAlias(node ast.Node) (ast.Node, error) {
switch n := node.(type) {
case *ast.MappingNode:
for idx, value := range n.Values {
n.Values[idx] = d.resolveAlias(value).(*ast.MappingValueNode)
for idx, v := range n.Values {
value, err := d.resolveAlias(v)
if err != nil {
return nil, err
}
n.Values[idx] = value.(*ast.MappingValueNode)
}
case *ast.TagNode:
n.Value = d.resolveAlias(n.Value)
value, err := d.resolveAlias(n.Value)
if err != nil {
return nil, err
}
n.Value = value
case *ast.MappingKeyNode:
n.Value = d.resolveAlias(n.Value)
value, err := d.resolveAlias(n.Value)
if err != nil {
return nil, err
}
n.Value = value
case *ast.MappingValueNode:
if n.Key.Type() == ast.MergeKeyType && n.Value.Type() == ast.AliasType {
value := d.resolveAlias(n.Value)
value, err := d.resolveAlias(n.Value)
if err != nil {
return nil, err
}
keyColumn := n.Key.GetToken().Position.Column
requiredColumn := keyColumn + 2
value.AddColumn(requiredColumn)
n.Value = value
} else {
n.Key = d.resolveAlias(n.Key).(ast.MapKeyNode)
n.Value = d.resolveAlias(n.Value)
key, err := d.resolveAlias(n.Key)
if err != nil {
return nil, err
}
n.Key = key.(ast.MapKeyNode)
value, err := d.resolveAlias(n.Value)
if err != nil {
return nil, err
}
n.Value = value
}
case *ast.SequenceNode:
for idx, value := range n.Values {
n.Values[idx] = d.resolveAlias(value)
for idx, v := range n.Values {
value, err := d.resolveAlias(v)
if err != nil {
return nil, err
}
n.Values[idx] = value
}
case *ast.AliasNode:
aliasName := n.Value.GetToken().Value
return d.resolveAlias(d.anchorNodeMap[aliasName])
node := d.anchorNodeMap[aliasName]
if node == nil {
return nil, xerrors.Errorf("cannot find anchor by alias name %s", aliasName)
}
return d.resolveAlias(node)
}
return node
return node, nil
}

func (d *Decoder) getMapNode(node ast.Node) (ast.MapNode, error) {
Expand Down Expand Up @@ -481,33 +512,41 @@ func (d *Decoder) lastNode(node ast.Node) ast.Node {
return node
}

func (d *Decoder) unmarshalableDocument(node ast.Node) []byte {
node = d.resolveAlias(node)
func (d *Decoder) unmarshalableDocument(node ast.Node) ([]byte, error) {
var err error
node, err = d.resolveAlias(node)
if err != nil {
return nil, err
}
doc := node.String()
last := d.lastNode(node)
if last != nil && last.Type() == ast.LiteralType {
doc += "\n"
}
return []byte(doc)
return []byte(doc), nil
}

func (d *Decoder) unmarshalableText(node ast.Node) ([]byte, bool) {
node = d.resolveAlias(node)
func (d *Decoder) unmarshalableText(node ast.Node) ([]byte, bool, error) {
var err error
node, err = d.resolveAlias(node)
if err != nil {
return nil, false, err
}
if node.Type() == ast.AnchorType {
node = node.(*ast.AnchorNode).Value
}
switch n := node.(type) {
case *ast.StringNode:
return []byte(n.Value), true
return []byte(n.Value), true, nil
case *ast.LiteralNode:
return []byte(n.Value.GetToken().Value), true
return []byte(n.Value.GetToken().Value), true, nil
default:
scalar, ok := n.(ast.ScalarNode)
if ok {
return []byte(fmt.Sprint(scalar.GetValue())), true
return []byte(fmt.Sprint(scalar.GetValue())), true, nil
}
}
return nil, false
return nil, false, nil
}

type jsonUnmarshaler interface {
Expand Down Expand Up @@ -541,14 +580,22 @@ func (d *Decoder) decodeByUnmarshaler(ctx context.Context, dst reflect.Value, sr
iface := dst.Addr().Interface()

if unmarshaler, ok := iface.(BytesUnmarshalerContext); ok {
if err := unmarshaler.UnmarshalYAML(ctx, d.unmarshalableDocument(src)); err != nil {
b, err := d.unmarshalableDocument(src)
if err != nil {
return errors.Wrapf(err, "failed to UnmarshalYAML")
}
if err := unmarshaler.UnmarshalYAML(ctx, b); err != nil {
return errors.Wrapf(err, "failed to UnmarshalYAML")
}
return nil
}

if unmarshaler, ok := iface.(BytesUnmarshaler); ok {
if err := unmarshaler.UnmarshalYAML(d.unmarshalableDocument(src)); err != nil {
b, err := d.unmarshalableDocument(src)
if err != nil {
return errors.Wrapf(err, "failed to UnmarshalYAML")
}
if err := unmarshaler.UnmarshalYAML(b); err != nil {
return errors.Wrapf(err, "failed to UnmarshalYAML")
}
return nil
Expand Down Expand Up @@ -595,7 +642,10 @@ func (d *Decoder) decodeByUnmarshaler(ctx context.Context, dst reflect.Value, sr
}

if unmarshaler, isText := iface.(encoding.TextUnmarshaler); isText {
b, ok := d.unmarshalableText(src)
b, ok, err := d.unmarshalableText(src)
if err != nil {
return errors.Wrapf(err, "failed to UnmarshalText")
}
if ok {
if err := unmarshaler.UnmarshalText(b); err != nil {
return errors.Wrapf(err, "failed to UnmarshalText")
Expand All @@ -606,7 +656,11 @@ func (d *Decoder) decodeByUnmarshaler(ctx context.Context, dst reflect.Value, sr

if d.useJSONUnmarshaler {
if unmarshaler, ok := iface.(jsonUnmarshaler); ok {
jsonBytes, err := YAMLToJSON(d.unmarshalableDocument(src))
b, err := d.unmarshalableDocument(src)
if err != nil {
return errors.Wrapf(err, "failed to UnmarshalJSON")
}
jsonBytes, err := YAMLToJSON(b)
if err != nil {
return errors.Wrapf(err, "failed to convert yaml to json")
}
Expand Down
96 changes: 68 additions & 28 deletions decode_test.go
Expand Up @@ -2368,39 +2368,79 @@ func (v *unmarshalYAMLWithAliasMap) UnmarshalYAML(b []byte) error {
}

func TestDecoder_UnmarshalYAMLWithAlias(t *testing.T) {
yml := `
type value struct {
String unmarshalYAMLWithAliasString
Map unmarshalYAMLWithAliasMap
}
tests := []struct {
name string
yaml string
expectedValue value
err string
}{
{
name: "ok",
yaml: `
anchors:
x: &x "\"hello\" \"world\""
map: &y
w: &w "\"hello\" \"world\""
map: &x
a: b
c: d
d: *x
a: *x
b:
<<: *y
d: *w
string: *w
map:
<<: *x
e: f
`
var v struct {
A unmarshalYAMLWithAliasString
B unmarshalYAMLWithAliasMap
}
if err := yaml.Unmarshal([]byte(yml), &v); err != nil {
t.Fatalf("%+v", err)
}
if v.A != `"hello" "world"` {
t.Fatal("failed to unmarshal with alias")
}
if len(v.B) != 4 {
t.Fatal("failed to unmarshal with alias")
}
if v.B["a"] != "b" {
t.Fatal("failed to unmarshal with alias")
}
if v.B["c"] != "d" {
t.Fatal("failed to unmarshal with alias")
`,
expectedValue: value{
String: unmarshalYAMLWithAliasString(`"hello" "world"`),
Map: unmarshalYAMLWithAliasMap(map[string]interface{}{
"a": "b",
"c": "d",
"d": `"hello" "world"`,
"e": "f",
}),
},
},
{
name: "unknown alias",
yaml: `
anchors:
w: &w "\"hello\" \"world\""
map: &x
a: b
c: d
d: *w
string: *y
map:
<<: *z
e: f
`,
err: "cannot find anchor by alias name y",
},
}
if v.B["d"] != `"hello" "world"` {
t.Fatal("failed to unmarshal with alias")

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var v value
err := yaml.Unmarshal([]byte(test.yaml), &v)

if test.err != "" {
if err == nil {
t.Fatal("expected to error")
}
if !strings.Contains(err.Error(), test.err) {
t.Fatalf("expected error message: %s to contain: %s", err.Error(), test.err)
}
} else {
if err != nil {
t.Fatalf("%+v", err)
}
if !reflect.DeepEqual(test.expectedValue, v) {
t.Fatalf("non matching values:\nexpected[%s]\ngot [%s]", test.expectedValue, v)
}
}
})
}
}

Expand Down