diff --git a/detekt-rules-style/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/style/CanBeNonNullable.kt b/detekt-rules-style/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/style/CanBeNonNullable.kt index 570f7ce22bd..a65ed2e7ceb 100644 --- a/detekt-rules-style/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/style/CanBeNonNullable.kt +++ b/detekt-rules-style/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/style/CanBeNonNullable.kt @@ -20,6 +20,7 @@ import org.jetbrains.kotlin.descriptors.PropertyDescriptor import org.jetbrains.kotlin.lexer.KtTokens import org.jetbrains.kotlin.name.FqName import org.jetbrains.kotlin.psi.KtBinaryExpression +import org.jetbrains.kotlin.psi.KtBlockExpression import org.jetbrains.kotlin.psi.KtCallExpression import org.jetbrains.kotlin.psi.KtClass import org.jetbrains.kotlin.psi.KtConstantExpression @@ -46,6 +47,7 @@ import org.jetbrains.kotlin.psi.KtWhenConditionWithExpression import org.jetbrains.kotlin.psi.KtWhenExpression import org.jetbrains.kotlin.psi.psiUtil.allChildren import org.jetbrains.kotlin.psi.psiUtil.collectDescendantsOfType +import org.jetbrains.kotlin.psi.psiUtil.isFirstStatement import org.jetbrains.kotlin.psi.psiUtil.isPrivate import org.jetbrains.kotlin.resolve.BindingContext import org.jetbrains.kotlin.resolve.calls.smartcasts.getKotlinTypeForComparison @@ -83,6 +85,11 @@ import org.jetbrains.kotlin.types.isNullable * println(a) * } * } + * + * fun foo(a: Int?) { + * if (a == null) return + * println(a) + * } * * * @@ -175,7 +182,7 @@ class CanBeNonNullable(config: Config = Config.empty) : Rule(config) { // the param, either via an if/when check or with a safe-qualified expression. .filter { val onlyNonNullCheck = validSingleChildExpression && it.isNonNullChecked && !it.isNullChecked - it.isNonNullForced || onlyNonNullCheck + it.isNonNullForced || it.isNullCheckReturnsUnit || onlyNonNullCheck } .forEach { nullableParam -> report( @@ -235,9 +242,24 @@ class CanBeNonNullable(config: Config = Config.empty) : Rule(config) { override fun visitIfExpression(expression: KtIfExpression) { expression.condition.evaluateCheckStatement(expression.`else`) + if (expression.isFirstStatement()) { + evaluateNullCheckReturnsUnit(expression.condition, expression.then) + } super.visitIfExpression(expression) } + private fun evaluateNullCheckReturnsUnit(condition: KtExpression?, then: KtExpression?) { + val thenExpression = if (then is KtBlockExpression) then.firstStatement else then + if (thenExpression !is KtReturnExpression) return + if (thenExpression.returnedExpression != null) return + + if (condition is KtBinaryExpression && condition.isNullCheck()) { + getDescriptor(condition.left, condition.right) + ?.let { nullableParams[it] } + ?.let { it.isNullCheckReturnsUnit = true } + } + } + override fun visitSafeQualifiedExpression(expression: KtSafeQualifiedExpression) { updateNullableParam(expression.receiverExpression) { it.isNonNullChecked = true } super.visitSafeQualifiedExpression(expression) @@ -323,15 +345,6 @@ class CanBeNonNullable(config: Config = Config.empty) : Rule(config) { val rightExpression = right val nonNullChecks = mutableListOf() - fun getDescriptor(leftExpression: KtExpression?, rightExpression: KtExpression?): CallableDescriptor? { - return when { - leftExpression is KtNameReferenceExpression -> leftExpression - rightExpression is KtNameReferenceExpression -> rightExpression - else -> null - }?.getResolvedCall(bindingContext) - ?.resultingDescriptor - } - if (isNullCheck()) { getDescriptor(leftExpression, rightExpression) ?.let { nullableParams[it] } @@ -346,6 +359,15 @@ class CanBeNonNullable(config: Config = Config.empty) : Rule(config) { return nonNullChecks } + private fun getDescriptor(leftExpression: KtExpression?, rightExpression: KtExpression?): CallableDescriptor? { + return when { + leftExpression is KtNameReferenceExpression -> leftExpression + rightExpression is KtNameReferenceExpression -> rightExpression + else -> null + }?.getResolvedCall(bindingContext) + ?.resultingDescriptor + } + private fun KtIsExpression.evaluateIsExpression(): List { val descriptor = this.leftHandSide.getResolvedCall(bindingContext)?.resultingDescriptor ?: return emptyList() @@ -417,6 +439,7 @@ class CanBeNonNullable(config: Config = Config.empty) : Rule(config) { var isNullChecked = false var isNonNullChecked = false var isNonNullForced = false + var isNullCheckReturnsUnit = false } private inner class PropertyCheckVisitor : DetektVisitor() { diff --git a/detekt-rules-style/src/test/kotlin/io/gitlab/arturbosch/detekt/rules/style/CanBeNonNullableSpec.kt b/detekt-rules-style/src/test/kotlin/io/gitlab/arturbosch/detekt/rules/style/CanBeNonNullableSpec.kt index d0e67b42caf..4b20456b0b4 100644 --- a/detekt-rules-style/src/test/kotlin/io/gitlab/arturbosch/detekt/rules/style/CanBeNonNullableSpec.kt +++ b/detekt-rules-style/src/test/kotlin/io/gitlab/arturbosch/detekt/rules/style/CanBeNonNullableSpec.kt @@ -897,6 +897,52 @@ class CanBeNonNullableSpec(val env: KotlinCoreEnvironment) { assertThat(subject.compileAndLintWithContext(env, code)).hasSize(1) } + @Test + fun `does report null-check returning unit type`() { + val code = """ + fun foo(a: Int?) { + if (a == null) return + println(a) + } + """.trimIndent() + assertThat(subject.compileAndLintWithContext(env, code)).hasSize(1) + } + + @Test + fun `does report null-check returning unit type in block`() { + val code = """ + fun foo(a: Int?) { + if (a == null) { return } + println(a) + } + """.trimIndent() + assertThat(subject.compileAndLintWithContext(env, code)).hasSize(1) + } + + @Test + fun `does not report guard statement with side effect ahead`() { + val code = """ + fun foo(a: Int?) { + println("side effect") + if (a == null) return + println(a) + } + """.trimIndent() + assertThat(subject.compileAndLintWithContext(env, code)).hasSize(0) + } + + @Test + fun `does not report null-check returning non-unit type`() { + val code = """ + fun foo(a: Int?): Int { + if (a == null) return 0 + println(a) + return a + } + """.trimIndent() + assertThat(subject.compileAndLintWithContext(env, code)).hasSize(0) + } + @Test fun `does not report when the parameter is checked on non-nullity with an else statement`() { val code = """