Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix nested CallExpr cost estimation #571

Merged
merged 10 commits into from Aug 17, 2022
33 changes: 33 additions & 0 deletions checker/cost.go
Expand Up @@ -476,6 +476,12 @@ func (c *coster) sizeEstimate(t AstNode) SizeEstimate {
if l := c.estimator.EstimateSize(t); l != nil {
return *l
}
// return an estimate of 1 for return types of set
// lengths, since strings/bytes/more complex objects could be of
// variable length
if isScalar(t.Type()) {
return SizeEstimate{Min: 1, Max: 1}
}
return SizeEstimate{Min: 0, Max: math.MaxUint64}
}

Expand Down Expand Up @@ -599,3 +605,30 @@ func (c *coster) newAstNode(e *exprpb.Expr) *astNode {
}
return &astNode{path: path, t: c.getType(e), expr: e, derivedSize: derivedSize}
}

func isScalar(t *exprpb.Type) bool {
DangerOnTheRanger marked this conversation as resolved.
Show resolved Hide resolved
switch kindOf(t) {
case kindPrimitive:
if t.GetPrimitive() != exprpb.Type_STRING && t.GetPrimitive() != exprpb.Type_BYTES {
DangerOnTheRanger marked this conversation as resolved.
Show resolved Hide resolved
return true
}
case kindWellKnown:
if t.GetWellKnown() == exprpb.Type_DURATION || t.GetWellKnown() == exprpb.Type_TIMESTAMP {
return true
}
case kindObject:
DangerOnTheRanger marked this conversation as resolved.
Show resolved Hide resolved
switch t.GetMessageType() {
case "google.protobuf.Duration", "google.protobuf.Timestamp",
"google.protobuf.BoolValue", "google.protobuf.BytesValue",
"google.protobuf.DoubleValue", "google.protobuf.FloatValue",
"google.protobuf.Int32Value", "google.protobuf.Int64Value",
"google.protobuf.UInt32Value", "google.protobuf.UInt64Value":
DangerOnTheRanger marked this conversation as resolved.
Show resolved Hide resolved
return true
default:
return false
}
default:
return false
}
return false
}
61 changes: 61 additions & 0 deletions checker/cost_test.go
Expand Up @@ -354,6 +354,67 @@ func TestCost(t *testing.T) {
hints: map[string]int64{"str1": 10, "str2": 10},
wanted: CostEstimate{Min: 2, Max: 6},
},
{
name: "list size comparison",
expr: `list1.size() == list2.size()`,
DangerOnTheRanger marked this conversation as resolved.
Show resolved Hide resolved
decls: []*exprpb.Decl{
decls.NewVar("list1", decls.NewListType(decls.Int)),
decls.NewVar("list2", decls.NewListType(decls.Int)),
},
wanted: CostEstimate{Min: 5, Max: 5},
},
{
name: "list size from ternary",
expr: `x > y ? list1.size() : list2.size()`,
decls: []*exprpb.Decl{
decls.NewVar("x", decls.Int),
decls.NewVar("y", decls.Int),
decls.NewVar("list1", decls.NewListType(decls.Int)),
decls.NewVar("list2", decls.NewListType(decls.Int)),
},
wanted: CostEstimate{Min: 5, Max: 5},
},
{
name: "str endsWith equality",
expr: `str1.endsWith("abcdefghijklmnopqrstuvwxyz") == str2.endsWith("abcdefghijklmnopqrstuvwxyz")`,
decls: []*exprpb.Decl{
decls.NewVar("str1", decls.String),
decls.NewVar("str2", decls.String),
},
wanted: CostEstimate{Min: 9, Max: 9},
},
{
name: "nested subexpression operators",
expr: `((5 != 6) == (1 == 2)) == ((3 <= 4) == (9 != 9))`,
wanted: CostEstimate{Min: 7, Max: 7},
},
{
name: "str size estimate",
expr: `string(timestamp1) == string(timestamp2)`,
decls: []*exprpb.Decl{
decls.NewVar("timestamp1", decls.Timestamp),
decls.NewVar("timestamp2", decls.Timestamp),
},
wanted: CostEstimate{Min: 5, Max: 1844674407370955268},
},
{
name: "timestamp equality check",
expr: `timestamp1 == timestamp2`,
decls: []*exprpb.Decl{
decls.NewVar("timestamp1", decls.Timestamp),
decls.NewVar("timestamp2", decls.Timestamp),
},
wanted: CostEstimate{Min: 3, Max: 3},
},
{
name: "duration inequality check",
expr: `duration1 != duration2`,
decls: []*exprpb.Decl{
decls.NewVar("duration1", decls.Duration),
decls.NewVar("duration2", decls.Duration),
},
wanted: CostEstimate{Min: 3, Max: 3},
},
}

for _, tc := range cases {
Expand Down