diff --git a/jadx-core/src/main/java/jadx/core/Jadx.java b/jadx-core/src/main/java/jadx/core/Jadx.java index f9dc7505..a1efbfb8 100644 --- a/jadx-core/src/main/java/jadx/core/Jadx.java +++ b/jadx-core/src/main/java/jadx/core/Jadx.java @@ -29,6 +29,7 @@ import jadx.core.dex.visitors.EnumVisitor; import jadx.core.dex.visitors.ExtractFieldInit; import jadx.core.dex.visitors.FallbackModeVisitor; import jadx.core.dex.visitors.FixAccessModifiers; +import jadx.core.dex.visitors.FixSwitchOverEnum; import jadx.core.dex.visitors.GenericTypesVisitor; import jadx.core.dex.visitors.IDexTreeVisitor; import jadx.core.dex.visitors.InitCodeVariables; @@ -43,7 +44,7 @@ import jadx.core.dex.visitors.PrepareForCodeGen; import jadx.core.dex.visitors.ProcessAnonymous; import jadx.core.dex.visitors.ProcessInstructionsVisitor; import jadx.core.dex.visitors.ProcessMethodsForInline; -import jadx.core.dex.visitors.ReSugarCode; +import jadx.core.dex.visitors.ReplaceNewArray; import jadx.core.dex.visitors.ShadowFieldVisitor; import jadx.core.dex.visitors.SignatureProcessor; import jadx.core.dex.visitors.SimplifyVisitor; @@ -154,7 +155,7 @@ public class Jadx { passes.add(new AnonymousClassVisitor()); passes.add(new ModVisitor()); passes.add(new CodeShrinkVisitor()); - passes.add(new ReSugarCode()); + passes.add(new ReplaceNewArray()); if (args.isCfgOutput()) { passes.add(DotGraphVisitor.dump()); } @@ -171,6 +172,7 @@ public class Jadx { passes.add(new CheckRegions()); passes.add(new EnumVisitor()); + passes.add(new FixSwitchOverEnum()); passes.add(new ExtractFieldInit()); passes.add(new FixAccessModifiers()); passes.add(new ClassModifier()); @@ -219,8 +221,7 @@ public class Jadx { passes.add(new DeboxingVisitor()); passes.add(new ModVisitor()); passes.add(new CodeShrinkVisitor()); - passes.add(new ReSugarCode()); - passes.add(new CodeShrinkVisitor()); + passes.add(new ReplaceNewArray()); passes.add(new SimplifyVisitor()); passes.add(new MethodVisitor("ForceGenerateAll", mth -> mth.remove(AFlag.DONT_GENERATE))); if (args.isCfgOutput()) { diff --git a/jadx-core/src/main/java/jadx/core/dex/attributes/AType.java b/jadx-core/src/main/java/jadx/core/dex/attributes/AType.java index 85c20f61..7faf60ee 100644 --- a/jadx-core/src/main/java/jadx/core/dex/attributes/AType.java +++ b/jadx-core/src/main/java/jadx/core/dex/attributes/AType.java @@ -25,6 +25,7 @@ import jadx.core.dex.attributes.nodes.MethodReplaceAttr; import jadx.core.dex.attributes.nodes.MethodTypeVarsAttr; import jadx.core.dex.attributes.nodes.PhiListAttr; import jadx.core.dex.attributes.nodes.RegDebugInfoAttr; +import jadx.core.dex.attributes.nodes.RegionRefAttr; import jadx.core.dex.attributes.nodes.RenameReasonAttr; import jadx.core.dex.attributes.nodes.SkipMethodArgsAttr; import jadx.core.dex.attributes.nodes.SpecialEdgeAttr; @@ -94,6 +95,7 @@ public final class AType implements IJadxAttrType { public static final AType> JUMP = new AType<>(); public static final AType METHOD_DETAILS = new AType<>(); public static final AType GENERIC_INFO = new AType<>(); + public static final AType REGION_REF = new AType<>(); // register public static final AType REG_DEBUG_INFO = new AType<>(); diff --git a/jadx-core/src/main/java/jadx/core/dex/attributes/nodes/RegionRefAttr.java b/jadx-core/src/main/java/jadx/core/dex/attributes/nodes/RegionRefAttr.java new file mode 100644 index 00000000..76d47dde --- /dev/null +++ b/jadx-core/src/main/java/jadx/core/dex/attributes/nodes/RegionRefAttr.java @@ -0,0 +1,30 @@ +package jadx.core.dex.attributes.nodes; + +import jadx.api.plugins.input.data.attributes.IJadxAttribute; +import jadx.core.dex.attributes.AType; +import jadx.core.dex.nodes.IRegion; + +/** + * Region created based on parent instruction + */ +public class RegionRefAttr implements IJadxAttribute { + private final IRegion region; + + public RegionRefAttr(IRegion region) { + this.region = region; + } + + public IRegion getRegion() { + return region; + } + + @Override + public AType getAttrType() { + return AType.REGION_REF; + } + + @Override + public String toString() { + return "RegionRef:" + region; + } +} diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/EnumVisitor.java b/jadx-core/src/main/java/jadx/core/dex/visitors/EnumVisitor.java index b3c2be24..32c4539c 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/EnumVisitor.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/EnumVisitor.java @@ -63,7 +63,7 @@ import static jadx.core.utils.InsnUtils.getWrappedInsn; runAfter = { CodeShrinkVisitor.class, // all possible instructions already inlined ModVisitor.class, - ReSugarCode.class, + ReplaceNewArray.class, // values array normalized IfRegionVisitor.class, // ternary operator inlined CheckRegions.class // regions processing finished }, diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/FixSwitchOverEnum.java b/jadx-core/src/main/java/jadx/core/dex/visitors/FixSwitchOverEnum.java new file mode 100644 index 00000000..f63c5f44 --- /dev/null +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/FixSwitchOverEnum.java @@ -0,0 +1,297 @@ +package jadx.core.dex.visitors; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.IntFunction; + +import org.jetbrains.annotations.Nullable; + +import jadx.core.dex.attributes.AFlag; +import jadx.core.dex.attributes.AType; +import jadx.core.dex.attributes.nodes.EnumClassAttr; +import jadx.core.dex.attributes.nodes.EnumMapAttr; +import jadx.core.dex.attributes.nodes.RegionRefAttr; +import jadx.core.dex.info.AccessInfo; +import jadx.core.dex.info.FieldInfo; +import jadx.core.dex.info.MethodInfo; +import jadx.core.dex.instructions.IndexInsnNode; +import jadx.core.dex.instructions.InsnType; +import jadx.core.dex.instructions.InvokeNode; +import jadx.core.dex.instructions.SwitchInsn; +import jadx.core.dex.instructions.args.InsnArg; +import jadx.core.dex.instructions.args.InsnWrapArg; +import jadx.core.dex.instructions.args.LiteralArg; +import jadx.core.dex.nodes.BlockNode; +import jadx.core.dex.nodes.ClassNode; +import jadx.core.dex.nodes.FieldNode; +import jadx.core.dex.nodes.InsnNode; +import jadx.core.dex.nodes.MethodNode; +import jadx.core.dex.nodes.RootNode; +import jadx.core.dex.regions.SwitchRegion; +import jadx.core.dex.visitors.shrink.CodeShrinkVisitor; +import jadx.core.utils.Utils; +import jadx.core.utils.exceptions.JadxException; + +@JadxVisitor( + name = "FixSwitchOverEnum", + desc = "Simplify synthetic code in switch over enum", + runAfter = { + CodeShrinkVisitor.class, + EnumVisitor.class + } +) +public class FixSwitchOverEnum extends AbstractVisitor { + + @Override + public boolean visit(ClassNode cls) throws JadxException { + initClsEnumMap(cls); + return true; + } + + @Override + public void visit(MethodNode mth) throws JadxException { + if (mth.isNoCode()) { + return; + } + boolean changed = false; + for (BlockNode block : mth.getBasicBlocks()) { + for (InsnNode insn : block.getInstructions()) { + if (insn.getType() == InsnType.SWITCH && !insn.contains(AFlag.REMOVE)) { + changed |= processEnumSwitch(mth, (SwitchInsn) insn); + } + } + } + if (changed) { + CodeShrinkVisitor.shrinkMethod(mth); + } + } + + private static boolean processEnumSwitch(MethodNode mth, SwitchInsn insn) { + InsnArg arg = insn.getArg(0); + if (!arg.isInsnWrap()) { + return false; + } + InsnNode wrapInsn = ((InsnWrapArg) arg).getWrapInsn(); + switch (wrapInsn.getType()) { + case AGET: + return processRemappedEnumSwitch(mth, insn, wrapInsn, arg); + case INVOKE: + return processDirectEnumSwitch(mth, insn, (InvokeNode) wrapInsn, arg); + } + return false; + } + + private static boolean executeReplace(SwitchInsn swInsn, InsnArg arg, InsnArg invVar, IntFunction caseReplace) { + RegionRefAttr regionRefAttr = swInsn.get(AType.REGION_REF); + if (regionRefAttr == null) { + return false; + } + if (!swInsn.replaceArg(arg, invVar)) { + return false; + } + Map replaceMap = new HashMap<>(); + int caseCount = swInsn.getKeys().length; + for (int i = 0; i < caseCount; i++) { + Object key = swInsn.getKey(i); + Object replaceObj = caseReplace.apply(i); + swInsn.modifyKey(i, replaceObj); + replaceMap.put(key, replaceObj); + } + SwitchRegion region = (SwitchRegion) regionRefAttr.getRegion(); + for (SwitchRegion.CaseInfo caseInfo : region.getCases()) { + caseInfo.getKeys().replaceAll(key -> Utils.getOrElse(replaceMap.get(key), key)); + } + return true; + } + + private static boolean processDirectEnumSwitch(MethodNode mth, SwitchInsn swInsn, InvokeNode invInsn, InsnArg arg) { + MethodInfo callMth = invInsn.getCallMth(); + if (!callMth.getShortId().equals("ordinal()I")) { + return false; + } + InsnArg invVar = invInsn.getArg(0); + ClassNode enumCls = mth.root().resolveClass(invVar.getType()); + if (enumCls == null) { + return false; + } + EnumClassAttr enumClassAttr = enumCls.get(AType.ENUM_CLASS); + if (enumClassAttr == null) { + return false; + } + FieldNode[] casesReplaceArr = mapToCases(swInsn, enumClassAttr.getFields()); + if (casesReplaceArr == null) { + return false; + } + return executeReplace(swInsn, arg, invVar, i -> casesReplaceArr[i]); + } + + private static @Nullable FieldNode[] mapToCases(SwitchInsn swInsn, List fields) { + int caseCount = swInsn.getKeys().length; + if (fields.size() < caseCount) { + return null; + } + FieldNode[] casesMap = new FieldNode[caseCount]; + for (int i = 0; i < caseCount; i++) { + Object key = swInsn.getKey(i); + if (key instanceof Integer) { + int ordinal = (Integer) key; + try { + casesMap[ordinal] = fields.get(ordinal).getField(); + } catch (Exception e) { + return null; + } + } else { + return null; + } + } + return casesMap; + } + + private static boolean processRemappedEnumSwitch(MethodNode mth, SwitchInsn insn, InsnNode wrapInsn, InsnArg arg) { + EnumMapInfo enumMapInfo = checkEnumMapAccess(mth.root(), wrapInsn); + if (enumMapInfo == null) { + return false; + } + FieldNode enumMapField = enumMapInfo.getMapField(); + InsnArg invArg = enumMapInfo.getArg(); + + EnumMapAttr.KeyValueMap valueMap = getEnumMap(enumMapField); + if (valueMap == null) { + return false; + } + int caseCount = insn.getKeys().length; + for (int i = 0; i < caseCount; i++) { + Object key = insn.getKey(i); + Object newKey = valueMap.get(key); + if (newKey == null) { + return false; + } + } + if (executeReplace(insn, arg, invArg, i -> valueMap.get(insn.getKey(i)))) { + enumMapField.add(AFlag.DONT_GENERATE); + checkAndHideClass(enumMapField.getParentClass()); + return true; + } + return false; + } + + private static void initClsEnumMap(ClassNode enumCls) { + MethodNode clsInitMth = enumCls.getClassInitMth(); + if (clsInitMth == null || clsInitMth.isNoCode() || clsInitMth.getBasicBlocks() == null) { + return; + } + EnumMapAttr mapAttr = new EnumMapAttr(); + for (BlockNode block : clsInitMth.getBasicBlocks()) { + for (InsnNode insn : block.getInstructions()) { + if (insn.getType() == InsnType.APUT) { + addToEnumMap(enumCls.root(), mapAttr, insn); + } + } + } + if (!mapAttr.isEmpty()) { + enumCls.addAttr(mapAttr); + } + } + + private static @Nullable EnumMapAttr.KeyValueMap getEnumMap(FieldNode field) { + ClassNode syntheticClass = field.getParentClass(); + EnumMapAttr mapAttr = syntheticClass.get(AType.ENUM_MAP); + if (mapAttr == null) { + return null; + } + return mapAttr.getMap(field); + } + + private static void addToEnumMap(RootNode root, EnumMapAttr mapAttr, InsnNode aputInsn) { + InsnArg litArg = aputInsn.getArg(2); + if (!litArg.isLiteral()) { + return; + } + EnumMapInfo mapInfo = checkEnumMapAccess(root, aputInsn); + if (mapInfo == null) { + return; + } + InsnArg enumArg = mapInfo.getArg(); + FieldNode field = mapInfo.getMapField(); + if (field == null || !enumArg.isInsnWrap()) { + return; + } + InsnNode sget = ((InsnWrapArg) enumArg).getWrapInsn(); + if (!(sget instanceof IndexInsnNode)) { + return; + } + Object index = ((IndexInsnNode) sget).getIndex(); + if (!(index instanceof FieldInfo)) { + return; + } + FieldNode fieldNode = root.resolveField((FieldInfo) index); + if (fieldNode == null) { + return; + } + int literal = (int) ((LiteralArg) litArg).getLiteral(); + mapAttr.add(field, literal, fieldNode); + } + + private static @Nullable EnumMapInfo checkEnumMapAccess(RootNode root, InsnNode checkInsn) { + InsnArg sgetArg = checkInsn.getArg(0); + InsnArg invArg = checkInsn.getArg(1); + if (!sgetArg.isInsnWrap() || !invArg.isInsnWrap()) { + return null; + } + InsnNode invInsn = ((InsnWrapArg) invArg).getWrapInsn(); + InsnNode sgetInsn = ((InsnWrapArg) sgetArg).getWrapInsn(); + if (invInsn.getType() != InsnType.INVOKE || sgetInsn.getType() != InsnType.SGET) { + return null; + } + InvokeNode inv = (InvokeNode) invInsn; + if (!inv.getCallMth().getShortId().equals("ordinal()I")) { + return null; + } + ClassNode enumCls = root.resolveClass(inv.getCallMth().getDeclClass()); + if (enumCls == null || !enumCls.isEnum()) { + return null; + } + Object index = ((IndexInsnNode) sgetInsn).getIndex(); + if (!(index instanceof FieldInfo)) { + return null; + } + FieldNode enumMapField = root.resolveField((FieldInfo) index); + if (enumMapField == null || !enumMapField.getAccessFlags().isSynthetic()) { + return null; + } + return new EnumMapInfo(inv.getArg(0), enumMapField); + } + + /** + * If all static final synthetic fields have DONT_GENERATE => hide whole class + */ + private static void checkAndHideClass(ClassNode cls) { + for (FieldNode field : cls.getFields()) { + AccessInfo af = field.getAccessFlags(); + if (af.isSynthetic() && af.isStatic() && af.isFinal() + && !field.contains(AFlag.DONT_GENERATE)) { + return; + } + } + cls.add(AFlag.DONT_GENERATE); + } + + private static class EnumMapInfo { + private final InsnArg arg; + private final FieldNode mapField; + + public EnumMapInfo(InsnArg arg, FieldNode mapField) { + this.arg = arg; + this.mapField = mapField; + } + + public InsnArg getArg() { + return arg; + } + + public FieldNode getMapField() { + return mapField; + } + } +} diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/ReSugarCode.java b/jadx-core/src/main/java/jadx/core/dex/visitors/ReSugarCode.java deleted file mode 100644 index 9d581557..00000000 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/ReSugarCode.java +++ /dev/null @@ -1,368 +0,0 @@ -package jadx.core.dex.visitors; - -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.SortedMap; -import java.util.TreeMap; - -import org.jetbrains.annotations.Nullable; - -import jadx.core.dex.attributes.AFlag; -import jadx.core.dex.attributes.AType; -import jadx.core.dex.attributes.nodes.EnumMapAttr; -import jadx.core.dex.info.AccessInfo; -import jadx.core.dex.info.FieldInfo; -import jadx.core.dex.instructions.FilledNewArrayNode; -import jadx.core.dex.instructions.IndexInsnNode; -import jadx.core.dex.instructions.InsnType; -import jadx.core.dex.instructions.InvokeNode; -import jadx.core.dex.instructions.NewArrayNode; -import jadx.core.dex.instructions.SwitchInsn; -import jadx.core.dex.instructions.args.ArgType; -import jadx.core.dex.instructions.args.InsnArg; -import jadx.core.dex.instructions.args.InsnWrapArg; -import jadx.core.dex.instructions.args.LiteralArg; -import jadx.core.dex.instructions.args.RegisterArg; -import jadx.core.dex.nodes.BlockNode; -import jadx.core.dex.nodes.ClassNode; -import jadx.core.dex.nodes.FieldNode; -import jadx.core.dex.nodes.InsnNode; -import jadx.core.dex.nodes.MethodNode; -import jadx.core.dex.nodes.RootNode; -import jadx.core.dex.visitors.shrink.CodeShrinkVisitor; -import jadx.core.utils.BlockUtils; -import jadx.core.utils.InsnList; -import jadx.core.utils.InsnRemover; -import jadx.core.utils.InsnUtils; -import jadx.core.utils.exceptions.JadxException; - -@JadxVisitor( - name = "ReSugarCode", - desc = "Simplify synthetic or verbose code", - runAfter = CodeShrinkVisitor.class -) -public class ReSugarCode extends AbstractVisitor { - - @Override - public boolean visit(ClassNode cls) throws JadxException { - initClsEnumMap(cls); - return true; - } - - @Override - public void visit(MethodNode mth) throws JadxException { - if (mth.isNoCode()) { - return; - } - int k = 0; - while (true) { - boolean changed = false; - InsnRemover remover = new InsnRemover(mth); - for (BlockNode block : mth.getBasicBlocks()) { - remover.setBlock(block); - List instructions = block.getInstructions(); - int size = instructions.size(); - for (int i = 0; i < size; i++) { - changed |= process(mth, instructions, i, remover); - } - remover.perform(); - } - if (changed) { - CodeShrinkVisitor.shrinkMethod(mth); - } else { - break; - } - if (k++ > 100) { - mth.addWarnComment("Reached limit for ReSugarCode iterations"); - break; - } - } - } - - private static boolean process(MethodNode mth, List instructions, int i, InsnRemover remover) { - InsnNode insn = instructions.get(i); - if (insn.contains(AFlag.REMOVE)) { - return false; - } - switch (insn.getType()) { - case NEW_ARRAY: - return processNewArray(mth, (NewArrayNode) insn, instructions, remover); - - case SWITCH: - return processEnumSwitch(mth, (SwitchInsn) insn); - - default: - return false; - } - } - - /** - * Replace new-array and sequence of array-put to new filled-array instruction. - */ - private static boolean processNewArray(MethodNode mth, NewArrayNode newArrayInsn, List instructions, InsnRemover remover) { - Object arrayLenConst = InsnUtils.getConstValueByArg(mth.root(), newArrayInsn.getArg(0)); - if (!(arrayLenConst instanceof LiteralArg)) { - return false; - } - int len = (int) ((LiteralArg) arrayLenConst).getLiteral(); - if (len == 0) { - return false; - } - ArgType arrType = newArrayInsn.getArrayType(); - ArgType elemType = arrType.getArrayElement(); - boolean allowMissingKeys = arrType.getArrayDimension() == 1 && elemType.isPrimitive(); - int minLen = allowMissingKeys ? len / 2 : len; - - RegisterArg arrArg = newArrayInsn.getResult(); - List useList = arrArg.getSVar().getUseList(); - if (useList.size() < minLen) { - return false; - } - // quick check if APUT is used - boolean foundPut = false; - for (RegisterArg registerArg : useList) { - InsnNode parentInsn = registerArg.getParentInsn(); - if (parentInsn != null && parentInsn.getType() == InsnType.APUT) { - foundPut = true; - break; - } - } - if (!foundPut) { - return false; - } - // collect put instructions sorted by array index - SortedMap arrPuts = new TreeMap<>(); - for (RegisterArg registerArg : useList) { - InsnNode parentInsn = registerArg.getParentInsn(); - if (parentInsn == null || parentInsn.getType() != InsnType.APUT) { - continue; - } - if (!arrArg.sameRegAndSVar(parentInsn.getArg(0))) { - return false; - } - Object constVal = InsnUtils.getConstValueByArg(mth.root(), parentInsn.getArg(1)); - if (!(constVal instanceof LiteralArg)) { - return false; - } - long index = ((LiteralArg) constVal).getLiteral(); - if (index >= len) { - return false; - } - if (arrPuts.containsKey(index)) { - // stop on index rewrite - break; - } - arrPuts.put(index, parentInsn); - } - if (arrPuts.size() < minLen) { - return false; - } - // expect all puts to be in same block - if (!new HashSet<>(instructions).containsAll(arrPuts.values())) { - return false; - } - - // checks complete, apply - InsnNode filledArr = new FilledNewArrayNode(elemType, len); - filledArr.setResult(arrArg.duplicate()); - filledArr.copyAttributesFrom(newArrayInsn); - filledArr.inheritMetadata(newArrayInsn); - filledArr.setOffset(newArrayInsn.getOffset()); - - long prevIndex = -1; - for (Map.Entry entry : arrPuts.entrySet()) { - long index = entry.getKey(); - if (index != prevIndex) { - // use zero for missing keys - for (long i = prevIndex + 1; i < index; i++) { - filledArr.addArg(InsnArg.lit(0, elemType)); - } - } - InsnNode put = entry.getValue(); - filledArr.addArg(replaceConstInArg(mth, put.getArg(2))); - remover.addAndUnbind(put); - prevIndex = index; - } - remover.addAndUnbind(newArrayInsn); - - InsnNode lastPut = arrPuts.get(arrPuts.lastKey()); - int replaceIndex = InsnList.getIndex(instructions, lastPut); - instructions.set(replaceIndex, filledArr); - BlockUtils.replaceInsn(mth, lastPut, filledArr); - return true; - } - - private static InsnArg replaceConstInArg(MethodNode mth, InsnArg valueArg) { - if (valueArg.isLiteral()) { - FieldNode f = mth.getParentClass().getConstFieldByLiteralArg((LiteralArg) valueArg); - if (f != null) { - InsnNode fGet = new IndexInsnNode(InsnType.SGET, f.getFieldInfo(), 0); - InsnArg arg = InsnArg.wrapArg(fGet); - f.addUseIn(mth); - return arg; - } - } - return valueArg.duplicate(); - } - - private static boolean processEnumSwitch(MethodNode mth, SwitchInsn insn) { - InsnArg arg = insn.getArg(0); - if (!arg.isInsnWrap()) { - return false; - } - InsnNode wrapInsn = ((InsnWrapArg) arg).getWrapInsn(); - if (wrapInsn.getType() != InsnType.AGET) { - return false; - } - EnumMapInfo enumMapInfo = checkEnumMapAccess(mth.root(), wrapInsn); - if (enumMapInfo == null) { - return false; - } - FieldNode enumMapField = enumMapInfo.getMapField(); - InsnArg invArg = enumMapInfo.getArg(); - - EnumMapAttr.KeyValueMap valueMap = getEnumMap(mth, enumMapField); - if (valueMap == null) { - return false; - } - int caseCount = insn.getKeys().length; - for (int i = 0; i < caseCount; i++) { - Object key = insn.getKey(i); - Object newKey = valueMap.get(key); - if (newKey == null) { - return false; - } - } - // replace confirmed - if (!insn.replaceArg(arg, invArg)) { - return false; - } - for (int i = 0; i < caseCount; i++) { - insn.modifyKey(i, valueMap.get(insn.getKey(i))); - } - enumMapField.add(AFlag.DONT_GENERATE); - checkAndHideClass(enumMapField.getParentClass()); - return true; - } - - private static void initClsEnumMap(ClassNode enumCls) { - MethodNode clsInitMth = enumCls.getClassInitMth(); - if (clsInitMth == null || clsInitMth.isNoCode() || clsInitMth.getBasicBlocks() == null) { - return; - } - EnumMapAttr mapAttr = new EnumMapAttr(); - for (BlockNode block : clsInitMth.getBasicBlocks()) { - for (InsnNode insn : block.getInstructions()) { - if (insn.getType() == InsnType.APUT) { - addToEnumMap(enumCls.root(), mapAttr, insn); - } - } - } - if (!mapAttr.isEmpty()) { - enumCls.addAttr(mapAttr); - } - } - - @Nullable - private static EnumMapAttr.KeyValueMap getEnumMap(MethodNode mth, FieldNode field) { - ClassNode syntheticClass = field.getParentClass(); - EnumMapAttr mapAttr = syntheticClass.get(AType.ENUM_MAP); - if (mapAttr == null) { - return null; - } - return mapAttr.getMap(field); - } - - private static void addToEnumMap(RootNode root, EnumMapAttr mapAttr, InsnNode aputInsn) { - InsnArg litArg = aputInsn.getArg(2); - if (!litArg.isLiteral()) { - return; - } - EnumMapInfo mapInfo = checkEnumMapAccess(root, aputInsn); - if (mapInfo == null) { - return; - } - InsnArg enumArg = mapInfo.getArg(); - FieldNode field = mapInfo.getMapField(); - if (field == null || !enumArg.isInsnWrap()) { - return; - } - InsnNode sget = ((InsnWrapArg) enumArg).getWrapInsn(); - if (!(sget instanceof IndexInsnNode)) { - return; - } - Object index = ((IndexInsnNode) sget).getIndex(); - if (!(index instanceof FieldInfo)) { - return; - } - FieldNode fieldNode = root.resolveField((FieldInfo) index); - if (fieldNode == null) { - return; - } - int literal = (int) ((LiteralArg) litArg).getLiteral(); - mapAttr.add(field, literal, fieldNode); - } - - public static EnumMapInfo checkEnumMapAccess(RootNode root, InsnNode checkInsn) { - InsnArg sgetArg = checkInsn.getArg(0); - InsnArg invArg = checkInsn.getArg(1); - if (!sgetArg.isInsnWrap() || !invArg.isInsnWrap()) { - return null; - } - InsnNode invInsn = ((InsnWrapArg) invArg).getWrapInsn(); - InsnNode sgetInsn = ((InsnWrapArg) sgetArg).getWrapInsn(); - if (invInsn.getType() != InsnType.INVOKE || sgetInsn.getType() != InsnType.SGET) { - return null; - } - InvokeNode inv = (InvokeNode) invInsn; - if (!inv.getCallMth().getShortId().equals("ordinal()I")) { - return null; - } - ClassNode enumCls = root.resolveClass(inv.getCallMth().getDeclClass()); - if (enumCls == null || !enumCls.isEnum()) { - return null; - } - Object index = ((IndexInsnNode) sgetInsn).getIndex(); - if (!(index instanceof FieldInfo)) { - return null; - } - FieldNode enumMapField = root.resolveField((FieldInfo) index); - if (enumMapField == null || !enumMapField.getAccessFlags().isSynthetic()) { - return null; - } - return new EnumMapInfo(inv.getArg(0), enumMapField); - } - - /** - * If all static final synthetic fields have DONT_GENERATE => hide whole class - */ - private static void checkAndHideClass(ClassNode cls) { - for (FieldNode field : cls.getFields()) { - AccessInfo af = field.getAccessFlags(); - if (af.isSynthetic() && af.isStatic() && af.isFinal() - && !field.contains(AFlag.DONT_GENERATE)) { - return; - } - } - cls.add(AFlag.DONT_GENERATE); - } - - private static class EnumMapInfo { - private final InsnArg arg; - private final FieldNode mapField; - - public EnumMapInfo(InsnArg arg, FieldNode mapField) { - this.arg = arg; - this.mapField = mapField; - } - - public InsnArg getArg() { - return arg; - } - - public FieldNode getMapField() { - return mapField; - } - } -} diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/ReplaceNewArray.java b/jadx-core/src/main/java/jadx/core/dex/visitors/ReplaceNewArray.java new file mode 100644 index 00000000..c95918d9 --- /dev/null +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/ReplaceNewArray.java @@ -0,0 +1,180 @@ +package jadx.core.dex.visitors; + +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.SortedMap; +import java.util.TreeMap; + +import jadx.core.dex.attributes.AFlag; +import jadx.core.dex.instructions.FilledNewArrayNode; +import jadx.core.dex.instructions.IndexInsnNode; +import jadx.core.dex.instructions.InsnType; +import jadx.core.dex.instructions.NewArrayNode; +import jadx.core.dex.instructions.args.ArgType; +import jadx.core.dex.instructions.args.InsnArg; +import jadx.core.dex.instructions.args.LiteralArg; +import jadx.core.dex.instructions.args.RegisterArg; +import jadx.core.dex.nodes.BlockNode; +import jadx.core.dex.nodes.FieldNode; +import jadx.core.dex.nodes.InsnNode; +import jadx.core.dex.nodes.MethodNode; +import jadx.core.dex.visitors.shrink.CodeShrinkVisitor; +import jadx.core.utils.BlockUtils; +import jadx.core.utils.InsnList; +import jadx.core.utils.InsnRemover; +import jadx.core.utils.InsnUtils; +import jadx.core.utils.exceptions.JadxException; + +@JadxVisitor( + name = "ReplaceNewArray", + desc = "Replace new-array and sequence of array-put to new filled-array instruction", + runAfter = CodeShrinkVisitor.class +) +public class ReplaceNewArray extends AbstractVisitor { + + @Override + public void visit(MethodNode mth) throws JadxException { + if (mth.isNoCode()) { + return; + } + int k = 0; + while (true) { + boolean changed = false; + InsnRemover remover = new InsnRemover(mth); + for (BlockNode block : mth.getBasicBlocks()) { + remover.setBlock(block); + List instructions = block.getInstructions(); + int size = instructions.size(); + for (int i = 0; i < size; i++) { + changed |= processInsn(mth, instructions, i, remover); + } + remover.perform(); + } + if (changed) { + CodeShrinkVisitor.shrinkMethod(mth); + } else { + break; + } + if (k++ > 100) { + mth.addWarnComment("Reached limit for ReplaceNewArray iterations"); + break; + } + } + } + + private static boolean processInsn(MethodNode mth, List instructions, int i, InsnRemover remover) { + InsnNode insn = instructions.get(i); + if (insn.getType() == InsnType.NEW_ARRAY && !insn.contains(AFlag.REMOVE)) { + return processNewArray(mth, (NewArrayNode) insn, instructions, remover); + } + return false; + } + + private static boolean processNewArray(MethodNode mth, + NewArrayNode newArrayInsn, List instructions, InsnRemover remover) { + Object arrayLenConst = InsnUtils.getConstValueByArg(mth.root(), newArrayInsn.getArg(0)); + if (!(arrayLenConst instanceof LiteralArg)) { + return false; + } + int len = (int) ((LiteralArg) arrayLenConst).getLiteral(); + if (len == 0) { + return false; + } + ArgType arrType = newArrayInsn.getArrayType(); + ArgType elemType = arrType.getArrayElement(); + boolean allowMissingKeys = arrType.getArrayDimension() == 1 && elemType.isPrimitive(); + int minLen = allowMissingKeys ? len / 2 : len; + + RegisterArg arrArg = newArrayInsn.getResult(); + List useList = arrArg.getSVar().getUseList(); + if (useList.size() < minLen) { + return false; + } + // quick check if APUT is used + boolean foundPut = false; + for (RegisterArg registerArg : useList) { + InsnNode parentInsn = registerArg.getParentInsn(); + if (parentInsn != null && parentInsn.getType() == InsnType.APUT) { + foundPut = true; + break; + } + } + if (!foundPut) { + return false; + } + // collect put instructions sorted by array index + SortedMap arrPuts = new TreeMap<>(); + for (RegisterArg registerArg : useList) { + InsnNode parentInsn = registerArg.getParentInsn(); + if (parentInsn == null || parentInsn.getType() != InsnType.APUT) { + continue; + } + if (!arrArg.sameRegAndSVar(parentInsn.getArg(0))) { + return false; + } + Object constVal = InsnUtils.getConstValueByArg(mth.root(), parentInsn.getArg(1)); + if (!(constVal instanceof LiteralArg)) { + return false; + } + long index = ((LiteralArg) constVal).getLiteral(); + if (index >= len) { + return false; + } + if (arrPuts.containsKey(index)) { + // stop on index rewrite + break; + } + arrPuts.put(index, parentInsn); + } + if (arrPuts.size() < minLen) { + return false; + } + // expect all puts to be in same block + if (!new HashSet<>(instructions).containsAll(arrPuts.values())) { + return false; + } + + // checks complete, apply + InsnNode filledArr = new FilledNewArrayNode(elemType, len); + filledArr.setResult(arrArg.duplicate()); + filledArr.copyAttributesFrom(newArrayInsn); + filledArr.inheritMetadata(newArrayInsn); + filledArr.setOffset(newArrayInsn.getOffset()); + + long prevIndex = -1; + for (Map.Entry entry : arrPuts.entrySet()) { + long index = entry.getKey(); + if (index != prevIndex) { + // use zero for missing keys + for (long i = prevIndex + 1; i < index; i++) { + filledArr.addArg(InsnArg.lit(0, elemType)); + } + } + InsnNode put = entry.getValue(); + filledArr.addArg(replaceConstInArg(mth, put.getArg(2))); + remover.addAndUnbind(put); + prevIndex = index; + } + remover.addAndUnbind(newArrayInsn); + + InsnNode lastPut = arrPuts.get(arrPuts.lastKey()); + int replaceIndex = InsnList.getIndex(instructions, lastPut); + instructions.set(replaceIndex, filledArr); + BlockUtils.replaceInsn(mth, lastPut, filledArr); + return true; + } + + private static InsnArg replaceConstInArg(MethodNode mth, InsnArg valueArg) { + if (valueArg.isLiteral()) { + FieldNode f = mth.getParentClass().getConstFieldByLiteralArg((LiteralArg) valueArg); + if (f != null) { + InsnNode fGet = new IndexInsnNode(InsnType.SGET, f.getFieldInfo(), 0); + InsnArg arg = InsnArg.wrapArg(fGet); + f.addUseIn(mth); + return arg; + } + } + return valueArg.duplicate(); + } +} diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/RegionMaker.java b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/RegionMaker.java index 0ef9d216..e83188c6 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/RegionMaker.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/RegionMaker.java @@ -22,6 +22,7 @@ import jadx.core.dex.attributes.AType; import jadx.core.dex.attributes.nodes.EdgeInsnAttr; import jadx.core.dex.attributes.nodes.LoopInfo; import jadx.core.dex.attributes.nodes.LoopLabelAttr; +import jadx.core.dex.attributes.nodes.RegionRefAttr; import jadx.core.dex.instructions.IfNode; import jadx.core.dex.instructions.InsnType; import jadx.core.dex.instructions.SwitchInsn; @@ -827,6 +828,7 @@ public class RegionMaker { } SwitchRegion sw = new SwitchRegion(currentRegion, block); + insn.addAttr(new RegionRefAttr(sw)); currentRegion.getSubBlocks().add(sw); stack.push(sw); stack.addExit(out); diff --git a/jadx-core/src/test/java/jadx/tests/integration/enums/TestSwitchOverEnum.java b/jadx-core/src/test/java/jadx/tests/integration/enums/TestSwitchOverEnum.java index 10fa01f0..8416da8c 100644 --- a/jadx-core/src/test/java/jadx/tests/integration/enums/TestSwitchOverEnum.java +++ b/jadx-core/src/test/java/jadx/tests/integration/enums/TestSwitchOverEnum.java @@ -2,14 +2,11 @@ package jadx.tests.integration.enums; import org.junit.jupiter.api.Test; -import jadx.core.dex.nodes.ClassNode; -import jadx.tests.api.IntegrationTest; +import jadx.tests.api.SmaliTest; -import static jadx.tests.api.utils.JadxMatchers.countString; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.junit.jupiter.api.Assertions.assertEquals; +import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat; -public class TestSwitchOverEnum extends IntegrationTest { +public class TestSwitchOverEnum extends SmaliTest { public enum Count { ONE, TWO, THREE @@ -26,18 +23,29 @@ public class TestSwitchOverEnum extends IntegrationTest { } public void check() { - assertEquals(1, testEnum(Count.ONE)); - assertEquals(2, testEnum(Count.TWO)); - assertEquals(0, testEnum(Count.THREE)); + assertThat(testEnum(Count.ONE)).isEqualTo(1); + assertThat(testEnum(Count.TWO)).isEqualTo(2); + assertThat(testEnum(Count.THREE)).isEqualTo(0); } @Test public void test() { - ClassNode cls = getClassNode(TestSwitchOverEnum.class); - String code = cls.getCode().toString(); + // remapping array placed in top class, place test also in top class + assertThat(getClassNode(TestSwitchOverEnum.class)) + .code() + .countString(1, "synthetic") + .countString(2, "switch (c) {") + .countString(3, "case ONE:"); + } - assertThat(code, countString(1, "synthetic")); - assertThat(code, countString(2, "switch (c) {")); - assertThat(code, countString(2, "case ONE:")); + /** + * Java 21 compiler can omit a remapping array and use switch over ordinal directly + */ + @Test + public void testSmaliDirect() { + assertThat(getClassNodeFromSmaliFiles()) + .code() + .containsOne("switch (v) {") + .containsOne("case ONE:"); } } diff --git a/jadx-core/src/test/smali/enums/TestSwitchOverEnum/Count.smali b/jadx-core/src/test/smali/enums/TestSwitchOverEnum/Count.smali new file mode 100644 index 00000000..4b0fdc17 --- /dev/null +++ b/jadx-core/src/test/smali/enums/TestSwitchOverEnum/Count.smali @@ -0,0 +1,76 @@ +.class public final enum Lenums/TestSwitchOverEnum$Count; +.super Ljava/lang/Enum; + +.field private static final synthetic $VALUES:[Lenums/TestSwitchOverEnum$Count; +.field public static final enum ONE:Lenums/TestSwitchOverEnum$Count; +.field public static final enum THREE:Lenums/TestSwitchOverEnum$Count; +.field public static final enum TWO:Lenums/TestSwitchOverEnum$Count; + +.method private static synthetic $values()[Lenums/TestSwitchOverEnum$Count; + .registers 3 + const/4 v0, 0x3 + new-array v0, v0, [Lenums/TestSwitchOverEnum$Count; + const/4 v1, 0x0 + sget-object v2, Lenums/TestSwitchOverEnum$Count;->ONE:Lenums/TestSwitchOverEnum$Count; + aput-object v2, v0, v1 + const/4 v1, 0x1 + sget-object v2, Lenums/TestSwitchOverEnum$Count;->TWO:Lenums/TestSwitchOverEnum$Count; + aput-object v2, v0, v1 + const/4 v1, 0x2 + sget-object v2, Lenums/TestSwitchOverEnum$Count;->THREE:Lenums/TestSwitchOverEnum$Count; + aput-object v2, v0, v1 + return-object v0 +.end method + +.method static constructor ()V + .registers 3 + new-instance v0, Lenums/TestSwitchOverEnum$Count; + const-string v1, "ONE" + const/4 v2, 0x0 + invoke-direct {v0, v1, v2}, Lenums/TestSwitchOverEnum$Count;->(Ljava/lang/String;I)V + sput-object v0, Lenums/TestSwitchOverEnum$Count;->ONE:Lenums/TestSwitchOverEnum$Count; + new-instance v0, Lenums/TestSwitchOverEnum$Count; + const-string v1, "TWO" + const/4 v2, 0x1 + invoke-direct {v0, v1, v2}, Lenums/TestSwitchOverEnum$Count;->(Ljava/lang/String;I)V + sput-object v0, Lenums/TestSwitchOverEnum$Count;->TWO:Lenums/TestSwitchOverEnum$Count; + new-instance v0, Lenums/TestSwitchOverEnum$Count; + const-string v1, "THREE" + const/4 v2, 0x2 + invoke-direct {v0, v1, v2}, Lenums/TestSwitchOverEnum$Count;->(Ljava/lang/String;I)V + sput-object v0, Lenums/TestSwitchOverEnum$Count;->THREE:Lenums/TestSwitchOverEnum$Count; + invoke-static {}, Lenums/TestSwitchOverEnum$Count;->$values()[Lenums/TestSwitchOverEnum$Count; + move-result-object v0 + sput-object v0, Lenums/TestSwitchOverEnum$Count;->$VALUES:[Lenums/TestSwitchOverEnum$Count; + return-void +.end method + +.method private constructor (Ljava/lang/String;I)V + .registers 3 + .annotation system Ldalvik/annotation/Signature; + value = { + "()V" + } + .end annotation + + invoke-direct {p0, p1, p2}, Ljava/lang/Enum;->(Ljava/lang/String;I)V + return-void +.end method + +.method public static valueOf(Ljava/lang/String;)Lenums/TestSwitchOverEnum$Count; + .registers 2 + const-class v0, Lenums/TestSwitchOverEnum$Count; + invoke-static {v0, p0}, Ljava/lang/Enum;->valueOf(Ljava/lang/Class;Ljava/lang/String;)Ljava/lang/Enum; + move-result-object v0 + check-cast v0, Lenums/TestSwitchOverEnum$Count; + return-object v0 +.end method + +.method public static values()[Lenums/TestSwitchOverEnum$Count; + .registers 1 + sget-object v0, Lenums/TestSwitchOverEnum$Count;->$VALUES:[Lenums/TestSwitchOverEnum$Count; + invoke-virtual {v0}, [Lenums/TestSwitchOverEnum$Count;->clone()Ljava/lang/Object; + move-result-object v0 + check-cast v0, [Lenums/TestSwitchOverEnum$Count; + return-object v0 +.end method diff --git a/jadx-core/src/test/smali/enums/TestSwitchOverEnum/TestSwitchOverEnum.smali b/jadx-core/src/test/smali/enums/TestSwitchOverEnum/TestSwitchOverEnum.smali new file mode 100644 index 00000000..1f2bea37 --- /dev/null +++ b/jadx-core/src/test/smali/enums/TestSwitchOverEnum/TestSwitchOverEnum.smali @@ -0,0 +1,30 @@ +.class public Lenums/TestSwitchOverEnum; +.super Ljava/lang/Object; + +.method public test(Lenums/TestSwitchOverEnum$Count;)I + .registers 3 + .param p1, "v" + + invoke-virtual {p1}, Lenums/TestSwitchOverEnum$Count;->ordinal()I + move-result v0 + + packed-switch v0, :pswitch_data + const/4 v0, 0x0 + + :goto_8 + return v0 + + :pswitch_9 + const/4 v0, 0x1 + goto :goto_8 + + :pswitch_b + const/4 v0, 0x2 + goto :goto_8 + + :pswitch_data + .packed-switch 0x0 + :pswitch_9 + :pswitch_b + .end packed-switch +.end method