Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use proto safe traversal accessors #570

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
65 changes: 30 additions & 35 deletions checker/checker.go
Expand Up @@ -80,10 +80,10 @@ func (c *checker) check(e *exprpb.Expr) {
return
}

switch e.ExprKind.(type) {
switch e.GetExprKind().(type) {
case *exprpb.Expr_ConstExpr:
literal := e.GetConstExpr()
switch literal.ConstantKind.(type) {
switch literal.GetConstantKind().(type) {
case *exprpb.Constant_BoolValue:
c.checkBoolLiteral(e)
case *exprpb.Constant_BytesValue:
Expand Down Expand Up @@ -149,8 +149,8 @@ func (c *checker) checkIdent(e *exprpb.Expr) {
identExpr := e.GetIdentExpr()
// Check to see if the identifier is declared.
if ident := c.env.LookupIdent(identExpr.GetName()); ident != nil {
c.setType(e, ident.GetIdent().Type)
c.setReference(e, newIdentReference(ident.GetName(), ident.GetIdent().Value))
c.setType(e, ident.GetIdent().GetType())
c.setReference(e, newIdentReference(ident.GetName(), ident.GetIdent().GetValue()))
// Overwrite the identifier with its fully qualified name.
identExpr.Name = ident.GetName()
return
Expand Down Expand Up @@ -186,23 +186,20 @@ func (c *checker) checkSelect(e *exprpb.Expr) {
}

// Interpret as field selection, first traversing down the operand.
c.check(sel.Operand)
targetType := substitute(c.mappings, c.getType(sel.Operand), false)
c.check(sel.GetOperand())
targetType := substitute(c.mappings, c.getType(sel.GetOperand()), false)
// Assume error type by default as most types do not support field selection.
resultType := decls.Error
switch kindOf(targetType) {
case kindMap:
// Maps yield their value type as the selection result type.
mapType := targetType.GetMapType()
resultType = mapType.ValueType
resultType = mapType.GetValueType()
case kindObject:
// Objects yield their field type declaration as the selection result type, but only if
// the field is defined.
messageType := targetType
if fieldType, found := c.lookupFieldType(
c.location(e),
messageType.GetMessageType(),
sel.Field); found {
if fieldType, found := c.lookupFieldType(c.location(e), messageType.GetMessageType(), sel.GetField()); found {
resultType = fieldType.Type
}
case kindTypeParam:
Expand Down Expand Up @@ -320,15 +317,15 @@ func (c *checker) resolveOverload(

var resultType *exprpb.Type
var checkedRef *exprpb.Reference
for _, overload := range fn.GetFunction().Overloads {
for _, overload := range fn.GetFunction().GetOverloads() {
// Determine whether the overload is currently considered.
if c.env.isOverloadDisabled(overload.GetOverloadId()) {
continue
}

// Ensure the call style for the overload matches.
if (target == nil && overload.IsInstanceFunction) ||
(target != nil && !overload.IsInstanceFunction) {
if (target == nil && overload.GetIsInstanceFunction()) ||
(target != nil && !overload.GetIsInstanceFunction()) {
// not a compatible call style.
continue
}
Expand All @@ -348,13 +345,11 @@ func (c *checker) resolveOverload(
if checkedRef == nil {
checkedRef = newFunctionReference(overload.GetOverloadId())
} else {
checkedRef.OverloadId = append(checkedRef.OverloadId, overload.GetOverloadId())
checkedRef.OverloadId = append(checkedRef.GetOverloadId(), overload.GetOverloadId())
}

// First matching overload, determines result type.
fnResultType := substitute(c.mappings,
overloadType.GetFunction().GetResultType(),
false)
fnResultType := substitute(c.mappings, overloadType.GetFunction().GetResultType(), false)
if resultType == nil {
resultType = fnResultType
} else if !isDyn(resultType) && !proto.Equal(fnResultType, resultType) {
Expand All @@ -375,7 +370,7 @@ func (c *checker) resolveOverload(
func (c *checker) checkCreateList(e *exprpb.Expr) {
create := e.GetListExpr()
var elemType *exprpb.Type
for _, e := range create.Elements {
for _, e := range create.GetElements() {
c.check(e)
elemType = c.joinTypes(c.location(e), elemType, c.getType(e))
}
Expand All @@ -388,7 +383,7 @@ func (c *checker) checkCreateList(e *exprpb.Expr) {

func (c *checker) checkCreateStruct(e *exprpb.Expr) {
str := e.GetStructExpr()
if str.MessageName != "" {
if str.GetMessageName() != "" {
c.checkCreateMessage(e)
} else {
c.checkCreateMap(e)
Expand Down Expand Up @@ -419,22 +414,22 @@ func (c *checker) checkCreateMessage(e *exprpb.Expr) {
msgVal := e.GetStructExpr()
// Determine the type of the message.
messageType := decls.Error
decl := c.env.LookupIdent(msgVal.MessageName)
decl := c.env.LookupIdent(msgVal.GetMessageName())
if decl == nil {
c.errors.undeclaredReference(
c.location(e), c.env.container.Name(), msgVal.MessageName)
c.location(e), c.env.container.Name(), msgVal.GetMessageName())
return
}
// Ensure the type name is fully qualified in the AST.
msgVal.MessageName = decl.GetName()
c.setReference(e, newIdentReference(decl.GetName(), nil))
ident := decl.GetIdent()
identKind := kindOf(ident.Type)
identKind := kindOf(ident.GetType())
if identKind != kindError {
if identKind != kindType {
c.errors.notAType(c.location(e), ident.Type)
c.errors.notAType(c.location(e), ident.GetType())
} else {
messageType = ident.Type.GetType()
messageType = ident.GetType().GetType()
if kindOf(messageType) != kindObject {
c.errors.notAMessageType(c.location(e), messageType)
messageType = decls.Error
Expand All @@ -450,12 +445,12 @@ func (c *checker) checkCreateMessage(e *exprpb.Expr) {
// Check the field initializers.
for _, ent := range msgVal.GetEntries() {
field := ent.GetFieldKey()
value := ent.Value
value := ent.GetValue()
c.check(value)

fieldType := decls.Error
if t, found := c.lookupFieldType(
c.locationByID(ent.Id),
c.locationByID(ent.GetId()),
messageType.GetMessageType(),
field); found {
fieldType = t.Type
Expand Down Expand Up @@ -576,25 +571,25 @@ func (c *checker) lookupFieldType(l common.Location, messageType string, fieldNa
}

func (c *checker) setType(e *exprpb.Expr, t *exprpb.Type) {
if old, found := c.types[e.Id]; found && !proto.Equal(old, t) {
if old, found := c.types[e.GetId()]; found && !proto.Equal(old, t) {
c.errors.ReportError(c.location(e),
"(Incompatible) Type already exists for expression: %v(%d) old:%v, new:%v", e, e.GetId(), old, t)
return
}
c.types[e.Id] = t
c.types[e.GetId()] = t
}

func (c *checker) getType(e *exprpb.Expr) *exprpb.Type {
return c.types[e.Id]
return c.types[e.GetId()]
}

func (c *checker) setReference(e *exprpb.Expr, r *exprpb.Reference) {
if old, found := c.references[e.Id]; found && !proto.Equal(old, r) {
if old, found := c.references[e.GetId()]; found && !proto.Equal(old, r) {
c.errors.ReportError(c.location(e),
"Reference already exists for expression: %v(%d) old:%v, new:%v", e, e.Id, old, r)
"Reference already exists for expression: %v(%d) old:%v, new:%v", e, e.GetId(), old, r)
return
}
c.references[e.Id] = r
c.references[e.GetId()] = r
}

func (c *checker) assertType(e *exprpb.Expr, t *exprpb.Type) {
Expand All @@ -616,15 +611,15 @@ func newResolution(checkedRef *exprpb.Reference, t *exprpb.Type) *overloadResolu
}

func (c *checker) location(e *exprpb.Expr) common.Location {
return c.locationByID(e.Id)
return c.locationByID(e.GetId())
}

func (c *checker) locationByID(id int64) common.Location {
positions := c.sourceInfo.GetPositions()
var line = 1
if offset, found := positions[id]; found {
col := int(offset)
for _, lineOffset := range c.sourceInfo.LineOffsets {
for _, lineOffset := range c.sourceInfo.GetLineOffsets() {
if lineOffset < offset {
line++
col = int(offset - lineOffset)
Expand Down
16 changes: 8 additions & 8 deletions checker/cost.go
Expand Up @@ -88,9 +88,9 @@ func (e astNode) ComputedSize() *SizeEstimate {
return e.derivedSize
}
var v uint64
switch ek := e.expr.ExprKind.(type) {
switch ek := e.expr.GetExprKind().(type) {
case *exprpb.Expr_ConstExpr:
switch ck := ek.ConstExpr.ConstantKind.(type) {
switch ck := ek.ConstExpr.GetConstantKind().(type) {
case *exprpb.Constant_StringValue:
v = uint64(len(ck.StringValue))
case *exprpb.Constant_BytesValue:
Expand All @@ -103,10 +103,10 @@ func (e astNode) ComputedSize() *SizeEstimate {
return nil
}
case *exprpb.Expr_ListExpr:
v = uint64(len(ek.ListExpr.Elements))
v = uint64(len(ek.ListExpr.GetElements()))
case *exprpb.Expr_StructExpr:
if ek.StructExpr.MessageName == "" {
v = uint64(len(ek.StructExpr.Entries))
if ek.StructExpr.GetMessageName() == "" {
v = uint64(len(ek.StructExpr.GetEntries()))
}
default:
return nil
Expand Down Expand Up @@ -297,7 +297,7 @@ func (c *coster) cost(e *exprpb.Expr) CostEstimate {
return CostEstimate{}
}
var cost CostEstimate
switch e.ExprKind.(type) {
switch e.GetExprKind().(type) {
case *exprpb.Expr_ConstExpr:
cost = constCost
case *exprpb.Expr_IdentExpr:
Expand All @@ -323,7 +323,7 @@ func (c *coster) costIdent(e *exprpb.Expr) CostEstimate {

// build and track the field path
if iterRange, ok := c.iterRanges.peek(identExpr.GetName()); ok {
switch c.checkedExpr.TypeMap[iterRange].TypeKind.(type) {
switch c.checkedExpr.TypeMap[iterRange].GetTypeKind().(type) {
case *exprpb.Type_ListType_:
c.addPath(e, append(c.exprPath[iterRange], "@items"))
case *exprpb.Type_MapType_:
Expand All @@ -350,7 +350,7 @@ func (c *coster) costSelect(e *exprpb.Expr) CostEstimate {
}

// build and track the field path
c.addPath(e, append(c.getPath(sel.GetOperand()), sel.Field))
c.addPath(e, append(c.getPath(sel.GetOperand()), sel.GetField()))

return sum
}
Expand Down
8 changes: 4 additions & 4 deletions checker/printer.go
Expand Up @@ -32,22 +32,22 @@ func (a *semanticAdorner) GetMetadata(elem interface{}) string {
if !isExpr {
return result
}
t := a.checks.TypeMap[e.Id]
t := a.checks.TypeMap[e.GetId()]
if t != nil {
result += "~"
result += FormatCheckedType(t)
}

switch e.ExprKind.(type) {
switch e.GetExprKind().(type) {
case *exprpb.Expr_IdentExpr,
*exprpb.Expr_CallExpr,
*exprpb.Expr_StructExpr,
*exprpb.Expr_SelectExpr:
if ref, found := a.checks.ReferenceMap[e.Id]; found {
if ref, found := a.checks.ReferenceMap[e.GetId()]; found {
if len(ref.GetOverloadId()) == 0 {
result += "^" + ref.Name
} else {
for i, overload := range ref.OverloadId {
for i, overload := range ref.GetOverloadId() {
if i == 0 {
result += "^"
} else {
Expand Down
36 changes: 16 additions & 20 deletions checker/types.go
Expand Up @@ -52,13 +52,13 @@ func FormatCheckedType(t *exprpb.Type) string {
t.GetFunction().GetArgTypes(),
false)
case kindList:
return fmt.Sprintf("list(%s)", FormatCheckedType(t.GetListType().ElemType))
return fmt.Sprintf("list(%s)", FormatCheckedType(t.GetListType().GetElemType()))
case kindObject:
return t.GetMessageType()
case kindMap:
return fmt.Sprintf("map(%s, %s)",
FormatCheckedType(t.GetMapType().KeyType),
FormatCheckedType(t.GetMapType().ValueType))
FormatCheckedType(t.GetMapType().GetKeyType()),
FormatCheckedType(t.GetMapType().GetValueType()))
case kindNull:
return "null"
case kindPrimitive:
Expand Down Expand Up @@ -152,20 +152,20 @@ func isEqualOrLessSpecific(t1 *exprpb.Type, t2 *exprpb.Type) bool {
}
return true
case kindList:
return isEqualOrLessSpecific(t1.GetListType().ElemType, t2.GetListType().ElemType)
return isEqualOrLessSpecific(t1.GetListType().GetElemType(), t2.GetListType().GetElemType())
case kindMap:
m1 := t1.GetMapType()
m2 := t2.GetMapType()
return isEqualOrLessSpecific(m1.KeyType, m2.KeyType) &&
isEqualOrLessSpecific(m1.ValueType, m2.ValueType)
return isEqualOrLessSpecific(m1.GetKeyType(), m2.GetKeyType()) &&
isEqualOrLessSpecific(m1.GetValueType(), m2.GetValueType())
case kindType:
return true
default:
return proto.Equal(t1, t2)
}
}

/// internalIsAssignable returns true if t1 is assignable to t2.
// / internalIsAssignable returns true if t1 is assignable to t2.
func internalIsAssignable(m *mapping, t1 *exprpb.Type, t2 *exprpb.Type) bool {
// Process type parameters.
kind1, kind2 := kindOf(t1), kindOf(t2)
Expand Down Expand Up @@ -272,18 +272,14 @@ func isValidTypeSubstitution(m *mapping, t1, t2 *exprpb.Type) (valid, hasSub boo

// internalIsAssignableAbstractType returns true if the abstract type names agree and all type
// parameters are assignable.
func internalIsAssignableAbstractType(m *mapping,
a1 *exprpb.Type_AbstractType,
a2 *exprpb.Type_AbstractType) bool {
func internalIsAssignableAbstractType(m *mapping, a1 *exprpb.Type_AbstractType, a2 *exprpb.Type_AbstractType) bool {
return a1.GetName() == a2.GetName() &&
internalIsAssignableList(m, a1.GetParameterTypes(), a2.GetParameterTypes())
}

// internalIsAssignableFunction returns true if the function return type and arg types are
// assignable.
func internalIsAssignableFunction(m *mapping,
f1 *exprpb.Type_FunctionType,
f2 *exprpb.Type_FunctionType) bool {
func internalIsAssignableFunction(m *mapping, f1 *exprpb.Type_FunctionType, f2 *exprpb.Type_FunctionType) bool {
f1ArgTypes := flattenFunctionTypes(f1)
f2ArgTypes := flattenFunctionTypes(f2)
if internalIsAssignableList(m, f1ArgTypes, f2ArgTypes) {
Expand Down Expand Up @@ -363,7 +359,7 @@ func kindOf(t *exprpb.Type) int {
if t == nil || t.TypeKind == nil {
return kindUnknown
}
switch t.TypeKind.(type) {
switch t.GetTypeKind().(type) {
case *exprpb.Type_Error:
return kindError
case *exprpb.Type_Function:
Expand Down Expand Up @@ -425,10 +421,10 @@ func notReferencedIn(m *mapping, t *exprpb.Type, withinType *exprpb.Type) bool {
}
return true
case kindList:
return notReferencedIn(m, t, withinType.GetListType().ElemType)
return notReferencedIn(m, t, withinType.GetListType().GetElemType())
case kindMap:
mt := withinType.GetMapType()
return notReferencedIn(m, t, mt.KeyType) && notReferencedIn(m, t, mt.ValueType)
return notReferencedIn(m, t, mt.GetKeyType()) && notReferencedIn(m, t, mt.GetValueType())
case kindWrapper:
return notReferencedIn(m, t, decls.NewPrimitiveType(withinType.GetWrapper()))
default:
Expand Down Expand Up @@ -457,17 +453,17 @@ func substitute(m *mapping, t *exprpb.Type, typeParamToDyn bool) *exprpb.Type {
case kindFunction:
fn := t.GetFunction()
rt := substitute(m, fn.ResultType, typeParamToDyn)
args := make([]*exprpb.Type, len(fn.ArgTypes))
args := make([]*exprpb.Type, len(fn.GetArgTypes()))
for i, a := range fn.ArgTypes {
args[i] = substitute(m, a, typeParamToDyn)
}
return decls.NewFunctionType(rt, args...)
case kindList:
return decls.NewListType(substitute(m, t.GetListType().ElemType, typeParamToDyn))
return decls.NewListType(substitute(m, t.GetListType().GetElemType(), typeParamToDyn))
case kindMap:
mt := t.GetMapType()
return decls.NewMapType(substitute(m, mt.KeyType, typeParamToDyn),
substitute(m, mt.ValueType, typeParamToDyn))
return decls.NewMapType(substitute(m, mt.GetKeyType(), typeParamToDyn),
substitute(m, mt.GetValueType(), typeParamToDyn))
case kindType:
if t.GetType() != nil {
return decls.NewTypeType(substitute(m, t.GetType(), typeParamToDyn))
Expand Down