From e9efecbf470100696355f32ea8b6ab942183ac6c Mon Sep 17 00:00:00 2001 From: Chris Lattner Date: Tue, 14 Mar 2006 16:04:29 +0000 Subject: [PATCH] Implement a FIXME, recusively reassociating A*A*B + A*A*C --> A*(A*B+A*C) --> A*(A*(B+C)) This implements Reassociate/mul-factor3.ll git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@26757 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/Transforms/Scalar/Reassociate.cpp | 91 +++++++++++++++++++-------- 1 file changed, 65 insertions(+), 26 deletions(-) diff --git a/lib/Transforms/Scalar/Reassociate.cpp b/lib/Transforms/Scalar/Reassociate.cpp index dc44ad593f2..e495ffafbb9 100644 --- a/lib/Transforms/Scalar/Reassociate.cpp +++ b/lib/Transforms/Scalar/Reassociate.cpp @@ -79,8 +79,8 @@ namespace { void BuildRankMap(Function &F); unsigned getRank(Value *V); void ReassociateExpression(BinaryOperator *I); - void RewriteExprTree(BinaryOperator *I, unsigned Idx, - std::vector &Ops); + void RewriteExprTree(BinaryOperator *I, std::vector &Ops, + unsigned Idx = 0); Value *OptimizeExpression(BinaryOperator *I, std::vector &Ops); void LinearizeExprTree(BinaryOperator *I, std::vector &Ops); void LinearizeExpr(BinaryOperator *I); @@ -174,7 +174,7 @@ unsigned Reassociate::getRank(Value *V) { /// isReassociableOp - Return true if V is an instruction of the specified /// opcode and if it only has one use. static BinaryOperator *isReassociableOp(Value *V, unsigned Opcode) { - if (V->hasOneUse() && isa(V) && + if ((V->hasOneUse() || V->use_empty()) && isa(V) && cast(V)->getOpcode() == Opcode) return cast(V); return 0; @@ -234,6 +234,10 @@ void Reassociate::LinearizeExpr(BinaryOperator *I) { /// form of the the expression (((a+b)+c)+d), and collects information about the /// rank of the non-tree operands. /// +/// NOTE: These intentionally destroys the expression tree operands (turning +/// them into undef values) to reduce #uses of the values. This means that the +/// caller MUST use something like RewriteExprTree to put the values back in. +/// void Reassociate::LinearizeExprTree(BinaryOperator *I, std::vector &Ops) { Value *LHS = I->getOperand(0), *RHS = I->getOperand(1); @@ -262,6 +266,10 @@ void Reassociate::LinearizeExprTree(BinaryOperator *I, // such, just remember these operands and their rank. Ops.push_back(ValueEntry(getRank(LHS), LHS)); Ops.push_back(ValueEntry(getRank(RHS), RHS)); + + // Clear the leaves out. + I->setOperand(0, UndefValue::get(I->getType())); + I->setOperand(1, UndefValue::get(I->getType())); return; } else { // Turn X+(Y+Z) -> (Y+Z)+X @@ -293,13 +301,17 @@ void Reassociate::LinearizeExprTree(BinaryOperator *I, // Remember the RHS operand and its rank. Ops.push_back(ValueEntry(getRank(RHS), RHS)); + + // Clear the RHS leaf out. + I->setOperand(1, UndefValue::get(I->getType())); } // RewriteExprTree - Now that the operands for this expression tree are // linearized and optimized, emit them in-order. This function is written to be // tail recursive. -void Reassociate::RewriteExprTree(BinaryOperator *I, unsigned i, - std::vector &Ops) { +void Reassociate::RewriteExprTree(BinaryOperator *I, + std::vector &Ops, + unsigned i) { if (i+2 == Ops.size()) { if (I->getOperand(0) != Ops[i].Op || I->getOperand(1) != Ops[i+1].Op) { @@ -334,7 +346,7 @@ void Reassociate::RewriteExprTree(BinaryOperator *I, unsigned i, // Compactify the tree instructions together with each other to guarantee // that the expression tree is dominated by all of Ops. LHS->moveBefore(I); - RewriteExprTree(LHS, i+1, Ops); + RewriteExprTree(LHS, Ops, i+1); } @@ -474,14 +486,36 @@ Value *Reassociate::RemoveFactorFromExpression(Value *V, Value *Factor) { Factors.erase(Factors.begin()+i); break; } - if (!FoundFactor) return 0; + if (!FoundFactor) { + // Make sure to restore the operands to the expression tree. + RewriteExprTree(BO, Factors); + return 0; + } if (Factors.size() == 1) return Factors[0].Op; - RewriteExprTree(BO, 0, Factors); + RewriteExprTree(BO, Factors); return BO; } +/// FindSingleUseMultiplyFactors - If V is a single-use multiply, recursively +/// add its operands as factors, otherwise add V to the list of factors. +static void FindSingleUseMultiplyFactors(Value *V, + std::vector &Factors) { + BinaryOperator *BO; + if ((!V->hasOneUse() && !V->use_empty()) || + !(BO = dyn_cast(V)) || + BO->getOpcode() != Instruction::Mul) { + Factors.push_back(V); + return; + } + + // Otherwise, add the LHS and RHS to the list of factors. + FindSingleUseMultiplyFactors(BO->getOperand(1), Factors); + FindSingleUseMultiplyFactors(BO->getOperand(0), Factors); +} + + Value *Reassociate::OptimizeExpression(BinaryOperator *I, std::vector &Ops) { @@ -627,26 +661,26 @@ Value *Reassociate::OptimizeExpression(BinaryOperator *I, if (!I->getType()->isFloatingPoint()) { for (unsigned i = 0, e = Ops.size(); i != e; ++i) { if (BinaryOperator *BOp = dyn_cast(Ops[i].Op)) - if (BOp->getOpcode() == Instruction::Mul && BOp->hasOneUse()) { + if (BOp->getOpcode() == Instruction::Mul && BOp->use_empty()) { // Compute all of the factors of this added value. - std::vector Factors; - LinearizeExprTree(BOp, Factors); + std::vector Factors; + FindSingleUseMultiplyFactors(BOp, Factors); assert(Factors.size() > 1 && "Bad linearize!"); // Add one to FactorOccurrences for each unique factor in this op. if (Factors.size() == 2) { - unsigned Occ = ++FactorOccurrences[Factors[0].Op]; - if (Occ > MaxOcc) { MaxOcc = Occ; MaxOccVal = Factors[0].Op; } - if (Factors[0].Op != Factors[1].Op) { // Don't double count A*A. - Occ = ++FactorOccurrences[Factors[1].Op]; - if (Occ > MaxOcc) { MaxOcc = Occ; MaxOccVal = Factors[1].Op; } + unsigned Occ = ++FactorOccurrences[Factors[0]]; + if (Occ > MaxOcc) { MaxOcc = Occ; MaxOccVal = Factors[0]; } + if (Factors[0] != Factors[1]) { // Don't double count A*A. + Occ = ++FactorOccurrences[Factors[1]]; + if (Occ > MaxOcc) { MaxOcc = Occ; MaxOccVal = Factors[1]; } } } else { std::set Duplicates; for (unsigned i = 0, e = Factors.size(); i != e; ++i) - if (Duplicates.insert(Factors[i].Op).second) { - unsigned Occ = ++FactorOccurrences[Factors[i].Op]; - if (Occ > MaxOcc) { MaxOcc = Occ; MaxOccVal = Factors[i].Op; } + if (Duplicates.insert(Factors[i]).second) { + unsigned Occ = ++FactorOccurrences[Factors[i]]; + if (Occ > MaxOcc) { MaxOcc = Occ; MaxOccVal = Factors[i]; } } } } @@ -675,21 +709,26 @@ Value *Reassociate::OptimizeExpression(BinaryOperator *I, // No need for extra uses anymore. delete DummyInst; + unsigned NumAddedValues = NewMulOps.size(); Value *V = EmitAddTreeOfValues(I, NewMulOps); - // FIXME: Must optimize V now, to handle this case: - // A*A*B + A*A*C -> A*(A*B+A*C) -> A*(A*(B+C)) - V = BinaryOperator::createMul(V, MaxOccVal, "tmp", I); + Value *V2 = BinaryOperator::createMul(V, MaxOccVal, "tmp", I); + // Now that we have inserted V and its sole use, optimize it. This allows + // us to handle cases that require multiple factoring steps, such as this: + // A*A*B + A*A*C --> A*(A*B+A*C) --> A*(A*(B+C)) + if (NumAddedValues > 1) + ReassociateExpression(cast(V)); + ++NumFactor; if (Ops.size() == 0) - return V; + return V2; // Add the new value to the list of things being added. - Ops.insert(Ops.begin(), ValueEntry(getRank(V), V)); + Ops.insert(Ops.begin(), ValueEntry(getRank(V2), V2)); // Rewrite the tree so that there is now a use of V. - RewriteExprTree(I, 0, Ops); + RewriteExprTree(I, Ops); return OptimizeExpression(I, Ops); } break; @@ -808,7 +847,7 @@ void Reassociate::ReassociateExpression(BinaryOperator *I) { } else { // Now that we ordered and optimized the expressions, splat them back into // the expression tree, removing any unneeded nodes. - RewriteExprTree(I, 0, Ops); + RewriteExprTree(I, Ops); } }