From 8f45c5862f82ca7e32c09ef84317daf66d278757 Mon Sep 17 00:00:00 2001 From: jacquesguan Date: Thu, 30 Jun 2022 19:24:31 +0800 Subject: [PATCH] [mlir][Vector] Fold InsertStridedSliceOp of ExtractStridedSliceOp. This patch supports to fold InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst) to dst. Differential Revision: https://reviews.llvm.org/D128903 --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 34 +++++++++++++++++++++- mlir/test/Dialect/Vector/canonicalize.mlir | 14 +++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 3edd23fef624..38f38f886705 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2205,11 +2205,43 @@ public: return success(); } }; + +/// Pattern to rewrite an InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst) +/// to dst. +class FoldInsertStridedSliceOfExtract final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp, + PatternRewriter &rewriter) const override { + auto extractStridedSliceOp = + insertStridedSliceOp.getSource() + .getDefiningOp(); + + if (!extractStridedSliceOp) + return failure(); + + if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest()) + return failure(); + + // Check if have the same strides and offsets. + if (extractStridedSliceOp.getStrides() != + insertStridedSliceOp.getStrides() || + extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets()) + return failure(); + + rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest()); + return success(); + } +}; + } // namespace void vector::InsertStridedSliceOp::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add( + context); } OpFoldResult InsertStridedSliceOp::fold(ArrayRef operands) { diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 7f50d9038045..515a2d1726b6 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1641,3 +1641,17 @@ func.func @insert_strided_slice_splat(%x: f32) -> (vector<8x16xf32>) { : vector<4x4xf32> into vector<8x16xf32> return %0 : vector<8x16xf32> } + + +// ----- + +// CHECK-LABEL: @insert_extract_strided_slice +// CHECK-SAME: (%[[ARG:.*]]: vector<8x16xf32>) +// CHECK-NEXT: return %[[ARG]] : vector<8x16xf32> +func.func @insert_extract_strided_slice(%x: vector<8x16xf32>) -> (vector<8x16xf32>) { + %0 = vector.extract_strided_slice %x {offsets = [0, 8], sizes = [2, 4], strides = [1, 1]} + : vector<8x16xf32> to vector<2x4xf32> + %1 = vector.insert_strided_slice %0, %x {offsets = [0, 8], strides = [1, 1]} + : vector<2x4xf32> into vector<8x16xf32> + return %1 : vector<8x16xf32> +}