Skip to content

Commit

Permalink
Use proto safe traversal accessors within cel-go (#570)
Browse files Browse the repository at this point in the history
  • Loading branch information
TristonianJones committed Aug 3, 2022
1 parent 52567e2 commit 25bb4c6
Show file tree
Hide file tree
Showing 12 changed files with 155 additions and 157 deletions.
65 changes: 30 additions & 35 deletions checker/checker.go
Expand Up @@ -80,10 +80,10 @@ func (c *checker) check(e *exprpb.Expr) {
return
}

switch e.ExprKind.(type) {
switch e.GetExprKind().(type) {
case *exprpb.Expr_ConstExpr:
literal := e.GetConstExpr()
switch literal.ConstantKind.(type) {
switch literal.GetConstantKind().(type) {
case *exprpb.Constant_BoolValue:
c.checkBoolLiteral(e)
case *exprpb.Constant_BytesValue:
Expand Down Expand Up @@ -149,8 +149,8 @@ func (c *checker) checkIdent(e *exprpb.Expr) {
identExpr := e.GetIdentExpr()
// Check to see if the identifier is declared.
if ident := c.env.LookupIdent(identExpr.GetName()); ident != nil {
c.setType(e, ident.GetIdent().Type)
c.setReference(e, newIdentReference(ident.GetName(), ident.GetIdent().Value))
c.setType(e, ident.GetIdent().GetType())
c.setReference(e, newIdentReference(ident.GetName(), ident.GetIdent().GetValue()))
// Overwrite the identifier with its fully qualified name.
identExpr.Name = ident.GetName()
return
Expand Down Expand Up @@ -186,23 +186,20 @@ func (c *checker) checkSelect(e *exprpb.Expr) {
}

// Interpret as field selection, first traversing down the operand.
c.check(sel.Operand)
targetType := substitute(c.mappings, c.getType(sel.Operand), false)
c.check(sel.GetOperand())
targetType := substitute(c.mappings, c.getType(sel.GetOperand()), false)
// Assume error type by default as most types do not support field selection.
resultType := decls.Error
switch kindOf(targetType) {
case kindMap:
// Maps yield their value type as the selection result type.
mapType := targetType.GetMapType()
resultType = mapType.ValueType
resultType = mapType.GetValueType()
case kindObject:
// Objects yield their field type declaration as the selection result type, but only if
// the field is defined.
messageType := targetType
if fieldType, found := c.lookupFieldType(
c.location(e),
messageType.GetMessageType(),
sel.Field); found {
if fieldType, found := c.lookupFieldType(c.location(e), messageType.GetMessageType(), sel.GetField()); found {
resultType = fieldType.Type
}
case kindTypeParam:
Expand Down Expand Up @@ -320,15 +317,15 @@ func (c *checker) resolveOverload(

var resultType *exprpb.Type
var checkedRef *exprpb.Reference
for _, overload := range fn.GetFunction().Overloads {
for _, overload := range fn.GetFunction().GetOverloads() {
// Determine whether the overload is currently considered.
if c.env.isOverloadDisabled(overload.GetOverloadId()) {
continue
}

// Ensure the call style for the overload matches.
if (target == nil && overload.IsInstanceFunction) ||
(target != nil && !overload.IsInstanceFunction) {
if (target == nil && overload.GetIsInstanceFunction()) ||
(target != nil && !overload.GetIsInstanceFunction()) {
// not a compatible call style.
continue
}
Expand All @@ -348,13 +345,11 @@ func (c *checker) resolveOverload(
if checkedRef == nil {
checkedRef = newFunctionReference(overload.GetOverloadId())
} else {
checkedRef.OverloadId = append(checkedRef.OverloadId, overload.GetOverloadId())
checkedRef.OverloadId = append(checkedRef.GetOverloadId(), overload.GetOverloadId())
}

// First matching overload, determines result type.
fnResultType := substitute(c.mappings,
overloadType.GetFunction().GetResultType(),
false)
fnResultType := substitute(c.mappings, overloadType.GetFunction().GetResultType(), false)
if resultType == nil {
resultType = fnResultType
} else if !isDyn(resultType) && !proto.Equal(fnResultType, resultType) {
Expand All @@ -375,7 +370,7 @@ func (c *checker) resolveOverload(
func (c *checker) checkCreateList(e *exprpb.Expr) {
create := e.GetListExpr()
var elemType *exprpb.Type
for _, e := range create.Elements {
for _, e := range create.GetElements() {
c.check(e)
elemType = c.joinTypes(c.location(e), elemType, c.getType(e))
}
Expand All @@ -388,7 +383,7 @@ func (c *checker) checkCreateList(e *exprpb.Expr) {

func (c *checker) checkCreateStruct(e *exprpb.Expr) {
str := e.GetStructExpr()
if str.MessageName != "" {
if str.GetMessageName() != "" {
c.checkCreateMessage(e)
} else {
c.checkCreateMap(e)
Expand Down Expand Up @@ -419,22 +414,22 @@ func (c *checker) checkCreateMessage(e *exprpb.Expr) {
msgVal := e.GetStructExpr()
// Determine the type of the message.
messageType := decls.Error
decl := c.env.LookupIdent(msgVal.MessageName)
decl := c.env.LookupIdent(msgVal.GetMessageName())
if decl == nil {
c.errors.undeclaredReference(
c.location(e), c.env.container.Name(), msgVal.MessageName)
c.location(e), c.env.container.Name(), msgVal.GetMessageName())
return
}
// Ensure the type name is fully qualified in the AST.
msgVal.MessageName = decl.GetName()
c.setReference(e, newIdentReference(decl.GetName(), nil))
ident := decl.GetIdent()
identKind := kindOf(ident.Type)
identKind := kindOf(ident.GetType())
if identKind != kindError {
if identKind != kindType {
c.errors.notAType(c.location(e), ident.Type)
c.errors.notAType(c.location(e), ident.GetType())
} else {
messageType = ident.Type.GetType()
messageType = ident.GetType().GetType()
if kindOf(messageType) != kindObject {
c.errors.notAMessageType(c.location(e), messageType)
messageType = decls.Error
Expand All @@ -450,12 +445,12 @@ func (c *checker) checkCreateMessage(e *exprpb.Expr) {
// Check the field initializers.
for _, ent := range msgVal.GetEntries() {
field := ent.GetFieldKey()
value := ent.Value
value := ent.GetValue()
c.check(value)

fieldType := decls.Error
if t, found := c.lookupFieldType(
c.locationByID(ent.Id),
c.locationByID(ent.GetId()),
messageType.GetMessageType(),
field); found {
fieldType = t.Type
Expand Down Expand Up @@ -576,25 +571,25 @@ func (c *checker) lookupFieldType(l common.Location, messageType string, fieldNa
}

func (c *checker) setType(e *exprpb.Expr, t *exprpb.Type) {
if old, found := c.types[e.Id]; found && !proto.Equal(old, t) {
if old, found := c.types[e.GetId()]; found && !proto.Equal(old, t) {
c.errors.ReportError(c.location(e),
"(Incompatible) Type already exists for expression: %v(%d) old:%v, new:%v", e, e.GetId(), old, t)
return
}
c.types[e.Id] = t
c.types[e.GetId()] = t
}

func (c *checker) getType(e *exprpb.Expr) *exprpb.Type {
return c.types[e.Id]
return c.types[e.GetId()]
}

func (c *checker) setReference(e *exprpb.Expr, r *exprpb.Reference) {
if old, found := c.references[e.Id]; found && !proto.Equal(old, r) {
if old, found := c.references[e.GetId()]; found && !proto.Equal(old, r) {
c.errors.ReportError(c.location(e),
"Reference already exists for expression: %v(%d) old:%v, new:%v", e, e.Id, old, r)
"Reference already exists for expression: %v(%d) old:%v, new:%v", e, e.GetId(), old, r)
return
}
c.references[e.Id] = r
c.references[e.GetId()] = r
}

func (c *checker) assertType(e *exprpb.Expr, t *exprpb.Type) {
Expand All @@ -616,15 +611,15 @@ func newResolution(checkedRef *exprpb.Reference, t *exprpb.Type) *overloadResolu
}

func (c *checker) location(e *exprpb.Expr) common.Location {
return c.locationByID(e.Id)
return c.locationByID(e.GetId())
}

func (c *checker) locationByID(id int64) common.Location {
positions := c.sourceInfo.GetPositions()
var line = 1
if offset, found := positions[id]; found {
col := int(offset)
for _, lineOffset := range c.sourceInfo.LineOffsets {
for _, lineOffset := range c.sourceInfo.GetLineOffsets() {
if lineOffset < offset {
line++
col = int(offset - lineOffset)
Expand Down
16 changes: 8 additions & 8 deletions checker/cost.go
Expand Up @@ -88,9 +88,9 @@ func (e astNode) ComputedSize() *SizeEstimate {
return e.derivedSize
}
var v uint64
switch ek := e.expr.ExprKind.(type) {
switch ek := e.expr.GetExprKind().(type) {
case *exprpb.Expr_ConstExpr:
switch ck := ek.ConstExpr.ConstantKind.(type) {
switch ck := ek.ConstExpr.GetConstantKind().(type) {
case *exprpb.Constant_StringValue:
v = uint64(len(ck.StringValue))
case *exprpb.Constant_BytesValue:
Expand All @@ -103,10 +103,10 @@ func (e astNode) ComputedSize() *SizeEstimate {
return nil
}
case *exprpb.Expr_ListExpr:
v = uint64(len(ek.ListExpr.Elements))
v = uint64(len(ek.ListExpr.GetElements()))
case *exprpb.Expr_StructExpr:
if ek.StructExpr.MessageName == "" {
v = uint64(len(ek.StructExpr.Entries))
if ek.StructExpr.GetMessageName() == "" {
v = uint64(len(ek.StructExpr.GetEntries()))
}
default:
return nil
Expand Down Expand Up @@ -297,7 +297,7 @@ func (c *coster) cost(e *exprpb.Expr) CostEstimate {
return CostEstimate{}
}
var cost CostEstimate
switch e.ExprKind.(type) {
switch e.GetExprKind().(type) {
case *exprpb.Expr_ConstExpr:
cost = constCost
case *exprpb.Expr_IdentExpr:
Expand All @@ -323,7 +323,7 @@ func (c *coster) costIdent(e *exprpb.Expr) CostEstimate {

// build and track the field path
if iterRange, ok := c.iterRanges.peek(identExpr.GetName()); ok {
switch c.checkedExpr.TypeMap[iterRange].TypeKind.(type) {
switch c.checkedExpr.TypeMap[iterRange].GetTypeKind().(type) {
case *exprpb.Type_ListType_:
c.addPath(e, append(c.exprPath[iterRange], "@items"))
case *exprpb.Type_MapType_:
Expand All @@ -350,7 +350,7 @@ func (c *coster) costSelect(e *exprpb.Expr) CostEstimate {
}

// build and track the field path
c.addPath(e, append(c.getPath(sel.GetOperand()), sel.Field))
c.addPath(e, append(c.getPath(sel.GetOperand()), sel.GetField()))

return sum
}
Expand Down
8 changes: 4 additions & 4 deletions checker/printer.go
Expand Up @@ -32,22 +32,22 @@ func (a *semanticAdorner) GetMetadata(elem interface{}) string {
if !isExpr {
return result
}
t := a.checks.TypeMap[e.Id]
t := a.checks.TypeMap[e.GetId()]
if t != nil {
result += "~"
result += FormatCheckedType(t)
}

switch e.ExprKind.(type) {
switch e.GetExprKind().(type) {
case *exprpb.Expr_IdentExpr,
*exprpb.Expr_CallExpr,
*exprpb.Expr_StructExpr,
*exprpb.Expr_SelectExpr:
if ref, found := a.checks.ReferenceMap[e.Id]; found {
if ref, found := a.checks.ReferenceMap[e.GetId()]; found {
if len(ref.GetOverloadId()) == 0 {
result += "^" + ref.Name
} else {
for i, overload := range ref.OverloadId {
for i, overload := range ref.GetOverloadId() {
if i == 0 {
result += "^"
} else {
Expand Down
36 changes: 16 additions & 20 deletions checker/types.go
Expand Up @@ -52,13 +52,13 @@ func FormatCheckedType(t *exprpb.Type) string {
t.GetFunction().GetArgTypes(),
false)
case kindList:
return fmt.Sprintf("list(%s)", FormatCheckedType(t.GetListType().ElemType))
return fmt.Sprintf("list(%s)", FormatCheckedType(t.GetListType().GetElemType()))
case kindObject:
return t.GetMessageType()
case kindMap:
return fmt.Sprintf("map(%s, %s)",
FormatCheckedType(t.GetMapType().KeyType),
FormatCheckedType(t.GetMapType().ValueType))
FormatCheckedType(t.GetMapType().GetKeyType()),
FormatCheckedType(t.GetMapType().GetValueType()))
case kindNull:
return "null"
case kindPrimitive:
Expand Down Expand Up @@ -152,20 +152,20 @@ func isEqualOrLessSpecific(t1 *exprpb.Type, t2 *exprpb.Type) bool {
}
return true
case kindList:
return isEqualOrLessSpecific(t1.GetListType().ElemType, t2.GetListType().ElemType)
return isEqualOrLessSpecific(t1.GetListType().GetElemType(), t2.GetListType().GetElemType())
case kindMap:
m1 := t1.GetMapType()
m2 := t2.GetMapType()
return isEqualOrLessSpecific(m1.KeyType, m2.KeyType) &&
isEqualOrLessSpecific(m1.ValueType, m2.ValueType)
return isEqualOrLessSpecific(m1.GetKeyType(), m2.GetKeyType()) &&
isEqualOrLessSpecific(m1.GetValueType(), m2.GetValueType())
case kindType:
return true
default:
return proto.Equal(t1, t2)
}
}

/// internalIsAssignable returns true if t1 is assignable to t2.
// / internalIsAssignable returns true if t1 is assignable to t2.
func internalIsAssignable(m *mapping, t1 *exprpb.Type, t2 *exprpb.Type) bool {
// Process type parameters.
kind1, kind2 := kindOf(t1), kindOf(t2)
Expand Down Expand Up @@ -272,18 +272,14 @@ func isValidTypeSubstitution(m *mapping, t1, t2 *exprpb.Type) (valid, hasSub boo

// internalIsAssignableAbstractType returns true if the abstract type names agree and all type
// parameters are assignable.
func internalIsAssignableAbstractType(m *mapping,
a1 *exprpb.Type_AbstractType,
a2 *exprpb.Type_AbstractType) bool {
func internalIsAssignableAbstractType(m *mapping, a1 *exprpb.Type_AbstractType, a2 *exprpb.Type_AbstractType) bool {
return a1.GetName() == a2.GetName() &&
internalIsAssignableList(m, a1.GetParameterTypes(), a2.GetParameterTypes())
}

// internalIsAssignableFunction returns true if the function return type and arg types are
// assignable.
func internalIsAssignableFunction(m *mapping,
f1 *exprpb.Type_FunctionType,
f2 *exprpb.Type_FunctionType) bool {
func internalIsAssignableFunction(m *mapping, f1 *exprpb.Type_FunctionType, f2 *exprpb.Type_FunctionType) bool {
f1ArgTypes := flattenFunctionTypes(f1)
f2ArgTypes := flattenFunctionTypes(f2)
if internalIsAssignableList(m, f1ArgTypes, f2ArgTypes) {
Expand Down Expand Up @@ -363,7 +359,7 @@ func kindOf(t *exprpb.Type) int {
if t == nil || t.TypeKind == nil {
return kindUnknown
}
switch t.TypeKind.(type) {
switch t.GetTypeKind().(type) {
case *exprpb.Type_Error:
return kindError
case *exprpb.Type_Function:
Expand Down Expand Up @@ -425,10 +421,10 @@ func notReferencedIn(m *mapping, t *exprpb.Type, withinType *exprpb.Type) bool {
}
return true
case kindList:
return notReferencedIn(m, t, withinType.GetListType().ElemType)
return notReferencedIn(m, t, withinType.GetListType().GetElemType())
case kindMap:
mt := withinType.GetMapType()
return notReferencedIn(m, t, mt.KeyType) && notReferencedIn(m, t, mt.ValueType)
return notReferencedIn(m, t, mt.GetKeyType()) && notReferencedIn(m, t, mt.GetValueType())
case kindWrapper:
return notReferencedIn(m, t, decls.NewPrimitiveType(withinType.GetWrapper()))
default:
Expand Down Expand Up @@ -457,17 +453,17 @@ func substitute(m *mapping, t *exprpb.Type, typeParamToDyn bool) *exprpb.Type {
case kindFunction:
fn := t.GetFunction()
rt := substitute(m, fn.ResultType, typeParamToDyn)
args := make([]*exprpb.Type, len(fn.ArgTypes))
args := make([]*exprpb.Type, len(fn.GetArgTypes()))
for i, a := range fn.ArgTypes {
args[i] = substitute(m, a, typeParamToDyn)
}
return decls.NewFunctionType(rt, args...)
case kindList:
return decls.NewListType(substitute(m, t.GetListType().ElemType, typeParamToDyn))
return decls.NewListType(substitute(m, t.GetListType().GetElemType(), typeParamToDyn))
case kindMap:
mt := t.GetMapType()
return decls.NewMapType(substitute(m, mt.KeyType, typeParamToDyn),
substitute(m, mt.ValueType, typeParamToDyn))
return decls.NewMapType(substitute(m, mt.GetKeyType(), typeParamToDyn),
substitute(m, mt.GetValueType(), typeParamToDyn))
case kindType:
if t.GetType() != nil {
return decls.NewTypeType(substitute(m, t.GetType(), typeParamToDyn))
Expand Down

0 comments on commit 25bb4c6

Please sign in to comment.