[MLIR][Presburger] PWMAFunction::valueAt: support local ids

Add a baseline implementation of support for local ids for `PWMAFunction::valueAt`. This can be made more efficient later if needed by handling locals with known div representations separately.

Reviewed By: Groverkss

Differential Revision: https://reviews.llvm.org/D122144
This commit is contained in:
Arjun P 2022-03-23 23:11:28 +00:00
parent 5630143af3
commit 4418669f1e
6 changed files with 72 additions and 21 deletions

View File

@ -285,11 +285,12 @@ public:
Optional<uint64_t> computeVolume() const;
/// Returns true if the given point satisfies the constraints, or false
/// otherwise.
///
/// Note: currently, if the relation contains local ids, the values of
/// the local ids must also be provided.
/// otherwise. Takes the values of all ids including locals.
bool containsPoint(ArrayRef<int64_t> point) const;
/// Given the values of non-local ids, return a satisfying assignment to the
/// local if one exists, or an empty optional otherwise.
Optional<SmallVector<int64_t, 8>>
containsPointNoLocal(ArrayRef<int64_t> point) const;
/// Find equality and pairs of inequality contraints identified by their
/// position indices, using which an explicit representation for each local

View File

@ -41,8 +41,7 @@ namespace presburger {
/// each id, and an extra column at the end for the constant term.
///
/// Checking equality of two such functions is supported, as well as finding the
/// value of the function at a specified point. Note that local ids in the
/// domain are not yet supported for finding the value at a point.
/// value of the function at a specified point.
class MultiAffineFunction : protected IntegerPolyhedron {
public:
/// We use protected inheritance to avoid inheriting the whole public
@ -114,8 +113,6 @@ public:
/// Get the value of the function at the specified point. If the point lies
/// outside the domain, an empty optional is returned.
///
/// Note: domains with local ids are not yet supported, and will assert-fail.
Optional<SmallVector<int64_t, 8>> valueAt(ArrayRef<int64_t> point) const;
void print(raw_ostream &os) const;
@ -145,8 +142,7 @@ private:
/// symbolic ids.
///
/// Support is provided to compare equality of two such functions as well as
/// finding the value of the function at a point. Note that local ids in the
/// piece are not supported for the latter.
/// finding the value of the function at a point.
class PWMAFunction : public PresburgerSpace {
public:
PWMAFunction(unsigned numDims, unsigned numSymbols, unsigned numOutputs)
@ -170,8 +166,6 @@ public:
/// Return the value at the specified point and an empty optional if the
/// point does not lie in the domain.
///
/// Note: domains with local ids are not yet supported, and will assert-fail.
Optional<SmallVector<int64_t, 8>> valueAt(ArrayRef<int64_t> point) const;
/// Return whether `this` and `other` are equal as PWMAFunctions, i.e. whether

View File

@ -784,6 +784,25 @@ bool IntegerRelation::containsPoint(ArrayRef<int64_t> point) const {
return true;
}
/// Just substitute the values given and check if an integer sample exists for
/// the local ids.
///
/// TODO: this could be made more efficient by handling divisions separately.
/// Instead of finding an integer sample over all the locals, we can first
/// compute the values of the locals that have division representations and
/// only use the integer emptiness check for the locals that don't have this.
/// Handling this correctly requires ordering the divs, though.
Optional<SmallVector<int64_t, 8>>
IntegerRelation::containsPointNoLocal(ArrayRef<int64_t> point) const {
assert(point.size() == getNumIds() - getNumLocalIds() &&
"Point should contain all ids except locals!");
assert(getIdKindOffset(IdKind::Local) == getNumIds() - getNumLocalIds() &&
"This function depends on locals being stored last!");
IntegerRelation copy = *this;
copy.setAndEliminate(0, point);
return copy.findIntegerSample();
}
void IntegerRelation::getLocalReprs(std::vector<MaybeLocalRepr> &repr) const {
std::vector<SmallVector<int64_t, 8>> dividends(getNumLocalIds());
SmallVector<unsigned, 4> denominators(getNumLocalIds());

View File

@ -36,19 +36,26 @@ PresburgerSet PWMAFunction::getDomain() const {
Optional<SmallVector<int64_t, 8>>
MultiAffineFunction::valueAt(ArrayRef<int64_t> point) const {
assert(getNumLocalIds() == 0 && "Local ids are not yet supported!");
assert(point.size() == getNumIds() && "Point has incorrect dimensionality!");
assert(point.size() == getNumDimAndSymbolIds() &&
"Point has incorrect dimensionality!");
if (!getDomain().containsPoint(point))
Optional<SmallVector<int64_t, 8>> maybeLocalValues =
getDomain().containsPointNoLocal(point);
if (!maybeLocalValues)
return {};
// The point lies in the domain, so we need to compute the output value.
SmallVector<int64_t, 8> pointHomogenous{llvm::to_vector(point)};
// The given point didn't include the values of locals which the output is a
// function of; we have computed one possible set of values and use them
// here. The function is not allowed to have local ids that take more than
// one possible value.
pointHomogenous.append(*maybeLocalValues);
// The matrix `output` has an affine expression in the ith row, corresponding
// to the expression for the ith value in the output vector. The last column
// of the matrix contains the constant term. Let v be the input point with
// a 1 appended at the end. We can see that output * v gives the desired
// output vector.
SmallVector<int64_t, 8> pointHomogenous{llvm::to_vector(point)};
pointHomogenous.push_back(1);
SmallVector<int64_t, 8> result =
output.postMultiplyWithColumn(pointHomogenous);

View File

@ -1187,3 +1187,18 @@ TEST(IntegerPolyhedronTest, computeVolume) {
parsePoly("(x, y) : (2*x - y >= 0, y - 3*x >= 0)"),
/*trueVolume=*/{}, /*resultBound=*/{});
}
TEST(IntegerPolyhedronTest, containsPointNoLocal) {
IntegerPolyhedron poly1 = parsePoly("(x) : ((x floordiv 2) - x == 0)");
EXPECT_TRUE(poly1.containsPointNoLocal({0}));
EXPECT_FALSE(poly1.containsPointNoLocal({1}));
IntegerPolyhedron poly2 = parsePoly(
"(x) : (x - 2*(x floordiv 2) == 0, x - 4*(x floordiv 4) - 2 == 0)");
EXPECT_TRUE(poly2.containsPointNoLocal({6}));
EXPECT_FALSE(poly2.containsPointNoLocal({4}));
IntegerPolyhedron poly3 = parsePoly("(x, y) : (2*x - y >= 0, y - 3*x >= 0)");
EXPECT_TRUE(poly3.containsPointNoLocal({0, 0}));
EXPECT_FALSE(poly3.containsPointNoLocal({1, 0}));
}

View File

@ -129,16 +129,31 @@ TEST(PWAFunctionTest, isEqual) {
}
TEST(PWMAFunction, valueAt) {
PWMAFunction nonNegPWAF = parsePWMAF(
PWMAFunction nonNegPWMAF = parsePWMAF(
/*numInputs=*/2, /*numOutputs=*/2,
{
{"(x, y) : (x >= 0)", {{1, 2, 3}, {3, 4, 5}}}, // (x, y).
{"(x, y) : (y >= 0, -x - 1 >= 0)", {{-1, 2, 3}, {-3, 4, 5}}} // (x, y)
});
EXPECT_THAT(*nonNegPWAF.valueAt({2, 3}), ElementsAre(11, 23));
EXPECT_THAT(*nonNegPWAF.valueAt({-2, 3}), ElementsAre(11, 23));
EXPECT_THAT(*nonNegPWAF.valueAt({2, -3}), ElementsAre(-1, -1));
EXPECT_FALSE(nonNegPWAF.valueAt({-2, -3}).hasValue());
EXPECT_THAT(*nonNegPWMAF.valueAt({2, 3}), ElementsAre(11, 23));
EXPECT_THAT(*nonNegPWMAF.valueAt({-2, 3}), ElementsAre(11, 23));
EXPECT_THAT(*nonNegPWMAF.valueAt({2, -3}), ElementsAre(-1, -1));
EXPECT_FALSE(nonNegPWMAF.valueAt({-2, -3}).hasValue());
PWMAFunction divPWMAF = parsePWMAF(
/*numInputs=*/2, /*numOutputs=*/2,
{
{"(x, y) : (x >= 0, x - 2*(x floordiv 2) == 0)",
{{0, 2, 1, 3}, {0, 4, 3, 5}}}, // (x, y).
{"(x, y) : (y >= 0, -x - 1 >= 0)", {{-1, 2, 3}, {-3, 4, 5}}} // (x, y)
});
EXPECT_THAT(*divPWMAF.valueAt({4, 3}), ElementsAre(11, 23));
EXPECT_THAT(*divPWMAF.valueAt({4, -3}), ElementsAre(-1, -1));
EXPECT_FALSE(divPWMAF.valueAt({3, 3}).hasValue());
EXPECT_FALSE(divPWMAF.valueAt({3, -3}).hasValue());
EXPECT_THAT(*divPWMAF.valueAt({-2, 3}), ElementsAre(11, 23));
EXPECT_FALSE(divPWMAF.valueAt({-2, -3}).hasValue());
}
TEST(PWMAFunction, removeIdRangeRegressionTest) {