Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Encourage when/thenReturn over doReturn/when. #4369

Merged
merged 1 commit into from Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
@@ -0,0 +1,198 @@
/*
* Copyright 2024 The Error Prone Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.errorprone.bugpatterns;

import static com.google.common.collect.Iterables.getLast;
import static com.google.errorprone.BugPattern.SeverityLevel.WARNING;
import static com.google.errorprone.matchers.Description.NO_MATCH;
import static com.google.errorprone.matchers.Matchers.instanceMethod;
import static com.google.errorprone.matchers.Matchers.staticMethod;
import static com.google.errorprone.util.ASTHelpers.getReceiver;
import static com.google.errorprone.util.ASTHelpers.getStartPosition;
import static com.google.errorprone.util.ASTHelpers.getSymbol;
import static com.google.errorprone.util.ASTHelpers.hasAnnotation;
import static com.google.errorprone.util.ASTHelpers.isSameType;
import static java.lang.String.format;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.errorprone.BugPattern;
import com.google.errorprone.VisitorState;
import com.google.errorprone.bugpatterns.BugChecker.CompilationUnitTreeMatcher;
import com.google.errorprone.fixes.SuggestedFix;
import com.google.errorprone.fixes.SuggestedFixes;
import com.google.errorprone.matchers.Description;
import com.google.errorprone.matchers.Matcher;
import com.sun.source.tree.AssignmentTree;
import com.sun.source.tree.CompilationUnitTree;
import com.sun.source.tree.ExpressionTree;
import com.sun.source.tree.MethodInvocationTree;
import com.sun.source.tree.Tree;
import com.sun.source.tree.VariableTree;
import com.sun.source.util.TreePath;
import com.sun.source.util.TreePathScanner;
import com.sun.tools.javac.code.Symbol.VarSymbol;

/** A BugPattern; see the summary. */
@BugPattern(
severity = WARNING,
summary = "Prefer using when/thenReturn over doReturn/when for additional type safety.")
public final class MockitoDoSetup extends BugChecker implements CompilationUnitTreeMatcher {
@Override
public Description matchCompilationUnit(CompilationUnitTree tree, VisitorState state) {
ImmutableSet<VarSymbol> spies = findSpies(state);
new SuppressibleTreePathScanner<Void, Void>(state) {

@Override
public Void visitMethodInvocation(MethodInvocationTree tree, Void unused) {
handle(tree);
return super.visitMethodInvocation(tree, null);
}

private void handle(MethodInvocationTree tree) {
if (!DO_STUBBER.matches(tree, state)) {
return;
}
TreePath whenPath = getCurrentPath().getParentPath().getParentPath();
Tree whenCall = whenPath.getLeaf();
if (!(whenCall instanceof MethodInvocationTree)
|| !INSTANCE_WHEN.matches((MethodInvocationTree) whenCall, state)) {
return;
}
if (isSpy(((MethodInvocationTree) whenCall).getArguments().get(0))) {
return;
}
Tree mockedMethod = whenPath.getParentPath().getParentPath().getLeaf();

if (!(mockedMethod instanceof MethodInvocationTree)) {
return;
}
if (isSameType(
getSymbol((MethodInvocationTree) mockedMethod).getReturnType(),
state.getSymtab().voidType,
state)) {
return;
}

SuggestedFix.Builder fix = SuggestedFix.builder();
var when = SuggestedFixes.qualifyStaticImport("org.mockito.Mockito.when", fix, state);
fix.replace(((MethodInvocationTree) whenCall).getMethodSelect(), when)
.replace(state.getEndPosition(whenCall) - 1, state.getEndPosition(whenCall), "")
.postfixWith(
mockedMethod,
format(
").%s(%s)",
NAME_MAPPINGS.get(getSymbol(tree).getSimpleName().toString()),
getParameterSource(tree, state)));

state.reportMatch(describeMatch(tree, fix.build()));
}

private boolean isSpy(ExpressionTree tree) {
var symbol = getSymbol(tree);
return symbol != null
&& (spies.contains(symbol) || hasAnnotation(symbol, "org.mockito.Spy", state));
}
}.scan(state.getPath(), null);
return NO_MATCH;
}

private static String getParameterSource(MethodInvocationTree tree, VisitorState state) {
return state
.getSourceCode()
.subSequence(
getStartPosition(tree.getArguments().get(0)),
state.getEndPosition(getLast(tree.getArguments())))
.toString();
}

private static ImmutableSet<VarSymbol> findSpies(VisitorState state) {
// NOTES: This is extremely conservative in at least two ways.
// 1) We ignore an entire mock if _any_ method is mocked to throw, not just the relevant method.
// 2) We could still refactor if the thenThrow comes _after_, or if the _only_ call is
// thenThrow.
ImmutableSet.Builder<VarSymbol> spiesOrThrows = ImmutableSet.builder();
new TreePathScanner<Void, Void>() {
@Override
public Void visitVariable(VariableTree tree, Void unused) {
if (tree.getInitializer() != null && SPY.matches(tree.getInitializer(), state)) {
spiesOrThrows.add(getSymbol(tree));
}
return super.visitVariable(tree, null);
}

@Override
public Void visitMethodInvocation(MethodInvocationTree tree, Void unused) {
if (DO_THROW.matches(tree, state)) {
var whenCall = getCurrentPath().getParentPath().getParentPath().getLeaf();
if ((whenCall instanceof MethodInvocationTree)
&& INSTANCE_WHEN.matches((MethodInvocationTree) whenCall, state)) {
var whenTarget = getSymbol(((MethodInvocationTree) whenCall).getArguments().get(0));
if (whenTarget instanceof VarSymbol) {
spiesOrThrows.add((VarSymbol) whenTarget);
}
}
}
if (THEN_THROW.matches(tree, state)) {
var receiver = getReceiver(tree);
if (STATIC_WHEN.matches(receiver, state)) {
var mock = getReceiver(((MethodInvocationTree) receiver).getArguments().get(0));
var mockSymbol = getSymbol(mock);
if (mockSymbol instanceof VarSymbol) {
spiesOrThrows.add((VarSymbol) mockSymbol);
}
}
}
return super.visitMethodInvocation(tree, null);
}

@Override
public Void visitAssignment(AssignmentTree tree, Void unused) {
if (SPY.matches(tree.getExpression(), state)) {
var symbol = getSymbol(tree.getVariable());
if (symbol instanceof VarSymbol) {
spiesOrThrows.add((VarSymbol) symbol);
}
}
return super.visitAssignment(tree, null);
}
}.scan(state.getPath().getCompilationUnit(), null);
return spiesOrThrows.build();
}

private static final ImmutableMap<String, String> NAME_MAPPINGS =
ImmutableMap.of(
"doAnswer", "thenAnswer",
"doReturn", "thenReturn",
"doThrow", "thenThrow");
private static final Matcher<ExpressionTree> DO_STUBBER =
staticMethod().onClass("org.mockito.Mockito").namedAnyOf(NAME_MAPPINGS.keySet());

private static final Matcher<ExpressionTree> INSTANCE_WHEN =
instanceMethod().onDescendantOf("org.mockito.stubbing.Stubber").named("when");

private static final Matcher<ExpressionTree> SPY =
staticMethod().onClass("org.mockito.Mockito").named("spy");

private static final Matcher<ExpressionTree> DO_THROW =
staticMethod().onClass("org.mockito.Mockito").named("doThrow");

private static final Matcher<ExpressionTree> STATIC_WHEN =
staticMethod().onClass("org.mockito.Mockito").named("when");

private static final Matcher<ExpressionTree> THEN_THROW =
instanceMethod().onDescendantOf("org.mockito.stubbing.OngoingStubbing").named("thenThrow");
}
Expand Up @@ -244,6 +244,7 @@
import com.google.errorprone.bugpatterns.MixedDescriptors;
import com.google.errorprone.bugpatterns.MixedMutabilityReturnType;
import com.google.errorprone.bugpatterns.MockNotUsedInProduction;
import com.google.errorprone.bugpatterns.MockitoDoSetup;
import com.google.errorprone.bugpatterns.MockitoUsage;
import com.google.errorprone.bugpatterns.ModifiedButNotUsed;
import com.google.errorprone.bugpatterns.ModifyCollectionInEnhancedForLoop;
Expand Down Expand Up @@ -1169,6 +1170,7 @@ public static ScannerSupplier warningChecks() {
MissingBraces.class,
MissingDefault.class,
MixedArrayDimensions.class,
MockitoDoSetup.class,
MoreThanOneQualifier.class,
MultiVariableDeclaration.class,
MultipleTopLevelClasses.class,
Expand Down
@@ -0,0 +1,120 @@
/*
* Copyright 2024 The Error Prone Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.errorprone.bugpatterns;

import com.google.errorprone.BugCheckerRefactoringTestHelper;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
public final class MockitoDoSetupTest {
private final BugCheckerRefactoringTestHelper helper =
BugCheckerRefactoringTestHelper.newInstance(MockitoDoSetup.class, getClass());

@Test
public void happy() {
helper
.addInputLines(
"Test.java",
"import org.mockito.Mockito;",
"public class Test {",
" public int test(Test test) {",
" Mockito.doReturn(1).when(test).test(null);",
" return 1;",
" }",
"}")
.addOutputLines(
"Test.java",
"import static org.mockito.Mockito.when;",
"import org.mockito.Mockito;",
"public class Test {",
" public int test(Test test) {",
" when(test.test(null)).thenReturn(1);",
" return 1;",
" }",
"}")
.doTest();
}

@Test
public void ignoresSpiesCreatedByAnnotation() {
helper
.addInputLines(
"Test.java",
"import org.mockito.Mockito;",
"public class Test {",
" @org.mockito.Spy Test test;",
" public int test() {",
" Mockito.doReturn(1).when(test).test();",
" return 1;",
" }",
"}")
.expectUnchanged()
.doTest();
}

@Test
public void ignoresSpiesCreatedByStaticMethod() {
helper
.addInputLines(
"Test.java",
"import org.mockito.Mockito;",
"public class Test {",
" Test test = Mockito.spy(Test.class);",
" public int test() {",
" Mockito.doReturn(1).when(test).test();",
" return 1;",
" }",
"}")
.expectUnchanged()
.doTest();
}

@Test
public void ignoresMocksConfiguredToThrow_viaThenThrow() {
helper
.addInputLines(
"Test.java",
"import org.mockito.Mockito;",
"public class Test {",
" public int test(Test test) {",
" Mockito.doReturn(1).when(test).test(null);",
" Mockito.when(test.test(null)).thenThrow(new Exception());",
" return 1;",
" }",
"}")
.expectUnchanged()
.doTest();
}

@Test
public void ignoresMocksConfiguredToThrow_viaDoThrow() {
helper
.addInputLines(
"Test.java",
"import org.mockito.Mockito;",
"public class Test {",
" public int test(Test test) {",
" Mockito.doReturn(1).when(test).test(null);",
" Mockito.doThrow(new Exception()).when(test).test(null);",
" return 1;",
" }",
"}")
.expectUnchanged()
.doTest();
}
}
22 changes: 22 additions & 0 deletions docs/bugpattern/MockitoDoSetup.md
@@ -0,0 +1,22 @@
Prefer using the format

```java
when(mock.mockedMethod(...)).thenReturn(returnValue);
```

to initialise mocks, rather than,

```java
doReturn(returnValue).when(mock).mockedMethod(...);
```

Mockito recommends the `when`/`thenReturn` syntax as it is both more readable
and provides type-safety: the return type of the stubbed method is checked
against the stubbed value at compile time.

There are certain situations where `doReturn` is required:

* Overriding previous stubbing where the method will *throw*, as `when` makes
an actual method call.
* Overriding a `spy` where the method call where calling the spied method
brings undesired side-effects.