diff --git a/ast/compile.go b/ast/compile.go index 9a1e73b299..df4695331d 100644 --- a/ast/compile.go +++ b/ast/compile.go @@ -4798,6 +4798,14 @@ func rewriteWithModifier(c *Compiler, f *equalityFactory, expr *Expr) ([]*Expr, func validateWith(c *Compiler, expr *Expr, i int) (bool, *Error) { target, value := expr.With[i].Target, expr.With[i].Value + + // Ensure that values that are built-ins are rewritten to Ref (not Var) + if v, ok := value.Value.(Var); ok { + if _, ok := c.builtins[v.String()]; ok { + value.Value = Ref([]*Term{NewTerm(v)}) + } + } + switch { case isDataRef(target): ref := target.Value.(Ref) @@ -4813,11 +4821,16 @@ func validateWith(c *Compiler, expr *Expr, i int) (bool, *Error) { } if node != nil { + // NOTE(sr): at this point in the compiler stages, we don't have a fully-populated + // TypeEnv yet -- so we have to make do with this check to see if the replacement + // target is a function. It's probably wrong for arity-0 functions, but those are + // and edge case anyways. if child := node.Child(ref[len(ref)-1].Value); child != nil { - for _, value := range child.Values { - if len(value.(*Rule).Head.Args) > 0 { - // TODO(sr): UDF - return false, NewError(CompileErr, target.Loc(), "with keyword used on non-built-in function") + for _, v := range child.Values { + if len(v.(*Rule).Head.Args) > 0 { + if validateWithFunctionValue(c.builtins, c.RuleTree, value) { + return false, nil + } } } } @@ -4830,11 +4843,6 @@ func validateWith(c *Compiler, expr *Expr, i int) (bool, *Error) { if v, ok := target.Value.(Var); ok { target.Value = Ref([]*Term{NewTerm(v)}) } - if v, ok := value.Value.(Var); ok { - if _, ok := c.builtins[v.String()]; ok { - value.Value = Ref([]*Term{NewTerm(v)}) - } - } targetRef := target.Value.(Ref) bi := c.builtins[targetRef.String()] // safe because isBuiltinRefOrVar checked this @@ -4842,17 +4850,11 @@ func validateWith(c *Compiler, expr *Expr, i int) (bool, *Error) { return false, err } - if v, ok := value.Value.(Ref); ok { - if c.RuleTree.Find(v) != nil { // ref exists in rule tree - return false, nil - } - if _, ok := c.builtins[v.String()]; ok { // built-in replaced by other built-in - return false, nil - } + if validateWithFunctionValue(c.builtins, c.RuleTree, value) { + return false, nil } - default: - return false, NewError(TypeErr, target.Location, "with keyword target must reference existing %v, %v, or a built-in function", InputRootDocument, DefaultRootDocument) + return false, NewError(TypeErr, target.Location, "with keyword target must reference existing %v, %v, or a function", InputRootDocument, DefaultRootDocument) } return requiresEval(value), nil } @@ -4878,6 +4880,15 @@ func validateWithBuiltinTarget(bi *Builtin, target Ref, loc *location.Location) return nil } +func validateWithFunctionValue(bs map[string]*Builtin, ruleTree *TreeNode, value *Term) bool { + if v, ok := value.Value.(Ref); ok { + if ruleTree.Find(v) != nil { // ref exists in rule tree + return true + } + } + return isBuiltinRefOrVar(bs, value) +} + func isInputRef(term *Term) bool { if ref, ok := term.Value.(Ref); ok { if ref.HasPrefix(InputRootRef) { diff --git a/ast/compile_test.go b/ast/compile_test.go index 344dc28bac..d00dfffb9c 100644 --- a/ast/compile_test.go +++ b/ast/compile_test.go @@ -4040,7 +4040,7 @@ func TestCompilerRewriteWithValue(t *testing.T) { { note: "invalid target", input: `p { true with foo.q as 1 }`, - wantErr: fmt.Errorf("rego_type_error: with keyword target must reference existing input, data, or a built-in function"), + wantErr: fmt.Errorf("rego_type_error: with keyword target must reference existing input, data, or a function"), }, { note: "built-in function: replaced by (unknown) var", @@ -4495,21 +4495,6 @@ func TestRewritePrintCallsWithElseImplicitArgs(t *testing.T) { } func TestCompilerMockFunction(t *testing.T) { - c := NewCompiler() - c.Modules["test"] = MustParseModule(` - package test - - is_allowed(label) { - label == "test_label" - } - - p {true with data.test.is_allowed as "blah" } - `) - compileStages(c, c.rewriteWithModifiers) - assertCompilerErrorStrings(t, c, []string{"rego_compile_error: with keyword used on non-built-in function"}) -} - -func TestCompilerMockBuiltinFunction(t *testing.T) { tests := []struct { note string module, extra string @@ -4645,6 +4630,54 @@ func TestCompilerMockBuiltinFunction(t *testing.T) { p { bar(foo.bar("one")) with bar as mock with foo.bar as mock_mock } `, }, + { + note: "non-built-in function replaced value", + module: `package test + original(_) + p { original(true) with original as 123 } + `, + }, + { + note: "non-built-in function replaced by another, arity 0", + module: `package test + original() = 1 + mock() = 2 + p { original() with original as mock } + `, + err: "rego_type_error: undefined function data.test.original", // TODO(sr): file bug -- this doesn't depend on "with" used or not + }, + { + note: "non-built-in function replaced by another, arity 1", + module: `package test + original(_) + mock(_) + p { original(true) with original as mock } + `, + }, + { + note: "non-built-in function replaced by built-in", + module: `package test + original(_) + p { original([1]) with original as count } + `, + }, + { + note: "non-built-in function replaced by another, arity mismatch", + module: `package test + original(_) + mock(_, _) + p { original([1]) with original as mock } + `, + err: "rego_type_error: data.test.original: arity mismatch\n\thave: (any, any)\n\twant: (any)", + }, + { + note: "non-built-in function replaced by built-in, arity mismatch", + module: `package test + original(_) + p { original([1]) with original as concat } + `, + err: "rego_type_error: data.test.original: arity mismatch\n\thave: (string, any)\n\twant: (any)", + }, } for _, tc := range tests { @@ -6415,7 +6448,7 @@ func TestQueryCompiler(t *testing.T) { q: "x = 1 with foo.p as null", pkg: "", imports: nil, - expected: fmt.Errorf("1 error occurred: 1:12: rego_type_error: with keyword target must reference existing input, data, or a built-in function"), + expected: fmt.Errorf("1 error occurred: 1:12: rego_type_error: with keyword target must reference existing input, data, or a function"), }, { note: "rewrite with value", diff --git a/docs/content/policy-language.md b/docs/content/policy-language.md index a9516eafbf..10e688bd4a 100644 --- a/docs/content/policy-language.md +++ b/docs/content/policy-language.md @@ -1487,7 +1487,7 @@ following syntax: ``` The ``s must be references to values in the input document (or the input -document itself) or data document, or references to built-in functions. +document itself) or data document, or references to functions (built-in or not). {{< info >}} When applied to the `data` document, the `` must not attempt to @@ -1516,7 +1516,7 @@ outer := result { } ``` -When `` is a reference to a built-in function, like `http.send`, then +When `` is a reference to a function, like `http.send`, then its `` can be any of the following: 1. a value: `with http.send as {"body": {"success": true }}` 2. a reference to another function: `with http.send as mock_http_send` @@ -1533,10 +1533,10 @@ See the following example: package opa.examples import future.keywords.in -f(x) = count(x) +f(x) := count(x) -mock_count(x) = 0 { "x" in x } -mock_count(x) = count(x) { not "x" in x } +mock_count(x) := 0 { "x" in x } +mock_count(x) := count(x) { not "x" in x } ``` ```live:with_builtins/1:query:merge_down @@ -1558,12 +1558,12 @@ Each replacement function evaluation will start a new scope: it's valid to use package opa.examples import future.keywords.in -f(x) = count(x) { +f(x) := count(x) { rule_using_concat with concat as "foo,bar" } -mock_count(x) = 0 { "x" in x } -mock_count(x) = count(x) { not "x" in x } +mock_count(x) := 0 { "x" in x } +mock_count(x) := count(x) { not "x" in x } rule_using_concat { concat(",", input.x) == "foo,bar" diff --git a/docs/content/policy-testing.md b/docs/content/policy-testing.md index 0a6c542300..349ad872bd 100644 --- a/docs/content/policy-testing.md +++ b/docs/content/policy-testing.md @@ -233,15 +233,15 @@ opa test --format=json pass_fail_error_test.rego ## Data and Function Mocking -OPA's `with` keyword can be used to replace the data document or built-in functions by mocks. +OPA's `with` keyword can be used to replace the data document or called functions with mocks. Both base and virtual documents can be replaced. -When replacing built-in functions, the following constraints are in place: +When replacing functions, built-in or otherwise, the following constraints are in place: 1. Replacing `internal.*` functions, or `rego.metadata.*`, or `eq`; or relations (`walk`) is not allowed. 2. Replacement and replaced function need to have the same arity. 3. Replaced functions can call the functions they're replacing, and those calls - will call out to the original built-in function, and not cause recursion. + will call out to the original function, and not cause recursion. Below is a simple policy that depends on the data document. @@ -360,7 +360,7 @@ data.authz.test_allow: PASS (458.752µs) PASS: 1/1 ``` -In simple cases, a built-in function can also be replaced with a value, as in +In simple cases, a function can also be replaced with a value, as in ```live:with_keyword_builtins/tests/value:module:read_only test_allow_value { @@ -374,22 +374,19 @@ test_allow_value { Every invocation of the function will then return the replacement value, regardless of the function's arguments. -Note that it's also possible to replace one built-in function by another. - - -**User-defined functions** cannot be replaced by the `with` keyword. -For example, in the below policy the function `cannot_replace` cannot be replaced. +Note that it's also possible to replace one built-in function by another; or a non-built-in +function by a built-in function. **authz.rego**: ```live:with_keyword_funcs:module:read_only package authz -invalid_replace { - cannot_replace(input.label) +replace_rule { + replace(input.label) } -cannot_replace(label) { +replace(label) { label == "test_label" } ``` @@ -399,14 +396,16 @@ cannot_replace(label) { ```live:with_keyword_funcs/tests:module:read_only package authz -test_invalid_replace { - invalid_replace with input as {"label": "test_label"} with cannot_replace as true +test_replace_rule { + replace_rule with input.label as "does-not-matter" with replace as true } ``` ```console $ opa test -v authz.rego authz_test.rego -1 error occurred: authz_test.rego:4: rego_compile_error: with keyword cannot replace rego functions +data.authz.test_replace_rule: PASS (648.314µs) +-------------------------------------------------------------------------------- +PASS: 1/1 ``` diff --git a/format/testfiles/test_with.rego b/format/testfiles/test_with.rego index 5993e12c5d..b5ef60f7e0 100644 --- a/format/testfiles/test_with.rego +++ b/format/testfiles/test_with.rego @@ -22,4 +22,11 @@ func_replacements { count(array.concat(input.x, [])) with input.x as "foo" with array.concat as true with count as mock_f +} + +original(x) = x+1 + +more_func_replacements { + original(1) with original as mock_f + original(1) with original as 1234 } \ No newline at end of file diff --git a/format/testfiles/test_with.rego.formatted b/format/testfiles/test_with.rego.formatted index c792ebbf5b..4552aa85c0 100644 --- a/format/testfiles/test_with.rego.formatted +++ b/format/testfiles/test_with.rego.formatted @@ -22,3 +22,10 @@ func_replacements { with array.concat as true with count as mock_f } + +original(x) = x + 1 + +more_func_replacements { + original(1) with original as mock_f + original(1) with original as 1234 +} diff --git a/internal/planner/planner.go b/internal/planner/planner.go index 6e4447dedf..4d27fa1e38 100644 --- a/internal/planner/planner.go +++ b/internal/planner/planner.go @@ -558,27 +558,29 @@ func (p *Planner) planWith(e *ast.Expr, iter planiter) error { values := make([]*ast.Term, 0, len(e.With)) // NOTE(sr): we could be overallocating if there are builtin replacements targets := make([]ast.Ref, 0, len(e.With)) - builtins := frame{} + mocks := frame{} for _, w := range e.With { - switch v := w.Target.Value.(type) { - case ast.Ref: - if ast.DefaultRootDocument.Equal(v[0]) || - ast.InputRootDocument.Equal(v[0]) { + v := w.Target.Value.(ast.Ref) - values = append(values, w.Value) - targets = append(targets, w.Target.Value.(ast.Ref)) - continue - } + switch { + case p.isFunction(v): // nothing to do + + case ast.DefaultRootDocument.Equal(v[0]) || + ast.InputRootDocument.Equal(v[0]): + + values = append(values, w.Value) + targets = append(targets, w.Target.Value.(ast.Ref)) + + continue // not a mock } - // target is a builtin - builtins[w.Target.String()] = w.Value + mocks[w.Target.String()] = w.Value } return p.planTermSlice(values, func(locals []ir.Operand) error { - p.mocks.PushFrame(builtins) + p.mocks.PushFrame(mocks) paths := make([][]int, len(targets)) saveVars := ast.NewVarSet() @@ -637,7 +639,7 @@ func (p *Planner) planWith(e *ast.Expr, iter planiter) error { err := iter() - p.mocks.PushFrame(builtins) + p.mocks.PushFrame(mocks) if shadowing { p.funcs.Push(map[string]string{}) for _, ref := range dataRefs { @@ -826,13 +828,13 @@ func (p *Planner) planExprCall(e *ast.Expr, iter planiter) error { switch r := replacement.Value.(type) { case ast.Ref: if !r.HasPrefix(ast.DefaultRootRef) && !r.HasPrefix(ast.InputRootRef) { - // replacement is other builtin + // replacement is builtin operator = r.String() - decl := p.decls[operator] - p.externs[operator] = decl + bi := p.decls[operator] + p.externs[operator] = bi // void functions and relations are forbidden; arity validation happened in compiler - return p.planExprCallFunc(operator, len(decl.Decl.Args()), void, operands, args, iter) + return p.planExprCallFunc(operator, len(bi.Decl.FuncArgs().Args), void, operands, args, iter) } // replacement is a function (rule) @@ -848,8 +850,13 @@ func (p *Planner) planExprCall(e *ast.Expr, iter planiter) error { return fmt.Errorf("illegal replacement of operator %q by %v", operator, replacement) - default: // target is a builtin, replacement a value - return p.planExprCallValue(replacement, len(p.decls[operator].Decl.Args()), operands, iter) + default: // replacement is a value + if bi, ok := p.decls[operator]; ok { + return p.planExprCallValue(replacement, len(bi.Decl.FuncArgs().Args), operands, iter) + } + if node := p.rules.Lookup(op); node != nil { + return p.planExprCallValue(replacement, node.Arity(), operands, iter) + } } } @@ -2264,6 +2271,13 @@ func (p *Planner) defaultOperands() []ir.Operand { } } +func (p *Planner) isFunction(r ast.Ref) bool { + if node := p.rules.Lookup(r); node != nil { + return node.Arity() > 0 + } + return false +} + func op(v ir.Val) ir.Operand { return ir.Operand{Value: v} } diff --git a/internal/planner/rules.go b/internal/planner/rules.go index 919cc7ff65..5e94fb1c5e 100644 --- a/internal/planner/rules.go +++ b/internal/planner/rules.go @@ -203,10 +203,10 @@ func (s *functionMocksStack) PopFrame() { *current = (*current)[:len(*current)-1] } -func (s *functionMocksStack) Lookup(builtinName string) *ast.Term { +func (s *functionMocksStack) Lookup(f string) *ast.Term { current := *s.stack[len(s.stack)-1] for i := len(current) - 1; i >= 0; i-- { - if t, ok := current[i][builtinName]; ok { + if t, ok := current[i][f]; ok { return t } } diff --git a/test/cases/testdata/withkeyword/test-with-function-mock.yaml b/test/cases/testdata/withkeyword/test-with-function-mock.yaml new file mode 100644 index 0000000000..70bf9f243a --- /dev/null +++ b/test/cases/testdata/withkeyword/test-with-function-mock.yaml @@ -0,0 +1,110 @@ +cases: +- data: + modules: + - | + package test + f(_) = 2 + p = y { + y = f(true) with f as 1 + } + note: 'withkeyword/function: direct call, value replacement, arity 1' # NOTE(sr): arity-0 functions fail typechecking + query: data.test.p = x + want_result: + - x: 1 +- data: + modules: + - | + package test + f(_) = 2 + g(_) = 1 + p = y { + y = f(true) with f as g + } + note: 'withkeyword/function: direct call, function replacement, arity 1' + query: data.test.p = x + want_result: + - x: 1 +- data: + modules: + - | + package test + f(_) = 2 + g(_) = 1 + p { + f(true, 1) with f as g + } + note: 'withkeyword/function: direct call, function replacement, arity 1, result captured' + query: data.test.p = x + want_result: + - x: true +- data: + modules: + - | + package test + f(_) = 2 + p = y { + y = f([1]) with f as count + } + note: 'withkeyword/function: direct call, built-in replacement, arity 1' + query: data.test.p = x + want_result: + - x: 1 +- data: + modules: + - | + package test + f(_) = 2 + p { + f([1], 1) with f as count + } + note: 'withkeyword/function: direct call, built-in replacement, arity 1, result captured' + query: data.test.p = x + want_result: + - x: true +- data: + modules: + - | + package test + + f1(x) = object.union_n(x) + f2(x) = count(x) + f3(x) = array.reverse(x) + f(x) = f1(x) + g(x) = 123 { + f2(x) + s with f3 as h + } + h(_) = ["replaced"] + p { q with f1 as f } + q { r with f2 as g } + r { x := [{"foo": 4}, {"baz": 5}]; f2(x) == 123; f1(x) == {"foo": 4, "baz": 5} } + s { x := [{}]; f3(x) == ["replaced"] } + note: 'withkeyword/function: nested scope handling' + query: data.test.p = x + want_result: + - x: true +- data: + modules: + - | + package test + + f(x) = 2 + g(x) = f(x) + p = y { y := f(1) with f as g } + note: 'withkeyword/function: simple scope handling (no recursion here)' + query: data.test.p = x + want_result: + - x: 2 +- data: + modules: + - | + package test + + f(_) = 1 { + input.x = "x" + } + p = y { y := f(1) with f as 2 } + note: 'withkeyword/function: rule indexing irrelevant' + query: data.test.p = x + want_result: + - x: 2 diff --git a/topdown/cache.go b/topdown/cache.go index 3db817581e..1b7c455eec 100644 --- a/topdown/cache.go +++ b/topdown/cache.go @@ -280,10 +280,10 @@ func (s *functionMocksStack) Put(el frame) { *current = append(*current, el) } -func (s *functionMocksStack) Get(builtinName string) (*ast.Term, bool) { +func (s *functionMocksStack) Get(f ast.Ref) (*ast.Term, bool) { current := *s.stack[len(s.stack)-1] for i := len(current) - 1; i >= 0; i-- { - if r, ok := current[i][builtinName]; ok { + if r, ok := current[i][f.String()]; ok { return r, true } } diff --git a/topdown/eval.go b/topdown/eval.go index 4163ada181..580f056e1a 100644 --- a/topdown/eval.go +++ b/topdown/eval.go @@ -16,6 +16,7 @@ import ( "github.com/open-policy-agent/opa/topdown/copypropagation" "github.com/open-policy-agent/opa/topdown/print" "github.com/open-policy-agent/opa/tracing" + "github.com/open-policy-agent/opa/types" ) type evalIterator func(*eval) error @@ -464,9 +465,10 @@ func (e *eval) evalWith(iter evalIterator) error { // could be relaxed in certain cases (e.g., if the with statement would // have no effect.) for _, with := range expr.With { - if isOtherRef(with.Target) { - // built-in replaced - _ = disableRef(with.Value.Value.(ast.Ref)) + if isFunction(e.compiler.TypeEnv, with.Target) || // non-builtin function replaced + isOtherRef(with.Target) { // built-in replaced + + ast.WalkRefs(with.Value, disableRef) continue } @@ -476,7 +478,6 @@ func (e *eval) evalWith(iter evalIterator) error { return e.next(iter) }) } - ast.WalkRefs(with.Target, disableRef) ast.WalkRefs(with.Value, disableRef) } @@ -490,21 +491,26 @@ func (e *eval) evalWith(iter evalIterator) error { targets := []ast.Ref{} for i := range expr.With { - target := expr.With[i].Target.Value + target := expr.With[i].Target plugged := e.bindings.Plug(expr.With[i].Value) switch { - case isInputRef(expr.With[i].Target): - pairsInput = append(pairsInput, [...]*ast.Term{expr.With[i].Target, plugged}) - case isDataRef(expr.With[i].Target): - pairsData = append(pairsData, [...]*ast.Term{expr.With[i].Target, plugged}) + // NOTE(sr): ordering matters here: isFunction's ref is also covered by isDataRef + case isFunction(e.compiler.TypeEnv, target): + functionMocks = append(functionMocks, [...]*ast.Term{target, plugged}) + + case isInputRef(target): + pairsInput = append(pairsInput, [...]*ast.Term{target, plugged}) + + case isDataRef(target): + pairsData = append(pairsData, [...]*ast.Term{target, plugged}) + default: // target must be builtin - _, _, ok := e.builtinFunc(target.String()) - if ok { - functionMocks = append(functionMocks, [...]*ast.Term{expr.With[i].Target, plugged}) + if _, _, ok := e.builtinFunc(target.String()); ok { + functionMocks = append(functionMocks, [...]*ast.Term{target, plugged}) + continue // don't append to disabled targets below } - continue } - targets = append(targets, target.(ast.Ref)) + targets = append(targets, target.Value.(ast.Ref)) } input, err := mergeTermWithValues(e.input, pairsInput) @@ -708,7 +714,32 @@ func (e *eval) evalCall(terms []*ast.Term, iter unifyIterator) error { ref := terms[0].Value.(ast.Ref) + var mocked bool + mock, mocked := e.functionMocks.Get(ref) + if mocked { + if m, ok := mock.Value.(ast.Ref); ok { // builtin or data function + mockCall := append([]*ast.Term{ast.NewTerm(m)}, terms[1:]...) + + e.functionMocks.Push() + err := e.evalCall(mockCall, func() error { + e.functionMocks.Pop() + err := iter() + e.functionMocks.Push() + return err + }) + e.functionMocks.Pop() + return err + } + } + // 'mocked' true now indicates that the replacement is a value: if + // it was a ref to a function, we'd have called that above. + if ref[0].Equal(ast.DefaultRootDocument) { + if mocked { + f := e.compiler.TypeEnv.Get(ref).(*types.Function) + return e.evalCallValue(len(f.FuncArgs().Args), terms, mock, iter) + } + var ir *ast.IndexResult var err error if e.partial() { @@ -719,6 +750,7 @@ func (e *eval) evalCall(terms []*ast.Term, iter unifyIterator) error { if err != nil { return err } + eval := evalFunc{ e: e, ref: ref, @@ -734,33 +766,8 @@ func (e *eval) evalCall(terms []*ast.Term, iter unifyIterator) error { return unsupportedBuiltinErr(e.query[e.index].Location) } - if mock, ok := e.functionMocks.Get(builtinName); ok { - switch m := mock.Value.(type) { - case ast.Ref: // builtin or data function - mockCall := append([]*ast.Term{ast.NewTerm(m)}, terms[1:]...) - - e.functionMocks.Push() - err := e.evalCall(mockCall, func() error { - e.functionMocks.Pop() - err := iter() - e.functionMocks.Push() - return err - }) - e.functionMocks.Pop() - return err - - default: // value replacement - switch { - case len(terms) == len(bi.Decl.Args())+2: // captured var - return e.unify(terms[len(terms)-1], mock, iter) - - case len(terms) == len(bi.Decl.Args())+1: - if mock.Value.Compare(ast.Boolean(false)) != 0 { - return iter() - } - return nil - } - } + if mocked { // value replacement of built-in call + return e.evalCallValue(len(bi.Decl.Args()), terms, mock, iter) } if e.unknown(e.query[e.index], e.bindings) { @@ -806,6 +813,20 @@ func (e *eval) evalCall(terms []*ast.Term, iter unifyIterator) error { return eval.eval(iter) } +func (e *eval) evalCallValue(arity int, terms []*ast.Term, mock *ast.Term, iter unifyIterator) error { + switch { + case len(terms) == arity+2: // captured var + return e.unify(terms[len(terms)-1], mock, iter) + + case len(terms) == arity+1: + if mock.Value.Compare(ast.Boolean(false)) != 0 { + return iter() + } + return nil + } + panic("unreachable") +} + func (e *eval) unify(a, b *ast.Term, iter unifyIterator) error { return e.biunify(a, b, e.bindings, e.bindings, iter) } @@ -3254,6 +3275,15 @@ func isOtherRef(term *ast.Term) bool { return !ref.HasPrefix(ast.DefaultRootRef) && !ref.HasPrefix(ast.InputRootRef) } +func isFunction(env *ast.TypeEnv, ref *ast.Term) bool { + r, ok := ref.Value.(ast.Ref) + if !ok { + return false + } + _, ok = env.Get(r).(*types.Function) + return ok +} + func merge(a, b ast.Value) (ast.Value, bool) { aObj, ok1 := a.(ast.Object) bObj, ok2 := b.(ast.Object) @@ -3323,17 +3353,10 @@ func suppressEarlyExit(err error) error { func (e *eval) updateSavedMocks(withs []*ast.With) []*ast.With { ret := make([]*ast.With, 0, len(withs)) for _, w := range withs { - v := w.Copy() - if isOtherRef(w.Target) { - ref := v.Value.Value.(ast.Ref) - nref := e.namespaceRef(ref) - if e.saveSupport.Exists(nref) { - v.Value.Value = nref - } else { - continue // skip - } + if isOtherRef(w.Target) || isFunction(e.compiler.TypeEnv, w.Target) { + continue } - ret = append(ret, v) + ret = append(ret, w.Copy()) } return ret } diff --git a/topdown/topdown_partial_test.go b/topdown/topdown_partial_test.go index d330ebf067..8c48c130d3 100644 --- a/topdown/topdown_partial_test.go +++ b/topdown/topdown_partial_test.go @@ -766,11 +766,45 @@ func TestTopDownPartialEval(t *testing.T) { `package test mock_concat(_, _) = "foo/bar" - p { q with concat as mock_concat} + p { q with concat as mock_concat } + q { concat("/", ["a", "b"], "foo/bar") }`, + }, + wantQueries: []string{`a = true`}, + }, + { + note: "with+builtin: value replacement", + query: "data.test.p = a", + modules: []string{ + `package test + + p { q with concat as "foo/bar" } q { concat("/", ["a", "b"], "foo/bar") }`, }, wantQueries: []string{`a = true`}, }, + { + note: "with+function: no unknowns", + query: "data.test.p = a", + modules: []string{ + `package test + f(_, _) = "x" + mock_f(_, _) = "foo/bar" + p { q with f as mock_f } + q { f("/", ["a", "b"], "foo/bar") }`, + }, + wantQueries: []string{`a = true`}, + }, + { + note: "with+function: value replacement", + query: "data.test.p = a", + modules: []string{ + `package test + f(_, _) = "x" + p { q with f as "foo/bar" } + q { f("/", ["a", "b"], "foo/bar") }`, + }, + wantQueries: []string{`a = true`}, + }, { note: "with+builtin: unknowns in replacement function", query: "data.test.p = a", @@ -791,6 +825,26 @@ func TestTopDownPartialEval(t *testing.T) { }`, }, }, + { + note: "with+function: unknowns in replacement function", + query: "data.test.p = a", + modules: []string{ + `package test + f(_) = "x/y" + mock_f(_) = "foo/bar" { input.y } + p { q with f as mock_f} + q { f("/", "foo/bar") }`, + }, + wantQueries: []string{`data.partial.test.mock_f("/", "foo/bar"); a = true`}, + wantSupport: []string{ + `package partial.test + + mock_f(__local1__3) = "foo/bar" { + input.y = x_term_3_03 + x_term_3_03 + }`, + }, + }, { note: "with+builtin: unknowns in replaced function's args", query: "data.test.p = a", @@ -806,8 +860,7 @@ func TestTopDownPartialEval(t *testing.T) { }`, }, wantQueries: []string{` - data.partial.test.q = x_term_1_01 with array.concat as data.partial.test.mock_concat - x_term_1_01 with array.concat as data.partial.test.mock_concat + data.partial.test.q a = true `}, wantSupport: []string{`package partial.test @@ -818,6 +871,32 @@ func TestTopDownPartialEval(t *testing.T) { mock_concat(__local0__3, __local1__3) = ["foo", "bar"] `}, }, + { + note: "with+function: unknowns in replaced function's args", + query: "data.test.p = a", + modules: []string{ + `package test + my_concat(x, y) = concat(x, y) + mock_concat(_, _) = "foo,bar" + p { + q with my_concat as mock_concat + } + q { + my_concat("/", input, "foo,bar") + }`, + }, + wantQueries: []string{` + data.partial.test.q + a = true + `}, + wantSupport: []string{`package partial.test + + q { + data.partial.test.mock_concat("/", input, "foo,bar") + } + mock_concat(__local2__3, __local3__3) = "foo,bar" + `}, + }, { note: "with+builtin: unknowns in replacement function's bodies", query: "data.test.p = a", @@ -831,8 +910,7 @@ func TestTopDownPartialEval(t *testing.T) { q { x := array.concat(["foo"], input) }`, }, wantQueries: []string{` - data.partial.test.q = x_term_1_01 with array.concat as data.partial.test.mock_concat - x_term_1_01 with array.concat as data.partial.test.mock_concat + data.partial.test.q a = true `}, wantSupport: []string{`package partial.test @@ -851,6 +929,38 @@ func TestTopDownPartialEval(t *testing.T) { x_term_4_04 }`}, }, + { + note: "with+function: unknowns in replacement function's bodies", + query: "data.test.p = a", + modules: []string{ + `package test + my_concat(x, y) = concat(x, y) + mock_concat(_, _) = "foo,bar" { input.foo } + mock_concat(_, _) = "bar,baz" { input.bar } + + p { q with my_concat as mock_concat } + q { x := my_concat(",", input) }`, + }, + wantQueries: []string{` + data.partial.test.q + a = true + `}, + wantSupport: []string{`package partial.test + + q { + __local9__2 = input + data.partial.test.mock_concat(",", __local9__2, __local8__2) + __local6__2 = __local8__2 + } + mock_concat(__local2__3, __local3__3) = "foo,bar" { + input.foo = x_term_3_03 + x_term_3_03 + } + mock_concat(__local4__4, __local5__4) = "bar,baz" { + input.bar = x_term_4_04 + x_term_4_04 + }`}, + }, { note: "with+builtin+negation: when replacement has no unknowns (args, defs), save negated expr without replacement", query: "data.test.p = true", @@ -873,6 +983,28 @@ func TestTopDownPartialEval(t *testing.T) { q { 100 = input.x } `}, }, + { + note: "with+function+negation: when replacement has no unknowns (args, defs), save negated expr without replacement", + query: "data.test.p = true", + modules: []string{` + package test + my_count(x) = count(x) + mock_count(_) = 100 + p { + not q with input.x as 1 with my_count as mock_count + } + + q { + my_count([1,2,3]) = input.x + } + `}, + wantQueries: []string{"not data.partial.test.q with input.x as 1"}, + wantSupport: []string{` + package partial.test + + q { 100 = input.x } + `}, + }, { note: "with+builtin+negation: when replacement args have unknowns, save negated expr with replacement", query: "data.test.p = true", @@ -888,16 +1020,35 @@ func TestTopDownPartialEval(t *testing.T) { count(input.y) = input.x # unknown arg for mocked func } `}, - wantQueryASTs: func() []ast.Body { - b := ast.MustParseBody("not data.partial.test.q with input.x as 1 with count as data.partial.test.mock_count") - b[0].With[1].Target.Value = ast.Ref([]*ast.Term{ast.VarTerm("count")}) - return []ast.Body{b} - }(), + wantQueries: []string{"not data.partial.test.q with input.x as 1"}, wantSupport: []string{` package partial.test q { data.partial.test.mock_count(input.y, __local1__3); __local1__3 = input.x } - mock_count(__local0__4) = 100 { true } + mock_count(__local0__4) = 100 + `}, + }, + { + note: "with+function+negation: when replacement args have unknowns, save negated expr with replacement", + query: "data.test.p = true", + modules: []string{` + package test + my_count(x) = count(x) + mock_count(_) = 100 + p { + not q with input.x as 1 with my_count as mock_count + } + + q { + my_count(input.y) = input.x # unknown arg for mocked func + } + `}, + wantQueries: []string{`not data.partial.test.q with input.x as 1`}, + wantSupport: []string{` + package partial.test + + q { data.partial.test.mock_count(input.y, __local3__3); __local3__3 = input.x } + mock_count(__local1__4) = 100 `}, }, { @@ -916,11 +1067,7 @@ func TestTopDownPartialEval(t *testing.T) { count([1]) = input.x # unknown arg for mocked func } `}, - wantQueryASTs: func() []ast.Body { - b := ast.MustParseBody("not data.partial.test.q with input.x as 1 with count as data.partial.test.mock_count") - b[0].With[1].Target.Value = ast.Ref([]*ast.Term{ast.VarTerm("count")}) - return []ast.Body{b} - }(), + wantQueries: []string{"not data.partial.test.q with input.x as 1"}, wantSupport: []string{` package partial.test @@ -929,6 +1076,31 @@ func TestTopDownPartialEval(t *testing.T) { mock_count(__local1__5) = 101 { input.z = x_term_5_05; x_term_5_05 } `}, }, + { + note: "with+function+negation: when replacement defs have unknowns, save negated expr with replacement", + query: "data.test.p = true", + modules: []string{` + package test + my_count(x) = count(x) + mock_count(_) = 100 { input.y } + mock_count(_) = 101 { input.z } + p { + not q with input.x as 1 with my_count as mock_count + } + + q { + my_count([1]) = input.x # unknown arg for mocked func + } + `}, + wantQueries: []string{"not data.partial.test.q with input.x as 1"}, + wantSupport: []string{` + package partial.test + + q { data.partial.test.mock_count([1], __local4__3); __local4__3 = input.x } + mock_count(__local1__4) = 100 { input.y = x_term_4_04; x_term_4_04 } + mock_count(__local2__5) = 101 { input.z = x_term_5_05; x_term_5_05 } + `}, + }, { note: "save: sub path", query: "input.x = 1; input.y = 2; input.z.a = 3; input.z.b = x",