diff --git a/lib/Transforms/Scalar/InstructionCombining.cpp b/lib/Transforms/Scalar/InstructionCombining.cpp index 3ba239b6a69..931ad8b232e 100644 --- a/lib/Transforms/Scalar/InstructionCombining.cpp +++ b/lib/Transforms/Scalar/InstructionCombining.cpp @@ -3648,68 +3648,101 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, ConstantUInt *Op1, } } - // If this is a shift of a shift, see if we can fold the two together. + // Find out if this is a shift of a shift by a constant. + ShiftInst *ShiftOp = 0; if (ShiftInst *Op0SI = dyn_cast(Op0)) - if (ConstantUInt *ShiftAmt1C = - dyn_cast(Op0SI->getOperand(1))) { - unsigned ShiftAmt1 = (unsigned)ShiftAmt1C->getValue(); - unsigned ShiftAmt2 = (unsigned)Op1->getValue(); + ShiftOp = Op0SI; + else if (CastInst *CI = dyn_cast(Op0)) { + // If this is a noop-integer case of a shift instruction, use the shift. + if (CI->getOperand(0)->getType()->isInteger() && + CI->getOperand(0)->getType()->getPrimitiveSizeInBits() == + CI->getType()->getPrimitiveSizeInBits() && + isa(CI->getOperand(0))) { + ShiftOp = cast(CI->getOperand(0)); + } + } + + if (ShiftOp && isa(ShiftOp->getOperand(1))) { + // Find the operands and properties of the input shift. Note that the + // signedness of the input shift may differ from the current shift if there + // is a noop cast between the two. + bool isShiftOfLeftShift = ShiftOp->getOpcode() == Instruction::Shl; + bool isShiftOfSignedShift = ShiftOp->getType()->isSigned(); + bool isShiftOfUnsignedShift = !isSignedShift; + + ConstantUInt *ShiftAmt1C = cast(ShiftOp->getOperand(1)); + + unsigned ShiftAmt1 = (unsigned)ShiftAmt1C->getValue(); + unsigned ShiftAmt2 = (unsigned)Op1->getValue(); + + // Check for (A << c1) << c2 and (A >> c1) >> c2. + if (isLeftShift == isShiftOfLeftShift) { + // Do not fold these shifts if the first one is signed and the second one + // is unsigned and this is a right shift. Further, don't do any folding + // on them. + if (isShiftOfSignedShift && isUnsignedShift && !isLeftShift) + return 0; - // Check for (A << c1) << c2 and (A >> c1) >> c2 - if (I.getOpcode() == Op0SI->getOpcode()) { - unsigned Amt = ShiftAmt1+ShiftAmt2; // Fold into one big shift. - if (Op0->getType()->getPrimitiveSizeInBits() < Amt) - Amt = Op0->getType()->getPrimitiveSizeInBits(); - return new ShiftInst(I.getOpcode(), Op0SI->getOperand(0), - ConstantUInt::get(Type::UByteTy, Amt)); - } + unsigned Amt = ShiftAmt1+ShiftAmt2; // Fold into one big shift. + if (Amt > Op0->getType()->getPrimitiveSizeInBits()) + Amt = Op0->getType()->getPrimitiveSizeInBits(); - // Check for (A << c1) >> c2 or visaversa. If we are dealing with - // signed types, we can only support the (A >> c1) << c2 configuration, - // because it can not turn an arbitrary bit of A into a sign bit. - if (isUnsignedShift || isLeftShift) { - // Calculate bitmask for what gets shifted off the edge... - Constant *C = ConstantIntegral::getAllOnesValue(I.getType()); - if (isLeftShift) - C = ConstantExpr::getShl(C, ShiftAmt1C); - else - C = ConstantExpr::getShr(C, ShiftAmt1C); - - Instruction *Mask = - BinaryOperator::createAnd(Op0SI->getOperand(0), C, - Op0SI->getOperand(0)->getName()+".mask"); - InsertNewInstBefore(Mask, I); - - // Figure out what flavor of shift we should use... - if (ShiftAmt1 == ShiftAmt2) - return ReplaceInstUsesWith(I, Mask); // (A << c) >> c === A & c2 - else if (ShiftAmt1 < ShiftAmt2) { - return new ShiftInst(I.getOpcode(), Mask, - ConstantUInt::get(Type::UByteTy, ShiftAmt2-ShiftAmt1)); - } else { - return new ShiftInst(Op0SI->getOpcode(), Mask, - ConstantUInt::get(Type::UByteTy, ShiftAmt1-ShiftAmt2)); - } + Value *Op = ShiftOp->getOperand(0); + if (isShiftOfSignedShift != isSignedShift) + Op = InsertNewInstBefore(new CastInst(Op, I.getType(), "tmp"), I); + return new ShiftInst(I.getOpcode(), Op, + ConstantUInt::get(Type::UByteTy, Amt)); + } + + // Check for (A << c1) >> c2 or (A >> c1) << c2. If we are dealing with + // signed types, we can only support the (A >> c1) << c2 configuration, + // because it can not turn an arbitrary bit of A into a sign bit. + if (isUnsignedShift || isLeftShift) { + // Calculate bitmask for what gets shifted off the edge. + Constant *C = ConstantIntegral::getAllOnesValue(I.getType()); + if (isLeftShift) + C = ConstantExpr::getShl(C, ShiftAmt1C); + else + C = ConstantExpr::getShr(C, ShiftAmt1C); // must be an unsigned shr. + + Value *Op = ShiftOp->getOperand(0); + if (isShiftOfSignedShift != isSignedShift) + Op = InsertNewInstBefore(new CastInst(Op, I.getType(),Op->getName()),I); + + Instruction *Mask = + BinaryOperator::createAnd(Op, C, Op->getName()+".mask"); + InsertNewInstBefore(Mask, I); + + // Figure out what flavor of shift we should use... + if (ShiftAmt1 == ShiftAmt2) + return ReplaceInstUsesWith(I, Mask); // (A << c) >> c === A & c2 + else if (ShiftAmt1 < ShiftAmt2) { + return new ShiftInst(I.getOpcode(), Mask, + ConstantUInt::get(Type::UByteTy, ShiftAmt2-ShiftAmt1)); } else { - // We can handle signed (X << C1) >> C2 if it's a sign extend. In - // this case, C1 == C2 and C1 is 8, 16, or 32. - if (ShiftAmt1 == ShiftAmt2) { - const Type *SExtType = 0; - switch (ShiftAmt1) { - case 8 : SExtType = Type::SByteTy; break; - case 16: SExtType = Type::ShortTy; break; - case 32: SExtType = Type::IntTy; break; - } - - if (SExtType) { - Instruction *NewTrunc = new CastInst(Op0SI->getOperand(0), - SExtType, "sext"); - InsertNewInstBefore(NewTrunc, I); - return new CastInst(NewTrunc, I.getType()); - } + return new ShiftInst(ShiftOp->getOpcode(), Mask, + ConstantUInt::get(Type::UByteTy, ShiftAmt1-ShiftAmt2)); + } + } else { + // We can handle signed (X << C1) >> C2 if it's a sign extend. In + // this case, C1 == C2 and C1 is 8, 16, or 32. + if (ShiftAmt1 == ShiftAmt2) { + const Type *SExtType = 0; + switch (ShiftAmt1) { + case 8 : SExtType = Type::SByteTy; break; + case 16: SExtType = Type::ShortTy; break; + case 32: SExtType = Type::IntTy; break; + } + + if (SExtType) { + Instruction *NewTrunc = new CastInst(ShiftOp->getOperand(0), + SExtType, "sext"); + InsertNewInstBefore(NewTrunc, I); + return new CastInst(NewTrunc, I.getType()); } } } + } return 0; }