[IR] Implement Constant::isNegativeZeroValue/isZeroValue/isAllOnesValue/isOneValue/isMinSignedValue for ConstantDataVector without going through getElementAsConstant

Summary:
Currently these methods call ConstantDataVector::getSplatValue which uses getElementsAsConstant to create a Constant object representing the element value. This method incurs a map lookup to see if we already have created such a Constant before and if not allocates a new Constant object.

This patch changes these methods to use getElementAsAPFloat and getElementAsInteger so we can just examine the data values directly.

Reviewers: spatel, pcc, dexonsmith, bogner, craig.topper

Reviewed By: craig.topper

Subscribers: llvm-commits

Differential Revision: https://reviews.llvm.org/D35040

llvm-svn: 308112
This commit is contained in:
Craig Topper 2017-07-15 22:06:19 +00:00
parent d918d5b36b
commit 0b4b4e388d
2 changed files with 75 additions and 19 deletions

View File

@ -598,6 +598,10 @@ public:
/// specified element in the low bits of a uint64_t. /// specified element in the low bits of a uint64_t.
uint64_t getElementAsInteger(unsigned i) const; uint64_t getElementAsInteger(unsigned i) const;
/// If this is a sequential container of integers (of any size), return the
/// specified element as an APInt.
APInt getElementAsAPInt(unsigned i) const;
/// If this is a sequential container of floating point type, return the /// If this is a sequential container of floating point type, return the
/// specified element as an APFloat. /// specified element as an APFloat.
APFloat getElementAsAPFloat(unsigned i) const; APFloat getElementAsAPFloat(unsigned i) const;
@ -761,6 +765,10 @@ public:
/// i32/i64/float/double) and must be a ConstantFP or ConstantInt. /// i32/i64/float/double) and must be a ConstantFP or ConstantInt.
static Constant *getSplat(unsigned NumElts, Constant *Elt); static Constant *getSplat(unsigned NumElts, Constant *Elt);
/// Returns true if this is a splat constant, meaning that all elements have
/// the same value.
bool isSplat() const;
/// If this is a splat constant, meaning that all of the elements have the /// If this is a splat constant, meaning that all of the elements have the
/// same value, return that value. Otherwise return NULL. /// same value, return that value. Otherwise return NULL.
Constant *getSplatValue() const; Constant *getSplatValue() const;

View File

@ -44,8 +44,8 @@ bool Constant::isNegativeZeroValue() const {
// Equivalent for a vector of -0.0's. // Equivalent for a vector of -0.0's.
if (const ConstantDataVector *CV = dyn_cast<ConstantDataVector>(this)) if (const ConstantDataVector *CV = dyn_cast<ConstantDataVector>(this))
if (ConstantFP *SplatCFP = dyn_cast_or_null<ConstantFP>(CV->getSplatValue())) if (CV->getElementType()->isFloatingPointTy() && CV->isSplat())
if (SplatCFP && SplatCFP->isZero() && SplatCFP->isNegative()) if (CV->getElementAsAPFloat(0).isNegZero())
return true; return true;
if (const ConstantVector *CV = dyn_cast<ConstantVector>(this)) if (const ConstantVector *CV = dyn_cast<ConstantVector>(this))
@ -70,8 +70,8 @@ bool Constant::isZeroValue() const {
// Equivalent for a vector of -0.0's. // Equivalent for a vector of -0.0's.
if (const ConstantDataVector *CV = dyn_cast<ConstantDataVector>(this)) if (const ConstantDataVector *CV = dyn_cast<ConstantDataVector>(this))
if (ConstantFP *SplatCFP = dyn_cast_or_null<ConstantFP>(CV->getSplatValue())) if (CV->getElementType()->isFloatingPointTy() && CV->isSplat())
if (SplatCFP && SplatCFP->isZero()) if (CV->getElementAsAPFloat(0).isZero())
return true; return true;
if (const ConstantVector *CV = dyn_cast<ConstantVector>(this)) if (const ConstantVector *CV = dyn_cast<ConstantVector>(this))
@ -113,9 +113,13 @@ bool Constant::isAllOnesValue() const {
return Splat->isAllOnesValue(); return Splat->isAllOnesValue();
// Check for constant vectors which are splats of -1 values. // Check for constant vectors which are splats of -1 values.
if (const ConstantDataVector *CV = dyn_cast<ConstantDataVector>(this)) if (const ConstantDataVector *CV = dyn_cast<ConstantDataVector>(this)) {
if (Constant *Splat = CV->getSplatValue()) if (CV->isSplat()) {
return Splat->isAllOnesValue(); if (CV->getElementType()->isFloatingPointTy())
return CV->getElementAsAPFloat(0).bitcastToAPInt().isAllOnesValue();
return CV->getElementAsAPInt(0).isAllOnesValue();
}
}
return false; return false;
} }
@ -135,9 +139,13 @@ bool Constant::isOneValue() const {
return Splat->isOneValue(); return Splat->isOneValue();
// Check for constant vectors which are splats of 1 values. // Check for constant vectors which are splats of 1 values.
if (const ConstantDataVector *CV = dyn_cast<ConstantDataVector>(this)) if (const ConstantDataVector *CV = dyn_cast<ConstantDataVector>(this)) {
if (Constant *Splat = CV->getSplatValue()) if (CV->isSplat()) {
return Splat->isOneValue(); if (CV->getElementType()->isFloatingPointTy())
return CV->getElementAsAPFloat(0).bitcastToAPInt().isOneValue();
return CV->getElementAsAPInt(0).isOneValue();
}
}
return false; return false;
} }
@ -157,9 +165,13 @@ bool Constant::isMinSignedValue() const {
return Splat->isMinSignedValue(); return Splat->isMinSignedValue();
// Check for constant vectors which are splats of INT_MIN values. // Check for constant vectors which are splats of INT_MIN values.
if (const ConstantDataVector *CV = dyn_cast<ConstantDataVector>(this)) if (const ConstantDataVector *CV = dyn_cast<ConstantDataVector>(this)) {
if (Constant *Splat = CV->getSplatValue()) if (CV->isSplat()) {
return Splat->isMinSignedValue(); if (CV->getElementType()->isFloatingPointTy())
return CV->getElementAsAPFloat(0).bitcastToAPInt().isMinSignedValue();
return CV->getElementAsAPInt(0).isMinSignedValue();
}
}
return false; return false;
} }
@ -179,9 +191,13 @@ bool Constant::isNotMinSignedValue() const {
return Splat->isNotMinSignedValue(); return Splat->isNotMinSignedValue();
// Check for constant vectors which are splats of INT_MIN values. // Check for constant vectors which are splats of INT_MIN values.
if (const ConstantDataVector *CV = dyn_cast<ConstantDataVector>(this)) if (const ConstantDataVector *CV = dyn_cast<ConstantDataVector>(this)) {
if (Constant *Splat = CV->getSplatValue()) if (CV->isSplat()) {
return Splat->isNotMinSignedValue(); if (CV->getElementType()->isFloatingPointTy())
return !CV->getElementAsAPFloat(0).bitcastToAPInt().isMinSignedValue();
return !CV->getElementAsAPInt(0).isMinSignedValue();
}
}
// It *may* contain INT_MIN, we can't tell. // It *may* contain INT_MIN, we can't tell.
return false; return false;
@ -2565,6 +2581,34 @@ uint64_t ConstantDataSequential::getElementAsInteger(unsigned Elt) const {
} }
} }
APInt ConstantDataSequential::getElementAsAPInt(unsigned Elt) const {
assert(isa<IntegerType>(getElementType()) &&
"Accessor can only be used when element is an integer");
const char *EltPtr = getElementPointer(Elt);
// The data is stored in host byte order, make sure to cast back to the right
// type to load with the right endianness.
switch (getElementType()->getIntegerBitWidth()) {
default: llvm_unreachable("Invalid bitwidth for CDS");
case 8: {
auto EltVal = *reinterpret_cast<const uint8_t *>(EltPtr);
return APInt(8, EltVal);
}
case 16: {
auto EltVal = *reinterpret_cast<const uint16_t *>(EltPtr);
return APInt(16, EltVal);
}
case 32: {
auto EltVal = *reinterpret_cast<const uint32_t *>(EltPtr);
return APInt(32, EltVal);
}
case 64: {
auto EltVal = *reinterpret_cast<const uint64_t *>(EltPtr);
return APInt(64, EltVal);
}
}
}
APFloat ConstantDataSequential::getElementAsAPFloat(unsigned Elt) const { APFloat ConstantDataSequential::getElementAsAPFloat(unsigned Elt) const {
const char *EltPtr = getElementPointer(Elt); const char *EltPtr = getElementPointer(Elt);
@ -2623,17 +2667,21 @@ bool ConstantDataSequential::isCString() const {
return Str.drop_back().find(0) == StringRef::npos; return Str.drop_back().find(0) == StringRef::npos;
} }
Constant *ConstantDataVector::getSplatValue() const { bool ConstantDataVector::isSplat() const {
const char *Base = getRawDataValues().data(); const char *Base = getRawDataValues().data();
// Compare elements 1+ to the 0'th element. // Compare elements 1+ to the 0'th element.
unsigned EltSize = getElementByteSize(); unsigned EltSize = getElementByteSize();
for (unsigned i = 1, e = getNumElements(); i != e; ++i) for (unsigned i = 1, e = getNumElements(); i != e; ++i)
if (memcmp(Base, Base+i*EltSize, EltSize)) if (memcmp(Base, Base+i*EltSize, EltSize))
return nullptr; return false;
return true;
}
Constant *ConstantDataVector::getSplatValue() const {
// If they're all the same, return the 0th one as a representative. // If they're all the same, return the 0th one as a representative.
return getElementAsConstant(0); return isSplat() ? getElementAsConstant(0) : nullptr;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//