Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Excessive locking in TypeCachingBytecodeGenerator#BOOTSTRAP_LOCK #3095

Merged
merged 1 commit into from Aug 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -6,6 +6,7 @@

import java.lang.ref.ReferenceQueue;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;

Expand All @@ -14,38 +15,68 @@

class TypeCachingBytecodeGenerator extends ReferenceQueue<ClassLoader>
implements BytecodeGenerator {

private static final Object BOOTSTRAP_LOCK = new Object();
/**
* The size of the {@link #cacheLocks}.
* <p>Caution: This must match the {@link #CACHE_LOCK_MASK}.
*/
private static final int CACHE_LOCK_SIZE = 16;

/**
* The mask to use to mask out the {@link MockitoMockKey#hashCode()} to find the {@link #cacheLocks}.
* <p>Caution: this must match the bits of the {@link #CACHE_LOCK_SIZE}.
*/
private static final int CACHE_LOCK_MASK = 0x0F;

private final BytecodeGenerator bytecodeGenerator;

private final TypeCache<MockitoMockKey> typeCache;

private final ReadWriteLock lock = new ReentrantReadWriteLock();

/**
* This array contains {@link TypeCachingLock} instances, which are used as java monitor locks for
* {@link TypeCache#findOrInsert(ClassLoader, Object, Callable, Object)}.
* The locks spread the lock to acquire over multiple locks instead of using a single lock
* {@code BOOTSTRAP_LOCK} for all {@link MockitoMockKey}.
*
* <p>Note: We can't simply use the mockedType class lock as a lock,
* because the {@link MockitoMockKey}, will be the same for different mockTypes + interfaces.
*
* <p>#3035: Excessive locking in TypeCachingBytecodeGenerator#BOOTSTRAP_LOCK
*/
private final TypeCachingLock[] cacheLocks;

public TypeCachingBytecodeGenerator(BytecodeGenerator bytecodeGenerator, boolean weak) {
this.bytecodeGenerator = bytecodeGenerator;
typeCache =
new TypeCache.WithInlineExpunction<>(
weak ? TypeCache.Sort.WEAK : TypeCache.Sort.SOFT);

this.cacheLocks = new TypeCachingLock[CACHE_LOCK_SIZE];
for (int i = 0; i < CACHE_LOCK_SIZE; i++) {
cacheLocks[i] = new TypeCachingLock();
}
}

@SuppressWarnings("unchecked")
@Override
public <T> Class<T> mockClass(final MockFeatures<T> params) {
lock.readLock().lock();
try {
ClassLoader classLoader = params.mockedType.getClassLoader();
Class<T> mockedType = params.mockedType;
ClassLoader classLoader = mockedType.getClassLoader();
MockitoMockKey key =
new MockitoMockKey(
mockedType,
params.interfaces,
params.serializableMode,
params.stripAnnotations);
return (Class<T>)
typeCache.findOrInsert(
classLoader,
new MockitoMockKey(
params.mockedType,
params.interfaces,
params.serializableMode,
params.stripAnnotations),
key,
() -> bytecodeGenerator.mockClass(params),
BOOTSTRAP_LOCK);
getCacheLockForKey(key));
} catch (IllegalArgumentException exception) {
Throwable cause = exception.getCause();
if (cause instanceof RuntimeException) {
Expand All @@ -58,6 +89,20 @@ public <T> Class<T> mockClass(final MockFeatures<T> params) {
}
}

/**
* Returns a {@link TypeCachingLock}, which locks the {@link TypeCache#findOrInsert(ClassLoader, Object, Callable, Object)}.
*
* @param key the key to lock
* @return the {@link TypeCachingLock} to use to lock the {@link TypeCache}
*/
private TypeCachingLock getCacheLockForKey(MockitoMockKey key) {
int hashCode = key.hashCode();
// Try to spread some higher bits with XOR to lower bits, because we only use lower bits.
hashCode = hashCode ^ (hashCode >>> 16);
int index = hashCode & CACHE_LOCK_MASK;
return cacheLocks[index];
}

@Override
public void mockClassStatic(Class<?> type) {
bytecodeGenerator.mockClassStatic(type);
Expand All @@ -79,6 +124,8 @@ public void clearAllCaches() {
}
}

private static final class TypeCachingLock {}

private static class MockitoMockKey extends TypeCache.SimpleKey {

private final SerializableMode serializableMode;
Expand Down
Expand Up @@ -13,10 +13,17 @@
import java.lang.ref.Reference;
import java.lang.ref.ReferenceQueue;
import java.lang.ref.WeakReference;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.Set;
import java.util.WeakHashMap;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Phaser;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.function.Function;

import org.junit.Before;
import org.junit.Test;
Expand Down Expand Up @@ -208,6 +215,134 @@ public void ensure_cache_returns_same_instance_defaultAnswer() throws Exception
assertThat(cache).isEmpty();
}

@Test
public void cacheLockingStressTest_same_hashcode_different_interface()
throws InterruptedException, TimeoutException {
Class<?>[] classes = cacheLockingInMemClassLoaderClasses();
Class<?> ifA = classes[0];
Class<?> ifB = classes[1];
var featA = newMockFeatures(ifA, ifB);
var featB = newMockFeatures(ifB, ifA);
cacheLockingStressTestImpl(featA, featB);
}

@Test
public void cacheLockingStressTest_same_hashcode_same_interface()
throws InterruptedException, TimeoutException {
Class<?>[] classes = cacheLockingInMemClassLoaderClasses();
Class<?> ifA = classes[0];
var featA = newMockFeatures(ifA);
cacheLockingStressTestImpl(featA, featA);
}

@Test
public void cacheLockingStressTest_different_hashcode()
throws InterruptedException, TimeoutException {
Class<?>[] classes = cacheLockingInMemClassLoaderClasses();
Class<?> ifA = classes[0];
Class<?> ifB = classes[1];
Class<?> ifC = classes[2];
var featA = newMockFeatures(ifA, ifB);
var featB = newMockFeatures(ifB, ifC);
cacheLockingStressTestImpl(featA, featB);
}

@Test
public void cacheLockingStressTest_unrelated_classes()
throws InterruptedException, TimeoutException {
Class<?>[] classes = cacheLockingInMemClassLoaderClasses();
Class<?> ifA = classes[0];
Class<?> ifB = classes[1];
var featA = newMockFeatures(ifA);
var featB = newMockFeatures(ifB);
cacheLockingStressTestImpl(featA, featB);
}

private void cacheLockingStressTestImpl(MockFeatures<?> featA, MockFeatures<?> featB)
throws InterruptedException, TimeoutException {
int iterations = 10_000;

TypeCachingBytecodeGenerator bytecodeGenerator =
new TypeCachingBytecodeGenerator(new SubclassBytecodeGenerator(), true);

Phaser phaser = new Phaser(4);
Function<Runnable, CompletableFuture<Void>> runCode =
code ->
CompletableFuture.runAsync(
() -> {
phaser.arriveAndAwaitAdvance();
try {
for (int i = 0; i < iterations; i++) {
code.run();
}
} finally {
phaser.arrive();
}
});
var mockFeatAFuture =
runCode.apply(
() -> {
Class<?> mockClass = bytecodeGenerator.mockClass(featA);
assertValidMockClass(featA, mockClass);
});

var mockFeatBFuture =
runCode.apply(
() -> {
Class<?> mockClass = bytecodeGenerator.mockClass(featB);
assertValidMockClass(featB, mockClass);
});
var cacheFuture = runCode.apply(bytecodeGenerator::clearAllCaches);
// Start test
phaser.arriveAndAwaitAdvance();
// Wait for test to end
int phase = phaser.arrive();
try {

phaser.awaitAdvanceInterruptibly(phase, 30, TimeUnit.SECONDS);
} finally {
// Collect exceptions from the futures, to make issues visible.
mockFeatAFuture.getNow(null);
mockFeatBFuture.getNow(null);
cacheFuture.getNow(null);
}
}

private static <T> MockFeatures<T> newMockFeatures(
Class<T> mockedType, Class<?>... interfaces) {
return MockFeatures.withMockFeatures(
mockedType,
new HashSet<>(Arrays.asList(interfaces)),
SerializableMode.NONE,
false,
null);
}

private static Class<?>[] cacheLockingInMemClassLoaderClasses() {
ClassLoader inMemClassLoader =
inMemoryClassLoader()
.withClassDefinition("foo.IfA", makeMarkerInterface("foo.IfA"))
.withClassDefinition("foo.IfB", makeMarkerInterface("foo.IfB"))
.withClassDefinition("foo.IfC", makeMarkerInterface("foo.IfC"))
.build();
try {
return new Class[] {
inMemClassLoader.loadClass("foo.IfA"),
inMemClassLoader.loadClass("foo.IfB"),
inMemClassLoader.loadClass("foo.IfC")
};
} catch (ClassNotFoundException e) {
throw new IllegalStateException(e);
}
}

private void assertValidMockClass(MockFeatures<?> mockFeature, Class<?> mockClass) {
assertThat(mockClass).isAssignableTo(mockFeature.mockedType);
for (Class<?> anInterface : mockFeature.interfaces) {
assertThat(mockClass).isAssignableTo(anInterface);
}
}

static class HoldingAReference {
final WeakReference<Class<?>> a;

Expand Down