Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve vararg handling #2807

Merged
merged 16 commits into from Dec 22, 2022
3 changes: 2 additions & 1 deletion src/main/java/org/mockito/ArgumentCaptor.java
Expand Up @@ -62,11 +62,12 @@
@CheckReturnValue
public class ArgumentCaptor<T> {

private final CapturingMatcher<T> capturingMatcher = new CapturingMatcher<T>();
private final CapturingMatcher<T> capturingMatcher;
private final Class<? extends T> clazz;

private ArgumentCaptor(Class<? extends T> clazz) {
this.clazz = clazz;
this.capturingMatcher = new CapturingMatcher<T>(clazz);
}

/**
Expand Down
40 changes: 40 additions & 0 deletions src/main/java/org/mockito/ArgumentMatcher.java
Expand Up @@ -125,4 +125,44 @@ public interface ArgumentMatcher<T> {
* @return true if this matcher accepts the given argument.
*/
boolean matches(T argument);

/**
* The type of the argument this matcher matches.
*
* <p>This method is used to differentiate between a matcher used to match a raw vararg array parameter
* from a matcher used to match a single value passed as a vararg parameter.
*
* <p>Where the matcher:
* <ul>
* <li>is at the parameter index of a vararg parameter</li>
* <li>is the last matcher passed</li>
* <li>this method returns a type assignable to the vararg parameter's raw type, i.e. its array type.</li>
* </ul>
*
* ...then the matcher is matched against the raw vararg parameter, rather than the first element of the raw parameter.
*
* <p>For example:
*
* <pre class="code"><code class="java">
* // Given vararg method with signature:
* int someVarargMethod(String... args);
*
* // The following will match invocations with any number of parameters, i.e. any number of elements in the raw array.
* mock.someVarargMethod(isA(String[].class));
*
* // The following will match invocations with a single parameter, i.e. one string in the raw array.
* mock.someVarargMethod(isA(String.class));
*
* // The following will match invocations with two parameters, i.e. two strings in the raw array
* mock.someVarargMethod(isA(String.class), isA(String.class));
* </code></pre>
*
* <p>Only matcher implementations that can conceptually match a raw vararg parameter should override this method.
*
* @return the type this matcher handles. The default value of {@link Void} means the type is not known.
* @since 4.10.0
*/
default Class<?> type() {
TimvdLippe marked this conversation as resolved.
Show resolved Hide resolved
return Void.class;
}
}
57 changes: 57 additions & 0 deletions src/main/java/org/mockito/ArgumentMatchers.java
Expand Up @@ -699,6 +699,23 @@ public static <T> T isNull() {
return null;
}

/**
* <code>null</code> argument.
*
* <p>
* See examples in javadoc for {@link ArgumentMatchers} class
* </p>
*
* @param type the type of the argument being matched.
* @return <code>null</code>.
* @see #isNotNull(Class)
* @since 4.10.0
*/
public static <T> T isNull(Class<T> type) {
reportMatcher(new Null<>(type));
return null;
}

/**
* Not <code>null</code> argument.
*
Expand All @@ -717,6 +734,26 @@ public static <T> T notNull() {
return null;
}

/**
* Not <code>null</code> argument.
*
* <p>
* Alias to {@link ArgumentMatchers#isNotNull()}
* </p>
*
* <p>
* See examples in javadoc for {@link ArgumentMatchers} class
* </p>
*
* @param type the type of the argument being matched.
* @return <code>null</code>.
* @since 4.10.0
*/
public static <T> T notNull(Class<T> type) {
reportMatcher(new NotNull<>(type));
return null;
}

/**
* Not <code>null</code> argument.
*
Expand All @@ -735,6 +772,26 @@ public static <T> T isNotNull() {
return notNull();
}

/**
* Not <code>null</code> argument.
*
* <p>
* Alias to {@link ArgumentMatchers#notNull(Class)}
* </p>
*
* <p>
* See examples in javadoc for {@link ArgumentMatchers} class
* </p>
*
* @param type the type of the argument being matched.
* @return <code>null</code>.
* @see #isNull()
* @since 4.10.0
*/
public static <T> T isNotNull(Class<T> type) {
return notNull(type);
}

/**
* Argument that is either <code>null</code> or of the given type.
*
Expand Down
Expand Up @@ -22,6 +22,7 @@ public boolean matches(Object argument) {
return this.matcher.matches(argument);
}

@SuppressWarnings("deprecation")
public boolean isVarargMatcher() {
return matcher instanceof VarargMatcher;
}
Expand Down
Expand Up @@ -58,17 +58,24 @@ public static MatcherApplicationStrategy getMatcherApplicationStrategyFor(
* </ul>
*/
public boolean forEachMatcherAndArgument(ArgumentMatcherAction action) {
final boolean maybeVararg =
invocation.getMethod().isVarArgs()
&& invocation.getRawArguments().length == matchers.size();

if (maybeVararg) {
final Class<?> matcherType = lastMatcher().type();
final Class<?> paramType = lastParameterType();
if (paramType.isAssignableFrom(matcherType)) {
return argsMatch(invocation.getRawArguments(), matchers, action);
}
}
TimvdLippe marked this conversation as resolved.
Show resolved Hide resolved

if (invocation.getArguments().length == matchers.size()) {
return argsMatch(invocation.getArguments(), matchers, action);
}

final boolean isVararg =
invocation.getMethod().isVarArgs()
&& invocation.getRawArguments().length == matchers.size()
&& isLastMatcherVarargMatcher(matchers);

if (isVararg) {
int times = varargLength(invocation);
if (maybeVararg && isLastMatcherVarargMatcher()) {
int times = varargLength();
final List<? extends ArgumentMatcher<?>> matchers = appendLastMatcherNTimes(times);
return argsMatch(invocation.getArguments(), matchers, action);
}
Expand All @@ -91,8 +98,8 @@ private boolean argsMatch(
return true;
}

private static boolean isLastMatcherVarargMatcher(List<? extends ArgumentMatcher<?>> matchers) {
ArgumentMatcher<?> argumentMatcher = lastMatcher(matchers);
private boolean isLastMatcherVarargMatcher() {
ArgumentMatcher<?> argumentMatcher = lastMatcher();
if (argumentMatcher instanceof HamcrestArgumentMatcher<?>) {
return ((HamcrestArgumentMatcher<?>) argumentMatcher).isVarargMatcher();
}
Expand All @@ -101,7 +108,7 @@ private static boolean isLastMatcherVarargMatcher(List<? extends ArgumentMatcher

private List<? extends ArgumentMatcher<?>> appendLastMatcherNTimes(
int timesToAppendLastMatcher) {
ArgumentMatcher<?> lastMatcher = lastMatcher(matchers);
ArgumentMatcher<?> lastMatcher = lastMatcher();

List<ArgumentMatcher<?>> expandedMatchers = new ArrayList<ArgumentMatcher<?>>(matchers);
for (int i = 0; i < timesToAppendLastMatcher; i++) {
Expand All @@ -110,13 +117,18 @@ private List<? extends ArgumentMatcher<?>> appendLastMatcherNTimes(
return expandedMatchers;
}

private static int varargLength(Invocation invocation) {
private int varargLength() {
int rawArgumentCount = invocation.getRawArguments().length;
int expandedArgumentCount = invocation.getArguments().length;
return expandedArgumentCount - rawArgumentCount;
}

private static ArgumentMatcher<?> lastMatcher(List<? extends ArgumentMatcher<?>> matchers) {
private ArgumentMatcher<?> lastMatcher() {
return matchers.get(matchers.size() - 1);
}

private Class<?> lastParameterType() {
final Class<?>[] parameterTypes = invocation.getMethod().getParameterTypes();
return parameterTypes[parameterTypes.length - 1];
}
}
7 changes: 7 additions & 0 deletions src/main/java/org/mockito/internal/matchers/And.java
Expand Up @@ -23,6 +23,13 @@ public boolean matches(Object actual) {
return m1.matches(actual) && m2.matches(actual);
}

@Override
public Class<?> type() {
return m1.type().isAssignableFrom(m2.type())
? m1.type()
: m2.type().isAssignableFrom(m1.type()) ? m2.type() : ArgumentMatcher.super.type();
}

@Override
public String toString() {
return "and(" + m1 + ", " + m2 + ")";
Expand Down
5 changes: 5 additions & 0 deletions src/main/java/org/mockito/internal/matchers/Any.java
Expand Up @@ -21,4 +21,9 @@ public boolean matches(Object actual) {
public String toString() {
return "<any>";
}

@Override
public Class<?> type() {
return Object.class;
}
}
11 changes: 11 additions & 0 deletions src/main/java/org/mockito/internal/matchers/CapturingMatcher.java
Expand Up @@ -9,6 +9,7 @@
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
Expand All @@ -19,12 +20,17 @@
public class CapturingMatcher<T>
implements ArgumentMatcher<T>, CapturesArguments, VarargMatcher, Serializable {

private final Class<? extends T> clazz;
private final List<Object> arguments = new ArrayList<>();

private final ReadWriteLock lock = new ReentrantReadWriteLock();
private final Lock readLock = lock.readLock();
private final Lock writeLock = lock.writeLock();

public CapturingMatcher(final Class<? extends T> clazz) {
this.clazz = Objects.requireNonNull(clazz);
}

@Override
public boolean matches(Object argument) {
return true;
Expand Down Expand Up @@ -66,4 +72,9 @@ public void captureFrom(Object argument) {
writeLock.unlock();
}
}

@Override
public Class<?> type() {
return clazz;
}
}
5 changes: 5 additions & 0 deletions src/main/java/org/mockito/internal/matchers/Equals.java
Expand Up @@ -22,6 +22,11 @@ public boolean matches(Object actual) {
return Equality.areEqual(this.wanted, actual);
}

@Override
public Class<?> type() {
return wanted != null ? wanted.getClass() : ArgumentMatcher.super.type();
}

@Override
public String toString() {
return describe(wanted);
Expand Down
12 changes: 11 additions & 1 deletion src/main/java/org/mockito/internal/matchers/InstanceOf.java
Expand Up @@ -11,7 +11,7 @@

public class InstanceOf implements ArgumentMatcher<Object>, Serializable {

private final Class<?> clazz;
final Class<?> clazz;
private final String description;

public InstanceOf(Class<?> clazz) {
Expand All @@ -30,6 +30,11 @@ public boolean matches(Object actual) {
|| clazz.isAssignableFrom(actual.getClass()));
}

@Override
public Class<?> type() {
return clazz;
}
TimvdLippe marked this conversation as resolved.
Show resolved Hide resolved

@Override
public String toString() {
return description;
Expand All @@ -44,5 +49,10 @@ public VarArgAware(Class<?> clazz) {
public VarArgAware(Class<?> clazz, String describedAs) {
super(clazz, describedAs);
}

@Override
public Class<?> type() {
return clazz;
}
}
}
5 changes: 5 additions & 0 deletions src/main/java/org/mockito/internal/matchers/Not.java
Expand Up @@ -22,6 +22,11 @@ public boolean matches(Object actual) {
return !matcher.matches(actual);
}

@Override
public Class<?> type() {
return matcher.type();
}

@Override
public String toString() {
return "not(" + matcher + ")";
Expand Down
16 changes: 13 additions & 3 deletions src/main/java/org/mockito/internal/matchers/NotNull.java
Expand Up @@ -5,20 +5,30 @@
package org.mockito.internal.matchers;

import java.io.Serializable;
import java.util.Objects;

import org.mockito.ArgumentMatcher;

public class NotNull implements ArgumentMatcher<Object>, Serializable {
public class NotNull<T> implements ArgumentMatcher<T>, Serializable {

public static final NotNull NOT_NULL = new NotNull();
public static final NotNull<Object> NOT_NULL = new NotNull<>(Object.class);

private NotNull() {}
private final Class<T> type;

public NotNull(Class<T> type) {
this.type = Objects.requireNonNull(type);
}

@Override
public boolean matches(Object actual) {
return actual != null;
}

@Override
public Class<T> type() {
return type;
}
TimvdLippe marked this conversation as resolved.
Show resolved Hide resolved

@Override
public String toString() {
return "notNull()";
Expand Down