[mlir][Vector] Fold InsertStridedSliceOp of two splat with the same input to splat.

This patch folds InsertStridedSliceOp(SplatOp(X):src_type, SplatOp(X):dst_type) to SplatOp(X):dst_type.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D128891
This commit is contained in:
jacquesguan 2022-06-30 16:30:59 +08:00
parent 2ceb9c347f
commit 91ab4d4231
3 changed files with 47 additions and 0 deletions

View File

@ -886,6 +886,7 @@ def Vector_InsertStridedSliceOp :
let hasFolder = 1;
let hasVerifier = 1;
let hasCanonicalizer = 1;
}
def Vector_OuterProductOp :

View File

@ -2180,6 +2180,38 @@ LogicalResult InsertStridedSliceOp::verify() {
return success();
}
namespace {
/// Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type,
/// SplatOp(X):dst_type) to SplatOp(X):dst_type.
class FoldInsertStridedSliceSplat final
: public OpRewritePattern<InsertStridedSliceOp> {
public:
using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
PatternRewriter &rewriter) const override {
auto srcSplatOp =
insertStridedSliceOp.getSource().getDefiningOp<vector::SplatOp>();
auto destSplatOp =
insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
if (!srcSplatOp || !destSplatOp)
return failure();
if (srcSplatOp.getInput() != destSplatOp.getInput())
return failure();
rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
return success();
}
};
} // namespace
void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<FoldInsertStridedSliceSplat>(context);
}
OpFoldResult InsertStridedSliceOp::fold(ArrayRef<Attribute> operands) {
if (getSourceVectorType() == getDestVectorType())
return getSource();

View File

@ -1627,3 +1627,17 @@ func.func @bitcast(%a: vector<4x8xf32>) -> vector<4x16xi16> {
%1 = vector.bitcast %0 : vector<4x8xi32> to vector<4x16xi16>
return %1 : vector<4x16xi16>
}
// -----
// CHECK-LABEL: @insert_strided_slice_splat
// CHECK-SAME: (%[[ARG:.*]]: f32)
// CHECK-NEXT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8x16xf32>
// CHECK-NEXT: return %[[SPLAT]] : vector<8x16xf32>
func.func @insert_strided_slice_splat(%x: f32) -> (vector<8x16xf32>) {
%splat0 = vector.splat %x : vector<4x4xf32>
%splat1 = vector.splat %x : vector<8x16xf32>
%0 = vector.insert_strided_slice %splat0, %splat1 {offsets = [2, 2], strides = [1, 1]}
: vector<4x4xf32> into vector<8x16xf32>
return %0 : vector<8x16xf32>
}