diff --git a/ast/ast.go b/ast/ast.go index 415fad4..795f525 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -191,13 +191,18 @@ type Node interface { addReadLen(int) // clean read length clearLen() +} + +// MapKeyNode type for map key node +type MapKeyNode interface { + Node // String node to text without comment stringWithoutComment() string } // ScalarNode type for scalar node type ScalarNode interface { - Node + MapKeyNode GetValue() interface{} } @@ -274,7 +279,7 @@ func readNode(p []byte, node Node) (int, error) { } // Null create node for null value -func Null(tk *token.Token) Node { +func Null(tk *token.Token) *NullNode { return &NullNode{ BaseNode: &BaseNode{}, Token: tk, @@ -282,7 +287,7 @@ func Null(tk *token.Token) Node { } // Bool create node for boolean value -func Bool(tk *token.Token) Node { +func Bool(tk *token.Token) *BoolNode { b, _ := strconv.ParseBool(tk.Value) return &BoolNode{ BaseNode: &BaseNode{}, @@ -292,7 +297,7 @@ func Bool(tk *token.Token) Node { } // Integer create node for integer value -func Integer(tk *token.Token) Node { +func Integer(tk *token.Token) *IntegerNode { value := removeUnderScoreFromNumber(tk.Value) switch tk.Type { case token.BinaryIntegerType: @@ -386,7 +391,7 @@ func Integer(tk *token.Token) Node { } // Float create node for float value -func Float(tk *token.Token) Node { +func Float(tk *token.Token) *FloatNode { f, _ := strconv.ParseFloat(removeUnderScoreFromNumber(tk.Value), 64) return &FloatNode{ BaseNode: &BaseNode{}, @@ -467,7 +472,7 @@ func Mapping(tk *token.Token, isFlowStyle bool, values ...*MappingValueNode) *Ma } // MappingValue create node for mapping value -func MappingValue(tk *token.Token, key Node, value Node) *MappingValueNode { +func MappingValue(tk *token.Token, key MapKeyNode, value Node) *MappingValueNode { return &MappingValueNode{ BaseNode: &BaseNode{}, Start: tk, @@ -564,10 +569,6 @@ func (f *File) String() string { return strings.Join(docs, "\n") } -func (f *File) stringWithoutComment() string { - return f.String() -} - // DocumentNode type of Document type DocumentNode struct { *BaseNode @@ -609,10 +610,6 @@ func (d *DocumentNode) String() string { return strings.Join(doc, "\n") } -func (d *DocumentNode) stringWithoutComment() string { - return d.String() -} - // MarshalYAML encodes to a YAML text func (d *DocumentNode) MarshalYAML() ([]byte, error) { return []byte(d.String()), nil @@ -956,7 +953,7 @@ func (n *MergeKeyNode) GetValue() interface{} { // String returns '<<' value func (n *MergeKeyNode) String() string { - return n.Token.Value + return n.stringWithoutComment() } func (n *MergeKeyNode) stringWithoutComment() string { @@ -1137,7 +1134,7 @@ func (m *MapNodeIter) Next() bool { } // Key returns the key of the iterator's current map node entry. -func (m *MapNodeIter) Key() Node { +func (m *MapNodeIter) Key() MapKeyNode { return m.values[m.idx].Key } @@ -1260,14 +1257,6 @@ func (n *MappingNode) String() string { return n.blockStyleString(commentMode) } -func (n *MappingNode) stringWithoutComment() string { - commentMode := false - if n.IsFlowStyle || len(n.Values) == 0 { - return n.flowStyleString(commentMode) - } - return n.blockStyleString(commentMode) -} - // MapRange implements MapNode protocol func (n *MappingNode) MapRange() *MapNodeIter { return &MapNodeIter{ @@ -1311,7 +1300,7 @@ func (n *MappingKeyNode) AddColumn(col int) { // String tag to text func (n *MappingKeyNode) String() string { - return fmt.Sprintf("%s %s", n.Start.Value, n.Value.String()) + return n.stringWithoutComment() } func (n *MappingKeyNode) stringWithoutComment() string { @@ -1327,7 +1316,7 @@ func (n *MappingKeyNode) MarshalYAML() ([]byte, error) { type MappingValueNode struct { *BaseNode Start *token.Token - Key Node + Key MapKeyNode Value Node } @@ -1434,26 +1423,6 @@ func (n *MappingValueNode) toString() string { return fmt.Sprintf("%s%s:\n%s", space, n.Key.String(), n.Value.String()) } -func (n *MappingValueNode) stringWithoutComment() string { - space := strings.Repeat(" ", n.Key.GetToken().Position.Column-1) - keyIndentLevel := n.Key.GetToken().Position.IndentLevel - valueIndentLevel := n.Value.GetToken().Position.IndentLevel - if _, ok := n.Value.(ScalarNode); ok { - return fmt.Sprintf("%s%s: %s", space, n.Key.String(), n.Value.String()) - } else if keyIndentLevel < valueIndentLevel { - return fmt.Sprintf("%s%s:\n%s", space, n.Key.String(), n.Value.String()) - } else if m, ok := n.Value.(*MappingNode); ok && (m.IsFlowStyle || len(m.Values) == 0) { - return fmt.Sprintf("%s%s: %s", space, n.Key.String(), n.Value.String()) - } else if s, ok := n.Value.(*SequenceNode); ok && (s.IsFlowStyle || len(s.Values) == 0) { - return fmt.Sprintf("%s%s: %s", space, n.Key.String(), n.Value.String()) - } else if _, ok := n.Value.(*AnchorNode); ok { - return fmt.Sprintf("%s%s: %s", space, n.Key.String(), n.Value.String()) - } else if _, ok := n.Value.(*AliasNode); ok { - return fmt.Sprintf("%s%s: %s", space, n.Key.String(), n.Value.String()) - } - return fmt.Sprintf("%s%s:\n%s", space, n.Key.String(), n.Value.String()) -} - // MapRange implements MapNode protocol func (n *MappingValueNode) MapRange() *MapNodeIter { return &MapNodeIter{ @@ -1619,13 +1588,6 @@ func (n *SequenceNode) String() string { return n.blockStyleString() } -func (n *SequenceNode) stringWithoutComment() string { - if n.IsFlowStyle || len(n.Values) == 0 { - return n.flowStyleString() - } - return n.blockStyleString() -} - // ArrayRange implements ArrayNode protocol func (n *SequenceNode) ArrayRange() *ArrayNodeIter { return &ArrayNodeIter{ @@ -1696,10 +1658,6 @@ func (n *AnchorNode) String() string { return fmt.Sprintf("&%s %s", n.Name.String(), value) } -func (n *AnchorNode) stringWithoutComment() string { - return n.String() -} - // MarshalYAML encodes to a YAML text func (n *AnchorNode) MarshalYAML() ([]byte, error) { return []byte(n.String()), nil @@ -1750,10 +1708,6 @@ func (n *AliasNode) String() string { return fmt.Sprintf("*%s", n.Value.String()) } -func (n *AliasNode) stringWithoutComment() string { - return fmt.Sprintf("*%s", n.Value.String()) -} - // MarshalYAML encodes to a YAML text func (n *AliasNode) MarshalYAML() ([]byte, error) { return []byte(n.String()), nil @@ -1791,10 +1745,6 @@ func (n *DirectiveNode) String() string { return fmt.Sprintf("%s%s", n.Start.Value, n.Value.String()) } -func (n *DirectiveNode) stringWithoutComment() string { - return fmt.Sprintf("%s%s", n.Start.Value, n.Value.String()) -} - // MarshalYAML encodes to a YAML text func (n *DirectiveNode) MarshalYAML() ([]byte, error) { return []byte(n.String()), nil @@ -1833,10 +1783,6 @@ func (n *TagNode) String() string { return fmt.Sprintf("%s %s", n.Start.Value, n.Value.String()) } -func (n *TagNode) stringWithoutComment() string { - return fmt.Sprintf("%s %s", n.Start.Value, n.Value.String()) -} - // MarshalYAML encodes to a YAML text func (n *TagNode) MarshalYAML() ([]byte, error) { return []byte(n.String()), nil @@ -1872,10 +1818,6 @@ func (n *CommentNode) String() string { return fmt.Sprintf("#%s", n.Token.Value) } -func (n *CommentNode) stringWithoutComment() string { - return "" -} - // MarshalYAML encodes to a YAML text func (n *CommentNode) MarshalYAML() ([]byte, error) { return []byte(n.String()), nil @@ -1929,10 +1871,6 @@ func (n *CommentGroupNode) StringWithSpace(col int) string { } -func (n *CommentGroupNode) stringWithoutComment() string { - return "" -} - // MarshalYAML encodes to a YAML text func (n *CommentGroupNode) MarshalYAML() ([]byte, error) { return []byte(n.String()), nil diff --git a/decode.go b/decode.go index d519b78..2682b2a 100644 --- a/decode.go +++ b/decode.go @@ -104,7 +104,7 @@ func (d *Decoder) mergeValueNode(value ast.Node) ast.Node { return value } -func (d *Decoder) mapKeyNodeToString(node ast.Node) string { +func (d *Decoder) mapKeyNodeToString(node ast.MapKeyNode) string { key := d.nodeToValue(node) if key == nil { return "null" @@ -283,7 +283,7 @@ func (d *Decoder) resolveAlias(node ast.Node) ast.Node { value.AddColumn(requiredColumn) n.Value = value } else { - n.Key = d.resolveAlias(n.Key) + n.Key = d.resolveAlias(n.Key).(ast.MapKeyNode) n.Value = d.resolveAlias(n.Value) } case *ast.SequenceNode: diff --git a/encode.go b/encode.go index 31eaf94..f582b48 100644 --- a/encode.go +++ b/encode.go @@ -351,7 +351,6 @@ func (e *Encoder) encodeValue(ctx context.Context, v reflect.Value, column int) default: return nil, xerrors.Errorf("unknown value type %s", v.Type().String()) } - return nil, nil } func (e *Encoder) pos(column int) *token.Position { @@ -364,17 +363,17 @@ func (e *Encoder) pos(column int) *token.Position { } } -func (e *Encoder) encodeNil() ast.Node { +func (e *Encoder) encodeNil() *ast.NullNode { value := "null" return ast.Null(token.New(value, value, e.pos(e.column))) } -func (e *Encoder) encodeInt(v int64) ast.Node { +func (e *Encoder) encodeInt(v int64) *ast.IntegerNode { value := fmt.Sprint(v) return ast.Integer(token.New(value, value, e.pos(e.column))) } -func (e *Encoder) encodeUint(v uint64) ast.Node { +func (e *Encoder) encodeUint(v uint64) *ast.IntegerNode { value := fmt.Sprint(v) return ast.Integer(token.New(value, value, e.pos(e.column))) } @@ -411,7 +410,7 @@ func (e *Encoder) isNeedQuoted(v string) bool { return false } -func (e *Encoder) encodeString(v string, column int) ast.Node { +func (e *Encoder) encodeString(v string, column int) *ast.StringNode { if e.isNeedQuoted(v) { if e.singleQuote { v = quoteWith(v, '\'') @@ -422,12 +421,12 @@ func (e *Encoder) encodeString(v string, column int) ast.Node { return ast.String(token.New(v, v, e.pos(column))) } -func (e *Encoder) encodeBool(v bool) ast.Node { +func (e *Encoder) encodeBool(v bool) *ast.BoolNode { value := fmt.Sprint(v) return ast.Bool(token.New(value, value, e.pos(e.column))) } -func (e *Encoder) encodeSlice(ctx context.Context, value reflect.Value) (ast.Node, error) { +func (e *Encoder) encodeSlice(ctx context.Context, value reflect.Value) (*ast.SequenceNode, error) { if e.indentSequence { e.column += e.indent } @@ -446,7 +445,7 @@ func (e *Encoder) encodeSlice(ctx context.Context, value reflect.Value) (ast.Nod return sequence, nil } -func (e *Encoder) encodeArray(ctx context.Context, value reflect.Value) (ast.Node, error) { +func (e *Encoder) encodeArray(ctx context.Context, value reflect.Value) (*ast.SequenceNode, error) { if e.indentSequence { e.column += e.indent } @@ -482,7 +481,7 @@ func (e *Encoder) encodeMapItem(ctx context.Context, item MapItem, column int) ( ), nil } -func (e *Encoder) encodeMapSlice(ctx context.Context, value MapSlice, column int) (ast.Node, error) { +func (e *Encoder) encodeMapSlice(ctx context.Context, value MapSlice, column int) (*ast.MappingNode, error) { node := ast.Mapping(token.New("", "", e.pos(column)), e.isFlowStyle) for _, item := range value { value, err := e.encodeMapItem(ctx, item, column) @@ -569,7 +568,7 @@ func (e *Encoder) isZeroValue(v reflect.Value) bool { return false } -func (e *Encoder) encodeTime(v time.Time, column int) ast.Node { +func (e *Encoder) encodeTime(v time.Time, column int) *ast.StringNode { value := v.Format(time.RFC3339Nano) if e.isJSONStyle { value = strconv.Quote(value) @@ -577,7 +576,7 @@ func (e *Encoder) encodeTime(v time.Time, column int) ast.Node { return ast.String(token.New(value, value, e.pos(column))) } -func (e *Encoder) encodeDuration(v time.Duration, column int) ast.Node { +func (e *Encoder) encodeDuration(v time.Duration, column int) *ast.StringNode { value := v.String() if e.isJSONStyle { value = strconv.Quote(value) @@ -585,7 +584,7 @@ func (e *Encoder) encodeDuration(v time.Duration, column int) ast.Node { return ast.String(token.New(value, value, e.pos(column))) } -func (e *Encoder) encodeAnchor(anchorName string, value ast.Node, fieldValue reflect.Value, column int) (ast.Node, error) { +func (e *Encoder) encodeAnchor(anchorName string, value ast.Node, fieldValue reflect.Value, column int) (*ast.AnchorNode, error) { anchorNode := ast.Anchor(token.New("&", "&", e.pos(column))) anchorNode.Name = ast.String(token.New(anchorName, anchorName, e.pos(column))) anchorNode.Value = value @@ -637,7 +636,7 @@ func (e *Encoder) encodeStruct(ctx context.Context, value reflect.Value, column s.SetIsFlowStyle(true) } } - key := e.encodeString(structField.RenderName, column) + var key ast.MapKeyNode = e.encodeString(structField.RenderName, column) switch { case structField.AnchorName != "": anchorNode, err := e.encodeAnchor(structField.AnchorName, value, fieldValue, column) diff --git a/parser/parser.go b/parser/parser.go index 80c9f51..70937c9 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -14,7 +14,7 @@ import ( type parser struct{} -func (p *parser) parseMapping(ctx *context) (ast.Node, error) { +func (p *parser) parseMapping(ctx *context) (*ast.MappingNode, error) { mapTk := ctx.currentToken() node := ast.Mapping(mapTk, true) node.SetPath(ctx.path) @@ -43,7 +43,7 @@ func (p *parser) parseMapping(ctx *context) (ast.Node, error) { return nil, errors.ErrSyntax("unterminated flow mapping", node.GetToken()) } -func (p *parser) parseSequence(ctx *context) (ast.Node, error) { +func (p *parser) parseSequence(ctx *context) (*ast.SequenceNode, error) { node := ast.Sequence(ctx.currentToken(), true) node.SetPath(ctx.path) ctx.progress(1) // skip SequenceStart token @@ -67,7 +67,7 @@ func (p *parser) parseSequence(ctx *context) (ast.Node, error) { return node, nil } -func (p *parser) parseTag(ctx *context) (ast.Node, error) { +func (p *parser) parseTag(ctx *context) (*ast.TagNode, error) { tagToken := ctx.currentToken() node := ast.Tag(tagToken) node.SetPath(ctx.path) @@ -138,7 +138,7 @@ func (p *parser) createNullToken(base *token.Token) *token.Token { return token.New("null", "null", &pos) } -func (p *parser) parseMapValue(ctx *context, key ast.Node, colonToken *token.Token) (ast.Node, error) { +func (p *parser) parseMapValue(ctx *context, key ast.MapKeyNode, colonToken *token.Token) (ast.Node, error) { node, err := p.createMapValueNode(ctx, key, colonToken) if err != nil { return nil, errors.Wrapf(err, "failed to create map value node") @@ -149,7 +149,7 @@ func (p *parser) parseMapValue(ctx *context, key ast.Node, colonToken *token.Tok return node, nil } -func (p *parser) createMapValueNode(ctx *context, key ast.Node, colonToken *token.Token) (ast.Node, error) { +func (p *parser) createMapValueNode(ctx *context, key ast.MapKeyNode, colonToken *token.Token) (ast.Node, error) { tk := ctx.currentToken() if tk == nil { nullToken := p.createNullToken(colonToken) @@ -273,7 +273,7 @@ func (p *parser) parseMappingValue(ctx *context) (ast.Node, error) { return node, nil } -func (p *parser) parseSequenceEntry(ctx *context) (ast.Node, error) { +func (p *parser) parseSequenceEntry(ctx *context) (*ast.SequenceNode, error) { tk := ctx.currentToken() sequenceNode := ast.Sequence(tk, false) sequenceNode.SetPath(ctx.path) @@ -319,7 +319,7 @@ func (p *parser) parseSequenceEntry(ctx *context) (ast.Node, error) { return sequenceNode, nil } -func (p *parser) parseAnchor(ctx *context) (ast.Node, error) { +func (p *parser) parseAnchor(ctx *context) (*ast.AnchorNode, error) { tk := ctx.currentToken() anchor := ast.Anchor(tk) anchor.SetPath(ctx.path) @@ -346,7 +346,7 @@ func (p *parser) parseAnchor(ctx *context) (ast.Node, error) { return anchor, nil } -func (p *parser) parseAlias(ctx *context) (ast.Node, error) { +func (p *parser) parseAlias(ctx *context) (*ast.AliasNode, error) { tk := ctx.currentToken() alias := ast.Alias(tk) alias.SetPath(ctx.path) @@ -363,7 +363,7 @@ func (p *parser) parseAlias(ctx *context) (ast.Node, error) { return alias, nil } -func (p *parser) parseMapKey(ctx *context) (ast.Node, error) { +func (p *parser) parseMapKey(ctx *context) (ast.MapKeyNode, error) { tk := ctx.currentToken() if value := p.parseScalarValue(tk); value != nil { return value, nil @@ -377,7 +377,7 @@ func (p *parser) parseMapKey(ctx *context) (ast.Node, error) { return nil, errors.ErrSyntax("unexpected mapping key", tk) } -func (p *parser) parseStringValue(tk *token.Token) ast.Node { +func (p *parser) parseStringValue(tk *token.Token) *ast.StringNode { switch tk.Type { case token.StringType, token.SingleQuoteType, @@ -387,7 +387,7 @@ func (p *parser) parseStringValue(tk *token.Token) ast.Node { return nil } -func (p *parser) parseScalarValueWithComment(ctx *context, tk *token.Token) (ast.Node, error) { +func (p *parser) parseScalarValueWithComment(ctx *context, tk *token.Token) (ast.ScalarNode, error) { node := p.parseScalarValue(tk) if node == nil { return nil, nil @@ -402,7 +402,7 @@ func (p *parser) parseScalarValueWithComment(ctx *context, tk *token.Token) (ast return node, nil } -func (p *parser) parseScalarValue(tk *token.Token) ast.Node { +func (p *parser) parseScalarValue(tk *token.Token) ast.ScalarNode { if node := p.parseStringValue(tk); node != nil { return node } @@ -426,7 +426,7 @@ func (p *parser) parseScalarValue(tk *token.Token) ast.Node { return nil } -func (p *parser) parseDirective(ctx *context) (ast.Node, error) { +func (p *parser) parseDirective(ctx *context) (*ast.DirectiveNode, error) { node := ast.Directive(ctx.currentToken()) ctx.progress(1) // skip directive token value, err := p.parseToken(ctx, ctx.currentToken()) @@ -447,7 +447,7 @@ func (p *parser) parseDirective(ctx *context) (ast.Node, error) { return node, nil } -func (p *parser) parseLiteral(ctx *context) (ast.Node, error) { +func (p *parser) parseLiteral(ctx *context) (*ast.LiteralNode, error) { node := ast.Literal(ctx.currentToken()) ctx.progress(1) // skip literal/folded token @@ -543,7 +543,7 @@ func (p *parser) parseComment(ctx *context) (ast.Node, error) { return node, nil } -func (p *parser) parseMappingKey(ctx *context) (ast.Node, error) { +func (p *parser) parseMappingKey(ctx *context) (*ast.MappingKeyNode, error) { keyTk := ctx.currentToken() node := ast.MappingKey(keyTk) node.SetPath(ctx.path)