Skip to content

Commit

Permalink
ast/compile: check arity in undefined function stage
Browse files Browse the repository at this point in the history
Before, the "undefined function" check stage in the compiler (and query
compiler) only asserted that the function was known.

Now, we'll also check that the number of arguments _could be_ valid. If
it really is valid will be determined by the type checker at a later
stage.

However, asserting the arity early allows us to give more on-the-spot
error messages.

Fixes open-policy-agent#4054.

Signed-off-by: Stephan Renatus <stephan.renatus@gmail.com>
  • Loading branch information
srenatus committed Nov 30, 2021
1 parent 7f73ac1 commit 170f33f
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 19 deletions.
44 changes: 33 additions & 11 deletions ast/compile.go
Expand Up @@ -11,6 +11,7 @@ import (
"strconv"
"strings"

"github.com/open-policy-agent/opa/ast/location"
"github.com/open-policy-agent/opa/internal/debug"
"github.com/open-policy-agent/opa/internal/gojsonschema"
"github.com/open-policy-agent/opa/metrics"
Expand Down Expand Up @@ -850,7 +851,21 @@ func checkUndefinedFuncs(x interface{}, arity func(Ref) int, rwVars map[Var]Var)
return false
}
ref := expr.Operator()
if arity(ref) >= 0 {
if arity := arity(ref); arity > 0 {
operands := len(expr.Operands())
if expr.Generated { // an output var was added
if !expr.IsEquality() && operands != arity+1 {
ref = rewriteVarsInRef(rwVars)(ref)
errs = append(errs, arityMismatchError(ref, expr.Loc(), arity, operands-1))
return true
}
} else { // either output var or not
if operands != arity && operands != arity+1 {
ref = rewriteVarsInRef(rwVars)(ref)
errs = append(errs, arityMismatchError(ref, expr.Loc(), arity, operands))
return true
}
}
return false
}
ref = rewriteVarsInRef(rwVars)(ref)
Expand All @@ -861,6 +876,13 @@ func checkUndefinedFuncs(x interface{}, arity func(Ref) int, rwVars map[Var]Var)
return errs
}

func arityMismatchError(f Ref, loc *location.Location, exp, act int) *Error {
if act != 1 {
return NewError(TypeErr, loc, "function %v has arity %d, got %d arguments", f, exp, act)
}
return NewError(TypeErr, loc, "function %v has arity %d, got %d argument", f, exp, act)
}

// checkSafetyRuleBodies ensures that variables appearing in negated expressions or non-target
// positions of built-in expressions will be bound when evaluating the rule from left
// to right, re-ordering as necessary.
Expand Down Expand Up @@ -2040,7 +2062,7 @@ func (qc *queryCompiler) checkSafety(_ *QueryContext, body Body) (Body, error) {
return reordered, nil
}

func (qc *queryCompiler) checkTypes(qctx *QueryContext, body Body) (Body, error) {
func (qc *queryCompiler) checkTypes(_ *QueryContext, body Body) (Body, error) {
var errs Errors
checker := newTypeChecker().
WithSchemaSet(qc.compiler.schemaSet).
Expand All @@ -2054,7 +2076,7 @@ func (qc *queryCompiler) checkTypes(qctx *QueryContext, body Body) (Body, error)
return body, nil
}

func (qc *queryCompiler) checkUnsafeBuiltins(qctx *QueryContext, body Body) (Body, error) {
func (qc *queryCompiler) checkUnsafeBuiltins(_ *QueryContext, body Body) (Body, error) {
var unsafe map[string]struct{}
if qc.unsafeBuiltins != nil {
unsafe = qc.unsafeBuiltins
Expand All @@ -2068,7 +2090,7 @@ func (qc *queryCompiler) checkUnsafeBuiltins(qctx *QueryContext, body Body) (Bod
return body, nil
}

func (qc *queryCompiler) rewriteWithModifiers(qctx *QueryContext, body Body) (Body, error) {
func (qc *queryCompiler) rewriteWithModifiers(_ *QueryContext, body Body) (Body, error) {
f := newEqualityFactory(newLocalVarGenerator("q", body))
body, err := rewriteWithModifiersInBody(qc.compiler, f, body)
if err != nil {
Expand All @@ -2077,7 +2099,7 @@ func (qc *queryCompiler) rewriteWithModifiers(qctx *QueryContext, body Body) (Bo
return body, nil
}

func (qc *queryCompiler) buildComprehensionIndices(qctx *QueryContext, body Body) (Body, error) {
func (qc *queryCompiler) buildComprehensionIndices(_ *QueryContext, body Body) (Body, error) {
// NOTE(tsandall): The query compiler does not have a metrics object so we
// cannot record index metrics currently.
_ = buildComprehensionIndices(qc.compiler.debug, qc.compiler.GetArity, ReservedVars, qc.RewrittenVars(), body, qc.comprehensionIndices)
Expand Down Expand Up @@ -2918,10 +2940,10 @@ func OutputVarsFromBody(c *Compiler, body Body, safe VarSet) VarSet {
return outputVarsForBody(body, c.GetArity, safe)
}

func outputVarsForBody(body Body, getArity func(Ref) int, safe VarSet) VarSet {
func outputVarsForBody(body Body, arity func(Ref) int, safe VarSet) VarSet {
o := safe.Copy()
for _, e := range body {
o.Update(outputVarsForExpr(e, getArity, o))
o.Update(outputVarsForExpr(e, arity, o))
}
return o.Diff(safe)
}
Expand All @@ -2933,7 +2955,7 @@ func OutputVarsFromExpr(c *Compiler, expr *Expr, safe VarSet) VarSet {
return outputVarsForExpr(expr, c.GetArity, safe)
}

func outputVarsForExpr(expr *Expr, getArity func(Ref) int, safe VarSet) VarSet {
func outputVarsForExpr(expr *Expr, arity func(Ref) int, safe VarSet) VarSet {

// Negated expressions must be safe.
if expr.Negated {
Expand Down Expand Up @@ -2968,12 +2990,12 @@ func outputVarsForExpr(expr *Expr, getArity func(Ref) int, safe VarSet) VarSet {
return VarSet{}
}

arity := getArity(operator)
if arity < 0 {
ar := arity(operator)
if ar < 0 {
return VarSet{}
}

return outputVarsForExprCall(expr, arity, safe, terms)
return outputVarsForExprCall(expr, ar, safe, terms)
default:
panic("illegal expression")
}
Expand Down
16 changes: 8 additions & 8 deletions ast/compile_test.go
Expand Up @@ -2693,20 +2693,19 @@ func TestCompileInvalidEqAssignExpr(t *testing.T) {
p {
# Type checking runs at a later stage so these errors will not be #
# caught until then. The stages before type checking should be tolerant
# of invalid eq and assign calls.
# Arity mismatches are caught in the checkUndefinedFuncs check,
# and invalid eq/assign calls are passed along until then.
assign()
assign(1)
eq()
eq(1)
}`)

var prev func()
checkRecursion := reflect.ValueOf(c.checkRecursion)
checkUndefinedFuncs := reflect.ValueOf(c.checkUndefinedFuncs)

for _, stage := range c.stages {
if reflect.ValueOf(stage.f).Pointer() == checkRecursion.Pointer() {
if reflect.ValueOf(stage.f).Pointer() == checkUndefinedFuncs.Pointer() {
break
}
prev = stage.f
Expand Down Expand Up @@ -4385,12 +4384,12 @@ func TestQueryCompiler(t *testing.T) {
{
note: "invalid eq",
q: "eq()",
expected: fmt.Errorf("too few arguments"),
expected: fmt.Errorf("1 error occurred: 1:1: rego_type_error: function eq has arity 2, got 0 arguments"),
},
{
note: "invalid eq",
q: "eq(1)",
expected: fmt.Errorf("too few arguments"),
expected: fmt.Errorf("1 error occurred: 1:1: rego_type_error: function eq has arity 2, got 1 argument"),
},
{
note: "rewrite assignment",
Expand Down Expand Up @@ -4472,7 +4471,7 @@ func TestQueryCompiler(t *testing.T) {
q: "count(sum())",
pkg: "",
imports: nil,
expected: fmt.Errorf("1 error occurred: 1:1: rego_unsafe_var_error: expression is unsafe"),
expected: fmt.Errorf("1 error occurred: 1:7: rego_type_error: function sum has arity 1, got 0 arguments"),
},
{
note: "check types",
Expand Down Expand Up @@ -4645,6 +4644,7 @@ func assertCompilerErrorStrings(t *testing.T, compiler *Compiler, expected []str
}

func assertNotFailed(t *testing.T, c *Compiler) {
t.Helper()
if c.Failed() {
t.Fatalf("Unexpected compilation error: %v", c.Errors)
}
Expand Down

0 comments on commit 170f33f

Please sign in to comment.