fix: support switch over enum by ordinal

This commit is contained in:
Skylot 2024-01-16 20:01:37 +00:00
parent f994abee21
commit 8e7ffc8ddb
No known key found for this signature in database
GPG Key ID: 47866607B16F25C8
11 changed files with 645 additions and 387 deletions

View File

@ -29,6 +29,7 @@ import jadx.core.dex.visitors.EnumVisitor;
import jadx.core.dex.visitors.ExtractFieldInit;
import jadx.core.dex.visitors.FallbackModeVisitor;
import jadx.core.dex.visitors.FixAccessModifiers;
import jadx.core.dex.visitors.FixSwitchOverEnum;
import jadx.core.dex.visitors.GenericTypesVisitor;
import jadx.core.dex.visitors.IDexTreeVisitor;
import jadx.core.dex.visitors.InitCodeVariables;
@ -43,7 +44,7 @@ import jadx.core.dex.visitors.PrepareForCodeGen;
import jadx.core.dex.visitors.ProcessAnonymous;
import jadx.core.dex.visitors.ProcessInstructionsVisitor;
import jadx.core.dex.visitors.ProcessMethodsForInline;
import jadx.core.dex.visitors.ReSugarCode;
import jadx.core.dex.visitors.ReplaceNewArray;
import jadx.core.dex.visitors.ShadowFieldVisitor;
import jadx.core.dex.visitors.SignatureProcessor;
import jadx.core.dex.visitors.SimplifyVisitor;
@ -154,7 +155,7 @@ public class Jadx {
passes.add(new AnonymousClassVisitor());
passes.add(new ModVisitor());
passes.add(new CodeShrinkVisitor());
passes.add(new ReSugarCode());
passes.add(new ReplaceNewArray());
if (args.isCfgOutput()) {
passes.add(DotGraphVisitor.dump());
}
@ -171,6 +172,7 @@ public class Jadx {
passes.add(new CheckRegions());
passes.add(new EnumVisitor());
passes.add(new FixSwitchOverEnum());
passes.add(new ExtractFieldInit());
passes.add(new FixAccessModifiers());
passes.add(new ClassModifier());
@ -219,8 +221,7 @@ public class Jadx {
passes.add(new DeboxingVisitor());
passes.add(new ModVisitor());
passes.add(new CodeShrinkVisitor());
passes.add(new ReSugarCode());
passes.add(new CodeShrinkVisitor());
passes.add(new ReplaceNewArray());
passes.add(new SimplifyVisitor());
passes.add(new MethodVisitor("ForceGenerateAll", mth -> mth.remove(AFlag.DONT_GENERATE)));
if (args.isCfgOutput()) {

View File

@ -25,6 +25,7 @@ import jadx.core.dex.attributes.nodes.MethodReplaceAttr;
import jadx.core.dex.attributes.nodes.MethodTypeVarsAttr;
import jadx.core.dex.attributes.nodes.PhiListAttr;
import jadx.core.dex.attributes.nodes.RegDebugInfoAttr;
import jadx.core.dex.attributes.nodes.RegionRefAttr;
import jadx.core.dex.attributes.nodes.RenameReasonAttr;
import jadx.core.dex.attributes.nodes.SkipMethodArgsAttr;
import jadx.core.dex.attributes.nodes.SpecialEdgeAttr;
@ -94,6 +95,7 @@ public final class AType<T extends IJadxAttribute> implements IJadxAttrType<T> {
public static final AType<AttrList<JumpInfo>> JUMP = new AType<>();
public static final AType<IMethodDetails> METHOD_DETAILS = new AType<>();
public static final AType<GenericInfoAttr> GENERIC_INFO = new AType<>();
public static final AType<RegionRefAttr> REGION_REF = new AType<>();
// register
public static final AType<RegDebugInfoAttr> REG_DEBUG_INFO = new AType<>();

View File

@ -0,0 +1,30 @@
package jadx.core.dex.attributes.nodes;
import jadx.api.plugins.input.data.attributes.IJadxAttribute;
import jadx.core.dex.attributes.AType;
import jadx.core.dex.nodes.IRegion;
/**
* Region created based on parent instruction
*/
public class RegionRefAttr implements IJadxAttribute {
private final IRegion region;
public RegionRefAttr(IRegion region) {
this.region = region;
}
public IRegion getRegion() {
return region;
}
@Override
public AType<RegionRefAttr> getAttrType() {
return AType.REGION_REF;
}
@Override
public String toString() {
return "RegionRef:" + region;
}
}

View File

@ -63,7 +63,7 @@ import static jadx.core.utils.InsnUtils.getWrappedInsn;
runAfter = {
CodeShrinkVisitor.class, // all possible instructions already inlined
ModVisitor.class,
ReSugarCode.class,
ReplaceNewArray.class, // values array normalized
IfRegionVisitor.class, // ternary operator inlined
CheckRegions.class // regions processing finished
},

View File

@ -0,0 +1,297 @@
package jadx.core.dex.visitors;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.IntFunction;
import org.jetbrains.annotations.Nullable;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.attributes.AType;
import jadx.core.dex.attributes.nodes.EnumClassAttr;
import jadx.core.dex.attributes.nodes.EnumMapAttr;
import jadx.core.dex.attributes.nodes.RegionRefAttr;
import jadx.core.dex.info.AccessInfo;
import jadx.core.dex.info.FieldInfo;
import jadx.core.dex.info.MethodInfo;
import jadx.core.dex.instructions.IndexInsnNode;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.InvokeNode;
import jadx.core.dex.instructions.SwitchInsn;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.InsnWrapArg;
import jadx.core.dex.instructions.args.LiteralArg;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.ClassNode;
import jadx.core.dex.nodes.FieldNode;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.nodes.RootNode;
import jadx.core.dex.regions.SwitchRegion;
import jadx.core.dex.visitors.shrink.CodeShrinkVisitor;
import jadx.core.utils.Utils;
import jadx.core.utils.exceptions.JadxException;
@JadxVisitor(
name = "FixSwitchOverEnum",
desc = "Simplify synthetic code in switch over enum",
runAfter = {
CodeShrinkVisitor.class,
EnumVisitor.class
}
)
public class FixSwitchOverEnum extends AbstractVisitor {
@Override
public boolean visit(ClassNode cls) throws JadxException {
initClsEnumMap(cls);
return true;
}
@Override
public void visit(MethodNode mth) throws JadxException {
if (mth.isNoCode()) {
return;
}
boolean changed = false;
for (BlockNode block : mth.getBasicBlocks()) {
for (InsnNode insn : block.getInstructions()) {
if (insn.getType() == InsnType.SWITCH && !insn.contains(AFlag.REMOVE)) {
changed |= processEnumSwitch(mth, (SwitchInsn) insn);
}
}
}
if (changed) {
CodeShrinkVisitor.shrinkMethod(mth);
}
}
private static boolean processEnumSwitch(MethodNode mth, SwitchInsn insn) {
InsnArg arg = insn.getArg(0);
if (!arg.isInsnWrap()) {
return false;
}
InsnNode wrapInsn = ((InsnWrapArg) arg).getWrapInsn();
switch (wrapInsn.getType()) {
case AGET:
return processRemappedEnumSwitch(mth, insn, wrapInsn, arg);
case INVOKE:
return processDirectEnumSwitch(mth, insn, (InvokeNode) wrapInsn, arg);
}
return false;
}
private static boolean executeReplace(SwitchInsn swInsn, InsnArg arg, InsnArg invVar, IntFunction<Object> caseReplace) {
RegionRefAttr regionRefAttr = swInsn.get(AType.REGION_REF);
if (regionRefAttr == null) {
return false;
}
if (!swInsn.replaceArg(arg, invVar)) {
return false;
}
Map<Object, Object> replaceMap = new HashMap<>();
int caseCount = swInsn.getKeys().length;
for (int i = 0; i < caseCount; i++) {
Object key = swInsn.getKey(i);
Object replaceObj = caseReplace.apply(i);
swInsn.modifyKey(i, replaceObj);
replaceMap.put(key, replaceObj);
}
SwitchRegion region = (SwitchRegion) regionRefAttr.getRegion();
for (SwitchRegion.CaseInfo caseInfo : region.getCases()) {
caseInfo.getKeys().replaceAll(key -> Utils.getOrElse(replaceMap.get(key), key));
}
return true;
}
private static boolean processDirectEnumSwitch(MethodNode mth, SwitchInsn swInsn, InvokeNode invInsn, InsnArg arg) {
MethodInfo callMth = invInsn.getCallMth();
if (!callMth.getShortId().equals("ordinal()I")) {
return false;
}
InsnArg invVar = invInsn.getArg(0);
ClassNode enumCls = mth.root().resolveClass(invVar.getType());
if (enumCls == null) {
return false;
}
EnumClassAttr enumClassAttr = enumCls.get(AType.ENUM_CLASS);
if (enumClassAttr == null) {
return false;
}
FieldNode[] casesReplaceArr = mapToCases(swInsn, enumClassAttr.getFields());
if (casesReplaceArr == null) {
return false;
}
return executeReplace(swInsn, arg, invVar, i -> casesReplaceArr[i]);
}
private static @Nullable FieldNode[] mapToCases(SwitchInsn swInsn, List<EnumClassAttr.EnumField> fields) {
int caseCount = swInsn.getKeys().length;
if (fields.size() < caseCount) {
return null;
}
FieldNode[] casesMap = new FieldNode[caseCount];
for (int i = 0; i < caseCount; i++) {
Object key = swInsn.getKey(i);
if (key instanceof Integer) {
int ordinal = (Integer) key;
try {
casesMap[ordinal] = fields.get(ordinal).getField();
} catch (Exception e) {
return null;
}
} else {
return null;
}
}
return casesMap;
}
private static boolean processRemappedEnumSwitch(MethodNode mth, SwitchInsn insn, InsnNode wrapInsn, InsnArg arg) {
EnumMapInfo enumMapInfo = checkEnumMapAccess(mth.root(), wrapInsn);
if (enumMapInfo == null) {
return false;
}
FieldNode enumMapField = enumMapInfo.getMapField();
InsnArg invArg = enumMapInfo.getArg();
EnumMapAttr.KeyValueMap valueMap = getEnumMap(enumMapField);
if (valueMap == null) {
return false;
}
int caseCount = insn.getKeys().length;
for (int i = 0; i < caseCount; i++) {
Object key = insn.getKey(i);
Object newKey = valueMap.get(key);
if (newKey == null) {
return false;
}
}
if (executeReplace(insn, arg, invArg, i -> valueMap.get(insn.getKey(i)))) {
enumMapField.add(AFlag.DONT_GENERATE);
checkAndHideClass(enumMapField.getParentClass());
return true;
}
return false;
}
private static void initClsEnumMap(ClassNode enumCls) {
MethodNode clsInitMth = enumCls.getClassInitMth();
if (clsInitMth == null || clsInitMth.isNoCode() || clsInitMth.getBasicBlocks() == null) {
return;
}
EnumMapAttr mapAttr = new EnumMapAttr();
for (BlockNode block : clsInitMth.getBasicBlocks()) {
for (InsnNode insn : block.getInstructions()) {
if (insn.getType() == InsnType.APUT) {
addToEnumMap(enumCls.root(), mapAttr, insn);
}
}
}
if (!mapAttr.isEmpty()) {
enumCls.addAttr(mapAttr);
}
}
private static @Nullable EnumMapAttr.KeyValueMap getEnumMap(FieldNode field) {
ClassNode syntheticClass = field.getParentClass();
EnumMapAttr mapAttr = syntheticClass.get(AType.ENUM_MAP);
if (mapAttr == null) {
return null;
}
return mapAttr.getMap(field);
}
private static void addToEnumMap(RootNode root, EnumMapAttr mapAttr, InsnNode aputInsn) {
InsnArg litArg = aputInsn.getArg(2);
if (!litArg.isLiteral()) {
return;
}
EnumMapInfo mapInfo = checkEnumMapAccess(root, aputInsn);
if (mapInfo == null) {
return;
}
InsnArg enumArg = mapInfo.getArg();
FieldNode field = mapInfo.getMapField();
if (field == null || !enumArg.isInsnWrap()) {
return;
}
InsnNode sget = ((InsnWrapArg) enumArg).getWrapInsn();
if (!(sget instanceof IndexInsnNode)) {
return;
}
Object index = ((IndexInsnNode) sget).getIndex();
if (!(index instanceof FieldInfo)) {
return;
}
FieldNode fieldNode = root.resolveField((FieldInfo) index);
if (fieldNode == null) {
return;
}
int literal = (int) ((LiteralArg) litArg).getLiteral();
mapAttr.add(field, literal, fieldNode);
}
private static @Nullable EnumMapInfo checkEnumMapAccess(RootNode root, InsnNode checkInsn) {
InsnArg sgetArg = checkInsn.getArg(0);
InsnArg invArg = checkInsn.getArg(1);
if (!sgetArg.isInsnWrap() || !invArg.isInsnWrap()) {
return null;
}
InsnNode invInsn = ((InsnWrapArg) invArg).getWrapInsn();
InsnNode sgetInsn = ((InsnWrapArg) sgetArg).getWrapInsn();
if (invInsn.getType() != InsnType.INVOKE || sgetInsn.getType() != InsnType.SGET) {
return null;
}
InvokeNode inv = (InvokeNode) invInsn;
if (!inv.getCallMth().getShortId().equals("ordinal()I")) {
return null;
}
ClassNode enumCls = root.resolveClass(inv.getCallMth().getDeclClass());
if (enumCls == null || !enumCls.isEnum()) {
return null;
}
Object index = ((IndexInsnNode) sgetInsn).getIndex();
if (!(index instanceof FieldInfo)) {
return null;
}
FieldNode enumMapField = root.resolveField((FieldInfo) index);
if (enumMapField == null || !enumMapField.getAccessFlags().isSynthetic()) {
return null;
}
return new EnumMapInfo(inv.getArg(0), enumMapField);
}
/**
* If all static final synthetic fields have DONT_GENERATE => hide whole class
*/
private static void checkAndHideClass(ClassNode cls) {
for (FieldNode field : cls.getFields()) {
AccessInfo af = field.getAccessFlags();
if (af.isSynthetic() && af.isStatic() && af.isFinal()
&& !field.contains(AFlag.DONT_GENERATE)) {
return;
}
}
cls.add(AFlag.DONT_GENERATE);
}
private static class EnumMapInfo {
private final InsnArg arg;
private final FieldNode mapField;
public EnumMapInfo(InsnArg arg, FieldNode mapField) {
this.arg = arg;
this.mapField = mapField;
}
public InsnArg getArg() {
return arg;
}
public FieldNode getMapField() {
return mapField;
}
}
}

View File

@ -1,368 +0,0 @@
package jadx.core.dex.visitors;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.SortedMap;
import java.util.TreeMap;
import org.jetbrains.annotations.Nullable;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.attributes.AType;
import jadx.core.dex.attributes.nodes.EnumMapAttr;
import jadx.core.dex.info.AccessInfo;
import jadx.core.dex.info.FieldInfo;
import jadx.core.dex.instructions.FilledNewArrayNode;
import jadx.core.dex.instructions.IndexInsnNode;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.InvokeNode;
import jadx.core.dex.instructions.NewArrayNode;
import jadx.core.dex.instructions.SwitchInsn;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.InsnWrapArg;
import jadx.core.dex.instructions.args.LiteralArg;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.ClassNode;
import jadx.core.dex.nodes.FieldNode;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.nodes.RootNode;
import jadx.core.dex.visitors.shrink.CodeShrinkVisitor;
import jadx.core.utils.BlockUtils;
import jadx.core.utils.InsnList;
import jadx.core.utils.InsnRemover;
import jadx.core.utils.InsnUtils;
import jadx.core.utils.exceptions.JadxException;
@JadxVisitor(
name = "ReSugarCode",
desc = "Simplify synthetic or verbose code",
runAfter = CodeShrinkVisitor.class
)
public class ReSugarCode extends AbstractVisitor {
@Override
public boolean visit(ClassNode cls) throws JadxException {
initClsEnumMap(cls);
return true;
}
@Override
public void visit(MethodNode mth) throws JadxException {
if (mth.isNoCode()) {
return;
}
int k = 0;
while (true) {
boolean changed = false;
InsnRemover remover = new InsnRemover(mth);
for (BlockNode block : mth.getBasicBlocks()) {
remover.setBlock(block);
List<InsnNode> instructions = block.getInstructions();
int size = instructions.size();
for (int i = 0; i < size; i++) {
changed |= process(mth, instructions, i, remover);
}
remover.perform();
}
if (changed) {
CodeShrinkVisitor.shrinkMethod(mth);
} else {
break;
}
if (k++ > 100) {
mth.addWarnComment("Reached limit for ReSugarCode iterations");
break;
}
}
}
private static boolean process(MethodNode mth, List<InsnNode> instructions, int i, InsnRemover remover) {
InsnNode insn = instructions.get(i);
if (insn.contains(AFlag.REMOVE)) {
return false;
}
switch (insn.getType()) {
case NEW_ARRAY:
return processNewArray(mth, (NewArrayNode) insn, instructions, remover);
case SWITCH:
return processEnumSwitch(mth, (SwitchInsn) insn);
default:
return false;
}
}
/**
* Replace new-array and sequence of array-put to new filled-array instruction.
*/
private static boolean processNewArray(MethodNode mth, NewArrayNode newArrayInsn, List<InsnNode> instructions, InsnRemover remover) {
Object arrayLenConst = InsnUtils.getConstValueByArg(mth.root(), newArrayInsn.getArg(0));
if (!(arrayLenConst instanceof LiteralArg)) {
return false;
}
int len = (int) ((LiteralArg) arrayLenConst).getLiteral();
if (len == 0) {
return false;
}
ArgType arrType = newArrayInsn.getArrayType();
ArgType elemType = arrType.getArrayElement();
boolean allowMissingKeys = arrType.getArrayDimension() == 1 && elemType.isPrimitive();
int minLen = allowMissingKeys ? len / 2 : len;
RegisterArg arrArg = newArrayInsn.getResult();
List<RegisterArg> useList = arrArg.getSVar().getUseList();
if (useList.size() < minLen) {
return false;
}
// quick check if APUT is used
boolean foundPut = false;
for (RegisterArg registerArg : useList) {
InsnNode parentInsn = registerArg.getParentInsn();
if (parentInsn != null && parentInsn.getType() == InsnType.APUT) {
foundPut = true;
break;
}
}
if (!foundPut) {
return false;
}
// collect put instructions sorted by array index
SortedMap<Long, InsnNode> arrPuts = new TreeMap<>();
for (RegisterArg registerArg : useList) {
InsnNode parentInsn = registerArg.getParentInsn();
if (parentInsn == null || parentInsn.getType() != InsnType.APUT) {
continue;
}
if (!arrArg.sameRegAndSVar(parentInsn.getArg(0))) {
return false;
}
Object constVal = InsnUtils.getConstValueByArg(mth.root(), parentInsn.getArg(1));
if (!(constVal instanceof LiteralArg)) {
return false;
}
long index = ((LiteralArg) constVal).getLiteral();
if (index >= len) {
return false;
}
if (arrPuts.containsKey(index)) {
// stop on index rewrite
break;
}
arrPuts.put(index, parentInsn);
}
if (arrPuts.size() < minLen) {
return false;
}
// expect all puts to be in same block
if (!new HashSet<>(instructions).containsAll(arrPuts.values())) {
return false;
}
// checks complete, apply
InsnNode filledArr = new FilledNewArrayNode(elemType, len);
filledArr.setResult(arrArg.duplicate());
filledArr.copyAttributesFrom(newArrayInsn);
filledArr.inheritMetadata(newArrayInsn);
filledArr.setOffset(newArrayInsn.getOffset());
long prevIndex = -1;
for (Map.Entry<Long, InsnNode> entry : arrPuts.entrySet()) {
long index = entry.getKey();
if (index != prevIndex) {
// use zero for missing keys
for (long i = prevIndex + 1; i < index; i++) {
filledArr.addArg(InsnArg.lit(0, elemType));
}
}
InsnNode put = entry.getValue();
filledArr.addArg(replaceConstInArg(mth, put.getArg(2)));
remover.addAndUnbind(put);
prevIndex = index;
}
remover.addAndUnbind(newArrayInsn);
InsnNode lastPut = arrPuts.get(arrPuts.lastKey());
int replaceIndex = InsnList.getIndex(instructions, lastPut);
instructions.set(replaceIndex, filledArr);
BlockUtils.replaceInsn(mth, lastPut, filledArr);
return true;
}
private static InsnArg replaceConstInArg(MethodNode mth, InsnArg valueArg) {
if (valueArg.isLiteral()) {
FieldNode f = mth.getParentClass().getConstFieldByLiteralArg((LiteralArg) valueArg);
if (f != null) {
InsnNode fGet = new IndexInsnNode(InsnType.SGET, f.getFieldInfo(), 0);
InsnArg arg = InsnArg.wrapArg(fGet);
f.addUseIn(mth);
return arg;
}
}
return valueArg.duplicate();
}
private static boolean processEnumSwitch(MethodNode mth, SwitchInsn insn) {
InsnArg arg = insn.getArg(0);
if (!arg.isInsnWrap()) {
return false;
}
InsnNode wrapInsn = ((InsnWrapArg) arg).getWrapInsn();
if (wrapInsn.getType() != InsnType.AGET) {
return false;
}
EnumMapInfo enumMapInfo = checkEnumMapAccess(mth.root(), wrapInsn);
if (enumMapInfo == null) {
return false;
}
FieldNode enumMapField = enumMapInfo.getMapField();
InsnArg invArg = enumMapInfo.getArg();
EnumMapAttr.KeyValueMap valueMap = getEnumMap(mth, enumMapField);
if (valueMap == null) {
return false;
}
int caseCount = insn.getKeys().length;
for (int i = 0; i < caseCount; i++) {
Object key = insn.getKey(i);
Object newKey = valueMap.get(key);
if (newKey == null) {
return false;
}
}
// replace confirmed
if (!insn.replaceArg(arg, invArg)) {
return false;
}
for (int i = 0; i < caseCount; i++) {
insn.modifyKey(i, valueMap.get(insn.getKey(i)));
}
enumMapField.add(AFlag.DONT_GENERATE);
checkAndHideClass(enumMapField.getParentClass());
return true;
}
private static void initClsEnumMap(ClassNode enumCls) {
MethodNode clsInitMth = enumCls.getClassInitMth();
if (clsInitMth == null || clsInitMth.isNoCode() || clsInitMth.getBasicBlocks() == null) {
return;
}
EnumMapAttr mapAttr = new EnumMapAttr();
for (BlockNode block : clsInitMth.getBasicBlocks()) {
for (InsnNode insn : block.getInstructions()) {
if (insn.getType() == InsnType.APUT) {
addToEnumMap(enumCls.root(), mapAttr, insn);
}
}
}
if (!mapAttr.isEmpty()) {
enumCls.addAttr(mapAttr);
}
}
@Nullable
private static EnumMapAttr.KeyValueMap getEnumMap(MethodNode mth, FieldNode field) {
ClassNode syntheticClass = field.getParentClass();
EnumMapAttr mapAttr = syntheticClass.get(AType.ENUM_MAP);
if (mapAttr == null) {
return null;
}
return mapAttr.getMap(field);
}
private static void addToEnumMap(RootNode root, EnumMapAttr mapAttr, InsnNode aputInsn) {
InsnArg litArg = aputInsn.getArg(2);
if (!litArg.isLiteral()) {
return;
}
EnumMapInfo mapInfo = checkEnumMapAccess(root, aputInsn);
if (mapInfo == null) {
return;
}
InsnArg enumArg = mapInfo.getArg();
FieldNode field = mapInfo.getMapField();
if (field == null || !enumArg.isInsnWrap()) {
return;
}
InsnNode sget = ((InsnWrapArg) enumArg).getWrapInsn();
if (!(sget instanceof IndexInsnNode)) {
return;
}
Object index = ((IndexInsnNode) sget).getIndex();
if (!(index instanceof FieldInfo)) {
return;
}
FieldNode fieldNode = root.resolveField((FieldInfo) index);
if (fieldNode == null) {
return;
}
int literal = (int) ((LiteralArg) litArg).getLiteral();
mapAttr.add(field, literal, fieldNode);
}
public static EnumMapInfo checkEnumMapAccess(RootNode root, InsnNode checkInsn) {
InsnArg sgetArg = checkInsn.getArg(0);
InsnArg invArg = checkInsn.getArg(1);
if (!sgetArg.isInsnWrap() || !invArg.isInsnWrap()) {
return null;
}
InsnNode invInsn = ((InsnWrapArg) invArg).getWrapInsn();
InsnNode sgetInsn = ((InsnWrapArg) sgetArg).getWrapInsn();
if (invInsn.getType() != InsnType.INVOKE || sgetInsn.getType() != InsnType.SGET) {
return null;
}
InvokeNode inv = (InvokeNode) invInsn;
if (!inv.getCallMth().getShortId().equals("ordinal()I")) {
return null;
}
ClassNode enumCls = root.resolveClass(inv.getCallMth().getDeclClass());
if (enumCls == null || !enumCls.isEnum()) {
return null;
}
Object index = ((IndexInsnNode) sgetInsn).getIndex();
if (!(index instanceof FieldInfo)) {
return null;
}
FieldNode enumMapField = root.resolveField((FieldInfo) index);
if (enumMapField == null || !enumMapField.getAccessFlags().isSynthetic()) {
return null;
}
return new EnumMapInfo(inv.getArg(0), enumMapField);
}
/**
* If all static final synthetic fields have DONT_GENERATE => hide whole class
*/
private static void checkAndHideClass(ClassNode cls) {
for (FieldNode field : cls.getFields()) {
AccessInfo af = field.getAccessFlags();
if (af.isSynthetic() && af.isStatic() && af.isFinal()
&& !field.contains(AFlag.DONT_GENERATE)) {
return;
}
}
cls.add(AFlag.DONT_GENERATE);
}
private static class EnumMapInfo {
private final InsnArg arg;
private final FieldNode mapField;
public EnumMapInfo(InsnArg arg, FieldNode mapField) {
this.arg = arg;
this.mapField = mapField;
}
public InsnArg getArg() {
return arg;
}
public FieldNode getMapField() {
return mapField;
}
}
}

View File

@ -0,0 +1,180 @@
package jadx.core.dex.visitors;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.SortedMap;
import java.util.TreeMap;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.instructions.FilledNewArrayNode;
import jadx.core.dex.instructions.IndexInsnNode;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.NewArrayNode;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.LiteralArg;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.FieldNode;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.visitors.shrink.CodeShrinkVisitor;
import jadx.core.utils.BlockUtils;
import jadx.core.utils.InsnList;
import jadx.core.utils.InsnRemover;
import jadx.core.utils.InsnUtils;
import jadx.core.utils.exceptions.JadxException;
@JadxVisitor(
name = "ReplaceNewArray",
desc = "Replace new-array and sequence of array-put to new filled-array instruction",
runAfter = CodeShrinkVisitor.class
)
public class ReplaceNewArray extends AbstractVisitor {
@Override
public void visit(MethodNode mth) throws JadxException {
if (mth.isNoCode()) {
return;
}
int k = 0;
while (true) {
boolean changed = false;
InsnRemover remover = new InsnRemover(mth);
for (BlockNode block : mth.getBasicBlocks()) {
remover.setBlock(block);
List<InsnNode> instructions = block.getInstructions();
int size = instructions.size();
for (int i = 0; i < size; i++) {
changed |= processInsn(mth, instructions, i, remover);
}
remover.perform();
}
if (changed) {
CodeShrinkVisitor.shrinkMethod(mth);
} else {
break;
}
if (k++ > 100) {
mth.addWarnComment("Reached limit for ReplaceNewArray iterations");
break;
}
}
}
private static boolean processInsn(MethodNode mth, List<InsnNode> instructions, int i, InsnRemover remover) {
InsnNode insn = instructions.get(i);
if (insn.getType() == InsnType.NEW_ARRAY && !insn.contains(AFlag.REMOVE)) {
return processNewArray(mth, (NewArrayNode) insn, instructions, remover);
}
return false;
}
private static boolean processNewArray(MethodNode mth,
NewArrayNode newArrayInsn, List<InsnNode> instructions, InsnRemover remover) {
Object arrayLenConst = InsnUtils.getConstValueByArg(mth.root(), newArrayInsn.getArg(0));
if (!(arrayLenConst instanceof LiteralArg)) {
return false;
}
int len = (int) ((LiteralArg) arrayLenConst).getLiteral();
if (len == 0) {
return false;
}
ArgType arrType = newArrayInsn.getArrayType();
ArgType elemType = arrType.getArrayElement();
boolean allowMissingKeys = arrType.getArrayDimension() == 1 && elemType.isPrimitive();
int minLen = allowMissingKeys ? len / 2 : len;
RegisterArg arrArg = newArrayInsn.getResult();
List<RegisterArg> useList = arrArg.getSVar().getUseList();
if (useList.size() < minLen) {
return false;
}
// quick check if APUT is used
boolean foundPut = false;
for (RegisterArg registerArg : useList) {
InsnNode parentInsn = registerArg.getParentInsn();
if (parentInsn != null && parentInsn.getType() == InsnType.APUT) {
foundPut = true;
break;
}
}
if (!foundPut) {
return false;
}
// collect put instructions sorted by array index
SortedMap<Long, InsnNode> arrPuts = new TreeMap<>();
for (RegisterArg registerArg : useList) {
InsnNode parentInsn = registerArg.getParentInsn();
if (parentInsn == null || parentInsn.getType() != InsnType.APUT) {
continue;
}
if (!arrArg.sameRegAndSVar(parentInsn.getArg(0))) {
return false;
}
Object constVal = InsnUtils.getConstValueByArg(mth.root(), parentInsn.getArg(1));
if (!(constVal instanceof LiteralArg)) {
return false;
}
long index = ((LiteralArg) constVal).getLiteral();
if (index >= len) {
return false;
}
if (arrPuts.containsKey(index)) {
// stop on index rewrite
break;
}
arrPuts.put(index, parentInsn);
}
if (arrPuts.size() < minLen) {
return false;
}
// expect all puts to be in same block
if (!new HashSet<>(instructions).containsAll(arrPuts.values())) {
return false;
}
// checks complete, apply
InsnNode filledArr = new FilledNewArrayNode(elemType, len);
filledArr.setResult(arrArg.duplicate());
filledArr.copyAttributesFrom(newArrayInsn);
filledArr.inheritMetadata(newArrayInsn);
filledArr.setOffset(newArrayInsn.getOffset());
long prevIndex = -1;
for (Map.Entry<Long, InsnNode> entry : arrPuts.entrySet()) {
long index = entry.getKey();
if (index != prevIndex) {
// use zero for missing keys
for (long i = prevIndex + 1; i < index; i++) {
filledArr.addArg(InsnArg.lit(0, elemType));
}
}
InsnNode put = entry.getValue();
filledArr.addArg(replaceConstInArg(mth, put.getArg(2)));
remover.addAndUnbind(put);
prevIndex = index;
}
remover.addAndUnbind(newArrayInsn);
InsnNode lastPut = arrPuts.get(arrPuts.lastKey());
int replaceIndex = InsnList.getIndex(instructions, lastPut);
instructions.set(replaceIndex, filledArr);
BlockUtils.replaceInsn(mth, lastPut, filledArr);
return true;
}
private static InsnArg replaceConstInArg(MethodNode mth, InsnArg valueArg) {
if (valueArg.isLiteral()) {
FieldNode f = mth.getParentClass().getConstFieldByLiteralArg((LiteralArg) valueArg);
if (f != null) {
InsnNode fGet = new IndexInsnNode(InsnType.SGET, f.getFieldInfo(), 0);
InsnArg arg = InsnArg.wrapArg(fGet);
f.addUseIn(mth);
return arg;
}
}
return valueArg.duplicate();
}
}

View File

@ -22,6 +22,7 @@ import jadx.core.dex.attributes.AType;
import jadx.core.dex.attributes.nodes.EdgeInsnAttr;
import jadx.core.dex.attributes.nodes.LoopInfo;
import jadx.core.dex.attributes.nodes.LoopLabelAttr;
import jadx.core.dex.attributes.nodes.RegionRefAttr;
import jadx.core.dex.instructions.IfNode;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.SwitchInsn;
@ -827,6 +828,7 @@ public class RegionMaker {
}
SwitchRegion sw = new SwitchRegion(currentRegion, block);
insn.addAttr(new RegionRefAttr(sw));
currentRegion.getSubBlocks().add(sw);
stack.push(sw);
stack.addExit(out);

View File

@ -2,14 +2,11 @@ package jadx.tests.integration.enums;
import org.junit.jupiter.api.Test;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest;
import jadx.tests.api.SmaliTest;
import static jadx.tests.api.utils.JadxMatchers.countString;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat;
public class TestSwitchOverEnum extends IntegrationTest {
public class TestSwitchOverEnum extends SmaliTest {
public enum Count {
ONE, TWO, THREE
@ -26,18 +23,29 @@ public class TestSwitchOverEnum extends IntegrationTest {
}
public void check() {
assertEquals(1, testEnum(Count.ONE));
assertEquals(2, testEnum(Count.TWO));
assertEquals(0, testEnum(Count.THREE));
assertThat(testEnum(Count.ONE)).isEqualTo(1);
assertThat(testEnum(Count.TWO)).isEqualTo(2);
assertThat(testEnum(Count.THREE)).isEqualTo(0);
}
@Test
public void test() {
ClassNode cls = getClassNode(TestSwitchOverEnum.class);
String code = cls.getCode().toString();
// remapping array placed in top class, place test also in top class
assertThat(getClassNode(TestSwitchOverEnum.class))
.code()
.countString(1, "synthetic")
.countString(2, "switch (c) {")
.countString(3, "case ONE:");
}
assertThat(code, countString(1, "synthetic"));
assertThat(code, countString(2, "switch (c) {"));
assertThat(code, countString(2, "case ONE:"));
/**
* Java 21 compiler can omit a remapping array and use switch over ordinal directly
*/
@Test
public void testSmaliDirect() {
assertThat(getClassNodeFromSmaliFiles())
.code()
.containsOne("switch (v) {")
.containsOne("case ONE:");
}
}

View File

@ -0,0 +1,76 @@
.class public final enum Lenums/TestSwitchOverEnum$Count;
.super Ljava/lang/Enum;
.field private static final synthetic $VALUES:[Lenums/TestSwitchOverEnum$Count;
.field public static final enum ONE:Lenums/TestSwitchOverEnum$Count;
.field public static final enum THREE:Lenums/TestSwitchOverEnum$Count;
.field public static final enum TWO:Lenums/TestSwitchOverEnum$Count;
.method private static synthetic $values()[Lenums/TestSwitchOverEnum$Count;
.registers 3
const/4 v0, 0x3
new-array v0, v0, [Lenums/TestSwitchOverEnum$Count;
const/4 v1, 0x0
sget-object v2, Lenums/TestSwitchOverEnum$Count;->ONE:Lenums/TestSwitchOverEnum$Count;
aput-object v2, v0, v1
const/4 v1, 0x1
sget-object v2, Lenums/TestSwitchOverEnum$Count;->TWO:Lenums/TestSwitchOverEnum$Count;
aput-object v2, v0, v1
const/4 v1, 0x2
sget-object v2, Lenums/TestSwitchOverEnum$Count;->THREE:Lenums/TestSwitchOverEnum$Count;
aput-object v2, v0, v1
return-object v0
.end method
.method static constructor <clinit>()V
.registers 3
new-instance v0, Lenums/TestSwitchOverEnum$Count;
const-string v1, "ONE"
const/4 v2, 0x0
invoke-direct {v0, v1, v2}, Lenums/TestSwitchOverEnum$Count;-><init>(Ljava/lang/String;I)V
sput-object v0, Lenums/TestSwitchOverEnum$Count;->ONE:Lenums/TestSwitchOverEnum$Count;
new-instance v0, Lenums/TestSwitchOverEnum$Count;
const-string v1, "TWO"
const/4 v2, 0x1
invoke-direct {v0, v1, v2}, Lenums/TestSwitchOverEnum$Count;-><init>(Ljava/lang/String;I)V
sput-object v0, Lenums/TestSwitchOverEnum$Count;->TWO:Lenums/TestSwitchOverEnum$Count;
new-instance v0, Lenums/TestSwitchOverEnum$Count;
const-string v1, "THREE"
const/4 v2, 0x2
invoke-direct {v0, v1, v2}, Lenums/TestSwitchOverEnum$Count;-><init>(Ljava/lang/String;I)V
sput-object v0, Lenums/TestSwitchOverEnum$Count;->THREE:Lenums/TestSwitchOverEnum$Count;
invoke-static {}, Lenums/TestSwitchOverEnum$Count;->$values()[Lenums/TestSwitchOverEnum$Count;
move-result-object v0
sput-object v0, Lenums/TestSwitchOverEnum$Count;->$VALUES:[Lenums/TestSwitchOverEnum$Count;
return-void
.end method
.method private constructor <init>(Ljava/lang/String;I)V
.registers 3
.annotation system Ldalvik/annotation/Signature;
value = {
"()V"
}
.end annotation
invoke-direct {p0, p1, p2}, Ljava/lang/Enum;-><init>(Ljava/lang/String;I)V
return-void
.end method
.method public static valueOf(Ljava/lang/String;)Lenums/TestSwitchOverEnum$Count;
.registers 2
const-class v0, Lenums/TestSwitchOverEnum$Count;
invoke-static {v0, p0}, Ljava/lang/Enum;->valueOf(Ljava/lang/Class;Ljava/lang/String;)Ljava/lang/Enum;
move-result-object v0
check-cast v0, Lenums/TestSwitchOverEnum$Count;
return-object v0
.end method
.method public static values()[Lenums/TestSwitchOverEnum$Count;
.registers 1
sget-object v0, Lenums/TestSwitchOverEnum$Count;->$VALUES:[Lenums/TestSwitchOverEnum$Count;
invoke-virtual {v0}, [Lenums/TestSwitchOverEnum$Count;->clone()Ljava/lang/Object;
move-result-object v0
check-cast v0, [Lenums/TestSwitchOverEnum$Count;
return-object v0
.end method

View File

@ -0,0 +1,30 @@
.class public Lenums/TestSwitchOverEnum;
.super Ljava/lang/Object;
.method public test(Lenums/TestSwitchOverEnum$Count;)I
.registers 3
.param p1, "v"
invoke-virtual {p1}, Lenums/TestSwitchOverEnum$Count;->ordinal()I
move-result v0
packed-switch v0, :pswitch_data
const/4 v0, 0x0
:goto_8
return v0
:pswitch_9
const/4 v0, 0x1
goto :goto_8
:pswitch_b
const/4 v0, 0x2
goto :goto_8
:pswitch_data
.packed-switch 0x0
:pswitch_9
:pswitch_b
.end packed-switch
.end method