Skip to content

Commit

Permalink
ast: Deprecating any() and all() built-in functions
Browse files Browse the repository at this point in the history
Updating compiler strict mode to produce error when these deprecated methods are used.

Fixes: #2437
Signed-off-by: Johan Fylling <johan.dev@fylling.se>
  • Loading branch information
johanfylling committed Jan 28, 2022
1 parent f7a084e commit a4dba1d
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 67 deletions.
76 changes: 42 additions & 34 deletions ast/builtins.go
Expand Up @@ -663,36 +663,6 @@ var Min = &Builtin{
),
}

// All takes a list and returns true if all of the items
// are true. A collection of length 0 returns true.
var All = &Builtin{
Name: "all",
Decl: types.NewFunction(
types.Args(
types.NewAny(
types.NewSet(types.A),
types.NewArray(nil, types.A),
),
),
types.B,
),
}

// Any takes a collection and returns true if any of the items
// is true. A collection of length 0 returns false.
var Any = &Builtin{
Name: "any",
Decl: types.NewFunction(
types.Args(
types.NewAny(
types.NewSet(types.A),
types.NewArray(nil, types.A),
),
),
types.B,
),
}

/**
* Arrays
*/
Expand Down Expand Up @@ -2516,13 +2486,51 @@ var RegexMatchDeprecated = &Builtin{
),
}

// All takes a list and returns true if all of the items
// are true. A collection of length 0 returns true.
var All = &Builtin{
Name: "all",
Decl: types.NewFunction(
types.Args(
types.NewAny(
types.NewSet(types.A),
types.NewArray(nil, types.A),
),
),
types.B,
),
deprecated: true,
}

// Any takes a collection and returns true if any of the items
// is true. A collection of length 0 returns false.
var Any = &Builtin{
Name: "any",
Decl: types.NewFunction(
types.Args(
types.NewAny(
types.NewSet(types.A),
types.NewArray(nil, types.A),
),
),
types.B,
),
deprecated: true,
}

// Builtin represents a built-in function supported by OPA. Every built-in
// function is uniquely identified by a name.
type Builtin struct {
Name string `json:"name"` // Unique name of built-in function, e.g., <name>(arg1,arg2,...,argN)
Decl *types.Function `json:"decl"` // Built-in function type declaration.
Infix string `json:"infix,omitempty"` // Unique name of infix operator. Default should be unset.
Relation bool `json:"relation,omitempty"` // Indicates if the built-in acts as a relation.
Name string `json:"name"` // Unique name of built-in function, e.g., <name>(arg1,arg2,...,argN)
Decl *types.Function `json:"decl"` // Built-in function type declaration.
Infix string `json:"infix,omitempty"` // Unique name of infix operator. Default should be unset.
Relation bool `json:"relation,omitempty"` // Indicates if the built-in acts as a relation.
deprecated bool // Indicates if the built-in has been deprecated.
}

// IsDeprecated returns true if the Builtin function is deprecated and will be removed in a future release.
func (b *Builtin) IsDeprecated() bool {
return b.deprecated
}

// Expr creates a new expression for the built-in with the given operands.
Expand Down
24 changes: 16 additions & 8 deletions ast/compile.go
Expand Up @@ -103,6 +103,7 @@ type Compiler struct {
builtins map[string]*Builtin // universe of built-in functions
customBuiltins map[string]*Builtin // user-supplied custom built-in functions (deprecated: use capabilities)
unsafeBuiltinsMap map[string]struct{} // user-supplied set of unsafe built-ins functions to block (deprecated: use capabilities)
deprecatedBuiltinsMap map[string]struct{} // set of deprecated, but not removed, built-in functions
enablePrintStatements bool // indicates if print statements should be elided (default)
comprehensionIndices map[*Term]*ComprehensionIndex // comprehension key index
initialized bool // indicates if init() has been called
Expand Down Expand Up @@ -240,11 +241,12 @@ func NewCompiler() *Compiler {
}, func(x util.T) int {
return x.(Ref).Hash()
}),
maxErrs: CompileErrorLimitDefault,
after: map[string][]CompilerStageDefinition{},
unsafeBuiltinsMap: map[string]struct{}{},
comprehensionIndices: map[*Term]*ComprehensionIndex{},
debug: debug.Discard(),
maxErrs: CompileErrorLimitDefault,
after: map[string][]CompilerStageDefinition{},
unsafeBuiltinsMap: map[string]struct{}{},
deprecatedBuiltinsMap: map[string]struct{}{},
comprehensionIndices: map[*Term]*ComprehensionIndex{},
debug: debug.Discard(),
}

c.ModuleTree = NewModuleTree(nil)
Expand Down Expand Up @@ -1165,7 +1167,7 @@ func (c *Compiler) checkTypes() {

func (c *Compiler) checkUnsafeBuiltins() {
for _, name := range c.sorted {
errs := checkUnsafeBuiltins(c.unsafeBuiltinsMap, c.Modules[name])
errs := checkUnsafeBuiltins(c.unsafeBuiltinsMap, c.deprecatedBuiltinsMap, c.Modules[name])
for _, err := range errs {
c.err(err)
}
Expand Down Expand Up @@ -1224,6 +1226,9 @@ func (c *Compiler) init() {

for _, bi := range c.capabilities.Builtins {
c.builtins[bi.Name] = bi
if c.strict && bi.IsDeprecated() {
c.deprecatedBuiltinsMap[bi.Name] = struct{}{}
}
}

for name, bi := range c.customBuiltins {
Expand Down Expand Up @@ -2126,7 +2131,7 @@ func (qc *queryCompiler) checkUnsafeBuiltins(_ *QueryContext, body Body) (Body,
} else {
unsafe = qc.compiler.unsafeBuiltinsMap
}
errs := checkUnsafeBuiltins(unsafe, body)
errs := checkUnsafeBuiltins(unsafe, qc.compiler.deprecatedBuiltinsMap, body)
if len(errs) > 0 {
return nil, errs
}
Expand Down Expand Up @@ -4485,14 +4490,17 @@ func safetyErrorSlice(unsafe unsafeVars, rewritten map[Var]Var) (result Errors)
return
}

func checkUnsafeBuiltins(unsafeBuiltinsMap map[string]struct{}, node interface{}) Errors {
func checkUnsafeBuiltins(unsafeBuiltinsMap map[string]struct{}, deprecatedBuiltinsMap map[string]struct{}, node interface{}) Errors {
errs := make(Errors, 0)
WalkExprs(node, func(x *Expr) bool {
if x.IsCall() {
operator := x.Operator().String()
if _, ok := unsafeBuiltinsMap[operator]; ok {
errs = append(errs, NewError(TypeErr, x.Loc(), "unsafe built-in function calls in expression: %v", operator))
}
if _, ok := deprecatedBuiltinsMap[operator]; ok {
errs = append(errs, NewError(TypeErr, x.Loc(), "deprecated built-in function calls in expression: %v", operator))
}
}
return false
})
Expand Down
93 changes: 69 additions & 24 deletions ast/compile_test.go
Expand Up @@ -1464,12 +1464,7 @@ func TestCompilerRewriteExprTerms(t *testing.T) {
}

func TestCompilerCheckDuplicateImports(t *testing.T) {
cases := []struct {
note string
module string
expectedErrors Errors
strict bool
}{
cases := []strictnessTestCase{
{
note: "shadow",
module: `package test
Expand All @@ -1488,7 +1483,6 @@ func TestCompilerCheckDuplicateImports(t *testing.T) {
Message: "import must not shadow import input.foo",
},
},
strict: true,
}, {
note: "alias shadow",
module: `package test
Expand All @@ -1502,35 +1496,86 @@ func TestCompilerCheckDuplicateImports(t *testing.T) {
Message: "import must not shadow import input.foo",
},
},
strict: true,
}, {
note: "no strict",
},
}

runStrictnessTestCase(t, cases, true)
}

func TestCompilerCheckDeprecatedMethods(t *testing.T) {
cases := []strictnessTestCase{
{
note: "all() built-in",
module: `package test
import input.noconflict
import input.foo
import data.foo
import data.bar.foo
import input.bar as foo
p := all([true, false])
`,
expectedErrors: Errors{
&Error{
Location: NewLocation([]byte("all([true, false])"), "", 2, 10),
Message: "deprecated built-in function calls in expression: all",
},
},
},
{
note: "user-defined all()",
module: `package test
import future.keywords.in
all(arr) = {x | some x in arr} == {true}
p := all([true, false])
`,
},
{
note: "any() built-in",
module: `package test
p := any([true, false])
`,
expectedErrors: Errors{
&Error{
Location: NewLocation([]byte("any([true, false])"), "", 2, 10),
Message: "deprecated built-in function calls in expression: any",
},
},
},
{
note: "user-defined any()",
module: `package test
import future.keywords.in
any(arr) = true in arr
p := any([true, false])
`,
strict: false,
},
}

for _, tc := range cases {
t.Run(tc.note, func(t *testing.T) {
compiler := NewCompiler().WithStrict(tc.strict)
runStrictnessTestCase(t, cases, true)
}

type strictnessTestCase struct {
note string
module string
expectedErrors Errors
}

func runStrictnessTestCase(t *testing.T, cases []strictnessTestCase, assertLocation bool) {
t.Helper()
makeTestRunner := func(tc strictnessTestCase, strict bool) func(t *testing.T) {
return func(t *testing.T) {
compiler := NewCompiler().WithStrict(strict)
compiler.Modules = map[string]*Module{
"test": MustParseModule(tc.module),
}
compileStages(compiler, nil)

compileStages(compiler, compiler.checkDuplicateImports)

if len(tc.expectedErrors) > 0 {
assertErrors(t, compiler.Errors, tc.expectedErrors, true)
if strict {
assertErrors(t, compiler.Errors, tc.expectedErrors, false)
} else {
assertNotFailed(t, compiler)
}
})
}
}

for _, tc := range cases {
t.Run(tc.note+"_strict", makeTestRunner(tc, true))
t.Run(tc.note+"_non-strict", makeTestRunner(tc, false))
}
}

Expand Down
3 changes: 2 additions & 1 deletion docs/content/strict.md
Expand Up @@ -19,4 +19,5 @@ Compiler Strict mode is supported by the `check` command, and can be enabled thr
Name | Description | Enforced by default in OPA version
--- | --- | ---
Duplicate imports | Duplicate [imports](../policy-language/#imports), where one import shadows another, are prohibited. | 1.0
Unused local assignments | Unused [assignments](../policy-reference/#assignment-and-equality) local to a rule, function or comprehension are prohibited | 1.0
Unused local assignments | Unused [assignments](../policy-reference/#assignment-and-equality) local to a rule, function or comprehension are prohibited | 1.0
`any()` and `all()` removed | The `any()` and `all()` built-in functions have been deprecated, and will be removed in OPA 1.0. Use of these functions is prohibited. | 1.0

0 comments on commit a4dba1d

Please sign in to comment.