[MLIR] AffineExpr lightweight value type for operators

This CL proposes adding MLIRContext* to AffineExpr as discussed previously.
This allows the value class to not require the context in its constructor and
makes it a POD that it makes sense to pass by value everywhere.
A list of other RFC CLs will build on this. The RFC CLs are small incremental
pushes of the API which would be a pretty big change otherwise.

Pushing the thinking a little bit more it seems reasonable to use implicit
cast/constructor to/from AffineExpr*.
As this thing evolves, it looks to me like IR (and
probably Parser, for not so good reasons) want to operate on AffineExpr* and
the rest of the code wants to operate on the value type.

For this reason I think AffineExprImpl*/AffineExpr may also make sense but I
do not have a particular naming preference.
The jury is still out for naming decision between the above and
AffineExprBase*/AffineExpr or AffineExpr*/AffineExprRef.

PiperOrigin-RevId: 215641596
This commit is contained in:
Nicolas Vasilache 2018-10-03 15:36:53 -07:00 committed by jpienaar
parent 4805e629c5
commit 9ef87c4b6b
5 changed files with 104 additions and 125 deletions

View File

@ -76,8 +76,11 @@ public:
/// Return true if the affine expression is a multiple of 'factor'.
bool isMultipleOf(int64_t factor) const;
MLIRContext *getContext() const;
protected:
explicit AffineExpr(Kind kind) : kind(kind) {}
explicit AffineExpr(Kind kind, MLIRContext *context)
: kind(kind), context(context) {}
~AffineExpr() {}
private:
@ -86,6 +89,7 @@ private:
/// Classification of the subclass
const Kind kind;
MLIRContext *context;
};
inline raw_ostream &operator<<(raw_ostream &os, const AffineExpr &expr) {
@ -93,6 +97,37 @@ inline raw_ostream &operator<<(raw_ostream &os, const AffineExpr &expr) {
return os;
}
// Helper structure to build AffineExpr with intuitive operators in order to
// operate on chainable, lightweight value types instead of pointer types.
struct AffineExprWrap {
/* implicit */ AffineExprWrap(mlir::AffineExpr *expr) : expr(expr) {}
AffineExprWrap(const AffineExprWrap &other) : expr(other.expr){};
AffineExprWrap &operator=(AffineExprWrap other) {
expr = other.expr;
return *this;
};
/* implicit */ operator mlir::AffineExpr *() { return expr; }
bool operator!() { return expr == nullptr; }
AffineExprWrap operator+(int64_t v) const;
AffineExprWrap operator+(AffineExprWrap other) const;
AffineExprWrap operator-() const;
AffineExprWrap operator-(int64_t v) const;
AffineExprWrap operator-(AffineExprWrap other) const;
AffineExprWrap operator*(int64_t v) const;
AffineExprWrap operator*(AffineExprWrap other) const;
AffineExprWrap floorDiv(uint64_t v) const;
AffineExprWrap floorDiv(AffineExprWrap other) const;
AffineExprWrap ceilDiv(uint64_t v) const;
AffineExprWrap ceilDiv(AffineExprWrap other) const;
AffineExprWrap operator%(uint64_t v) const;
AffineExprWrap operator%(AffineExprWrap other) const;
AffineExpr *expr;
};
/// Affine binary operation expression. An affine binary operation could be an
/// add, mul, floordiv, ceildiv, or a modulo operation. (Subtraction is
/// represented through a multiply by -1 and add.) These expressions are always
@ -146,7 +181,8 @@ public:
}
protected:
explicit AffineBinaryOpExpr(Kind kind, AffineExpr *lhs, AffineExpr *rhs);
explicit AffineBinaryOpExpr(Kind kind, AffineExpr *lhs, AffineExpr *rhs,
MLIRContext *context);
AffineExpr *const lhs;
AffineExpr *const rhs;
@ -184,8 +220,8 @@ public:
private:
~AffineDimExpr() = delete;
explicit AffineDimExpr(unsigned position)
: AffineExpr(Kind::DimId), position(position) {}
explicit AffineDimExpr(unsigned position, MLIRContext *context)
: AffineExpr(Kind::DimId, context), position(position) {}
/// Position of this identifier in the argument list.
unsigned position;
@ -209,8 +245,8 @@ public:
private:
~AffineSymbolExpr() = delete;
explicit AffineSymbolExpr(unsigned position)
: AffineExpr(Kind::SymbolId), position(position) {}
explicit AffineSymbolExpr(unsigned position, MLIRContext *context)
: AffineExpr(Kind::SymbolId, context), position(position) {}
/// Position of this identifier in the symbol list.
unsigned position;
@ -230,111 +266,13 @@ public:
private:
~AffineConstantExpr() = delete;
explicit AffineConstantExpr(int64_t constant)
: AffineExpr(Kind::Constant), constant(constant) {}
explicit AffineConstantExpr(int64_t constant, MLIRContext *context)
: AffineExpr(Kind::Constant, context), constant(constant) {}
// The constant.
int64_t constant;
};
// Helper structure to build AffineExpr with intuitive operators instead of all
// the IR boilerplate. To do this we need to operate on chainable, lightweight
// value types instead of pointer types.
// The more general proposal is that builders directly return such value type
// objects for the cases where we want composition with operators.
// Once these things are available, matchers and simplifiers can be written much
// more nicely.
// The base version of an operator is alway AffineExprWrap op AffineExpr.
// The other versions reuse that base version.
struct AffineExprWrap {
AffineExprWrap(int64_t v, MLIRContext *c)
: e(AffineConstantExpr::get(v, c)), context(c) {}
AffineExprWrap(mlir::AffineExpr *expr, MLIRContext *c)
: e(expr), context(c) {}
/* implicit */ operator mlir::AffineExpr *() { return e; }
bool operator!() { return e == nullptr; }
// Base version for operator+
inline AffineExprWrap operator+(mlir::AffineExpr *expr) const {
return AffineExprWrap(AffineBinaryOpExpr::getAdd(e, expr, context),
context);
}
inline AffineExprWrap operator+(int64_t v) const {
return AffineExprWrap(AffineBinaryOpExpr::getAdd(e, v, context), context);
}
inline AffineExprWrap operator+(const AffineExprWrap &other) const {
return *this + other.e;
}
// Unary minus, delegate to operator*
inline AffineExprWrap operator-() const { return *this * (-1); }
// Base version for operator-, delegate to operator+
inline AffineExprWrap operator-(mlir::AffineExpr *expr) const {
return *this + (-AffineExprWrap(expr, context));
}
inline AffineExprWrap operator-(int64_t v) const { return *this + (-v); }
inline AffineExprWrap operator-(const AffineExprWrap &other) const {
return *this - other.e;
}
// Base version for operator*
inline AffineExprWrap operator*(mlir::AffineExpr *expr) const {
return AffineExprWrap(AffineBinaryOpExpr::getMul(e, expr, context),
context);
}
inline AffineExprWrap operator*(int64_t v) const {
return AffineExprWrap(AffineBinaryOpExpr::getMul(e, v, context), context);
}
inline AffineExprWrap operator*(const AffineExprWrap &other) const {
return *this * other.e;
}
// Base version for floorDiv
inline AffineExprWrap floorDiv(AffineExpr *expr) const {
return AffineExprWrap(AffineBinaryOpExpr::getFloorDiv(e, expr, context),
context);
}
inline AffineExprWrap floorDiv(uint64_t v) const {
return AffineExprWrap(AffineBinaryOpExpr::getFloorDiv(e, v, context),
context);
}
inline AffineExprWrap floorDiv(const AffineExprWrap &other) const {
return this->floorDiv(other.e);
}
// Base version for ceilDiv
inline AffineExprWrap ceilDiv(AffineExpr *expr) const {
return AffineExprWrap(AffineBinaryOpExpr::getCeilDiv(e, expr, context),
context);
}
inline AffineExprWrap ceilDiv(uint64_t v) const {
return AffineExprWrap(AffineBinaryOpExpr::getCeilDiv(e, v, context),
context);
}
inline AffineExprWrap ceilDiv(const AffineExprWrap &other) const {
return this->ceilDiv(other.e);
}
// Base version for operator%
inline AffineExprWrap operator%(mlir::AffineExpr *expr) const {
return AffineExprWrap(AffineBinaryOpExpr::getMod(e, expr, context),
context);
}
inline AffineExprWrap operator%(uint64_t v) const {
return AffineExprWrap(AffineBinaryOpExpr::getMod(e, v, context), context);
}
inline AffineExprWrap operator%(const AffineExprWrap &other) const {
return *this % other.e;
}
AffineExpr *e;
MLIRContext *context;
};
} // end namespace mlir
#endif // MLIR_IR_AFFINE_EXPR_H

View File

@ -17,13 +17,13 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/Support/STLExtras.h"
#include "third_party/llvm/llvm/include/llvm/ADT/STLExtras.h"
#include "llvm/ADT/STLExtras.h"
using namespace mlir;
AffineBinaryOpExpr::AffineBinaryOpExpr(Kind kind, AffineExpr *lhs,
AffineExpr *rhs)
: AffineExpr(kind), lhs(lhs), rhs(rhs) {
AffineExpr *rhs, MLIRContext *context)
: AffineExpr(kind, context), lhs(lhs), rhs(rhs) {
// We verify affine op expr forms at construction time.
switch (kind) {
case Kind::Add:
@ -193,3 +193,47 @@ bool AffineExpr::isMultipleOf(int64_t factor) const {
0;
}
}
MLIRContext *AffineExpr::getContext() const { return context; }
AffineExprWrap AffineExprWrap::operator+(int64_t v) const {
return AffineBinaryOpExpr::getAdd(expr, v, expr->getContext());
}
AffineExprWrap AffineExprWrap::operator+(AffineExprWrap other) const {
return AffineBinaryOpExpr::getAdd(expr, other.expr, expr->getContext());
}
// Unary minus, delegate to operator*.
AffineExprWrap AffineExprWrap::operator-() const { return *this * (-1); }
// Delegate to operator+.
AffineExprWrap AffineExprWrap::operator-(int64_t v) const {
return *this + (-v);
}
AffineExprWrap AffineExprWrap::operator-(AffineExprWrap other) const {
return *this + (-other);
}
AffineExprWrap AffineExprWrap::operator*(int64_t v) const {
return AffineBinaryOpExpr::getMul(expr, v, expr->getContext());
}
AffineExprWrap AffineExprWrap::operator*(AffineExprWrap other) const {
return AffineBinaryOpExpr::getMul(expr, other.expr, expr->getContext());
}
AffineExprWrap AffineExprWrap::floorDiv(uint64_t v) const {
return AffineBinaryOpExpr::getFloorDiv(expr, v, expr->getContext());
}
AffineExprWrap AffineExprWrap::floorDiv(AffineExprWrap other) const {
return AffineBinaryOpExpr::getFloorDiv(expr, other.expr, expr->getContext());
}
AffineExprWrap AffineExprWrap::ceilDiv(uint64_t v) const {
return AffineBinaryOpExpr::getCeilDiv(expr, v, expr->getContext());
}
AffineExprWrap AffineExprWrap::ceilDiv(AffineExprWrap other) const {
return AffineBinaryOpExpr::getCeilDiv(expr, other.expr, expr->getContext());
}
AffineExprWrap AffineExprWrap::operator%(uint64_t v) const {
return AffineBinaryOpExpr::getMod(expr, v, expr->getContext());
}
AffineExprWrap AffineExprWrap::operator%(AffineExprWrap other) const {
return AffineBinaryOpExpr::getMod(expr, other.expr, expr->getContext());
}

View File

@ -246,7 +246,7 @@ AffineMap *Builder::getSymbolIdentityMap() {
AffineMap *Builder::getSingleDimShiftAffineMap(int64_t shift) {
// expr = d0 + shift.
auto expr = AffineExprWrap(getDimExpr(0), context) + shift;
auto expr = AffineExprWrap(getDimExpr(0)) + shift;
return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, {expr}, {}, context);
}

View File

@ -278,8 +278,7 @@ public:
/// Copy the specified array of elements into memory managed by our bump
/// pointer allocator. This assumes the elements are all PODs.
template <typename T>
ArrayRef<T> copyInto(ArrayRef<T> elements) {
template <typename T> ArrayRef<T> copyInto(ArrayRef<T> elements) {
auto result = allocator.Allocate<T>(elements.size());
std::uninitialized_copy(elements.begin(), elements.end(), result);
return ArrayRef<T>(result, elements.size());
@ -879,7 +878,7 @@ AffineExpr *AffineBinaryOpExpr::get(AffineExpr::Kind kind, AffineExpr *lhs,
// simplified/canonical form. Create and store it.
auto *result = impl.allocator.Allocate<AffineBinaryOpExpr>();
// Initialize the memory using placement new.
new (result) AffineBinaryOpExpr(kind, lhs, rhs);
new (result) AffineBinaryOpExpr(kind, lhs, rhs, context);
bool inserted = impl.affineExprs.insert({keyValue, result}).second;
assert(inserted && "the expression shouldn't already exist in the map");
(void)inserted;
@ -899,7 +898,7 @@ AffineDimExpr *AffineDimExpr::get(unsigned position, MLIRContext *context) {
result = impl.allocator.Allocate<AffineDimExpr>();
// Initialize the memory using placement new.
new (result) AffineDimExpr(position);
new (result) AffineDimExpr(position, context);
return result;
}
@ -917,7 +916,7 @@ AffineSymbolExpr *AffineSymbolExpr::get(unsigned position,
result = impl.allocator.Allocate<AffineSymbolExpr>();
// Initialize the memory using placement new.
new (result) AffineSymbolExpr(position);
new (result) AffineSymbolExpr(position, context);
return result;
}
@ -931,7 +930,7 @@ AffineConstantExpr *AffineConstantExpr::get(int64_t constant,
result = impl.allocator.Allocate<AffineConstantExpr>();
// Initialize the memory using placement new.
new (result) AffineConstantExpr(constant);
new (result) AffineConstantExpr(constant, context);
return result;
}

View File

@ -45,14 +45,13 @@ AffineMap *mlir::getUnrolledLoopUpperBound(const ForStmt &forStmt,
return nullptr;
// Sometimes, the trip count cannot be expressed as an affine expression.
auto tripCount =
AffineExprWrap(getTripCountExpr(forStmt), builder->getContext());
AffineExprWrap tripCount(getTripCountExpr(forStmt));
if (!tripCount)
return nullptr;
auto lb = AffineExprWrap(lbMap->getResult(0), builder->getContext());
auto step = AffineExprWrap(forStmt.getStep(), builder->getContext());
auto newUb = lb + step * (tripCount - tripCount % unrollFactor - 1);
AffineExprWrap lb(lbMap->getResult(0));
unsigned step = forStmt.getStep();
auto newUb = lb + (tripCount - tripCount % unrollFactor - 1) * step;
return builder->getAffineMap(lbMap->getNumDims(), lbMap->getNumSymbols(),
{newUb}, {});
@ -72,13 +71,12 @@ AffineMap *mlir::getCleanupLoopLowerBound(const ForStmt &forStmt,
return nullptr;
// Sometimes the trip count cannot be expressed as an affine expression.
auto tripCount =
AffineExprWrap(getTripCountExpr(forStmt), builder->getContext());
AffineExprWrap tripCount(getTripCountExpr(forStmt));
if (!tripCount)
return nullptr;
auto lb = AffineExprWrap(lbMap->getResult(0), builder->getContext());
auto step = AffineExprWrap(forStmt.getStep(), builder->getContext());
AffineExprWrap lb(lbMap->getResult(0));
unsigned step = forStmt.getStep();
auto newLb = lb + (tripCount - tripCount % unrollFactor) * step;
return builder->getAffineMap(lbMap->getNumDims(), lbMap->getNumSymbols(),
{newLb}, {});