Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the StringSwitch and EnumSwitch constructs #143

Merged
merged 1 commit into from Jan 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
34 changes: 34 additions & 0 deletions src/main/java/io/quarkus/gizmo/AbstractSwitch.java
@@ -0,0 +1,34 @@
package io.quarkus.gizmo;

import java.util.Objects;
import java.util.function.Consumer;

abstract class AbstractSwitch<T> extends BytecodeCreatorImpl implements Switch<T> {

protected static final Consumer<BytecodeCreator> EMPTY_BLOCK = bc -> {
};

protected boolean fallThrough;
protected Consumer<BytecodeCreator> defaultBlockConsumer;

AbstractSwitch(BytecodeCreatorImpl enclosing) {
super(enclosing);
}

@Override
public void fallThrough() {
fallThrough = true;
}

@Override
public void defaultCase(Consumer<BytecodeCreator> defatultBlockConsumer) {
Objects.requireNonNull(defatultBlockConsumer);
this.defaultBlockConsumer = defatultBlockConsumer;
}

@Override
public void doBreak(BytecodeCreator creator) {
creator.breakScope(this);
}

}
18 changes: 18 additions & 0 deletions src/main/java/io/quarkus/gizmo/BytecodeCreator.java
Expand Up @@ -1033,6 +1033,24 @@ default ResultHandle increment(ResultHandle toIncrement) {
return add(toIncrement, load(1));
}

/**
* Create a new switch construct for a string value.
*
* @param value The string value to switch on
* @return the switch construct
*/
Switch.StringSwitch stringSwitch(ResultHandle value);

/**
* Create a new switch construct for an enum constant.
*
* @param <E>
* @param value The enum constant to switch on
* @param enumClass
* @return the switch construct
*/
<E extends Enum<E>> Switch.EnumSwitch<E> enumSwitch(ResultHandle value, Class<E> enumClass);

/**
* Indicate that the scope is no longer in use. The scope may refuse additional instructions after this method
* is called.
Expand Down
27 changes: 27 additions & 0 deletions src/main/java/io/quarkus/gizmo/BytecodeCreatorImpl.java
Expand Up @@ -801,6 +801,16 @@ public BytecodeCreator createScope() {
operations.add(new BlockOperation(enclosed));
return enclosed;
}

/**
* Go the the top of the given scope. Unlike {@link #continueScope(BytecodeCreator)} this method does not verify if this
* bytecode creator is scoped within the given bytecode creator.
*
* @param scope
*/
void jumpTo(BytecodeCreator scope) {
operations.add(new JumpOperation(((BytecodeCreatorImpl) scope).top));
}

static void storeResultHandle(MethodVisitor methodVisitor, ResultHandle handle) {
if (handle.getResultType() == ResultHandle.ResultType.UNUSED) {
Expand Down Expand Up @@ -1318,6 +1328,23 @@ public ResultHandle bitwiseXor(ResultHandle a1, ResultHandle a2) {
return emitBinaryArithmetic(Opcodes.IXOR, a1, a2);
}

@Override
public Switch.StringSwitch stringSwitch(ResultHandle value) {
Objects.requireNonNull(value);
StringSwitchImpl stringSwitch = new StringSwitchImpl(value, this);
operations.add(new BlockOperation(stringSwitch));
return stringSwitch;
}

@Override
public <E extends Enum<E>> Switch.EnumSwitch<E> enumSwitch(ResultHandle value, Class<E> enumClass) {
Objects.requireNonNull(value);
Objects.requireNonNull(enumClass);
EnumSwitchImpl<E> enumSwitch = new EnumSwitchImpl<>(value, enumClass, this);
operations.add(new BlockOperation(enumSwitch));
return enumSwitch;
}

private ResultHandle emitBinaryArithmetic(int intOpcode, ResultHandle a1, ResultHandle a2) {
Objects.requireNonNull(a1);
Objects.requireNonNull(a2);
Expand Down
193 changes: 193 additions & 0 deletions src/main/java/io/quarkus/gizmo/EnumSwitchImpl.java
@@ -0,0 +1,193 @@
package io.quarkus.gizmo;

import static org.objectweb.asm.Opcodes.ACC_PRIVATE;
import static org.objectweb.asm.Opcodes.ACC_STATIC;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Set;
import java.util.function.Consumer;

import org.objectweb.asm.Label;
import org.objectweb.asm.MethodVisitor;

class EnumSwitchImpl<E extends Enum<E>> extends AbstractSwitch<E> implements Switch.EnumSwitch<E> {

private final Map<Integer, Consumer<BytecodeCreator>> ordinalToCaseBlocks;

public EnumSwitchImpl(ResultHandle value, Class<E> enumClass, BytecodeCreatorImpl enclosing) {
super(enclosing);
this.ordinalToCaseBlocks = new LinkedHashMap<>();

MethodDescriptor enumOrdinal = MethodDescriptor.ofMethod(enumClass, "ordinal", int.class);
ResultHandle ordinal = invokeVirtualMethod(enumOrdinal, value);

// Generate the int[] switch table needed for binary compatibility
ResultHandle switchTable;
MethodCreatorImpl methodCreator = findMethodCreator(enclosing);
if (methodCreator != null) {
// Generate a static method that returns the switch table
char sep = '$';
ClassCreator classCreator = methodCreator.getClassCreator();
// $GIZMO_SWITCH_TABLE$org$acme$MyEnum()
StringBuilder methodName = new StringBuilder();
methodName.append(sep).append("GIZMO_SWITCH_TABLE");
for (String part : enumClass.getName().split("\\.")) {
methodName.append(sep).append(part);
}
MethodDescriptor gizmoSwitchTableDescriptor = MethodDescriptor.ofMethod(classCreator.getClassName(),
methodName.toString(), int[].class);
if (!classCreator.getExistingMethods()
.contains(gizmoSwitchTableDescriptor)) {
MethodCreator gizmoSwitchTable = classCreator.getMethodCreator(gizmoSwitchTableDescriptor)
.setModifiers(ACC_PRIVATE | ACC_STATIC);
gizmoSwitchTable.returnValue(generateSwitchTable(enumClass, gizmoSwitchTable, enumOrdinal));
}
switchTable = invokeStaticMethod(gizmoSwitchTableDescriptor);
} else {
// This is suboptimal - the switch table is generated for each switch construct
switchTable = generateSwitchTable(enumClass, methodCreator, enumOrdinal);
}
ResultHandle effectiveOrdinal = readArrayValue(switchTable, ordinal);

Set<ResultHandle> inputHandles = new HashSet<>();
inputHandles.add(effectiveOrdinal);

operations.add(new Operation() {

@Override
void writeBytecode(MethodVisitor methodVisitor) {
E[] constants = enumClass.getEnumConstants();
Map<Integer, Label> ordinalToLabel = new HashMap<>();
List<BytecodeCreatorImpl> caseBlocks = new ArrayList<>();

BytecodeCreatorImpl defaultBlock = new BytecodeCreatorImpl(EnumSwitchImpl.this);
if (defaultBlockConsumer != null) {
defaultBlockConsumer.accept(defaultBlock);
}

// Initialize the case blocks
for (Entry<Integer, Consumer<BytecodeCreator>> caseEntry : ordinalToCaseBlocks.entrySet()) {
BytecodeCreatorImpl caseBlock = new BytecodeCreatorImpl(EnumSwitchImpl.this);
Consumer<BytecodeCreator> blockConsumer = caseEntry.getValue();
blockConsumer.accept(caseBlock);
if (blockConsumer != EMPTY_BLOCK && !fallThrough) {
caseBlock.breakScope(EnumSwitchImpl.this);
}
caseBlock.findActiveResultHandles(inputHandles);
caseBlocks.add(caseBlock);
ordinalToLabel.put(caseEntry.getKey(), caseBlock.getTop());
}

int min = ordinalToLabel.keySet().stream().mapToInt(Integer::intValue).min().orElse(0);
int max = ordinalToLabel.keySet().stream().mapToInt(Integer::intValue).max().orElse(0);

// Add empty blocks for missing ordinals
// This would be suboptimal for cases if there is a large number of missing ordinals
for (int i = 0; i < constants.length; i++) {
if (i >= min && i <= max) {
if (ordinalToLabel.get(i) == null) {
BytecodeCreatorImpl emptyCaseBlock = new BytecodeCreatorImpl(EnumSwitchImpl.this);
caseBlocks.add(emptyCaseBlock);
ordinalToLabel.put(i, emptyCaseBlock.getTop());
}
}
}

// Load the ordinal of the tested value
loadResultHandle(methodVisitor, effectiveOrdinal, EnumSwitchImpl.this, "I");

int[] ordinals = ordinalToLabel.keySet().stream().mapToInt(Integer::intValue).sorted().toArray();
Label[] labels = new Label[ordinals.length];
for (int i = 0; i < ordinals.length; i++) {
labels[i] = ordinalToLabel.get(ordinals[i]);
}
methodVisitor.visitTableSwitchInsn(min, max, defaultBlock.getTop(), labels);

// Write the case blocks
for (BytecodeCreatorImpl caseBlock : caseBlocks) {
caseBlock.writeOperations(methodVisitor);
}

// Write the default block
defaultBlock.writeOperations(methodVisitor);
}

@Override
ResultHandle getTopResultHandle() {
return null;
}

@Override
ResultHandle getOutgoingResultHandle() {
return null;
}

@Override
Set<ResultHandle> getInputResultHandles() {
return inputHandles;
}

});
}

@Override
public void caseOf(E value, Consumer<BytecodeCreator> caseBlockConsumer) {
Objects.requireNonNull(value);
Objects.requireNonNull(caseBlockConsumer);
addCaseBlock(value, caseBlockConsumer);
}

@Override
public void caseOf(List<E> values, Consumer<BytecodeCreator> caseBlockConsumer) {
Objects.requireNonNull(values);
Objects.requireNonNull(caseBlockConsumer);
for (Iterator<E> it = values.iterator(); it.hasNext();) {
E e = it.next();
if (it.hasNext()) {
addCaseBlock(e, EMPTY_BLOCK);
} else {
addCaseBlock(e, caseBlockConsumer);
}
}
}

private void addCaseBlock(E value, Consumer<BytecodeCreator> caseBlockConsumer) {
int ordinal = value.ordinal();
if (ordinalToCaseBlocks.containsKey(ordinal)) {
throw new IllegalArgumentException("A case block for the enum value " + value + " already exists");
}
ordinalToCaseBlocks.put(ordinal, caseBlockConsumer);
}

private MethodCreatorImpl findMethodCreator(BytecodeCreatorImpl enclosing) {
if (enclosing instanceof MethodCreatorImpl) {
return (MethodCreatorImpl) enclosing;
}
if (enclosing.getOwner() != null) {
return findMethodCreator(enclosing.getOwner());
}
return null;
}

private ResultHandle generateSwitchTable(Class<E> enumClass, BytecodeCreator bytecodeCreator,
MethodDescriptor enumOrdinal) {
E[] constants = enumClass.getEnumConstants();
ResultHandle switchTable = bytecodeCreator.newArray(int.class, constants.length);
for (int i = 0; i < constants.length; i++) {
ResultHandle currentConstant = bytecodeCreator
.readStaticField(FieldDescriptor.of(enumClass, constants[i].name(), enumClass));
ResultHandle currentOrdinal = bytecodeCreator.invokeVirtualMethod(enumOrdinal, currentConstant);
bytecodeCreator.writeArrayValue(switchTable, i, currentOrdinal);
}
return switchTable;
}

}