Skip to content

Commit

Permalink
Introduce MapKeyNode to limit node types for map key
Browse files Browse the repository at this point in the history
  • Loading branch information
itchyny committed Aug 18, 2022
1 parent 883a73b commit 2bdb41e
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 107 deletions.
92 changes: 15 additions & 77 deletions ast/ast.go
Expand Up @@ -191,13 +191,18 @@ type Node interface {
addReadLen(int)
// clean read length
clearLen()
}

// MapKeyNode type for map key node
type MapKeyNode interface {
Node
// String node to text without comment
stringWithoutComment() string
}

// ScalarNode type for scalar node
type ScalarNode interface {
Node
MapKeyNode
GetValue() interface{}
}

Expand Down Expand Up @@ -274,15 +279,15 @@ func readNode(p []byte, node Node) (int, error) {
}

// Null create node for null value
func Null(tk *token.Token) Node {
func Null(tk *token.Token) *NullNode {
return &NullNode{
BaseNode: &BaseNode{},
Token: tk,
}
}

// Bool create node for boolean value
func Bool(tk *token.Token) Node {
func Bool(tk *token.Token) *BoolNode {
b, _ := strconv.ParseBool(tk.Value)
return &BoolNode{
BaseNode: &BaseNode{},
Expand All @@ -292,7 +297,7 @@ func Bool(tk *token.Token) Node {
}

// Integer create node for integer value
func Integer(tk *token.Token) Node {
func Integer(tk *token.Token) *IntegerNode {
value := removeUnderScoreFromNumber(tk.Value)
switch tk.Type {
case token.BinaryIntegerType:
Expand Down Expand Up @@ -386,7 +391,7 @@ func Integer(tk *token.Token) Node {
}

// Float create node for float value
func Float(tk *token.Token) Node {
func Float(tk *token.Token) *FloatNode {
f, _ := strconv.ParseFloat(removeUnderScoreFromNumber(tk.Value), 64)
return &FloatNode{
BaseNode: &BaseNode{},
Expand Down Expand Up @@ -467,7 +472,7 @@ func Mapping(tk *token.Token, isFlowStyle bool, values ...*MappingValueNode) *Ma
}

// MappingValue create node for mapping value
func MappingValue(tk *token.Token, key Node, value Node) *MappingValueNode {
func MappingValue(tk *token.Token, key MapKeyNode, value Node) *MappingValueNode {
return &MappingValueNode{
BaseNode: &BaseNode{},
Start: tk,
Expand Down Expand Up @@ -564,10 +569,6 @@ func (f *File) String() string {
return strings.Join(docs, "\n")
}

func (f *File) stringWithoutComment() string {
return f.String()
}

// DocumentNode type of Document
type DocumentNode struct {
*BaseNode
Expand Down Expand Up @@ -609,10 +610,6 @@ func (d *DocumentNode) String() string {
return strings.Join(doc, "\n")
}

func (d *DocumentNode) stringWithoutComment() string {
return d.String()
}

// MarshalYAML encodes to a YAML text
func (d *DocumentNode) MarshalYAML() ([]byte, error) {
return []byte(d.String()), nil
Expand Down Expand Up @@ -956,7 +953,7 @@ func (n *MergeKeyNode) GetValue() interface{} {

// String returns '<<' value
func (n *MergeKeyNode) String() string {
return n.Token.Value
return n.stringWithoutComment()
}

func (n *MergeKeyNode) stringWithoutComment() string {
Expand Down Expand Up @@ -1137,7 +1134,7 @@ func (m *MapNodeIter) Next() bool {
}

// Key returns the key of the iterator's current map node entry.
func (m *MapNodeIter) Key() Node {
func (m *MapNodeIter) Key() MapKeyNode {
return m.values[m.idx].Key
}

Expand Down Expand Up @@ -1260,14 +1257,6 @@ func (n *MappingNode) String() string {
return n.blockStyleString(commentMode)
}

func (n *MappingNode) stringWithoutComment() string {
commentMode := false
if n.IsFlowStyle || len(n.Values) == 0 {
return n.flowStyleString(commentMode)
}
return n.blockStyleString(commentMode)
}

// MapRange implements MapNode protocol
func (n *MappingNode) MapRange() *MapNodeIter {
return &MapNodeIter{
Expand Down Expand Up @@ -1311,7 +1300,7 @@ func (n *MappingKeyNode) AddColumn(col int) {

// String tag to text
func (n *MappingKeyNode) String() string {
return fmt.Sprintf("%s %s", n.Start.Value, n.Value.String())
return n.stringWithoutComment()
}

func (n *MappingKeyNode) stringWithoutComment() string {
Expand All @@ -1327,7 +1316,7 @@ func (n *MappingKeyNode) MarshalYAML() ([]byte, error) {
type MappingValueNode struct {
*BaseNode
Start *token.Token
Key Node
Key MapKeyNode
Value Node
}

Expand Down Expand Up @@ -1434,26 +1423,6 @@ func (n *MappingValueNode) toString() string {
return fmt.Sprintf("%s%s:\n%s", space, n.Key.String(), n.Value.String())
}

func (n *MappingValueNode) stringWithoutComment() string {
space := strings.Repeat(" ", n.Key.GetToken().Position.Column-1)
keyIndentLevel := n.Key.GetToken().Position.IndentLevel
valueIndentLevel := n.Value.GetToken().Position.IndentLevel
if _, ok := n.Value.(ScalarNode); ok {
return fmt.Sprintf("%s%s: %s", space, n.Key.String(), n.Value.String())
} else if keyIndentLevel < valueIndentLevel {
return fmt.Sprintf("%s%s:\n%s", space, n.Key.String(), n.Value.String())
} else if m, ok := n.Value.(*MappingNode); ok && (m.IsFlowStyle || len(m.Values) == 0) {
return fmt.Sprintf("%s%s: %s", space, n.Key.String(), n.Value.String())
} else if s, ok := n.Value.(*SequenceNode); ok && (s.IsFlowStyle || len(s.Values) == 0) {
return fmt.Sprintf("%s%s: %s", space, n.Key.String(), n.Value.String())
} else if _, ok := n.Value.(*AnchorNode); ok {
return fmt.Sprintf("%s%s: %s", space, n.Key.String(), n.Value.String())
} else if _, ok := n.Value.(*AliasNode); ok {
return fmt.Sprintf("%s%s: %s", space, n.Key.String(), n.Value.String())
}
return fmt.Sprintf("%s%s:\n%s", space, n.Key.String(), n.Value.String())
}

// MapRange implements MapNode protocol
func (n *MappingValueNode) MapRange() *MapNodeIter {
return &MapNodeIter{
Expand Down Expand Up @@ -1619,13 +1588,6 @@ func (n *SequenceNode) String() string {
return n.blockStyleString()
}

func (n *SequenceNode) stringWithoutComment() string {
if n.IsFlowStyle || len(n.Values) == 0 {
return n.flowStyleString()
}
return n.blockStyleString()
}

// ArrayRange implements ArrayNode protocol
func (n *SequenceNode) ArrayRange() *ArrayNodeIter {
return &ArrayNodeIter{
Expand Down Expand Up @@ -1696,10 +1658,6 @@ func (n *AnchorNode) String() string {
return fmt.Sprintf("&%s %s", n.Name.String(), value)
}

func (n *AnchorNode) stringWithoutComment() string {
return n.String()
}

// MarshalYAML encodes to a YAML text
func (n *AnchorNode) MarshalYAML() ([]byte, error) {
return []byte(n.String()), nil
Expand Down Expand Up @@ -1750,10 +1708,6 @@ func (n *AliasNode) String() string {
return fmt.Sprintf("*%s", n.Value.String())
}

func (n *AliasNode) stringWithoutComment() string {
return fmt.Sprintf("*%s", n.Value.String())
}

// MarshalYAML encodes to a YAML text
func (n *AliasNode) MarshalYAML() ([]byte, error) {
return []byte(n.String()), nil
Expand Down Expand Up @@ -1791,10 +1745,6 @@ func (n *DirectiveNode) String() string {
return fmt.Sprintf("%s%s", n.Start.Value, n.Value.String())
}

func (n *DirectiveNode) stringWithoutComment() string {
return fmt.Sprintf("%s%s", n.Start.Value, n.Value.String())
}

// MarshalYAML encodes to a YAML text
func (n *DirectiveNode) MarshalYAML() ([]byte, error) {
return []byte(n.String()), nil
Expand Down Expand Up @@ -1833,10 +1783,6 @@ func (n *TagNode) String() string {
return fmt.Sprintf("%s %s", n.Start.Value, n.Value.String())
}

func (n *TagNode) stringWithoutComment() string {
return fmt.Sprintf("%s %s", n.Start.Value, n.Value.String())
}

// MarshalYAML encodes to a YAML text
func (n *TagNode) MarshalYAML() ([]byte, error) {
return []byte(n.String()), nil
Expand Down Expand Up @@ -1872,10 +1818,6 @@ func (n *CommentNode) String() string {
return fmt.Sprintf("#%s", n.Token.Value)
}

func (n *CommentNode) stringWithoutComment() string {
return ""
}

// MarshalYAML encodes to a YAML text
func (n *CommentNode) MarshalYAML() ([]byte, error) {
return []byte(n.String()), nil
Expand Down Expand Up @@ -1929,10 +1871,6 @@ func (n *CommentGroupNode) StringWithSpace(col int) string {

}

func (n *CommentGroupNode) stringWithoutComment() string {
return ""
}

// MarshalYAML encodes to a YAML text
func (n *CommentGroupNode) MarshalYAML() ([]byte, error) {
return []byte(n.String()), nil
Expand Down
4 changes: 2 additions & 2 deletions decode.go
Expand Up @@ -104,7 +104,7 @@ func (d *Decoder) mergeValueNode(value ast.Node) ast.Node {
return value
}

func (d *Decoder) mapKeyNodeToString(node ast.Node) string {
func (d *Decoder) mapKeyNodeToString(node ast.MapKeyNode) string {
key := d.nodeToValue(node)
if key == nil {
return "null"
Expand Down Expand Up @@ -283,7 +283,7 @@ func (d *Decoder) resolveAlias(node ast.Node) ast.Node {
value.AddColumn(requiredColumn)
n.Value = value
} else {
n.Key = d.resolveAlias(n.Key)
n.Key = d.resolveAlias(n.Key).(ast.MapKeyNode)
n.Value = d.resolveAlias(n.Value)
}
case *ast.SequenceNode:
Expand Down
25 changes: 12 additions & 13 deletions encode.go
Expand Up @@ -351,7 +351,6 @@ func (e *Encoder) encodeValue(ctx context.Context, v reflect.Value, column int)
default:
return nil, xerrors.Errorf("unknown value type %s", v.Type().String())
}
return nil, nil
}

func (e *Encoder) pos(column int) *token.Position {
Expand All @@ -364,17 +363,17 @@ func (e *Encoder) pos(column int) *token.Position {
}
}

func (e *Encoder) encodeNil() ast.Node {
func (e *Encoder) encodeNil() *ast.NullNode {
value := "null"
return ast.Null(token.New(value, value, e.pos(e.column)))
}

func (e *Encoder) encodeInt(v int64) ast.Node {
func (e *Encoder) encodeInt(v int64) *ast.IntegerNode {
value := fmt.Sprint(v)
return ast.Integer(token.New(value, value, e.pos(e.column)))
}

func (e *Encoder) encodeUint(v uint64) ast.Node {
func (e *Encoder) encodeUint(v uint64) *ast.IntegerNode {
value := fmt.Sprint(v)
return ast.Integer(token.New(value, value, e.pos(e.column)))
}
Expand Down Expand Up @@ -411,7 +410,7 @@ func (e *Encoder) isNeedQuoted(v string) bool {
return false
}

func (e *Encoder) encodeString(v string, column int) ast.Node {
func (e *Encoder) encodeString(v string, column int) *ast.StringNode {
if e.isNeedQuoted(v) {
if e.singleQuote {
v = quoteWith(v, '\'')
Expand All @@ -422,12 +421,12 @@ func (e *Encoder) encodeString(v string, column int) ast.Node {
return ast.String(token.New(v, v, e.pos(column)))
}

func (e *Encoder) encodeBool(v bool) ast.Node {
func (e *Encoder) encodeBool(v bool) *ast.BoolNode {
value := fmt.Sprint(v)
return ast.Bool(token.New(value, value, e.pos(e.column)))
}

func (e *Encoder) encodeSlice(ctx context.Context, value reflect.Value) (ast.Node, error) {
func (e *Encoder) encodeSlice(ctx context.Context, value reflect.Value) (*ast.SequenceNode, error) {
if e.indentSequence {
e.column += e.indent
}
Expand All @@ -446,7 +445,7 @@ func (e *Encoder) encodeSlice(ctx context.Context, value reflect.Value) (ast.Nod
return sequence, nil
}

func (e *Encoder) encodeArray(ctx context.Context, value reflect.Value) (ast.Node, error) {
func (e *Encoder) encodeArray(ctx context.Context, value reflect.Value) (*ast.SequenceNode, error) {
if e.indentSequence {
e.column += e.indent
}
Expand Down Expand Up @@ -482,7 +481,7 @@ func (e *Encoder) encodeMapItem(ctx context.Context, item MapItem, column int) (
), nil
}

func (e *Encoder) encodeMapSlice(ctx context.Context, value MapSlice, column int) (ast.Node, error) {
func (e *Encoder) encodeMapSlice(ctx context.Context, value MapSlice, column int) (*ast.MappingNode, error) {
node := ast.Mapping(token.New("", "", e.pos(column)), e.isFlowStyle)
for _, item := range value {
value, err := e.encodeMapItem(ctx, item, column)
Expand Down Expand Up @@ -569,23 +568,23 @@ func (e *Encoder) isZeroValue(v reflect.Value) bool {
return false
}

func (e *Encoder) encodeTime(v time.Time, column int) ast.Node {
func (e *Encoder) encodeTime(v time.Time, column int) *ast.StringNode {
value := v.Format(time.RFC3339Nano)
if e.isJSONStyle {
value = strconv.Quote(value)
}
return ast.String(token.New(value, value, e.pos(column)))
}

func (e *Encoder) encodeDuration(v time.Duration, column int) ast.Node {
func (e *Encoder) encodeDuration(v time.Duration, column int) *ast.StringNode {
value := v.String()
if e.isJSONStyle {
value = strconv.Quote(value)
}
return ast.String(token.New(value, value, e.pos(column)))
}

func (e *Encoder) encodeAnchor(anchorName string, value ast.Node, fieldValue reflect.Value, column int) (ast.Node, error) {
func (e *Encoder) encodeAnchor(anchorName string, value ast.Node, fieldValue reflect.Value, column int) (*ast.AnchorNode, error) {
anchorNode := ast.Anchor(token.New("&", "&", e.pos(column)))
anchorNode.Name = ast.String(token.New(anchorName, anchorName, e.pos(column)))
anchorNode.Value = value
Expand Down Expand Up @@ -637,7 +636,7 @@ func (e *Encoder) encodeStruct(ctx context.Context, value reflect.Value, column
s.SetIsFlowStyle(true)
}
}
key := e.encodeString(structField.RenderName, column)
var key ast.MapKeyNode = e.encodeString(structField.RenderName, column)
switch {
case structField.AnchorName != "":
anchorNode, err := e.encodeAnchor(structField.AnchorName, value, fieldValue, column)
Expand Down

0 comments on commit 2bdb41e

Please sign in to comment.