diff --git a/ast/compile.go b/ast/compile.go index 42db098829..f449fddcd7 100644 --- a/ast/compile.go +++ b/ast/compile.go @@ -2196,7 +2196,7 @@ func (c *Compiler) rewriteWithModifiers() { if !ok { return x, nil } - body, err := rewriteWithModifiersInBody(c, f, body) + body, err := rewriteWithModifiersInBody(c, c.unsafeBuiltinsMap, f, body) if err != nil { c.err(err) } @@ -2475,19 +2475,20 @@ func (qc *queryCompiler) checkTypes(_ *QueryContext, body Body) (Body, error) { } func (qc *queryCompiler) checkUnsafeBuiltins(_ *QueryContext, body Body) (Body, error) { - var unsafe map[string]struct{} - if qc.unsafeBuiltins != nil { - unsafe = qc.unsafeBuiltins - } else { - unsafe = qc.compiler.unsafeBuiltinsMap - } - errs := checkUnsafeBuiltins(unsafe, body) + errs := checkUnsafeBuiltins(qc.unsafeBuiltinsMap(), body) if len(errs) > 0 { return nil, errs } return body, nil } +func (qc *queryCompiler) unsafeBuiltinsMap() map[string]struct{} { + if qc.unsafeBuiltins != nil { + return qc.unsafeBuiltins + } + return qc.compiler.unsafeBuiltinsMap +} + func (qc *queryCompiler) checkDeprecatedBuiltins(_ *QueryContext, body Body) (Body, error) { errs := checkDeprecatedBuiltins(qc.compiler.deprecatedBuiltinsMap, body, qc.compiler.strict) if len(errs) > 0 { @@ -2498,7 +2499,7 @@ func (qc *queryCompiler) checkDeprecatedBuiltins(_ *QueryContext, body Body) (Bo func (qc *queryCompiler) rewriteWithModifiers(_ *QueryContext, body Body) (Body, error) { f := newEqualityFactory(newLocalVarGenerator("q", body)) - body, err := rewriteWithModifiersInBody(qc.compiler, f, body) + body, err := rewriteWithModifiersInBody(qc.compiler, qc.unsafeBuiltinsMap(), f, body) if err != nil { return nil, Errors{err} } @@ -4785,10 +4786,10 @@ func rewriteDeclaredVar(g *localVarGenerator, stack *localDeclaredVars, v Var, o // rewriteWithModifiersInBody will rewrite the body so that with modifiers do // not contain terms that require evaluation as values. If this function // encounters an invalid with modifier target then it will raise an error. -func rewriteWithModifiersInBody(c *Compiler, f *equalityFactory, body Body) (Body, *Error) { +func rewriteWithModifiersInBody(c *Compiler, unsafeBuiltinsMap map[string]struct{}, f *equalityFactory, body Body) (Body, *Error) { var result Body for i := range body { - exprs, err := rewriteWithModifier(c, f, body[i]) + exprs, err := rewriteWithModifier(c, unsafeBuiltinsMap, f, body[i]) if err != nil { return nil, err } @@ -4803,11 +4804,11 @@ func rewriteWithModifiersInBody(c *Compiler, f *equalityFactory, body Body) (Bod return result, nil } -func rewriteWithModifier(c *Compiler, f *equalityFactory, expr *Expr) ([]*Expr, *Error) { +func rewriteWithModifier(c *Compiler, unsafeBuiltinsMap map[string]struct{}, f *equalityFactory, expr *Expr) ([]*Expr, *Error) { var result []*Expr for i := range expr.With { - eval, err := validateWith(c, expr, i) + eval, err := validateWith(c, unsafeBuiltinsMap, expr, i) if err != nil { return nil, err } @@ -4822,7 +4823,7 @@ func rewriteWithModifier(c *Compiler, f *equalityFactory, expr *Expr) ([]*Expr, return append(result, expr), nil } -func validateWith(c *Compiler, expr *Expr, i int) (bool, *Error) { +func validateWith(c *Compiler, unsafeBuiltinsMap map[string]struct{}, expr *Expr, i int) (bool, *Error) { target, value := expr.With[i].Target, expr.With[i].Value // Ensure that values that are built-ins are rewritten to Ref (not Var) @@ -4831,6 +4832,10 @@ func validateWith(c *Compiler, expr *Expr, i int) (bool, *Error) { value.Value = Ref([]*Term{NewTerm(v)}) } } + isBuiltinRefOrVar, err := isBuiltinRefOrVar(c.builtins, unsafeBuiltinsMap, target) + if err != nil { + return false, err + } switch { case isDataRef(target): @@ -4854,15 +4859,15 @@ func validateWith(c *Compiler, expr *Expr, i int) (bool, *Error) { if child := node.Child(ref[len(ref)-1].Value); child != nil { for _, v := range child.Values { if len(v.(*Rule).Head.Args) > 0 { - if validateWithFunctionValue(c.builtins, c.RuleTree, value) { - return false, nil + if ok, err := validateWithFunctionValue(c.builtins, unsafeBuiltinsMap, c.RuleTree, value); err != nil || ok { + return false, err // may be nil } } } } } case isInputRef(target): // ok, valid - case isBuiltinRefOrVar(c.builtins, target): + case isBuiltinRefOrVar: // NOTE(sr): first we ensure that parsed Var builtins (`count`, `concat`, etc) // are rewritten to their proper Ref convention @@ -4876,8 +4881,8 @@ func validateWith(c *Compiler, expr *Expr, i int) (bool, *Error) { return false, err } - if validateWithFunctionValue(c.builtins, c.RuleTree, value) { - return false, nil + if ok, err := validateWithFunctionValue(c.builtins, unsafeBuiltinsMap, c.RuleTree, value); err != nil || ok { + return false, err // may be nil } default: return false, NewError(TypeErr, target.Location, "with keyword target must reference existing %v, %v, or a function", InputRootDocument, DefaultRootDocument) @@ -4906,13 +4911,13 @@ func validateWithBuiltinTarget(bi *Builtin, target Ref, loc *location.Location) return nil } -func validateWithFunctionValue(bs map[string]*Builtin, ruleTree *TreeNode, value *Term) bool { +func validateWithFunctionValue(bs map[string]*Builtin, unsafeMap map[string]struct{}, ruleTree *TreeNode, value *Term) (bool, *Error) { if v, ok := value.Value.(Ref); ok { if ruleTree.Find(v) != nil { // ref exists in rule tree - return true + return true, nil } } - return isBuiltinRefOrVar(bs, value) + return isBuiltinRefOrVar(bs, unsafeMap, value) } func isInputRef(term *Term) bool { @@ -4933,13 +4938,16 @@ func isDataRef(term *Term) bool { return false } -func isBuiltinRefOrVar(bs map[string]*Builtin, term *Term) bool { +func isBuiltinRefOrVar(bs map[string]*Builtin, unsafeBuiltinsMap map[string]struct{}, term *Term) (bool, *Error) { switch v := term.Value.(type) { case Ref, Var: + if _, ok := unsafeBuiltinsMap[v.String()]; ok { + return false, NewError(CompileErr, term.Location, "with keyword replacing built-in function: target must not be unsafe: %q", v) + } _, ok := bs[v.String()] - return ok + return ok, nil } - return false + return false, nil } func isVirtual(node *TreeNode, ref Ref) bool { diff --git a/ast/compile_test.go b/ast/compile_test.go index 0bb3cc41ca..aa3165114a 100644 --- a/ast/compile_test.go +++ b/ast/compile_test.go @@ -4049,6 +4049,7 @@ func TestCompilerRewriteWithValue(t *testing.T) { tests := []struct { note string input string + opts func(*Compiler) *Compiler expected string expectedRule *Rule wantErr error @@ -4154,6 +4155,26 @@ func TestCompilerRewriteWithValue(t *testing.T) { return r }(), }, + { + note: "built-in function: replaced by another built-in that's marked unsafe", + input: ` + q := is_object({"url": "https://httpbin.org", "method": "GET"}) + p { q with is_object as http.send } + `, + opts: func(c *Compiler) *Compiler { return c.WithUnsafeBuiltins(map[string]struct{}{"http.send": {}}) }, + wantErr: fmt.Errorf("rego_compile_error: with keyword replacing built-in function: target must not be unsafe: \"http.send\""), + }, + { + note: "non-built-in function: replaced by another built-in that's marked unsafe", + input: ` + r(_) = {} + q := r({"url": "https://httpbin.org", "method": "GET"}) + p { + q with r as http.send + }`, + opts: func(c *Compiler) *Compiler { return c.WithUnsafeBuiltins(map[string]struct{}{"http.send": {}}) }, + wantErr: fmt.Errorf("rego_compile_error: with keyword replacing built-in function: target must not be unsafe: \"http.send\""), + }, { note: "built-in function: valid, arity 1, non-compound name", input: ` @@ -4171,6 +4192,9 @@ func TestCompilerRewriteWithValue(t *testing.T) { for _, tc := range tests { t.Run(tc.note, func(t *testing.T) { c := NewCompiler() + if tc.opts != nil { + c = tc.opts(c) + } module := fixture + tc.input c.Modules["test"] = MustParseModule(module) compileStages(c, c.rewriteWithModifiers) @@ -6676,13 +6700,63 @@ func TestQueryCompilerWithStageAfterWithMetrics(t *testing.T) { } func TestQueryCompilerWithUnsafeBuiltins(t *testing.T) { - c := NewCompiler().WithUnsafeBuiltins(map[string]struct{}{ - "count": {}, - }) + tests := []struct { + note string + query string + compiler *Compiler + opts func(QueryCompiler) QueryCompiler + err string + }{ + { + note: "builtin unsafe via compiler", + query: "count([])", + compiler: NewCompiler().WithUnsafeBuiltins(map[string]struct{}{"count": {}}), + err: "unsafe built-in function calls in expression: count", + }, + { + note: "builtin unsafe via query compiler", + query: "count([])", + compiler: NewCompiler(), + opts: func(qc QueryCompiler) QueryCompiler { + return qc.WithUnsafeBuiltins(map[string]struct{}{"count": {}}) + }, + err: "unsafe built-in function calls in expression: count", + }, + { + note: "builtin unsafe via compiler, 'with' mocking", + query: "is_array([]) with is_array as count", + compiler: NewCompiler().WithUnsafeBuiltins(map[string]struct{}{"count": {}}), + err: `with keyword replacing built-in function: target must not be unsafe: "count"`, + }, + { + note: "builtin unsafe via query compiler, 'with' mocking", + query: "is_array([]) with is_array as count", + compiler: NewCompiler(), + opts: func(qc QueryCompiler) QueryCompiler { + return qc.WithUnsafeBuiltins(map[string]struct{}{"count": {}}) + }, + err: `with keyword replacing built-in function: target must not be unsafe: "count"`, + }, + } - _, err := c.QueryCompiler().WithUnsafeBuiltins(map[string]struct{}{}).Compile(MustParseBody("count([])")) - if err != nil { - t.Fatal(err) + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + qc := tc.compiler.QueryCompiler() + if tc.opts != nil { + qc = tc.opts(qc) + } + _, err := qc.Compile(MustParseBody(tc.query)) + var errs Errors + if !errors.As(err, &errs) { + t.Fatalf("expected error type %T, got %v %[2]T", errs, err) + } + if exp, act := 1, len(errs); exp != act { + t.Fatalf("expected %d error(s), got %d", exp, act) + } + if exp, act := tc.err, errs[0].Message; exp != act { + t.Errorf("expected message %q, got %q", exp, act) + } + }) } } diff --git a/rego/rego_test.go b/rego/rego_test.go index 6f20c05027..40dd15f58a 100644 --- a/rego/rego_test.go +++ b/rego/rego_test.go @@ -1437,6 +1437,7 @@ func TestUnsafeBuiltins(t *testing.T) { ctx := context.Background() unsafeCountExpr := "unsafe built-in function calls in expression: count" + unsafeCountExprWith := `with keyword replacing built-in function: target must not be unsafe: "count"` t.Run("unsafe query", func(t *testing.T) { r := New( @@ -1448,6 +1449,16 @@ func TestUnsafeBuiltins(t *testing.T) { } }) + t.Run("unsafe query, 'with' replacement", func(t *testing.T) { + r := New( + Query(`is_array([1, 2, 3]) with is_array as count`), + UnsafeBuiltins(map[string]struct{}{"count": {}}), + ) + if _, err := r.Eval(ctx); err == nil || !strings.Contains(err.Error(), unsafeCountExprWith) { + t.Fatalf("Expected unsafe built-in error but got %v", err) + } + }) + t.Run("unsafe module", func(t *testing.T) { r := New( Query(`data.pkg.deny`), @@ -1463,6 +1474,36 @@ func TestUnsafeBuiltins(t *testing.T) { } }) + t.Run("unsafe module, 'with' replacement in query", func(t *testing.T) { + r := New( + Query(`data.pkg.deny with is_array as count`), + Module("pkg.rego", `package pkg + deny { + is_array(input.requests) > 10 + } + `), + UnsafeBuiltins(map[string]struct{}{"count": {}}), + ) + if _, err := r.Eval(ctx); err == nil || !strings.Contains(err.Error(), unsafeCountExprWith) { + t.Fatalf("Expected unsafe built-in error but got %v", err) + } + }) + + t.Run("unsafe module, 'with' replacement in module", func(t *testing.T) { + r := New( + Query(`data.pkg.deny`), + Module("pkg.rego", `package pkg + deny { + is_array(input.requests) > 10 with is_array as count + } + `), + UnsafeBuiltins(map[string]struct{}{"count": {}}), + ) + if _, err := r.Eval(ctx); err == nil || !strings.Contains(err.Error(), unsafeCountExprWith) { + t.Fatalf("Expected unsafe built-in error but got %v", err) + } + }) + t.Run("inherit in query", func(t *testing.T) { r := New( Compiler(ast.NewCompiler().WithUnsafeBuiltins(map[string]struct{}{"count": {}})), @@ -1473,6 +1514,16 @@ func TestUnsafeBuiltins(t *testing.T) { } }) + t.Run("inherit in query, 'with' replacement", func(t *testing.T) { + r := New( + Compiler(ast.NewCompiler().WithUnsafeBuiltins(map[string]struct{}{"count": {}})), + Query("is_array([]) with is_array as count"), + ) + if _, err := r.Eval(ctx); err == nil || !strings.Contains(err.Error(), unsafeCountExprWith) { + t.Fatalf("Expected unsafe built-in error but got %v", err) + } + }) + t.Run("override/disable in query", func(t *testing.T) { r := New( Compiler(ast.NewCompiler().WithUnsafeBuiltins(map[string]struct{}{"count": {}})),