Skip to content

Commit

Permalink
Merge pull request #3340 from onflow/supun/port-220
Browse files Browse the repository at this point in the history
Use parent to determine whether an expression is target of an invocation
  • Loading branch information
SupunS committed May 14, 2024
2 parents 282c93c + a0e71bd commit 49c7f8f
Show file tree
Hide file tree
Showing 27 changed files with 146 additions and 71 deletions.
2 changes: 1 addition & 1 deletion runtime/repl.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ func (r *REPL) Accept(code []byte, eval bool) (inputIsComplete bool, err error)
var expressionType sema.Type
expressionStatement, isExpression := statement.(*ast.ExpressionStatement)
if isExpression {
expressionType = r.checker.VisitExpression(expressionStatement.Expression, nil)
expressionType = r.checker.VisitExpression(expressionStatement.Expression, expressionStatement, nil)
if !eval && expressionType != sema.InvalidType {
r.onExpressionType(expressionType)
}
Expand Down
18 changes: 9 additions & 9 deletions runtime/sema/check_array_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package sema

import "github.com/onflow/cadence/runtime/ast"

func (checker *Checker) VisitArrayExpression(expression *ast.ArrayExpression) Type {
func (checker *Checker) VisitArrayExpression(arrayExpression *ast.ArrayExpression) Type {

// visit all elements, ensure they are all the same type

Expand All @@ -29,7 +29,7 @@ func (checker *Checker) VisitArrayExpression(expression *ast.ArrayExpression) Ty
var elementType Type
var resultType ArrayType

elementCount := len(expression.Values)
elementCount := len(arrayExpression.Values)

switch typ := expectedType.(type) {

Expand All @@ -43,7 +43,7 @@ func (checker *Checker) VisitArrayExpression(expression *ast.ArrayExpression) Ty
&ConstantSizedArrayLiteralSizeError{
ExpectedSize: typ.Size,
ActualSize: literalCount,
Range: expression.Range,
Range: arrayExpression.Range,
},
)
}
Expand All @@ -68,13 +68,13 @@ func (checker *Checker) VisitArrayExpression(expression *ast.ArrayExpression) Ty
if elementCount > 0 {
argumentTypes = make([]Type, elementCount)

for i, value := range expression.Values {
valueType := checker.VisitExpression(value, elementType)
for i, element := range arrayExpression.Values {
valueType := checker.VisitExpression(element, arrayExpression, elementType)

argumentTypes[i] = valueType

checker.checkVariableMove(value)
checker.checkResourceMoveOperation(value, valueType)
checker.checkVariableMove(element)
checker.checkResourceMoveOperation(element, valueType)
}
}

Expand All @@ -87,7 +87,7 @@ func (checker *Checker) VisitArrayExpression(expression *ast.ArrayExpression) Ty
checker.report(
&TypeAnnotationRequiredError{
Cause: "cannot infer type from array literal:",
Pos: expression.StartPos,
Pos: arrayExpression.StartPos,
},
)

Expand All @@ -100,7 +100,7 @@ func (checker *Checker) VisitArrayExpression(expression *ast.ArrayExpression) Ty
}

checker.Elaboration.SetArrayExpressionTypes(
expression,
arrayExpression,
ArrayExpressionTypes{
ArgumentTypes: argumentTypes,
ArrayType: resultType,
Expand Down
2 changes: 1 addition & 1 deletion runtime/sema/check_assignment.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func (checker *Checker) checkAssignment(
if checker.accessedSelfMember(target) == nil {
checkValue = checker.VisitExpressionWithReferenceCheck
}
valueType = checkValue(value, targetType)
valueType = checkValue(value, assignment, targetType)

// NOTE: Visiting the `value` checks the compatibility between value and target types.
// Check for the *target* type, so that assignment using non-resource typed value (e.g. `nil`)
Expand Down
2 changes: 1 addition & 1 deletion runtime/sema/check_attach_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (checker *Checker) VisitAttachExpression(expression *ast.AttachExpression)
attachment := expression.Attachment
baseExpression := expression.Base

baseType := checker.VisitExpression(baseExpression, checker.expectedType)
baseType := checker.VisitExpression(baseExpression, expression, checker.expectedType)
attachmentType := checker.checkInvocationExpression(attachment)

if attachmentType.IsInvalidType() || baseType.IsInvalidType() {
Expand Down
21 changes: 18 additions & 3 deletions runtime/sema/check_binary_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,12 @@ func (checker *Checker) VisitBinaryExpression(expression *ast.BinaryExpression)
// Visit the expression, with contextually expected type. Use the expected type
// only for inferring wherever possible, but do not check for compatibility.
// Compatibility is checked separately for each operand kind.
leftType = checker.VisitExpressionWithForceType(expression.Left, expectedType, false)
leftType = checker.VisitExpressionWithForceType(
expression.Left,
expression,
expectedType,
false,
)

leftIsInvalid := leftType.IsInvalidType()

Expand Down Expand Up @@ -123,7 +128,12 @@ func (checker *Checker) VisitBinaryExpression(expression *ast.BinaryExpression)
expectedType = leftType
}

rightType = checker.VisitExpressionWithForceType(expression.Right, expectedType, false)
rightType = checker.VisitExpressionWithForceType(
expression.Right,
expression,
expectedType,
false,
)

rightIsInvalid := rightType.IsInvalidType()

Expand Down Expand Up @@ -174,7 +184,12 @@ func (checker *Checker) VisitBinaryExpression(expression *ast.BinaryExpression)
expectedType = optionalLeftType.Type
}
}
return checker.VisitExpressionWithForceType(expression.Right, expectedType, false)
return checker.VisitExpressionWithForceType(
expression.Right,
expression,
expectedType,
false,
)
})

rightIsInvalid := rightType.IsInvalidType()
Expand Down
2 changes: 1 addition & 1 deletion runtime/sema/check_casting_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func (checker *Checker) VisitCastingExpression(expression *ast.CastingExpression

beforeErrors := len(checker.errors)

leftHandType, exprActualType := checker.visitExpression(leftHandExpression, expectedType)
leftHandType, exprActualType := checker.visitExpression(leftHandExpression, expression, expectedType)

hasErrors := len(checker.errors) > beforeErrors

Expand Down
12 changes: 10 additions & 2 deletions runtime/sema/check_composite_declaration.go
Original file line number Diff line number Diff line change
Expand Up @@ -2087,6 +2087,7 @@ func (checker *Checker) checkDefaultDestroyParamExpressionKind(

func (checker *Checker) checkDefaultDestroyEventParam(
param Parameter,
eventDeclaration ast.CompositeLikeDeclaration,
astParam *ast.Parameter,
containerType EntitlementSupportingType,
containerDeclaration ast.Declaration,
Expand All @@ -2113,7 +2114,8 @@ func (checker *Checker) checkDefaultDestroyEventParam(
compositeContainer,
compositeContainer.baseTypeDocString)
}
param.DefaultArgument = checker.VisitExpression(paramDefaultArgument, paramType)

param.DefaultArgument = checker.VisitExpression(paramDefaultArgument, eventDeclaration, paramType)

// default events must have default arguments for all their parameters; this is enforced in the parser
// we want to check that these arguments are all either literals or field accesses, and have primitive types
Expand Down Expand Up @@ -2143,7 +2145,13 @@ func (checker *Checker) checkDefaultDestroyEvent(
defer checker.leaveValueScope(eventDeclaration.EndPosition, true)

for index, param := range eventType.ConstructorParameters {
checker.checkDefaultDestroyEventParam(param, constructorFunctionParameters[index], containerType, containerDeclaration)
checker.checkDefaultDestroyEventParam(
param,
eventDeclaration,
constructorFunctionParameters[index],
containerType,
containerDeclaration,
)
}
}

Expand Down
8 changes: 4 additions & 4 deletions runtime/sema/check_conditional.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func (checker *Checker) VisitIfStatement(statement *ast.IfStatement) (_ struct{}

switch test := statement.Test.(type) {
case ast.Expression:
checker.VisitExpression(test, BoolType)
checker.VisitExpression(test, statement, BoolType)

checker.checkConditionalBranches(
func() Type {
Expand Down Expand Up @@ -90,14 +90,14 @@ func (checker *Checker) VisitConditionalExpression(expression *ast.ConditionalEx

expectedType := checker.expectedType

checker.VisitExpression(expression.Test, BoolType)
checker.VisitExpression(expression.Test, expression, BoolType)

thenType, elseType := checker.checkConditionalBranches(
func() Type {
return checker.VisitExpression(expression.Then, expectedType)
return checker.VisitExpression(expression.Then, expression, expectedType)
},
func() Type {
return checker.VisitExpression(expression.Else, expectedType)
return checker.VisitExpression(expression.Else, expression, expectedType)
},
)

Expand Down
4 changes: 2 additions & 2 deletions runtime/sema/check_conditions.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ func (checker *Checker) checkCondition(condition ast.Condition) {
case *ast.TestCondition:

// check test expression is boolean
checker.VisitExpression(condition.Test, BoolType)
checker.VisitExpression(condition.Test, condition, BoolType)

// check message expression results in a string
if condition.Message != nil {
checker.VisitExpression(condition.Message, StringType)
checker.VisitExpression(condition.Message, condition, StringType)
}

case *ast.EmitCondition:
Expand Down
2 changes: 1 addition & 1 deletion runtime/sema/check_create_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (checker *Checker) VisitCreateExpression(expression *ast.CreateExpression)

invocation := expression.InvocationExpression

ty := checker.VisitExpression(invocation, nil)
ty := checker.VisitExpression(invocation, expression, nil)

if ty.IsInvalidType() {
return ty
Expand Down
2 changes: 1 addition & 1 deletion runtime/sema/check_destroy_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import (
func (checker *Checker) VisitDestroyExpression(expression *ast.DestroyExpression) (resultType Type) {
resultType = VoidType

valueType := checker.VisitExpression(expression.Expression, nil)
valueType := checker.VisitExpression(expression.Expression, expression, nil)

checker.ObserveImpureOperation(expression)
checker.recordResourceInvalidation(
Expand Down
4 changes: 2 additions & 2 deletions runtime/sema/check_dictionary_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ func (checker *Checker) VisitDictionaryExpression(expression *ast.DictionaryExpr
// NOTE: important to check move after each type check,
// not combined after both type checks!

entryKeyType := checker.VisitExpression(entry.Key, keyType)
entryKeyType := checker.VisitExpression(entry.Key, expression, keyType)
checker.checkVariableMove(entry.Key)
checker.checkResourceMoveOperation(entry.Key, entryKeyType)

entryValueType := checker.VisitExpression(entry.Value, valueType)
entryValueType := checker.VisitExpression(entry.Value, expression, valueType)
checker.checkVariableMove(entry.Value)
checker.checkResourceMoveOperation(entry.Value, entryValueType)

Expand Down
5 changes: 3 additions & 2 deletions runtime/sema/check_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ func (checker *Checker) checkResourceVariableCapturingInFunction(variable *Varia
func (checker *Checker) VisitExpressionStatement(statement *ast.ExpressionStatement) (_ struct{}) {
expression := statement.Expression

ty := checker.VisitExpression(expression, nil)
ty := checker.VisitExpression(expression, statement, nil)

if ty.IsResourceType() {
checker.report(
Expand Down Expand Up @@ -270,7 +270,7 @@ func (checker *Checker) visitIndexExpression(
) Type {

targetExpression := indexExpression.TargetExpression
targetType := checker.VisitExpression(targetExpression, nil)
targetType := checker.VisitExpression(targetExpression, indexExpression, nil)

// NOTE: check indexed type first for UX reasons

Expand Down Expand Up @@ -309,6 +309,7 @@ func (checker *Checker) visitIndexExpression(
}
indexingType := checker.VisitExpression(
indexExpression.IndexingExpression,
indexExpression,
valueIndexedType.IndexingType(),
)

Expand Down
2 changes: 1 addition & 1 deletion runtime/sema/check_for.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func (checker *Checker) VisitForStatement(statement *ast.ForStatement) (_ struct
}
}

valueType := checker.VisitExpression(valueExpression, expectedType)
valueType := checker.VisitExpression(valueExpression, statement, expectedType)

// Only get the element type if the array is not a resource array.
// Otherwise, in addition to the `UnsupportedResourceForLoopError`,
Expand Down
2 changes: 1 addition & 1 deletion runtime/sema/check_force_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func (checker *Checker) VisitForceExpression(expression *ast.ForceExpression) Ty
// i.e: if `x!` is `String`, then `x` is expected to be `String?`.
expectedType := wrapWithOptionalIfNotNil(checker.expectedType)

valueType := checker.VisitExpression(expression.Expression, expectedType)
valueType := checker.VisitExpression(expression.Expression, expression, expectedType)

if valueType.IsInvalidType() {
return valueType
Expand Down
16 changes: 8 additions & 8 deletions runtime/sema/check_invocation_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func (checker *Checker) checkInvocationExpression(invocationExpression *ast.Invo
// check the invoked expression can be invoked

invokedExpression := invocationExpression.InvokedExpression
expressionType := checker.VisitExpression(invokedExpression, nil)
expressionType := checker.VisitExpression(invokedExpression, invocationExpression, nil)

// `inInvocation` should be reset before visiting arguments
checker.inInvocation = false
Expand Down Expand Up @@ -131,7 +131,7 @@ func (checker *Checker) checkInvocationExpression(invocationExpression *ast.Invo
argumentTypes = make([]Type, 0, argumentCount)

for _, argument := range invocationExpression.Arguments {
argumentType := checker.VisitExpression(argument.Expression, nil)
argumentType := checker.VisitExpression(argument.Expression, invocationExpression, nil)
argumentTypes = append(argumentTypes, argumentType)
}

Expand Down Expand Up @@ -469,7 +469,7 @@ func (checker *Checker) checkInvocation(

parameterTypes[argumentIndex] =
checker.checkInvocationRequiredArgument(
invocationExpression.Arguments,
invocationExpression,
argumentIndex,
functionType,
argumentTypes,
Expand All @@ -482,7 +482,7 @@ func (checker *Checker) checkInvocation(
for i := minCount; i < argumentCount; i++ {
argument := invocationExpression.Arguments[i]
// TODO: pass the expected type to support type inferring for parameters
argumentTypes[i] = checker.VisitExpression(argument.Expression, nil)
argumentTypes[i] = checker.VisitExpression(argument.Expression, invocationExpression, nil)
}
}

Expand Down Expand Up @@ -571,15 +571,15 @@ func (checker *Checker) checkTypeParameterInference(
}

func (checker *Checker) checkInvocationRequiredArgument(
arguments ast.Arguments,
invocationExpression *ast.InvocationExpression,
argumentIndex int,
functionType *FunctionType,
argumentTypes []Type,
typeParameters *TypeParameterTypeOrderedMap,
) (
parameterType Type,
) {
argument := arguments[argumentIndex]
argument := invocationExpression.Arguments[argumentIndex]

parameter := functionType.Parameters[argumentIndex]
parameterType = parameter.TypeAnnotation.Type
Expand Down Expand Up @@ -637,7 +637,7 @@ func (checker *Checker) checkInvocationRequiredArgument(
expectedType = nil
}

argumentType = checker.VisitExpression(argument.Expression, expectedType)
argumentType = checker.VisitExpression(argument.Expression, invocationExpression, expectedType)

// If we did not pass an expected type,
// we must manually check that the argument type and the parameter type are compatible.
Expand All @@ -659,7 +659,7 @@ func (checker *Checker) checkInvocationRequiredArgument(
// We will then have to manually check that the argument type is compatible
// with the parameter type (see below).

argumentType = checker.VisitExpression(argument.Expression, nil)
argumentType = checker.VisitExpression(argument.Expression, invocationExpression, nil)

// Try to unify the parameter type with the argument type.
// If unification fails, fall back to the parameter type for now.
Expand Down

0 comments on commit 49c7f8f

Please sign in to comment.