From 4ee9281ee4df195714275c87f33fe4c36c8d21af Mon Sep 17 00:00:00 2001 From: Sanjay Patel Date: Mon, 16 Jan 2017 21:24:41 +0000 Subject: [PATCH] [InstCombine] use m_APInt instead of faking it git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@292164 91177308-0d34-0410-b5e6-96231b3b80d8 --- .../InstCombine/InstCombineShifts.cpp | 34 ++++++++----------- 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/lib/Transforms/InstCombine/InstCombineShifts.cpp b/lib/Transforms/InstCombine/InstCombineShifts.cpp index 1df0afc05e5..bf0ab82e89d 100644 --- a/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -312,7 +312,7 @@ static Value *getShiftedValue(Value *V, unsigned NumBits, bool isLeftShift, /// Try to fold (X << C1) << C2, where the shifts are some combination of /// shl/ashr/lshr. static Instruction * -foldShiftByConstOfShiftByConst(BinaryOperator &I, ConstantInt *COp1, +foldShiftByConstOfShiftByConst(BinaryOperator &I, const APInt *COp1, InstCombiner::BuilderTy *Builder) { Value *Op0 = I.getOperand(0); uint32_t TypeBits = Op0->getType()->getScalarSizeInBits(); @@ -475,33 +475,26 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, BinaryOperator &I) { bool isLeftShift = I.getOpcode() == Instruction::Shl; - ConstantInt *COp1 = nullptr; - if (ConstantDataVector *CV = dyn_cast(Op1)) - COp1 = dyn_cast_or_null(CV->getSplatValue()); - else if (ConstantVector *CV = dyn_cast(Op1)) - COp1 = dyn_cast_or_null(CV->getSplatValue()); - else - COp1 = dyn_cast(Op1); - - if (!COp1) + const APInt *Op1C; + if (!match(Op1, m_APInt(Op1C))) return nullptr; // See if we can propagate this shift into the input, this covers the trivial // cast of lshr(shl(x,c1),c2) as well as other more complex cases. if (I.getOpcode() != Instruction::AShr && - canEvaluateShifted(Op0, COp1->getZExtValue(), isLeftShift, *this, &I)) { + canEvaluateShifted(Op0, Op1C->getZExtValue(), isLeftShift, *this, &I)) { DEBUG(dbgs() << "ICE: GetShiftedValue propagating shift through expression" " to eliminate shift:\n IN: " << *Op0 << "\n SH: " << I <<"\n"); return replaceInstUsesWith( - I, getShiftedValue(Op0, COp1->getZExtValue(), isLeftShift, *this, DL)); + I, getShiftedValue(Op0, Op1C->getZExtValue(), isLeftShift, *this, DL)); } // See if we can simplify any instructions used by the instruction whose sole // purpose is to compute bits we don't care about. - uint32_t TypeBits = Op0->getType()->getScalarSizeInBits(); + unsigned TypeBits = Op0->getType()->getScalarSizeInBits(); - assert(!COp1->uge(TypeBits) && + assert(!Op1C->uge(TypeBits) && "Shift over the type width should have been removed already"); // ((X*C1) << C2) == (X * (C1 << C2)) @@ -525,7 +518,8 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, if (TrOp && I.isLogicalShift() && TrOp->isShift() && isa(TrOp->getOperand(1))) { // Okay, we'll do this xform. Make the shift of shift. - Constant *ShAmt = ConstantExpr::getZExt(COp1, TrOp->getType()); + Constant *ShAmt = + ConstantExpr::getZExt(cast(Op1), TrOp->getType()); // (shift2 (shift1 & 0x00FF), c2) Value *NSh = Builder->CreateBinOp(I.getOpcode(), TrOp, ShAmt,I.getName()); @@ -542,10 +536,10 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, // shift. We know that it is a logical shift by a constant, so adjust the // mask as appropriate. if (I.getOpcode() == Instruction::Shl) - MaskV <<= COp1->getZExtValue(); + MaskV <<= Op1C->getZExtValue(); else { assert(I.getOpcode() == Instruction::LShr && "Unknown logical shift"); - MaskV = MaskV.lshr(COp1->getZExtValue()); + MaskV = MaskV.lshr(Op1C->getZExtValue()); } // shift1 & 0x00FF @@ -579,7 +573,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, // (X + (Y << C)) Value *X = Builder->CreateBinOp(Op0BO->getOpcode(), YS, V1, Op0BO->getOperand(1)->getName()); - uint32_t Op1Val = COp1->getLimitedValue(TypeBits); + unsigned Op1Val = Op1C->getLimitedValue(TypeBits); APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val); Constant *Mask = ConstantInt::get(I.getContext(), Bits); @@ -615,7 +609,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, // (X + (Y << C)) Value *X = Builder->CreateBinOp(Op0BO->getOpcode(), V1, YS, Op0BO->getOperand(0)->getName()); - uint32_t Op1Val = COp1->getLimitedValue(TypeBits); + unsigned Op1Val = Op1C->getLimitedValue(TypeBits); APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val); Constant *Mask = ConstantInt::get(I.getContext(), Bits); @@ -686,7 +680,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, } } - if (Instruction *Folded = foldShiftByConstOfShiftByConst(I, COp1, Builder)) + if (Instruction *Folded = foldShiftByConstOfShiftByConst(I, Op1C, Builder)) return Folded; return nullptr;