mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-10-07 19:03:57 +00:00
[MLIR] [TOSA]: Move reshape(reshape(x)) -> reshape(x) from canonicalization to fold
reshape(reshape(x)) -> reshape(x) can be directly written as a fold instead of a canonicalization, to help other passes cleanup while they work. This initially broke ReshapeConverterExpand/Collapse, which relies on creating foldable reshapes and a carefully crafted benefit priority of patterns. I turned this into a single pattern on reshapes, which does expand and/or collapse as needed in one go. Differential Revision: https://reviews.llvm.org/D155266
This commit is contained in:
parent
fdf36c3d4b
commit
0ebb050311
@ -1480,7 +1480,6 @@ def Tosa_ReshapeOp: Tosa_Op<"reshape", [
|
||||
No data conversion happens during a reshape operation.
|
||||
}];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
let hasFolder = 1;
|
||||
let hasVerifier = 1;
|
||||
|
||||
|
@ -129,81 +129,74 @@ static bool createReassociationMapsForCollapse(
|
||||
}
|
||||
|
||||
namespace {
|
||||
class ReshapeConverterCollapse : public OpConversionPattern<tosa::ReshapeOp> {
|
||||
public:
|
||||
using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
|
||||
Value createCollapse(ConversionPatternRewriter &rewriter, Location loc,
|
||||
ShapedType resultTy, Value operand) {
|
||||
ShapedType operandTy = cast<ShapedType>(operand.getType());
|
||||
if (resultTy == operandTy)
|
||||
return operand;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
ShapedType operandTy = cast<ShapedType>(adaptor.getInput1().getType());
|
||||
ShapedType resultTy = cast<ShapedType>(reshape.getType());
|
||||
bool isDynamic = !operandTy.hasStaticShape();
|
||||
bool isDynamic = !operandTy.hasStaticShape();
|
||||
|
||||
if (isDynamic && resultTy.getRank() != 1) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
reshape, "Cannot collapse dynamic dims to more than one dimension");
|
||||
}
|
||||
|
||||
SmallVector<ReassociationExprs, 4> reassociationMap;
|
||||
if (!createReassociationMapsForCollapse(rewriter, operandTy.getShape(),
|
||||
resultTy.getShape(),
|
||||
reassociationMap, isDynamic)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
reshape,
|
||||
"tosa.reshape Attempting to collapse into an incompatible shape");
|
||||
}
|
||||
|
||||
SmallVector<int64_t> intermediateShape;
|
||||
if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
|
||||
intermediateShape, isDynamic)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
reshape, "tosa.reshape Cannot collapse into given shape");
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
|
||||
reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
|
||||
return success();
|
||||
if (isDynamic && resultTy.getRank() != 1) {
|
||||
(void)rewriter.notifyMatchFailure(
|
||||
loc, "Cannot collapse dynamic dims to more than one dimension");
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
class ReshapeConverterExpand : public OpConversionPattern<tosa::ReshapeOp> {
|
||||
public:
|
||||
using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
ShapedType operandTy = cast<ShapedType>(adaptor.getInput1().getType());
|
||||
ShapedType resultTy = cast<ShapedType>(reshape.getType());
|
||||
bool isDynamic = !operandTy.hasStaticShape();
|
||||
|
||||
if (isDynamic && operandTy.getRank() != 1) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
reshape, "Cannot expand dynamic dims from more than one dimension");
|
||||
}
|
||||
|
||||
SmallVector<ReassociationExprs, 4> reassociationMap;
|
||||
if (!createReassociationMapsForCollapse(rewriter, resultTy.getShape(),
|
||||
operandTy.getShape(),
|
||||
reassociationMap, isDynamic)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
reshape,
|
||||
"tosa.reshape Attempting to expand into an incompatible shape");
|
||||
}
|
||||
|
||||
SmallVector<int64_t> intermediateShape;
|
||||
if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
|
||||
intermediateShape, isDynamic) ||
|
||||
intermediateShape != operandTy.getShape()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
reshape, "tosa.reshape Cannot expand into given shape");
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
|
||||
reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
|
||||
return success();
|
||||
SmallVector<ReassociationExprs, 4> reassociationMap;
|
||||
if (!createReassociationMapsForCollapse(rewriter, operandTy.getShape(),
|
||||
resultTy.getShape(),
|
||||
reassociationMap, isDynamic)) {
|
||||
(void)rewriter.notifyMatchFailure(
|
||||
loc, "tosa.reshape Attempting to collapse into an incompatible shape");
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
SmallVector<int64_t> intermediateShape;
|
||||
if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
|
||||
intermediateShape, isDynamic)) {
|
||||
(void)rewriter.notifyMatchFailure(
|
||||
loc, "tosa.reshape Cannot collapse into given shape");
|
||||
return {};
|
||||
}
|
||||
return rewriter.create<tensor::CollapseShapeOp>(loc, resultTy, operand,
|
||||
reassociationMap);
|
||||
}
|
||||
|
||||
Value createExpand(ConversionPatternRewriter &rewriter, Location loc,
|
||||
ShapedType resultTy, Value operand) {
|
||||
ShapedType operandTy = cast<ShapedType>(operand.getType());
|
||||
if (resultTy == operandTy)
|
||||
return operand;
|
||||
|
||||
bool isDynamic = !operandTy.hasStaticShape();
|
||||
|
||||
if (isDynamic && operandTy.getRank() != 1) {
|
||||
(void)rewriter.notifyMatchFailure(
|
||||
loc, "Cannot expand dynamic dims from more than one dimension");
|
||||
return {};
|
||||
}
|
||||
|
||||
SmallVector<ReassociationExprs, 4> reassociationMap;
|
||||
if (!createReassociationMapsForCollapse(rewriter, resultTy.getShape(),
|
||||
operandTy.getShape(),
|
||||
reassociationMap, isDynamic)) {
|
||||
(void)rewriter.notifyMatchFailure(
|
||||
loc, "tosa.reshape Attempting to expand into an incompatible shape");
|
||||
return {};
|
||||
}
|
||||
|
||||
SmallVector<int64_t> intermediateShape;
|
||||
if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
|
||||
intermediateShape, isDynamic) ||
|
||||
intermediateShape != operandTy.getShape()) {
|
||||
(void)rewriter.notifyMatchFailure(
|
||||
loc, "tosa.reshape Cannot expand into given shape");
|
||||
return {};
|
||||
}
|
||||
return rewriter.create<tensor::ExpandShapeOp>(loc, resultTy, operand,
|
||||
reassociationMap);
|
||||
}
|
||||
|
||||
class ReshapeConverterCollapseExpand
|
||||
: public OpConversionPattern<tosa::ReshapeOp> {
|
||||
@ -224,17 +217,19 @@ public:
|
||||
reshape, "tosa.reshape Cannot identify an intermediate shape between "
|
||||
"the given two shapes");
|
||||
}
|
||||
auto intermediateTy = RankedTensorType::get(
|
||||
intermediateShape, reshape.getType().getElementType());
|
||||
|
||||
Value collapse = createCollapse(rewriter, reshape.getLoc(), intermediateTy,
|
||||
adaptor.getInput1());
|
||||
if (!collapse)
|
||||
return failure();
|
||||
|
||||
Value expand = createExpand(rewriter, reshape.getLoc(), resultTy, collapse);
|
||||
if (!expand)
|
||||
return failure();
|
||||
|
||||
Value collapse = rewriter.create<tosa::ReshapeOp>(
|
||||
reshape.getLoc(),
|
||||
RankedTensorType::get(intermediateShape,
|
||||
reshape.getType().getElementType()),
|
||||
adaptor.getInput1(), rewriter.getDenseI64ArrayAttr(intermediateShape));
|
||||
Value expand = rewriter.create<tosa::ReshapeOp>(
|
||||
reshape.getLoc(), resultTy, collapse,
|
||||
rewriter.getDenseI64ArrayAttr(resultTy.getShape()));
|
||||
rewriter.replaceOp(reshape, expand);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@ -420,10 +415,6 @@ void mlir::tosa::populateTosaToTensorConversionPatterns(
|
||||
RewritePatternSet *patterns) {
|
||||
patterns->add<SliceConverter, PadConverter, ConcatConverter>(
|
||||
patterns->getContext());
|
||||
patterns->add<ReshapeConverterCollapse>(patterns->getContext(),
|
||||
/*benefit=*/100);
|
||||
patterns->add<ReshapeConverterExpand>(patterns->getContext(),
|
||||
/*benefit=*/200);
|
||||
patterns->add<ReshapeConverterCollapseExpand>(patterns->getContext(),
|
||||
/*benefit=*/300);
|
||||
|
||||
patterns->add<ReshapeConverterCollapseExpand>(patterns->getContext());
|
||||
}
|
||||
|
@ -62,31 +62,6 @@ void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
results.add<ConcatOptimization>(context);
|
||||
}
|
||||
|
||||
struct ReshapeReshapeOptimization : public OpRewritePattern<tosa::ReshapeOp> {
|
||||
using OpRewritePattern<tosa::ReshapeOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(tosa::ReshapeOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value input = op.getInput1();
|
||||
Operation *definingOp = input.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return failure();
|
||||
|
||||
if (tosa::ReshapeOp reshapeOp = dyn_cast<tosa::ReshapeOp>(definingOp)) {
|
||||
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
|
||||
op, op.getType(), reshapeOp.getInput1(), op.getNewShape());
|
||||
return success();
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<ReshapeReshapeOptimization>(context);
|
||||
}
|
||||
|
||||
LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
|
||||
auto notOp = op.getPred().getDefiningOp<tosa::LogicalNotOp>();
|
||||
if (!notOp)
|
||||
@ -820,25 +795,32 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
|
||||
if (inputTy == outputTy)
|
||||
return getInput1();
|
||||
|
||||
// Constants must have static shape.
|
||||
if (!outputTy.hasStaticShape())
|
||||
return {};
|
||||
|
||||
auto operand = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
||||
if (!operand)
|
||||
return {};
|
||||
|
||||
// Okay to duplicate splat constants.
|
||||
if (operand.isSplat()) {
|
||||
return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
|
||||
// reshape(reshape(x)) -> reshape(x)
|
||||
if (auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
|
||||
getInput1().getDefiningOp())) {
|
||||
getInput1Mutable().assign(reshapeOp.getInput1());
|
||||
return getResult();
|
||||
}
|
||||
|
||||
// Don't duplicate other constants.
|
||||
if (!getInput1().hasOneUse())
|
||||
return {};
|
||||
// reshape(const(x)) -> const(reshape-attr(x))
|
||||
if (auto operand = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
|
||||
// Constants must have static shape.
|
||||
if (!outputTy.hasStaticShape())
|
||||
return {};
|
||||
|
||||
return operand.reshape(
|
||||
llvm::cast<ShapedType>(operand.getType()).clone(getNewShape()));
|
||||
// Okay to duplicate splat constants.
|
||||
if (operand.isSplat())
|
||||
return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
|
||||
|
||||
// Don't duplicate other constants.
|
||||
if (!getInput1().hasOneUse())
|
||||
return {};
|
||||
|
||||
return operand.reshape(
|
||||
llvm::cast<ShapedType>(operand.getType()).clone(getNewShape()));
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
|
||||
|
Loading…
Reference in New Issue
Block a user