Skip to content

Commit

Permalink
format: ensure future keyword import with in (#4115)
Browse files Browse the repository at this point in the history
When a module is formatted that has calls to `internal.member_2` or
`internal.member_3`, which get pretty-printed as infix `in` operator
calls, the formatter now ensures that the corresponding future keyword
import is present.

Fixes #4111.

Signed-off-by: Stephan Renatus <stephan.renatus@gmail.com>
  • Loading branch information
srenatus committed Dec 9, 2021
1 parent ab48bd7 commit 9795126
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 11 deletions.
67 changes: 56 additions & 11 deletions format/format.go
Expand Up @@ -51,6 +51,12 @@ func Ast(x interface{}) ([]byte, error) {

wildcards := map[ast.Var]*ast.Term{}

// NOTE(sr): When the formatter encounters a call to internal.member_2
// or internal.member_3, it will sugarize them into usage of the `in`
// operator. It has to ensure that the proper future keyword import is
// present.
extraFutureKeywordImports := map[string]bool{}

// Preprocess the AST. Set any required defaults and calculate
// values required for printing the formatted output.
ast.WalkNodes(x, func(x ast.Node) bool {
Expand All @@ -61,6 +67,13 @@ func Ast(x interface{}) ([]byte, error) {
}
case *ast.Term:
unmangleWildcardVar(wildcards, n)

case *ast.Expr:
if n.IsCall() &&
ast.Member.Ref().Equal(n.Operator()) ||
ast.MemberWithKey.Ref().Equal(n.Operator()) {
extraFutureKeywordImports["in"] = true
}
}
if x.Loc() == nil {
x.SetLoc(defaultLocation(x))
Expand All @@ -74,6 +87,9 @@ func Ast(x interface{}) ([]byte, error) {

switch x := x.(type) {
case *ast.Module:
for kw := range extraFutureKeywordImports {
x.Imports = ensureFutureKeywordImport(x.Imports, kw)
}
w.writeModule(x)
case *ast.Package:
w.writePackage(x, nil)
Expand Down Expand Up @@ -910,20 +926,30 @@ func mapImportsToComments(imports []*ast.Import, comments []*ast.Comment) (map[*
return m, leftovers
}

func groupImports(imports []*ast.Import) (groups [][]*ast.Import) {
if len(imports) == 0 {
func groupImports(imports []*ast.Import) [][]*ast.Import {
switch len(imports) { // shortcuts
case 0:
return nil
case 1:
return [][]*ast.Import{imports}
}
// there are >=2 imports to group

var groups [][]*ast.Import
group := []*ast.Import{imports[0]}

for _, i := range imports[1:] {
last := group[len(group)-1]

// nil-location imports have been sorted up to come first
if i.Loc() != nil && last.Loc() != nil && // first import with a location, or
i.Loc().Row-last.Loc().Row > 1 { // more than one row apart from previous import

last := imports[0]
var group []*ast.Import
for _, i := range imports {
if i.Loc().Row-last.Loc().Row > 1 {
// start a new group
groups = append(groups, group)
group = []*ast.Import{}
}
group = append(group, i)
last = i
}
if len(group) > 0 {
groups = append(groups, group)
Expand Down Expand Up @@ -982,21 +1008,29 @@ func locLess(a, b interface{}) bool {
func locCmp(a, b interface{}) int {
al := getLoc(a)
bl := getLoc(b)
switch {
case al == nil && bl == nil:
return 0
case al == nil:
return -1
case bl == nil:
return 1
}

if cmp := al.Row - bl.Row; cmp != 0 {
return cmp

}
return al.Col - bl.Col
}

func getLoc(x interface{}) *ast.Location {
switch x := x.(type) {
case ast.Statement:
// Implicitly matches *ast.Head, *ast.Expr, *ast.With, *ast.Term.
case ast.Node: // *ast.Head, *ast.Expr, *ast.With, *ast.Term
return x.Loc()
case *ast.Location:
return x
case [2]*ast.Term:
// Special case to allow for easy printing of objects.
case [2]*ast.Term: // Special case to allow for easy printing of objects.
return x[0].Location
default:
panic("Not reached")
Expand Down Expand Up @@ -1154,3 +1188,14 @@ func (w *writer) down() {
}
w.level--
}

func ensureFutureKeywordImport(imps []*ast.Import, kw string) []*ast.Import {
allKeywords := ast.MustParseTerm("future.keywords")
kwPath := ast.MustParseTerm("future.keywords." + kw)
for _, imp := range imps {
if allKeywords.Equal(imp.Path) || imp.Path.Equal(kwPath) {
return imps
}
}
return append(imps, &ast.Import{Path: kwPath})
}
@@ -0,0 +1,9 @@
package p

import input.foo
import future.keywords

r {
internal.member_2(1, [1])
internal.member_3(0, 1, [1])
}
@@ -0,0 +1,9 @@
package p

import future.keywords
import input.foo

r {
1 in [1]
0, 1 in [1]
}
8 changes: 8 additions & 0 deletions format/testfiles/test_in_operator_without_import.rego
@@ -0,0 +1,8 @@
package p

import input.foo

r {
internal.member_2(1, [1])
internal.member_3(0, 1, [1])
}
@@ -0,0 +1,9 @@
package p

import future.keywords.in
import input.foo

r {
1 in [1]
0, 1 in [1]
}

0 comments on commit 9795126

Please sign in to comment.