From 5b81ae9573505619570a213f1cd359edfc58e146 Mon Sep 17 00:00:00 2001 From: TristonianJones Date: Wed, 13 Jul 2022 15:28:46 -0700 Subject: [PATCH 1/2] Ensure overloads are searched in the order they are declared during dynamic dispatch --- cel/decls.go | 78 ++++++++++++++++++++++++++-------------------------- 1 file changed, 39 insertions(+), 39 deletions(-) diff --git a/cel/decls.go b/cel/decls.go index 55532788..906aea02 100644 --- a/cel/decls.go +++ b/cel/decls.go @@ -221,7 +221,8 @@ func (t *Type) equals(other *Type) bool { // - The from types are the same instance // - The target type is dynamic // - The fromType has the same kind and type name as the target type, and all parameters of the target type -// are IsAssignableType() from the parameters of the fromType. +// +// are IsAssignableType() from the parameters of the fromType. func (t *Type) defaultIsAssignableType(fromType *Type) bool { if t == fromType || t.isDyn() { return true @@ -333,7 +334,7 @@ func Function(name string, opts ...FunctionOpt) EnvOption { return func(e *Env) (*Env, error) { fn := &functionDecl{ name: name, - overloads: map[string]*overloadDecl{}, + overloads: []*o{}, options: opts, } err := fn.init() @@ -445,12 +446,12 @@ func MemberOverload(overloadID string, args []*Type, resultType *Type, opts ...O } // OverloadOpt is a functional option for configuring a function overload. -type OverloadOpt func(*overloadDecl) (*overloadDecl, error) +type OverloadOpt func(*o) (*o, error) // UnaryBinding provides the implementation of a unary overload. The provided function is protected by a runtime // type-guard which ensures runtime type agreement between the overload signature and runtime argument types. func UnaryBinding(binding functions.UnaryOp) OverloadOpt { - return func(o *overloadDecl) (*overloadDecl, error) { + return func(o *o) (*o, error) { if o.hasBinding() { return nil, fmt.Errorf("overload already has a binding: %s", o.id) } @@ -465,7 +466,7 @@ func UnaryBinding(binding functions.UnaryOp) OverloadOpt { // BinaryBinding provides the implementation of a binary overload. The provided function is protected by a runtime // type-guard which ensures runtime type agreement between the overload signature and runtime argument types. func BinaryBinding(binding functions.BinaryOp) OverloadOpt { - return func(o *overloadDecl) (*overloadDecl, error) { + return func(o *o) (*o, error) { if o.hasBinding() { return nil, fmt.Errorf("overload already has a binding: %s", o.id) } @@ -480,7 +481,7 @@ func BinaryBinding(binding functions.BinaryOp) OverloadOpt { // FunctionBinding provides the implementation of a variadic overload. The provided function is protected by a runtime // type-guard which ensures runtime type agreement between the overload signature and runtime argument types. func FunctionBinding(binding functions.FunctionOp) OverloadOpt { - return func(o *overloadDecl) (*overloadDecl, error) { + return func(o *o) (*o, error) { if o.hasBinding() { return nil, fmt.Errorf("overload already has a binding: %s", o.id) } @@ -493,7 +494,7 @@ func FunctionBinding(binding functions.FunctionOp) OverloadOpt { // // Note: do not use this option unless absoluately necessary as it should be an uncommon feature. func OverloadIsNonStrict() OverloadOpt { - return func(o *overloadDecl) (*overloadDecl, error) { + return func(o *o) (*o, error) { o.nonStrict = true return o, nil } @@ -502,7 +503,7 @@ func OverloadIsNonStrict() OverloadOpt { // OverloadOperandTrait configures a set of traits which the first argument to the overload must implement in order to be // successfully invoked. func OverloadOperandTrait(trait int) OverloadOpt { - return func(o *overloadDecl) (*overloadDecl, error) { + return func(o *o) (*o, error) { o.operandTrait = trait return o, nil } @@ -510,7 +511,7 @@ func OverloadOperandTrait(trait int) OverloadOpt { type functionDecl struct { name string - overloads map[string]*overloadDecl + overloads []*o options []FunctionOpt singleton *functions.Overload initialized bool @@ -591,22 +592,22 @@ func (f *functionDecl) bindings() ([]*functions.Overload, error) { // performs dynamic dispatch to the proper overload based on the argument types. bindings := append([]*functions.Overload{}, overloads...) funcDispatch := func(args ...ref.Val) ref.Val { - for _, overloadDecl := range f.overloads { - if !overloadDecl.matchesRuntimeSignature(args...) { + for _, o := range f.overloads { + if !o.matchesRuntimeSignature(args...) { continue } switch len(args) { case 1: - if overloadDecl.unaryOp != nil { - return overloadDecl.unaryOp(args[0]) + if o.unaryOp != nil { + return o.unaryOp(args[0]) } case 2: - if overloadDecl.binaryOp != nil { - return overloadDecl.binaryOp(args[0], args[1]) + if o.binaryOp != nil { + return o.binaryOp(args[0], args[1]) } } - if overloadDecl.functionOp != nil { - return overloadDecl.functionOp(args...) + if o.functionOp != nil { + return o.functionOp(args...) } // eventually this will fall through to the noSuchOverload below. } @@ -639,14 +640,12 @@ func (f *functionDecl) merge(other *functionDecl) (*functionDecl, error) { } merged := &functionDecl{ name: f.name, - overloads: map[string]*overloadDecl{}, + overloads: make([]*o, len(f.overloads)), options: []FunctionOpt{}, initialized: true, singleton: f.singleton, } - for id, o := range f.overloads { - merged.overloads[id] = o - } + copy(merged.overloads, f.overloads) for _, o := range other.overloads { err := merged.addOverload(o) if err != nil { @@ -665,21 +664,22 @@ func (f *functionDecl) merge(other *functionDecl) (*functionDecl, error) { // addOverload ensures that the new overload does not collide with an existing overload signature; // however, if the function signatures are identical, the implementation may be rewritten as its // difficult to compare functions by object identity. -func (f *functionDecl) addOverload(overload *overloadDecl) error { - for id, o := range f.overloads { - if id != overload.id && o.signatureOverlaps(overload) { +func (f *functionDecl) addOverload(overload *o) error { + for index, o := range f.overloads { + if o.id != overload.id && o.signatureOverlaps(overload) { return fmt.Errorf("overload signature collision in function %s: %s collides with %s", f.name, o.id, overload.id) } - if id == overload.id { + if o.id == overload.id { if o.signatureEquals(overload) && o.nonStrict == overload.nonStrict { // Allow redefinition of an overload implementation so long as the signatures match. - f.overloads[id] = overload + f.overloads[index] = overload + return nil } else { return fmt.Errorf("overload redefinition in function. %s: %s has multiple definitions", f.name, o.id) } } } - f.overloads[overload.id] = overload + f.overloads = append(f.overloads, overload) return nil } @@ -692,8 +692,8 @@ func noSuchOverload(funcName string, args ...ref.Val) ref.Val { return types.NewErr("no such overload: %s(%s)", funcName, signature) } -// overloadDecl contains all of the relevant information regarding a specific function overload. -type overloadDecl struct { +// o contains all of the relevant information regarding a specific function overload. +type o struct { id string argTypes []*Type resultType *Type @@ -709,12 +709,12 @@ type overloadDecl struct { operandTrait int } -func (o *overloadDecl) hasBinding() bool { +func (o *o) hasBinding() bool { return o.unaryOp != nil || o.binaryOp != nil || o.functionOp != nil } // guardedUnaryOp creates an invocation guard around the provided unary operator, if one is defined. -func (o *overloadDecl) guardedUnaryOp(funcName string) functions.UnaryOp { +func (o *o) guardedUnaryOp(funcName string) functions.UnaryOp { if o.unaryOp == nil { return nil } @@ -727,7 +727,7 @@ func (o *overloadDecl) guardedUnaryOp(funcName string) functions.UnaryOp { } // guardedBinaryOp creates an invocation guard around the provided binary operator, if one is defined. -func (o *overloadDecl) guardedBinaryOp(funcName string) functions.BinaryOp { +func (o *o) guardedBinaryOp(funcName string) functions.BinaryOp { if o.binaryOp == nil { return nil } @@ -740,7 +740,7 @@ func (o *overloadDecl) guardedBinaryOp(funcName string) functions.BinaryOp { } // guardedFunctionOp creates an invocation guard around the provided variadic function binding, if one is provided. -func (o *overloadDecl) guardedFunctionOp(funcName string) functions.FunctionOp { +func (o *o) guardedFunctionOp(funcName string) functions.FunctionOp { if o.functionOp == nil { return nil } @@ -753,7 +753,7 @@ func (o *overloadDecl) guardedFunctionOp(funcName string) functions.FunctionOp { } // matchesRuntimeUnarySignature indicates whether the argument type is runtime assiganble to the overload's expected argument. -func (o *overloadDecl) matchesRuntimeUnarySignature(arg ref.Val) bool { +func (o *o) matchesRuntimeUnarySignature(arg ref.Val) bool { if o.nonStrict && types.IsUnknownOrError(arg) { return true } @@ -761,7 +761,7 @@ func (o *overloadDecl) matchesRuntimeUnarySignature(arg ref.Val) bool { } // matchesRuntimeBinarySignature indicates whether the argument types are runtime assiganble to the overload's expected arguments. -func (o *overloadDecl) matchesRuntimeBinarySignature(arg1, arg2 ref.Val) bool { +func (o *o) matchesRuntimeBinarySignature(arg1, arg2 ref.Val) bool { if o.nonStrict { if types.IsUnknownOrError(arg1) { return types.IsUnknownOrError(arg2) || o.argTypes[1].IsAssignableRuntimeType(arg2.Type()) @@ -773,7 +773,7 @@ func (o *overloadDecl) matchesRuntimeBinarySignature(arg1, arg2 ref.Val) bool { } // matchesRuntimeSignature indicates whether the argument types are runtime assiganble to the overload's expected arguments. -func (o *overloadDecl) matchesRuntimeSignature(args ...ref.Val) bool { +func (o *o) matchesRuntimeSignature(args ...ref.Val) bool { if len(args) != len(o.argTypes) { return false } @@ -795,7 +795,7 @@ func (o *overloadDecl) matchesRuntimeSignature(args ...ref.Val) bool { // signatureEquals indicates whether one overload has an identical signature to another overload. // // Providing a duplicate signature is not an issue, but an overloapping signature is problematic. -func (o *overloadDecl) signatureEquals(other *overloadDecl) bool { +func (o *o) signatureEquals(other *o) bool { if o.id != other.id || o.memberFunction != other.memberFunction || len(o.argTypes) != len(other.argTypes) { return false } @@ -811,7 +811,7 @@ func (o *overloadDecl) signatureEquals(other *overloadDecl) bool { // signatureOverlaps indicates whether one overload has an overlapping signature with another overload. // // The 'other' overload must first be checked for equality before determining whether it overlaps in order to be completely accurate. -func (o *overloadDecl) signatureOverlaps(other *overloadDecl) bool { +func (o *o) signatureOverlaps(other *o) bool { if o.memberFunction != other.memberFunction || len(o.argTypes) != len(other.argTypes) { return false } @@ -827,7 +827,7 @@ func (o *overloadDecl) signatureOverlaps(other *overloadDecl) bool { func newOverload(overloadID string, memberFunction bool, args []*Type, resultType *Type, opts ...OverloadOpt) FunctionOpt { return func(f *functionDecl) (*functionDecl, error) { - overload := &overloadDecl{ + overload := &o{ id: overloadID, argTypes: args, resultType: resultType, From a94354b6e50e17c7d1fef2a9799e486884f3e89b Mon Sep 17 00:00:00 2001 From: TristonianJones Date: Wed, 13 Jul 2022 16:56:21 -0700 Subject: [PATCH 2/2] Improved support for dynamic dispatch --- cel/cel_test.go | 205 ++++++++++++++++++++++++++++++++++++++-------- cel/decls.go | 120 ++++++++++++++++++--------- cel/decls_test.go | 22 ++++- 3 files changed, 273 insertions(+), 74 deletions(-) diff --git a/cel/cel_test.go b/cel/cel_test.go index 36dbdcc4..5b8e0b69 100644 --- a/cel/cel_test.go +++ b/cel/cel_test.go @@ -674,7 +674,10 @@ func TestGlobalVars(t *testing.T) { t.Run("attrs_alt", func(t *testing.T) { vars := map[string]interface{}{ "attrs": map[string]interface{}{"second": "yep"}} - out, _, _ := prg.Eval(vars) + out, _, err := prg.Eval(vars) + if err != nil { + t.Fatalf("prg.Eval(vars) failed: %v", err) + } if out.Equal(types.String("yep")) != types.True { t.Errorf("got '%v', expected 'yep'.", out.Value()) } @@ -1657,7 +1660,7 @@ func TestDefaultUTCTimeZone(t *testing.T) { if err != nil { t.Fatalf("NewEnv() failed: %v", err) } - ast, iss := env.Compile(` + out, err := interpret(t, env, ` x.getFullYear() == 1970 && x.getMonth() == 0 && x.getDayOfYear() == 0 @@ -1687,16 +1690,10 @@ func TestDefaultUTCTimeZone(t *testing.T) { && x.getHours('23:15') == 1 && x.getMinutes('23:15') == 20 && x.getSeconds('23:15') == 6 - && x.getMilliseconds('23:15') == 1 - `) - if iss.Err() != nil { - t.Fatalf("env.Compile() failed: %v", iss.Err()) - } - prg, err := env.Program(ast) - if err != nil { - t.Fatalf("env.Program() failed: %v", err) - } - out, _, err := prg.Eval(map[string]interface{}{"x": time.Unix(7506, 1000000).Local()}) + && x.getMilliseconds('23:15') == 1`, + map[string]interface{}{ + "x": time.Unix(7506, 1000000).Local(), + }) if err != nil { t.Fatalf("prg.Eval() failed: %v", err) } @@ -1718,20 +1715,12 @@ func TestDefaultUTCTimeZoneExtension(t *testing.T) { if err != nil { t.Fatalf("env.Extend() failed: %v", err) } - ast, iss := env.Compile(` + out, err := interpret(t, env, ` x.getFullYear() == 1970 && y.getHours() == 2 && y.getMinutes() == 120 && y.getSeconds() == 7235 - && y.getMilliseconds() == 7235000`) - if iss.Err() != nil { - t.Fatalf("env.Compile() failed: %v", iss.Err()) - } - prg, err := env.Program(ast) - if err != nil { - t.Fatalf("env.Program() failed: %v", err) - } - out, _, err := prg.Eval( + && y.getMilliseconds() == 7235000`, map[string]interface{}{ "x": time.Unix(7506, 1000000).Local(), "y": time.Duration(7235) * time.Second, @@ -1750,7 +1739,7 @@ func TestDefaultUTCTimeZoneError(t *testing.T) { if err != nil { t.Fatalf("NewEnv() failed: %v", err) } - ast, iss := env.Compile(` + out, err := interpret(t, env, ` x.getFullYear(':xx') == 1969 || x.getDayOfYear('xx:') == 364 || x.getMonth('Am/Ph') == 11 @@ -1761,30 +1750,180 @@ func TestDefaultUTCTimeZoneError(t *testing.T) { || x.getMinutes('Am/Ph') == 5 || x.getSeconds('Am/Ph') == 6 || x.getMilliseconds('Am/Ph') == 1 - `) - if iss.Err() != nil { - t.Fatalf("env.Compile() failed: %v", iss.Err()) + `, map[string]interface{}{ + "x": time.Unix(7506, 1000000).Local(), + }, + ) + if err == nil { + t.Fatalf("prg.Eval() got %v wanted error", out) } - prg, err := env.Program(ast) +} + +func TestDynamicDispatch(t *testing.T) { + env, err := NewEnv( + HomogeneousAggregateLiterals(), + Function("first", + MemberOverload("first_list_int", []*Type{ListType(IntType)}, IntType, + UnaryBinding(func(list ref.Val) ref.Val { + l := list.(traits.Lister) + if l.Size() == types.IntZero { + return types.IntZero + } + return l.Get(types.IntZero) + }), + ), + MemberOverload("first_list_double", []*Type{ListType(DoubleType)}, DoubleType, + UnaryBinding(func(list ref.Val) ref.Val { + l := list.(traits.Lister) + if l.Size() == types.IntZero { + return types.Double(0.0) + } + return l.Get(types.IntZero) + }), + ), + MemberOverload("first_list_string", []*Type{ListType(StringType)}, StringType, + UnaryBinding(func(list ref.Val) ref.Val { + l := list.(traits.Lister) + if l.Size() == types.IntZero { + return types.String("") + } + return l.Get(types.IntZero) + }), + ), + MemberOverload("first_list_list_string", []*Type{ListType(ListType(StringType))}, ListType(StringType), + UnaryBinding(func(list ref.Val) ref.Val { + l := list.(traits.Lister) + if l.Size() == types.IntZero { + return types.DefaultTypeAdapter.NativeToValue([]string{}) + } + return l.Get(types.IntZero) + }), + ), + ), + ) if err != nil { - t.Fatalf("env.Program() failed: %v", err) + t.Fatalf("NewEnv() failed: %v", err) } - out, _, err := prg.Eval(map[string]interface{}{"x": time.Unix(7506, 1000000).Local()}) - if err == nil { - t.Fatalf("prg.Eval() got %v wanted error", out) + out, err := interpret(t, env, ` + [].first() == 0 + && [1, 2].first() == 1 + && [1.0, 2.0].first() == 1.0 + && ["hello", "world"].first() == "hello" + && [["hello"], ["world", "!"]].first().first() == "hello" + && [[], ["empty"]].first().first() == "" + && dyn([1, 2]).first() == 1 + && dyn([1.0, 2.0]).first() == 1.0 + && dyn(["hello", "world"]).first() == "hello" + && dyn([["hello"], ["world", "!"]]).first().first() == "hello" + `, map[string]interface{}{}, + ) + if err != nil { + t.Fatalf("prg.Eval() failed: %v", err) + } + if out != types.True { + t.Fatalf("prg.Eval() got %v wanted true", out) + } +} + +func BenchmarkDynamicDispatch(b *testing.B) { + env, err := NewEnv( + HomogeneousAggregateLiterals(), + Function("first", + MemberOverload("first_list_int", []*Type{ListType(IntType)}, IntType, + UnaryBinding(func(list ref.Val) ref.Val { + l := list.(traits.Lister) + if l.Size() == types.IntZero { + return types.IntZero + } + return l.Get(types.IntZero) + }), + ), + MemberOverload("first_list_double", []*Type{ListType(DoubleType)}, DoubleType, + UnaryBinding(func(list ref.Val) ref.Val { + l := list.(traits.Lister) + if l.Size() == types.IntZero { + return types.Double(0.0) + } + return l.Get(types.IntZero) + }), + ), + MemberOverload("first_list_string", []*Type{ListType(StringType)}, StringType, + UnaryBinding(func(list ref.Val) ref.Val { + l := list.(traits.Lister) + if l.Size() == types.IntZero { + return types.String("") + } + return l.Get(types.IntZero) + }), + ), + MemberOverload("first_list_list_string", []*Type{ListType(ListType(StringType))}, ListType(StringType), + UnaryBinding(func(list ref.Val) ref.Val { + l := list.(traits.Lister) + if l.Size() == types.IntZero { + return types.DefaultTypeAdapter.NativeToValue([]string{}) + } + return l.Get(types.IntZero) + }), + ), + ), + ) + if err != nil { + b.Fatalf("NewEnv() failed: %v", err) } + prg := compile(b, env, ` + [].first() == 0 + && [1, 2].first() == 1 + && [1.0, 2.0].first() == 1.0 + && ["hello", "world"].first() == "hello" + && [["hello"], ["world", "!"]].first().first() == "hello"`) + prgDyn := compile(b, env, ` + dyn([]).first() == 0 + && dyn([1, 2]).first() == 1 + && dyn([1.0, 2.0]).first() == 1.0 + && dyn(["hello", "world"]).first() == "hello" + && dyn([["hello"], ["world", "!"]]).first().first() == "hello"`) + b.ResetTimer() + b.Run("DirectDispatch", func(b *testing.B) { + for i := 0; i < b.N; i++ { + prg.Eval(NoVars()) + } + }) + b.ResetTimer() + b.Run("DynamicDispatch", func(b *testing.B) { + for i := 0; i < b.N; i++ { + prgDyn.Eval(NoVars()) + } + }) } -func interpret(t *testing.T, env *Env, expr string, vars interface{}) (ref.Val, error) { +func compile(t testing.TB, env *Env, expr string) Program { + t.Helper() + prg, err := compileOrError(t, env, expr) + if err != nil { + t.Fatal(err) + } + return prg +} + +func compileOrError(t testing.TB, env *Env, expr string) (Program, error) { t.Helper() ast, iss := env.Compile(expr) if iss.Err() != nil { return nil, fmt.Errorf("env.Compile(%s) failed: %v", expr, iss.Err()) } - prg, err := env.Program(ast) + prg, err := env.Program(ast, EvalOptions(OptOptimize)) if err != nil { return nil, fmt.Errorf("env.Program() failed: %v", err) } + return prg, nil +} + +func interpret(t testing.TB, env *Env, expr string, vars interface{}) (ref.Val, error) { + t.Helper() + prg, err := compileOrError(t, env, expr) + if err != nil { + return nil, err + } out, _, err := prg.Eval(vars) if err != nil { return nil, fmt.Errorf("prg.Eval(%v) failed: %v", vars, err) diff --git a/cel/decls.go b/cel/decls.go index 906aea02..f2df721d 100644 --- a/cel/decls.go +++ b/cel/decls.go @@ -21,6 +21,7 @@ import ( "github.com/google/cel-go/checker/decls" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/common/types/traits" "github.com/google/cel-go/interpreter/functions" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" @@ -162,7 +163,7 @@ type Type struct { // isAssignableRuntimeType function determines whether the runtime type (with erasure) is assignable to this type. // A nil value for the isAssignableRuntimeType function falls back to the equality of the type or type name. - isAssignableRuntimeType func(other ref.Type) bool + isAssignableRuntimeType func(other ref.Val) bool } // IsAssignableType determines whether the current type is type-check assignable from the input fromType. @@ -177,11 +178,11 @@ func (t *Type) IsAssignableType(fromType *Type) bool { // // At runtime, parameterized types are erased and so a function which type-checks to support a map(string, string) // will have a runtime assignable type of a map. -func (t *Type) IsAssignableRuntimeType(runtimeType ref.Type) bool { +func (t *Type) IsAssignableRuntimeType(val ref.Val) bool { if t.isAssignableRuntimeType != nil { - return t.isAssignableRuntimeType(runtimeType) + return t.isAssignableRuntimeType(val) } - return t.defaultIsAssignableRuntimeType(runtimeType) + return t.defaultIsAssignableRuntimeType(val) } // String returns a human-readable definition of the type name. @@ -221,8 +222,7 @@ func (t *Type) equals(other *Type) bool { // - The from types are the same instance // - The target type is dynamic // - The fromType has the same kind and type name as the target type, and all parameters of the target type -// -// are IsAssignableType() from the parameters of the fromType. +// are IsAssignableType() from the parameters of the fromType. func (t *Type) defaultIsAssignableType(fromType *Type) bool { if t == fromType || t.isDyn() { return true @@ -241,8 +241,40 @@ func (t *Type) defaultIsAssignableType(fromType *Type) bool { return true } -func (t *Type) defaultIsAssignableRuntimeType(runtimeType ref.Type) bool { - return t.runtimeType == runtimeType || t.isDyn() || t.runtimeType.TypeName() == runtimeType.TypeName() +// defaultIsAssignableRuntimeType inspects the type and in the case of list and map elements, the key and element types +// to determine whether a ref.Val is assignable to the declared type for a function signature. +func (t *Type) defaultIsAssignableRuntimeType(val ref.Val) bool { + valType := val.Type() + if !(t.runtimeType == valType || t.isDyn() || t.runtimeType.TypeName() == valType.TypeName()) { + return false + } + switch t.runtimeType { + case types.ListType: + elemType := t.parameters[0] + l := val.(traits.Lister) + if l.Size() == types.IntZero { + return true + } + it := l.Iterator() + for it.HasNext() == types.True { + elemVal := it.Next() + return elemType.IsAssignableRuntimeType(elemVal) + } + case types.MapType: + keyType := t.parameters[0] + elemType := t.parameters[1] + m := val.(traits.Mapper) + if m.Size() == types.IntZero { + return true + } + it := m.Iterator() + for it.HasNext() == types.True { + keyVal := it.Next() + elemVal := m.Get(keyVal) + return keyType.IsAssignableRuntimeType(keyVal) && elemType.IsAssignableRuntimeType(elemVal) + } + } + return true } // ListType creates an instances of a list type value with the provided element type. @@ -274,7 +306,7 @@ func NullableType(wrapped *Type) *Type { isAssignableType: func(other *Type) bool { return NullType.IsAssignableType(other) || wrapped.IsAssignableType(other) }, - isAssignableRuntimeType: func(other ref.Type) bool { + isAssignableRuntimeType: func(other ref.Val) bool { return NullType.IsAssignableRuntimeType(other) || wrapped.IsAssignableRuntimeType(other) }, } @@ -329,12 +361,26 @@ func Variable(name string, t *Type) EnvOption { // One key difference with using Function() is that each FunctionDecl provided will handle dynamic // dispatch based on the type-signatures of the overloads provided which means overload resolution at // runtime is handled out of the box rather than via a custom binding for overload resolution via -// Functions(). +// Functions(): +// +// - Overloads are searched in the order they are declared +// - Dynamic dispatch for lists and maps is limited by inspection of the list and map contents +// at runtime. Empty lists and maps will result in a 'default dispatch' +// - In the event that a default dispatch occurs, the first overload provided is the one invoked +// +// If you intend to use overloads which differentiate based on the key or element type of a list or +// map, consider using a generic function instead: e.g. func(list(T)) or func(map(K, V)) as this +// will allow your implementation to determine how best to handle dispatch and the default behavior +// for empty lists and maps whose contents cannot be inspected. +// +// For functions which use parameterized opaque types (abstract types), consider using a singleton +// function which is capable of inspecting the contents of the type and resolving the appropriate +// overload as CEL can only make inferences by type-name regarding such types. func Function(name string, opts ...FunctionOpt) EnvOption { return func(e *Env) (*Env, error) { fn := &functionDecl{ name: name, - overloads: []*o{}, + overloads: []*overloadDecl{}, options: opts, } err := fn.init() @@ -446,12 +492,12 @@ func MemberOverload(overloadID string, args []*Type, resultType *Type, opts ...O } // OverloadOpt is a functional option for configuring a function overload. -type OverloadOpt func(*o) (*o, error) +type OverloadOpt func(*overloadDecl) (*overloadDecl, error) // UnaryBinding provides the implementation of a unary overload. The provided function is protected by a runtime // type-guard which ensures runtime type agreement between the overload signature and runtime argument types. func UnaryBinding(binding functions.UnaryOp) OverloadOpt { - return func(o *o) (*o, error) { + return func(o *overloadDecl) (*overloadDecl, error) { if o.hasBinding() { return nil, fmt.Errorf("overload already has a binding: %s", o.id) } @@ -466,7 +512,7 @@ func UnaryBinding(binding functions.UnaryOp) OverloadOpt { // BinaryBinding provides the implementation of a binary overload. The provided function is protected by a runtime // type-guard which ensures runtime type agreement between the overload signature and runtime argument types. func BinaryBinding(binding functions.BinaryOp) OverloadOpt { - return func(o *o) (*o, error) { + return func(o *overloadDecl) (*overloadDecl, error) { if o.hasBinding() { return nil, fmt.Errorf("overload already has a binding: %s", o.id) } @@ -481,7 +527,7 @@ func BinaryBinding(binding functions.BinaryOp) OverloadOpt { // FunctionBinding provides the implementation of a variadic overload. The provided function is protected by a runtime // type-guard which ensures runtime type agreement between the overload signature and runtime argument types. func FunctionBinding(binding functions.FunctionOp) OverloadOpt { - return func(o *o) (*o, error) { + return func(o *overloadDecl) (*overloadDecl, error) { if o.hasBinding() { return nil, fmt.Errorf("overload already has a binding: %s", o.id) } @@ -494,7 +540,7 @@ func FunctionBinding(binding functions.FunctionOp) OverloadOpt { // // Note: do not use this option unless absoluately necessary as it should be an uncommon feature. func OverloadIsNonStrict() OverloadOpt { - return func(o *o) (*o, error) { + return func(o *overloadDecl) (*overloadDecl, error) { o.nonStrict = true return o, nil } @@ -503,7 +549,7 @@ func OverloadIsNonStrict() OverloadOpt { // OverloadOperandTrait configures a set of traits which the first argument to the overload must implement in order to be // successfully invoked. func OverloadOperandTrait(trait int) OverloadOpt { - return func(o *o) (*o, error) { + return func(o *overloadDecl) (*overloadDecl, error) { o.operandTrait = trait return o, nil } @@ -511,7 +557,7 @@ func OverloadOperandTrait(trait int) OverloadOpt { type functionDecl struct { name string - overloads []*o + overloads []*overloadDecl options []FunctionOpt singleton *functions.Overload initialized bool @@ -640,7 +686,7 @@ func (f *functionDecl) merge(other *functionDecl) (*functionDecl, error) { } merged := &functionDecl{ name: f.name, - overloads: make([]*o, len(f.overloads)), + overloads: make([]*overloadDecl, len(f.overloads)), options: []FunctionOpt{}, initialized: true, singleton: f.singleton, @@ -664,7 +710,7 @@ func (f *functionDecl) merge(other *functionDecl) (*functionDecl, error) { // addOverload ensures that the new overload does not collide with an existing overload signature; // however, if the function signatures are identical, the implementation may be rewritten as its // difficult to compare functions by object identity. -func (f *functionDecl) addOverload(overload *o) error { +func (f *functionDecl) addOverload(overload *overloadDecl) error { for index, o := range f.overloads { if o.id != overload.id && o.signatureOverlaps(overload) { return fmt.Errorf("overload signature collision in function %s: %s collides with %s", f.name, o.id, overload.id) @@ -692,8 +738,8 @@ func noSuchOverload(funcName string, args ...ref.Val) ref.Val { return types.NewErr("no such overload: %s(%s)", funcName, signature) } -// o contains all of the relevant information regarding a specific function overload. -type o struct { +// overloadDecl contains all of the relevant information regarding a specific function overload. +type overloadDecl struct { id string argTypes []*Type resultType *Type @@ -709,12 +755,12 @@ type o struct { operandTrait int } -func (o *o) hasBinding() bool { +func (o *overloadDecl) hasBinding() bool { return o.unaryOp != nil || o.binaryOp != nil || o.functionOp != nil } // guardedUnaryOp creates an invocation guard around the provided unary operator, if one is defined. -func (o *o) guardedUnaryOp(funcName string) functions.UnaryOp { +func (o *overloadDecl) guardedUnaryOp(funcName string) functions.UnaryOp { if o.unaryOp == nil { return nil } @@ -727,7 +773,7 @@ func (o *o) guardedUnaryOp(funcName string) functions.UnaryOp { } // guardedBinaryOp creates an invocation guard around the provided binary operator, if one is defined. -func (o *o) guardedBinaryOp(funcName string) functions.BinaryOp { +func (o *overloadDecl) guardedBinaryOp(funcName string) functions.BinaryOp { if o.binaryOp == nil { return nil } @@ -740,7 +786,7 @@ func (o *o) guardedBinaryOp(funcName string) functions.BinaryOp { } // guardedFunctionOp creates an invocation guard around the provided variadic function binding, if one is provided. -func (o *o) guardedFunctionOp(funcName string) functions.FunctionOp { +func (o *overloadDecl) guardedFunctionOp(funcName string) functions.FunctionOp { if o.functionOp == nil { return nil } @@ -753,27 +799,27 @@ func (o *o) guardedFunctionOp(funcName string) functions.FunctionOp { } // matchesRuntimeUnarySignature indicates whether the argument type is runtime assiganble to the overload's expected argument. -func (o *o) matchesRuntimeUnarySignature(arg ref.Val) bool { +func (o *overloadDecl) matchesRuntimeUnarySignature(arg ref.Val) bool { if o.nonStrict && types.IsUnknownOrError(arg) { return true } - return o.argTypes[0].IsAssignableRuntimeType(arg.Type()) && (o.operandTrait == 0 || arg.Type().HasTrait(o.operandTrait)) + return o.argTypes[0].IsAssignableRuntimeType(arg) && (o.operandTrait == 0 || arg.Type().HasTrait(o.operandTrait)) } // matchesRuntimeBinarySignature indicates whether the argument types are runtime assiganble to the overload's expected arguments. -func (o *o) matchesRuntimeBinarySignature(arg1, arg2 ref.Val) bool { +func (o *overloadDecl) matchesRuntimeBinarySignature(arg1, arg2 ref.Val) bool { if o.nonStrict { if types.IsUnknownOrError(arg1) { - return types.IsUnknownOrError(arg2) || o.argTypes[1].IsAssignableRuntimeType(arg2.Type()) + return types.IsUnknownOrError(arg2) || o.argTypes[1].IsAssignableRuntimeType(arg2) } - } else if !o.argTypes[1].IsAssignableRuntimeType(arg2.Type()) { + } else if !o.argTypes[1].IsAssignableRuntimeType(arg2) { return false } - return o.argTypes[0].IsAssignableRuntimeType(arg1.Type()) && (o.operandTrait == 0 || arg1.Type().HasTrait(o.operandTrait)) + return o.argTypes[0].IsAssignableRuntimeType(arg1) && (o.operandTrait == 0 || arg1.Type().HasTrait(o.operandTrait)) } // matchesRuntimeSignature indicates whether the argument types are runtime assiganble to the overload's expected arguments. -func (o *o) matchesRuntimeSignature(args ...ref.Val) bool { +func (o *overloadDecl) matchesRuntimeSignature(args ...ref.Val) bool { if len(args) != len(o.argTypes) { return false } @@ -785,7 +831,7 @@ func (o *o) matchesRuntimeSignature(args ...ref.Val) bool { if o.nonStrict && types.IsUnknownOrError(arg) { continue } - allArgsMatch = allArgsMatch && o.argTypes[i].IsAssignableRuntimeType(arg.Type()) + allArgsMatch = allArgsMatch && o.argTypes[i].IsAssignableRuntimeType(arg) } arg := args[0] @@ -795,7 +841,7 @@ func (o *o) matchesRuntimeSignature(args ...ref.Val) bool { // signatureEquals indicates whether one overload has an identical signature to another overload. // // Providing a duplicate signature is not an issue, but an overloapping signature is problematic. -func (o *o) signatureEquals(other *o) bool { +func (o *overloadDecl) signatureEquals(other *overloadDecl) bool { if o.id != other.id || o.memberFunction != other.memberFunction || len(o.argTypes) != len(other.argTypes) { return false } @@ -811,7 +857,7 @@ func (o *o) signatureEquals(other *o) bool { // signatureOverlaps indicates whether one overload has an overlapping signature with another overload. // // The 'other' overload must first be checked for equality before determining whether it overlaps in order to be completely accurate. -func (o *o) signatureOverlaps(other *o) bool { +func (o *overloadDecl) signatureOverlaps(other *overloadDecl) bool { if o.memberFunction != other.memberFunction || len(o.argTypes) != len(other.argTypes) { return false } @@ -827,7 +873,7 @@ func (o *o) signatureOverlaps(other *o) bool { func newOverload(overloadID string, memberFunction bool, args []*Type, resultType *Type, opts ...OverloadOpt) FunctionOpt { return func(f *functionDecl) (*functionDecl, error) { - overload := &o{ + overload := &overloadDecl{ id: overloadID, argTypes: args, resultType: resultType, diff --git a/cel/decls_test.go b/cel/decls_test.go index 655b0ed9..d1cadd72 100644 --- a/cel/decls_test.go +++ b/cel/decls_test.go @@ -20,6 +20,7 @@ import ( "reflect" "strings" "testing" + "time" "github.com/google/cel-go/checker/decls" "github.com/google/cel-go/common/operators" @@ -659,14 +660,27 @@ func TestIsAssignableType(t *testing.T) { } func TestIsAssignableRuntimeType(t *testing.T) { - if !NullableType(DoubleType).IsAssignableRuntimeType(types.NullType) { + if !NullableType(DoubleType).IsAssignableRuntimeType(types.NullValue) { t.Error("nullable double cannot be assigned from null") } - if !NullableType(DoubleType).IsAssignableRuntimeType(types.DoubleType) { + if !NullableType(DoubleType).IsAssignableRuntimeType(types.Double(0.0)) { t.Error("nullable double cannot be assigned from double") } - if !MapType(StringType, DurationType).IsAssignableRuntimeType(types.MapType) { - t.Error("map(string, duration) not assibale to map at runtime") + if !MapType(StringType, DurationType).IsAssignableRuntimeType( + types.DefaultTypeAdapter.NativeToValue(map[string]time.Duration{})) { + t.Error("map(string, duration) not assignable to map at runtime") + } + if !MapType(StringType, DurationType).IsAssignableRuntimeType( + types.DefaultTypeAdapter.NativeToValue(map[string]time.Duration{"one": time.Duration(1)})) { + t.Error("map(string, duration) not assignable to map at runtime") + } + if !MapType(StringType, DynType).IsAssignableRuntimeType( + types.DefaultTypeAdapter.NativeToValue(map[string]time.Duration{"one": time.Duration(1)})) { + t.Error("map(string, dyn) not assignable to map at runtime") + } + if MapType(StringType, DynType).IsAssignableRuntimeType( + types.DefaultTypeAdapter.NativeToValue(map[int64]time.Duration{1: time.Duration(1)})) { + t.Error("map(string, dyn) must not be assignable to map(int, duration) at runtime") } }