fix: prevent endless loop in pre header insertion mod (#2300)

This commit is contained in:
Skylot 2024-10-31 19:00:54 +00:00
parent cfbe5ab672
commit 7544d1a113
No known key found for this signature in database
GPG Key ID: 47A4975761262B6A
4 changed files with 184 additions and 47 deletions

View File

@ -8,11 +8,14 @@ import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.jetbrains.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import jadx.api.plugins.input.data.attributes.IJadxAttrType;
import jadx.api.plugins.input.data.attributes.IJadxAttribute;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.attributes.AType;
import jadx.core.dex.attributes.nodes.LoopInfo;
@ -36,6 +39,8 @@ import static jadx.core.dex.visitors.blocks.BlockSplitter.connect;
public class BlockProcessor extends AbstractVisitor {
private static final Logger LOG = LoggerFactory.getLogger(BlockProcessor.class);
private static final boolean DEBUG_MODS = false;
@Override
public void visit(MethodNode mth) {
if (mth.isNoCode() || mth.getBasicBlocks().isEmpty()) {
@ -57,13 +62,27 @@ public class BlockProcessor extends AbstractVisitor {
}
updateCleanSuccessors(mth);
int blocksCount = mth.getBasicBlocks().size();
int modLimit = Math.max(100, blocksCount);
if (DEBUG_MODS) {
mth.addAttr(new DebugModAttr());
}
int i = 0;
while (modifyBlocksTree(mth)) {
computeDominators(mth);
if (i++ > 100) {
throw new JadxRuntimeException("CFG modification limit reached, blocks count: " + mth.getBasicBlocks().size());
if (i++ > modLimit) {
mth.addWarn("CFG modification limit reached, blocks count: " + blocksCount);
break;
}
}
if (DEBUG_MODS && i != 0) {
String stats = "CFG modifications count: " + i
+ ", blocks count: " + blocksCount + '\n'
+ mth.get(DebugModAttr.TYPE).formatStats() + '\n';
mth.addDebugComment(stats);
LOG.debug("Method: {}\n{}", mth, stats);
mth.remove(DebugModAttr.TYPE);
}
checkForUnreachableBlocks(mth);
DominatorTree.computeDominanceFrontier(mth);
@ -298,6 +317,9 @@ public class BlockProcessor extends AbstractVisitor {
}
if (changed) {
removeMarkedBlocks(mth);
if (DEBUG_MODS) {
mth.get(DebugModAttr.TYPE).addEvent("Merge const return");
}
}
return changed;
}
@ -338,17 +360,20 @@ public class BlockProcessor extends AbstractVisitor {
private static boolean simplifyLoopEnd(MethodNode mth, LoopInfo loop) {
BlockNode loopEnd = loop.getEnd();
if (loopEnd.getSuccessors().size() > 1) {
// make loop end a simple path block
BlockNode newLoopEnd = BlockSplitter.startNewBlock(mth, -1);
newLoopEnd.add(AFlag.SYNTHETIC);
newLoopEnd.add(AFlag.LOOP_END);
BlockNode loopStart = loop.getStart();
BlockSplitter.replaceConnection(loopEnd, loopStart, newLoopEnd);
BlockSplitter.connect(newLoopEnd, loopStart);
return true;
if (loopEnd.getSuccessors().size() <= 1) {
return false;
}
return false;
// make loop end a simple path block
BlockNode newLoopEnd = BlockSplitter.startNewBlock(mth, -1);
newLoopEnd.add(AFlag.SYNTHETIC);
newLoopEnd.add(AFlag.LOOP_END);
BlockNode loopStart = loop.getStart();
BlockSplitter.replaceConnection(loopEnd, loopStart, newLoopEnd);
BlockSplitter.connect(newLoopEnd, loopStart);
if (DEBUG_MODS) {
mth.get(DebugModAttr.TYPE).addEvent("Simplify loop end");
}
return true;
}
private static boolean checkLoops(MethodNode mth, BlockNode block) {
@ -371,7 +396,6 @@ public class BlockProcessor extends AbstractVisitor {
if (loopsCount == 1) {
LoopInfo loop = loops.get(0);
return insertBlocksForContinue(mth, loop)
|| insertBlockForPredecessors(mth, loop)
|| insertPreHeader(mth, loop)
|| simplifyLoopEnd(mth, loop);
}
@ -398,15 +422,21 @@ public class BlockProcessor extends AbstractVisitor {
mth.setEnterBlock(newEnterBlock);
start.remove(AFlag.MTH_ENTER_BLOCK);
BlockSplitter.connect(newEnterBlock, start);
return true;
} else {
// multiple predecessors
BlockNode preHeader = BlockSplitter.startNewBlock(mth, -1);
preHeader.add(AFlag.SYNTHETIC);
BlockNode loopEnd = loop.getEnd();
for (BlockNode pred : new ArrayList<>(preds)) {
if (pred != loopEnd) {
BlockSplitter.replaceConnection(pred, start, preHeader);
}
}
BlockSplitter.connect(preHeader, start);
}
// multiple predecessors
BlockNode preHeader = BlockSplitter.startNewBlock(mth, -1);
preHeader.add(AFlag.SYNTHETIC);
for (BlockNode pred : new ArrayList<>(preds)) {
BlockSplitter.replaceConnection(pred, start, preHeader);
if (DEBUG_MODS) {
mth.get(DebugModAttr.TYPE).addEvent("Insert loop pre header");
}
BlockSplitter.connect(preHeader, start);
return true;
}
@ -426,6 +456,9 @@ public class BlockProcessor extends AbstractVisitor {
}
}
}
if (DEBUG_MODS && change) {
mth.get(DebugModAttr.TYPE).addEvent("Insert loop break blocks");
}
return change;
}
@ -444,24 +477,10 @@ public class BlockProcessor extends AbstractVisitor {
}
}
}
return change;
}
/**
* Insert additional block if loop header has several predecessors (exclude back edges)
*/
private static boolean insertBlockForPredecessors(MethodNode mth, LoopInfo loop) {
BlockNode loopHeader = loop.getStart();
List<BlockNode> preds = loopHeader.getPredecessors();
if (preds.size() > 2) {
List<BlockNode> blocks = new ArrayList<>(preds);
blocks.removeIf(block -> block.contains(AFlag.LOOP_END));
BlockNode first = blocks.remove(0);
BlockNode preHeader = BlockSplitter.insertBlockBetween(mth, first, loopHeader);
blocks.forEach(block -> BlockSplitter.replaceConnection(block, loopHeader, preHeader));
return true;
if (DEBUG_MODS && change) {
mth.get(DebugModAttr.TYPE).addEvent("Insert loop continue block");
}
return false;
return change;
}
private static boolean splitLoops(MethodNode mth, BlockNode block, List<LoopInfo> loops) {
@ -472,17 +491,20 @@ public class BlockProcessor extends AbstractVisitor {
break;
}
}
if (oneHeader) {
// several back edges connected to one loop header => make additional block
BlockNode newLoopEnd = BlockSplitter.startNewBlock(mth, block.getStartOffset());
newLoopEnd.add(AFlag.SYNTHETIC);
connect(newLoopEnd, block);
for (LoopInfo la : loops) {
BlockSplitter.replaceConnection(la.getEnd(), block, newLoopEnd);
}
return true;
if (!oneHeader) {
return false;
}
return false;
// several back edges connected to one loop header => make additional block
BlockNode newLoopEnd = BlockSplitter.startNewBlock(mth, block.getStartOffset());
newLoopEnd.add(AFlag.SYNTHETIC);
connect(newLoopEnd, block);
for (LoopInfo la : loops) {
BlockSplitter.replaceConnection(la.getEnd(), block, newLoopEnd);
}
if (DEBUG_MODS) {
mth.get(DebugModAttr.TYPE).addEvent("Split loops");
}
return true;
}
private static boolean splitExitBlocks(MethodNode mth) {
@ -496,6 +518,9 @@ public class BlockProcessor extends AbstractVisitor {
}
if (changed) {
updateExitBlockConnections(mth);
if (DEBUG_MODS) {
mth.get(DebugModAttr.TYPE).addEvent("Split exit block");
}
}
return changed;
}
@ -691,4 +716,25 @@ public class BlockProcessor extends AbstractVisitor {
block.getDominatesOn().clear();
});
}
private static final class DebugModAttr implements IJadxAttribute {
static final IJadxAttrType<DebugModAttr> TYPE = IJadxAttrType.create("DebugModAttr");
private final Map<String, Integer> statMap = new HashMap<>();
public void addEvent(String name) {
statMap.merge(name, 1, Integer::sum);
}
public String formatStats() {
return statMap.entrySet().stream()
.map(entry -> " " + entry.getKey() + ": " + entry.getValue())
.collect(Collectors.joining("\n"));
}
@Override
public IJadxAttrType<DebugModAttr> getAttrType() {
return TYPE;
}
}
}

View File

@ -0,0 +1,18 @@
package jadx.tests.integration.loops;
import org.junit.jupiter.api.Test;
import jadx.tests.api.SmaliTest;
import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat;
public class TestLoopRestore3 extends SmaliTest {
@Test
public void test() {
disableCompilation();
assertThat(getClassNodeFromSmali())
.code()
.countString(3, "while (");
}
}

View File

@ -0,0 +1,64 @@
.class public Lloops/TestLoopRestore3;
.super Ljava/lang/Object;
.method public final b(Ljava/lang/String;Lb/U53$b;)V
.registers 8
iget-object v0, p0, Lb/X53;->e:Ljava/util/concurrent/atomic/AtomicReference;
:goto_2
invoke-virtual {v0}, Ljava/util/concurrent/atomic/AtomicReference;->get()Ljava/lang/Object;
move-result-object v1
move-object v2, v1
check-cast v2, Ljava/util/List;
move-object v3, v2
check-cast v3, Ljava/lang/Iterable;
instance-of v4, v3, Ljava/util/Collection;
if-eqz v4, :cond_1a
move-object v4, v3
check-cast v4, Ljava/util/Collection;
invoke-interface {v4}, Ljava/util/Collection;->isEmpty()Z
move-result v4
if-eqz v4, :cond_1a
goto :goto_33
:cond_1a
invoke-interface {v3}, Ljava/lang/Iterable;->iterator()Ljava/util/Iterator;
move-result-object v3
:cond_1e
invoke-interface {v3}, Ljava/util/Iterator;->hasNext()Z
move-result v4
if-eqz v4, :cond_33
invoke-interface {v3}, Ljava/util/Iterator;->next()Ljava/lang/Object;
move-result-object v4
check-cast v4, Lb/X53$c;
iget-object v4, v4, Lb/X53$c;->b:Ljava/lang/String;
invoke-static {v4, p1}, Lkotlin/jvm/internal/Intrinsics;->a(Ljava/lang/Object;Ljava/lang/Object;)Z
move-result v4
if-eqz v4, :cond_1e
goto :goto_40
:cond_33
:goto_33
check-cast v2, Ljava/util/Collection;
new-instance v3, Lb/X53$c;
sget-object v4, Lb/Pd2;->a:Lb/Pd2;
invoke-direct {v3, p2, p1, v4}, Lb/X53$c;-><init>(Lb/U53$b;Ljava/lang/String;Ljava/util/List;)V
invoke-static {v2, v3}, Lb/R31;->a0(Ljava/util/Collection;Ljava/lang/Object;)Ljava/util/ArrayList;
move-result-object v2
:cond_40
:goto_40
invoke-virtual {v0, v1, v2}, Ljava/util/concurrent/atomic/AtomicReference;->compareAndSet(Ljava/lang/Object;Ljava/lang/Object;)Z
move-result v3
if-eqz v3, :cond_47
return-void
:cond_47
invoke-virtual {v0}, Ljava/util/concurrent/atomic/AtomicReference;->get()Ljava/lang/Object;
move-result-object v3
if-eq v3, v1, :cond_40
goto :goto_2
.end method

View File

@ -18,4 +18,13 @@ public interface IJadxAttrType<T extends IJadxAttribute> {
return new IJadxAttrType<>() {
};
}
static <A extends IJadxAttribute> IJadxAttrType<A> create(String name) {
return new IJadxAttrType<>() {
@Override
public String toString() {
return name;
}
};
}
}