mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-10-07 10:54:01 +00:00
[mlir][Linalg] NFC: Combine elementwise fusion test passes.
There are a few different test passes that check elementwise fusion in Linalg. Consolidate them to a single pass controlled by different pass options (in keeping with how `TestLinalgTransforms` exists).
This commit is contained in:
parent
bf02586c57
commit
d730336411
@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns -split-input-file | FileCheck %s
|
||||
// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-generic-ops -split-input-file | FileCheck %s
|
||||
|
||||
#map0 = affine_map<(d0, d1) -> (d0, d1)>
|
||||
#binary2Dpointwise = {
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt %s -test-linalg-push-reshape -split-input-file | FileCheck %s
|
||||
// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=push-expanding-reshape -split-input-file | FileCheck %s
|
||||
|
||||
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d1)>
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt -test-linalg-control-fusion-by-expansion %s -split-input-file | FileCheck %s
|
||||
// RUN: mlir-opt -test-linalg-elementwise-fusion-patterns=control-fusion-by-expansion %s -split-input-file | FileCheck %s
|
||||
|
||||
func @control_producer_reshape_fusion(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?xf32>) -> tensor<?x?xf32> {
|
||||
%c0 = arith.constant 0 : index
|
||||
|
@ -58,87 +58,77 @@ struct TestLinalgElementwiseFusion
|
||||
return "Test Linalg element wise operation fusion patterns";
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &this->getContext();
|
||||
FuncOp funcOp = this->getOperation();
|
||||
RewritePatternSet fusionPatterns(context);
|
||||
Option<bool>
|
||||
fuseGenericOps(*this, "fuse-generic-ops",
|
||||
llvm::cl::desc("Test fusion of generic operations."),
|
||||
llvm::cl::init(false));
|
||||
|
||||
linalg::populateElementwiseOpsFusionPatterns(
|
||||
fusionPatterns,
|
||||
linalg::LinalgElementwiseFusionOptions()
|
||||
.setControlElementwiseOpsFusionFn(setFusedOpOperandLimit<4>));
|
||||
Option<bool> controlFuseByExpansion(
|
||||
*this, "control-fusion-by-expansion",
|
||||
llvm::cl::desc(
|
||||
"Test controlling fusion of reshape with generic op by expansion"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
(void)applyPatternsAndFoldGreedily(funcOp.getBody(),
|
||||
std::move(fusionPatterns));
|
||||
}
|
||||
};
|
||||
|
||||
struct TestLinalgControlFuseByExpansion
|
||||
: public PassWrapper<TestLinalgControlFuseByExpansion,
|
||||
OperationPass<FuncOp>> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry
|
||||
.insert<AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>();
|
||||
}
|
||||
StringRef getArgument() const final {
|
||||
return "test-linalg-control-fusion-by-expansion";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Test controlling of fusion of elementwise ops with reshape by "
|
||||
"expansion";
|
||||
}
|
||||
Option<bool>
|
||||
pushExpandingReshape(*this, "push-expanding-reshape",
|
||||
llvm::cl::desc("Test linalg expand_shape -> generic "
|
||||
"to generic -> expand_shape pattern"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &this->getContext();
|
||||
FuncOp funcOp = this->getOperation();
|
||||
RewritePatternSet fusionPatterns(context);
|
||||
|
||||
linalg::ControlElementwiseOpsFusionFn controlReshapeFusionFn =
|
||||
[](const OpResult &producer, OpOperand &consumer) {
|
||||
if (auto collapseOp =
|
||||
producer.getDefiningOp<tensor::CollapseShapeOp>()) {
|
||||
if (!collapseOp.src().getDefiningOp<linalg::LinalgOp>()) {
|
||||
return false;
|
||||
if (fuseGenericOps) {
|
||||
RewritePatternSet fusionPatterns(context);
|
||||
linalg::populateElementwiseOpsFusionPatterns(
|
||||
fusionPatterns,
|
||||
linalg::LinalgElementwiseFusionOptions()
|
||||
.setControlElementwiseOpsFusionFn(setFusedOpOperandLimit<4>));
|
||||
|
||||
(void)applyPatternsAndFoldGreedily(funcOp.getBody(),
|
||||
std::move(fusionPatterns));
|
||||
return;
|
||||
}
|
||||
|
||||
if (controlFuseByExpansion) {
|
||||
RewritePatternSet fusionPatterns(context);
|
||||
|
||||
linalg::ControlElementwiseOpsFusionFn controlReshapeFusionFn =
|
||||
[](const OpResult &producer, OpOperand &consumer) {
|
||||
if (auto collapseOp =
|
||||
producer.getDefiningOp<tensor::CollapseShapeOp>()) {
|
||||
if (!collapseOp.src().getDefiningOp<linalg::LinalgOp>()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (auto expandOp =
|
||||
dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) {
|
||||
if (expandOp->hasOneUse()) {
|
||||
OpOperand &use = *expandOp->getUses().begin();
|
||||
auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner());
|
||||
if (linalgOp && linalgOp.isOutputTensor(&use))
|
||||
return true;
|
||||
if (auto expandOp =
|
||||
dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) {
|
||||
if (expandOp->hasOneUse()) {
|
||||
OpOperand &use = *expandOp->getUses().begin();
|
||||
auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner());
|
||||
if (linalgOp && linalgOp.isOutputTensor(&use))
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return linalg::skipUnitDimReshape(producer, consumer);
|
||||
};
|
||||
return linalg::skipUnitDimReshape(producer, consumer);
|
||||
};
|
||||
|
||||
linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns,
|
||||
controlReshapeFusionFn);
|
||||
(void)applyPatternsAndFoldGreedily(funcOp.getBody(),
|
||||
std::move(fusionPatterns));
|
||||
linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns,
|
||||
controlReshapeFusionFn);
|
||||
(void)applyPatternsAndFoldGreedily(funcOp.getBody(),
|
||||
std::move(fusionPatterns));
|
||||
return;
|
||||
}
|
||||
|
||||
if (pushExpandingReshape) {
|
||||
RewritePatternSet patterns(context);
|
||||
linalg::populatePushReshapeOpsPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct TestPushExpandingReshape
|
||||
: public PassWrapper<TestPushExpandingReshape, OperationPass<FuncOp>> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry
|
||||
.insert<AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>();
|
||||
}
|
||||
StringRef getArgument() const final { return "test-linalg-push-reshape"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Test Linalg reshape push patterns";
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &this->getContext();
|
||||
FuncOp funcOp = this->getOperation();
|
||||
RewritePatternSet patterns(context);
|
||||
linalg::populatePushReshapeOpsPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace test {
|
||||
|
@ -81,10 +81,8 @@ void registerTestGenericIRVisitorsPass();
|
||||
void registerTestGenericIRVisitorsInterruptPass();
|
||||
void registerTestInterfaces();
|
||||
void registerTestLinalgCodegenStrategy();
|
||||
void registerTestLinalgControlFuseByExpansion();
|
||||
void registerTestLinalgDistribution();
|
||||
void registerTestLinalgElementwiseFusion();
|
||||
void registerTestPushExpandingReshape();
|
||||
void registerTestLinalgFusionTransforms();
|
||||
void registerTestLinalgTensorFusionTransforms();
|
||||
void registerTestLinalgTiledLoopFusionTransforms();
|
||||
@ -172,10 +170,8 @@ void registerTestPasses() {
|
||||
mlir::test::registerTestGenericIRVisitorsPass();
|
||||
mlir::test::registerTestInterfaces();
|
||||
mlir::test::registerTestLinalgCodegenStrategy();
|
||||
mlir::test::registerTestLinalgControlFuseByExpansion();
|
||||
mlir::test::registerTestLinalgDistribution();
|
||||
mlir::test::registerTestLinalgElementwiseFusion();
|
||||
mlir::test::registerTestPushExpandingReshape();
|
||||
mlir::test::registerTestLinalgFusionTransforms();
|
||||
mlir::test::registerTestLinalgTensorFusionTransforms();
|
||||
mlir::test::registerTestLinalgTiledLoopFusionTransforms();
|
||||
|
Loading…
Reference in New Issue
Block a user