Skip to content

Commit

Permalink
ast/compile: check arity in undefined function stage (#4059)
Browse files Browse the repository at this point in the history
* ast/compile: check arity in undefined function stage

Before, the "undefined function" check stage in the compiler (and query
compiler) only asserted that the function was known.

Now, we'll also check that the number of arguments _could be_ valid. If
it really is valid will be determined by the type checker at a later
stage.

However, asserting the arity early allows us to give more on-the-spot
error messages.

Fixes #4054.

Signed-off-by: Stephan Renatus <stephan.renatus@gmail.com>
  • Loading branch information
srenatus committed Dec 1, 2021
1 parent 234a37b commit c97f58b
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 28 deletions.
56 changes: 42 additions & 14 deletions ast/compile.go
Expand Up @@ -835,13 +835,13 @@ func (c *Compiler) checkRuleConflicts() {
func (c *Compiler) checkUndefinedFuncs() {
for _, name := range c.sorted {
m := c.Modules[name]
for _, err := range checkUndefinedFuncs(m, c.GetArity, c.RewrittenVars) {
for _, err := range checkUndefinedFuncs(c.TypeEnv, m, c.GetArity, c.RewrittenVars) {
c.err(err)
}
}
}

func checkUndefinedFuncs(x interface{}, arity func(Ref) int, rwVars map[Var]Var) Errors {
func checkUndefinedFuncs(env *TypeEnv, x interface{}, arity func(Ref) int, rwVars map[Var]Var) Errors {

var errs Errors

Expand All @@ -850,7 +850,21 @@ func checkUndefinedFuncs(x interface{}, arity func(Ref) int, rwVars map[Var]Var)
return false
}
ref := expr.Operator()
if arity(ref) >= 0 {
if arity := arity(ref); arity >= 0 {
operands := len(expr.Operands())
if expr.Generated { // an output var was added
if !expr.IsEquality() && operands != arity+1 {
ref = rewriteVarsInRef(rwVars)(ref)
errs = append(errs, arityMismatchError(env, ref, expr, arity, operands-1))
return true
}
} else { // either output var or not
if operands != arity && operands != arity+1 {
ref = rewriteVarsInRef(rwVars)(ref)
errs = append(errs, arityMismatchError(env, ref, expr, arity, operands))
return true
}
}
return false
}
ref = rewriteVarsInRef(rwVars)(ref)
Expand All @@ -861,6 +875,20 @@ func checkUndefinedFuncs(x interface{}, arity func(Ref) int, rwVars map[Var]Var)
return errs
}

func arityMismatchError(env *TypeEnv, f Ref, expr *Expr, exp, act int) *Error {
if want, ok := env.Get(f).(*types.Function); ok { // generate richer error for built-in functions
have := make([]types.Type, len(expr.Operands()))
for i, op := range expr.Operands() {
have[i] = env.Get(op)
}
return newArgError(expr.Loc(), f, "arity mismatch", have, want.FuncArgs())
}
if act != 1 {
return NewError(TypeErr, expr.Loc(), "function %v has arity %d, got %d arguments", f, exp, act)
}
return NewError(TypeErr, expr.Loc(), "function %v has arity %d, got %d argument", f, exp, act)
}

// checkSafetyRuleBodies ensures that variables appearing in negated expressions or non-target
// positions of built-in expressions will be bound when evaluating the rule from left
// to right, re-ordering as necessary.
Expand Down Expand Up @@ -2025,7 +2053,7 @@ func (qc *queryCompiler) checkVoidCalls(_ *QueryContext, body Body) (Body, error
}

func (qc *queryCompiler) checkUndefinedFuncs(_ *QueryContext, body Body) (Body, error) {
if errs := checkUndefinedFuncs(body, qc.compiler.GetArity, qc.rewritten); len(errs) > 0 {
if errs := checkUndefinedFuncs(qc.compiler.TypeEnv, body, qc.compiler.GetArity, qc.rewritten); len(errs) > 0 {
return nil, errs
}
return body, nil
Expand All @@ -2040,7 +2068,7 @@ func (qc *queryCompiler) checkSafety(_ *QueryContext, body Body) (Body, error) {
return reordered, nil
}

func (qc *queryCompiler) checkTypes(qctx *QueryContext, body Body) (Body, error) {
func (qc *queryCompiler) checkTypes(_ *QueryContext, body Body) (Body, error) {
var errs Errors
checker := newTypeChecker().
WithSchemaSet(qc.compiler.schemaSet).
Expand All @@ -2054,7 +2082,7 @@ func (qc *queryCompiler) checkTypes(qctx *QueryContext, body Body) (Body, error)
return body, nil
}

func (qc *queryCompiler) checkUnsafeBuiltins(qctx *QueryContext, body Body) (Body, error) {
func (qc *queryCompiler) checkUnsafeBuiltins(_ *QueryContext, body Body) (Body, error) {
var unsafe map[string]struct{}
if qc.unsafeBuiltins != nil {
unsafe = qc.unsafeBuiltins
Expand All @@ -2068,7 +2096,7 @@ func (qc *queryCompiler) checkUnsafeBuiltins(qctx *QueryContext, body Body) (Bod
return body, nil
}

func (qc *queryCompiler) rewriteWithModifiers(qctx *QueryContext, body Body) (Body, error) {
func (qc *queryCompiler) rewriteWithModifiers(_ *QueryContext, body Body) (Body, error) {
f := newEqualityFactory(newLocalVarGenerator("q", body))
body, err := rewriteWithModifiersInBody(qc.compiler, f, body)
if err != nil {
Expand All @@ -2077,7 +2105,7 @@ func (qc *queryCompiler) rewriteWithModifiers(qctx *QueryContext, body Body) (Bo
return body, nil
}

func (qc *queryCompiler) buildComprehensionIndices(qctx *QueryContext, body Body) (Body, error) {
func (qc *queryCompiler) buildComprehensionIndices(_ *QueryContext, body Body) (Body, error) {
// NOTE(tsandall): The query compiler does not have a metrics object so we
// cannot record index metrics currently.
_ = buildComprehensionIndices(qc.compiler.debug, qc.compiler.GetArity, ReservedVars, qc.RewrittenVars(), body, qc.comprehensionIndices)
Expand Down Expand Up @@ -2918,10 +2946,10 @@ func OutputVarsFromBody(c *Compiler, body Body, safe VarSet) VarSet {
return outputVarsForBody(body, c.GetArity, safe)
}

func outputVarsForBody(body Body, getArity func(Ref) int, safe VarSet) VarSet {
func outputVarsForBody(body Body, arity func(Ref) int, safe VarSet) VarSet {
o := safe.Copy()
for _, e := range body {
o.Update(outputVarsForExpr(e, getArity, o))
o.Update(outputVarsForExpr(e, arity, o))
}
return o.Diff(safe)
}
Expand All @@ -2933,7 +2961,7 @@ func OutputVarsFromExpr(c *Compiler, expr *Expr, safe VarSet) VarSet {
return outputVarsForExpr(expr, c.GetArity, safe)
}

func outputVarsForExpr(expr *Expr, getArity func(Ref) int, safe VarSet) VarSet {
func outputVarsForExpr(expr *Expr, arity func(Ref) int, safe VarSet) VarSet {

// Negated expressions must be safe.
if expr.Negated {
Expand Down Expand Up @@ -2968,12 +2996,12 @@ func outputVarsForExpr(expr *Expr, getArity func(Ref) int, safe VarSet) VarSet {
return VarSet{}
}

arity := getArity(operator)
if arity < 0 {
ar := arity(operator)
if ar < 0 {
return VarSet{}
}

return outputVarsForExprCall(expr, arity, safe, terms)
return outputVarsForExprCall(expr, ar, safe, terms)
default:
panic("illegal expression")
}
Expand Down
56 changes: 43 additions & 13 deletions ast/compile_test.go
Expand Up @@ -1118,6 +1118,18 @@ func TestCompilerCheckUndefinedFuncs(t *testing.T) {
undefined_dynamic_dispatch_declared_var_in_array {
z := "f"; data.test2[[z]](1)
}
arity_mismatch_1 {
data.test2.f(1,2,3)
}
arity_mismatch_2 {
data.test2.f()
}
arity_mismatch_3 {
x:= data.test2.f()
}
`

module2 := `
Expand All @@ -1141,6 +1153,9 @@ func TestCompilerCheckUndefinedFuncs(t *testing.T) {
"rego_type_error: undefined function data.test2[x]",
"rego_type_error: undefined function data.test2[y]",
"rego_type_error: undefined function data.test2[[z]]",
"rego_type_error: function data.test2.f has arity 1, got 3 arguments",
"test.rego:31: rego_type_error: function data.test2.f has arity 1, got 0 arguments",
"test.rego:35: rego_type_error: function data.test2.f has arity 1, got 0 arguments",
}
for _, w := range want {
if !strings.Contains(result, w) {
Expand Down Expand Up @@ -2693,20 +2708,19 @@ func TestCompileInvalidEqAssignExpr(t *testing.T) {
p {
# Type checking runs at a later stage so these errors will not be #
# caught until then. The stages before type checking should be tolerant
# of invalid eq and assign calls.
# Arity mismatches are caught in the checkUndefinedFuncs check,
# and invalid eq/assign calls are passed along until then.
assign()
assign(1)
eq()
eq(1)
}`)

var prev func()
checkRecursion := reflect.ValueOf(c.checkRecursion)
checkUndefinedFuncs := reflect.ValueOf(c.checkUndefinedFuncs)

for _, stage := range c.stages {
if reflect.ValueOf(stage.f).Pointer() == checkRecursion.Pointer() {
if reflect.ValueOf(stage.f).Pointer() == checkUndefinedFuncs.Pointer() {
break
}
prev = stage.f
Expand Down Expand Up @@ -4385,12 +4399,12 @@ func TestQueryCompiler(t *testing.T) {
{
note: "invalid eq",
q: "eq()",
expected: fmt.Errorf("too few arguments"),
expected: fmt.Errorf("1 error occurred: 1:1: rego_type_error: eq: arity mismatch\n\thave: ()\n\twant: (any, any)"),
},
{
note: "invalid eq",
q: "eq(1)",
expected: fmt.Errorf("too few arguments"),
expected: fmt.Errorf("1 error occurred: 1:1: rego_type_error: eq: arity mismatch\n\thave: (number)\n\twant: (any, any)"),
},
{
note: "rewrite assignment",
Expand Down Expand Up @@ -4468,11 +4482,25 @@ func TestQueryCompiler(t *testing.T) {
expected: `__localq1__ = data.a.b.c.z; __localq0__ = [__localq1__]; 1 with input as __localq0__`,
},
{
note: "unsafe exprs",
note: "built-in function arity mismatch",
q: `startswith("x")`,
pkg: "",
imports: nil,
expected: fmt.Errorf("1 error occurred: 1:1: rego_type_error: startswith: arity mismatch\n\thave: (string)\n\twant: (string, string)"),
},
{
note: "built-in function arity mismatch (arity 0)",
q: `x := opa.runtime("foo")`,
pkg: "",
imports: nil,
expected: fmt.Errorf("1 error occurred: 1:6: rego_type_error: opa.runtime: arity mismatch\n\thave: (string, ???)\n\twant: ()"),
},
{
note: "built-in function arity mismatch, nested",
q: "count(sum())",
pkg: "",
imports: nil,
expected: fmt.Errorf("1 error occurred: 1:1: rego_unsafe_var_error: expression is unsafe"),
expected: fmt.Errorf("1 error occurred: 1:7: rego_type_error: sum: arity mismatch\n\thave: (???)\n\twant: (any<array[number], set[number]>)"),
},
{
note: "check types",
Expand Down Expand Up @@ -4505,7 +4533,7 @@ func TestQueryCompiler(t *testing.T) {
},
}
for _, tc := range tests {
runQueryCompilerTest(t, tc.note, tc.q, tc.pkg, tc.imports, tc.expected)
t.Run(tc.note, runQueryCompilerTest(tc.q, tc.pkg, tc.imports, tc.expected))
}
}

Expand Down Expand Up @@ -4645,6 +4673,7 @@ func assertCompilerErrorStrings(t *testing.T, compiler *Compiler, expected []str
}

func assertNotFailed(t *testing.T, c *Compiler) {
t.Helper()
if c.Failed() {
t.Fatalf("Unexpected compilation error: %v", c.Errors)
}
Expand Down Expand Up @@ -4800,8 +4829,9 @@ func compilerErrsToStringSlice(errors []*Error) []string {
return result
}

func runQueryCompilerTest(t *testing.T, note, q, pkg string, imports []string, expected interface{}) {
t.Run(note, func(t *testing.T) {
func runQueryCompilerTest(q, pkg string, imports []string, expected interface{}) func(*testing.T) {
return func(t *testing.T) {
t.Helper()
c := NewCompiler().WithEnablePrintStatements(false)
c.Compile(getCompilerTestModules())
assertNotFailed(t, c)
Expand Down Expand Up @@ -4839,7 +4869,7 @@ func runQueryCompilerTest(t *testing.T, note, q, pkg string, imports []string, e
t.Fatalf("Expected error %v but got: %v", expected, err)
}
}
})
}
}

func TestCompilerCapabilitiesExtendedWithCustomBuiltins(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion repl/repl_test.go
Expand Up @@ -1255,7 +1255,7 @@ x := 2

buffer.Reset()
err := repl.OneShot(ctx, "assign()")
if err == nil || !strings.Contains(err.Error(), "too few arguments") {
if err == nil || !strings.Contains(err.Error(), "rego_type_error: assign: arity mismatch\n\thave: ()\n\twant: (any, any)") {
t.Fatal("Expected type check error but got:", err)
}
}
Expand Down

0 comments on commit c97f58b

Please sign in to comment.