Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve vararg handling #2807

Merged
merged 16 commits into from Dec 22, 2022
3 changes: 2 additions & 1 deletion src/main/java/org/mockito/ArgumentCaptor.java
Expand Up @@ -62,11 +62,12 @@
@CheckReturnValue
public class ArgumentCaptor<T> {

private final CapturingMatcher<T> capturingMatcher = new CapturingMatcher<T>();
private final CapturingMatcher<T> capturingMatcher;
private final Class<? extends T> clazz;

private ArgumentCaptor(Class<? extends T> clazz) {
this.clazz = clazz;
this.capturingMatcher = new CapturingMatcher<T>(clazz);
}

/**
Expand Down
44 changes: 44 additions & 0 deletions src/main/java/org/mockito/ArgumentMatcher.java
Expand Up @@ -125,4 +125,48 @@ public interface ArgumentMatcher<T> {
* @return true if this matcher accepts the given argument.
*/
boolean matches(T argument);

/**
* The type of the argument the matcher matches.
*
* <p>Only defaulted to maintain backwards compatability.
* Implementations should provide their own implementation for this method.
*
* <p>Initially, this method is only being used to determine if a matcher should be used to match
* a raw vararg parameter or not. This may change in future releases.
*
* <p>Where a matcher:
* <ul>
* <li>is at the parameter index of a vararg parameter</li>
* <li>is the last matcher passed</li>
* <li>matchers the raw type of the vararg parameter</li>
* </ul>
*
* Then the matcher is matched against the vararg raw parameter.
*
* <p>For example:
*
* <pre class="code"><code class="java">
* // Given vararg method with signature:
* int someVarargMethod(int x, String... args);
*
* // The following will match the last matcher against the contents of the `args` array:
* (as the above criteria are met)
* mock.someVarargMethod(eq(1), any(String[].class));
*
* // The following will match the last matcher against each element of the `args` array:
* // (as the type of the last matcher does not match the raw type of the vararg parameter)
* mock.someVarargMethod(eq(1), any(String.class));
*
* // The following will match only invocations with two strings in the 'args' array:
* // (as there are more matchers than raw arguments)
* mock.someVarargMethod(eq(1), any(), any());
* </code></pre>
*
* @return the type this matcher handles.
* @since 4.10.0
*/
default Class<?> type() {
TimvdLippe marked this conversation as resolved.
Show resolved Hide resolved
return Void.class;
}
TimvdLippe marked this conversation as resolved.
Show resolved Hide resolved
}
54 changes: 54 additions & 0 deletions src/main/java/org/mockito/ArgumentMatchers.java
Expand Up @@ -699,6 +699,22 @@ public static <T> T isNull() {
return null;
}

/**
* <code>null</code> argument.
*
* <p>
* See examples in javadoc for {@link ArgumentMatchers} class
* </p>
*
* @param type the type of the argument being matched.
* @return <code>null</code>.
* @see #isNotNull(Class)
*/
public static <T> T isNull(Class<T> type) {
reportMatcher(new Null(type));
return null;
}

/**
* Not <code>null</code> argument.
*
Expand All @@ -717,6 +733,25 @@ public static <T> T notNull() {
return null;
}

/**
* Not <code>null</code> argument.
*
* <p>
* Alias to {@link ArgumentMatchers#isNotNull()}
* </p>
*
* <p>
* See examples in javadoc for {@link ArgumentMatchers} class
* </p>
*
* @param type the type of the argument being matched.
* @return <code>null</code>.
*/
public static <T> T notNull(Class<T> type) {
reportMatcher(new NotNull(type));
return null;
}

/**
* Not <code>null</code> argument.
*
Expand All @@ -735,6 +770,25 @@ public static <T> T isNotNull() {
return notNull();
}

/**
* Not <code>null</code> argument.
*
* <p>
* Alias to {@link ArgumentMatchers#notNull(Class)}
* </p>
*
* <p>
* See examples in javadoc for {@link ArgumentMatchers} class
* </p>
*
* @param type the type of the argument being matched.
* @return <code>null</code>.
* @see #isNull()
*/
public static <T> T isNotNull(Class<T> type) {
return notNull(type);
}

/**
* Argument that is either <code>null</code> or of the given type.
*
Expand Down
Expand Up @@ -22,6 +22,7 @@ public boolean matches(Object argument) {
return this.matcher.matches(argument);
}

@Deprecated
public boolean isVarargMatcher() {
return matcher instanceof VarargMatcher;
}
Expand Down
Expand Up @@ -58,17 +58,24 @@ public static MatcherApplicationStrategy getMatcherApplicationStrategyFor(
* </ul>
*/
public boolean forEachMatcherAndArgument(ArgumentMatcherAction action) {
if (invocation.getArguments().length == matchers.size()) {
return argsMatch(invocation.getArguments(), matchers, action);
}

final boolean isVararg =
invocation.getMethod().isVarArgs()
&& invocation.getRawArguments().length == matchers.size()
&& isLastMatcherVarargMatcher(matchers);
&& invocation.getRawArguments().length == matchers.size();

if (isVararg) {
int times = varargLength(invocation);
final Class<?> matcherType = lastMatcher().type();
final Class<?> paramType = lastParameterType();
if (paramType.isAssignableFrom(matcherType)) {
return argsMatch(invocation.getRawArguments(), matchers, action);
}
}
TimvdLippe marked this conversation as resolved.
Show resolved Hide resolved

if (invocation.getArguments().length == matchers.size()) {
return argsMatch(invocation.getArguments(), matchers, action);
}

if (isVararg && isLastMatcherVarargMatcher()) {
int times = varargLength();
final List<? extends ArgumentMatcher<?>> matchers = appendLastMatcherNTimes(times);
return argsMatch(invocation.getArguments(), matchers, action);
}
Expand All @@ -91,8 +98,8 @@ private boolean argsMatch(
return true;
}

private static boolean isLastMatcherVarargMatcher(List<? extends ArgumentMatcher<?>> matchers) {
ArgumentMatcher<?> argumentMatcher = lastMatcher(matchers);
private boolean isLastMatcherVarargMatcher() {
ArgumentMatcher<?> argumentMatcher = lastMatcher();
if (argumentMatcher instanceof HamcrestArgumentMatcher<?>) {
return ((HamcrestArgumentMatcher<?>) argumentMatcher).isVarargMatcher();
}
Expand All @@ -101,7 +108,7 @@ private static boolean isLastMatcherVarargMatcher(List<? extends ArgumentMatcher

private List<? extends ArgumentMatcher<?>> appendLastMatcherNTimes(
int timesToAppendLastMatcher) {
ArgumentMatcher<?> lastMatcher = lastMatcher(matchers);
ArgumentMatcher<?> lastMatcher = lastMatcher();

List<ArgumentMatcher<?>> expandedMatchers = new ArrayList<ArgumentMatcher<?>>(matchers);
for (int i = 0; i < timesToAppendLastMatcher; i++) {
Expand All @@ -110,13 +117,18 @@ private List<? extends ArgumentMatcher<?>> appendLastMatcherNTimes(
return expandedMatchers;
}

private static int varargLength(Invocation invocation) {
private int varargLength() {
int rawArgumentCount = invocation.getRawArguments().length;
int expandedArgumentCount = invocation.getArguments().length;
return expandedArgumentCount - rawArgumentCount;
}

private static ArgumentMatcher<?> lastMatcher(List<? extends ArgumentMatcher<?>> matchers) {
private ArgumentMatcher<?> lastMatcher() {
return matchers.get(matchers.size() - 1);
}

private Class<?> lastParameterType() {
final Class<?>[] parameterTypes = invocation.getMethod().getParameterTypes();
return parameterTypes[parameterTypes.length - 1];
}
}
5 changes: 5 additions & 0 deletions src/main/java/org/mockito/internal/matchers/Any.java
Expand Up @@ -21,4 +21,9 @@ public boolean matches(Object actual) {
public String toString() {
return "<any>";
}

@Override
public Class<?> type() {
return Object.class;
}
}
10 changes: 10 additions & 0 deletions src/main/java/org/mockito/internal/matchers/CapturingMatcher.java
Expand Up @@ -19,12 +19,17 @@
public class CapturingMatcher<T>
implements ArgumentMatcher<T>, CapturesArguments, VarargMatcher, Serializable {

private final Class<? extends T> clazz;
private final List<Object> arguments = new ArrayList<>();

private final ReadWriteLock lock = new ReentrantReadWriteLock();
private final Lock readLock = lock.readLock();
private final Lock writeLock = lock.writeLock();

public CapturingMatcher(final Class<? extends T> clazz) {
this.clazz = clazz;
}

@Override
public boolean matches(Object argument) {
return true;
Expand Down Expand Up @@ -66,4 +71,9 @@ public void captureFrom(Object argument) {
writeLock.unlock();
}
}

@Override
public Class<?> type() {
return clazz;
}
}
12 changes: 11 additions & 1 deletion src/main/java/org/mockito/internal/matchers/InstanceOf.java
Expand Up @@ -11,7 +11,7 @@

public class InstanceOf implements ArgumentMatcher<Object>, Serializable {

private final Class<?> clazz;
final Class<?> clazz;
private final String description;

public InstanceOf(Class<?> clazz) {
Expand All @@ -30,6 +30,11 @@ public boolean matches(Object actual) {
|| clazz.isAssignableFrom(actual.getClass()));
}

@Override
public Class<?> type() {
return clazz;
}
TimvdLippe marked this conversation as resolved.
Show resolved Hide resolved

@Override
public String toString() {
return description;
Expand All @@ -44,5 +49,10 @@ public VarArgAware(Class<?> clazz) {
public VarArgAware(Class<?> clazz, String describedAs) {
super(clazz, describedAs);
}

@Override
public Class<?> type() {
return clazz;
}
}
}
15 changes: 12 additions & 3 deletions src/main/java/org/mockito/internal/matchers/NotNull.java
Expand Up @@ -8,17 +8,26 @@

import org.mockito.ArgumentMatcher;

public class NotNull implements ArgumentMatcher<Object>, Serializable {
public class NotNull<T> implements ArgumentMatcher<T>, Serializable {

public static final NotNull NOT_NULL = new NotNull();
public static final NotNull<Object> NOT_NULL = new NotNull<>(Object.class);

private NotNull() {}
private final Class<T> type;

public NotNull(Class<T> type) {
this.type = type;
}

@Override
public boolean matches(Object actual) {
return actual != null;
}

@Override
public Class<T> type() {
return type;
}
TimvdLippe marked this conversation as resolved.
Show resolved Hide resolved

@Override
public String toString() {
return "notNull()";
Expand Down
14 changes: 11 additions & 3 deletions src/main/java/org/mockito/internal/matchers/Null.java
Expand Up @@ -8,17 +8,25 @@

import org.mockito.ArgumentMatcher;

public class Null implements ArgumentMatcher<Object>, Serializable {
public class Null<T> implements ArgumentMatcher<T>, Serializable {

public static final Null NULL = new Null();
public static final Null<Object> NULL = new Null<>(Object.class);
private final Class<T> type;

private Null() {}
public Null(Class<T> type) {
this.type = type;
}

@Override
public boolean matches(Object actual) {
return actual == null;
}

@Override
public Class<T> type() {
return type;
}
TimvdLippe marked this conversation as resolved.
Show resolved Hide resolved

@Override
public String toString() {
return "isNull()";
Expand Down
Expand Up @@ -9,5 +9,8 @@
/**
* Internal interface that informs Mockito that the matcher is intended to capture varargs.
* This information is needed when mockito collects the arguments.
*
* @deprecated this interface is maintained for backwards compatability with user code only.
*/
@Deprecated
public interface VarargMatcher extends Serializable {}
TimvdLippe marked this conversation as resolved.
Show resolved Hide resolved
Expand Up @@ -136,7 +136,7 @@ public void should_be_similar_if_is_overloaded_but_used_with_different_arg() thr
public void should_capture_arguments_from_invocation() throws Exception {
// given
Invocation invocation = new InvocationBuilder().args("1", 100).toInvocation();
CapturingMatcher capturingMatcher = new CapturingMatcher();
CapturingMatcher capturingMatcher = new CapturingMatcher(List.class);
InvocationMatcher invocationMatcher =
new InvocationMatcher(invocation, (List) asList(new Equals("1"), capturingMatcher));

Expand Down Expand Up @@ -167,7 +167,7 @@ public void should_capture_varargs_as_vararg() throws Exception {
// given
mock.mixedVarargs(1, "a", "b");
Invocation invocation = getLastInvocation();
CapturingMatcher m = new CapturingMatcher();
CapturingMatcher m = new CapturingMatcher(List.class);
InvocationMatcher invocationMatcher =
new InvocationMatcher(invocation, Arrays.<ArgumentMatcher>asList(new Equals(1), m));

Expand Down