[mlir][sparse] Adding new Merger::addLat overload

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D146559
This commit is contained in:
wren romano 2023-03-21 13:13:42 -07:00
parent 087b5f3277
commit 13e9afd16d
2 changed files with 9 additions and 2 deletions

View File

@ -280,6 +280,7 @@ public:
/// Constructs a new iteration lattice point, and returns its identifier.
LatPointId addLat(TensorId t, LoopId i, ExprId e);
LatPointId addLat(const BitVector &bits, ExprId e);
/// Constructs a new (initially empty) set, and returns its identifier.
LatSetId addSet();

View File

@ -247,6 +247,13 @@ LatPointId Merger::addLat(TensorId t, LoopId i, ExprId e) {
return p;
}
LatPointId Merger::addLat(const BitVector &bits, ExprId e) {
assert(bits.size() == numLoops * numTensors);
const LatPointId p = latPoints.size();
latPoints.emplace_back(bits, e);
return p;
}
LatSetId Merger::addSet() {
const LatSetId s = latSets.size();
latSets.emplace_back();
@ -322,8 +329,7 @@ LatSetId Merger::mapSet(TensorExp::Kind kind, LatSetId s0, Value v,
const LatSetId s = addSet();
for (const LatPointId p : latSets[s0]) {
const ExprId e = addExp(kind, latPoints[p].exp, v, op);
latPoints.emplace_back(latPoints[p].bits, e);
latSets[s].push_back(latPoints.size() - 1);
latSets[s].push_back(addLat(latPoints[p].bits, e));
}
return s;
}