diff --git a/include/llvm/Constants.h b/include/llvm/Constants.h index 229ef9c76fd..738c90cee74 100644 --- a/include/llvm/Constants.h +++ b/include/llvm/Constants.h @@ -629,10 +629,12 @@ protected: Constant *C2); static Constant *getSelectTy(const Type *Ty, Constant *C1, Constant *C2, Constant *C3); + template static Constant *getGetElementPtrTy(const Type *Ty, Constant *C, - Value* const *Idxs, unsigned NumIdxs); + IndexTy const *Idxs, unsigned NumIdxs); + template static Constant *getInBoundsGetElementPtrTy(const Type *Ty, Constant *C, - Value* const *Idxs, + IndexTy const *Idxs, unsigned NumIdxs); static Constant *getExtractElementTy(const Type *Ty, Constant *Val, Constant *Idx); @@ -645,6 +647,14 @@ protected: static Constant *getInsertValueTy(const Type *Ty, Constant *Agg, Constant *Val, const unsigned *Idxs, unsigned NumIdxs); + template + static Constant *getGetElementPtrImpl(Constant *C, + IndexTy const *IdxList, + unsigned NumIdx); + template + static Constant *getInBoundsGetElementPtrImpl(Constant *C, + IndexTy const *IdxList, + unsigned NumIdx); public: // Static methods to construct a ConstantExpr of different kinds. Note that diff --git a/include/llvm/Instructions.h b/include/llvm/Instructions.h index 941227f81f2..c79fda0deb4 100644 --- a/include/llvm/Instructions.h +++ b/include/llvm/Instructions.h @@ -457,6 +457,9 @@ public: static const Type *getIndexedType(const Type *Ptr, Value* const *Idx, unsigned NumIdx); + static const Type *getIndexedType(const Type *Ptr, + Constant* const *Idx, unsigned NumIdx); + static const Type *getIndexedType(const Type *Ptr, uint64_t const *Idx, unsigned NumIdx); diff --git a/lib/VMCore/ConstantFold.cpp b/lib/VMCore/ConstantFold.cpp index 3dc78470ff8..a21b4a28e47 100644 --- a/lib/VMCore/ConstantFold.cpp +++ b/lib/VMCore/ConstantFold.cpp @@ -2067,53 +2067,52 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred, /// isInBoundsIndices - Test whether the given sequence of *normalized* indices /// is "inbounds". -static bool isInBoundsIndices(Constant *const *Idxs, size_t NumIdx) { +template +static bool isInBoundsIndices(IndexTy const *Idxs, size_t NumIdx) { // No indices means nothing that could be out of bounds. if (NumIdx == 0) return true; // If the first index is zero, it's in bounds. - if (Idxs[0]->isNullValue()) return true; + if (cast(Idxs[0])->isNullValue()) return true; // If the first index is one and all the rest are zero, it's in bounds, // by the one-past-the-end rule. if (!cast(Idxs[0])->isOne()) return false; for (unsigned i = 1, e = NumIdx; i != e; ++i) - if (!Idxs[i]->isNullValue()) + if (!cast(Idxs[i])->isNullValue()) return false; return true; } -Constant *llvm::ConstantFoldGetElementPtr(Constant *C, - bool inBounds, - Constant* const *Idxs, - unsigned NumIdx) { +template +static Constant *ConstantFoldGetElementPtrImpl(Constant *C, + bool inBounds, + IndexTy const *Idxs, + unsigned NumIdx) { + Constant *Idx0 = cast(Idxs[0]); if (NumIdx == 0 || - (NumIdx == 1 && Idxs[0]->isNullValue())) + (NumIdx == 1 && Idx0->isNullValue())) return C; if (isa(C)) { const PointerType *Ptr = cast(C->getType()); - const Type *Ty = GetElementPtrInst::getIndexedType(Ptr, - (Value **)Idxs, - (Value **)Idxs+NumIdx); + const Type *Ty = GetElementPtrInst::getIndexedType(Ptr, Idxs, Idxs+NumIdx); assert(Ty != 0 && "Invalid indices for GEP!"); return UndefValue::get(PointerType::get(Ty, Ptr->getAddressSpace())); } - Constant *Idx0 = Idxs[0]; if (C->isNullValue()) { bool isNull = true; for (unsigned i = 0, e = NumIdx; i != e; ++i) - if (!Idxs[i]->isNullValue()) { + if (!cast(Idxs[i])->isNullValue()) { isNull = false; break; } if (isNull) { const PointerType *Ptr = cast(C->getType()); - const Type *Ty = GetElementPtrInst::getIndexedType(Ptr, - (Value**)Idxs, - (Value**)Idxs+NumIdx); + const Type *Ty = GetElementPtrInst::getIndexedType(Ptr, Idxs, + Idxs+NumIdx); assert(Ty != 0 && "Invalid indices for GEP!"); return ConstantPointerNull::get( PointerType::get(Ty,Ptr->getAddressSpace())); @@ -2208,7 +2207,7 @@ Constant *llvm::ConstantFoldGetElementPtr(Constant *C, ATy->getNumElements()); NewIdxs[i] = ConstantExpr::getSRem(CI, Factor); - Constant *PrevIdx = Idxs[i-1]; + Constant *PrevIdx = cast(Idxs[i-1]); Constant *Div = ConstantExpr::getSDiv(CI, Factor); // Before adding, extend both operands to i64 to avoid @@ -2236,7 +2235,7 @@ Constant *llvm::ConstantFoldGetElementPtr(Constant *C, // If we did any factoring, start over with the adjusted indices. if (!NewIdxs.empty()) { for (unsigned i = 0; i != NumIdx; ++i) - if (!NewIdxs[i]) NewIdxs[i] = Idxs[i]; + if (!NewIdxs[i]) NewIdxs[i] = cast(Idxs[i]); return inBounds ? ConstantExpr::getInBoundsGetElementPtr(C, NewIdxs.data(), NewIdxs.size()) : @@ -2251,3 +2250,17 @@ Constant *llvm::ConstantFoldGetElementPtr(Constant *C, return 0; } + +Constant *llvm::ConstantFoldGetElementPtr(Constant *C, + bool inBounds, + Constant* const *Idxs, + unsigned NumIdx) { + return ConstantFoldGetElementPtrImpl(C, inBounds, Idxs, NumIdx); +} + +Constant *llvm::ConstantFoldGetElementPtr(Constant *C, + bool inBounds, + Value* const *Idxs, + unsigned NumIdx) { + return ConstantFoldGetElementPtrImpl(C, inBounds, Idxs, NumIdx); +} diff --git a/lib/VMCore/ConstantFold.h b/lib/VMCore/ConstantFold.h index d2dbbdd74c2..0ecd7b49a48 100644 --- a/lib/VMCore/ConstantFold.h +++ b/lib/VMCore/ConstantFold.h @@ -49,6 +49,8 @@ namespace llvm { Constant *C1, Constant *C2); Constant *ConstantFoldGetElementPtr(Constant *C, bool inBounds, Constant* const *Idxs, unsigned NumIdx); + Constant *ConstantFoldGetElementPtr(Constant *C, bool inBounds, + Value* const *Idxs, unsigned NumIdx); } // End llvm namespace #endif diff --git a/lib/VMCore/Constants.cpp b/lib/VMCore/Constants.cpp index 29fc0dbdf8d..46daa61d998 100644 --- a/lib/VMCore/Constants.cpp +++ b/lib/VMCore/Constants.cpp @@ -1546,8 +1546,9 @@ Constant *ConstantExpr::getSelectTy(const Type *ReqTy, Constant *C, return pImpl->ExprConstants.getOrCreate(ReqTy, Key); } +template Constant *ConstantExpr::getGetElementPtrTy(const Type *ReqTy, Constant *C, - Value* const *Idxs, + IndexTy const *Idxs, unsigned NumIdx) { assert(GetElementPtrInst::getIndexedType(C->getType(), Idxs, Idxs+NumIdx) == @@ -1555,7 +1556,7 @@ Constant *ConstantExpr::getGetElementPtrTy(const Type *ReqTy, Constant *C, "GEP indices invalid!"); if (Constant *FC = ConstantFoldGetElementPtr(C, /*inBounds=*/false, - (Constant**)Idxs, NumIdx)) + Idxs, NumIdx)) return FC; // Fold a few common cases... assert(C->getType()->isPointerTy() && @@ -1572,9 +1573,10 @@ Constant *ConstantExpr::getGetElementPtrTy(const Type *ReqTy, Constant *C, return pImpl->ExprConstants.getOrCreate(ReqTy, Key); } +template Constant *ConstantExpr::getInBoundsGetElementPtrTy(const Type *ReqTy, Constant *C, - Value *const *Idxs, + IndexTy const *Idxs, unsigned NumIdx) { assert(GetElementPtrInst::getIndexedType(C->getType(), Idxs, Idxs+NumIdx) == @@ -1582,7 +1584,7 @@ Constant *ConstantExpr::getInBoundsGetElementPtrTy(const Type *ReqTy, "GEP indices invalid!"); if (Constant *FC = ConstantFoldGetElementPtr(C, /*inBounds=*/true, - (Constant**)Idxs, NumIdx)) + Idxs, NumIdx)) return FC; // Fold a few common cases... assert(C->getType()->isPointerTy() && @@ -1600,8 +1602,9 @@ Constant *ConstantExpr::getInBoundsGetElementPtrTy(const Type *ReqTy, return pImpl->ExprConstants.getOrCreate(ReqTy, Key); } -Constant *ConstantExpr::getGetElementPtr(Constant *C, Value* const *Idxs, - unsigned NumIdx) { +template +Constant *ConstantExpr::getGetElementPtrImpl(Constant *C, IndexTy const *Idxs, + unsigned NumIdx) { // Get the result type of the getelementptr! const Type *Ty = GetElementPtrInst::getIndexedType(C->getType(), Idxs, Idxs+NumIdx); @@ -1610,9 +1613,10 @@ Constant *ConstantExpr::getGetElementPtr(Constant *C, Value* const *Idxs, return getGetElementPtrTy(PointerType::get(Ty, As), C, Idxs, NumIdx); } -Constant *ConstantExpr::getInBoundsGetElementPtr(Constant *C, - Value* const *Idxs, - unsigned NumIdx) { +template +Constant *ConstantExpr::getInBoundsGetElementPtrImpl(Constant *C, + IndexTy const *Idxs, + unsigned NumIdx) { // Get the result type of the getelementptr! const Type *Ty = GetElementPtrInst::getIndexedType(C->getType(), Idxs, Idxs+NumIdx); @@ -1621,15 +1625,26 @@ Constant *ConstantExpr::getInBoundsGetElementPtr(Constant *C, return getInBoundsGetElementPtrTy(PointerType::get(Ty, As), C, Idxs, NumIdx); } -Constant *ConstantExpr::getGetElementPtr(Constant *C, Constant* const *Idxs, +Constant *ConstantExpr::getGetElementPtr(Constant *C, Value* const *Idxs, unsigned NumIdx) { - return getGetElementPtr(C, (Value* const *)Idxs, NumIdx); + return getGetElementPtrImpl(C, Idxs, NumIdx); +} + +Constant *ConstantExpr::getGetElementPtr(Constant *C, Constant *const *Idxs, + unsigned NumIdx) { + return getGetElementPtrImpl(C, Idxs, NumIdx); } Constant *ConstantExpr::getInBoundsGetElementPtr(Constant *C, - Constant* const *Idxs, + Value* const *Idxs, unsigned NumIdx) { - return getInBoundsGetElementPtr(C, (Value* const *)Idxs, NumIdx); + return getInBoundsGetElementPtrImpl(C, Idxs, NumIdx); +} + +Constant *ConstantExpr::getInBoundsGetElementPtr(Constant *C, + Constant *const *Idxs, + unsigned NumIdx) { + return getInBoundsGetElementPtrImpl(C, Idxs, NumIdx); } Constant * diff --git a/lib/VMCore/Instructions.cpp b/lib/VMCore/Instructions.cpp index 33acd0d040a..909dab9a651 100644 --- a/lib/VMCore/Instructions.cpp +++ b/lib/VMCore/Instructions.cpp @@ -1173,6 +1173,12 @@ const Type* GetElementPtrInst::getIndexedType(const Type *Ptr, return getIndexedTypeInternal(Ptr, Idxs, NumIdx); } +const Type* GetElementPtrInst::getIndexedType(const Type *Ptr, + Constant* const *Idxs, + unsigned NumIdx) { + return getIndexedTypeInternal(Ptr, Idxs, NumIdx); +} + const Type* GetElementPtrInst::getIndexedType(const Type *Ptr, uint64_t const *Idxs, unsigned NumIdx) {