Skip to content

Commit

Permalink
Merge pull request #317 from nervo/handle-unmarshall-unknown-aliases
Browse files Browse the repository at this point in the history
Handle unmarshall unknown aliases
  • Loading branch information
goccy committed Oct 26, 2022
2 parents 48a606c + 941abdb commit 8607d4f
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 53 deletions.
104 changes: 79 additions & 25 deletions decode.go
Expand Up @@ -265,36 +265,67 @@ func (d *Decoder) nodeToValue(node ast.Node) interface{} {
return nil
}

func (d *Decoder) resolveAlias(node ast.Node) ast.Node {
func (d *Decoder) resolveAlias(node ast.Node) (ast.Node, error) {
switch n := node.(type) {
case *ast.MappingNode:
for idx, value := range n.Values {
n.Values[idx] = d.resolveAlias(value).(*ast.MappingValueNode)
for idx, v := range n.Values {
value, err := d.resolveAlias(v)
if err != nil {
return nil, err
}
n.Values[idx] = value.(*ast.MappingValueNode)
}
case *ast.TagNode:
n.Value = d.resolveAlias(n.Value)
value, err := d.resolveAlias(n.Value)
if err != nil {
return nil, err
}
n.Value = value
case *ast.MappingKeyNode:
n.Value = d.resolveAlias(n.Value)
value, err := d.resolveAlias(n.Value)
if err != nil {
return nil, err
}
n.Value = value
case *ast.MappingValueNode:
if n.Key.Type() == ast.MergeKeyType && n.Value.Type() == ast.AliasType {
value := d.resolveAlias(n.Value)
value, err := d.resolveAlias(n.Value)
if err != nil {
return nil, err
}
keyColumn := n.Key.GetToken().Position.Column
requiredColumn := keyColumn + 2
value.AddColumn(requiredColumn)
n.Value = value
} else {
n.Key = d.resolveAlias(n.Key).(ast.MapKeyNode)
n.Value = d.resolveAlias(n.Value)
key, err := d.resolveAlias(n.Key)
if err != nil {
return nil, err
}
n.Key = key.(ast.MapKeyNode)
value, err := d.resolveAlias(n.Value)
if err != nil {
return nil, err
}
n.Value = value
}
case *ast.SequenceNode:
for idx, value := range n.Values {
n.Values[idx] = d.resolveAlias(value)
for idx, v := range n.Values {
value, err := d.resolveAlias(v)
if err != nil {
return nil, err
}
n.Values[idx] = value
}
case *ast.AliasNode:
aliasName := n.Value.GetToken().Value
return d.resolveAlias(d.anchorNodeMap[aliasName])
node := d.anchorNodeMap[aliasName]
if node == nil {
return nil, xerrors.Errorf("cannot find anchor by alias name %s", aliasName)
}
return d.resolveAlias(node)
}
return node
return node, nil
}

func (d *Decoder) getMapNode(node ast.Node) (ast.MapNode, error) {
Expand Down Expand Up @@ -481,33 +512,41 @@ func (d *Decoder) lastNode(node ast.Node) ast.Node {
return node
}

func (d *Decoder) unmarshalableDocument(node ast.Node) []byte {
node = d.resolveAlias(node)
func (d *Decoder) unmarshalableDocument(node ast.Node) ([]byte, error) {
var err error
node, err = d.resolveAlias(node)
if err != nil {
return nil, err
}
doc := node.String()
last := d.lastNode(node)
if last != nil && last.Type() == ast.LiteralType {
doc += "\n"
}
return []byte(doc)
return []byte(doc), nil
}

func (d *Decoder) unmarshalableText(node ast.Node) ([]byte, bool) {
node = d.resolveAlias(node)
func (d *Decoder) unmarshalableText(node ast.Node) ([]byte, bool, error) {
var err error
node, err = d.resolveAlias(node)
if err != nil {
return nil, false, err
}
if node.Type() == ast.AnchorType {
node = node.(*ast.AnchorNode).Value
}
switch n := node.(type) {
case *ast.StringNode:
return []byte(n.Value), true
return []byte(n.Value), true, nil
case *ast.LiteralNode:
return []byte(n.Value.GetToken().Value), true
return []byte(n.Value.GetToken().Value), true, nil
default:
scalar, ok := n.(ast.ScalarNode)
if ok {
return []byte(fmt.Sprint(scalar.GetValue())), true
return []byte(fmt.Sprint(scalar.GetValue())), true, nil
}
}
return nil, false
return nil, false, nil
}

type jsonUnmarshaler interface {
Expand Down Expand Up @@ -541,14 +580,22 @@ func (d *Decoder) decodeByUnmarshaler(ctx context.Context, dst reflect.Value, sr
iface := dst.Addr().Interface()

if unmarshaler, ok := iface.(BytesUnmarshalerContext); ok {
if err := unmarshaler.UnmarshalYAML(ctx, d.unmarshalableDocument(src)); err != nil {
b, err := d.unmarshalableDocument(src)
if err != nil {
return errors.Wrapf(err, "failed to UnmarshalYAML")
}
if err := unmarshaler.UnmarshalYAML(ctx, b); err != nil {
return errors.Wrapf(err, "failed to UnmarshalYAML")
}
return nil
}

if unmarshaler, ok := iface.(BytesUnmarshaler); ok {
if err := unmarshaler.UnmarshalYAML(d.unmarshalableDocument(src)); err != nil {
b, err := d.unmarshalableDocument(src)
if err != nil {
return errors.Wrapf(err, "failed to UnmarshalYAML")
}
if err := unmarshaler.UnmarshalYAML(b); err != nil {
return errors.Wrapf(err, "failed to UnmarshalYAML")
}
return nil
Expand Down Expand Up @@ -595,7 +642,10 @@ func (d *Decoder) decodeByUnmarshaler(ctx context.Context, dst reflect.Value, sr
}

if unmarshaler, isText := iface.(encoding.TextUnmarshaler); isText {
b, ok := d.unmarshalableText(src)
b, ok, err := d.unmarshalableText(src)
if err != nil {
return errors.Wrapf(err, "failed to UnmarshalText")
}
if ok {
if err := unmarshaler.UnmarshalText(b); err != nil {
return errors.Wrapf(err, "failed to UnmarshalText")
Expand All @@ -606,7 +656,11 @@ func (d *Decoder) decodeByUnmarshaler(ctx context.Context, dst reflect.Value, sr

if d.useJSONUnmarshaler {
if unmarshaler, ok := iface.(jsonUnmarshaler); ok {
jsonBytes, err := YAMLToJSON(d.unmarshalableDocument(src))
b, err := d.unmarshalableDocument(src)
if err != nil {
return errors.Wrapf(err, "failed to UnmarshalJSON")
}
jsonBytes, err := YAMLToJSON(b)
if err != nil {
return errors.Wrapf(err, "failed to convert yaml to json")
}
Expand Down
96 changes: 68 additions & 28 deletions decode_test.go
Expand Up @@ -2368,39 +2368,79 @@ func (v *unmarshalYAMLWithAliasMap) UnmarshalYAML(b []byte) error {
}

func TestDecoder_UnmarshalYAMLWithAlias(t *testing.T) {
yml := `
type value struct {
String unmarshalYAMLWithAliasString
Map unmarshalYAMLWithAliasMap
}
tests := []struct {
name string
yaml string
expectedValue value
err string
}{
{
name: "ok",
yaml: `
anchors:
x: &x "\"hello\" \"world\""
map: &y
w: &w "\"hello\" \"world\""
map: &x
a: b
c: d
d: *x
a: *x
b:
<<: *y
d: *w
string: *w
map:
<<: *x
e: f
`
var v struct {
A unmarshalYAMLWithAliasString
B unmarshalYAMLWithAliasMap
}
if err := yaml.Unmarshal([]byte(yml), &v); err != nil {
t.Fatalf("%+v", err)
}
if v.A != `"hello" "world"` {
t.Fatal("failed to unmarshal with alias")
}
if len(v.B) != 4 {
t.Fatal("failed to unmarshal with alias")
}
if v.B["a"] != "b" {
t.Fatal("failed to unmarshal with alias")
}
if v.B["c"] != "d" {
t.Fatal("failed to unmarshal with alias")
`,
expectedValue: value{
String: unmarshalYAMLWithAliasString(`"hello" "world"`),
Map: unmarshalYAMLWithAliasMap(map[string]interface{}{
"a": "b",
"c": "d",
"d": `"hello" "world"`,
"e": "f",
}),
},
},
{
name: "unknown alias",
yaml: `
anchors:
w: &w "\"hello\" \"world\""
map: &x
a: b
c: d
d: *w
string: *y
map:
<<: *z
e: f
`,
err: "cannot find anchor by alias name y",
},
}
if v.B["d"] != `"hello" "world"` {
t.Fatal("failed to unmarshal with alias")

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var v value
err := yaml.Unmarshal([]byte(test.yaml), &v)

if test.err != "" {
if err == nil {
t.Fatal("expected to error")
}
if !strings.Contains(err.Error(), test.err) {
t.Fatalf("expected error message: %s to contain: %s", err.Error(), test.err)
}
} else {
if err != nil {
t.Fatalf("%+v", err)
}
if !reflect.DeepEqual(test.expectedValue, v) {
t.Fatalf("non matching values:\nexpected[%s]\ngot [%s]", test.expectedValue, v)
}
}
})
}
}

Expand Down

0 comments on commit 8607d4f

Please sign in to comment.