diff --git a/lib/Transforms/Scalar/InstructionCombining.cpp b/lib/Transforms/Scalar/InstructionCombining.cpp index 4be4f30ff67..807653b56d1 100644 --- a/lib/Transforms/Scalar/InstructionCombining.cpp +++ b/lib/Transforms/Scalar/InstructionCombining.cpp @@ -952,21 +952,23 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { } Instruction *InstCombiner::visitDiv(BinaryOperator &I) { - if (isa(I.getOperand(0))) // undef / X -> 0 - return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); - if (isa(I.getOperand(1))) - return ReplaceInstUsesWith(I, I.getOperand(1)); // X / undef -> undef + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - if (ConstantInt *RHS = dyn_cast(I.getOperand(1))) { + if (isa(Op0)) // undef / X -> 0 + return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); + if (isa(Op1)) + return ReplaceInstUsesWith(I, Op1); // X / undef -> undef + + if (ConstantInt *RHS = dyn_cast(Op1)) { // div X, 1 == X if (RHS->equalsInt(1)) - return ReplaceInstUsesWith(I, I.getOperand(0)); + return ReplaceInstUsesWith(I, Op0); // div X, -1 == -X if (RHS->isAllOnesValue()) - return BinaryOperator::createNeg(I.getOperand(0)); + return BinaryOperator::createNeg(Op0); - if (Instruction *LHS = dyn_cast(I.getOperand(0))) + if (Instruction *LHS = dyn_cast(Op0)) if (LHS->getOpcode() == Instruction::Div) if (ConstantInt *LHSRHS = dyn_cast(LHS->getOperand(1))) { // (X / C1) / C2 -> X / (C1*C2) @@ -979,21 +981,54 @@ Instruction *InstCombiner::visitDiv(BinaryOperator &I) { if (ConstantUInt *C = dyn_cast(RHS)) if (uint64_t Val = C->getValue()) // Don't break X / 0 if (uint64_t C = Log2(Val)) - return new ShiftInst(Instruction::Shr, I.getOperand(0), + return new ShiftInst(Instruction::Shr, Op0, ConstantUInt::get(Type::UByteTy, C)); // -X/C -> X/-C if (RHS->getType()->isSigned()) - if (Value *LHSNeg = dyn_castNegVal(I.getOperand(0))) + if (Value *LHSNeg = dyn_castNegVal(Op0)) return BinaryOperator::createDiv(LHSNeg, ConstantExpr::getNeg(RHS)); - if (isa(I.getOperand(0)) && !RHS->isNullValue()) - if (Instruction *NV = FoldOpIntoPhi(I)) - return NV; + if (!RHS->isNullValue()) { + if (SelectInst *SI = dyn_cast(Op0)) + if (Instruction *R = FoldBinOpIntoSelect(I, SI, this)) + return R; + if (isa(Op0)) + if (Instruction *NV = FoldOpIntoPhi(I)) + return NV; + } } + // If this is 'udiv X, (Cond ? C1, C2)' where C1&C2 are powers of two, + // transform this into: '(Cond ? (udiv X, C1) : (udiv X, C2))'. + if (SelectInst *SI = dyn_cast(Op1)) + if (ConstantUInt *STO = dyn_cast(SI->getOperand(1))) + if (ConstantUInt *SFO = dyn_cast(SI->getOperand(2))) { + if (STO->getValue() == 0) { // Couldn't be this argument. + I.setOperand(1, SFO); + return &I; + } else if (SFO->getValue() == 0) { + I.setOperand(1, STO); + return &I; + } + + if (uint64_t TSA = Log2(STO->getValue())) + if (uint64_t FSA = Log2(SFO->getValue())) { + Constant *TC = ConstantUInt::get(Type::UByteTy, TSA); + Instruction *TSI = new ShiftInst(Instruction::Shr, Op0, + TC, SI->getName()+".t"); + TSI = InsertNewInstBefore(TSI, I); + + Constant *FC = ConstantUInt::get(Type::UByteTy, FSA); + Instruction *FSI = new ShiftInst(Instruction::Shr, Op0, + FC, SI->getName()+".f"); + FSI = InsertNewInstBefore(FSI, I); + return new SelectInst(SI->getOperand(0), TSI, FSI); + } + } + // 0 / X == 0, we don't need to preserve faults! - if (ConstantInt *LHS = dyn_cast(I.getOperand(0))) + if (ConstantInt *LHS = dyn_cast(Op0)) if (LHS->equalsInt(0)) return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); @@ -1002,8 +1037,9 @@ Instruction *InstCombiner::visitDiv(BinaryOperator &I) { Instruction *InstCombiner::visitRem(BinaryOperator &I) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (I.getType()->isSigned()) - if (Value *RHSNeg = dyn_castNegVal(I.getOperand(1))) + if (Value *RHSNeg = dyn_castNegVal(Op1)) if (!isa(RHSNeg) || cast(RHSNeg)->getValue() > 0) { // X % -Y -> X % Y @@ -1012,12 +1048,12 @@ Instruction *InstCombiner::visitRem(BinaryOperator &I) { return &I; } - if (isa(I.getOperand(0))) // undef % X -> 0 + if (isa(Op0)) // undef % X -> 0 return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); - if (isa(I.getOperand(1))) - return ReplaceInstUsesWith(I, I.getOperand(1)); // X % undef -> undef + if (isa(Op1)) + return ReplaceInstUsesWith(I, Op1); // X % undef -> undef - if (ConstantInt *RHS = dyn_cast(I.getOperand(1))) { + if (ConstantInt *RHS = dyn_cast(Op1)) { if (RHS->equalsInt(1)) // X % 1 == 0 return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); @@ -1026,15 +1062,44 @@ Instruction *InstCombiner::visitRem(BinaryOperator &I) { if (ConstantUInt *C = dyn_cast(RHS)) if (uint64_t Val = C->getValue()) // Don't break X % 0 (divide by zero) if (!(Val & (Val-1))) // Power of 2 - return BinaryOperator::createAnd(I.getOperand(0), - ConstantUInt::get(I.getType(), Val-1)); - if (isa(I.getOperand(0)) && !RHS->isNullValue()) - if (Instruction *NV = FoldOpIntoPhi(I)) - return NV; + return BinaryOperator::createAnd(Op0, + ConstantUInt::get(I.getType(), Val-1)); + + if (!RHS->isNullValue()) { + if (SelectInst *SI = dyn_cast(Op0)) + if (Instruction *R = FoldBinOpIntoSelect(I, SI, this)) + return R; + if (isa(Op0)) + if (Instruction *NV = FoldOpIntoPhi(I)) + return NV; + } } + // If this is 'urem X, (Cond ? C1, C2)' where C1&C2 are powers of two, + // transform this into: '(Cond ? (urem X, C1) : (urem X, C2))'. + if (SelectInst *SI = dyn_cast(Op1)) + if (ConstantUInt *STO = dyn_cast(SI->getOperand(1))) + if (ConstantUInt *SFO = dyn_cast(SI->getOperand(2))) { + if (STO->getValue() == 0) { // Couldn't be this argument. + I.setOperand(1, SFO); + return &I; + } else if (SFO->getValue() == 0) { + I.setOperand(1, STO); + return &I; + } + + if (!(STO->getValue() & (STO->getValue()-1)) && + !(SFO->getValue() & (SFO->getValue()-1))) { + Value *TrueAnd = InsertNewInstBefore(BinaryOperator::createAnd(Op0, + SubOne(STO), SI->getName()+".t"), I); + Value *FalseAnd = InsertNewInstBefore(BinaryOperator::createAnd(Op0, + SubOne(SFO), SI->getName()+".f"), I); + return new SelectInst(SI->getOperand(0), TrueAnd, FalseAnd); + } + } + // 0 % X == 0, we don't need to preserve faults! - if (ConstantInt *LHS = dyn_cast(I.getOperand(0))) + if (ConstantInt *LHS = dyn_cast(Op0)) if (LHS->equalsInt(0)) return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType()));