mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-11-27 07:31:28 +00:00
When performing constant folding on the affineApplyOp, there is a division of 0 in the affine map. [related issue](https://github.com/llvm/llvm-project/issues/64622) --------- Co-authored-by: Javier Setoain <jsetoain@users.noreply.github.com>
This commit is contained in:
parent
c093383ffa
commit
dc4786b487
@ -14,6 +14,7 @@
|
||||
#define MLIR_IR_AFFINEEXPRVISITOR_H
|
||||
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
|
||||
namespace mlir {
|
||||
@ -65,88 +66,41 @@ namespace mlir {
|
||||
/// just as efficient as having your own switch instruction over the instruction
|
||||
/// opcode.
|
||||
|
||||
template <typename SubClass, typename RetTy = void>
|
||||
class AffineExprVisitor {
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Interface code - This is the public interface of the AffineExprVisitor
|
||||
// that you use to visit affine expressions...
|
||||
template <typename SubClass, typename RetTy>
|
||||
class AffineExprVisitorBase {
|
||||
public:
|
||||
// Function to walk an AffineExpr (in post order).
|
||||
RetTy walkPostOrder(AffineExpr expr) {
|
||||
static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
|
||||
"Must instantiate with a derived type of AffineExprVisitor");
|
||||
switch (expr.getKind()) {
|
||||
case AffineExprKind::Add: {
|
||||
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
|
||||
walkOperandsPostOrder(binOpExpr);
|
||||
return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
|
||||
}
|
||||
case AffineExprKind::Mul: {
|
||||
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
|
||||
walkOperandsPostOrder(binOpExpr);
|
||||
return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
|
||||
}
|
||||
case AffineExprKind::Mod: {
|
||||
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
|
||||
walkOperandsPostOrder(binOpExpr);
|
||||
return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
|
||||
}
|
||||
case AffineExprKind::FloorDiv: {
|
||||
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
|
||||
walkOperandsPostOrder(binOpExpr);
|
||||
return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
|
||||
}
|
||||
case AffineExprKind::CeilDiv: {
|
||||
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
|
||||
walkOperandsPostOrder(binOpExpr);
|
||||
return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
|
||||
}
|
||||
case AffineExprKind::Constant:
|
||||
return static_cast<SubClass *>(this)->visitConstantExpr(
|
||||
cast<AffineConstantExpr>(expr));
|
||||
case AffineExprKind::DimId:
|
||||
return static_cast<SubClass *>(this)->visitDimExpr(
|
||||
cast<AffineDimExpr>(expr));
|
||||
case AffineExprKind::SymbolId:
|
||||
return static_cast<SubClass *>(this)->visitSymbolExpr(
|
||||
cast<AffineSymbolExpr>(expr));
|
||||
}
|
||||
}
|
||||
|
||||
// Function to visit an AffineExpr.
|
||||
RetTy visit(AffineExpr expr) {
|
||||
static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
|
||||
static_assert(std::is_base_of<AffineExprVisitorBase, SubClass>::value,
|
||||
"Must instantiate with a derived type of AffineExprVisitor");
|
||||
auto self = static_cast<SubClass *>(this);
|
||||
switch (expr.getKind()) {
|
||||
case AffineExprKind::Add: {
|
||||
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
|
||||
return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
|
||||
return self->visitAddExpr(binOpExpr);
|
||||
}
|
||||
case AffineExprKind::Mul: {
|
||||
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
|
||||
return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
|
||||
return self->visitMulExpr(binOpExpr);
|
||||
}
|
||||
case AffineExprKind::Mod: {
|
||||
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
|
||||
return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
|
||||
return self->visitModExpr(binOpExpr);
|
||||
}
|
||||
case AffineExprKind::FloorDiv: {
|
||||
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
|
||||
return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
|
||||
return self->visitFloorDivExpr(binOpExpr);
|
||||
}
|
||||
case AffineExprKind::CeilDiv: {
|
||||
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
|
||||
return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
|
||||
return self->visitCeilDivExpr(binOpExpr);
|
||||
}
|
||||
case AffineExprKind::Constant:
|
||||
return static_cast<SubClass *>(this)->visitConstantExpr(
|
||||
cast<AffineConstantExpr>(expr));
|
||||
return self->visitConstantExpr(cast<AffineConstantExpr>(expr));
|
||||
case AffineExprKind::DimId:
|
||||
return static_cast<SubClass *>(this)->visitDimExpr(
|
||||
cast<AffineDimExpr>(expr));
|
||||
return self->visitDimExpr(cast<AffineDimExpr>(expr));
|
||||
case AffineExprKind::SymbolId:
|
||||
return static_cast<SubClass *>(this)->visitSymbolExpr(
|
||||
cast<AffineSymbolExpr>(expr));
|
||||
return self->visitSymbolExpr(cast<AffineSymbolExpr>(expr));
|
||||
}
|
||||
llvm_unreachable("Unknown AffineExpr");
|
||||
}
|
||||
@ -180,6 +134,54 @@ public:
|
||||
RetTy visitConstantExpr(AffineConstantExpr expr) { return RetTy(); }
|
||||
RetTy visitDimExpr(AffineDimExpr expr) { return RetTy(); }
|
||||
RetTy visitSymbolExpr(AffineSymbolExpr expr) { return RetTy(); }
|
||||
};
|
||||
|
||||
template <typename SubClass, typename RetTy = void>
|
||||
class AffineExprVisitor : public AffineExprVisitorBase<SubClass, RetTy> {
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Interface code - This is the public interface of the AffineExprVisitor
|
||||
// that you use to visit affine expressions...
|
||||
public:
|
||||
// Function to walk an AffineExpr (in post order).
|
||||
RetTy walkPostOrder(AffineExpr expr) {
|
||||
static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
|
||||
"Must instantiate with a derived type of AffineExprVisitor");
|
||||
auto self = static_cast<SubClass *>(this);
|
||||
switch (expr.getKind()) {
|
||||
case AffineExprKind::Add: {
|
||||
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
|
||||
walkOperandsPostOrder(binOpExpr);
|
||||
return self->visitAddExpr(binOpExpr);
|
||||
}
|
||||
case AffineExprKind::Mul: {
|
||||
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
|
||||
walkOperandsPostOrder(binOpExpr);
|
||||
return self->visitMulExpr(binOpExpr);
|
||||
}
|
||||
case AffineExprKind::Mod: {
|
||||
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
|
||||
walkOperandsPostOrder(binOpExpr);
|
||||
return self->visitModExpr(binOpExpr);
|
||||
}
|
||||
case AffineExprKind::FloorDiv: {
|
||||
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
|
||||
walkOperandsPostOrder(binOpExpr);
|
||||
return self->visitFloorDivExpr(binOpExpr);
|
||||
}
|
||||
case AffineExprKind::CeilDiv: {
|
||||
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
|
||||
walkOperandsPostOrder(binOpExpr);
|
||||
return self->visitCeilDivExpr(binOpExpr);
|
||||
}
|
||||
case AffineExprKind::Constant:
|
||||
return self->visitConstantExpr(cast<AffineConstantExpr>(expr));
|
||||
case AffineExprKind::DimId:
|
||||
return self->visitDimExpr(cast<AffineDimExpr>(expr));
|
||||
case AffineExprKind::SymbolId:
|
||||
return self->visitSymbolExpr(cast<AffineSymbolExpr>(expr));
|
||||
}
|
||||
llvm_unreachable("Unknown AffineExpr");
|
||||
}
|
||||
|
||||
private:
|
||||
// Walk the operands - each operand is itself walked in post order.
|
||||
@ -189,6 +191,70 @@ private:
|
||||
}
|
||||
};
|
||||
|
||||
template <typename SubClass>
|
||||
class AffineExprVisitor<SubClass, LogicalResult>
|
||||
: public AffineExprVisitorBase<SubClass, LogicalResult> {
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Interface code - This is the public interface of the AffineExprVisitor
|
||||
// that you use to visit affine expressions...
|
||||
public:
|
||||
// Function to walk an AffineExpr (in post order).
|
||||
LogicalResult walkPostOrder(AffineExpr expr) {
|
||||
static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
|
||||
"Must instantiate with a derived type of AffineExprVisitor");
|
||||
auto self = static_cast<SubClass *>(this);
|
||||
switch (expr.getKind()) {
|
||||
case AffineExprKind::Add: {
|
||||
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
|
||||
if (failed(walkOperandsPostOrder(binOpExpr)))
|
||||
return failure();
|
||||
return self->visitAddExpr(binOpExpr);
|
||||
}
|
||||
case AffineExprKind::Mul: {
|
||||
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
|
||||
if (failed(walkOperandsPostOrder(binOpExpr)))
|
||||
return failure();
|
||||
return self->visitMulExpr(binOpExpr);
|
||||
}
|
||||
case AffineExprKind::Mod: {
|
||||
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
|
||||
if (failed(walkOperandsPostOrder(binOpExpr)))
|
||||
return failure();
|
||||
return self->visitModExpr(binOpExpr);
|
||||
}
|
||||
case AffineExprKind::FloorDiv: {
|
||||
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
|
||||
if (failed(walkOperandsPostOrder(binOpExpr)))
|
||||
return failure();
|
||||
return self->visitFloorDivExpr(binOpExpr);
|
||||
}
|
||||
case AffineExprKind::CeilDiv: {
|
||||
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
|
||||
if (failed(walkOperandsPostOrder(binOpExpr)))
|
||||
return failure();
|
||||
return self->visitCeilDivExpr(binOpExpr);
|
||||
}
|
||||
case AffineExprKind::Constant:
|
||||
return self->visitConstantExpr(cast<AffineConstantExpr>(expr));
|
||||
case AffineExprKind::DimId:
|
||||
return self->visitDimExpr(cast<AffineDimExpr>(expr));
|
||||
case AffineExprKind::SymbolId:
|
||||
return self->visitSymbolExpr(cast<AffineSymbolExpr>(expr));
|
||||
}
|
||||
llvm_unreachable("Unknown AffineExpr");
|
||||
}
|
||||
|
||||
private:
|
||||
// Walk the operands - each operand is itself walked in post order.
|
||||
LogicalResult walkOperandsPostOrder(AffineBinaryOpExpr expr) {
|
||||
if (failed(walkPostOrder(expr.getLHS())))
|
||||
return failure();
|
||||
if (failed(walkPostOrder(expr.getRHS())))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// This class is used to flatten a pure affine expression (AffineExpr,
|
||||
// which is in a tree form) into a sum of products (w.r.t constants) when
|
||||
// possible, and in that process simplifying the expression. For a modulo,
|
||||
@ -246,7 +312,7 @@ private:
|
||||
// expressions are mapped to the same local identifier (same column position in
|
||||
// 'localVarCst').
|
||||
class SimpleAffineExprFlattener
|
||||
: public AffineExprVisitor<SimpleAffineExprFlattener> {
|
||||
: public AffineExprVisitor<SimpleAffineExprFlattener, LogicalResult> {
|
||||
public:
|
||||
// Flattend expression layout: [dims, symbols, locals, constant]
|
||||
// Stack that holds the LHS and RHS operands while visiting a binary op expr.
|
||||
@ -275,13 +341,13 @@ public:
|
||||
virtual ~SimpleAffineExprFlattener() = default;
|
||||
|
||||
// Visitor method overrides.
|
||||
void visitMulExpr(AffineBinaryOpExpr expr);
|
||||
void visitAddExpr(AffineBinaryOpExpr expr);
|
||||
void visitDimExpr(AffineDimExpr expr);
|
||||
void visitSymbolExpr(AffineSymbolExpr expr);
|
||||
void visitConstantExpr(AffineConstantExpr expr);
|
||||
void visitCeilDivExpr(AffineBinaryOpExpr expr);
|
||||
void visitFloorDivExpr(AffineBinaryOpExpr expr);
|
||||
LogicalResult visitMulExpr(AffineBinaryOpExpr expr);
|
||||
LogicalResult visitAddExpr(AffineBinaryOpExpr expr);
|
||||
LogicalResult visitDimExpr(AffineDimExpr expr);
|
||||
LogicalResult visitSymbolExpr(AffineSymbolExpr expr);
|
||||
LogicalResult visitConstantExpr(AffineConstantExpr expr);
|
||||
LogicalResult visitCeilDivExpr(AffineBinaryOpExpr expr);
|
||||
LogicalResult visitFloorDivExpr(AffineBinaryOpExpr expr);
|
||||
|
||||
//
|
||||
// t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1
|
||||
@ -289,7 +355,7 @@ public:
|
||||
// A mod expression "expr mod c" is thus flattened by introducing a new local
|
||||
// variable q (= expr floordiv c), such that expr mod c is replaced with
|
||||
// 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
|
||||
void visitModExpr(AffineBinaryOpExpr expr);
|
||||
LogicalResult visitModExpr(AffineBinaryOpExpr expr);
|
||||
|
||||
protected:
|
||||
// Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
|
||||
@ -328,7 +394,7 @@ private:
|
||||
//
|
||||
// A ceildiv is similarly flattened:
|
||||
// t = expr ceildiv c <=> t = (expr + c - 1) floordiv c
|
||||
void visitDivExpr(AffineBinaryOpExpr expr, bool isCeil);
|
||||
LogicalResult visitDivExpr(AffineBinaryOpExpr expr, bool isCeil);
|
||||
|
||||
int findLocalId(AffineExpr localExpr);
|
||||
|
||||
|
@ -310,7 +310,8 @@ public:
|
||||
/// Folds the results of the application of an affine map on the provided
|
||||
/// operands to a constant if possible.
|
||||
LogicalResult constantFold(ArrayRef<Attribute> operandConstants,
|
||||
SmallVectorImpl<Attribute> &results) const;
|
||||
SmallVectorImpl<Attribute> &results,
|
||||
bool *hasPoison = nullptr) const;
|
||||
|
||||
/// Propagates the constant operands into this affine map. Operands are
|
||||
/// allowed to be null, at which point they are treated as non-constant. This
|
||||
@ -318,9 +319,9 @@ public:
|
||||
/// which may be equal to the old map if no folding happened. If `results` is
|
||||
/// provided and if all expressions in the map were folded to constants,
|
||||
/// `results` will contain the values of these constants.
|
||||
AffineMap
|
||||
partialConstantFold(ArrayRef<Attribute> operandConstants,
|
||||
SmallVectorImpl<int64_t> *results = nullptr) const;
|
||||
AffineMap partialConstantFold(ArrayRef<Attribute> operandConstants,
|
||||
SmallVectorImpl<int64_t> *results = nullptr,
|
||||
bool *hasPoison = nullptr) const;
|
||||
|
||||
/// Returns the AffineMap resulting from composing `this` with `map`.
|
||||
/// The resulting AffineMap has as many AffineDimExpr as `map` and as many
|
||||
|
@ -67,7 +67,9 @@ private:
|
||||
} // namespace
|
||||
|
||||
// Flattens the expressions in map. Returns failure if 'expr' was unable to be
|
||||
// flattened (i.e., semi-affine expressions not handled yet).
|
||||
// flattened. For example two specific cases:
|
||||
// 1. semi-affine expressions not handled yet.
|
||||
// 2. has poison expression (i.e., division by zero).
|
||||
static LogicalResult
|
||||
getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims,
|
||||
unsigned numSymbols,
|
||||
@ -85,8 +87,10 @@ getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims,
|
||||
for (auto expr : exprs) {
|
||||
if (!expr.isPureAffine())
|
||||
return failure();
|
||||
|
||||
flattener.walkPostOrder(expr);
|
||||
// has poison expression
|
||||
auto flattenResult = flattener.walkPostOrder(expr);
|
||||
if (failed(flattenResult))
|
||||
return failure();
|
||||
}
|
||||
|
||||
assert(flattener.operandExprStack.size() == exprs.size());
|
||||
|
@ -9,6 +9,7 @@
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/UB/IR/UBOps.h"
|
||||
#include "mlir/IR/AffineExprVisitor.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
@ -226,6 +227,8 @@ void AffineDialect::initialize() {
|
||||
Operation *AffineDialect::materializeConstant(OpBuilder &builder,
|
||||
Attribute value, Type type,
|
||||
Location loc) {
|
||||
if (auto poison = dyn_cast<ub::PoisonAttr>(value))
|
||||
return builder.create<ub::PoisonOp>(loc, type, poison);
|
||||
return arith::ConstantOp::materialize(builder, value, type, loc);
|
||||
}
|
||||
|
||||
@ -580,7 +583,12 @@ OpFoldResult AffineApplyOp::fold(FoldAdaptor adaptor) {
|
||||
|
||||
// Otherwise, default to folding the map.
|
||||
SmallVector<Attribute, 1> result;
|
||||
if (failed(map.constantFold(adaptor.getMapOperands(), result)))
|
||||
bool hasPoison = false;
|
||||
auto foldResult =
|
||||
map.constantFold(adaptor.getMapOperands(), result, &hasPoison);
|
||||
if (hasPoison)
|
||||
return ub::PoisonAttr::get(getContext());
|
||||
if (failed(foldResult))
|
||||
return {};
|
||||
return result[0];
|
||||
}
|
||||
@ -3379,7 +3387,9 @@ static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map) {
|
||||
return failure();
|
||||
|
||||
SimpleAffineExprFlattener flattener(map.getNumDims(), map.getNumSymbols());
|
||||
flattener.walkPostOrder(resultExpr);
|
||||
auto flattenResult = flattener.walkPostOrder(resultExpr);
|
||||
if (failed(flattenResult))
|
||||
return failure();
|
||||
|
||||
// Fail if the flattened expression has local variables.
|
||||
if (flattener.operandExprStack.back().size() !=
|
||||
|
@ -19,5 +19,6 @@ add_mlir_dialect_library(MLIRAffineDialect
|
||||
MLIRMemRefDialect
|
||||
MLIRShapedOpInterfaces
|
||||
MLIRSideEffectInterfaces
|
||||
MLIRUBDialect
|
||||
MLIRValueBoundsOpInterface
|
||||
)
|
||||
|
@ -1216,7 +1216,7 @@ SimpleAffineExprFlattener::SimpleAffineExprFlattener(unsigned numDims,
|
||||
// In case of semi affine multiplication expressions, t = expr * symbolic_expr,
|
||||
// introduce a local variable p (= expr * symbolic_expr), and the affine
|
||||
// expression expr * symbolic_expr is added to `localExprs`.
|
||||
void SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
|
||||
LogicalResult SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
|
||||
assert(operandExprStack.size() >= 2);
|
||||
SmallVector<int64_t, 8> rhs = operandExprStack.back();
|
||||
operandExprStack.pop_back();
|
||||
@ -1232,7 +1232,7 @@ void SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
|
||||
AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
|
||||
localExprs, context);
|
||||
addLocalVariableSemiAffine(a * b, lhs, lhs.size());
|
||||
return;
|
||||
return success();
|
||||
}
|
||||
|
||||
// Get the RHS constant.
|
||||
@ -1240,9 +1240,10 @@ void SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
|
||||
for (unsigned i = 0, e = lhs.size(); i < e; i++) {
|
||||
lhs[i] *= rhsConst;
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
void SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) {
|
||||
LogicalResult SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) {
|
||||
assert(operandExprStack.size() >= 2);
|
||||
const auto &rhs = operandExprStack.back();
|
||||
auto &lhs = operandExprStack[operandExprStack.size() - 2];
|
||||
@ -1253,6 +1254,7 @@ void SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) {
|
||||
}
|
||||
// Pop off the RHS.
|
||||
operandExprStack.pop_back();
|
||||
return success();
|
||||
}
|
||||
|
||||
//
|
||||
@ -1265,7 +1267,7 @@ void SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) {
|
||||
// In case of semi-affine modulo expressions, t = expr mod symbolic_expr,
|
||||
// introduce a local variable m (= expr mod symbolic_expr), and the affine
|
||||
// expression expr mod symbolic_expr is added to `localExprs`.
|
||||
void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
|
||||
LogicalResult SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
|
||||
assert(operandExprStack.size() >= 2);
|
||||
|
||||
SmallVector<int64_t, 8> rhs = operandExprStack.back();
|
||||
@ -1283,13 +1285,12 @@ void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
|
||||
localExprs, context);
|
||||
AffineExpr modExpr = dividendExpr % divisorExpr;
|
||||
addLocalVariableSemiAffine(modExpr, lhs, lhs.size());
|
||||
return;
|
||||
return success();
|
||||
}
|
||||
|
||||
int64_t rhsConst = rhs[getConstantIndex()];
|
||||
// TODO: handle modulo by zero case when this issue is fixed
|
||||
// at the other places in the IR.
|
||||
assert(rhsConst > 0 && "RHS constant has to be positive");
|
||||
if (rhsConst <= 0)
|
||||
return failure();
|
||||
|
||||
// Check if the LHS expression is a multiple of modulo factor.
|
||||
unsigned i, e;
|
||||
@ -1299,7 +1300,7 @@ void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
|
||||
// If yes, modulo expression here simplifies to zero.
|
||||
if (i == lhs.size()) {
|
||||
std::fill(lhs.begin(), lhs.end(), 0);
|
||||
return;
|
||||
return success();
|
||||
}
|
||||
|
||||
// Add a local variable for the quotient, i.e., expr % c is replaced by
|
||||
@ -1331,33 +1332,41 @@ void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
|
||||
// Reuse the existing local id.
|
||||
lhs[getLocalVarStartIndex() + loc] = -rhsConst;
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
void SimpleAffineExprFlattener::visitCeilDivExpr(AffineBinaryOpExpr expr) {
|
||||
visitDivExpr(expr, /*isCeil=*/true);
|
||||
LogicalResult
|
||||
SimpleAffineExprFlattener::visitCeilDivExpr(AffineBinaryOpExpr expr) {
|
||||
return visitDivExpr(expr, /*isCeil=*/true);
|
||||
}
|
||||
void SimpleAffineExprFlattener::visitFloorDivExpr(AffineBinaryOpExpr expr) {
|
||||
visitDivExpr(expr, /*isCeil=*/false);
|
||||
LogicalResult
|
||||
SimpleAffineExprFlattener::visitFloorDivExpr(AffineBinaryOpExpr expr) {
|
||||
return visitDivExpr(expr, /*isCeil=*/false);
|
||||
}
|
||||
|
||||
void SimpleAffineExprFlattener::visitDimExpr(AffineDimExpr expr) {
|
||||
LogicalResult SimpleAffineExprFlattener::visitDimExpr(AffineDimExpr expr) {
|
||||
operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
|
||||
auto &eq = operandExprStack.back();
|
||||
assert(expr.getPosition() < numDims && "Inconsistent number of dims");
|
||||
eq[getDimStartIndex() + expr.getPosition()] = 1;
|
||||
return success();
|
||||
}
|
||||
|
||||
void SimpleAffineExprFlattener::visitSymbolExpr(AffineSymbolExpr expr) {
|
||||
LogicalResult
|
||||
SimpleAffineExprFlattener::visitSymbolExpr(AffineSymbolExpr expr) {
|
||||
operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
|
||||
auto &eq = operandExprStack.back();
|
||||
assert(expr.getPosition() < numSymbols && "inconsistent number of symbols");
|
||||
eq[getSymbolStartIndex() + expr.getPosition()] = 1;
|
||||
return success();
|
||||
}
|
||||
|
||||
void SimpleAffineExprFlattener::visitConstantExpr(AffineConstantExpr expr) {
|
||||
LogicalResult
|
||||
SimpleAffineExprFlattener::visitConstantExpr(AffineConstantExpr expr) {
|
||||
operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
|
||||
auto &eq = operandExprStack.back();
|
||||
eq[getConstantIndex()] = expr.getValue();
|
||||
return success();
|
||||
}
|
||||
|
||||
void SimpleAffineExprFlattener::addLocalVariableSemiAffine(
|
||||
@ -1388,8 +1397,8 @@ void SimpleAffineExprFlattener::addLocalVariableSemiAffine(
|
||||
// or t = expr ceildiv symbolic_expr, introduce a local variable q (= expr
|
||||
// floordiv/ceildiv symbolic_expr), and the affine floordiv/ceildiv is added to
|
||||
// `localExprs`.
|
||||
void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
|
||||
bool isCeil) {
|
||||
LogicalResult SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
|
||||
bool isCeil) {
|
||||
assert(operandExprStack.size() >= 2);
|
||||
|
||||
MLIRContext *context = expr.getContext();
|
||||
@ -1407,14 +1416,13 @@ void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
|
||||
localExprs, context);
|
||||
AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
|
||||
addLocalVariableSemiAffine(divExpr, lhs, lhs.size());
|
||||
return;
|
||||
return success();
|
||||
}
|
||||
|
||||
// This is a pure affine expr; the RHS is a positive constant.
|
||||
int64_t rhsConst = rhs[getConstantIndex()];
|
||||
// TODO: handle division by zero at the same time the issue is
|
||||
// fixed at other places.
|
||||
assert(rhsConst > 0 && "RHS constant has to be positive");
|
||||
if (rhsConst <= 0)
|
||||
return failure();
|
||||
|
||||
// Simplify the floordiv, ceildiv if possible by canceling out the greatest
|
||||
// common divisors of the numerator and denominator.
|
||||
@ -1430,7 +1438,7 @@ void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
|
||||
// If the divisor becomes 1, the updated LHS is the result. (The
|
||||
// divisor can't be negative since rhsConst is positive).
|
||||
if (divisor == 1)
|
||||
return;
|
||||
return success();
|
||||
|
||||
// If the divisor cannot be simplified to one, we will have to retain
|
||||
// the ceil/floor expr (simplified up until here). Add an existential
|
||||
@ -1460,6 +1468,7 @@ void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
|
||||
lhs[getLocalVarStartIndex() + numLocals - 1] = 1;
|
||||
else
|
||||
lhs[getLocalVarStartIndex() + loc] = 1;
|
||||
return success();
|
||||
}
|
||||
|
||||
// Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
|
||||
@ -1500,7 +1509,9 @@ AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims,
|
||||
expr = simplifySemiAffine(expr, numDims, numSymbols);
|
||||
|
||||
SimpleAffineExprFlattener flattener(numDims, numSymbols);
|
||||
flattener.walkPostOrder(expr);
|
||||
// has poison expression
|
||||
if (failed(flattener.walkPostOrder(expr)))
|
||||
return expr;
|
||||
ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
|
||||
if (!expr.isPureAffine() &&
|
||||
expr == getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
|
||||
@ -1573,7 +1584,10 @@ std::optional<int64_t> mlir::getBoundForAffineExpr(
|
||||
}
|
||||
// Flatten the expression.
|
||||
SimpleAffineExprFlattener flattener(numDims, numSymbols);
|
||||
flattener.walkPostOrder(expr);
|
||||
auto simpleResult = flattener.walkPostOrder(expr);
|
||||
// has poison expression
|
||||
if (failed(simpleResult))
|
||||
return std::nullopt;
|
||||
ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
|
||||
// TODO: Handle local variables. We can get hold of flattener.localExprs and
|
||||
// get bound on the local expr recursively.
|
||||
|
@ -8,6 +8,7 @@
|
||||
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "AffineMapDetail.h"
|
||||
#include "mlir/Dialect/UB/IR/UBOps.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
@ -59,13 +60,34 @@ private:
|
||||
expr, [](int64_t lhs, int64_t rhs) { return lhs * rhs; });
|
||||
case AffineExprKind::Mod:
|
||||
return constantFoldBinExpr(
|
||||
expr, [](int64_t lhs, int64_t rhs) { return mod(lhs, rhs); });
|
||||
expr,
|
||||
[expr, this](int64_t lhs, int64_t rhs) -> std::optional<int64_t> {
|
||||
if (rhs < 1) {
|
||||
hasPoison_ = true;
|
||||
return std::nullopt;
|
||||
}
|
||||
return mod(lhs, rhs);
|
||||
});
|
||||
case AffineExprKind::FloorDiv:
|
||||
return constantFoldBinExpr(
|
||||
expr, [](int64_t lhs, int64_t rhs) { return floorDiv(lhs, rhs); });
|
||||
expr,
|
||||
[expr, this](int64_t lhs, int64_t rhs) -> std::optional<int64_t> {
|
||||
if (rhs == 0) {
|
||||
hasPoison_ = true;
|
||||
return std::nullopt;
|
||||
}
|
||||
return floorDiv(lhs, rhs);
|
||||
});
|
||||
case AffineExprKind::CeilDiv:
|
||||
return constantFoldBinExpr(
|
||||
expr, [](int64_t lhs, int64_t rhs) { return ceilDiv(lhs, rhs); });
|
||||
expr,
|
||||
[expr, this](int64_t lhs, int64_t rhs) -> std::optional<int64_t> {
|
||||
if (rhs == 0) {
|
||||
hasPoison_ = true;
|
||||
return std::nullopt;
|
||||
}
|
||||
return ceilDiv(lhs, rhs);
|
||||
});
|
||||
case AffineExprKind::Constant:
|
||||
return cast<AffineConstantExpr>(expr).getValue();
|
||||
case AffineExprKind::DimId:
|
||||
@ -387,12 +409,12 @@ std::optional<unsigned> AffineMap::getResultPosition(AffineExpr input) const {
|
||||
/// Folds the results of the application of an affine map on the provided
|
||||
/// operands to a constant if possible. Returns false if the folding happens,
|
||||
/// true otherwise.
|
||||
LogicalResult
|
||||
AffineMap::constantFold(ArrayRef<Attribute> operandConstants,
|
||||
SmallVectorImpl<Attribute> &results) const {
|
||||
LogicalResult AffineMap::constantFold(ArrayRef<Attribute> operandConstants,
|
||||
SmallVectorImpl<Attribute> &results,
|
||||
bool *hasPoison) const {
|
||||
// Attempt partial folding.
|
||||
SmallVector<int64_t, 2> integers;
|
||||
partialConstantFold(operandConstants, &integers);
|
||||
partialConstantFold(operandConstants, &integers, hasPoison);
|
||||
|
||||
// If all expressions folded to a constant, populate results with attributes
|
||||
// containing those constants.
|
||||
@ -406,9 +428,9 @@ AffineMap::constantFold(ArrayRef<Attribute> operandConstants,
|
||||
return success();
|
||||
}
|
||||
|
||||
AffineMap
|
||||
AffineMap::partialConstantFold(ArrayRef<Attribute> operandConstants,
|
||||
SmallVectorImpl<int64_t> *results) const {
|
||||
AffineMap AffineMap::partialConstantFold(ArrayRef<Attribute> operandConstants,
|
||||
SmallVectorImpl<int64_t> *results,
|
||||
bool *hasPoison) const {
|
||||
assert(getNumInputs() == operandConstants.size());
|
||||
|
||||
// Fold each of the result expressions.
|
||||
@ -418,6 +440,10 @@ AffineMap::partialConstantFold(ArrayRef<Attribute> operandConstants,
|
||||
|
||||
for (auto expr : getResults()) {
|
||||
auto folded = exprFolder.constantFold(expr);
|
||||
if (exprFolder.hasPoison() && hasPoison) {
|
||||
*hasPoison = true;
|
||||
return {};
|
||||
}
|
||||
// If did not fold to a constant, keep the original expression, and clear
|
||||
// the integer results vector.
|
||||
if (folded) {
|
||||
|
@ -60,3 +60,24 @@ func.func @affine_min(%variable: index) -> (index, index) {
|
||||
// CHECK: return %[[r]], %[[C44]]
|
||||
return %0, %1 : index, index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @affine_apply_poison_division_zero() {
|
||||
// This is just for mlir::context to load ub dailect
|
||||
%ub = ub.poison : index
|
||||
%c16 = arith.constant 16 : index
|
||||
%0 = affine.apply affine_map<(d0)[s0] -> (d0 mod (s0 - s0))>(%c16)[%c16]
|
||||
%1 = affine.apply affine_map<(d0)[s0] -> (d0 floordiv (s0 - s0))>(%c16)[%c16]
|
||||
%2 = affine.apply affine_map<(d0)[s0] -> (d0 ceildiv (s0 - s0))>(%c16)[%c16]
|
||||
%alloc = memref.alloc(%0, %1, %2) : memref<?x?x?xi1>
|
||||
%3 = affine.load %alloc[%0, %1, %2] : memref<?x?x?xi1>
|
||||
affine.store %3, %alloc[%0, %1, %2] : memref<?x?x?xi1>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-NOT: affine.apply
|
||||
// CHECK: %[[poison:.*]] = ub.poison : index
|
||||
// CHECK-NEXT: %[[alloc:.*]] = memref.alloc(%[[poison]], %[[poison]], %[[poison]])
|
||||
// CHECK-NEXT: %[[load:.*]] = affine.load %[[alloc]][%[[poison]], %[[poison]], %[[poison]]] : memref<?x?x?xi1>
|
||||
// CHECK-NEXT: affine.store %[[load]], %alloc[%[[poison]], %[[poison]], %[[poison]]] : memref<?x?x?xi1>
|
||||
|
Loading…
Reference in New Issue
Block a user