diff --git a/core/src/main/java/com/google/errorprone/bugpatterns/MockitoDoSetup.java b/core/src/main/java/com/google/errorprone/bugpatterns/MockitoDoSetup.java new file mode 100644 index 00000000000..b2a67cae0c7 --- /dev/null +++ b/core/src/main/java/com/google/errorprone/bugpatterns/MockitoDoSetup.java @@ -0,0 +1,198 @@ +/* + * Copyright 2024 The Error Prone Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.errorprone.bugpatterns; + +import static com.google.common.collect.Iterables.getLast; +import static com.google.errorprone.BugPattern.SeverityLevel.WARNING; +import static com.google.errorprone.matchers.Description.NO_MATCH; +import static com.google.errorprone.matchers.Matchers.instanceMethod; +import static com.google.errorprone.matchers.Matchers.staticMethod; +import static com.google.errorprone.util.ASTHelpers.getReceiver; +import static com.google.errorprone.util.ASTHelpers.getStartPosition; +import static com.google.errorprone.util.ASTHelpers.getSymbol; +import static com.google.errorprone.util.ASTHelpers.hasAnnotation; +import static com.google.errorprone.util.ASTHelpers.isSameType; +import static java.lang.String.format; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.errorprone.BugPattern; +import com.google.errorprone.VisitorState; +import com.google.errorprone.bugpatterns.BugChecker.CompilationUnitTreeMatcher; +import com.google.errorprone.fixes.SuggestedFix; +import com.google.errorprone.fixes.SuggestedFixes; +import com.google.errorprone.matchers.Description; +import com.google.errorprone.matchers.Matcher; +import com.sun.source.tree.AssignmentTree; +import com.sun.source.tree.CompilationUnitTree; +import com.sun.source.tree.ExpressionTree; +import com.sun.source.tree.MethodInvocationTree; +import com.sun.source.tree.Tree; +import com.sun.source.tree.VariableTree; +import com.sun.source.util.TreePath; +import com.sun.source.util.TreePathScanner; +import com.sun.tools.javac.code.Symbol.VarSymbol; + +/** A BugPattern; see the summary. */ +@BugPattern( + severity = WARNING, + summary = "Prefer using when/thenReturn over doReturn/when for additional type safety.") +public final class MockitoDoSetup extends BugChecker implements CompilationUnitTreeMatcher { + @Override + public Description matchCompilationUnit(CompilationUnitTree tree, VisitorState state) { + ImmutableSet spies = findSpies(state); + new SuppressibleTreePathScanner(state) { + + @Override + public Void visitMethodInvocation(MethodInvocationTree tree, Void unused) { + handle(tree); + return super.visitMethodInvocation(tree, null); + } + + private void handle(MethodInvocationTree tree) { + if (!DO_STUBBER.matches(tree, state)) { + return; + } + TreePath whenPath = getCurrentPath().getParentPath().getParentPath(); + Tree whenCall = whenPath.getLeaf(); + if (!(whenCall instanceof MethodInvocationTree) + || !INSTANCE_WHEN.matches((MethodInvocationTree) whenCall, state)) { + return; + } + if (isSpy(((MethodInvocationTree) whenCall).getArguments().get(0))) { + return; + } + Tree mockedMethod = whenPath.getParentPath().getParentPath().getLeaf(); + + if (!(mockedMethod instanceof MethodInvocationTree)) { + return; + } + if (isSameType( + getSymbol((MethodInvocationTree) mockedMethod).getReturnType(), + state.getSymtab().voidType, + state)) { + return; + } + + SuggestedFix.Builder fix = SuggestedFix.builder(); + var when = SuggestedFixes.qualifyStaticImport("org.mockito.Mockito.when", fix, state); + fix.replace(((MethodInvocationTree) whenCall).getMethodSelect(), when) + .replace(state.getEndPosition(whenCall) - 1, state.getEndPosition(whenCall), "") + .postfixWith( + mockedMethod, + format( + ").%s(%s)", + NAME_MAPPINGS.get(getSymbol(tree).getSimpleName().toString()), + getParameterSource(tree, state))); + + state.reportMatch(describeMatch(tree, fix.build())); + } + + private boolean isSpy(ExpressionTree tree) { + var symbol = getSymbol(tree); + return symbol != null + && (spies.contains(symbol) || hasAnnotation(symbol, "org.mockito.Spy", state)); + } + }.scan(state.getPath(), null); + return NO_MATCH; + } + + private static String getParameterSource(MethodInvocationTree tree, VisitorState state) { + return state + .getSourceCode() + .subSequence( + getStartPosition(tree.getArguments().get(0)), + state.getEndPosition(getLast(tree.getArguments()))) + .toString(); + } + + private static ImmutableSet findSpies(VisitorState state) { + // NOTES: This is extremely conservative in at least two ways. + // 1) We ignore an entire mock if _any_ method is mocked to throw, not just the relevant method. + // 2) We could still refactor if the thenThrow comes _after_, or if the _only_ call is + // thenThrow. + ImmutableSet.Builder spiesOrThrows = ImmutableSet.builder(); + new TreePathScanner() { + @Override + public Void visitVariable(VariableTree tree, Void unused) { + if (tree.getInitializer() != null && SPY.matches(tree.getInitializer(), state)) { + spiesOrThrows.add(getSymbol(tree)); + } + return super.visitVariable(tree, null); + } + + @Override + public Void visitMethodInvocation(MethodInvocationTree tree, Void unused) { + if (DO_THROW.matches(tree, state)) { + var whenCall = getCurrentPath().getParentPath().getParentPath().getLeaf(); + if ((whenCall instanceof MethodInvocationTree) + && INSTANCE_WHEN.matches((MethodInvocationTree) whenCall, state)) { + var whenTarget = getSymbol(((MethodInvocationTree) whenCall).getArguments().get(0)); + if (whenTarget instanceof VarSymbol) { + spiesOrThrows.add((VarSymbol) whenTarget); + } + } + } + if (THEN_THROW.matches(tree, state)) { + var receiver = getReceiver(tree); + if (STATIC_WHEN.matches(receiver, state)) { + var mock = getReceiver(((MethodInvocationTree) receiver).getArguments().get(0)); + var mockSymbol = getSymbol(mock); + if (mockSymbol instanceof VarSymbol) { + spiesOrThrows.add((VarSymbol) mockSymbol); + } + } + } + return super.visitMethodInvocation(tree, null); + } + + @Override + public Void visitAssignment(AssignmentTree tree, Void unused) { + if (SPY.matches(tree.getExpression(), state)) { + var symbol = getSymbol(tree.getVariable()); + if (symbol instanceof VarSymbol) { + spiesOrThrows.add((VarSymbol) symbol); + } + } + return super.visitAssignment(tree, null); + } + }.scan(state.getPath().getCompilationUnit(), null); + return spiesOrThrows.build(); + } + + private static final ImmutableMap NAME_MAPPINGS = + ImmutableMap.of( + "doAnswer", "thenAnswer", + "doReturn", "thenReturn", + "doThrow", "thenThrow"); + private static final Matcher DO_STUBBER = + staticMethod().onClass("org.mockito.Mockito").namedAnyOf(NAME_MAPPINGS.keySet()); + + private static final Matcher INSTANCE_WHEN = + instanceMethod().onDescendantOf("org.mockito.stubbing.Stubber").named("when"); + + private static final Matcher SPY = + staticMethod().onClass("org.mockito.Mockito").named("spy"); + + private static final Matcher DO_THROW = + staticMethod().onClass("org.mockito.Mockito").named("doThrow"); + + private static final Matcher STATIC_WHEN = + staticMethod().onClass("org.mockito.Mockito").named("when"); + + private static final Matcher THEN_THROW = + instanceMethod().onDescendantOf("org.mockito.stubbing.OngoingStubbing").named("thenThrow"); +} diff --git a/core/src/main/java/com/google/errorprone/scanner/BuiltInCheckerSuppliers.java b/core/src/main/java/com/google/errorprone/scanner/BuiltInCheckerSuppliers.java index 2506e23075c..f90f05f4a83 100644 --- a/core/src/main/java/com/google/errorprone/scanner/BuiltInCheckerSuppliers.java +++ b/core/src/main/java/com/google/errorprone/scanner/BuiltInCheckerSuppliers.java @@ -244,6 +244,7 @@ import com.google.errorprone.bugpatterns.MixedDescriptors; import com.google.errorprone.bugpatterns.MixedMutabilityReturnType; import com.google.errorprone.bugpatterns.MockNotUsedInProduction; +import com.google.errorprone.bugpatterns.MockitoDoSetup; import com.google.errorprone.bugpatterns.MockitoUsage; import com.google.errorprone.bugpatterns.ModifiedButNotUsed; import com.google.errorprone.bugpatterns.ModifyCollectionInEnhancedForLoop; @@ -1169,6 +1170,7 @@ public static ScannerSupplier warningChecks() { MissingBraces.class, MissingDefault.class, MixedArrayDimensions.class, + MockitoDoSetup.class, MoreThanOneQualifier.class, MultiVariableDeclaration.class, MultipleTopLevelClasses.class, diff --git a/core/src/test/java/com/google/errorprone/bugpatterns/MockitoDoSetupTest.java b/core/src/test/java/com/google/errorprone/bugpatterns/MockitoDoSetupTest.java new file mode 100644 index 00000000000..12cb2cbf03d --- /dev/null +++ b/core/src/test/java/com/google/errorprone/bugpatterns/MockitoDoSetupTest.java @@ -0,0 +1,120 @@ +/* + * Copyright 2024 The Error Prone Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.errorprone.bugpatterns; + +import com.google.errorprone.BugCheckerRefactoringTestHelper; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class MockitoDoSetupTest { + private final BugCheckerRefactoringTestHelper helper = + BugCheckerRefactoringTestHelper.newInstance(MockitoDoSetup.class, getClass()); + + @Test + public void happy() { + helper + .addInputLines( + "Test.java", + "import org.mockito.Mockito;", + "public class Test {", + " public int test(Test test) {", + " Mockito.doReturn(1).when(test).test(null);", + " return 1;", + " }", + "}") + .addOutputLines( + "Test.java", + "import static org.mockito.Mockito.when;", + "import org.mockito.Mockito;", + "public class Test {", + " public int test(Test test) {", + " when(test.test(null)).thenReturn(1);", + " return 1;", + " }", + "}") + .doTest(); + } + + @Test + public void ignoresSpiesCreatedByAnnotation() { + helper + .addInputLines( + "Test.java", + "import org.mockito.Mockito;", + "public class Test {", + " @org.mockito.Spy Test test;", + " public int test() {", + " Mockito.doReturn(1).when(test).test();", + " return 1;", + " }", + "}") + .expectUnchanged() + .doTest(); + } + + @Test + public void ignoresSpiesCreatedByStaticMethod() { + helper + .addInputLines( + "Test.java", + "import org.mockito.Mockito;", + "public class Test {", + " Test test = Mockito.spy(Test.class);", + " public int test() {", + " Mockito.doReturn(1).when(test).test();", + " return 1;", + " }", + "}") + .expectUnchanged() + .doTest(); + } + + @Test + public void ignoresMocksConfiguredToThrow_viaThenThrow() { + helper + .addInputLines( + "Test.java", + "import org.mockito.Mockito;", + "public class Test {", + " public int test(Test test) {", + " Mockito.doReturn(1).when(test).test(null);", + " Mockito.when(test.test(null)).thenThrow(new Exception());", + " return 1;", + " }", + "}") + .expectUnchanged() + .doTest(); + } + + @Test + public void ignoresMocksConfiguredToThrow_viaDoThrow() { + helper + .addInputLines( + "Test.java", + "import org.mockito.Mockito;", + "public class Test {", + " public int test(Test test) {", + " Mockito.doReturn(1).when(test).test(null);", + " Mockito.doThrow(new Exception()).when(test).test(null);", + " return 1;", + " }", + "}") + .expectUnchanged() + .doTest(); + } +} diff --git a/docs/bugpattern/MockitoDoSetup.md b/docs/bugpattern/MockitoDoSetup.md new file mode 100644 index 00000000000..6e906e3fdeb --- /dev/null +++ b/docs/bugpattern/MockitoDoSetup.md @@ -0,0 +1,22 @@ +Prefer using the format + +```java + when(mock.mockedMethod(...)).thenReturn(returnValue); +``` + +to initialise mocks, rather than, + +```java + doReturn(returnValue).when(mock).mockedMethod(...); +``` + +Mockito recommends the `when`/`thenReturn` syntax as it is both more readable +and provides type-safety: the return type of the stubbed method is checked +against the stubbed value at compile time. + +There are certain situations where `doReturn` is required: + +* Overriding previous stubbing where the method will *throw*, as `when` makes + an actual method call. +* Overriding a `spy` where the method call where calling the spied method + brings undesired side-effects.