From c174c2a46860d47e2e9c3c2a5a4526a1292ee2ea Mon Sep 17 00:00:00 2001 From: Stephan Renatus Date: Wed, 17 Nov 2021 10:41:07 +0100 Subject: [PATCH] format: don't linebreak when there are generated vars in func call (#4019) Fixes #4018. Signed-off-by: Stephan Renatus --- format/format.go | 14 ++++++++++++++ format/format_test.go | 16 ++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/format/format.go b/format/format.go index 180c1f1250..48d2318327 100644 --- a/format/format.go +++ b/format/format.go @@ -853,6 +853,20 @@ func (w *writer) listWriter() entryWriter { // groupIterable will group the `elements` slice into slices according to their // location: anything on the same line will be put into a slice. func groupIterable(elements []interface{}, last *ast.Location) [][]interface{} { + // Generated vars occur in the AST when we're rendering the result of + // partial evaluation in a bundle build with optimization. For those vars, + // there is no location, and the grouping based on source location will + // yield a bad result. So if there's a generated variable among elements, + // we'll render the elements all in one line. + vis := ast.NewVarVisitor() + for _, elem := range elements { + vis.Walk(elem) + } + for v := range vis.Vars() { + if v.IsGenerated() { + return [][]interface{}{elements} + } + } sort.Slice(elements, func(i, j int) bool { return locLess(elements[i], elements[j]) }) diff --git a/format/format_test.go b/format/format_test.go index 68b47fe654..9efb307367 100644 --- a/format/format_test.go +++ b/format/format_test.go @@ -13,6 +13,7 @@ import ( "testing" "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/ast/location" ) func TestFormatNilLocation(t *testing.T) { @@ -42,6 +43,21 @@ func TestFormatNilLocationEmptyBody(t *testing.T) { } } +func TestFormatNilLocationFunctionArgs(t *testing.T) { + b := ast.NewBody() + s := ast.StringTerm(" ") + s.SetLocation(location.NewLocation([]byte("foo"), "p.rego", 2, 2)) + b.Append(ast.Split.Expr(ast.NewTerm(ast.Var("__local1__")), s, ast.NewTerm(ast.Var("__local2__")))) + exp := "split(__local1__, \" \", __local2__)\n" + bs, err := Ast(b) + if err != nil { + t.Fatal(err) + } + if string(bs) != exp { + t.Fatalf("Expected %q but got %q", exp, string(bs)) + } +} + func TestFormatSourceError(t *testing.T) { rego := "testfiles/test.rego.error" contents, err := ioutil.ReadFile(rego)