[MLIR] [TOSA]: Move reshape(reshape(x)) -> reshape(x) from canonicalization to fold

reshape(reshape(x)) -> reshape(x) can be directly written as a fold instead of a canonicalization,
to help other passes cleanup while they work.

This initially broke ReshapeConverterExpand/Collapse, which relies on creating foldable reshapes and a carefully crafted
benefit priority of patterns.
I turned this into a single pattern on reshapes, which does expand and/or collapse as needed in one go.

Differential Revision: https://reviews.llvm.org/D155266
This commit is contained in:
Matthias Gehre 2023-07-13 08:53:47 +02:00
parent fdf36c3d4b
commit 0ebb050311
3 changed files with 99 additions and 127 deletions

View File

@ -1480,7 +1480,6 @@ def Tosa_ReshapeOp: Tosa_Op<"reshape", [
No data conversion happens during a reshape operation.
}];
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;

View File

@ -129,81 +129,74 @@ static bool createReassociationMapsForCollapse(
}
namespace {
class ReshapeConverterCollapse : public OpConversionPattern<tosa::ReshapeOp> {
public:
using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
Value createCollapse(ConversionPatternRewriter &rewriter, Location loc,
ShapedType resultTy, Value operand) {
ShapedType operandTy = cast<ShapedType>(operand.getType());
if (resultTy == operandTy)
return operand;
LogicalResult
matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
ShapedType operandTy = cast<ShapedType>(adaptor.getInput1().getType());
ShapedType resultTy = cast<ShapedType>(reshape.getType());
bool isDynamic = !operandTy.hasStaticShape();
bool isDynamic = !operandTy.hasStaticShape();
if (isDynamic && resultTy.getRank() != 1) {
return rewriter.notifyMatchFailure(
reshape, "Cannot collapse dynamic dims to more than one dimension");
}
SmallVector<ReassociationExprs, 4> reassociationMap;
if (!createReassociationMapsForCollapse(rewriter, operandTy.getShape(),
resultTy.getShape(),
reassociationMap, isDynamic)) {
return rewriter.notifyMatchFailure(
reshape,
"tosa.reshape Attempting to collapse into an incompatible shape");
}
SmallVector<int64_t> intermediateShape;
if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
intermediateShape, isDynamic)) {
return rewriter.notifyMatchFailure(
reshape, "tosa.reshape Cannot collapse into given shape");
}
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
return success();
if (isDynamic && resultTy.getRank() != 1) {
(void)rewriter.notifyMatchFailure(
loc, "Cannot collapse dynamic dims to more than one dimension");
return {};
}
};
class ReshapeConverterExpand : public OpConversionPattern<tosa::ReshapeOp> {
public:
using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
ShapedType operandTy = cast<ShapedType>(adaptor.getInput1().getType());
ShapedType resultTy = cast<ShapedType>(reshape.getType());
bool isDynamic = !operandTy.hasStaticShape();
if (isDynamic && operandTy.getRank() != 1) {
return rewriter.notifyMatchFailure(
reshape, "Cannot expand dynamic dims from more than one dimension");
}
SmallVector<ReassociationExprs, 4> reassociationMap;
if (!createReassociationMapsForCollapse(rewriter, resultTy.getShape(),
operandTy.getShape(),
reassociationMap, isDynamic)) {
return rewriter.notifyMatchFailure(
reshape,
"tosa.reshape Attempting to expand into an incompatible shape");
}
SmallVector<int64_t> intermediateShape;
if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
intermediateShape, isDynamic) ||
intermediateShape != operandTy.getShape()) {
return rewriter.notifyMatchFailure(
reshape, "tosa.reshape Cannot expand into given shape");
}
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
return success();
SmallVector<ReassociationExprs, 4> reassociationMap;
if (!createReassociationMapsForCollapse(rewriter, operandTy.getShape(),
resultTy.getShape(),
reassociationMap, isDynamic)) {
(void)rewriter.notifyMatchFailure(
loc, "tosa.reshape Attempting to collapse into an incompatible shape");
return {};
}
};
SmallVector<int64_t> intermediateShape;
if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
intermediateShape, isDynamic)) {
(void)rewriter.notifyMatchFailure(
loc, "tosa.reshape Cannot collapse into given shape");
return {};
}
return rewriter.create<tensor::CollapseShapeOp>(loc, resultTy, operand,
reassociationMap);
}
Value createExpand(ConversionPatternRewriter &rewriter, Location loc,
ShapedType resultTy, Value operand) {
ShapedType operandTy = cast<ShapedType>(operand.getType());
if (resultTy == operandTy)
return operand;
bool isDynamic = !operandTy.hasStaticShape();
if (isDynamic && operandTy.getRank() != 1) {
(void)rewriter.notifyMatchFailure(
loc, "Cannot expand dynamic dims from more than one dimension");
return {};
}
SmallVector<ReassociationExprs, 4> reassociationMap;
if (!createReassociationMapsForCollapse(rewriter, resultTy.getShape(),
operandTy.getShape(),
reassociationMap, isDynamic)) {
(void)rewriter.notifyMatchFailure(
loc, "tosa.reshape Attempting to expand into an incompatible shape");
return {};
}
SmallVector<int64_t> intermediateShape;
if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
intermediateShape, isDynamic) ||
intermediateShape != operandTy.getShape()) {
(void)rewriter.notifyMatchFailure(
loc, "tosa.reshape Cannot expand into given shape");
return {};
}
return rewriter.create<tensor::ExpandShapeOp>(loc, resultTy, operand,
reassociationMap);
}
class ReshapeConverterCollapseExpand
: public OpConversionPattern<tosa::ReshapeOp> {
@ -224,17 +217,19 @@ public:
reshape, "tosa.reshape Cannot identify an intermediate shape between "
"the given two shapes");
}
auto intermediateTy = RankedTensorType::get(
intermediateShape, reshape.getType().getElementType());
Value collapse = createCollapse(rewriter, reshape.getLoc(), intermediateTy,
adaptor.getInput1());
if (!collapse)
return failure();
Value expand = createExpand(rewriter, reshape.getLoc(), resultTy, collapse);
if (!expand)
return failure();
Value collapse = rewriter.create<tosa::ReshapeOp>(
reshape.getLoc(),
RankedTensorType::get(intermediateShape,
reshape.getType().getElementType()),
adaptor.getInput1(), rewriter.getDenseI64ArrayAttr(intermediateShape));
Value expand = rewriter.create<tosa::ReshapeOp>(
reshape.getLoc(), resultTy, collapse,
rewriter.getDenseI64ArrayAttr(resultTy.getShape()));
rewriter.replaceOp(reshape, expand);
return success();
}
};
@ -420,10 +415,6 @@ void mlir::tosa::populateTosaToTensorConversionPatterns(
RewritePatternSet *patterns) {
patterns->add<SliceConverter, PadConverter, ConcatConverter>(
patterns->getContext());
patterns->add<ReshapeConverterCollapse>(patterns->getContext(),
/*benefit=*/100);
patterns->add<ReshapeConverterExpand>(patterns->getContext(),
/*benefit=*/200);
patterns->add<ReshapeConverterCollapseExpand>(patterns->getContext(),
/*benefit=*/300);
patterns->add<ReshapeConverterCollapseExpand>(patterns->getContext());
}

View File

@ -62,31 +62,6 @@ void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ConcatOptimization>(context);
}
struct ReshapeReshapeOptimization : public OpRewritePattern<tosa::ReshapeOp> {
using OpRewritePattern<tosa::ReshapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::ReshapeOp op,
PatternRewriter &rewriter) const override {
Value input = op.getInput1();
Operation *definingOp = input.getDefiningOp();
if (!definingOp)
return failure();
if (tosa::ReshapeOp reshapeOp = dyn_cast<tosa::ReshapeOp>(definingOp)) {
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
op, op.getType(), reshapeOp.getInput1(), op.getNewShape());
return success();
}
return failure();
}
};
void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ReshapeReshapeOptimization>(context);
}
LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
auto notOp = op.getPred().getDefiningOp<tosa::LogicalNotOp>();
if (!notOp)
@ -820,25 +795,32 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
if (inputTy == outputTy)
return getInput1();
// Constants must have static shape.
if (!outputTy.hasStaticShape())
return {};
auto operand = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
if (!operand)
return {};
// Okay to duplicate splat constants.
if (operand.isSplat()) {
return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
// reshape(reshape(x)) -> reshape(x)
if (auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
getInput1().getDefiningOp())) {
getInput1Mutable().assign(reshapeOp.getInput1());
return getResult();
}
// Don't duplicate other constants.
if (!getInput1().hasOneUse())
return {};
// reshape(const(x)) -> const(reshape-attr(x))
if (auto operand = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
// Constants must have static shape.
if (!outputTy.hasStaticShape())
return {};
return operand.reshape(
llvm::cast<ShapedType>(operand.getType()).clone(getNewShape()));
// Okay to duplicate splat constants.
if (operand.isSplat())
return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
// Don't duplicate other constants.
if (!getInput1().hasOneUse())
return {};
return operand.reshape(
llvm::cast<ShapedType>(operand.getType()).clone(getNewShape()));
}
return {};
}
OpFoldResult PadOp::fold(FoldAdaptor adaptor) {