Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce MapKeyNode interface to limit node types for map key #312

Merged
merged 1 commit into from Aug 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
92 changes: 15 additions & 77 deletions ast/ast.go
Expand Up @@ -191,13 +191,18 @@ type Node interface {
addReadLen(int)
// clean read length
clearLen()
}

// MapKeyNode type for map key node
type MapKeyNode interface {
Node
// String node to text without comment
stringWithoutComment() string
}

// ScalarNode type for scalar node
type ScalarNode interface {
Node
MapKeyNode
GetValue() interface{}
}

Expand Down Expand Up @@ -274,15 +279,15 @@ func readNode(p []byte, node Node) (int, error) {
}

// Null create node for null value
func Null(tk *token.Token) Node {
func Null(tk *token.Token) *NullNode {
return &NullNode{
BaseNode: &BaseNode{},
Token: tk,
}
}

// Bool create node for boolean value
func Bool(tk *token.Token) Node {
func Bool(tk *token.Token) *BoolNode {
b, _ := strconv.ParseBool(tk.Value)
return &BoolNode{
BaseNode: &BaseNode{},
Expand All @@ -292,7 +297,7 @@ func Bool(tk *token.Token) Node {
}

// Integer create node for integer value
func Integer(tk *token.Token) Node {
func Integer(tk *token.Token) *IntegerNode {
value := removeUnderScoreFromNumber(tk.Value)
switch tk.Type {
case token.BinaryIntegerType:
Expand Down Expand Up @@ -386,7 +391,7 @@ func Integer(tk *token.Token) Node {
}

// Float create node for float value
func Float(tk *token.Token) Node {
func Float(tk *token.Token) *FloatNode {
f, _ := strconv.ParseFloat(removeUnderScoreFromNumber(tk.Value), 64)
return &FloatNode{
BaseNode: &BaseNode{},
Expand Down Expand Up @@ -467,7 +472,7 @@ func Mapping(tk *token.Token, isFlowStyle bool, values ...*MappingValueNode) *Ma
}

// MappingValue create node for mapping value
func MappingValue(tk *token.Token, key Node, value Node) *MappingValueNode {
func MappingValue(tk *token.Token, key MapKeyNode, value Node) *MappingValueNode {
return &MappingValueNode{
BaseNode: &BaseNode{},
Start: tk,
Expand Down Expand Up @@ -564,10 +569,6 @@ func (f *File) String() string {
return strings.Join(docs, "\n")
}

func (f *File) stringWithoutComment() string {
return f.String()
}

// DocumentNode type of Document
type DocumentNode struct {
*BaseNode
Expand Down Expand Up @@ -609,10 +610,6 @@ func (d *DocumentNode) String() string {
return strings.Join(doc, "\n")
}

func (d *DocumentNode) stringWithoutComment() string {
return d.String()
}

// MarshalYAML encodes to a YAML text
func (d *DocumentNode) MarshalYAML() ([]byte, error) {
return []byte(d.String()), nil
Expand Down Expand Up @@ -956,7 +953,7 @@ func (n *MergeKeyNode) GetValue() interface{} {

// String returns '<<' value
func (n *MergeKeyNode) String() string {
return n.Token.Value
return n.stringWithoutComment()
}

func (n *MergeKeyNode) stringWithoutComment() string {
Expand Down Expand Up @@ -1137,7 +1134,7 @@ func (m *MapNodeIter) Next() bool {
}

// Key returns the key of the iterator's current map node entry.
func (m *MapNodeIter) Key() Node {
func (m *MapNodeIter) Key() MapKeyNode {
return m.values[m.idx].Key
}

Expand Down Expand Up @@ -1260,14 +1257,6 @@ func (n *MappingNode) String() string {
return n.blockStyleString(commentMode)
}

func (n *MappingNode) stringWithoutComment() string {
commentMode := false
if n.IsFlowStyle || len(n.Values) == 0 {
return n.flowStyleString(commentMode)
}
return n.blockStyleString(commentMode)
}

// MapRange implements MapNode protocol
func (n *MappingNode) MapRange() *MapNodeIter {
return &MapNodeIter{
Expand Down Expand Up @@ -1311,7 +1300,7 @@ func (n *MappingKeyNode) AddColumn(col int) {

// String tag to text
func (n *MappingKeyNode) String() string {
return fmt.Sprintf("%s %s", n.Start.Value, n.Value.String())
return n.stringWithoutComment()
}

func (n *MappingKeyNode) stringWithoutComment() string {
Expand All @@ -1327,7 +1316,7 @@ func (n *MappingKeyNode) MarshalYAML() ([]byte, error) {
type MappingValueNode struct {
*BaseNode
Start *token.Token
Key Node
Key MapKeyNode
Value Node
}

Expand Down Expand Up @@ -1434,26 +1423,6 @@ func (n *MappingValueNode) toString() string {
return fmt.Sprintf("%s%s:\n%s", space, n.Key.String(), n.Value.String())
}

func (n *MappingValueNode) stringWithoutComment() string {
space := strings.Repeat(" ", n.Key.GetToken().Position.Column-1)
keyIndentLevel := n.Key.GetToken().Position.IndentLevel
valueIndentLevel := n.Value.GetToken().Position.IndentLevel
if _, ok := n.Value.(ScalarNode); ok {
return fmt.Sprintf("%s%s: %s", space, n.Key.String(), n.Value.String())
} else if keyIndentLevel < valueIndentLevel {
return fmt.Sprintf("%s%s:\n%s", space, n.Key.String(), n.Value.String())
} else if m, ok := n.Value.(*MappingNode); ok && (m.IsFlowStyle || len(m.Values) == 0) {
return fmt.Sprintf("%s%s: %s", space, n.Key.String(), n.Value.String())
} else if s, ok := n.Value.(*SequenceNode); ok && (s.IsFlowStyle || len(s.Values) == 0) {
return fmt.Sprintf("%s%s: %s", space, n.Key.String(), n.Value.String())
} else if _, ok := n.Value.(*AnchorNode); ok {
return fmt.Sprintf("%s%s: %s", space, n.Key.String(), n.Value.String())
} else if _, ok := n.Value.(*AliasNode); ok {
return fmt.Sprintf("%s%s: %s", space, n.Key.String(), n.Value.String())
}
return fmt.Sprintf("%s%s:\n%s", space, n.Key.String(), n.Value.String())
}

// MapRange implements MapNode protocol
func (n *MappingValueNode) MapRange() *MapNodeIter {
return &MapNodeIter{
Expand Down Expand Up @@ -1619,13 +1588,6 @@ func (n *SequenceNode) String() string {
return n.blockStyleString()
}

func (n *SequenceNode) stringWithoutComment() string {
if n.IsFlowStyle || len(n.Values) == 0 {
return n.flowStyleString()
}
return n.blockStyleString()
}

// ArrayRange implements ArrayNode protocol
func (n *SequenceNode) ArrayRange() *ArrayNodeIter {
return &ArrayNodeIter{
Expand Down Expand Up @@ -1696,10 +1658,6 @@ func (n *AnchorNode) String() string {
return fmt.Sprintf("&%s %s", n.Name.String(), value)
}

func (n *AnchorNode) stringWithoutComment() string {
return n.String()
}

// MarshalYAML encodes to a YAML text
func (n *AnchorNode) MarshalYAML() ([]byte, error) {
return []byte(n.String()), nil
Expand Down Expand Up @@ -1750,10 +1708,6 @@ func (n *AliasNode) String() string {
return fmt.Sprintf("*%s", n.Value.String())
}

func (n *AliasNode) stringWithoutComment() string {
return fmt.Sprintf("*%s", n.Value.String())
}

// MarshalYAML encodes to a YAML text
func (n *AliasNode) MarshalYAML() ([]byte, error) {
return []byte(n.String()), nil
Expand Down Expand Up @@ -1791,10 +1745,6 @@ func (n *DirectiveNode) String() string {
return fmt.Sprintf("%s%s", n.Start.Value, n.Value.String())
}

func (n *DirectiveNode) stringWithoutComment() string {
return fmt.Sprintf("%s%s", n.Start.Value, n.Value.String())
}

// MarshalYAML encodes to a YAML text
func (n *DirectiveNode) MarshalYAML() ([]byte, error) {
return []byte(n.String()), nil
Expand Down Expand Up @@ -1833,10 +1783,6 @@ func (n *TagNode) String() string {
return fmt.Sprintf("%s %s", n.Start.Value, n.Value.String())
}

func (n *TagNode) stringWithoutComment() string {
return fmt.Sprintf("%s %s", n.Start.Value, n.Value.String())
}

// MarshalYAML encodes to a YAML text
func (n *TagNode) MarshalYAML() ([]byte, error) {
return []byte(n.String()), nil
Expand Down Expand Up @@ -1872,10 +1818,6 @@ func (n *CommentNode) String() string {
return fmt.Sprintf("#%s", n.Token.Value)
}

func (n *CommentNode) stringWithoutComment() string {
return ""
}

// MarshalYAML encodes to a YAML text
func (n *CommentNode) MarshalYAML() ([]byte, error) {
return []byte(n.String()), nil
Expand Down Expand Up @@ -1929,10 +1871,6 @@ func (n *CommentGroupNode) StringWithSpace(col int) string {

}

func (n *CommentGroupNode) stringWithoutComment() string {
return ""
}

// MarshalYAML encodes to a YAML text
func (n *CommentGroupNode) MarshalYAML() ([]byte, error) {
return []byte(n.String()), nil
Expand Down
4 changes: 2 additions & 2 deletions decode.go
Expand Up @@ -104,7 +104,7 @@ func (d *Decoder) mergeValueNode(value ast.Node) ast.Node {
return value
}

func (d *Decoder) mapKeyNodeToString(node ast.Node) string {
func (d *Decoder) mapKeyNodeToString(node ast.MapKeyNode) string {
key := d.nodeToValue(node)
if key == nil {
return "null"
Expand Down Expand Up @@ -283,7 +283,7 @@ func (d *Decoder) resolveAlias(node ast.Node) ast.Node {
value.AddColumn(requiredColumn)
n.Value = value
} else {
n.Key = d.resolveAlias(n.Key)
n.Key = d.resolveAlias(n.Key).(ast.MapKeyNode)
n.Value = d.resolveAlias(n.Value)
}
case *ast.SequenceNode:
Expand Down
25 changes: 12 additions & 13 deletions encode.go
Expand Up @@ -351,7 +351,6 @@ func (e *Encoder) encodeValue(ctx context.Context, v reflect.Value, column int)
default:
return nil, xerrors.Errorf("unknown value type %s", v.Type().String())
}
return nil, nil
}

func (e *Encoder) pos(column int) *token.Position {
Expand All @@ -364,17 +363,17 @@ func (e *Encoder) pos(column int) *token.Position {
}
}

func (e *Encoder) encodeNil() ast.Node {
func (e *Encoder) encodeNil() *ast.NullNode {
value := "null"
return ast.Null(token.New(value, value, e.pos(e.column)))
}

func (e *Encoder) encodeInt(v int64) ast.Node {
func (e *Encoder) encodeInt(v int64) *ast.IntegerNode {
value := fmt.Sprint(v)
return ast.Integer(token.New(value, value, e.pos(e.column)))
}

func (e *Encoder) encodeUint(v uint64) ast.Node {
func (e *Encoder) encodeUint(v uint64) *ast.IntegerNode {
value := fmt.Sprint(v)
return ast.Integer(token.New(value, value, e.pos(e.column)))
}
Expand Down Expand Up @@ -411,7 +410,7 @@ func (e *Encoder) isNeedQuoted(v string) bool {
return false
}

func (e *Encoder) encodeString(v string, column int) ast.Node {
func (e *Encoder) encodeString(v string, column int) *ast.StringNode {
if e.isNeedQuoted(v) {
if e.singleQuote {
v = quoteWith(v, '\'')
Expand All @@ -422,12 +421,12 @@ func (e *Encoder) encodeString(v string, column int) ast.Node {
return ast.String(token.New(v, v, e.pos(column)))
}

func (e *Encoder) encodeBool(v bool) ast.Node {
func (e *Encoder) encodeBool(v bool) *ast.BoolNode {
value := fmt.Sprint(v)
return ast.Bool(token.New(value, value, e.pos(e.column)))
}

func (e *Encoder) encodeSlice(ctx context.Context, value reflect.Value) (ast.Node, error) {
func (e *Encoder) encodeSlice(ctx context.Context, value reflect.Value) (*ast.SequenceNode, error) {
if e.indentSequence {
e.column += e.indent
}
Expand All @@ -446,7 +445,7 @@ func (e *Encoder) encodeSlice(ctx context.Context, value reflect.Value) (ast.Nod
return sequence, nil
}

func (e *Encoder) encodeArray(ctx context.Context, value reflect.Value) (ast.Node, error) {
func (e *Encoder) encodeArray(ctx context.Context, value reflect.Value) (*ast.SequenceNode, error) {
if e.indentSequence {
e.column += e.indent
}
Expand Down Expand Up @@ -482,7 +481,7 @@ func (e *Encoder) encodeMapItem(ctx context.Context, item MapItem, column int) (
), nil
}

func (e *Encoder) encodeMapSlice(ctx context.Context, value MapSlice, column int) (ast.Node, error) {
func (e *Encoder) encodeMapSlice(ctx context.Context, value MapSlice, column int) (*ast.MappingNode, error) {
node := ast.Mapping(token.New("", "", e.pos(column)), e.isFlowStyle)
for _, item := range value {
value, err := e.encodeMapItem(ctx, item, column)
Expand Down Expand Up @@ -569,23 +568,23 @@ func (e *Encoder) isZeroValue(v reflect.Value) bool {
return false
}

func (e *Encoder) encodeTime(v time.Time, column int) ast.Node {
func (e *Encoder) encodeTime(v time.Time, column int) *ast.StringNode {
value := v.Format(time.RFC3339Nano)
if e.isJSONStyle {
value = strconv.Quote(value)
}
return ast.String(token.New(value, value, e.pos(column)))
}

func (e *Encoder) encodeDuration(v time.Duration, column int) ast.Node {
func (e *Encoder) encodeDuration(v time.Duration, column int) *ast.StringNode {
value := v.String()
if e.isJSONStyle {
value = strconv.Quote(value)
}
return ast.String(token.New(value, value, e.pos(column)))
}

func (e *Encoder) encodeAnchor(anchorName string, value ast.Node, fieldValue reflect.Value, column int) (ast.Node, error) {
func (e *Encoder) encodeAnchor(anchorName string, value ast.Node, fieldValue reflect.Value, column int) (*ast.AnchorNode, error) {
anchorNode := ast.Anchor(token.New("&", "&", e.pos(column)))
anchorNode.Name = ast.String(token.New(anchorName, anchorName, e.pos(column)))
anchorNode.Value = value
Expand Down Expand Up @@ -637,7 +636,7 @@ func (e *Encoder) encodeStruct(ctx context.Context, value reflect.Value, column
s.SetIsFlowStyle(true)
}
}
key := e.encodeString(structField.RenderName, column)
var key ast.MapKeyNode = e.encodeString(structField.RenderName, column)
switch {
case structField.AnchorName != "":
anchorNode, err := e.encodeAnchor(structField.AnchorName, value, fieldValue, column)
Expand Down