[MLIR][Presburger] introduce SetCoalescer

This patch refactors the current coalesce implementation. It introduces
the `SetCoalescer`, a class in which all coalescing functionality lives.
The main advantage over the old design is the fact that the vectors of
constraints do not have to be passed around, but are implemented as
private fields of the SetCoalescer. This will become especially
important once more inequality types are introduced.

Reviewed By: arjunp

Differential Revision: https://reviews.llvm.org/D121364
This commit is contained in:
Michel Weber 2022-03-18 07:58:46 +00:00 committed by Arjun P
parent 26c95ae389
commit ae3e3c6362
2 changed files with 288 additions and 223 deletions

View File

@ -18,6 +18,11 @@
namespace mlir {
namespace presburger {
/// The SetCoalescer class contains all functionality concerning the coalesce
/// heuristic. It is built from a `PresburgerRelation` and has the `coalesce()`
/// function as its main API.
class SetCoalescer;
/// A PresburgerRelation represents a union of IntegerRelations that live in
/// the same PresburgerSpace with support for union, intersection, subtraction,
/// and complement operations, as well as sampling.
@ -120,6 +125,8 @@ protected:
/// The list of disjuncts that this set is the union of.
SmallVector<IntegerRelation, 2> integerRelations;
friend class SetCoalescer;
};
class PresburgerSet : public PresburgerRelation {

View File

@ -393,242 +393,99 @@ Optional<uint64_t> PresburgerRelation::computeVolume() const {
return result;
}
/// Given an IntegerRelation `p` and one of its inequalities `ineq`, check
/// that all inequalities of `cuttingIneqs` are redundant for the facet of `p`
/// where `ineq` holds as an equality. `simp` must be the Simplex constructed
/// from `p`.
static bool isFacetContained(ArrayRef<int64_t> ineq, Simplex &simp,
IntegerRelation &p,
ArrayRef<ArrayRef<int64_t>> cuttingIneqs) {
unsigned snapshot = simp.getSnapshot();
simp.addEquality(ineq);
if (llvm::any_of(cuttingIneqs, [&simp](ArrayRef<int64_t> curr) {
return !simp.isRedundantInequality(curr);
})) {
simp.rollback(snapshot);
return false;
}
simp.rollback(snapshot);
return true;
}
/// The SetCoalescer class contains all functionality concerning the coalesce
/// heuristic. It is built from a `PresburgerRelation` and has the `coalesce()`
/// function as its main API. The coalesce heuristic simplifies the
/// representation of a PresburgerRelation. In particular, it removes all
/// disjuncts which are subsets of other disjuncts in the union and it combines
/// sets that overlap and can be combined in a convex way.
class presburger::SetCoalescer {
/// Adds `disjunct` to `disjuncts` and removes the disjuncts at position `i` and
/// `j`. Updates `simplices` to reflect the changes. `i` and `j` cannot be
/// equal.
static void addCoalescedDisjunct(SmallVectorImpl<IntegerRelation> &disjuncts,
unsigned i, unsigned j,
const IntegerRelation &disjunct,
SmallVectorImpl<Simplex> &simplices) {
assert(i != j && "The indices must refer to different disjuncts");
public:
/// Simplifies the representation of a PresburgerSet.
PresburgerRelation coalesce();
unsigned n = disjuncts.size();
if (j == n - 1) {
// This case needs special handling since position `n` - 1 is removed from
// the vector, hence the `IntegerRelation` at position `n` - 2 is lost
// otherwise.
disjuncts[i] = disjuncts[n - 2];
disjuncts.pop_back();
disjuncts[n - 2] = disjunct;
/// Construct a SetCoalescer from a PresburgerSet.
SetCoalescer(const PresburgerRelation &s);
simplices[i] = simplices[n - 2];
simplices.pop_back();
simplices[n - 2] = Simplex(disjunct);
private:
/// The dimensionality of the set the SetCoalescer is coalescing.
unsigned numDomainIds;
unsigned numRangeIds;
unsigned numSymbolIds;
} else {
// Other possible edge cases are correct since for `j` or `i` == `n` - 2,
// the `IntegerRelation` at position `n` - 2 should be lost. The case
// `i` == `n` - 1 makes the first following statement a noop. Hence, in this
// case the same thing is done as above, but with `j` rather than `i`.
disjuncts[i] = disjuncts[n - 1];
disjuncts[j] = disjuncts[n - 2];
disjuncts.pop_back();
disjuncts[n - 2] = disjunct;
/// The current list of `IntegerRelation`s that the currently coalesced set is
/// the union of.
SmallVector<IntegerRelation, 2> disjuncts;
/// The list of `Simplex`s constructed from the elements of `disjuncts`.
SmallVector<Simplex, 2> simplices;
simplices[i] = simplices[n - 1];
simplices[j] = simplices[n - 2];
simplices.pop_back();
simplices[n - 2] = Simplex(disjunct);
}
}
/// Given two disjuncts `a` and `b` at positions `i` and `j` in `disjuncts`
/// and `redundantIneqsA` being the inequalities of `a` that are redundant for
/// `b` (similarly for `cuttingIneqsA`, `redundantIneqsB`, and `cuttingIneqsB`),
/// checks whether the facets of all cutting inequalites of `a` are contained in
/// `b`. If so, a new disjunct consisting of all redundant inequalites of `a`
/// and `b` and all equalities of both is created.
///
/// An example of this case:
/// ___________ ___________
/// / / | / / /
/// \ \ | / ==> \ /
/// \ \ | / \ /
/// \___\|/ \_____/
///
///
static LogicalResult
coalescePairCutCase(SmallVectorImpl<IntegerRelation> &disjuncts,
SmallVectorImpl<Simplex> &simplices, unsigned i, unsigned j,
ArrayRef<ArrayRef<int64_t>> redundantIneqsA,
ArrayRef<ArrayRef<int64_t>> cuttingIneqsA,
ArrayRef<ArrayRef<int64_t>> redundantIneqsB,
ArrayRef<ArrayRef<int64_t>> cuttingIneqsB) {
/// All inequalities of `b` need to be redundant. We already know that the
/// redundant ones are, so only the cutting ones remain to be checked.
Simplex &simp = simplices[i];
IntegerRelation &disjunct = disjuncts[i];
if (llvm::any_of(cuttingIneqsA, [&simp, &disjunct,
&cuttingIneqsB](ArrayRef<int64_t> curr) {
return !isFacetContained(curr, simp, disjunct, cuttingIneqsB);
}))
return failure();
IntegerRelation newSet(disjunct.getNumDomainIds(), disjunct.getNumRangeIds(),
disjunct.getNumSymbolIds(), disjunct.getNumLocalIds());
for (ArrayRef<int64_t> curr : redundantIneqsA)
newSet.addInequality(curr);
for (ArrayRef<int64_t> curr : redundantIneqsB)
newSet.addInequality(curr);
addCoalescedDisjunct(disjuncts, i, j, newSet, simplices);
return success();
}
/// Types the inequality `ineq` according to its `IneqType` for `simp` into
/// `redundantIneqs` and `cuttingIneqs`. Returns success, if no separate
/// inequalities were encountered. Otherwise, returns failure.
static LogicalResult
typeInequality(ArrayRef<int64_t> ineq, Simplex &simp,
SmallVectorImpl<ArrayRef<int64_t>> &redundantIneqs,
SmallVectorImpl<ArrayRef<int64_t>> &cuttingIneqs) {
Simplex::IneqType type = simp.findIneqType(ineq);
if (type == Simplex::IneqType::Redundant)
redundantIneqs.push_back(ineq);
else if (type == Simplex::IneqType::Cut)
cuttingIneqs.push_back(ineq);
else
return failure();
return success();
}
/// Types the equality `eq`, i.e. for `eq` == 0, types both `eq` >= 0 and -`eq`
/// >= 0 according to their `IneqType` for `simp` into `redundantIneqs` and
/// `cuttingIneqs`. Returns success, if no separate inequalities were
/// encountered. Otherwise, returns failure.
static LogicalResult
typeEquality(ArrayRef<int64_t> eq, Simplex &simp,
SmallVectorImpl<ArrayRef<int64_t>> &redundantIneqs,
SmallVectorImpl<ArrayRef<int64_t>> &cuttingIneqs,
SmallVectorImpl<SmallVector<int64_t, 2>> &negEqs) {
if (typeInequality(eq, simp, redundantIneqs, cuttingIneqs).failed())
return failure();
negEqs.push_back(getNegatedCoeffs(eq));
ArrayRef<int64_t> inv(negEqs.back());
if (typeInequality(inv, simp, redundantIneqs, cuttingIneqs).failed())
return failure();
return success();
}
/// Replaces the element at position `i` with the last element and erases the
/// last element for both `disjuncts` and `simplices`.
static void eraseDisjunct(unsigned i,
SmallVectorImpl<IntegerRelation> &disjuncts,
SmallVectorImpl<Simplex> &simplices) {
assert(simplices.size() == disjuncts.size() &&
"simplices and disjuncts must be equally as long");
disjuncts[i] = disjuncts.back();
disjuncts.pop_back();
simplices[i] = simplices.back();
simplices.pop_back();
}
/// Attempts to coalesce the two IntegerRelations at position `i` and `j` in
/// `disjuncts` in-place. Returns whether the disjuncts were successfully
/// coalesced. The simplices in `simplices` need to be the ones constructed from
/// `disjuncts`. At this point, there are no empty disjuncts in
/// `disjuncts` left.
static LogicalResult coalescePair(unsigned i, unsigned j,
SmallVectorImpl<IntegerRelation> &disjuncts,
SmallVectorImpl<Simplex> &simplices) {
IntegerRelation &a = disjuncts[i];
IntegerRelation &b = disjuncts[j];
/// Handling of local ids is not yet implemented, so these cases are skipped.
/// TODO: implement local id support.
if (a.getNumLocalIds() != 0 || b.getNumLocalIds() != 0)
return failure();
Simplex &simpA = simplices[i];
Simplex &simpB = simplices[j];
SmallVector<ArrayRef<int64_t>, 2> redundantIneqsA;
SmallVector<ArrayRef<int64_t>, 2> cuttingIneqsA;
/// The list of all inversed equalities during typing. This ensures that
/// the constraints exist even after the typing function has concluded.
SmallVector<SmallVector<int64_t, 2>, 2> negEqs;
// Organize all inequalities and equalities of `a` according to their type for
// `b` into `redundantIneqsA` and `cuttingIneqsA` (and vice versa for all
// inequalities of `b` according to their type in `a`). If a separate
// inequality is encountered during typing, the two IntegerRelations cannot
// be coalesced.
for (int k = 0, e = a.getNumInequalities(); k < e; ++k)
if (typeInequality(a.getInequality(k), simpB, redundantIneqsA,
cuttingIneqsA)
.failed())
return failure();
for (int k = 0, e = a.getNumEqualities(); k < e; ++k)
if (typeEquality(a.getEquality(k), simpB, redundantIneqsA, cuttingIneqsA,
negEqs)
.failed())
return failure();
/// `redundantIneqsA` is the inequalities of `a` that are redundant for `b`
/// (similarly for `cuttingIneqsA`, `redundantIneqsB`, and `cuttingIneqsB`).
SmallVector<ArrayRef<int64_t>, 2> redundantIneqsA;
SmallVector<ArrayRef<int64_t>, 2> cuttingIneqsA;
SmallVector<ArrayRef<int64_t>, 2> redundantIneqsB;
SmallVector<ArrayRef<int64_t>, 2> cuttingIneqsB;
for (int k = 0, e = b.getNumInequalities(); k < e; ++k)
if (typeInequality(b.getInequality(k), simpA, redundantIneqsB,
cuttingIneqsB)
.failed())
return failure();
for (int k = 0, e = b.getNumEqualities(); k < e; ++k)
if (typeEquality(b.getEquality(k), simpA, redundantIneqsB, cuttingIneqsB,
negEqs)
.failed())
return failure();
/// Given a Simplex `simp` and one of its inequalities `ineq`, check
/// that the facet of `simp` where `ineq` holds as an equality is contained
/// within `a`.
bool isFacetContained(ArrayRef<int64_t> ineq, Simplex &simp);
// If there are no cutting inequalities of `a`, `b` is contained
// within `a` (and vice versa for `b`).
if (cuttingIneqsA.empty()) {
eraseDisjunct(j, disjuncts, simplices);
return success();
}
/// Adds `disjunct` to `disjuncts` and removes the disjuncts at position `i`
/// and `j`. Updates `simplices` to reflect the changes. `i` and `j` cannot
/// be equal.
void addCoalescedDisjunct(unsigned i, unsigned j,
const IntegerRelation &disjunct);
if (cuttingIneqsB.empty()) {
eraseDisjunct(i, disjuncts, simplices);
return success();
}
/// Checks whether `a` and `b` can be combined in a convex sense, if there
/// exist cutting inequalities.
///
/// An example of this case:
/// ___________ ___________
/// / / | / / /
/// \ \ | / ==> \ /
/// \ \ | / \ /
/// \___\|/ \_____/
///
///
LogicalResult coalescePairCutCase(unsigned i, unsigned j);
// Try to apply the cut case
if (coalescePairCutCase(disjuncts, simplices, i, j, redundantIneqsA,
cuttingIneqsA, redundantIneqsB, cuttingIneqsB)
.succeeded())
return success();
/// Types the inequality `ineq` according to its `IneqType` for `simp` into
/// `redundantIneqsB` and `cuttingIneqsB`. Returns success, if no separate
/// inequalities were encountered. Otherwise, returns failure.
LogicalResult typeInequality(ArrayRef<int64_t> ineq, Simplex &simp);
if (coalescePairCutCase(disjuncts, simplices, j, i, redundantIneqsB,
cuttingIneqsB, redundantIneqsA, cuttingIneqsA)
.succeeded())
return success();
/// Types the equality `eq`, i.e. for `eq` == 0, types both `eq` >= 0 and
/// -`eq` >= 0 according to their `IneqType` for `simp` into
/// `redundantIneqsB` and `cuttingIneqsB`. Returns success, if no separate
/// inequalities were encountered. Otherwise, returns failure.
LogicalResult typeEquality(ArrayRef<int64_t> eq, Simplex &simp);
return failure();
}
/// Replaces the element at position `i` with the last element and erases
/// the last element for both `disjuncts` and `simplices`.
void eraseDisjunct(unsigned i);
PresburgerRelation PresburgerRelation::coalesce() const {
PresburgerRelation newSet = PresburgerRelation::getEmpty(
getNumDomainIds(), getNumRangeIds(), getNumSymbolIds());
SmallVector<IntegerRelation, 2> disjuncts = integerRelations;
SmallVector<Simplex, 2> simplices;
/// Attempts to coalesce the two IntegerRelations at position `i` and `j`
/// in `disjuncts` in-place. Returns whether the disjuncts were
/// successfully coalesced. The simplices in `simplices` need to be the ones
/// constructed from `disjuncts`. At this point, there are no empty
/// disjuncts in `disjuncts` left.
LogicalResult coalescePair(unsigned i, unsigned j);
};
simplices.reserve(getNumDisjuncts());
/// Constructs a `SetCoalescer` from a `PresburgerRelation`. Only adds non-empty
/// `IntegerRelation`s to the `disjuncts` vector.
SetCoalescer::SetCoalescer(const PresburgerRelation &s) {
disjuncts = s.integerRelations;
simplices.reserve(s.getNumDisjuncts());
// Note that disjuncts.size() changes during the loop.
for (unsigned i = 0; i < disjuncts.size();) {
Simplex simp(disjuncts[i]);
@ -640,20 +497,31 @@ PresburgerRelation PresburgerRelation::coalesce() const {
++i;
simplices.push_back(simp);
}
numDomainIds = s.getNumDomainIds();
numRangeIds = s.getNumRangeIds();
numSymbolIds = s.getNumSymbolIds();
}
// For all tuples of IntegerRelations, check whether they can be coalesced.
// When coalescing is successful, the contained IntegerRelation is swapped
// with the last element of `disjuncts` and subsequently erased and
// similarly for simplices.
/// Simplifies the representation of a PresburgerSet.
PresburgerRelation SetCoalescer::coalesce() {
// For all tuples of IntegerRelations, check whether they can be
// coalesced. When coalescing is successful, the contained IntegerRelation
// is swapped with the last element of `disjuncts` and subsequently erased
// and similarly for simplices.
for (unsigned i = 0; i < disjuncts.size();) {
// TODO: This does some comparisons two times (index 0 with 1 and index 1
// with 0).
bool broken = false;
for (unsigned j = 0, e = disjuncts.size(); j < e; ++j) {
negEqs.clear();
redundantIneqsA.clear();
redundantIneqsB.clear();
cuttingIneqsA.clear();
cuttingIneqsB.clear();
if (i == j)
continue;
if (coalescePair(i, j, disjuncts, simplices).succeeded()) {
if (coalescePair(i, j).succeeded()) {
broken = true;
break;
}
@ -666,12 +534,202 @@ PresburgerRelation PresburgerRelation::coalesce() const {
++i;
}
PresburgerRelation newSet =
PresburgerRelation::getEmpty(numDomainIds, numRangeIds, numSymbolIds);
for (unsigned i = 0, e = disjuncts.size(); i < e; ++i)
newSet.unionInPlace(disjuncts[i]);
return newSet;
}
/// Given a Simplex `simp` and one of its inequalities `ineq`, check
/// that all inequalities of `cuttingIneqsB` are redundant for the facet of
/// `simp` where `ineq` holds as an equality is contained within `a`.
bool SetCoalescer::isFacetContained(ArrayRef<int64_t> ineq, Simplex &simp) {
unsigned snapshot = simp.getSnapshot();
simp.addEquality(ineq);
if (llvm::any_of(cuttingIneqsB, [&simp](ArrayRef<int64_t> curr) {
return !simp.isRedundantInequality(curr);
})) {
simp.rollback(snapshot);
return false;
}
simp.rollback(snapshot);
return true;
}
void SetCoalescer::addCoalescedDisjunct(unsigned i, unsigned j,
const IntegerRelation &disjunct) {
assert(i != j && "The indices must refer to different disjuncts");
unsigned n = disjuncts.size();
if (j == n - 1) {
// This case needs special handling since position `n` - 1 is removed
// from the vector, hence the `IntegerRelation` at position `n` - 2 is
// lost otherwise.
disjuncts[i] = disjuncts[n - 2];
disjuncts.pop_back();
disjuncts[n - 2] = disjunct;
simplices[i] = simplices[n - 2];
simplices.pop_back();
simplices[n - 2] = Simplex(disjunct);
} else {
// Other possible edge cases are correct since for `j` or `i` == `n` -
// 2, the `IntegerRelation` at position `n` - 2 should be lost. The
// case `i` == `n` - 1 makes the first following statement a noop.
// Hence, in this case the same thing is done as above, but with `j`
// rather than `i`.
disjuncts[i] = disjuncts[n - 1];
disjuncts[j] = disjuncts[n - 2];
disjuncts.pop_back();
disjuncts[n - 2] = disjunct;
simplices[i] = simplices[n - 1];
simplices[j] = simplices[n - 2];
simplices.pop_back();
simplices[n - 2] = Simplex(disjunct);
}
}
/// Given two polyhedra `a` and `b` at positions `i` and `j` in
/// `disjuncts` and `redundantIneqsA` being the inequalities of `a` that
/// are redundant for `b` (similarly for `cuttingIneqsA`, `redundantIneqsB`,
/// and `cuttingIneqsB`), Checks whether the facets of all cutting
/// inequalites of `a` are contained in `b`. If so, a new polyhedron
/// consisting of all redundant inequalites of `a` and `b` and all
/// equalities of both is created.
///
/// An example of this case:
/// ___________ ___________
/// / / | / / /
/// \ \ | / ==> \ /
/// \ \ | / \ /
/// \___\|/ \_____/
///
///
LogicalResult SetCoalescer::coalescePairCutCase(unsigned i, unsigned j) {
/// All inequalities of `b` need to be redundant. We already know that the
/// redundant ones are, so only the cutting ones remain to be checked.
Simplex &simp = simplices[i];
IntegerRelation &disjunct = disjuncts[i];
if (llvm::any_of(cuttingIneqsA, [this, &simp](ArrayRef<int64_t> curr) {
return !isFacetContained(curr, simp);
}))
return failure();
IntegerRelation newSet(disjunct.getNumDomainIds(), disjunct.getNumRangeIds(),
disjunct.getNumSymbolIds(), disjunct.getNumLocalIds());
for (ArrayRef<int64_t> curr : redundantIneqsA)
newSet.addInequality(curr);
for (ArrayRef<int64_t> curr : redundantIneqsB)
newSet.addInequality(curr);
addCoalescedDisjunct(i, j, newSet);
return success();
}
LogicalResult SetCoalescer::typeInequality(ArrayRef<int64_t> ineq,
Simplex &simp) {
Simplex::IneqType type = simp.findIneqType(ineq);
if (type == Simplex::IneqType::Redundant)
redundantIneqsB.push_back(ineq);
else if (type == Simplex::IneqType::Cut)
cuttingIneqsB.push_back(ineq);
else
return failure();
return success();
}
LogicalResult SetCoalescer::typeEquality(ArrayRef<int64_t> eq, Simplex &simp) {
if (typeInequality(eq, simp).failed())
return failure();
negEqs.push_back(getNegatedCoeffs(eq));
ArrayRef<int64_t> inv(negEqs.back());
if (typeInequality(inv, simp).failed())
return failure();
return success();
}
void SetCoalescer::eraseDisjunct(unsigned i) {
assert(simplices.size() == disjuncts.size() &&
"simplices and disjuncts must be equally as long");
disjuncts[i] = disjuncts.back();
disjuncts.pop_back();
simplices[i] = simplices.back();
simplices.pop_back();
}
LogicalResult SetCoalescer::coalescePair(unsigned i, unsigned j) {
IntegerRelation &a = disjuncts[i];
IntegerRelation &b = disjuncts[j];
/// Handling of local ids is not yet implemented, so these cases are
/// skipped.
/// TODO: implement local id support.
if (a.getNumLocalIds() != 0 || b.getNumLocalIds() != 0)
return failure();
Simplex &simpA = simplices[i];
Simplex &simpB = simplices[j];
// Organize all inequalities and equalities of `a` according to their type
// for `b` into `redundantIneqsA` and `cuttingIneqsA` (and vice versa for
// all inequalities of `b` according to their type in `a`). If a separate
// inequality is encountered during typing, the two IntegerRelations
// cannot be coalesced.
for (int k = 0, e = a.getNumInequalities(); k < e; ++k)
if (typeInequality(a.getInequality(k), simpB).failed())
return failure();
for (int k = 0, e = a.getNumEqualities(); k < e; ++k)
if (typeEquality(a.getEquality(k), simpB).failed())
return failure();
std::swap(redundantIneqsA, redundantIneqsB);
std::swap(cuttingIneqsA, cuttingIneqsB);
for (int k = 0, e = b.getNumInequalities(); k < e; ++k)
if (typeInequality(b.getInequality(k), simpA).failed())
return failure();
for (int k = 0, e = b.getNumEqualities(); k < e; ++k)
if (typeEquality(b.getEquality(k), simpA).failed())
return failure();
// If there are no cutting inequalities of `a`, `b` is contained
// within `a`.
if (cuttingIneqsA.empty()) {
eraseDisjunct(j);
return success();
}
// Try to apply the cut case
if (coalescePairCutCase(i, j).succeeded())
return success();
// Swap the vectors to compare the pair (j,i) instead of (i,j).
std::swap(redundantIneqsA, redundantIneqsB);
std::swap(cuttingIneqsA, cuttingIneqsB);
// If there are no cutting inequalities of `a`, `b` is contained
// within `a`.
if (cuttingIneqsA.empty()) {
eraseDisjunct(i);
return success();
}
// Try to apply the cut case
if (coalescePairCutCase(j, i).succeeded())
return success();
return failure();
}
PresburgerRelation PresburgerRelation::coalesce() const {
return SetCoalescer(*this).coalesce();
}
void PresburgerRelation::print(raw_ostream &os) const {
os << "Number of Disjuncts: " << getNumDisjuncts() << "\n";
for (const IntegerRelation &disjunct : integerRelations) {