Skip to content

Commit

Permalink
support context propagation
Browse files Browse the repository at this point in the history
- context.Context instance passed in ContextEval can be propagated to binding function to cancel the process.
  • Loading branch information
goccy committed Apr 18, 2024
1 parent 2337cc0 commit 3f26b91
Show file tree
Hide file tree
Showing 14 changed files with 348 additions and 227 deletions.
37 changes: 36 additions & 1 deletion cel/cel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1014,6 +1014,41 @@ func TestContextEval(t *testing.T) {
}
}

func TestContextEvalPropagation(t *testing.T) {
env, err := NewEnv(Function("test",
Overload("test_int", []*Type{}, IntType,
FunctionBindingContext(func(ctx context.Context, _ ...ref.Val) ref.Val {
md := ctx.Value("metadata")
if md == nil {
return types.NewErr("cannot find metadata value")
}
return types.Int(md.(int))
}),
),
))
if err != nil {
t.Fatalf("NewEnv() failed: %v", err)
}
ast, iss := env.Compile("test()")
if iss.Err() != nil {
t.Fatalf("env.Compile(expr) failed: %v", iss.Err())
}
prg, err := env.Program(ast)
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}

expected := 10
ctx := context.WithValue(context.Background(), "metadata", expected)
out, _, err := prg.ContextEval(ctx, map[string]interface{}{})
if err != nil {
t.Fatalf("prg.ContextEval() failed: %v", err)
}
if out != types.Int(expected) {
t.Errorf("prg.ContextEval() got %v, but wanted %d", out, expected)
}
}

func BenchmarkContextEval(b *testing.B) {
env := testEnv(b,
Variable("items", ListType(IntType)),
Expand Down Expand Up @@ -1428,7 +1463,7 @@ func TestCustomInterpreterDecorator(t *testing.T) {
if !lhsIsConst || !rhsIsConst {
return i, nil
}
val := call.Eval(interpreter.EmptyActivation())
val := call.Eval(context.Background(), interpreter.EmptyActivation())
if types.IsError(val) {
return nil, val.(*types.Err)
}
Expand Down
18 changes: 18 additions & 0 deletions cel/decls.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,24 @@ func FunctionBinding(binding functions.FunctionOp) OverloadOpt {
return decls.FunctionBinding(binding)
}

// UnaryBindingContext provides the implementation of a unary overload. The provided function is protected by a runtime
// type-guard which ensures runtime type agreement between the overload signature and runtime argument types.
func UnaryBindingContext(binding functions.UnaryContextOp) OverloadOpt {
return decls.UnaryBindingContext(binding)
}

// BinaryBindingContext provides the implementation of a binary overload. The provided function is protected by a runtime
// type-guard which ensures runtime type agreement between the overload signature and runtime argument types.
func BinaryBindingContext(binding functions.BinaryContextOp) OverloadOpt {
return decls.BinaryBindingContext(binding)
}

// FunctionBindingContext provides the implementation of a variadic overload. The provided function is protected by a runtime
// type-guard which ensures runtime type agreement between the overload signature and runtime argument types.
func FunctionBindingContext(binding functions.FunctionContextOp) OverloadOpt {
return decls.FunctionBindingContext(binding)
}

// OverloadIsNonStrict enables the function to be called with error and unknown argument values.
//
// Note: do not use this option unless absoluately necessary as it should be an uncommon feature.
Expand Down
5 changes: 3 additions & 2 deletions cel/decls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package cel

import (
"context"
"fmt"
"math"
"reflect"
Expand Down Expand Up @@ -673,7 +674,7 @@ func TestExprDeclToDeclaration(t *testing.T) {
}
prg, err := e.Program(ast, Functions(&functions.Overload{
Operator: overloads.SizeString,
Unary: func(arg ref.Val) ref.Val {
Unary: func(ctx context.Context, arg ref.Val) ref.Val {
str, ok := arg.(types.String)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
Expand All @@ -682,7 +683,7 @@ func TestExprDeclToDeclaration(t *testing.T) {
},
}, &functions.Overload{
Operator: overloads.SizeStringInst,
Unary: func(arg ref.Val) ref.Val {
Unary: func(ctx context.Context, arg ref.Val) ref.Val {
str, ok := arg.(types.String)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
Expand Down
13 changes: 7 additions & 6 deletions cel/library.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package cel

import (
"context"
"math"
"strconv"
"strings"
Expand Down Expand Up @@ -494,17 +495,17 @@ func (opt *evalOptionalOr) ID() int64 {

// Eval evaluates the left-hand side optional to determine whether it contains a value, else
// proceeds with the right-hand side evaluation.
func (opt *evalOptionalOr) Eval(ctx interpreter.Activation) ref.Val {
func (opt *evalOptionalOr) Eval(ctx context.Context, vars interpreter.Activation) ref.Val {
// short-circuit lhs.
optLHS := opt.lhs.Eval(ctx)
optLHS := opt.lhs.Eval(ctx, vars)
optVal, ok := optLHS.(*types.Optional)
if !ok {
return optLHS
}
if optVal.HasValue() {
return optVal
}
return opt.rhs.Eval(ctx)
return opt.rhs.Eval(ctx, vars)
}

// evalOptionalOrValue selects between an optional or a concrete value. If the optional has a value,
Expand All @@ -522,17 +523,17 @@ func (opt *evalOptionalOrValue) ID() int64 {

// Eval evaluates the left-hand side optional to determine whether it contains a value, else
// proceeds with the right-hand side evaluation.
func (opt *evalOptionalOrValue) Eval(ctx interpreter.Activation) ref.Val {
func (opt *evalOptionalOrValue) Eval(ctx context.Context, vars interpreter.Activation) ref.Val {
// short-circuit lhs.
optLHS := opt.lhs.Eval(ctx)
optLHS := opt.lhs.Eval(ctx, vars)
optVal, ok := optLHS.(*types.Optional)
if !ok {
return optLHS
}
if optVal.HasValue() {
return optVal.GetValue()
}
return opt.rhs.Eval(ctx)
return opt.rhs.Eval(ctx, vars)
}

type timeUTCLibrary struct{}
Expand Down
16 changes: 13 additions & 3 deletions cel/program.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,11 @@ func (p *prog) initInterpretable(a *Ast, decs []interpreter.InterpretableDecorat

// Eval implements the Program interface method.
func (p *prog) Eval(input any) (v ref.Val, det *EvalDetails, err error) {
return p.eval(context.Background(), input)
}

// Eval implements the Program interface method.
func (p *prog) eval(ctx context.Context, input any) (v ref.Val, det *EvalDetails, err error) {
// Configure error recovery for unexpected panics during evaluation. Note, the use of named
// return values makes it possible to modify the error response during the recovery
// function.
Expand Down Expand Up @@ -291,7 +296,7 @@ func (p *prog) Eval(input any) (v ref.Val, det *EvalDetails, err error) {
if p.defaultVars != nil {
vars = interpreter.NewHierarchicalActivation(p.defaultVars, vars)
}
v = p.interpretable.Eval(vars)
v = p.interpretable.Eval(ctx, vars)
// The output of an internal Eval may have a value (`v`) that is a types.Err. This step
// translates the CEL value to a Go error response. This interface does not quite match the
// RPC signature which allows for multiple errors to be returned, but should be sufficient.
Expand Down Expand Up @@ -321,7 +326,7 @@ func (p *prog) ContextEval(ctx context.Context, input any) (ref.Val, *EvalDetail
default:
return nil, nil, fmt.Errorf("invalid input, wanted Activation or map[string]any, got: (%T)%v", input, input)
}
return p.Eval(vars)
return p.eval(ctx, vars)
}

// progFactory is a helper alias for marking a program creation factory function.
Expand Down Expand Up @@ -349,6 +354,11 @@ func newProgGen(factory progFactory) (Program, error) {

// Eval implements the Program interface method.
func (gen *progGen) Eval(input any) (ref.Val, *EvalDetails, error) {
return gen.eval(context.Background(), input)
}

// Eval implements the Program interface method.
func (gen *progGen) eval(ctx context.Context, input any) (ref.Val, *EvalDetails, error) {
// The factory based Eval() differs from the standard evaluation model in that it generates a
// new EvalState instance for each call to ensure that unique evaluations yield unique stateful
// results.
Expand All @@ -368,7 +378,7 @@ func (gen *progGen) Eval(input any) (ref.Val, *EvalDetails, error) {
}

// Evaluate the input, returning the result and the 'state' within EvalDetails.
v, _, err := p.Eval(input)
v, _, err := p.ContextEval(ctx, input)
if err != nil {
return v, det, err
}
Expand Down
83 changes: 57 additions & 26 deletions common/decls/decls.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package decls

import (
"context"
"fmt"
"strings"

Expand Down Expand Up @@ -242,23 +243,23 @@ func (f *FunctionDecl) Bindings() ([]*functions.Overload, error) {
// All of the defined overloads are wrapped into a top-level function which
// performs dynamic dispatch to the proper overload based on the argument types.
bindings := append([]*functions.Overload{}, overloads...)
funcDispatch := func(args ...ref.Val) ref.Val {
funcDispatch := func(ctx context.Context, args ...ref.Val) ref.Val {
for _, oID := range f.overloadOrdinals {
o := f.overloads[oID]
// During dynamic dispatch over multiple functions, signature agreement checks
// are preserved in order to assist with the function resolution step.
switch len(args) {
case 1:
if o.unaryOp != nil && o.matchesRuntimeSignature( /* disableTypeGuards=*/ false, args...) {
return o.unaryOp(args[0])
return o.unaryOp(ctx, args[0])
}
case 2:
if o.binaryOp != nil && o.matchesRuntimeSignature( /* disableTypeGuards=*/ false, args...) {
return o.binaryOp(args[0], args[1])
return o.binaryOp(ctx, args[0], args[1])
}
}
if o.functionOp != nil && o.matchesRuntimeSignature( /* disableTypeGuards=*/ false, args...) {
return o.functionOp(args...)
return o.functionOp(ctx, args...)
}
// eventually this will fall through to the noSuchOverload below.
}
Expand Down Expand Up @@ -333,8 +334,10 @@ func SingletonUnaryBinding(fn functions.UnaryOp, traits ...int) FunctionOpt {
return nil, fmt.Errorf("function already has a singleton binding: %s", f.Name())
}
f.singleton = &functions.Overload{
Operator: f.Name(),
Unary: fn,
Operator: f.Name(),
Unary: func(ctx context.Context, val ref.Val) ref.Val {
return fn(val)
},
OperandTrait: trait,
}
return f, nil
Expand All @@ -355,8 +358,10 @@ func SingletonBinaryBinding(fn functions.BinaryOp, traits ...int) FunctionOpt {
return nil, fmt.Errorf("function already has a singleton binding: %s", f.Name())
}
f.singleton = &functions.Overload{
Operator: f.Name(),
Binary: fn,
Operator: f.Name(),
Binary: func(ctx context.Context, lhs ref.Val, rhs ref.Val) ref.Val {
return fn(lhs, rhs)
},
OperandTrait: trait,
}
return f, nil
Expand All @@ -377,8 +382,10 @@ func SingletonFunctionBinding(fn functions.FunctionOp, traits ...int) FunctionOp
return nil, fmt.Errorf("function already has a singleton binding: %s", f.Name())
}
f.singleton = &functions.Overload{
Operator: f.Name(),
Function: fn,
Operator: f.Name(),
Function: func(ctx context.Context, values ...ref.Val) ref.Val {
return fn(values...)
},
OperandTrait: trait,
}
return f, nil
Expand Down Expand Up @@ -460,11 +467,11 @@ type OverloadDecl struct {

// Function implementation options. Optional, but encouraged.
// unaryOp is a function binding that takes a single argument.
unaryOp functions.UnaryOp
unaryOp functions.UnaryContextOp
// binaryOp is a function binding that takes two arguments.
binaryOp functions.BinaryOp
binaryOp functions.BinaryContextOp
// functionOp is a catch-all for zero-arity and three-plus arity functions.
functionOp functions.FunctionOp
functionOp functions.FunctionContextOp
}

// ID mirrors the overload signature and provides a unique id which may be referenced within the type-checker
Expand Down Expand Up @@ -580,41 +587,41 @@ func (o *OverloadDecl) hasBinding() bool {
}

// guardedUnaryOp creates an invocation guard around the provided unary operator, if one is defined.
func (o *OverloadDecl) guardedUnaryOp(funcName string, disableTypeGuards bool) functions.UnaryOp {
func (o *OverloadDecl) guardedUnaryOp(funcName string, disableTypeGuards bool) functions.UnaryContextOp {
if o.unaryOp == nil {
return nil
}
return func(arg ref.Val) ref.Val {
return func(ctx context.Context, arg ref.Val) ref.Val {
if !o.matchesRuntimeUnarySignature(disableTypeGuards, arg) {
return MaybeNoSuchOverload(funcName, arg)
}
return o.unaryOp(arg)
return o.unaryOp(ctx, arg)
}
}

// guardedBinaryOp creates an invocation guard around the provided binary operator, if one is defined.
func (o *OverloadDecl) guardedBinaryOp(funcName string, disableTypeGuards bool) functions.BinaryOp {
func (o *OverloadDecl) guardedBinaryOp(funcName string, disableTypeGuards bool) functions.BinaryContextOp {
if o.binaryOp == nil {
return nil
}
return func(arg1, arg2 ref.Val) ref.Val {
return func(ctx context.Context, arg1, arg2 ref.Val) ref.Val {
if !o.matchesRuntimeBinarySignature(disableTypeGuards, arg1, arg2) {
return MaybeNoSuchOverload(funcName, arg1, arg2)
}
return o.binaryOp(arg1, arg2)
return o.binaryOp(ctx, arg1, arg2)
}
}

// guardedFunctionOp creates an invocation guard around the provided variadic function binding, if one is provided.
func (o *OverloadDecl) guardedFunctionOp(funcName string, disableTypeGuards bool) functions.FunctionOp {
func (o *OverloadDecl) guardedFunctionOp(funcName string, disableTypeGuards bool) functions.FunctionContextOp {
if o.functionOp == nil {
return nil
}
return func(args ...ref.Val) ref.Val {
return func(ctx context.Context, args ...ref.Val) ref.Val {
if !o.matchesRuntimeSignature(disableTypeGuards, args...) {
return MaybeNoSuchOverload(funcName, args...)
}
return o.functionOp(args...)
return o.functionOp(ctx, args...)
}
}

Expand Down Expand Up @@ -667,6 +674,30 @@ type OverloadOpt func(*OverloadDecl) (*OverloadDecl, error)
// UnaryBinding provides the implementation of a unary overload. The provided function is protected by a runtime
// type-guard which ensures runtime type agreement between the overload signature and runtime argument types.
func UnaryBinding(binding functions.UnaryOp) OverloadOpt {
return UnaryBindingContext(func(ctx context.Context, val ref.Val) ref.Val {
return binding(val)
})
}

// BinaryBinding provides the implementation of a binary overload. The provided function is protected by a runtime
// type-guard which ensures runtime type agreement between the overload signature and runtime argument types.
func BinaryBinding(binding functions.BinaryOp) OverloadOpt {
return BinaryBindingContext(func(ctx context.Context, lhs ref.Val, rhs ref.Val) ref.Val {
return binding(lhs, rhs)
})
}

// FunctionBinding provides the implementation of a variadic overload. The provided function is protected by a runtime
// type-guard which ensures runtime type agreement between the overload signature and runtime argument types.
func FunctionBinding(binding functions.FunctionOp) OverloadOpt {
return FunctionBindingContext(func(ctx context.Context, values ...ref.Val) ref.Val {
return binding(values...)
})
}

// UnaryBindingContext provides the implementation of a unary overload. The provided function is protected by a runtime
// type-guard which ensures runtime type agreement between the overload signature and runtime argument types.
func UnaryBindingContext(binding functions.UnaryContextOp) OverloadOpt {
return func(o *OverloadDecl) (*OverloadDecl, error) {
if o.hasBinding() {
return nil, fmt.Errorf("overload already has a binding: %s", o.ID())
Expand All @@ -679,9 +710,9 @@ func UnaryBinding(binding functions.UnaryOp) OverloadOpt {
}
}

// BinaryBinding provides the implementation of a binary overload. The provided function is protected by a runtime
// BinaryBindingContext provides the implementation of a binary overload. The provided function is protected by a runtime
// type-guard which ensures runtime type agreement between the overload signature and runtime argument types.
func BinaryBinding(binding functions.BinaryOp) OverloadOpt {
func BinaryBindingContext(binding functions.BinaryContextOp) OverloadOpt {
return func(o *OverloadDecl) (*OverloadDecl, error) {
if o.hasBinding() {
return nil, fmt.Errorf("overload already has a binding: %s", o.ID())
Expand All @@ -694,9 +725,9 @@ func BinaryBinding(binding functions.BinaryOp) OverloadOpt {
}
}

// FunctionBinding provides the implementation of a variadic overload. The provided function is protected by a runtime
// FunctionBindingContext provides the implementation of a variadic overload. The provided function is protected by a runtime
// type-guard which ensures runtime type agreement between the overload signature and runtime argument types.
func FunctionBinding(binding functions.FunctionOp) OverloadOpt {
func FunctionBindingContext(binding functions.FunctionContextOp) OverloadOpt {
return func(o *OverloadDecl) (*OverloadDecl, error) {
if o.hasBinding() {
return nil, fmt.Errorf("overload already has a binding: %s", o.ID())
Expand Down

0 comments on commit 3f26b91

Please sign in to comment.