Skip to content

Commit

Permalink
format: don't linebreak when there are generated vars in func call (o…
Browse files Browse the repository at this point in the history
…pen-policy-agent#4019)

Fixes open-policy-agent#4018.

Signed-off-by: Stephan Renatus <stephan.renatus@gmail.com>
  • Loading branch information
srenatus committed Nov 17, 2021
1 parent 3d28e67 commit c174c2a
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
14 changes: 14 additions & 0 deletions format/format.go
Expand Up @@ -853,6 +853,20 @@ func (w *writer) listWriter() entryWriter {
// groupIterable will group the `elements` slice into slices according to their
// location: anything on the same line will be put into a slice.
func groupIterable(elements []interface{}, last *ast.Location) [][]interface{} {
// Generated vars occur in the AST when we're rendering the result of
// partial evaluation in a bundle build with optimization. For those vars,
// there is no location, and the grouping based on source location will
// yield a bad result. So if there's a generated variable among elements,
// we'll render the elements all in one line.
vis := ast.NewVarVisitor()
for _, elem := range elements {
vis.Walk(elem)
}
for v := range vis.Vars() {
if v.IsGenerated() {
return [][]interface{}{elements}
}
}
sort.Slice(elements, func(i, j int) bool {
return locLess(elements[i], elements[j])
})
Expand Down
16 changes: 16 additions & 0 deletions format/format_test.go
Expand Up @@ -13,6 +13,7 @@ import (
"testing"

"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/ast/location"
)

func TestFormatNilLocation(t *testing.T) {
Expand Down Expand Up @@ -42,6 +43,21 @@ func TestFormatNilLocationEmptyBody(t *testing.T) {
}
}

func TestFormatNilLocationFunctionArgs(t *testing.T) {
b := ast.NewBody()
s := ast.StringTerm(" ")
s.SetLocation(location.NewLocation([]byte("foo"), "p.rego", 2, 2))
b.Append(ast.Split.Expr(ast.NewTerm(ast.Var("__local1__")), s, ast.NewTerm(ast.Var("__local2__"))))
exp := "split(__local1__, \" \", __local2__)\n"
bs, err := Ast(b)
if err != nil {
t.Fatal(err)
}
if string(bs) != exp {
t.Fatalf("Expected %q but got %q", exp, string(bs))
}
}

func TestFormatSourceError(t *testing.T) {
rego := "testfiles/test.rego.error"
contents, err := ioutil.ReadFile(rego)
Expand Down

0 comments on commit c174c2a

Please sign in to comment.