Factor a bunch of code out into a helper method.

llvm-svn: 61852
This commit is contained in:
Chris Lattner 2009-01-07 07:18:45 +00:00
parent 1a9f7818cd
commit 794f7e91f4

View File

@ -117,6 +117,11 @@ namespace {
void RewriteBitCastUserOfAlloca(Instruction *BCInst, AllocationInst *AI, void RewriteBitCastUserOfAlloca(Instruction *BCInst, AllocationInst *AI,
SmallVector<AllocaInst*, 32> &NewElts); SmallVector<AllocaInst*, 32> &NewElts);
void RewriteMemIntrinUserOfAlloca(MemIntrinsic *MI, Instruction *BCInst,
AllocationInst *AI,
SmallVector<AllocaInst*, 32> &NewElts);
const Type *CanConvertToScalar(Value *V, bool &IsNotTrivial); const Type *CanConvertToScalar(Value *V, bool &IsNotTrivial);
void ConvertToScalar(AllocationInst *AI, const Type *Ty); void ConvertToScalar(AllocationInst *AI, const Type *Ty);
void ConvertUsesToScalar(Value *Ptr, AllocaInst *NewAI, unsigned Offset); void ConvertUsesToScalar(Value *Ptr, AllocaInst *NewAI, unsigned Offset);
@ -593,179 +598,182 @@ void SROA::isSafeUseOfBitCastedAllocation(BitCastInst *BC, AllocationInst *AI,
/// instead. /// instead.
void SROA::RewriteBitCastUserOfAlloca(Instruction *BCInst, AllocationInst *AI, void SROA::RewriteBitCastUserOfAlloca(Instruction *BCInst, AllocationInst *AI,
SmallVector<AllocaInst*, 32> &NewElts) { SmallVector<AllocaInst*, 32> &NewElts) {
Constant *Zero = Constant::getNullValue(Type::Int32Ty);
Value::use_iterator UI = BCInst->use_begin(), UE = BCInst->use_end(); Value::use_iterator UI = BCInst->use_begin(), UE = BCInst->use_end();
while (UI != UE) { while (UI != UE) {
if (BitCastInst *BCU = dyn_cast<BitCastInst>(*UI)) { Instruction *User = cast<Instruction>(*UI++);
if (BitCastInst *BCU = dyn_cast<BitCastInst>(User)) {
RewriteBitCastUserOfAlloca(BCU, AI, NewElts); RewriteBitCastUserOfAlloca(BCU, AI, NewElts);
++UI;
BCU->eraseFromParent(); BCU->eraseFromParent();
continue; continue;
} }
// Otherwise, must be memcpy/memmove/memset of the entire aggregate. Split if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(User)) {
// into one per element. // This must be memcpy/memmove/memset of the entire aggregate.
MemIntrinsic *MI = dyn_cast<MemIntrinsic>(*UI); // Split into one per element.
RewriteMemIntrinUserOfAlloca(MI, BCInst, AI, NewElts);
// If it's not a mem intrinsic, it must be some other user of a gep of the MI->eraseFromParent();
// first pointer. Just leave these alone.
if (!MI) {
++UI;
continue; continue;
} }
// If this is a memcpy/memmove, construct the other pointer as the // If it's not a mem intrinsic, it must be some other user of a gep of the
// appropriate type. // first pointer. Just leave these alone.
Value *OtherPtr = 0; continue;
if (MemCpyInst *MCI = dyn_cast<MemCpyInst>(MI)) { }
if (BCInst == MCI->getRawDest()) }
OtherPtr = MCI->getRawSource();
else { /// RewriteMemIntrinUserOfAlloca - MI is a memcpy/memset/memmove from or to AI.
assert(BCInst == MCI->getRawSource()); /// Rewrite it to copy or set the elements of the scalarized memory.
OtherPtr = MCI->getRawDest(); void SROA::RewriteMemIntrinUserOfAlloca(MemIntrinsic *MI, Instruction *BCInst,
} AllocationInst *AI,
} else if (MemMoveInst *MMI = dyn_cast<MemMoveInst>(MI)) { SmallVector<AllocaInst*, 32> &NewElts) {
if (BCInst == MMI->getRawDest())
OtherPtr = MMI->getRawSource(); // If this is a memcpy/memmove, construct the other pointer as the
else { // appropriate type.
assert(BCInst == MMI->getRawSource()); Value *OtherPtr = 0;
OtherPtr = MMI->getRawDest(); if (MemCpyInst *MCI = dyn_cast<MemCpyInst>(MI)) {
} if (BCInst == MCI->getRawDest())
OtherPtr = MCI->getRawSource();
else {
assert(BCInst == MCI->getRawSource());
OtherPtr = MCI->getRawDest();
} }
} else if (MemMoveInst *MMI = dyn_cast<MemMoveInst>(MI)) {
if (BCInst == MMI->getRawDest())
OtherPtr = MMI->getRawSource();
else {
assert(BCInst == MMI->getRawSource());
OtherPtr = MMI->getRawDest();
}
}
// If there is an other pointer, we want to convert it to the same pointer
// type as AI has, so we can GEP through it safely.
if (OtherPtr) {
// It is likely that OtherPtr is a bitcast, if so, remove it.
if (BitCastInst *BC = dyn_cast<BitCastInst>(OtherPtr))
OtherPtr = BC->getOperand(0);
// All zero GEPs are effectively bitcasts.
if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(OtherPtr))
if (GEP->hasAllZeroIndices())
OtherPtr = GEP->getOperand(0);
// If there is an other pointer, we want to convert it to the same pointer if (ConstantExpr *BCE = dyn_cast<ConstantExpr>(OtherPtr))
// type as AI has, so we can GEP through it. if (BCE->getOpcode() == Instruction::BitCast)
OtherPtr = BCE->getOperand(0);
// If the pointer is not the right type, insert a bitcast to the right
// type.
if (OtherPtr->getType() != AI->getType())
OtherPtr = new BitCastInst(OtherPtr, AI->getType(), OtherPtr->getName(),
MI);
}
// Process each element of the aggregate.
Value *TheFn = MI->getOperand(0);
const Type *BytePtrTy = MI->getRawDest()->getType();
bool SROADest = MI->getRawDest() == BCInst;
Constant *Zero = Constant::getNullValue(Type::Int32Ty);
for (unsigned i = 0, e = NewElts.size(); i != e; ++i) {
// If this is a memcpy/memmove, emit a GEP of the other element address.
Value *OtherElt = 0;
if (OtherPtr) { if (OtherPtr) {
// It is likely that OtherPtr is a bitcast, if so, remove it. Value *Idx[2] = { Zero, ConstantInt::get(Type::Int32Ty, i) };
if (BitCastInst *BC = dyn_cast<BitCastInst>(OtherPtr)) OtherElt = GetElementPtrInst::Create(OtherPtr, Idx, Idx + 2,
OtherPtr = BC->getOperand(0);
// All zero GEPs are effectively bitcasts.
if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(OtherPtr))
if (GEP->hasAllZeroIndices())
OtherPtr = GEP->getOperand(0);
if (ConstantExpr *BCE = dyn_cast<ConstantExpr>(OtherPtr))
if (BCE->getOpcode() == Instruction::BitCast)
OtherPtr = BCE->getOperand(0);
// If the pointer is not the right type, insert a bitcast to the right
// type.
if (OtherPtr->getType() != AI->getType())
OtherPtr = new BitCastInst(OtherPtr, AI->getType(), OtherPtr->getName(),
MI);
}
// Process each element of the aggregate.
Value *TheFn = MI->getOperand(0);
const Type *BytePtrTy = MI->getRawDest()->getType();
bool SROADest = MI->getRawDest() == BCInst;
for (unsigned i = 0, e = NewElts.size(); i != e; ++i) {
// If this is a memcpy/memmove, emit a GEP of the other element address.
Value *OtherElt = 0;
if (OtherPtr) {
Value *Idx[2] = { Zero, ConstantInt::get(Type::Int32Ty, i) };
OtherElt = GetElementPtrInst::Create(OtherPtr, Idx, Idx + 2,
OtherPtr->getNameStr()+"."+utostr(i), OtherPtr->getNameStr()+"."+utostr(i),
MI); MI);
} }
Value *EltPtr = NewElts[i]; Value *EltPtr = NewElts[i];
const Type *EltTy =cast<PointerType>(EltPtr->getType())->getElementType(); const Type *EltTy =cast<PointerType>(EltPtr->getType())->getElementType();
// If we got down to a scalar, insert a load or store as appropriate. // If we got down to a scalar, insert a load or store as appropriate.
if (EltTy->isSingleValueType()) { if (EltTy->isSingleValueType()) {
if (isa<MemCpyInst>(MI) || isa<MemMoveInst>(MI)) { if (isa<MemCpyInst>(MI) || isa<MemMoveInst>(MI)) {
Value *Elt = new LoadInst(SROADest ? OtherElt : EltPtr, "tmp", Value *Elt = new LoadInst(SROADest ? OtherElt : EltPtr, "tmp",
MI); MI);
new StoreInst(Elt, SROADest ? EltPtr : OtherElt, MI); new StoreInst(Elt, SROADest ? EltPtr : OtherElt, MI);
continue; continue;
} else { }
assert(isa<MemSetInst>(MI)); assert(isa<MemSetInst>(MI));
// If the stored element is zero (common case), just store a null // If the stored element is zero (common case), just store a null
// constant. // constant.
Constant *StoreVal; Constant *StoreVal;
if (ConstantInt *CI = dyn_cast<ConstantInt>(MI->getOperand(2))) { if (ConstantInt *CI = dyn_cast<ConstantInt>(MI->getOperand(2))) {
if (CI->isZero()) { if (CI->isZero()) {
StoreVal = Constant::getNullValue(EltTy); // 0.0, null, 0, <0,0> StoreVal = Constant::getNullValue(EltTy); // 0.0, null, 0, <0,0>
} else { } else {
// If EltTy is a vector type, get the element type. // If EltTy is a vector type, get the element type.
const Type *ValTy = EltTy; const Type *ValTy = EltTy;
if (const VectorType *VTy = dyn_cast<VectorType>(ValTy)) if (const VectorType *VTy = dyn_cast<VectorType>(ValTy))
ValTy = VTy->getElementType(); ValTy = VTy->getElementType();
// Construct an integer with the right value. // Construct an integer with the right value.
unsigned EltSize = TD->getTypeSizeInBits(ValTy); unsigned EltSize = TD->getTypeSizeInBits(ValTy);
APInt OneVal(EltSize, CI->getZExtValue()); APInt OneVal(EltSize, CI->getZExtValue());
APInt TotalVal(OneVal); APInt TotalVal(OneVal);
// Set each byte. // Set each byte.
for (unsigned i = 0; 8*i < EltSize; ++i) { for (unsigned i = 0; 8*i < EltSize; ++i) {
TotalVal = TotalVal.shl(8); TotalVal = TotalVal.shl(8);
TotalVal |= OneVal; TotalVal |= OneVal;
} }
// Convert the integer value to the appropriate type. // Convert the integer value to the appropriate type.
StoreVal = ConstantInt::get(TotalVal); StoreVal = ConstantInt::get(TotalVal);
if (isa<PointerType>(ValTy)) if (isa<PointerType>(ValTy))
StoreVal = ConstantExpr::getIntToPtr(StoreVal, ValTy); StoreVal = ConstantExpr::getIntToPtr(StoreVal, ValTy);
else if (ValTy->isFloatingPoint()) else if (ValTy->isFloatingPoint())
StoreVal = ConstantExpr::getBitCast(StoreVal, ValTy); StoreVal = ConstantExpr::getBitCast(StoreVal, ValTy);
assert(StoreVal->getType() == ValTy && "Type mismatch!"); assert(StoreVal->getType() == ValTy && "Type mismatch!");
// If the requested value was a vector constant, create it. // If the requested value was a vector constant, create it.
if (EltTy != ValTy) { if (EltTy != ValTy) {
unsigned NumElts = cast<VectorType>(ValTy)->getNumElements(); unsigned NumElts = cast<VectorType>(ValTy)->getNumElements();
SmallVector<Constant*, 16> Elts(NumElts, StoreVal); SmallVector<Constant*, 16> Elts(NumElts, StoreVal);
StoreVal = ConstantVector::get(&Elts[0], NumElts); StoreVal = ConstantVector::get(&Elts[0], NumElts);
} }
} }
new StoreInst(StoreVal, EltPtr, MI); new StoreInst(StoreVal, EltPtr, MI);
continue; continue;
} }
// Otherwise, if we're storing a byte variable, use a memset call for // Otherwise, if we're storing a byte variable, use a memset call for
// this element. // this element.
} }
}
// Cast the element pointer to BytePtrTy.
// Cast the element pointer to BytePtrTy. if (EltPtr->getType() != BytePtrTy)
if (EltPtr->getType() != BytePtrTy) EltPtr = new BitCastInst(EltPtr, BytePtrTy, EltPtr->getNameStr(), MI);
EltPtr = new BitCastInst(EltPtr, BytePtrTy, EltPtr->getNameStr(), MI);
// Cast the other pointer (if we have one) to BytePtrTy.
// Cast the other pointer (if we have one) to BytePtrTy. if (OtherElt && OtherElt->getType() != BytePtrTy)
if (OtherElt && OtherElt->getType() != BytePtrTy) OtherElt = new BitCastInst(OtherElt, BytePtrTy,OtherElt->getNameStr(),
OtherElt = new BitCastInst(OtherElt, BytePtrTy,OtherElt->getNameStr(), MI);
MI);
unsigned EltSize = TD->getABITypeSize(EltTy);
unsigned EltSize = TD->getABITypeSize(EltTy);
// Finally, insert the meminst for this element.
// Finally, insert the meminst for this element. if (isa<MemCpyInst>(MI) || isa<MemMoveInst>(MI)) {
if (isa<MemCpyInst>(MI) || isa<MemMoveInst>(MI)) { Value *Ops[] = {
Value *Ops[] = { SROADest ? EltPtr : OtherElt, // Dest ptr
SROADest ? EltPtr : OtherElt, // Dest ptr SROADest ? OtherElt : EltPtr, // Src ptr
SROADest ? OtherElt : EltPtr, // Src ptr ConstantInt::get(MI->getOperand(3)->getType(), EltSize), // Size
ConstantInt::get(MI->getOperand(3)->getType(), EltSize), // Size Zero // Align
Zero // Align };
}; CallInst::Create(TheFn, Ops, Ops + 4, "", MI);
CallInst::Create(TheFn, Ops, Ops + 4, "", MI); } else {
} else { assert(isa<MemSetInst>(MI));
assert(isa<MemSetInst>(MI)); Value *Ops[] = {
Value *Ops[] = { EltPtr, MI->getOperand(2), // Dest, Value,
EltPtr, MI->getOperand(2), // Dest, Value, ConstantInt::get(MI->getOperand(3)->getType(), EltSize), // Size
ConstantInt::get(MI->getOperand(3)->getType(), EltSize), // Size Zero // Align
Zero // Align };
}; CallInst::Create(TheFn, Ops, Ops + 4, "", MI);
CallInst::Create(TheFn, Ops, Ops + 4, "", MI);
}
} }
// Finally, MI is now dead, as we've modified its actions to occur on all of
// the elements of the aggregate.
++UI;
MI->eraseFromParent();
} }
} }
/// HasPadding - Return true if the specified type has any structure or /// HasPadding - Return true if the specified type has any structure or
/// alignment padding, false otherwise. /// alignment padding, false otherwise.