Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure overloads are searched in the order they are declared #566

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
205 changes: 172 additions & 33 deletions cel/cel_test.go
Expand Up @@ -674,7 +674,10 @@ func TestGlobalVars(t *testing.T) {
t.Run("attrs_alt", func(t *testing.T) {
vars := map[string]interface{}{
"attrs": map[string]interface{}{"second": "yep"}}
out, _, _ := prg.Eval(vars)
out, _, err := prg.Eval(vars)
if err != nil {
t.Fatalf("prg.Eval(vars) failed: %v", err)
}
if out.Equal(types.String("yep")) != types.True {
t.Errorf("got '%v', expected 'yep'.", out.Value())
}
Expand Down Expand Up @@ -1657,7 +1660,7 @@ func TestDefaultUTCTimeZone(t *testing.T) {
if err != nil {
t.Fatalf("NewEnv() failed: %v", err)
}
ast, iss := env.Compile(`
out, err := interpret(t, env, `
x.getFullYear() == 1970
&& x.getMonth() == 0
&& x.getDayOfYear() == 0
Expand Down Expand Up @@ -1687,16 +1690,10 @@ func TestDefaultUTCTimeZone(t *testing.T) {
&& x.getHours('23:15') == 1
&& x.getMinutes('23:15') == 20
&& x.getSeconds('23:15') == 6
&& x.getMilliseconds('23:15') == 1
`)
if iss.Err() != nil {
t.Fatalf("env.Compile() failed: %v", iss.Err())
}
prg, err := env.Program(ast)
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
out, _, err := prg.Eval(map[string]interface{}{"x": time.Unix(7506, 1000000).Local()})
&& x.getMilliseconds('23:15') == 1`,
map[string]interface{}{
"x": time.Unix(7506, 1000000).Local(),
})
if err != nil {
t.Fatalf("prg.Eval() failed: %v", err)
}
Expand All @@ -1718,20 +1715,12 @@ func TestDefaultUTCTimeZoneExtension(t *testing.T) {
if err != nil {
t.Fatalf("env.Extend() failed: %v", err)
}
ast, iss := env.Compile(`
out, err := interpret(t, env, `
x.getFullYear() == 1970
&& y.getHours() == 2
&& y.getMinutes() == 120
&& y.getSeconds() == 7235
&& y.getMilliseconds() == 7235000`)
if iss.Err() != nil {
t.Fatalf("env.Compile() failed: %v", iss.Err())
}
prg, err := env.Program(ast)
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
out, _, err := prg.Eval(
&& y.getMilliseconds() == 7235000`,
map[string]interface{}{
"x": time.Unix(7506, 1000000).Local(),
"y": time.Duration(7235) * time.Second,
Expand All @@ -1750,7 +1739,7 @@ func TestDefaultUTCTimeZoneError(t *testing.T) {
if err != nil {
t.Fatalf("NewEnv() failed: %v", err)
}
ast, iss := env.Compile(`
out, err := interpret(t, env, `
x.getFullYear(':xx') == 1969
|| x.getDayOfYear('xx:') == 364
|| x.getMonth('Am/Ph') == 11
Expand All @@ -1761,30 +1750,180 @@ func TestDefaultUTCTimeZoneError(t *testing.T) {
|| x.getMinutes('Am/Ph') == 5
|| x.getSeconds('Am/Ph') == 6
|| x.getMilliseconds('Am/Ph') == 1
`)
if iss.Err() != nil {
t.Fatalf("env.Compile() failed: %v", iss.Err())
`, map[string]interface{}{
"x": time.Unix(7506, 1000000).Local(),
},
)
if err == nil {
t.Fatalf("prg.Eval() got %v wanted error", out)
}
prg, err := env.Program(ast)
}

func TestDynamicDispatch(t *testing.T) {
env, err := NewEnv(
HomogeneousAggregateLiterals(),
Function("first",
MemberOverload("first_list_int", []*Type{ListType(IntType)}, IntType,
UnaryBinding(func(list ref.Val) ref.Val {
l := list.(traits.Lister)
if l.Size() == types.IntZero {
return types.IntZero
}
return l.Get(types.IntZero)
}),
),
MemberOverload("first_list_double", []*Type{ListType(DoubleType)}, DoubleType,
UnaryBinding(func(list ref.Val) ref.Val {
l := list.(traits.Lister)
if l.Size() == types.IntZero {
return types.Double(0.0)
}
return l.Get(types.IntZero)
}),
),
MemberOverload("first_list_string", []*Type{ListType(StringType)}, StringType,
UnaryBinding(func(list ref.Val) ref.Val {
l := list.(traits.Lister)
if l.Size() == types.IntZero {
return types.String("")
}
return l.Get(types.IntZero)
}),
),
MemberOverload("first_list_list_string", []*Type{ListType(ListType(StringType))}, ListType(StringType),
UnaryBinding(func(list ref.Val) ref.Val {
l := list.(traits.Lister)
if l.Size() == types.IntZero {
return types.DefaultTypeAdapter.NativeToValue([]string{})
}
return l.Get(types.IntZero)
}),
),
),
)
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
t.Fatalf("NewEnv() failed: %v", err)
}
out, _, err := prg.Eval(map[string]interface{}{"x": time.Unix(7506, 1000000).Local()})
if err == nil {
t.Fatalf("prg.Eval() got %v wanted error", out)
out, err := interpret(t, env, `
[].first() == 0
&& [1, 2].first() == 1
&& [1.0, 2.0].first() == 1.0
&& ["hello", "world"].first() == "hello"
&& [["hello"], ["world", "!"]].first().first() == "hello"
&& [[], ["empty"]].first().first() == ""
&& dyn([1, 2]).first() == 1
&& dyn([1.0, 2.0]).first() == 1.0
&& dyn(["hello", "world"]).first() == "hello"
&& dyn([["hello"], ["world", "!"]]).first().first() == "hello"
`, map[string]interface{}{},
)
if err != nil {
t.Fatalf("prg.Eval() failed: %v", err)
}
if out != types.True {
t.Fatalf("prg.Eval() got %v wanted true", out)
}
}

func BenchmarkDynamicDispatch(b *testing.B) {
env, err := NewEnv(
HomogeneousAggregateLiterals(),
Function("first",
MemberOverload("first_list_int", []*Type{ListType(IntType)}, IntType,
UnaryBinding(func(list ref.Val) ref.Val {
l := list.(traits.Lister)
if l.Size() == types.IntZero {
return types.IntZero
}
return l.Get(types.IntZero)
}),
),
MemberOverload("first_list_double", []*Type{ListType(DoubleType)}, DoubleType,
UnaryBinding(func(list ref.Val) ref.Val {
l := list.(traits.Lister)
if l.Size() == types.IntZero {
return types.Double(0.0)
}
return l.Get(types.IntZero)
}),
),
MemberOverload("first_list_string", []*Type{ListType(StringType)}, StringType,
UnaryBinding(func(list ref.Val) ref.Val {
l := list.(traits.Lister)
if l.Size() == types.IntZero {
return types.String("")
}
return l.Get(types.IntZero)
}),
),
MemberOverload("first_list_list_string", []*Type{ListType(ListType(StringType))}, ListType(StringType),
UnaryBinding(func(list ref.Val) ref.Val {
l := list.(traits.Lister)
if l.Size() == types.IntZero {
return types.DefaultTypeAdapter.NativeToValue([]string{})
}
return l.Get(types.IntZero)
}),
),
),
)
if err != nil {
b.Fatalf("NewEnv() failed: %v", err)
}
prg := compile(b, env, `
[].first() == 0
&& [1, 2].first() == 1
&& [1.0, 2.0].first() == 1.0
&& ["hello", "world"].first() == "hello"
&& [["hello"], ["world", "!"]].first().first() == "hello"`)
prgDyn := compile(b, env, `
dyn([]).first() == 0
&& dyn([1, 2]).first() == 1
&& dyn([1.0, 2.0]).first() == 1.0
&& dyn(["hello", "world"]).first() == "hello"
&& dyn([["hello"], ["world", "!"]]).first().first() == "hello"`)
b.ResetTimer()
b.Run("DirectDispatch", func(b *testing.B) {
for i := 0; i < b.N; i++ {
prg.Eval(NoVars())
}
})
b.ResetTimer()
b.Run("DynamicDispatch", func(b *testing.B) {
for i := 0; i < b.N; i++ {
prgDyn.Eval(NoVars())
}
})
}

func interpret(t *testing.T, env *Env, expr string, vars interface{}) (ref.Val, error) {
func compile(t testing.TB, env *Env, expr string) Program {
t.Helper()
prg, err := compileOrError(t, env, expr)
if err != nil {
t.Fatal(err)
}
return prg
}

func compileOrError(t testing.TB, env *Env, expr string) (Program, error) {
t.Helper()
ast, iss := env.Compile(expr)
if iss.Err() != nil {
return nil, fmt.Errorf("env.Compile(%s) failed: %v", expr, iss.Err())
}
prg, err := env.Program(ast)
prg, err := env.Program(ast, EvalOptions(OptOptimize))
if err != nil {
return nil, fmt.Errorf("env.Program() failed: %v", err)
}
return prg, nil
}

func interpret(t testing.TB, env *Env, expr string, vars interface{}) (ref.Val, error) {
t.Helper()
prg, err := compileOrError(t, env, expr)
if err != nil {
return nil, err
}
out, _, err := prg.Eval(vars)
if err != nil {
return nil, fmt.Errorf("prg.Eval(%v) failed: %v", vars, err)
Expand Down