Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

String/Enum switch - make result handles from enclosing scope accesible #144

Merged
merged 1 commit into from Jan 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 3 additions & 5 deletions src/main/java/io/quarkus/gizmo/AbstractSwitch.java
Expand Up @@ -5,14 +5,12 @@

abstract class AbstractSwitch<T> extends BytecodeCreatorImpl implements Switch<T> {

protected static final Consumer<BytecodeCreator> EMPTY_BLOCK = bc -> {
};

protected boolean fallThrough;
protected Consumer<BytecodeCreator> defaultBlockConsumer;
protected final BytecodeCreatorImpl defaultBlock;

AbstractSwitch(BytecodeCreatorImpl enclosing) {
super(enclosing);
this.defaultBlock = new BytecodeCreatorImpl(this);
}

@Override
Expand All @@ -23,7 +21,7 @@ public void fallThrough() {
@Override
public void defaultCase(Consumer<BytecodeCreator> defatultBlockConsumer) {
Objects.requireNonNull(defatultBlockConsumer);
this.defaultBlockConsumer = defatultBlockConsumer;
defatultBlockConsumer.accept(defaultBlock);
}

@Override
Expand Down
42 changes: 27 additions & 15 deletions src/main/java/io/quarkus/gizmo/EnumSwitchImpl.java
Expand Up @@ -20,7 +20,7 @@

class EnumSwitchImpl<E extends Enum<E>> extends AbstractSwitch<E> implements Switch.EnumSwitch<E> {

private final Map<Integer, Consumer<BytecodeCreator>> ordinalToCaseBlocks;
private final Map<Integer, BytecodeCreatorImpl> ordinalToCaseBlocks;

public EnumSwitchImpl(ResultHandle value, Class<E> enumClass, BytecodeCreatorImpl enclosing) {
super(enclosing);
Expand Down Expand Up @@ -68,20 +68,16 @@ void writeBytecode(MethodVisitor methodVisitor) {
Map<Integer, Label> ordinalToLabel = new HashMap<>();
List<BytecodeCreatorImpl> caseBlocks = new ArrayList<>();

BytecodeCreatorImpl defaultBlock = new BytecodeCreatorImpl(EnumSwitchImpl.this);
if (defaultBlockConsumer != null) {
defaultBlockConsumer.accept(defaultBlock);
}

// Initialize the case blocks
for (Entry<Integer, Consumer<BytecodeCreator>> caseEntry : ordinalToCaseBlocks.entrySet()) {
BytecodeCreatorImpl caseBlock = new BytecodeCreatorImpl(EnumSwitchImpl.this);
Consumer<BytecodeCreator> blockConsumer = caseEntry.getValue();
blockConsumer.accept(caseBlock);
if (blockConsumer != EMPTY_BLOCK && !fallThrough) {
for (Entry<Integer, BytecodeCreatorImpl> caseEntry : ordinalToCaseBlocks.entrySet()) {
BytecodeCreatorImpl caseBlock = caseEntry.getValue();
if (caseBlock != null && !fallThrough) {
caseBlock.breakScope(EnumSwitchImpl.this);
} else if (caseBlock == null) {
// Add empty fall through block
caseBlock = new BytecodeCreatorImpl(EnumSwitchImpl.this);
caseEntry.setValue(caseBlock);
}
caseBlock.findActiveResultHandles(inputHandles);
caseBlocks.add(caseBlock);
ordinalToLabel.put(caseEntry.getKey(), caseBlock.getTop());
}
Expand Down Expand Up @@ -152,19 +148,35 @@ public void caseOf(List<E> values, Consumer<BytecodeCreator> caseBlockConsumer)
for (Iterator<E> it = values.iterator(); it.hasNext();) {
E e = it.next();
if (it.hasNext()) {
addCaseBlock(e, EMPTY_BLOCK);
addCaseBlock(e, null);
} else {
addCaseBlock(e, caseBlockConsumer);
}
}
}

@Override
void findActiveResultHandles(final Set<ResultHandle> handlesToAllocate) {
super.findActiveResultHandles(handlesToAllocate);
for (BytecodeCreatorImpl caseBlock : ordinalToCaseBlocks.values()) {
if (caseBlock != null) {
caseBlock.findActiveResultHandles(handlesToAllocate);
}
}
defaultBlock.findActiveResultHandles(handlesToAllocate);
}

private void addCaseBlock(E value, Consumer<BytecodeCreator> caseBlockConsumer) {
int ordinal = value.ordinal();
if (ordinalToCaseBlocks.containsKey(ordinal)) {
throw new IllegalArgumentException("A case block for the enum value " + value + " already exists");
throw new IllegalArgumentException("A case block for the enum value [" + value + "] already exists");
}
BytecodeCreatorImpl caseBlock = null;
if (caseBlockConsumer != null) {
caseBlock = new BytecodeCreatorImpl(this);
caseBlockConsumer.accept(caseBlock);
}
ordinalToCaseBlocks.put(ordinal, caseBlockConsumer);
ordinalToCaseBlocks.put(ordinal, caseBlock);
}

private MethodCreatorImpl findMethodCreator(BytecodeCreatorImpl enclosing) {
Expand Down
71 changes: 43 additions & 28 deletions src/main/java/io/quarkus/gizmo/StringSwitchImpl.java
Expand Up @@ -16,12 +16,12 @@
import org.objectweb.asm.MethodVisitor;

class StringSwitchImpl extends AbstractSwitch<String> implements Switch.StringSwitch {
private final Map<Integer, List<Entry<String, Consumer<BytecodeCreator>>>> hashToCaseBlocks;

private final Map<String, BytecodeCreatorImpl> caseBlocks;

public StringSwitchImpl(ResultHandle value, BytecodeCreatorImpl enclosing) {
super(enclosing);
this.hashToCaseBlocks = new LinkedHashMap<>();
this.caseBlocks = new LinkedHashMap<>();
ResultHandle strHash = invokeVirtualMethod(MethodDescriptor.ofMethod(Object.class, "hashCode", int.class), value);

Set<ResultHandle> inputHandles = new HashSet<>();
Expand All @@ -34,24 +34,31 @@ public StringSwitchImpl(ResultHandle value, BytecodeCreatorImpl enclosing) {
void writeBytecode(MethodVisitor methodVisitor) {
Map<Integer, Label> hashToLabel = new HashMap<>();
List<BytecodeCreatorImpl> lookupBlocks = new ArrayList<>();
Map<String, BytecodeCreatorImpl> caseBlocks = new LinkedHashMap<>();
BytecodeCreatorImpl defaultBlock = new BytecodeCreatorImpl(StringSwitchImpl.this);
if (defaultBlockConsumer != null) {
defaultBlockConsumer.accept(defaultBlock);
Map<Integer, List<Entry<String, BytecodeCreatorImpl>>> hashToCaseBlocks = new LinkedHashMap<>();

for (Entry<String, BytecodeCreatorImpl> e : caseBlocks.entrySet()) {
int hashCode = e.getKey().hashCode();
List<Entry<String, BytecodeCreatorImpl>> list = hashToCaseBlocks.get(hashCode);
if (list == null) {
list = new ArrayList<>();
hashToCaseBlocks.put(hashCode, list);
}
list.add(e);
}

// Initialize the case blocks and lookup blocks
for (Entry<Integer, List<Entry<String, Consumer<BytecodeCreator>>>> hashEntry : hashToCaseBlocks.entrySet()) {
for (Entry<Integer, List<Entry<String, BytecodeCreatorImpl>>> hashEntry : hashToCaseBlocks.entrySet()) {
BytecodeCreatorImpl lookupBlock = new BytecodeCreatorImpl(StringSwitchImpl.this);
for (Entry<String, Consumer<BytecodeCreator>> caseEntry : hashEntry.getValue()) {
BytecodeCreatorImpl caseBlock = new BytecodeCreatorImpl(StringSwitchImpl.this);
Consumer<BytecodeCreator> blockConsumer = caseEntry.getValue();
blockConsumer.accept(caseBlock);
if (blockConsumer != EMPTY_BLOCK && !fallThrough) {
for (Entry<String, BytecodeCreatorImpl> caseEntry : hashEntry.getValue()) {
BytecodeCreatorImpl caseBlock = caseEntry.getValue();
if (caseBlock != null && !fallThrough) {
caseBlock.breakScope(StringSwitchImpl.this);
} else if (caseBlock == null) {
// TODO empty block
caseBlock = new BytecodeCreatorImpl(StringSwitchImpl.this);
caseEntry.setValue(caseBlock);
}
caseBlock.findActiveResultHandles(inputHandles);
caseBlocks.put(caseEntry.getKey(), caseBlock);
// caseBlock.findActiveResultHandles(inputHandles);
BytecodeCreatorImpl isEqual = (BytecodeCreatorImpl) lookupBlock
.ifTrue(Gizmo.equals(lookupBlock, lookupBlock.load(caseEntry.getKey()), value)).trueBranch();
isEqual.jumpTo(caseBlock);
Expand Down Expand Up @@ -102,6 +109,7 @@ Set<ResultHandle> getInputResultHandles() {
}

});

}

@Override
Expand All @@ -118,27 +126,34 @@ public void caseOf(List<String> values, Consumer<BytecodeCreator> caseBlockConsu
for (Iterator<String> it = values.iterator(); it.hasNext();) {
String s = it.next();
if (it.hasNext()) {
addCaseBlock(s, EMPTY_BLOCK);
addCaseBlock(s, null);
} else {
addCaseBlock(s, caseBlockConsumer);
}
}
}

private void addCaseBlock(String value, Consumer<BytecodeCreator> blockConsumer) {
int hashCode = value.hashCode();
List<Entry<String, Consumer<BytecodeCreator>>> caseBlocks = hashToCaseBlocks.get(hashCode);
if (caseBlocks == null) {
caseBlocks = new ArrayList<>();
hashToCaseBlocks.put(hashCode, caseBlocks);
} else {
for (Entry<String, Consumer<BytecodeCreator>> e : caseBlocks) {
if (e.getKey().equals(value)) {
throw new IllegalArgumentException("A case block for the string value " + value + " already exists");
}
@Override
void findActiveResultHandles(final Set<ResultHandle> handlesToAllocate) {
super.findActiveResultHandles(handlesToAllocate);
for (BytecodeCreatorImpl caseBlock : caseBlocks.values()) {
if (caseBlock != null) {
caseBlock.findActiveResultHandles(handlesToAllocate);
}
}
caseBlocks.add(Map.entry(value, blockConsumer));
defaultBlock.findActiveResultHandles(handlesToAllocate);
}

private void addCaseBlock(String value, Consumer<BytecodeCreator> blockConsumer) {
if (caseBlocks.containsKey(value)) {
throw new IllegalArgumentException("A case block for the string value [" + value + "] already exists");
}
BytecodeCreatorImpl caseBlock = null;
if (blockConsumer != null) {
caseBlock = new BytecodeCreatorImpl(this);
blockConsumer.accept(caseBlock);
}
caseBlocks.put(value, caseBlock);
}

}
72 changes: 70 additions & 2 deletions src/test/java/io/quarkus/gizmo/SwitchTest.java
Expand Up @@ -254,7 +254,7 @@ public void testEnumSwitch() throws InstantiationException, IllegalAccessExcepti
// default: -> return null;
// }
EnumSwitch<Status> s = method.enumSwitch(method.getMethodParam(0), Status.class);
s.caseOf(List.of(Status.ON, Status.OFF), bc -> {
s.caseOf(List.of(Status.OFF, Status.ON), bc -> {
bc.returnValue(Gizmo.toString(bc, method.getMethodParam(0)));
});
s.caseOf(Status.UNKNOWN, bc -> {
Expand Down Expand Up @@ -287,8 +287,8 @@ public void testEnumSwitchFallThrough() throws InstantiationException, IllegalAc
// }
EnumSwitch<Status> s = method.enumSwitch(method.getMethodParam(0), Status.class);
s.fallThrough();
s.caseOf(Status.ON, bc -> bc.assign(ret, bc.load("on")));
s.caseOf(Status.OFF, bc -> bc.assign(ret, bc.load("off")));
s.caseOf(Status.ON, bc -> bc.assign(ret, bc.load("on")));
s.defaultCase(bc -> bc.assign(ret, bc.load("??")));
method.returnValue(ret);
}
Expand Down Expand Up @@ -373,6 +373,74 @@ public void testEnumSwitchDuplicateCase() throws InstantiationException, Illegal
}
}

@SuppressWarnings("unchecked")
@Test
public void testStringSwitchConsumesHandleFromEnclosingScope()
throws InstantiationException, IllegalAccessException, ClassNotFoundException {
TestClassLoader cl = new TestClassLoader(getClass().getClassLoader());
try (ClassCreator creator = ClassCreator.builder().classOutput(cl).className("com.MyTest").interfaces(Function.class)
.build()) {
MethodCreator method = creator.getMethodCreator("apply", Object.class, Object.class);
ResultHandle prefix = Gizmo.toString(method, method.load("p_"));
ResultHandle placeholder = Gizmo.toString(method, method.load("placeholder"));
AssignableResultHandle ret = method.createVariable(String.class);
// String placeholder = "placeholder".toString();
// String prefix = "_p".toString();
// String ret;
// switch(arg) {
// case "bar" -> ret = prefix + "barr";
// default -> ret = placeholder;
// }
// return ret;
StringSwitch s = method.stringSwitch(method.getMethodParam(0));
s.caseOf("bar", bc -> {
bc.assign(ret,
bc.invokeVirtualMethod(MethodDescriptor.ofMethod(String.class, "concat", String.class, String.class),
prefix, bc.load("barr")));
});
s.defaultCase(bc -> bc.assign(ret, placeholder));

method.returnValue(ret);
}
Function<String, String> myInterface = (Function<String, String>) cl.loadClass("com.MyTest").newInstance();
assertEquals("p_barr", myInterface.apply("bar"));
assertEquals("placeholder", myInterface.apply("unknown"));
}

@SuppressWarnings("unchecked")
@Test
public void testEnumSwitchConsumesHandleFromEnclosingScope()
throws InstantiationException, IllegalAccessException, ClassNotFoundException {
TestClassLoader cl = new TestClassLoader(getClass().getClassLoader());
try (ClassCreator creator = ClassCreator.builder().classOutput(cl).className("com.MyTest").interfaces(Function.class)
.build()) {
MethodCreator method = creator.getMethodCreator("apply", Object.class, Object.class);
ResultHandle prefix = Gizmo.toString(method, method.load("p_"));
ResultHandle placeholder = Gizmo.toString(method, method.load("placeholder"));
AssignableResultHandle ret = method.createVariable(String.class);
// String placeholder = "placeholder".toString();
// String prefix = "_p".toString();
// String ret;
// switch(status) {
// case ON -> ret = prefix + "on";
// default: -> return placeholder;
// }
// return ret;
EnumSwitch<Status> s = method.enumSwitch(method.getMethodParam(0), Status.class);
s.caseOf(Status.ON, bc -> {
bc.assign(ret,
bc.invokeVirtualMethod(MethodDescriptor.ofMethod(String.class, "concat", String.class, String.class),
prefix, bc.load("on")));
});
s.defaultCase(bc -> bc.assign(ret, placeholder));

method.returnValue(ret);
}
Function<Status, String> myInterface = (Function<Status, String>) cl.loadClass("com.MyTest").newInstance();
assertEquals("p_on", myInterface.apply(Status.ON));
assertEquals("placeholder", myInterface.apply(Status.UNKNOWN));
}

public enum Status {
ON,
OFF,
Expand Down