Skip to content

Commit

Permalink
Decoder: disallow modification of existing table (#704)
Browse files Browse the repository at this point in the history
Fixes #703
  • Loading branch information
pelletier committed Dec 15, 2021
1 parent facb2b1 commit 696dd25
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 95 deletions.
49 changes: 24 additions & 25 deletions internal/ast/ast.go
Expand Up @@ -20,8 +20,8 @@ type Iterator struct {
node *Node
}

// Next moves the iterator forward and returns true if points to a node, false
// otherwise.
// Next moves the iterator forward and returns true if points to a
// node, false otherwise.
func (c *Iterator) Next() bool {
if !c.started {
c.started = true
Expand All @@ -31,8 +31,8 @@ func (c *Iterator) Next() bool {
return c.node.Valid()
}

// IsLast returns true if the current node of the iterator is the last one.
// Subsequent call to Next() will return false.
// IsLast returns true if the current node of the iterator is the last
// one. Subsequent call to Next() will return false.
func (c *Iterator) IsLast() bool {
return c.node.next == 0
}
Expand Down Expand Up @@ -62,20 +62,20 @@ func (r *Root) at(idx Reference) *Node {
return &r.nodes[idx]
}

// Arrays have one child per element in the array.
// InlineTables have one child per key-value pair in the table.
// KeyValues have at least two children. The first one is the value. The
// rest make a potentially dotted key.
// Table and Array table have one child per element of the key they
// represent (same as KeyValue, but without the last node being the value).
// children []Node
// Arrays have one child per element in the array. InlineTables have
// one child per key-value pair in the table. KeyValues have at least
// two children. The first one is the value. The rest make a
// potentially dotted key. Table and Array table have one child per
// element of the key they represent (same as KeyValue, but without
// the last node being the value).
type Node struct {
Kind Kind
Raw Range // Raw bytes from the input.
Data []byte // Node value (could be either allocated or referencing the input).
Data []byte // Node value (either allocated or referencing the input).

// References to other nodes, as offsets in the backing array from this
// node. References can go backward, so those can be negative.
// References to other nodes, as offsets in the backing array
// from this node. References can go backward, so those can be
// negative.
next int // 0 if last element
child int // 0 if no child
}
Expand All @@ -85,8 +85,8 @@ type Range struct {
Length uint32
}

// Next returns a copy of the next node, or an invalid Node if there is no
// next node.
// Next returns a copy of the next node, or an invalid Node if there
// is no next node.
func (n *Node) Next() *Node {
if n.next == 0 {
return nil
Expand All @@ -96,9 +96,9 @@ func (n *Node) Next() *Node {
return (*Node)(danger.Stride(ptr, size, n.next))
}

// Child returns a copy of the first child node of this node. Other children
// can be accessed calling Next on the first child.
// Returns an invalid Node if there is none.
// Child returns a copy of the first child node of this node. Other
// children can be accessed calling Next on the first child. Returns
// an invalid Node if there is none.
func (n *Node) Child() *Node {
if n.child == 0 {
return nil
Expand All @@ -113,10 +113,9 @@ func (n *Node) Valid() bool {
return n != nil
}

// Key returns the child nodes making the Key on a supported node. Panics
// otherwise.
// They are guaranteed to be all be of the Kind Key. A simple key would return
// just one element.
// Key returns the child nodes making the Key on a supported
// node. Panics otherwise. They are guaranteed to be all be of the
// Kind Key. A simple key would return just one element.
func (n *Node) Key() Iterator {
switch n.Kind {
case KeyValue:
Expand All @@ -133,8 +132,8 @@ func (n *Node) Key() Iterator {
}

// Value returns a pointer to the value node of a KeyValue.
// Guaranteed to be non-nil.
// Panics if not called on a KeyValue node, or if the Children are malformed.
// Guaranteed to be non-nil. Panics if not called on a KeyValue node,
// or if the Children are malformed.
func (n *Node) Value() *Node {
return n.Child()
}
Expand Down
166 changes: 97 additions & 69 deletions internal/tracker/seen.go
Expand Up @@ -3,6 +3,7 @@ package tracker
import (
"bytes"
"fmt"
"sync"

"github.com/pelletier/go-toml/v2/internal/ast"
)
Expand Down Expand Up @@ -54,69 +55,103 @@ func (k keyKind) String() string {
type SeenTracker struct {
entries []entry
currentIdx int
nextID int
}

var pool sync.Pool

func (s *SeenTracker) reset() {
// Always contains a root element at index 0.
s.currentIdx = 0
if len(s.entries) == 0 {
s.entries = make([]entry, 1, 2)
} else {
s.entries = s.entries[:1]
}
s.entries[0].child = -1
s.entries[0].next = -1
}

type entry struct {
id int
parent int
// Use -1 to indicate no child or no sibling.
child int
next int

name []byte
kind keyKind
explicit bool
}

// Find the index of the child of parentIdx with key k. Returns -1 if
// it does not exist.
func (s *SeenTracker) find(parentIdx int, k []byte) int {
for i := s.entries[parentIdx].child; i >= 0; i = s.entries[i].next {
if bytes.Equal(s.entries[i].name, k) {
return i
}
}
return -1
}

// Remove all descendants of node at position idx.
func (s *SeenTracker) clear(idx int) {
p := s.entries[idx].id
rest := clear(p, s.entries[idx+1:])
s.entries = s.entries[:idx+1+len(rest)]
}
if idx >= len(s.entries) {
return
}

func clear(parentID int, entries []entry) []entry {
for i := 0; i < len(entries); {
if entries[i].parent == parentID {
id := entries[i].id
copy(entries[i:], entries[i+1:])
entries = entries[:len(entries)-1]
rest := clear(id, entries[i:])
entries = entries[:i+len(rest)]
} else {
i++
}
for i := s.entries[idx].child; i >= 0; {
next := s.entries[i].next
n := s.entries[0].next
s.entries[0].next = i
s.entries[i].next = n
s.entries[i].name = nil
s.clear(i)
i = next
}
return entries

s.entries[idx].child = -1
}

func (s *SeenTracker) create(parentIdx int, name []byte, kind keyKind, explicit bool) int {
parentID := s.id(parentIdx)
e := entry{
child: -1,
next: s.entries[parentIdx].child,

idx := len(s.entries)
s.entries = append(s.entries, entry{
id: s.nextID,
parent: parentID,
name: name,
kind: kind,
explicit: explicit,
})
s.nextID++
}
var idx int
if s.entries[0].next >= 0 {
idx = s.entries[0].next
s.entries[0].next = s.entries[idx].next
s.entries[idx] = e
} else {
idx = len(s.entries)
s.entries = append(s.entries, e)
}

s.entries[parentIdx].child = idx

return idx
}

func (s *SeenTracker) setExplicitFlag(parentIdx int) {
for i := s.entries[parentIdx].child; i >= 0; i = s.entries[i].next {
s.entries[i].explicit = true
s.setExplicitFlag(i)
}
}

// CheckExpression takes a top-level node and checks that it does not contain
// keys that have been seen in previous calls, and validates that types are
// consistent.
func (s *SeenTracker) CheckExpression(node *ast.Node) error {
if s.entries == nil {
// Skip ID = 0 to remove the confusion between nodes whose
// parent has id 0 and root nodes (parent id is 0 because it's
// the zero value).
s.nextID = 1
// Start unscoped, so idx is negative.
s.currentIdx = -1
s.reset()
}
switch node.Kind {
case ast.KeyValue:
return s.checkKeyValue(s.currentIdx, node)
return s.checkKeyValue(node)
case ast.Table:
return s.checkTable(node)
case ast.ArrayTable:
Expand All @@ -127,9 +162,13 @@ func (s *SeenTracker) CheckExpression(node *ast.Node) error {
}

func (s *SeenTracker) checkTable(node *ast.Node) error {
if s.currentIdx >= 0 {
s.setExplicitFlag(s.currentIdx)
}

it := node.Key()

parentIdx := -1
parentIdx := 0

// This code is duplicated in checkArrayTable. This is because factoring
// it in a function requires to copy the iterator, or allocate it to the
Expand Down Expand Up @@ -176,9 +215,13 @@ func (s *SeenTracker) checkTable(node *ast.Node) error {
}

func (s *SeenTracker) checkArrayTable(node *ast.Node) error {
if s.currentIdx >= 0 {
s.setExplicitFlag(s.currentIdx)
}

it := node.Key()

parentIdx := -1
parentIdx := 0

for it.Next() {
if it.IsLast() {
Expand Down Expand Up @@ -219,7 +262,8 @@ func (s *SeenTracker) checkArrayTable(node *ast.Node) error {
return nil
}

func (s *SeenTracker) checkKeyValue(parentIdx int, node *ast.Node) error {
func (s *SeenTracker) checkKeyValue(node *ast.Node) error {
parentIdx := s.currentIdx
it := node.Key()

for it.Next() {
Expand Down Expand Up @@ -249,45 +293,48 @@ func (s *SeenTracker) checkKeyValue(parentIdx int, node *ast.Node) error {

switch value.Kind {
case ast.InlineTable:
return s.checkInlineTable(parentIdx, value)
return s.checkInlineTable(value)
case ast.Array:
return s.checkArray(parentIdx, value)
return s.checkArray(value)
}

return nil
}

func (s *SeenTracker) checkArray(parentIdx int, node *ast.Node) error {
set := false
func (s *SeenTracker) checkArray(node *ast.Node) error {
it := node.Children()
for it.Next() {
if set {
s.clear(parentIdx)
}
n := it.Node()
switch n.Kind {
case ast.InlineTable:
err := s.checkInlineTable(parentIdx, n)
err := s.checkInlineTable(n)
if err != nil {
return err
}
set = true
case ast.Array:
err := s.checkArray(parentIdx, n)
err := s.checkArray(n)
if err != nil {
return err
}
set = true
}
}
return nil
}

func (s *SeenTracker) checkInlineTable(parentIdx int, node *ast.Node) error {
func (s *SeenTracker) checkInlineTable(node *ast.Node) error {
if pool.New == nil {
pool.New = func() interface{} {
return &SeenTracker{}
}
}

s = pool.Get().(*SeenTracker)
s.reset()

it := node.Children()
for it.Next() {
n := it.Node()
err := s.checkKeyValue(parentIdx, n)
err := s.checkKeyValue(n)
if err != nil {
return err
}
Expand All @@ -299,25 +346,6 @@ func (s *SeenTracker) checkInlineTable(parentIdx int, node *ast.Node) error {
// mark the presence of the inline table and prevent
// redefinition of its keys: check* functions cannot walk into
// a value.
s.clear(parentIdx)
pool.Put(s)
return nil
}

func (s *SeenTracker) id(idx int) int {
if idx >= 0 {
return s.entries[idx].id
}
return 0
}

func (s *SeenTracker) find(parentIdx int, k []byte) int {
parentID := s.id(parentIdx)

for i := parentIdx + 1; i < len(s.entries); i++ {
if s.entries[i].parent == parentID && bytes.Equal(s.entries[i].name, k) {
return i
}
}

return -1
}
8 changes: 7 additions & 1 deletion unmarshaler_test.go
Expand Up @@ -2129,7 +2129,7 @@ xz_hash = "1a48f723fea1f17d786ce6eadd9d00914d38062d28fd9c455ed3c3801905b388"

expected := doc{
Pkg: map[string]pkg{
"cargo": pkg{
"cargo": {
Target: map[string]target{
"aarch64-apple-darwin": {
XZ_URL: "https://static.rust-lang.org/dist/2021-07-29/cargo-1.54.0-aarch64-apple-darwin.tar.xz",
Expand Down Expand Up @@ -2298,6 +2298,12 @@ z=0
}
}

func TestIssue703(t *testing.T) {
var v interface{}
err := toml.Unmarshal([]byte("[a]\nx.y=0\n[a.x]"), &v)
require.Error(t, err)
}

func TestUnmarshalDecodeErrors(t *testing.T) {
examples := []struct {
desc string
Expand Down

0 comments on commit 696dd25

Please sign in to comment.