fix: improve 'continue' insertion for switch in loop (#2249)

This commit is contained in:
Skylot 2024-09-01 23:01:52 +01:00
parent 2df69bbfb4
commit cca706c94f
No known key found for this signature in database
GPG Key ID: 47A4975761262B6A
3 changed files with 70 additions and 16 deletions

View File

@ -871,13 +871,14 @@ public class RegionMaker {
outs.or(s.getDomFrontier());
}
outs.clear(block.getId());
outs.clear(mth.getExitBlock().getId());
if (outs.isEmpty()) {
// switch already contains method exit
// add everything, out block not needed
return mth.getExitBlock();
}
BlockNode out;
BlockNode out = null;
if (outs.cardinality() == 1) {
// single exit
out = BlockUtils.bitSetToOneBlock(mth, outs);
@ -886,23 +887,29 @@ public class RegionMaker {
// possible 'return', 'continue' or fallthrough in one of the cases
LoopInfo loop = mth.getLoopForBlock(block);
if (loop != null) {
outs.andNot(block.getPostDoms());
out = BlockUtils.bitSetToOneBlock(mth, outs);
if (out != null) {
insertContinueInSwitch(block, out, loop.getEnd());
if (out == loop.getStart()) {
// no other outs instead back edge to loop start
return null;
outs.andNot(loop.getStart().getPostDoms());
outs.andNot(loop.getEnd().getPostDoms());
BlockNode loopEnd = loop.getEnd();
if (outs.cardinality() == 2 && outs.get(loopEnd.getId())) {
// insert 'continue' for cases lead to loop end
// expect only 2 exits: loop end and switch out
List<BlockNode> outList = BlockUtils.bitSetToBlocks(mth, outs);
outList.remove(loopEnd);
BlockNode possibleOut = Utils.getOne(outList);
if (possibleOut != null && insertContinueInSwitch(block, possibleOut, loopEnd)) {
outs.clear(loopEnd.getId());
out = possibleOut;
}
}
} else {
outs.clear(mth.getExitBlock().getId());
}
if (out == null) {
BlockNode imPostDom = block.getIPostDom();
if (outs.get(imPostDom.getId())) {
return imPostDom;
out = imPostDom;
} else {
outs.andNot(block.getPostDoms());
out = BlockUtils.bitSetToOneBlock(mth, outs);
}
outs.andNot(block.getPostDoms());
out = BlockUtils.bitSetToOneBlock(mth, outs);
}
}
if (out != null && mth.isPreExitBlock(out)) {
@ -994,7 +1001,8 @@ public class RegionMaker {
return newBlocksMap;
}
private void insertContinueInSwitch(BlockNode switchBlock, BlockNode switchOut, BlockNode loopEnd) {
private boolean insertContinueInSwitch(BlockNode switchBlock, BlockNode switchOut, BlockNode loopEnd) {
boolean inserted = false;
for (BlockNode caseBlock : switchBlock.getCleanSuccessors()) {
if (caseBlock.getDomFrontier().get(loopEnd.getId()) && caseBlock != switchOut) {
// search predecessor of loop end on path from this successor
@ -1006,6 +1014,7 @@ public class RegionMaker {
if (list.contains(p)) {
if (p.isSynthetic()) {
p.getInstructions().add(new InsnNode(InsnType.CONTINUE, 0));
inserted = true;
}
break;
}
@ -1013,6 +1022,7 @@ public class RegionMaker {
}
}
}
return inserted;
}
public IRegion processTryCatchBlocks(MethodNode mth) {

View File

@ -3,10 +3,12 @@ package jadx.tests.integration.switches;
import org.junit.jupiter.api.Test;
import jadx.tests.api.IntegrationTest;
import jadx.tests.api.utils.assertj.JadxAssertions;
import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat;
public class TestSwitchContinue extends IntegrationTest {
@SuppressWarnings({ "StringConcatenationInLoop", "DataFlowIssue" })
public static class TestCls {
public String test(int a) {
String s = "";
@ -32,7 +34,7 @@ public class TestSwitchContinue extends IntegrationTest {
@Test
public void test() {
JadxAssertions.assertThat(getClassNode(TestCls.class))
assertThat(getClassNode(TestCls.class))
.code()
.contains("switch (a % 4) {")
.countString(4, "case ")

View File

@ -0,0 +1,42 @@
package jadx.tests.integration.switches;
import jadx.tests.api.IntegrationTest;
import jadx.tests.api.extensions.profiles.TestProfile;
import jadx.tests.api.extensions.profiles.TestWithProfiles;
import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat;
public class TestSwitchInLoop4 extends IntegrationTest {
@SuppressWarnings("SwitchStatementWithTooFewBranches")
public static class TestCls {
private static boolean test(String s, int start) {
boolean foundSeparator = false;
for (int i = start; i < s.length(); i++) {
char c = s.charAt(i);
switch (c) {
case '.':
foundSeparator = true;
break;
}
if (foundSeparator) {
break;
}
}
return foundSeparator;
}
public void check() {
assertThat(test("a.b", 0)).isTrue();
assertThat(test("abc", 1)).isFalse();
}
}
@TestWithProfiles({ TestProfile.DX_J8, TestProfile.D8_J11, TestProfile.JAVA11 })
public void test() {
assertThat(getClassNode(TestCls.class))
.code()
.containsOne("switch (c) {")
.containsOne("break;"); // allow replacing second 'break' with 'return'
}
}