Skip to content

Commit

Permalink
Hack about to fix mockito#2796
Browse files Browse the repository at this point in the history
This PR contains changes NOT intended to be committed 'as-is', but as a showcase for a potential solution to:

* mockito#2796
* mockito#1593

And potentially other vararg related issues.

The crux of the issue is that Mockito needs to handle the last matcher passed to a method, when that matcher aligns with a vararg parameter and has the same type as the vararg parameter.

For example,

```java
public interface Foo {
 String m1(String... args);  // Vararg param at index 0 and type String[]
}

@test
public void shouldWork2() throws Exception  {
  // Last matcher at index 0, and with type String[]: needs special handling!
  given(foo.m1(any(String[].class))).willReturn("var arg method");
  ...
}
```

In such situations that code needs to match the raw argument, _not_ the current functionality, which is to use the last matcher to match the last _non raw_ argument.

Unfortunately, I'm not aware of a way to get at the type of the matcher without adding a method to `VarargMatcher`
to get this information. This is the downside of this approach.
  • Loading branch information
big-andy-coates committed Nov 23, 2022
1 parent 3faa002 commit 4b1d533
Show file tree
Hide file tree
Showing 11 changed files with 193 additions and 57 deletions.
3 changes: 2 additions & 1 deletion src/main/java/org/mockito/ArgumentCaptor.java
Expand Up @@ -62,11 +62,12 @@
@CheckReturnValue
public class ArgumentCaptor<T> {

private final CapturingMatcher<T> capturingMatcher = new CapturingMatcher<T>();
private final CapturingMatcher<T> capturingMatcher;
private final Class<? extends T> clazz;

private ArgumentCaptor(Class<? extends T> clazz) {
this.clazz = clazz;
this.capturingMatcher = new CapturingMatcher<T>(clazz);
}

/**
Expand Down
Expand Up @@ -9,6 +9,8 @@
import org.mockito.ArgumentMatcher;
import org.mockito.internal.matchers.VarargMatcher;

import java.util.Optional;

public class HamcrestArgumentMatcher<T> implements ArgumentMatcher<T> {

private final Matcher matcher;
Expand All @@ -26,6 +28,10 @@ public boolean isVarargMatcher() {
return matcher instanceof VarargMatcher;
}

public Optional<VarargMatcher> varargMatcher() {
return isVarargMatcher() ? Optional.of((VarargMatcher) matcher) : Optional.empty();
}

@Override
public String toString() {
// TODO SF add unit tests and integ test coverage for toString()
Expand Down
Expand Up @@ -4,12 +4,10 @@
*/
package org.mockito.internal.invocation;

import static org.mockito.internal.invocation.MatcherApplicationStrategy.MatcherApplicationType.ERROR_UNSUPPORTED_NUMBER_OF_MATCHERS;
import static org.mockito.internal.invocation.MatcherApplicationStrategy.MatcherApplicationType.MATCH_EACH_VARARGS_WITH_LAST_MATCHER;
import static org.mockito.internal.invocation.MatcherApplicationStrategy.MatcherApplicationType.ONE_MATCHER_PER_ARGUMENT;

import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;

import org.mockito.ArgumentMatcher;
import org.mockito.internal.hamcrest.HamcrestArgumentMatcher;
Expand All @@ -21,21 +19,12 @@ public class MatcherApplicationStrategy {

private final Invocation invocation;
private final List<ArgumentMatcher<?>> matchers;
private final MatcherApplicationType matchingType;

private MatcherApplicationStrategy(
Invocation invocation,
List<ArgumentMatcher<?>> matchers,
MatcherApplicationType matchingType) {
List<ArgumentMatcher<?>> matchers) {
this.invocation = invocation;
if (matchingType == MATCH_EACH_VARARGS_WITH_LAST_MATCHER) {
int times = varargLength(invocation);
this.matchers = appendLastMatcherNTimes(matchers, times);
} else {
this.matchers = matchers;
}

this.matchingType = matchingType;
this.matchers = matchers;
}

/**
Expand All @@ -53,8 +42,7 @@ private MatcherApplicationStrategy(
public static MatcherApplicationStrategy getMatcherApplicationStrategyFor(
Invocation invocation, List<ArgumentMatcher<?>> matchers) {

MatcherApplicationType type = getMatcherApplicationType(invocation, matchers);
return new MatcherApplicationStrategy(invocation, matchers, type);
return new MatcherApplicationStrategy(invocation, matchers);
}

/**
Expand All @@ -74,11 +62,35 @@ public static MatcherApplicationStrategy getMatcherApplicationStrategyFor(
* </ul>
*/
public boolean forEachMatcherAndArgument(ArgumentMatcherAction action) {
if (matchingType == ERROR_UNSUPPORTED_NUMBER_OF_MATCHERS) {
return false;
final boolean isVararg = invocation.getMethod().isVarArgs()
&& invocation.getRawArguments().length == matchers.size()
&& getLastVarargMatcher(matchers).isPresent();

if (isVararg) {
final Type type = getLastVarargMatcher(matchers).get().type();
final Class<?> varArgType = invocation.getMethod().getParameterTypes()[invocation.getMethod().getParameterTypes().length - 1];
if (type.equals(varArgType)) {
return argsMatch(invocation.getRawArguments(), matchers, action);
}
}

if (invocation.getArguments().length == matchers.size()) {
return argsMatch(invocation.getArguments(), matchers, action);
}

if (isVararg) {
int times = varargLength(invocation);
final List<ArgumentMatcher<?>> matchers = appendLastMatcherNTimes(this.matchers, times);
return argsMatch(invocation.getArguments(), matchers, action);
}

Object[] arguments = invocation.getArguments();
return false;
}

private boolean argsMatch(final Object[] arguments,
final List<ArgumentMatcher<?>> matchers,
final ArgumentMatcherAction action) {

for (int i = 0; i < arguments.length; i++) {
ArgumentMatcher<?> matcher = matchers.get(i);
Object argument = arguments[i];
Expand All @@ -90,29 +102,12 @@ public boolean forEachMatcherAndArgument(ArgumentMatcherAction action) {
return true;
}

private static MatcherApplicationType getMatcherApplicationType(
Invocation invocation, List<ArgumentMatcher<?>> matchers) {
final int rawArguments = invocation.getRawArguments().length;
final int expandedArguments = invocation.getArguments().length;
final int matcherCount = matchers.size();

if (expandedArguments == matcherCount) {
return ONE_MATCHER_PER_ARGUMENT;
}

if (rawArguments == matcherCount && isLastMatcherVarargMatcher(matchers)) {
return MATCH_EACH_VARARGS_WITH_LAST_MATCHER;
}

return ERROR_UNSUPPORTED_NUMBER_OF_MATCHERS;
}

private static boolean isLastMatcherVarargMatcher(final List<ArgumentMatcher<?>> matchers) {
private static Optional<VarargMatcher> getLastVarargMatcher(final List<ArgumentMatcher<?>> matchers) {
ArgumentMatcher<?> argumentMatcher = lastMatcher(matchers);
if (argumentMatcher instanceof HamcrestArgumentMatcher<?>) {
return ((HamcrestArgumentMatcher<?>) argumentMatcher).isVarargMatcher();
return ((HamcrestArgumentMatcher<?>) argumentMatcher).varargMatcher();
}
return argumentMatcher instanceof VarargMatcher;
return argumentMatcher instanceof VarargMatcher ? Optional.of((VarargMatcher) argumentMatcher) : Optional.empty();
}

private static List<ArgumentMatcher<?>> appendLastMatcherNTimes(
Expand All @@ -135,10 +130,4 @@ private static int varargLength(Invocation invocation) {
private static ArgumentMatcher<?> lastMatcher(List<ArgumentMatcher<?>> matchers) {
return matchers.get(matchers.size() - 1);
}

enum MatcherApplicationType {
ONE_MATCHER_PER_ARGUMENT,
MATCH_EACH_VARARGS_WITH_LAST_MATCHER,
ERROR_UNSUPPORTED_NUMBER_OF_MATCHERS;
}
}
6 changes: 6 additions & 0 deletions src/main/java/org/mockito/internal/matchers/Any.java
Expand Up @@ -5,6 +5,7 @@
package org.mockito.internal.matchers;

import java.io.Serializable;
import java.lang.reflect.Type;

import org.mockito.ArgumentMatcher;

Expand All @@ -21,4 +22,9 @@ public boolean matches(Object actual) {
public String toString() {
return "<any>";
}

@Override
public Type type() {
return Object.class;
}
}
11 changes: 11 additions & 0 deletions src/main/java/org/mockito/internal/matchers/CapturingMatcher.java
Expand Up @@ -7,6 +7,7 @@
import static org.mockito.internal.exceptions.Reporter.noArgumentValueWasCaptured;

import java.io.Serializable;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.locks.Lock;
Expand All @@ -19,12 +20,17 @@
public class CapturingMatcher<T>
implements ArgumentMatcher<T>, CapturesArguments, VarargMatcher, Serializable {

private final Class<? extends T> clazz;
private final List<Object> arguments = new ArrayList<>();

private final ReadWriteLock lock = new ReentrantReadWriteLock();
private final Lock readLock = lock.readLock();
private final Lock writeLock = lock.writeLock();

public CapturingMatcher(final Class<? extends T> clazz) {
this.clazz = clazz;
}

@Override
public boolean matches(Object argument) {
return true;
Expand Down Expand Up @@ -66,4 +72,9 @@ public void captureFrom(Object argument) {
writeLock.unlock();
}
}

@Override
public Type type() {
return clazz;
}
}
8 changes: 7 additions & 1 deletion src/main/java/org/mockito/internal/matchers/InstanceOf.java
Expand Up @@ -5,13 +5,14 @@
package org.mockito.internal.matchers;

import java.io.Serializable;
import java.lang.reflect.Type;

import org.mockito.ArgumentMatcher;
import org.mockito.internal.util.Primitives;

public class InstanceOf implements ArgumentMatcher<Object>, Serializable {

private final Class<?> clazz;
final Class<?> clazz;
private final String description;

public InstanceOf(Class<?> clazz) {
Expand Down Expand Up @@ -44,5 +45,10 @@ public VarArgAware(Class<?> clazz) {
public VarArgAware(Class<?> clazz, String describedAs) {
super(clazz, describedAs);
}

@Override
public Type type() {
return clazz;
}
}
}
Expand Up @@ -5,9 +5,14 @@
package org.mockito.internal.matchers;

import java.io.Serializable;
import java.lang.reflect.Type;

/**
* Internal interface that informs Mockito that the matcher is intended to capture varargs.
* This information is needed when mockito collects the arguments.
*/
public interface VarargMatcher extends Serializable {}
public interface VarargMatcher extends Serializable {

// Todo: default impl to avoid compatability issues?
Type type();
}
Expand Up @@ -136,7 +136,7 @@ public void should_be_similar_if_is_overloaded_but_used_with_different_arg() thr
public void should_capture_arguments_from_invocation() throws Exception {
// given
Invocation invocation = new InvocationBuilder().args("1", 100).toInvocation();
CapturingMatcher capturingMatcher = new CapturingMatcher();
CapturingMatcher capturingMatcher = new CapturingMatcher(List.class);
InvocationMatcher invocationMatcher =
new InvocationMatcher(invocation, (List) asList(new Equals("1"), capturingMatcher));

Expand Down Expand Up @@ -167,7 +167,7 @@ public void should_capture_varargs_as_vararg() throws Exception {
// given
mock.mixedVarargs(1, "a", "b");
Invocation invocation = getLastInvocation();
CapturingMatcher m = new CapturingMatcher();
CapturingMatcher m = new CapturingMatcher(List.class);
InvocationMatcher invocationMatcher =
new InvocationMatcher(invocation, Arrays.<ArgumentMatcher>asList(new Equals(1), m));

Expand Down
Expand Up @@ -12,6 +12,7 @@
import static org.mockito.internal.invocation.MatcherApplicationStrategy.getMatcherApplicationStrategyFor;
import static org.mockito.internal.matchers.Any.ANY;

import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.List;

Expand Down Expand Up @@ -230,6 +231,11 @@ public boolean matches(Object o) {
}

public void describeTo(Description description) {}

@Override
public Type type() {
return Integer.class;
}
}

private Invocation mixedVarargs(Object a, String... s) {
Expand Down
Expand Up @@ -19,7 +19,7 @@ public class CapturingMatcherTest extends TestBase {
@Test
public void should_capture_arguments() throws Exception {
// given
CapturingMatcher<String> m = new CapturingMatcher<String>();
CapturingMatcher<String> m = new CapturingMatcher<String>(String.class);

// when
m.captureFrom("foo");
Expand All @@ -32,7 +32,7 @@ public void should_capture_arguments() throws Exception {
@Test
public void should_know_last_captured_value() throws Exception {
// given
CapturingMatcher<String> m = new CapturingMatcher<String>();
CapturingMatcher<String> m = new CapturingMatcher<String>(String.class);

// when
m.captureFrom("foo");
Expand All @@ -45,7 +45,7 @@ public void should_know_last_captured_value() throws Exception {
@Test
public void should_scream_when_nothing_yet_captured() throws Exception {
// given
CapturingMatcher<String> m = new CapturingMatcher<String>();
CapturingMatcher<String> m = new CapturingMatcher<String>(String.class);

try {
// when
Expand All @@ -59,7 +59,7 @@ public void should_scream_when_nothing_yet_captured() throws Exception {
@Test
public void should_not_fail_when_used_in_concurrent_tests() throws Exception {
// given
final CapturingMatcher<String> m = new CapturingMatcher<String>();
final CapturingMatcher<String> m = new CapturingMatcher<String>(String.class);

// when
m.captureFrom("concurrent access");
Expand Down

0 comments on commit 4b1d533

Please sign in to comment.