fix: improve switch out block search if all method exits are inside (#2264)
Some checks failed
Build Artifacts / build (push) Failing after 1s
Build Test / tests (ubuntu-latest) (push) Failing after 1s
CodeQL / Analyze (java) (push) Failing after 1s
Validate Gradle Wrapper / Validation (push) Failing after 1s
Build Artifacts / build-win-bundle (push) Has been cancelled
Build Test / tests (windows-latest) (push) Has been cancelled

This commit is contained in:
Skylot 2024-09-22 21:09:10 +01:00
parent 9c30aeacdb
commit 7abbc81886
No known key found for this signature in database
GPG Key ID: 47A4975761262B6A
3 changed files with 106 additions and 7 deletions

View File

@ -15,6 +15,9 @@ import jadx.core.dex.attributes.nodes.LoopInfo;
import jadx.core.dex.attributes.nodes.RegionRefAttr;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.SwitchInsn;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.InsnNode;
@ -132,11 +135,6 @@ final class SwitchRegionMaker {
}
outs.clear(block.getId());
outs.clear(mth.getExitBlock().getId());
if (outs.isEmpty()) {
// switch already contains method exit
// add everything, out block not needed
return mth.getExitBlock();
}
BlockNode out = null;
if (outs.cardinality() == 1) {
@ -161,6 +159,10 @@ final class SwitchRegionMaker {
out = possibleOut;
}
}
if (outs.isEmpty()) {
// all exits inside switch, keep inside to exit from loop
return mth.getExitBlock();
}
}
if (out == null) {
BlockNode imPostDom = block.getIPostDom();
@ -177,6 +179,11 @@ final class SwitchRegionMaker {
out = mth.getExitBlock();
}
BlockNode imPostDom = block.getIPostDom();
if (out == null && imPostDom == mth.getExitBlock()) {
// all exits inside switch
// check if all returns are equals and should be treated as single out block
return allSameReturns(stack);
}
if (out != imPostDom && !mth.isPreExitBlock(imPostDom)) {
// stop other paths at common exit
stack.addExit(imPostDom);
@ -197,6 +204,58 @@ final class SwitchRegionMaker {
return out;
}
private BlockNode allSameReturns(RegionStack stack) {
BlockNode exitBlock = mth.getExitBlock();
List<BlockNode> preds = exitBlock.getPredecessors();
int count = preds.size();
if (count == 1) {
return preds.get(0);
}
if (mth.getReturnType() == ArgType.VOID) {
for (BlockNode pred : preds) {
InsnNode insn = BlockUtils.getLastInsn(pred);
if (insn == null || insn.getType() != InsnType.RETURN) {
return exitBlock;
}
}
} else {
List<InsnArg> returnArgs = new ArrayList<>();
for (BlockNode pred : preds) {
InsnNode insn = BlockUtils.getLastInsn(pred);
if (insn == null || insn.getType() != InsnType.RETURN) {
return exitBlock;
}
returnArgs.add(insn.getArg(0));
}
InsnArg firstArg = returnArgs.get(0);
if (firstArg.isRegister()) {
RegisterArg reg = (RegisterArg) firstArg;
for (int i = 1; i < count; i++) {
InsnArg arg = returnArgs.get(1);
if (!arg.isRegister() || !((RegisterArg) arg).sameCodeVar(reg)) {
return exitBlock;
}
}
} else {
for (int i = 1; i < count; i++) {
InsnArg arg = returnArgs.get(1);
if (!arg.equals(firstArg)) {
return exitBlock;
}
}
}
}
// confirmed
stack.addExits(preds);
// ignore other returns
for (int i = 1; i < count; i++) {
BlockNode block = preds.get(i);
block.add(AFlag.REMOVE);
block.add(AFlag.ADDED_TO_REGION);
}
return preds.get(0);
}
/**
* Remove empty case blocks:
* 1. single 'default' case

View File

@ -45,7 +45,7 @@ public class TestSwitch3 extends IntegrationTest {
public void test() {
assertThat(getClassNode(TestCls.class))
.code()
.countString(0, "break;")
.countString(3, "return;");
.countString(3, "break;")
.countString(0, "return;");
}
}

View File

@ -0,0 +1,40 @@
package jadx.tests.integration.switches;
import org.junit.jupiter.api.Test;
import jadx.tests.api.IntegrationTest;
import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat;
public class TestSwitch4 extends IntegrationTest {
public static class TestCls {
@SuppressWarnings({ "FallThrough", "unused" })
private static int parse(char[] ch, int off, int len) {
int num = ch[off + len - 1] - '0';
switch (len) {
case 4:
num += (ch[off++] - '0') * 1000;
case 3:
num += (ch[off++] - '0') * 100;
case 2:
num += (ch[off] - '0') * 10;
}
return num;
}
public void check() {
assertThat(parse("123".toCharArray(), 0, 3)).isEqualTo(123);
assertThat(parse("a=1234".toCharArray(), 2, 4)).isEqualTo(1234);
}
}
@Test
public void test() {
assertThat(getClassNode(TestCls.class))
.code()
.containsOne("switch (")
.countString(3, "case ")
.doesNotContain("break");
}
}