Skip to content

Commit

Permalink
Add VarHandle and ThreadLocal fallback to StripedBuffer (#273)
Browse files Browse the repository at this point in the history
  • Loading branch information
ben-manes committed Jan 4, 2021
1 parent 7c9085e commit f8ebdae
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 31 deletions.
Expand Up @@ -15,6 +15,8 @@
*/
package com.github.benmanes.caffeine;

import java.lang.invoke.MethodHandles;
import java.lang.invoke.VarHandle;
import java.util.Arrays;
import java.util.concurrent.ThreadLocalRandom;

Expand Down Expand Up @@ -50,8 +52,9 @@
*/
@State(Scope.Benchmark)
public class SlotLookupBenchmark {
static final int ARENA_SIZE = 2 << 6;
static final int SPARSE_SIZE = 2 << 14;
static final int ARENA_SIZE = 2 << 6;
static final VarHandle PROBE;

ThreadLocal<Integer> threadLocal;
long element;
Expand Down Expand Up @@ -148,30 +151,63 @@ public int threadHashCode() {
}

@Benchmark
public long striped64(Blackhole blackhole) {
public long striped64_unsafe(Blackhole blackhole) {
// Emulates finding the arena slot by reusing the thread-local random seed (j.u.c.a.Striped64)
int hash = getProbe();
int hash = getProbe_unsafe();
if (hash == 0) {
blackhole.consume(ThreadLocalRandom.current()); // force initialization
hash = getProbe();
hash = getProbe_unsafe();
}
advanceProbe(hash);
advanceProbe_unsafe(hash);
int index = selectSlot(hash);
return array[index];
}

private int getProbe() {
private int getProbe_unsafe() {
return UnsafeAccess.UNSAFE.getInt(Thread.currentThread(), probeOffset);
}

private void advanceProbe(int probe) {
private void advanceProbe_unsafe(int probe) {
probe ^= probe << 13; // xorshift
probe ^= probe >>> 17;
probe ^= probe << 5;
UnsafeAccess.UNSAFE.putInt(Thread.currentThread(), probeOffset, probe);
}

@Benchmark
public long striped64_varHandle(Blackhole blackhole) {
// Emulates finding the arena slot by reusing the thread-local random seed (j.u.c.a.Striped64)
int hash = getProbe_varHandle();
if (hash == 0) {
blackhole.consume(ThreadLocalRandom.current()); // force initialization
hash = getProbe_varHandle();
}
advanceProbe_varHandle(hash);
int index = selectSlot(hash);
return array[index];
}

private int getProbe_varHandle() {
return (int) PROBE.get(Thread.currentThread());
}

private void advanceProbe_varHandle(int probe) {
probe ^= probe << 13; // xorshift
probe ^= probe >>> 17;
probe ^= probe << 5;
PROBE.set(Thread.currentThread(), probe);
}

private static int selectSlot(int i) {
return i & (ARENA_SIZE - 1);
}

static {
try {
PROBE = MethodHandles.privateLookupIn(Thread.class, MethodHandles.lookup())
.findVarHandle(Thread.class, "threadLocalRandomProbe", int.class);
} catch (ReflectiveOperationException e) {
throw new ExceptionInInitializerError(e);
}
}
}
Expand Up @@ -21,10 +21,15 @@
package com.github.benmanes.caffeine.cache;

import static com.github.benmanes.caffeine.cache.Caffeine.ceilingPowerOfTwo;
import static java.util.Objects.requireNonNull;

import java.lang.invoke.MethodHandles;
import java.lang.invoke.VarHandle;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.Consumer;
import java.util.function.Supplier;

import org.checkerframework.checker.nullness.qual.Nullable;

Expand Down Expand Up @@ -84,8 +89,10 @@ abstract class StripedBuffer<E> implements Buffer<E> {
* again; and for short-lived ones, it does not matter.
*/

static final long TABLE_BUSY = UnsafeAccess.objectFieldOffset(StripedBuffer.class, "tableBusy");
static final long PROBE = UnsafeAccess.objectFieldOffset(Thread.class, "threadLocalRandomProbe");
static final VarHandle TABLE_BUSY;

/** A probe value for the current thread. */
static final Probe PROBE;

/** Number of CPUS. */
static final int NCPU = Runtime.getRuntime().availableProcessors();
Expand All @@ -97,33 +104,22 @@ abstract class StripedBuffer<E> implements Buffer<E> {
static final int ATTEMPTS = 3;

/** Table of buffers. When non-null, size is a power of 2. */
transient volatile Buffer<E> @Nullable[] table;
volatile Buffer<E> @Nullable[] table;

/** Spinlock (locked via CAS) used when resizing and/or creating Buffers. */
transient volatile int tableBusy;
volatile int tableBusy;

/** CASes the tableBusy field from 0 to 1 to acquire lock. */
final boolean casTableBusy() {
return UnsafeAccess.UNSAFE.compareAndSwapInt(this, TABLE_BUSY, 0, 1);
}

/**
* Returns the probe value for the current thread. Duplicated from ThreadLocalRandom because of
* packaging restrictions.
*/
static final int getProbe() {
return UnsafeAccess.UNSAFE.getInt(Thread.currentThread(), PROBE);
return TABLE_BUSY.compareAndSet(this, 0, 1);
}

/**
* Pseudo-randomly advances and records the given probe value for the given thread. Duplicated
* from ThreadLocalRandom because of packaging restrictions.
*/
/** Pseudo-randomly advances and records the given probe value for the given thread. */
static final int advanceProbe(int probe) {
probe ^= probe << 13; // xorshift
probe ^= probe >>> 17;
probe ^= probe << 5;
UnsafeAccess.UNSAFE.putInt(Thread.currentThread(), PROBE, probe);
PROBE.set(probe);
return probe;
}

Expand All @@ -144,7 +140,7 @@ public int offer(E e) {
Buffer<E>[] buffers = table;
if ((buffers == null)
|| (mask = buffers.length - 1) < 0
|| (buffer = buffers[getProbe() & mask]) == null
|| (buffer = buffers[PROBE.get() & mask]) == null
|| !(uncontended = ((result = buffer.offer(e)) != Buffer.FAILED))) {
expandOrRetry(e, uncontended);
}
Expand Down Expand Up @@ -205,9 +201,9 @@ public int writes() {
@SuppressWarnings("PMD.ConfusingTernary")
final void expandOrRetry(E e, boolean wasUncontended) {
int h;
if ((h = getProbe()) == 0) {
ThreadLocalRandom.current(); // force initialization
h = getProbe();
if ((h = PROBE.get()) == 0) {
PROBE.initialize();
h = PROBE.get();
wasUncontended = true;
}
boolean collide = false; // True if last slot nonempty
Expand Down Expand Up @@ -275,4 +271,87 @@ final void expandOrRetry(E e, boolean wasUncontended) {
}
}
}

static {
try {
TABLE_BUSY = MethodHandles.lookup()
.findVarHandle(StripedBuffer.class, "tableBusy", int.class);
} catch (ReflectiveOperationException e) {
throw new ExceptionInInitializerError(e);
}

Probe probe = null;
List<Supplier<Probe>> suppliers = List.of(
UnsafeProbe::new, VarHandleProbe::new, ThreadLocalProbe::new);
for (var supplier : suppliers) {
try {
probe = supplier.get();
break;
} catch (Throwable ignored) { /* Try next strategy */ }
}
PROBE = requireNonNull(probe, "Unable to determine a probe strategy");
}

interface Probe {
int get();
void set(int value);
void initialize();
}

/** Uses the Thread's random probe value, if accessible. */
static final class UnsafeProbe implements Probe {
static final long PROBE = UnsafeAccess.objectFieldOffset(
Thread.class, "threadLocalRandomProbe");

@Override public int get() {
return UnsafeAccess.UNSAFE.getInt(Thread.currentThread(), PROBE);
}
@Override public void set(int probe) {
UnsafeAccess.UNSAFE.putInt(Thread.currentThread(), PROBE, probe);
}
@Override public void initialize() {
ThreadLocalRandom.current(); // force initialization
}
}

/** Uses the Thread's random probe value, if accessible. */
static final class VarHandleProbe implements Probe {
static final VarHandle PROBE;

static {
try {
PROBE = MethodHandles.privateLookupIn(Thread.class, MethodHandles.lookup())
.findVarHandle(Thread.class, "threadLocalRandomProbe", int.class);
} catch (ReflectiveOperationException e) {
throw new ExceptionInInitializerError(e);
}
}

@Override public int get() {
return (int) PROBE.get(Thread.currentThread());
}
@Override public void set(int probe) {
PROBE.set(Thread.currentThread(), probe);
}
@Override public void initialize() {
ThreadLocalRandom.current(); // force initialization
}
}

/** Uses a thread local to maintain a random probe value. */
static final class ThreadLocalProbe implements Probe {
static final ThreadLocal<int[]> threadHashCode = new ThreadLocal<>();

@Override public int get() {
return threadHashCode.get()[0];
}
@Override public void set(int probe) {
threadHashCode.get()[0] = probe;
}
@Override public void initialize() {
// Avoid zero to allow xorShift rehash
int hash = 1 | ThreadLocalRandom.current().nextInt();
threadHashCode.set(new int[] { hash });
}
}
}
Expand Up @@ -18,13 +18,18 @@
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.lessThanOrEqualTo;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;

import java.util.function.Consumer;

import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

import com.github.benmanes.caffeine.cache.StripedBuffer.Probe;
import com.github.benmanes.caffeine.cache.StripedBuffer.ThreadLocalProbe;
import com.github.benmanes.caffeine.cache.StripedBuffer.UnsafeProbe;
import com.github.benmanes.caffeine.cache.StripedBuffer.VarHandleProbe;
import com.github.benmanes.caffeine.testing.ConcurrentTestHarness;
import com.google.common.base.MoreObjects;

Expand All @@ -42,6 +47,24 @@ public void init(FakeBuffer<Integer> buffer) {
assertThat(buffer.table.length, is(1));
}

@Test(dataProvider = "probes")
public void probe(Probe probe) {
probe.initialize();
assertThat(probe.get(), is(not(0)));

probe.set(1);
assertThat(probe.get(), is(1));
}

@DataProvider(name = "probes")
public Object[][] providesProbes() {
return new Object[][] {
{ new UnsafeProbe() },
{ new VarHandleProbe() },
{ new ThreadLocalProbe() },
};
}

@Test(dataProvider = "buffers")
@SuppressWarnings("ThreadPriorityCheck")
public void produce(FakeBuffer<Integer> buffer) {
Expand All @@ -65,8 +88,8 @@ public void drain(FakeBuffer<Integer> buffer) {
assertThat(buffer.drains, is(1));
}

@DataProvider
public Object[][] buffers() {
@DataProvider(name = "buffers")
public Object[][] providesBuffers() {
return new Object[][] {
{ new FakeBuffer<Integer>(Buffer.FULL) },
{ new FakeBuffer<Integer>(Buffer.FAILED) },
Expand Down

0 comments on commit f8ebdae

Please sign in to comment.