From 6b2907be8cc7e4978010336ba52a37a17d234212 Mon Sep 17 00:00:00 2001 From: cici37 Date: Tue, 1 Mar 2022 11:34:32 -0800 Subject: [PATCH] Wire runtime cost into validation. --- .../apiextensions/validation/validation.go | 2 +- .../pkg/apiserver/schema/cel/compilation.go | 19 ++- .../apiserver/schema/cel/compilation_test.go | 2 +- .../pkg/apiserver/schema/cel/validation.go | 117 ++++++++++---- .../apiserver/schema/cel/validation_test.go | 69 ++++++-- .../apiserver/schema/defaulting/validation.go | 52 ++++-- .../schema/defaulting/validation_test.go | 148 ++++++++++++++++++ .../customresource/status_strategy.go | 4 +- .../pkg/registry/customresource/strategy.go | 8 +- 9 files changed, 355 insertions(+), 66 deletions(-) create mode 100644 staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/defaulting/validation_test.go diff --git a/staging/src/k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/validation/validation.go b/staging/src/k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/validation/validation.go index 91aab9462b41..8d8d09449973 100644 --- a/staging/src/k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/validation/validation.go +++ b/staging/src/k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/validation/validation.go @@ -946,7 +946,7 @@ func ValidateCustomResourceDefinitionOpenAPISchema(schema *apiextensions.JSONSch structural, err := structuralschema.NewStructural(schema) if err == nil { - compResults, err := cel.Compile(structural, isRoot) + compResults, err := cel.Compile(structural, isRoot, cel.PerCallLimit) if err != nil { allErrs = append(allErrs, field.InternalError(fldPath.Child("x-kubernetes-validations"), err)) } else { diff --git a/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/compilation.go b/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/compilation.go index efa4c8201c21..5ca78b043a61 100644 --- a/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/compilation.go +++ b/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/compilation.go @@ -18,6 +18,7 @@ package cel import ( "fmt" + "math" "strings" "time" @@ -40,6 +41,14 @@ const ( // OldScopedVarName is the variable name assigned to the existing value of the locally scoped data element of a // CEL validation expression. OldScopedVarName = "oldSelf" + + // PerCallLimit specify the actual cost limit per CEL validation call + //TODO: pick the number for PerCallLimit + PerCallLimit = uint64(math.MaxInt64) + + // RuntimeCELCostBudget is the overall cost budget for runtime CEL validation cost per CustomResource + //TODO: pick the RuntimeCELCostBudget + RuntimeCELCostBudget = math.MaxInt64 ) // CompilationResult represents the cel compilation result for one rule @@ -58,7 +67,8 @@ type CompilationResult struct { /// - non-nil Program, nil Error: The program was compiled successfully // - nil Program, non-nil Error: Compilation resulted in an error // - nil Program, nil Error: The provided rule was empty so compilation was not attempted -func Compile(s *schema.Structural, isResourceRoot bool) ([]CompilationResult, error) { +// perCallLimit was added for testing purpose only. Callers should always use const PerCallLimit as input. +func Compile(s *schema.Structural, isResourceRoot bool, perCallLimit uint64) ([]CompilationResult, error) { if len(s.Extensions.XValidations) == 0 { return nil, nil } @@ -106,13 +116,13 @@ func Compile(s *schema.Structural, isResourceRoot bool) ([]CompilationResult, er // compResults is the return value which saves a list of compilation results in the same order as x-kubernetes-validations rules. compResults := make([]CompilationResult, len(celRules)) for i, rule := range celRules { - compResults[i] = compileRule(rule, env) + compResults[i] = compileRule(rule, env, perCallLimit) } return compResults, nil } -func compileRule(rule apiextensions.ValidationRule, env *cel.Env) (compilationResult CompilationResult) { +func compileRule(rule apiextensions.ValidationRule, env *cel.Env, perCallLimit uint64) (compilationResult CompilationResult) { if len(strings.TrimSpace(rule.Rule)) == 0 { // include a compilation result, but leave both program and error nil per documented return semantics of this // function @@ -141,7 +151,8 @@ func compileRule(rule apiextensions.ValidationRule, env *cel.Env) (compilationRe } } - prog, err := env.Program(ast, cel.EvalOptions(cel.OptOptimize)) + // TODO: Ideally we could configure the per expression limit at validation time and set it to the remaining overall budget, but we would either need a way to pass in a limit at evaluation time or move program creation to validation time + prog, err := env.Program(ast, cel.EvalOptions(cel.OptOptimize, cel.OptTrackCost), cel.CostLimit(perCallLimit)) if err != nil { compilationResult.Error = &Error{ErrorTypeInvalid, "program instantiation failed: " + err.Error()} return diff --git a/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/compilation_test.go b/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/compilation_test.go index 84befdc2a9fb..9f418ad21d25 100644 --- a/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/compilation_test.go +++ b/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/compilation_test.go @@ -637,7 +637,7 @@ func TestCelCompilation(t *testing.T) { for _, tt := range cases { t.Run(tt.name, func(t *testing.T) { - compilationResults, err := Compile(&tt.input, false) + compilationResults, err := Compile(&tt.input, false, PerCallLimit) if err != nil { t.Errorf("Expected no error, but got: %v", err) } diff --git a/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/validation.go b/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/validation.go index fb6feebfb471..4a466189ec34 100644 --- a/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/validation.go +++ b/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/validation.go @@ -18,6 +18,7 @@ package cel import ( "fmt" + "math" "strings" "github.com/google/cel-go/common/types" @@ -55,27 +56,28 @@ type Validator struct { // validators for all items, properties and additionalProperties that transitively contain validator rules. // Returns nil only if there no validator rules in the Structural schema. May return a validator containing // only errors. -func NewValidator(s *schema.Structural) *Validator { - return validator(s, true) +// Adding perCallLimit as input arg for testing purpose only. Callers should always use const PerCallLimit as input +func NewValidator(s *schema.Structural, perCallLimit uint64) *Validator { + return validator(s, true, perCallLimit) } -func validator(s *schema.Structural, isResourceRoot bool) *Validator { - compiledRules, err := Compile(s, isResourceRoot) +func validator(s *schema.Structural, isResourceRoot bool, perCallLimit uint64) *Validator { + compiledRules, err := Compile(s, isResourceRoot, perCallLimit) var itemsValidator, additionalPropertiesValidator *Validator var propertiesValidators map[string]Validator if s.Items != nil { - itemsValidator = validator(s.Items, s.Items.XEmbeddedResource) + itemsValidator = validator(s.Items, s.Items.XEmbeddedResource, perCallLimit) } if len(s.Properties) > 0 { propertiesValidators = make(map[string]Validator, len(s.Properties)) for k, prop := range s.Properties { - if p := validator(&prop, prop.XEmbeddedResource); p != nil { + if p := validator(&prop, prop.XEmbeddedResource, perCallLimit); p != nil { propertiesValidators[k] = *p } } } if s.AdditionalProperties != nil && s.AdditionalProperties.Structural != nil { - additionalPropertiesValidator = validator(s.AdditionalProperties.Structural, s.AdditionalProperties.Structural.XEmbeddedResource) + additionalPropertiesValidator = validator(s.AdditionalProperties.Structural, s.AdditionalProperties.Structural.XEmbeddedResource, perCallLimit) } if len(compiledRules) > 0 || err != nil || itemsValidator != nil || additionalPropertiesValidator != nil || len(propertiesValidators) > 0 { return &Validator{ @@ -92,34 +94,51 @@ func validator(s *schema.Structural, isResourceRoot bool) *Validator { } // Validate validates all x-kubernetes-validations rules in Validator against obj and returns any errors. -func (s *Validator) Validate(fldPath *field.Path, sts *schema.Structural, obj interface{}) field.ErrorList { +// If the validation rules exceed the costBudget, subsequent evaluations will be skipped, the list of errs returned will not be empty, and a negative remainingBudget will be returned. +// Most callers can ignore the returned remainingBudget value unless another validate call is going to be made +func (s *Validator) Validate(fldPath *field.Path, sts *schema.Structural, obj interface{}, costBudget int64) (errs field.ErrorList, remainingBudget int64) { + remainingBudget = costBudget if s == nil || obj == nil { - return nil + return nil, remainingBudget } - errs := s.validateExpressions(fldPath, sts, obj) + errs, remainingBudget = s.validateExpressions(fldPath, sts, obj, remainingBudget) + if remainingBudget < 0 { + return errs, remainingBudget + } switch obj := obj.(type) { case []interface{}: - return append(errs, s.validateArray(fldPath, sts, obj)...) + var arrayErrs field.ErrorList + arrayErrs, remainingBudget = s.validateArray(fldPath, sts, obj, remainingBudget) + errs = append(errs, arrayErrs...) + return errs, remainingBudget case map[string]interface{}: - return append(errs, s.validateMap(fldPath, sts, obj)...) + var mapErrs field.ErrorList + mapErrs, remainingBudget = s.validateMap(fldPath, sts, obj, remainingBudget) + errs = append(errs, mapErrs...) + return errs, remainingBudget } - return errs + return errs, remainingBudget } -func (s *Validator) validateExpressions(fldPath *field.Path, sts *schema.Structural, obj interface{}) (errs field.ErrorList) { +func (s *Validator) validateExpressions(fldPath *field.Path, sts *schema.Structural, obj interface{}, costBudget int64) (errs field.ErrorList, remainingBudget int64) { + remainingBudget = costBudget if obj == nil { // We only validate non-null values. Rules that need to check for the state of a nullable value or the presence of an optional // field must do so from the surrounding schema. E.g. if an array has nullable string items, a rule on the array // schema can check if items are null, but a rule on the nullable string schema only validates the non-null strings. - return nil + return nil, remainingBudget } if s.compilationErr != nil { errs = append(errs, field.Invalid(fldPath, obj, fmt.Sprintf("rule compiler initialization error: %v", s.compilationErr))) - return errs + return errs, remainingBudget } if len(s.compiledRules) == 0 { - return nil // nothing to do + return nil, remainingBudget // nothing to do + } + if remainingBudget <= 0 { + errs = append(errs, field.Invalid(fldPath, obj, fmt.Sprintf("validation failed due to running out of cost budget, no further validation rules will be run"))) + return errs, -1 } if s.isResourceRoot { sts = model.WithTypeAndObjectMeta(sts) @@ -140,7 +159,23 @@ func (s *Validator) validateExpressions(fldPath *field.Path, sts *schema.Structu errs = append(errs, field.InternalError(fldPath, fmt.Errorf("oldSelf validation not implemented"))) continue // todo: wire oldObj parameter } - evalResult, _, err := compiled.Program.Eval(activation) + evalResult, evalDetails, err := compiled.Program.Eval(activation) + if evalDetails == nil { + errs = append(errs, field.InternalError(fldPath, fmt.Errorf("runtime cost could not be calculated for validation rule: %v, no further validation rules will be run", ruleErrorString(rule)))) + return errs, -1 + } else { + rtCost := evalDetails.ActualCost() + if rtCost == nil { + errs = append(errs, field.Invalid(fldPath, obj, fmt.Sprintf("runtime cost could not be calculated for validation rule: %v, no further validation rules will be run", ruleErrorString(rule)))) + return errs, -1 + } else { + if *rtCost > math.MaxInt64 || int64(*rtCost) > remainingBudget { + errs = append(errs, field.Invalid(fldPath, obj, fmt.Sprintf("validation failed due to running out of cost budget, no further validation rules will be run"))) + return errs, -1 + } + remainingBudget -= int64(*rtCost) + } + } if err != nil { // see types.Err for list of well defined error types if strings.HasPrefix(err.Error(), "no such overload") { @@ -149,12 +184,15 @@ func (s *Validator) validateExpressions(fldPath *field.Path, sts *schema.Structu // append a more descriptive error message. This error can only occur when static type checking has // been bypassed. int-or-string is typed as dynamic and so bypasses compiler type checking. errs = append(errs, field.Invalid(fldPath, obj, fmt.Sprintf("'%v': call arguments did not match a supported operator, function or macro signature for rule: %v", err, ruleErrorString(rule)))) + } else if strings.HasPrefix(err.Error(), "operation cancelled: actual cost limit exceeded") { + errs = append(errs, field.Invalid(fldPath, obj, fmt.Sprintf("'%v': call cost exceeds limit for rule: %v", err, ruleErrorString(rule)))) } else { // no such key: {key}, index out of bounds: {index}, integer overflow, division by zero, ... errs = append(errs, field.Invalid(fldPath, obj, fmt.Sprintf("%v evaluating rule: %v", err, ruleErrorString(rule)))) } continue } + if evalResult != types.True { if len(rule.Message) != 0 { errs = append(errs, field.Invalid(fldPath, obj, rule.Message)) @@ -163,7 +201,7 @@ func (s *Validator) validateExpressions(fldPath *field.Path, sts *schema.Structu } } } - return errs + return errs, remainingBudget } func ruleErrorString(rule apiextensions.ValidationRule) string { @@ -192,14 +230,23 @@ func (a *validationActivation) Parent() interpreter.Activation { return nil } -func (s *Validator) validateMap(fldPath *field.Path, sts *schema.Structural, obj map[string]interface{}) (errs field.ErrorList) { +func (s *Validator) validateMap(fldPath *field.Path, sts *schema.Structural, obj map[string]interface{}, costBudget int64) (errs field.ErrorList, remainingBudget int64) { + remainingBudget = costBudget + if remainingBudget < 0 { + return errs, remainingBudget + } if s == nil || obj == nil { - return nil + return nil, remainingBudget } if s.AdditionalProperties != nil && sts.AdditionalProperties != nil && sts.AdditionalProperties.Structural != nil { for k, v := range obj { - errs = append(errs, s.AdditionalProperties.Validate(fldPath.Key(k), sts.AdditionalProperties.Structural, v)...) + var err field.ErrorList + err, remainingBudget = s.AdditionalProperties.Validate(fldPath.Key(k), sts.AdditionalProperties.Structural, v, remainingBudget) + errs = append(errs, err...) + if remainingBudget < 0 { + return errs, remainingBudget + } } } if s.Properties != nil && sts.Properties != nil { @@ -207,22 +254,34 @@ func (s *Validator) validateMap(fldPath *field.Path, sts *schema.Structural, obj stsProp, stsOk := sts.Properties[k] sub, ok := s.Properties[k] if ok && stsOk { - errs = append(errs, sub.Validate(fldPath.Child(k), &stsProp, v)...) + var err field.ErrorList + err, remainingBudget = sub.Validate(fldPath.Child(k), &stsProp, v, remainingBudget) + errs = append(errs, err...) + if remainingBudget < 0 { + return errs, remainingBudget + } } } } - return errs + return errs, remainingBudget } -func (s *Validator) validateArray(fldPath *field.Path, sts *schema.Structural, obj []interface{}) field.ErrorList { - var errs field.ErrorList - +func (s *Validator) validateArray(fldPath *field.Path, sts *schema.Structural, obj []interface{}, costBudget int64) (errs field.ErrorList, remainingBudget int64) { + remainingBudget = costBudget + if remainingBudget < 0 { + return errs, remainingBudget + } if s.Items != nil && sts.Items != nil { for i := range obj { - errs = append(errs, s.Items.Validate(fldPath.Index(i), sts.Items, obj[i])...) + var err field.ErrorList + err, remainingBudget = s.Items.Validate(fldPath.Index(i), sts.Items, obj[i], remainingBudget) + errs = append(errs, err...) + if remainingBudget < 0 { + return errs, remainingBudget + } } } - return errs + return errs, remainingBudget } diff --git a/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/validation_test.go b/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/validation_test.go index 02cac25af7ed..cc17621ce366 100644 --- a/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/validation_test.go +++ b/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/validation_test.go @@ -30,11 +30,12 @@ import ( // TestValidationExpressions tests CEL integration with custom resource values and OpenAPIv3. func TestValidationExpressions(t *testing.T) { tests := []struct { - name string - schema *schema.Structural - obj map[string]interface{} - valid []string - errors map[string]string // rule -> string that error message must contain + name string + schema *schema.Structural + obj map[string]interface{} + valid []string + errors map[string]string // rule -> string that error message must contain + costBudget int64 }{ // tests where val1 and val2 are equal but val3 is different // equality, comparisons and type specific functions @@ -1683,27 +1684,61 @@ func TestValidationExpressions(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + // set costBudget to maxInt64 for current test + tt.costBudget = math.MaxInt64 for _, validRule := range tt.valid { t.Run(validRule, func(t *testing.T) { s := withRule(*tt.schema, validRule) - celValidator := NewValidator(&s) + celValidator := NewValidator(&s, PerCallLimit) if celValidator == nil { t.Fatal("expected non nil validator") } - errs := celValidator.Validate(field.NewPath("root"), &s, tt.obj) + errs, _ := celValidator.Validate(field.NewPath("root"), &s, tt.obj, tt.costBudget) for _, err := range errs { t.Errorf("unexpected error: %v", err) } + + // test with cost budget exceeded + errs, _ = celValidator.Validate(field.NewPath("root"), &s, tt.obj, 0) + var found bool + for _, err := range errs { + if err.Type == field.ErrorTypeInvalid && strings.Contains(err.Error(), "validation failed due to running out of cost budget, no further validation rules will be run") { + found = true + } + } + if !found { + t.Errorf("expect cost limit exceed err but did not find") + } + if len(errs) > 1 { + t.Errorf("expect to return cost budget exceed err once") + } + + // test with PerCallLimit exceeded + found = false + celValidator = NewValidator(&s, 0) + if celValidator == nil { + t.Fatal("expected non nil validator") + } + errs, _ = celValidator.Validate(field.NewPath("root"), &s, tt.obj, tt.costBudget) + for _, err := range errs { + if err.Type == field.ErrorTypeInvalid && strings.Contains(err.Error(), "call cost exceeds limit for rule") { + found = true + break + } + } + if !found { + t.Errorf("expect PerCostLimit exceed err but did not find") + } }) } for rule, expectErrToContain := range tt.errors { t.Run(rule, func(t *testing.T) { s := withRule(*tt.schema, rule) - celValidator := NewValidator(&s) + celValidator := NewValidator(&s, PerCallLimit) if celValidator == nil { t.Fatal("expected non nil validator") } - errs := celValidator.Validate(field.NewPath("root"), &s, tt.obj) + errs, _ := celValidator.Validate(field.NewPath("root"), &s, tt.obj, tt.costBudget) if len(errs) == 0 { t.Error("expected validation errors but got none") } @@ -1712,9 +1747,23 @@ func TestValidationExpressions(t *testing.T) { t.Errorf("expected error to contain '%s', but got: %v", expectErrToContain, err) } } + + // test with cost budget exceeded + errs, _ = celValidator.Validate(field.NewPath("root"), &s, tt.obj, 0) + var found bool + for _, err := range errs { + if err.Type == field.ErrorTypeInvalid && strings.Contains(err.Error(), "validation failed due to running out of cost budget, no further validation rules will be run") { + found = true + } + } + if !found { + t.Errorf("expect cost limit exceed err but did not find") + } + if len(errs) > 1 { + t.Errorf("expect to return cost budget exceed err once") + } }) } - }) } } diff --git a/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/defaulting/validation.go b/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/defaulting/validation.go index 88fa3eba9b37..a181de0ddb44 100644 --- a/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/defaulting/validation.go +++ b/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/defaulting/validation.go @@ -47,15 +47,17 @@ func ValidateDefaults(pth *field.Path, s *structuralschema.Structural, isResourc } } - return validate(pth, s, s, f, false, requirePrunedDefaults) + allErr, error, _ := validate(pth, s, s, f, false, requirePrunedDefaults, cel.RuntimeCELCostBudget) + return allErr, error } // validate is the recursive step func for the validation. insideMeta is true if s specifies // TypeMeta or ObjectMeta. The SurroundingObjectFunc f is used to validate defaults of // TypeMeta or ObjectMeta fields. -func validate(pth *field.Path, s *structuralschema.Structural, rootSchema *structuralschema.Structural, f SurroundingObjectFunc, insideMeta, requirePrunedDefaults bool) (field.ErrorList, error) { +func validate(pth *field.Path, s *structuralschema.Structural, rootSchema *structuralschema.Structural, f SurroundingObjectFunc, insideMeta, requirePrunedDefaults bool, costBudget int64) (allErrs field.ErrorList, error error, remainingCost int64) { + remainingCost = costBudget if s == nil { - return nil, nil + return nil, nil, remainingCost } if s.XEmbeddedResource { @@ -64,8 +66,6 @@ func validate(pth *field.Path, s *structuralschema.Structural, rootSchema *struc rootSchema = s } - allErrs := field.ErrorList{} - if s.Default.Object != nil { validator := kubeopenapivalidate.NewSchemaValidator(s.ToKubeOpenAPI(), nil, "", strfmt.Default) @@ -75,7 +75,7 @@ func validate(pth *field.Path, s *structuralschema.Structural, rootSchema *struc // this should never happen. f(s.Default.Object) only gives an error if f is the // root object func, but the default value is not a map. But then we wouldn't be // in this case. - return nil, fmt.Errorf("failed to validate default value inside metadata: %v", err) + return nil, fmt.Errorf("failed to validate default value inside metadata: %v", err), remainingCost } // check ObjectMeta/TypeMeta and everything else @@ -85,8 +85,13 @@ func validate(pth *field.Path, s *structuralschema.Structural, rootSchema *struc allErrs = append(allErrs, field.Invalid(pth.Child("default"), s.Default.Object, fmt.Sprintf("must result in valid metadata: %v", errs.ToAggregate()))) } else if errs := apiservervalidation.ValidateCustomResource(pth.Child("default"), s.Default.Object, validator); len(errs) > 0 { allErrs = append(allErrs, errs...) - } else if celValidator := cel.NewValidator(s); celValidator != nil { - allErrs = append(allErrs, celValidator.Validate(pth.Child("default"), s, s.Default.Object)...) + } else if celValidator := cel.NewValidator(s, cel.PerCallLimit); celValidator != nil { + celErrs, rmCost := celValidator.Validate(pth.Child("default"), s, s.Default.Object, remainingCost) + remainingCost = rmCost + allErrs = append(allErrs, celErrs...) + if remainingCost < 0 { + return allErrs, nil, remainingCost + } } } else { // check whether default is pruned @@ -105,8 +110,13 @@ func validate(pth *field.Path, s *structuralschema.Structural, rootSchema *struc allErrs = append(allErrs, errs...) } else if errs := apiservervalidation.ValidateCustomResource(pth.Child("default"), s.Default.Object, validator); len(errs) > 0 { allErrs = append(allErrs, errs...) - } else if celValidator := cel.NewValidator(s); celValidator != nil { - allErrs = append(allErrs, celValidator.Validate(pth.Child("default"), s, s.Default.Object)...) + } else if celValidator := cel.NewValidator(s, cel.PerCallLimit); celValidator != nil { + celErrs, rmCost := celValidator.Validate(pth.Child("default"), s, s.Default.Object, remainingCost) + remainingCost = rmCost + allErrs = append(allErrs, celErrs...) + if remainingCost < 0 { + return allErrs, nil, remainingCost + } } } } @@ -114,11 +124,15 @@ func validate(pth *field.Path, s *structuralschema.Structural, rootSchema *struc // do not follow additionalProperties because defaults are forbidden there if s.Items != nil { - errs, err := validate(pth.Child("items"), s.Items, rootSchema, f.Index(), insideMeta, requirePrunedDefaults) + errs, err, rCost := validate(pth.Child("items"), s.Items, rootSchema, f.Index(), insideMeta, requirePrunedDefaults, remainingCost) + remainingCost = rCost + allErrs = append(allErrs, errs...) if err != nil { - return nil, err + return nil, err, remainingCost + } + if remainingCost < 0 { + return allErrs, nil, remainingCost } - allErrs = append(allErrs, errs...) } for k, subSchema := range s.Properties { @@ -126,12 +140,16 @@ func validate(pth *field.Path, s *structuralschema.Structural, rootSchema *struc if s.XEmbeddedResource && (k == "metadata" || k == "apiVersion" || k == "kind") { subInsideMeta = true } - errs, err := validate(pth.Child("properties").Key(k), &subSchema, rootSchema, f.Child(k), subInsideMeta, requirePrunedDefaults) + errs, err, rCost := validate(pth.Child("properties").Key(k), &subSchema, rootSchema, f.Child(k), subInsideMeta, requirePrunedDefaults, remainingCost) + remainingCost = rCost + allErrs = append(allErrs, errs...) if err != nil { - return nil, err + return nil, err, remainingCost + } + if remainingCost < 0 { + return allErrs, nil, remainingCost } - allErrs = append(allErrs, errs...) } - return allErrs, nil + return allErrs, nil, remainingCost } diff --git a/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/defaulting/validation_test.go b/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/defaulting/validation_test.go new file mode 100644 index 000000000000..3f29600f1aef --- /dev/null +++ b/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/defaulting/validation_test.go @@ -0,0 +1,148 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package defaulting + +import ( + "strings" + "testing" + + "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions" + structuralschema "k8s.io/apiextensions-apiserver/pkg/apiserver/schema" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/validation/field" +) + +func jsonPtr(x interface{}) *apiextensions.JSON { + ret := apiextensions.JSON(x) + return &ret +} + +func TestDefaultValidationWithCostBudget(t *testing.T) { + tests := []struct { + name string + input apiextensions.CustomResourceValidation + }{ + { + name: "default cel validation", + input: apiextensions.CustomResourceValidation{ + OpenAPIV3Schema: &apiextensions.JSONSchemaProps{ + Type: "object", + Properties: map[string]apiextensions.JSONSchemaProps{ + "embedded": { + Type: "object", + Properties: map[string]apiextensions.JSONSchemaProps{ + "metadata": { + Type: "object", + XEmbeddedResource: true, + Properties: map[string]apiextensions.JSONSchemaProps{ + "name": { + Type: "string", + XValidations: apiextensions.ValidationRules{ + { + Rule: "self == 'singleton'", + }, + }, + Default: jsonPtr("singleton"), + }, + }, + }, + }, + }, + "value": { + Type: "string", + XValidations: apiextensions.ValidationRules{ + { + Rule: "self.startsWith('kube')", + }, + }, + Default: jsonPtr("kube-everything"), + }, + "object": { + Type: "object", + Properties: map[string]apiextensions.JSONSchemaProps{ + "field1": { + Type: "integer", + }, + "field2": { + Type: "integer", + }, + }, + XValidations: apiextensions.ValidationRules{ + { + Rule: "self.field1 < self.field2", + }, + }, + Default: jsonPtr(map[string]interface{}{"field1": 1, "field2": 2}), + }, + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + schema := tt.input.OpenAPIV3Schema + ss, err := structuralschema.NewStructural(schema) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + f := NewRootObjectFunc().WithTypeMeta(metav1.TypeMeta{APIVersion: "validation/v1", Kind: "Validation"}) + + // cost budget is large enough to pass all validation rules + allErrs, err, _ := validate(field.NewPath("test"), ss, ss, f, false, false, 10) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + for _, valErr := range allErrs { + t.Errorf("unexpected error: %v", valErr) + } + + // cost budget exceeded for the first validation rule + allErrs, err, _ = validate(field.NewPath("test"), ss, ss, f, false, false, 0) + meet := 0 + for _, er := range allErrs { + if er.Type == field.ErrorTypeInvalid && strings.Contains(er.Error(), "validation failed due to running out of cost budget, no further validation rules will be run") { + meet += 1 + } + } + if meet != 1 { + t.Errorf("expected to get cost budget exceed error once but got %v cost budget exceed error", meet) + } + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + // cost budget exceeded for the last validation rule + allErrs, err, _ = validate(field.NewPath("test"), ss, ss, f, false, false, 9) + meet = 0 + for _, er := range allErrs { + if er.Type == field.ErrorTypeInvalid && strings.Contains(er.Error(), "validation failed due to running out of cost budget, no further validation rules will be run") { + meet += 1 + } + } + if meet != 1 { + t.Errorf("expected to get cost budget exceed error once but got %v cost budget exceed error", meet) + } + if err != nil { + t.Errorf("unexpected error: %v", err) + } + }) + } +} diff --git a/staging/src/k8s.io/apiextensions-apiserver/pkg/registry/customresource/status_strategy.go b/staging/src/k8s.io/apiextensions-apiserver/pkg/registry/customresource/status_strategy.go index 5582633f00cd..0826212a8b17 100644 --- a/staging/src/k8s.io/apiextensions-apiserver/pkg/registry/customresource/status_strategy.go +++ b/staging/src/k8s.io/apiextensions-apiserver/pkg/registry/customresource/status_strategy.go @@ -19,6 +19,7 @@ package customresource import ( "context" + "k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/util/validation/field" @@ -90,7 +91,8 @@ func (a statusStrategy) ValidateUpdate(ctx context.Context, obj, old runtime.Obj // validate x-kubernetes-validations rules if celValidator, ok := a.customResourceStrategy.celValidators[v]; ok { - errs = append(errs, celValidator.Validate(nil, a.customResourceStrategy.structuralSchemas[v], u.Object)...) + err, _ := celValidator.Validate(nil, a.customResourceStrategy.structuralSchemas[v], u.Object, cel.RuntimeCELCostBudget) + errs = append(errs, err...) } } return errs diff --git a/staging/src/k8s.io/apiextensions-apiserver/pkg/registry/customresource/strategy.go b/staging/src/k8s.io/apiextensions-apiserver/pkg/registry/customresource/strategy.go index aae84fd3cfde..4de94269dee2 100644 --- a/staging/src/k8s.io/apiextensions-apiserver/pkg/registry/customresource/strategy.go +++ b/staging/src/k8s.io/apiextensions-apiserver/pkg/registry/customresource/strategy.go @@ -60,7 +60,7 @@ func NewStrategy(typer runtime.ObjectTyper, namespaceScoped bool, kind schema.Gr celValidators := map[string]*cel.Validator{} if utilfeature.DefaultFeatureGate.Enabled(features.CustomResourceValidationExpressions) { for name, s := range structuralSchemas { - v := cel.NewValidator(s) // CEL programs are compiled and cached here + v := cel.NewValidator(s, cel.PerCallLimit) // CEL programs are compiled and cached here if v != nil { celValidators[name] = v } @@ -174,7 +174,8 @@ func (a customResourceStrategy) Validate(ctx context.Context, obj runtime.Object // validate x-kubernetes-validations rules if celValidator, ok := a.celValidators[v]; ok { - errs = append(errs, celValidator.Validate(nil, a.structuralSchemas[v], u.Object)...) + err, _ := celValidator.Validate(nil, a.structuralSchemas[v], u.Object, cel.RuntimeCELCostBudget) + errs = append(errs, err...) } } @@ -226,7 +227,8 @@ func (a customResourceStrategy) ValidateUpdate(ctx context.Context, obj, old run // validate x-kubernetes-validations rules if celValidator, ok := a.celValidators[v]; ok { - errs = append(errs, celValidator.Validate(nil, a.structuralSchemas[v], uNew.Object)...) + err, _ := celValidator.Validate(nil, a.structuralSchemas[v], uNew.Object, cel.RuntimeCELCostBudget) + errs = append(errs, err...) } return errs