Skip to content

Commit

Permalink
Improve vararg handling
Browse files Browse the repository at this point in the history
Fixes: mockito#2796

Add an optional method to `VarargMatcher`, which implementations
can choose to override to return the type of object the matcher is matching.

This is used by `MatcherApplicationStrategy` to determine if the type of matcher used to match a vararg parameter is of a type compatible with the vararg parameter.

Where a vararg compatible matcher is found, the matcher is used to match the _raw_ parameters.
  • Loading branch information
big-andy-coates committed Nov 28, 2022
1 parent 8f4af18 commit 48a3014
Show file tree
Hide file tree
Showing 13 changed files with 274 additions and 28 deletions.
3 changes: 2 additions & 1 deletion src/main/java/org/mockito/ArgumentCaptor.java
Expand Up @@ -62,11 +62,12 @@
@CheckReturnValue
public class ArgumentCaptor<T> {

private final CapturingMatcher<T> capturingMatcher = new CapturingMatcher<T>();
private final CapturingMatcher<T> capturingMatcher;
private final Class<? extends T> clazz;

private ArgumentCaptor(Class<? extends T> clazz) {
this.clazz = clazz;
this.capturingMatcher = new CapturingMatcher<T>(clazz);
}

/**
Expand Down
Expand Up @@ -9,6 +9,8 @@
import org.mockito.ArgumentMatcher;
import org.mockito.internal.matchers.VarargMatcher;

import java.util.Optional;

public class HamcrestArgumentMatcher<T> implements ArgumentMatcher<T> {

private final Matcher matcher;
Expand All @@ -26,6 +28,10 @@ public boolean isVarargMatcher() {
return matcher instanceof VarargMatcher;
}

public Optional<VarargMatcher> varargMatcher() {
return isVarargMatcher() ? Optional.of((VarargMatcher) matcher) : Optional.empty();
}

@Override
public String toString() {
// TODO SF add unit tests and integ test coverage for toString()
Expand Down
Expand Up @@ -4,8 +4,10 @@
*/
package org.mockito.internal.invocation;

import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;

import org.mockito.ArgumentMatcher;
import org.mockito.internal.hamcrest.HamcrestArgumentMatcher;
Expand Down Expand Up @@ -58,14 +60,25 @@ public static MatcherApplicationStrategy getMatcherApplicationStrategyFor(
* </ul>
*/
public boolean forEachMatcherAndArgument(ArgumentMatcherAction action) {
if (invocation.getArguments().length == matchers.size()) {
return argsMatch(invocation.getArguments(), matchers, action);
}

final boolean isVararg =
invocation.getMethod().isVarArgs()
&& invocation.getRawArguments().length == matchers.size()
&& isLastMatcherVarargMatcher(matchers);
&& getLastMatcherVarargMatcherType(matchers).isPresent();

if (isVararg) {
final Class<?> matcherType = getLastMatcherVarargMatcherType(matchers).get();
final Class<?> paramType =
invocation.getMethod()
.getParameterTypes()[
invocation.getMethod().getParameterTypes().length - 1];
if (paramType.isAssignableFrom(matcherType)) {
return argsMatch(invocation.getRawArguments(), matchers, action);
}
}

if (invocation.getArguments().length == matchers.size()) {
return argsMatch(invocation.getArguments(), matchers, action);
}

if (isVararg) {
int times = varargLength(invocation);
Expand All @@ -91,12 +104,17 @@ private boolean argsMatch(
return true;
}

private static boolean isLastMatcherVarargMatcher(List<? extends ArgumentMatcher<?>> matchers) {
private static Optional<Class<?>> getLastMatcherVarargMatcherType(
final List<? extends ArgumentMatcher<?>> matchers) {
ArgumentMatcher<?> argumentMatcher = lastMatcher(matchers);
if (argumentMatcher instanceof HamcrestArgumentMatcher<?>) {
return ((HamcrestArgumentMatcher<?>) argumentMatcher).isVarargMatcher();
return ((HamcrestArgumentMatcher<?>) argumentMatcher)
.varargMatcher()
.map(VarargMatcher::type);
}
return argumentMatcher instanceof VarargMatcher;
return argumentMatcher instanceof VarargMatcher
? Optional.of((VarargMatcher) argumentMatcher).map(VarargMatcher::type)
: Optional.empty();
}

private List<? extends ArgumentMatcher<?>> appendLastMatcherNTimes(
Expand Down
5 changes: 5 additions & 0 deletions src/main/java/org/mockito/internal/matchers/Any.java
Expand Up @@ -21,4 +21,9 @@ public boolean matches(Object actual) {
public String toString() {
return "<any>";
}

@Override
public Class<?> type() {
return Object.class;
}
}
10 changes: 10 additions & 0 deletions src/main/java/org/mockito/internal/matchers/CapturingMatcher.java
Expand Up @@ -19,12 +19,17 @@
public class CapturingMatcher<T>
implements ArgumentMatcher<T>, CapturesArguments, VarargMatcher, Serializable {

private final Class<? extends T> clazz;
private final List<Object> arguments = new ArrayList<>();

private final ReadWriteLock lock = new ReentrantReadWriteLock();
private final Lock readLock = lock.readLock();
private final Lock writeLock = lock.writeLock();

public CapturingMatcher(final Class<? extends T> clazz) {
this.clazz = clazz;
}

@Override
public boolean matches(Object argument) {
return true;
Expand Down Expand Up @@ -66,4 +71,9 @@ public void captureFrom(Object argument) {
writeLock.unlock();
}
}

@Override
public Class<?> type() {
return clazz;
}
}
7 changes: 6 additions & 1 deletion src/main/java/org/mockito/internal/matchers/InstanceOf.java
Expand Up @@ -11,7 +11,7 @@

public class InstanceOf implements ArgumentMatcher<Object>, Serializable {

private final Class<?> clazz;
final Class<?> clazz;
private final String description;

public InstanceOf(Class<?> clazz) {
Expand Down Expand Up @@ -44,5 +44,10 @@ public VarArgAware(Class<?> clazz) {
public VarArgAware(Class<?> clazz, String describedAs) {
super(clazz, describedAs);
}

@Override
public Class<?> type() {
return clazz;
}
}
}
42 changes: 41 additions & 1 deletion src/main/java/org/mockito/internal/matchers/VarargMatcher.java
Expand Up @@ -10,4 +10,44 @@
* Internal interface that informs Mockito that the matcher is intended to capture varargs.
* This information is needed when mockito collects the arguments.
*/
public interface VarargMatcher extends Serializable {}
public interface VarargMatcher extends Serializable {

/**
* The type of the argument the matcher matches.
*
* <p>If a vararg aware matcher:
* <ul>
* <li>is at the parameter index of a vararg parameter</li>
* <li>is the last matcher passed</li>
* <li>matchers the raw type of the vararg parameter</li>
* </ul>
*
* Then the matcher is matched against the vararg raw parameter.
* Otherwise, the matcher will be matched against each element in the vararg raw parameters.
*
* <p>For example:
*
* <pre class="code"><code class="java">
* // Given vararg method with signature:
* int someVarargMethod(int x, String... args);
*
* // The following will match the last matcher against the contents of the `args` array:
* (as the above criteria are met)
* mock.someVarargMethod(eq(1), any(String[].class));
*
* // The following will match the last matcher against each element of the `args` array:
* // (as the type of the last matcher does not match the raw type of the vararg parameter)
* mock.someVarargMethod(eq(1), any(String.class));
*
* // The following will match only invocations with two strings in the 'args' array:
* // (as there are more matchers than raw arguments)
* mock.someVarargMethod(eq(1), any(), any());
* </code></pre>
*
* @return the type this matcher handles.
* @since 4.10.0
*/
default Class<?> type() {
return Void.class;
}
}
Expand Up @@ -136,7 +136,7 @@ public void should_be_similar_if_is_overloaded_but_used_with_different_arg() thr
public void should_capture_arguments_from_invocation() throws Exception {
// given
Invocation invocation = new InvocationBuilder().args("1", 100).toInvocation();
CapturingMatcher capturingMatcher = new CapturingMatcher();
CapturingMatcher capturingMatcher = new CapturingMatcher(List.class);
InvocationMatcher invocationMatcher =
new InvocationMatcher(invocation, (List) asList(new Equals("1"), capturingMatcher));

Expand Down Expand Up @@ -167,7 +167,7 @@ public void should_capture_varargs_as_vararg() throws Exception {
// given
mock.mixedVarargs(1, "a", "b");
Invocation invocation = getLastInvocation();
CapturingMatcher m = new CapturingMatcher();
CapturingMatcher m = new CapturingMatcher(List.class);
InvocationMatcher invocationMatcher =
new InvocationMatcher(invocation, Arrays.<ArgumentMatcher>asList(new Equals(1), m));

Expand Down
Expand Up @@ -9,6 +9,7 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.internal.invocation.MatcherApplicationStrategy.getMatcherApplicationStrategyFor;
import static org.mockito.internal.matchers.Any.ANY;

Expand Down Expand Up @@ -225,12 +226,32 @@ public void shouldMatchAnyEvenIfMatcherIsWrappedInHamcrestMatcher() {
recordAction.assertContainsExactly(argumentMatcher, argumentMatcher);
}

@Test
public void shouldMatchAnyThatMatchesRawVarArgType() {
// given
invocation = varargs("1", "2");
InstanceOf.VarArgAware any = new InstanceOf.VarArgAware(String[].class, "<any String[]>");
matchers = asList(any);

// when
getMatcherApplicationStrategyFor(invocation, matchers)
.forEachMatcherAndArgument(recordAction);

// then
recordAction.assertContainsExactly(any);
}

private static class IntMatcher extends BaseMatcher<Integer> implements VarargMatcher {
public boolean matches(Object o) {
return true;
}

public void describeTo(Description description) {}

@Override
public Class<?> type() {
return Integer.class;
}
}

private Invocation mixedVarargs(Object a, String... s) {
Expand Down
Expand Up @@ -19,7 +19,7 @@ public class CapturingMatcherTest extends TestBase {
@Test
public void should_capture_arguments() throws Exception {
// given
CapturingMatcher<String> m = new CapturingMatcher<String>();
CapturingMatcher<String> m = new CapturingMatcher<String>(String.class);

// when
m.captureFrom("foo");
Expand All @@ -32,7 +32,7 @@ public void should_capture_arguments() throws Exception {
@Test
public void should_know_last_captured_value() throws Exception {
// given
CapturingMatcher<String> m = new CapturingMatcher<String>();
CapturingMatcher<String> m = new CapturingMatcher<String>(String.class);

// when
m.captureFrom("foo");
Expand All @@ -45,7 +45,7 @@ public void should_know_last_captured_value() throws Exception {
@Test
public void should_scream_when_nothing_yet_captured() throws Exception {
// given
CapturingMatcher<String> m = new CapturingMatcher<String>();
CapturingMatcher<String> m = new CapturingMatcher<String>(String.class);

try {
// when
Expand All @@ -59,7 +59,7 @@ public void should_scream_when_nothing_yet_captured() throws Exception {
@Test
public void should_not_fail_when_used_in_concurrent_tests() throws Exception {
// given
final CapturingMatcher<String> m = new CapturingMatcher<String>();
final CapturingMatcher<String> m = new CapturingMatcher<String>(String.class);

// when
m.captureFrom("concurrent access");
Expand Down
4 changes: 4 additions & 0 deletions src/test/java/org/mockitousage/IMethods.java
Expand Up @@ -199,6 +199,10 @@ String sixArgumentVarArgsMethod(

Object[] mixedVarargsReturningObjectArray(Object i, String... string);

String methodWithVarargAndNonVarargVariants(String string);

String methodWithVarargAndNonVarargVariants(String... string);

List<String> listReturningMethod(Object... objects);

LinkedList<String> linkedListReturningMethod();
Expand Down
10 changes: 10 additions & 0 deletions src/test/java/org/mockitousage/MethodsImpl.java
Expand Up @@ -376,6 +376,16 @@ public Object[] mixedVarargsReturningObjectArray(Object i, String... string) {
return null;
}

@Override
public String methodWithVarargAndNonVarargVariants(String string) {
return "plain";
}

@Override
public String methodWithVarargAndNonVarargVariants(String... string) {
return "varargs";
}

public void varargsbyte(byte... bytes) {}

public List<String> listReturningMethod(Object... objects) {
Expand Down

0 comments on commit 48a3014

Please sign in to comment.