From f66340904a2fad821efead3b3b5d155cb41fe596 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Thu, 5 Jan 2023 14:32:20 -0800 Subject: [PATCH] Fix nested CallExpr cost estimation (#571) (#624) * Fix nested CallExpr cost estimation. * Add TODO about merging isScalar/ComputedSize. Co-authored-by: Kermit Alexander II --- checker/cost.go | 26 +++++++++++++++++++ checker/cost_test.go | 61 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+) diff --git a/checker/cost.go b/checker/cost.go index 4e1ea5aa..6f832d96 100644 --- a/checker/cost.go +++ b/checker/cost.go @@ -476,6 +476,15 @@ func (c *coster) sizeEstimate(t AstNode) SizeEstimate { if l := c.estimator.EstimateSize(t); l != nil { return *l } + // return an estimate of 1 for return types of set + // lengths, since strings/bytes/more complex objects could be of + // variable length + if isScalar(t.Type()) { + // TODO: since the logic for size estimation is split between + // ComputedSize and isScalar, changing one will likely require changing + // the other, so they should be merged in the future if possible + return SizeEstimate{Min: 1, Max: 1} + } return SizeEstimate{Min: 0, Max: math.MaxUint64} } @@ -599,3 +608,20 @@ func (c *coster) newAstNode(e *exprpb.Expr) *astNode { } return &astNode{path: path, t: c.getType(e), expr: e, derivedSize: derivedSize} } + +// isScalar returns true if the given type is known to be of a constant size at +// compile time. isScalar will return false for strings (they are variable-width) +// in addition to protobuf.Any and protobuf.Value (their size is not knowable at compile time). +func isScalar(t *exprpb.Type) bool { + switch kindOf(t) { + case kindPrimitive: + if t.GetPrimitive() != exprpb.Type_STRING && t.GetPrimitive() != exprpb.Type_BYTES { + return true + } + case kindWellKnown: + if t.GetWellKnown() == exprpb.Type_DURATION || t.GetWellKnown() == exprpb.Type_TIMESTAMP { + return true + } + } + return false +} diff --git a/checker/cost_test.go b/checker/cost_test.go index e2118f87..4e8cb9b4 100644 --- a/checker/cost_test.go +++ b/checker/cost_test.go @@ -354,6 +354,67 @@ func TestCost(t *testing.T) { hints: map[string]int64{"str1": 10, "str2": 10}, wanted: CostEstimate{Min: 2, Max: 6}, }, + { + name: "list size comparison", + expr: `list1.size() == list2.size()`, + decls: []*exprpb.Decl{ + decls.NewVar("list1", decls.NewListType(decls.Int)), + decls.NewVar("list2", decls.NewListType(decls.Int)), + }, + wanted: CostEstimate{Min: 5, Max: 5}, + }, + { + name: "list size from ternary", + expr: `x > y ? list1.size() : list2.size()`, + decls: []*exprpb.Decl{ + decls.NewVar("x", decls.Int), + decls.NewVar("y", decls.Int), + decls.NewVar("list1", decls.NewListType(decls.Int)), + decls.NewVar("list2", decls.NewListType(decls.Int)), + }, + wanted: CostEstimate{Min: 5, Max: 5}, + }, + { + name: "str endsWith equality", + expr: `str1.endsWith("abcdefghijklmnopqrstuvwxyz") == str2.endsWith("abcdefghijklmnopqrstuvwxyz")`, + decls: []*exprpb.Decl{ + decls.NewVar("str1", decls.String), + decls.NewVar("str2", decls.String), + }, + wanted: CostEstimate{Min: 9, Max: 9}, + }, + { + name: "nested subexpression operators", + expr: `((5 != 6) == (1 == 2)) == ((3 <= 4) == (9 != 9))`, + wanted: CostEstimate{Min: 7, Max: 7}, + }, + { + name: "str size estimate", + expr: `string(timestamp1) == string(timestamp2)`, + decls: []*exprpb.Decl{ + decls.NewVar("timestamp1", decls.Timestamp), + decls.NewVar("timestamp2", decls.Timestamp), + }, + wanted: CostEstimate{Min: 5, Max: 1844674407370955268}, + }, + { + name: "timestamp equality check", + expr: `timestamp1 == timestamp2`, + decls: []*exprpb.Decl{ + decls.NewVar("timestamp1", decls.Timestamp), + decls.NewVar("timestamp2", decls.Timestamp), + }, + wanted: CostEstimate{Min: 3, Max: 3}, + }, + { + name: "duration inequality check", + expr: `duration1 != duration2`, + decls: []*exprpb.Decl{ + decls.NewVar("duration1", decls.Duration), + decls.NewVar("duration2", decls.Duration), + }, + wanted: CostEstimate{Min: 3, Max: 3}, + }, } for _, tc := range cases {