Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support context propagation on overloads #558

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
52 changes: 52 additions & 0 deletions cel/cel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -913,6 +913,58 @@ func TestContextEval(t *testing.T) {
}
}

func TestContextEvalPropagation(t *testing.T) {
env, err := NewEnv(
Declarations(
decls.NewFunction("sleep", decls.NewOverload(
"sleep", []*exprpb.Type{decls.Int}, decls.Null,
)),
),
)
if err != nil {
t.Fatalf("NewEnv() failed: %v", err)
}
ast, iss := env.Compile("sleep(20)")
if iss.Err() != nil {
t.Fatalf("env.Compile(expr) failed: %v", iss.Err())
}
prg, err := env.Program(ast, EvalOptions(OptOptimize|OptTrackState), Functions(&functions.ContextOverload{
Operator: "sleep",
Unary: func(ctx context.Context, value ref.Val) ref.Val {
t := time.NewTimer(time.Duration(value.Value().(int64)) * time.Microsecond)
select {
case <-t.C:
return types.NullValue
case <-ctx.Done():
return types.NewErr("ctx done")
}
},
}))
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}

ctx := context.TODO()
out, _, err := prg.ContextEval(ctx, map[string]interface{}{})
if err != nil {
t.Fatalf("prg.ContextEval() failed: %v", err)
}
if out != types.NullValue {
t.Errorf("prg.ContextEval() got %v, wanted true", out)
}

evalCtx, cancel := context.WithTimeout(ctx, time.Microsecond)
defer cancel()

out, _, err = prg.ContextEval(evalCtx, map[string]interface{}{})
if err == nil {
t.Errorf("Got result %v, wanted timeout error", out)
}
if err != nil && err.Error() != "ctx done" {
t.Errorf("Got %v, wanted operation interrupted error", err)
}
}

func BenchmarkContextEval(b *testing.B) {
env, err := NewEnv(
Declarations(
Expand Down
2 changes: 1 addition & 1 deletion cel/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ func CustomDecorator(dec interpreter.InterpretableDecorator) ProgramOption {
}

// Functions adds function overloads that extend or override the set of CEL built-ins.
func Functions(funcs ...*functions.Overload) ProgramOption {
func Functions(funcs ...functions.Overloader) ProgramOption {
return func(p *prog) (*prog, error) {
if err := p.dispatcher.Add(funcs...); err != nil {
return nil, err
Expand Down
71 changes: 65 additions & 6 deletions cel/program.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"fmt"
"math"
"sync"
"time"

exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"

Expand Down Expand Up @@ -301,12 +302,12 @@ func (p *prog) ContextEval(ctx context.Context, input interface{}) (ref.Val, *Ev
var vars interpreter.Activation
switch v := input.(type) {
case interpreter.Activation:
vars = ctxActivationPool.Setup(v, ctx.Done(), p.interruptCheckFrequency)
vars = ctxActivationPool.Setup(ctx, v, p.interruptCheckFrequency)
defer ctxActivationPool.Put(vars)
case map[string]interface{}:
rawVars := activationPool.Setup(v)
defer activationPool.Put(rawVars)
vars = ctxActivationPool.Setup(rawVars, ctx.Done(), p.interruptCheckFrequency)
vars = ctxActivationPool.Setup(ctx, rawVars, p.interruptCheckFrequency)
defer ctxActivationPool.Put(vars)
default:
return nil, nil, fmt.Errorf("invalid input, wanted Activation or map[string]interface{}, got: (%T)%v", input, input)
Expand Down Expand Up @@ -415,12 +416,63 @@ func estimateCost(i interface{}) (min, max int64) {
}

type ctxEvalActivation struct {
ctx context.Context
parent interpreter.Activation
interrupt <-chan struct{}
interruptCheckCount uint
interruptCheckFrequency uint
}

func (a *ctxEvalActivation) Deadline() (deadline time.Time, ok bool) {
if a.parent != nil {
if d1, ok := a.parent.Deadline(); ok {
if d2, ok := a.ctx.Deadline(); ok {
if d1.Before(d2) {
return d1, true
} else {
return d2, true
}
}
return d1, ok
}
}
return a.ctx.Deadline()
}

func (a *ctxEvalActivation) Done() <-chan struct{} {
if a.parent != nil {
if a.parent.Done() != nil {
c := make(chan struct{})
go func() {
select {
case c <- <-a.parent.Done():
case c <- <-a.ctx.Done():
}
}()
return c
}
}
return a.ctx.Done()
}

func (a *ctxEvalActivation) Err() error {
if a.parent != nil {
if err := a.parent.Err(); err != nil {
return err
}
}
return a.ctx.Err()
}

func (a *ctxEvalActivation) Value(key interface{}) interface{} {
if a.parent != nil {
if v := a.parent.Value(key); v != nil {
return v
}
}
return a.ctx.Value(key)
}

// ResolveName implements the Activation interface method, but adds a special #interrupted variable
// which is capable of testing whether a 'done' signal is provided from a context.Context channel.
func (a *ctxEvalActivation) ResolveName(name string) (interface{}, bool) {
Expand All @@ -447,7 +499,7 @@ func newCtxEvalActivationPool() *ctxEvalActivationPool {
return &ctxEvalActivationPool{
Pool: sync.Pool{
New: func() interface{} {
return &ctxEvalActivation{}
return &ctxEvalActivation{ctx: context.Background()}
},
},
}
Expand All @@ -458,20 +510,27 @@ type ctxEvalActivationPool struct {
}

// Setup initializes a pooled Activation with the ability check for context.Context cancellation
func (p *ctxEvalActivationPool) Setup(vars interpreter.Activation, done <-chan struct{}, interruptCheckRate uint) *ctxEvalActivation {
func (p *ctxEvalActivationPool) Setup(ctx context.Context, vars interpreter.Activation, interruptCheckRate uint) *ctxEvalActivation {
a := p.Pool.Get().(*ctxEvalActivation)
a.ctx = ctx
a.parent = vars
a.interrupt = done
a.interrupt = ctx.Done()
a.interruptCheckCount = 0
a.interruptCheckFrequency = interruptCheckRate
return a
}

type evalActivation struct {
ctx context.Context
vars map[string]interface{}
lazyVars map[string]interface{}
}

func (a *evalActivation) Deadline() (deadline time.Time, ok bool) { return a.ctx.Deadline() }
func (a *evalActivation) Done() <-chan struct{} { return a.ctx.Done() }
func (a *evalActivation) Err() error { return a.ctx.Err() }
func (a *evalActivation) Value(key interface{}) interface{} { return a.ctx.Value }

// ResolveName looks up the value of the input variable name, if found.
//
// Lazy bindings may be supplied within the map-based input in either of the following forms:
Expand Down Expand Up @@ -516,7 +575,7 @@ func newEvalActivationPool() *evalActivationPool {
return &evalActivationPool{
Pool: sync.Pool{
New: func() interface{} {
return &evalActivation{lazyVars: make(map[string]interface{})}
return &evalActivation{ctx: context.Background(), lazyVars: make(map[string]interface{})}
},
},
}
Expand Down
101 changes: 96 additions & 5 deletions interpreter/activation.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
package interpreter

import (
"context"
"errors"
"fmt"
"sync"
"time"

"github.com/google/cel-go/common/types/ref"
)
Expand All @@ -26,6 +28,7 @@ import (
//
// An Activation is the primary mechanism by which a caller supplies input into a CEL program.
type Activation interface {
context.Context
// ResolveName returns a value from the activation by qualified name, or false if the name
// could not be found.
ResolveName(name string) (interface{}, bool)
Expand All @@ -37,14 +40,20 @@ type Activation interface {

// EmptyActivation returns a variable-free activation.
func EmptyActivation() Activation {
return emptyActivation{}
return &emptyActivation{ctx: context.Background()}
}

// emptyActivation is a variable-free activation.
type emptyActivation struct{}
type emptyActivation struct {
ctx context.Context
}

func (emptyActivation) ResolveName(string) (interface{}, bool) { return nil, false }
func (emptyActivation) Parent() Activation { return nil }
func (a *emptyActivation) Deadline() (deadline time.Time, ok bool) { return a.ctx.Deadline() }
func (a *emptyActivation) Done() <-chan struct{} { return a.ctx.Done() }
func (a *emptyActivation) Err() error { return a.ctx.Err() }
func (a *emptyActivation) Value(key interface{}) interface{} { return a.ctx.Value }
func (a *emptyActivation) ResolveName(string) (interface{}, bool) { return nil, false }
func (a *emptyActivation) Parent() Activation { return nil }

// NewActivation returns an activation based on a map-based binding where the map keys are
// expected to be qualified names used with ResolveName calls.
Expand Down Expand Up @@ -73,17 +82,23 @@ func NewActivation(bindings interface{}) (Activation, error) {
"activation input must be an activation or map[string]interface: got %T",
bindings)
}
return &mapActivation{bindings: m}, nil
return &mapActivation{ctx: context.Background(), bindings: m}, nil
}

// mapActivation which implements Activation and maps of named values.
//
// Named bindings may lazily supply values by providing a function which accepts no arguments and
// produces an interface value.
type mapActivation struct {
ctx context.Context
bindings map[string]interface{}
}

func (a *mapActivation) Deadline() (deadline time.Time, ok bool) { return a.ctx.Deadline() }
func (a *mapActivation) Done() <-chan struct{} { return a.ctx.Done() }
func (a *mapActivation) Err() error { return a.ctx.Err() }
func (a *mapActivation) Value(key interface{}) interface{} { return a.ctx.Value }

// Parent implements the Activation interface method.
func (a *mapActivation) Parent() Activation {
return nil
Expand Down Expand Up @@ -115,6 +130,54 @@ type hierarchicalActivation struct {
child Activation
}

func (a *hierarchicalActivation) Deadline() (deadline time.Time, ok bool) {
if d1, ok := a.child.Deadline(); ok {
if d2, ok := a.parent.Deadline(); ok {
if d1.Before(d2) {
return d1, true
} else {
return d2, true
}
}
return d1, ok
}
return a.parent.Deadline()
}

func (a *hierarchicalActivation) Done() <-chan struct{} {
if a.parent.Done() != nil {
if a.child.Done() != nil {
c := make(chan struct{})
go func() {
select {
case c <- <-a.parent.Done():
case c <- <-a.child.Done():
}
}()
return c
} else {
return a.parent.Done()
}
}
return a.child.Done()
}

func (a *hierarchicalActivation) Err() error {
if err := a.child.Err(); err != nil {
return err
} else if err = a.parent.Err(); err != nil {
return err
}
return nil
}

func (a *hierarchicalActivation) Value(key interface{}) interface{} {
if v := a.child.Value(key); v != nil {
return v
}
return a.parent.Value(key)
}

// Parent implements the Activation interface method.
func (a *hierarchicalActivation) Parent() Activation {
return a.parent
Expand Down Expand Up @@ -178,6 +241,34 @@ type varActivation struct {
val ref.Val
}

func (a *varActivation) Deadline() (deadline time.Time, ok bool) {
if a.parent != nil {
return a.parent.Deadline()
}
return time.Time{}, ok
}

func (a *varActivation) Done() <-chan struct{} {
if a.parent != nil {
return a.parent.Done()
}
return nil
}

func (a *varActivation) Err() error {
if a.parent != nil {
return a.parent.Err()
}
return nil
}

func (a *varActivation) Value(key interface{}) interface{} {
if a.parent != nil {
return a.parent.Value(key)
}
return nil
}

// Parent implements the Activation interface method.
func (v *varActivation) Parent() Activation {
return v.parent
Expand Down