[mlir][Vector] Fold InsertStridedSliceOp of ExtractStridedSliceOp.

This patch supports to fold InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst) to dst.

Differential Revision: https://reviews.llvm.org/D128903
This commit is contained in:
jacquesguan 2022-06-30 19:24:31 +08:00
parent 91ab4d4231
commit 8f45c5862f
2 changed files with 47 additions and 1 deletions

View File

@ -2205,11 +2205,43 @@ public:
return success();
}
};
/// Pattern to rewrite an InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst)
/// to dst.
class FoldInsertStridedSliceOfExtract final
: public OpRewritePattern<InsertStridedSliceOp> {
public:
using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
PatternRewriter &rewriter) const override {
auto extractStridedSliceOp =
insertStridedSliceOp.getSource()
.getDefiningOp<vector::ExtractStridedSliceOp>();
if (!extractStridedSliceOp)
return failure();
if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
return failure();
// Check if have the same strides and offsets.
if (extractStridedSliceOp.getStrides() !=
insertStridedSliceOp.getStrides() ||
extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
return failure();
rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
return success();
}
};
} // namespace
void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<FoldInsertStridedSliceSplat>(context);
results.add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract>(
context);
}
OpFoldResult InsertStridedSliceOp::fold(ArrayRef<Attribute> operands) {

View File

@ -1641,3 +1641,17 @@ func.func @insert_strided_slice_splat(%x: f32) -> (vector<8x16xf32>) {
: vector<4x4xf32> into vector<8x16xf32>
return %0 : vector<8x16xf32>
}
// -----
// CHECK-LABEL: @insert_extract_strided_slice
// CHECK-SAME: (%[[ARG:.*]]: vector<8x16xf32>)
// CHECK-NEXT: return %[[ARG]] : vector<8x16xf32>
func.func @insert_extract_strided_slice(%x: vector<8x16xf32>) -> (vector<8x16xf32>) {
%0 = vector.extract_strided_slice %x {offsets = [0, 8], sizes = [2, 4], strides = [1, 1]}
: vector<8x16xf32> to vector<2x4xf32>
%1 = vector.insert_strided_slice %0, %x {offsets = [0, 8], strides = [1, 1]}
: vector<2x4xf32> into vector<8x16xf32>
return %1 : vector<8x16xf32>
}