[MLIR] Avoid creation of buggy affine maps when incorrect values of number of dimensions and number of symbols are provided.

We check whether the maximum index of dimensional identifier present
in the result expressions is less than dimCount (number of dimensional
identifiers) argument passed in the AffineMap::get() and the maximum index
of symbolic identifier present in the result expressions is less than
symbolCount (number of symbolic identifiers) argument passed in AffineMap::get().

Reviewed By: nicolasvasilache, bondhugula

Differential Revision: https://reviews.llvm.org/D114238
This commit is contained in:
Arnab Dutta 2021-11-27 00:36:09 +05:30 committed by Uday Bondhugula
parent e4e4da86af
commit c2280b5517
3 changed files with 42 additions and 15 deletions

View File

@ -546,6 +546,23 @@ SmallVector<T> applyPermutationMap(AffineMap map, llvm::ArrayRef<T> source) {
return result;
}
/// Calculates maxmimum dimension and symbol positions from the expressions
/// in `exprsLists` and stores them in `maxDim` and `maxSym` respectively.
template <typename AffineExprContainer>
static void getMaxDimAndSymbol(ArrayRef<AffineExprContainer> exprsList,
int64_t &maxDim, int64_t &maxSym) {
for (const auto &exprs : exprsList) {
for (auto expr : exprs) {
expr.walk([&maxDim, &maxSym](AffineExpr e) {
if (auto d = e.dyn_cast<AffineDimExpr>())
maxDim = std::max(maxDim, static_cast<int64_t>(d.getPosition()));
if (auto s = e.dyn_cast<AffineSymbolExpr>())
maxSym = std::max(maxSym, static_cast<int64_t>(s.getPosition()));
});
}
}
}
inline raw_ostream &operator<<(raw_ostream &os, AffineMap map) {
map.print(os);
return os;

View File

@ -215,21 +215,6 @@ AffineMap AffineMap::getPermutationMap(ArrayRef<unsigned> permutation,
return permutationMap;
}
template <typename AffineExprContainer>
static void getMaxDimAndSymbol(ArrayRef<AffineExprContainer> exprsList,
int64_t &maxDim, int64_t &maxSym) {
for (const auto &exprs : exprsList) {
for (auto expr : exprs) {
expr.walk([&maxDim, &maxSym](AffineExpr e) {
if (auto d = e.dyn_cast<AffineDimExpr>())
maxDim = std::max(maxDim, static_cast<int64_t>(d.getPosition()));
if (auto s = e.dyn_cast<AffineSymbolExpr>())
maxSym = std::max(maxSym, static_cast<int64_t>(s.getPosition()));
});
}
}
}
template <typename AffineExprContainer>
static SmallVector<AffineMap, 4>
inferFromExprList(ArrayRef<AffineExprContainer> exprsList) {

View File

@ -1012,6 +1012,29 @@ AffineMap AffineMap::getImpl(unsigned dimCount, unsigned symbolCount,
});
}
/// Check whether the arguments passed to the AffineMap::get() are consistent.
/// This method checks whether the highest index of dimensional identifier
/// present in result expressions is less than `dimCount` and the highest index
/// of symbolic identifier present in result expressions is less than
/// `symbolCount`.
[[nodiscard]] static bool willBeValidAffineMap(unsigned dimCount,
unsigned symbolCount,
ArrayRef<AffineExpr> results) {
int64_t maxDimPosition = -1;
int64_t maxSymbolPosition = -1;
getMaxDimAndSymbol(ArrayRef<ArrayRef<AffineExpr>>(results), maxDimPosition,
maxSymbolPosition);
if ((maxDimPosition >= dimCount) || (maxSymbolPosition >= symbolCount)) {
LLVM_DEBUG(
llvm::dbgs()
<< "maximum dimensional identifier position in result expression must "
"be less than `dimCount` and maximum symbolic identifier position "
"in result expression must be less than `symbolCount`\n");
return false;
}
return true;
}
AffineMap AffineMap::get(MLIRContext *context) {
return getImpl(/*dimCount=*/0, /*symbolCount=*/0, /*results=*/{}, context);
}
@ -1023,11 +1046,13 @@ AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
AffineExpr result) {
assert(willBeValidAffineMap(dimCount, symbolCount, {result}));
return getImpl(dimCount, symbolCount, {result}, result.getContext());
}
AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
ArrayRef<AffineExpr> results, MLIRContext *context) {
assert(willBeValidAffineMap(dimCount, symbolCount, results));
return getImpl(dimCount, symbolCount, results, context);
}