diff --git a/src/main/java/io/quarkus/gizmo/AbstractSwitch.java b/src/main/java/io/quarkus/gizmo/AbstractSwitch.java index 0089509..28689bf 100644 --- a/src/main/java/io/quarkus/gizmo/AbstractSwitch.java +++ b/src/main/java/io/quarkus/gizmo/AbstractSwitch.java @@ -5,14 +5,12 @@ abstract class AbstractSwitch extends BytecodeCreatorImpl implements Switch { - protected static final Consumer EMPTY_BLOCK = bc -> { - }; - protected boolean fallThrough; - protected Consumer defaultBlockConsumer; + protected final BytecodeCreatorImpl defaultBlock; AbstractSwitch(BytecodeCreatorImpl enclosing) { super(enclosing); + this.defaultBlock = new BytecodeCreatorImpl(this); } @Override @@ -23,7 +21,7 @@ public void fallThrough() { @Override public void defaultCase(Consumer defatultBlockConsumer) { Objects.requireNonNull(defatultBlockConsumer); - this.defaultBlockConsumer = defatultBlockConsumer; + defatultBlockConsumer.accept(defaultBlock); } @Override diff --git a/src/main/java/io/quarkus/gizmo/EnumSwitchImpl.java b/src/main/java/io/quarkus/gizmo/EnumSwitchImpl.java index 19c6728..39354ea 100644 --- a/src/main/java/io/quarkus/gizmo/EnumSwitchImpl.java +++ b/src/main/java/io/quarkus/gizmo/EnumSwitchImpl.java @@ -20,7 +20,7 @@ class EnumSwitchImpl> extends AbstractSwitch implements Switch.EnumSwitch { - private final Map> ordinalToCaseBlocks; + private final Map ordinalToCaseBlocks; public EnumSwitchImpl(ResultHandle value, Class enumClass, BytecodeCreatorImpl enclosing) { super(enclosing); @@ -68,20 +68,16 @@ void writeBytecode(MethodVisitor methodVisitor) { Map ordinalToLabel = new HashMap<>(); List caseBlocks = new ArrayList<>(); - BytecodeCreatorImpl defaultBlock = new BytecodeCreatorImpl(EnumSwitchImpl.this); - if (defaultBlockConsumer != null) { - defaultBlockConsumer.accept(defaultBlock); - } - // Initialize the case blocks - for (Entry> caseEntry : ordinalToCaseBlocks.entrySet()) { - BytecodeCreatorImpl caseBlock = new BytecodeCreatorImpl(EnumSwitchImpl.this); - Consumer blockConsumer = caseEntry.getValue(); - blockConsumer.accept(caseBlock); - if (blockConsumer != EMPTY_BLOCK && !fallThrough) { + for (Entry caseEntry : ordinalToCaseBlocks.entrySet()) { + BytecodeCreatorImpl caseBlock = caseEntry.getValue(); + if (caseBlock != null && !fallThrough) { caseBlock.breakScope(EnumSwitchImpl.this); + } else if (caseBlock == null) { + // Add empty fall through block + caseBlock = new BytecodeCreatorImpl(EnumSwitchImpl.this); + caseEntry.setValue(caseBlock); } - caseBlock.findActiveResultHandles(inputHandles); caseBlocks.add(caseBlock); ordinalToLabel.put(caseEntry.getKey(), caseBlock.getTop()); } @@ -152,19 +148,35 @@ public void caseOf(List values, Consumer caseBlockConsumer) for (Iterator it = values.iterator(); it.hasNext();) { E e = it.next(); if (it.hasNext()) { - addCaseBlock(e, EMPTY_BLOCK); + addCaseBlock(e, null); } else { addCaseBlock(e, caseBlockConsumer); } } } + @Override + void findActiveResultHandles(final Set handlesToAllocate) { + super.findActiveResultHandles(handlesToAllocate); + for (BytecodeCreatorImpl caseBlock : ordinalToCaseBlocks.values()) { + if (caseBlock != null) { + caseBlock.findActiveResultHandles(handlesToAllocate); + } + } + defaultBlock.findActiveResultHandles(handlesToAllocate); + } + private void addCaseBlock(E value, Consumer caseBlockConsumer) { int ordinal = value.ordinal(); if (ordinalToCaseBlocks.containsKey(ordinal)) { - throw new IllegalArgumentException("A case block for the enum value " + value + " already exists"); + throw new IllegalArgumentException("A case block for the enum value [" + value + "] already exists"); + } + BytecodeCreatorImpl caseBlock = null; + if (caseBlockConsumer != null) { + caseBlock = new BytecodeCreatorImpl(this); + caseBlockConsumer.accept(caseBlock); } - ordinalToCaseBlocks.put(ordinal, caseBlockConsumer); + ordinalToCaseBlocks.put(ordinal, caseBlock); } private MethodCreatorImpl findMethodCreator(BytecodeCreatorImpl enclosing) { diff --git a/src/main/java/io/quarkus/gizmo/StringSwitchImpl.java b/src/main/java/io/quarkus/gizmo/StringSwitchImpl.java index 703774d..1d824d3 100644 --- a/src/main/java/io/quarkus/gizmo/StringSwitchImpl.java +++ b/src/main/java/io/quarkus/gizmo/StringSwitchImpl.java @@ -16,12 +16,12 @@ import org.objectweb.asm.MethodVisitor; class StringSwitchImpl extends AbstractSwitch implements Switch.StringSwitch { - - private final Map>>> hashToCaseBlocks; + + private final Map caseBlocks; public StringSwitchImpl(ResultHandle value, BytecodeCreatorImpl enclosing) { super(enclosing); - this.hashToCaseBlocks = new LinkedHashMap<>(); + this.caseBlocks = new LinkedHashMap<>(); ResultHandle strHash = invokeVirtualMethod(MethodDescriptor.ofMethod(Object.class, "hashCode", int.class), value); Set inputHandles = new HashSet<>(); @@ -34,24 +34,31 @@ public StringSwitchImpl(ResultHandle value, BytecodeCreatorImpl enclosing) { void writeBytecode(MethodVisitor methodVisitor) { Map hashToLabel = new HashMap<>(); List lookupBlocks = new ArrayList<>(); - Map caseBlocks = new LinkedHashMap<>(); - BytecodeCreatorImpl defaultBlock = new BytecodeCreatorImpl(StringSwitchImpl.this); - if (defaultBlockConsumer != null) { - defaultBlockConsumer.accept(defaultBlock); + Map>> hashToCaseBlocks = new LinkedHashMap<>(); + + for (Entry e : caseBlocks.entrySet()) { + int hashCode = e.getKey().hashCode(); + List> list = hashToCaseBlocks.get(hashCode); + if (list == null) { + list = new ArrayList<>(); + hashToCaseBlocks.put(hashCode, list); + } + list.add(e); } // Initialize the case blocks and lookup blocks - for (Entry>>> hashEntry : hashToCaseBlocks.entrySet()) { + for (Entry>> hashEntry : hashToCaseBlocks.entrySet()) { BytecodeCreatorImpl lookupBlock = new BytecodeCreatorImpl(StringSwitchImpl.this); - for (Entry> caseEntry : hashEntry.getValue()) { - BytecodeCreatorImpl caseBlock = new BytecodeCreatorImpl(StringSwitchImpl.this); - Consumer blockConsumer = caseEntry.getValue(); - blockConsumer.accept(caseBlock); - if (blockConsumer != EMPTY_BLOCK && !fallThrough) { + for (Entry caseEntry : hashEntry.getValue()) { + BytecodeCreatorImpl caseBlock = caseEntry.getValue(); + if (caseBlock != null && !fallThrough) { caseBlock.breakScope(StringSwitchImpl.this); + } else if (caseBlock == null) { + // TODO empty block + caseBlock = new BytecodeCreatorImpl(StringSwitchImpl.this); + caseEntry.setValue(caseBlock); } - caseBlock.findActiveResultHandles(inputHandles); - caseBlocks.put(caseEntry.getKey(), caseBlock); + // caseBlock.findActiveResultHandles(inputHandles); BytecodeCreatorImpl isEqual = (BytecodeCreatorImpl) lookupBlock .ifTrue(Gizmo.equals(lookupBlock, lookupBlock.load(caseEntry.getKey()), value)).trueBranch(); isEqual.jumpTo(caseBlock); @@ -102,6 +109,7 @@ Set getInputResultHandles() { } }); + } @Override @@ -118,27 +126,34 @@ public void caseOf(List values, Consumer caseBlockConsu for (Iterator it = values.iterator(); it.hasNext();) { String s = it.next(); if (it.hasNext()) { - addCaseBlock(s, EMPTY_BLOCK); + addCaseBlock(s, null); } else { addCaseBlock(s, caseBlockConsumer); } } } - private void addCaseBlock(String value, Consumer blockConsumer) { - int hashCode = value.hashCode(); - List>> caseBlocks = hashToCaseBlocks.get(hashCode); - if (caseBlocks == null) { - caseBlocks = new ArrayList<>(); - hashToCaseBlocks.put(hashCode, caseBlocks); - } else { - for (Entry> e : caseBlocks) { - if (e.getKey().equals(value)) { - throw new IllegalArgumentException("A case block for the string value " + value + " already exists"); - } + @Override + void findActiveResultHandles(final Set handlesToAllocate) { + super.findActiveResultHandles(handlesToAllocate); + for (BytecodeCreatorImpl caseBlock : caseBlocks.values()) { + if (caseBlock != null) { + caseBlock.findActiveResultHandles(handlesToAllocate); } } - caseBlocks.add(Map.entry(value, blockConsumer)); + defaultBlock.findActiveResultHandles(handlesToAllocate); + } + + private void addCaseBlock(String value, Consumer blockConsumer) { + if (caseBlocks.containsKey(value)) { + throw new IllegalArgumentException("A case block for the string value [" + value + "] already exists"); + } + BytecodeCreatorImpl caseBlock = null; + if (blockConsumer != null) { + caseBlock = new BytecodeCreatorImpl(this); + blockConsumer.accept(caseBlock); + } + caseBlocks.put(value, caseBlock); } } diff --git a/src/test/java/io/quarkus/gizmo/SwitchTest.java b/src/test/java/io/quarkus/gizmo/SwitchTest.java index aa65aa7..bce0bde 100644 --- a/src/test/java/io/quarkus/gizmo/SwitchTest.java +++ b/src/test/java/io/quarkus/gizmo/SwitchTest.java @@ -254,7 +254,7 @@ public void testEnumSwitch() throws InstantiationException, IllegalAccessExcepti // default: -> return null; // } EnumSwitch s = method.enumSwitch(method.getMethodParam(0), Status.class); - s.caseOf(List.of(Status.ON, Status.OFF), bc -> { + s.caseOf(List.of(Status.OFF, Status.ON), bc -> { bc.returnValue(Gizmo.toString(bc, method.getMethodParam(0))); }); s.caseOf(Status.UNKNOWN, bc -> { @@ -287,8 +287,8 @@ public void testEnumSwitchFallThrough() throws InstantiationException, IllegalAc // } EnumSwitch s = method.enumSwitch(method.getMethodParam(0), Status.class); s.fallThrough(); - s.caseOf(Status.ON, bc -> bc.assign(ret, bc.load("on"))); s.caseOf(Status.OFF, bc -> bc.assign(ret, bc.load("off"))); + s.caseOf(Status.ON, bc -> bc.assign(ret, bc.load("on"))); s.defaultCase(bc -> bc.assign(ret, bc.load("??"))); method.returnValue(ret); } @@ -373,6 +373,74 @@ public void testEnumSwitchDuplicateCase() throws InstantiationException, Illegal } } + @SuppressWarnings("unchecked") + @Test + public void testStringSwitchConsumesHandleFromEnclosingScope() + throws InstantiationException, IllegalAccessException, ClassNotFoundException { + TestClassLoader cl = new TestClassLoader(getClass().getClassLoader()); + try (ClassCreator creator = ClassCreator.builder().classOutput(cl).className("com.MyTest").interfaces(Function.class) + .build()) { + MethodCreator method = creator.getMethodCreator("apply", Object.class, Object.class); + ResultHandle prefix = Gizmo.toString(method, method.load("p_")); + ResultHandle placeholder = Gizmo.toString(method, method.load("placeholder")); + AssignableResultHandle ret = method.createVariable(String.class); + // String placeholder = "placeholder".toString(); + // String prefix = "_p".toString(); + // String ret; + // switch(arg) { + // case "bar" -> ret = prefix + "barr"; + // default -> ret = placeholder; + // } + // return ret; + StringSwitch s = method.stringSwitch(method.getMethodParam(0)); + s.caseOf("bar", bc -> { + bc.assign(ret, + bc.invokeVirtualMethod(MethodDescriptor.ofMethod(String.class, "concat", String.class, String.class), + prefix, bc.load("barr"))); + }); + s.defaultCase(bc -> bc.assign(ret, placeholder)); + + method.returnValue(ret); + } + Function myInterface = (Function) cl.loadClass("com.MyTest").newInstance(); + assertEquals("p_barr", myInterface.apply("bar")); + assertEquals("placeholder", myInterface.apply("unknown")); + } + + @SuppressWarnings("unchecked") + @Test + public void testEnumSwitchConsumesHandleFromEnclosingScope() + throws InstantiationException, IllegalAccessException, ClassNotFoundException { + TestClassLoader cl = new TestClassLoader(getClass().getClassLoader()); + try (ClassCreator creator = ClassCreator.builder().classOutput(cl).className("com.MyTest").interfaces(Function.class) + .build()) { + MethodCreator method = creator.getMethodCreator("apply", Object.class, Object.class); + ResultHandle prefix = Gizmo.toString(method, method.load("p_")); + ResultHandle placeholder = Gizmo.toString(method, method.load("placeholder")); + AssignableResultHandle ret = method.createVariable(String.class); + // String placeholder = "placeholder".toString(); + // String prefix = "_p".toString(); + // String ret; + // switch(status) { + // case ON -> ret = prefix + "on"; + // default: -> return placeholder; + // } + // return ret; + EnumSwitch s = method.enumSwitch(method.getMethodParam(0), Status.class); + s.caseOf(Status.ON, bc -> { + bc.assign(ret, + bc.invokeVirtualMethod(MethodDescriptor.ofMethod(String.class, "concat", String.class, String.class), + prefix, bc.load("on"))); + }); + s.defaultCase(bc -> bc.assign(ret, placeholder)); + + method.returnValue(ret); + } + Function myInterface = (Function) cl.loadClass("com.MyTest").newInstance(); + assertEquals("p_on", myInterface.apply(Status.ON)); + assertEquals("placeholder", myInterface.apply(Status.UNKNOWN)); + } + public enum Status { ON, OFF,