mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-02-11 12:16:07 +00:00
[MLIR][SCF] Allow combining subsequent if statements that yield & negated condition
This patch extends the existing if combining canonicalization to also handle the case where a value returned by the first if is used within the body of the second if. This patch also extends if combining to support if's whose conditions are logical negations of each other. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D120924
This commit is contained in:
parent
0e96d95d13
commit
62f84c73d2
@ -1519,51 +1519,98 @@ struct CombineIfs : public OpRewritePattern<IfOp> {
|
||||
if (!prevIf)
|
||||
return failure();
|
||||
|
||||
if (nextIf.getCondition() != prevIf.getCondition())
|
||||
// Determine the logical then/else blocks when prevIf's
|
||||
// condition is used. Null means the block does not exist
|
||||
// in that case (e.g. empty else). If neither of these
|
||||
// are set, the two conditions cannot be compared.
|
||||
Block *nextThen = nullptr;
|
||||
Block *nextElse = nullptr;
|
||||
if (nextIf.getCondition() == prevIf.getCondition()) {
|
||||
nextThen = nextIf.thenBlock();
|
||||
if (!nextIf.getElseRegion().empty())
|
||||
nextElse = nextIf.elseBlock();
|
||||
}
|
||||
if (arith::XOrIOp notv =
|
||||
nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
|
||||
if (notv.getLhs() == prevIf.getCondition() &&
|
||||
matchPattern(notv.getRhs(), m_One())) {
|
||||
nextElse = nextIf.thenBlock();
|
||||
if (!nextIf.getElseRegion().empty())
|
||||
nextThen = nextIf.elseBlock();
|
||||
}
|
||||
}
|
||||
if (arith::XOrIOp notv =
|
||||
prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
|
||||
if (notv.getLhs() == nextIf.getCondition() &&
|
||||
matchPattern(notv.getRhs(), m_One())) {
|
||||
nextElse = nextIf.thenBlock();
|
||||
if (!nextIf.getElseRegion().empty())
|
||||
nextThen = nextIf.elseBlock();
|
||||
}
|
||||
}
|
||||
|
||||
if (!nextThen && !nextElse)
|
||||
return failure();
|
||||
|
||||
// Don't permit merging if a result of the first if is used
|
||||
// within the second.
|
||||
if (llvm::any_of(prevIf->getUsers(),
|
||||
[&](Operation *user) { return nextIf->isAncestor(user); }))
|
||||
return failure();
|
||||
SmallVector<Value> prevElseYielded;
|
||||
if (!prevIf.getElseRegion().empty())
|
||||
prevElseYielded = prevIf.elseYield().getOperands();
|
||||
// Replace all uses of return values of op within nextIf with the
|
||||
// corresponding yields
|
||||
for (auto it : llvm::zip(prevIf.getResults(),
|
||||
prevIf.thenYield().getOperands(), prevElseYielded))
|
||||
for (OpOperand &use :
|
||||
llvm::make_early_inc_range(std::get<0>(it).getUses())) {
|
||||
if (nextThen && nextThen->getParent()->isAncestor(
|
||||
use.getOwner()->getParentRegion())) {
|
||||
rewriter.startRootUpdate(use.getOwner());
|
||||
use.set(std::get<1>(it));
|
||||
rewriter.finalizeRootUpdate(use.getOwner());
|
||||
} else if (nextElse && nextElse->getParent()->isAncestor(
|
||||
use.getOwner()->getParentRegion())) {
|
||||
rewriter.startRootUpdate(use.getOwner());
|
||||
use.set(std::get<2>(it));
|
||||
rewriter.finalizeRootUpdate(use.getOwner());
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<Type> mergedTypes(prevIf.getResultTypes());
|
||||
llvm::append_range(mergedTypes, nextIf.getResultTypes());
|
||||
|
||||
IfOp combinedIf = rewriter.create<IfOp>(
|
||||
nextIf.getLoc(), mergedTypes, nextIf.getCondition(), /*hasElse=*/false);
|
||||
nextIf.getLoc(), mergedTypes, prevIf.getCondition(), /*hasElse=*/false);
|
||||
rewriter.eraseBlock(&combinedIf.getThenRegion().back());
|
||||
|
||||
YieldOp thenYield = prevIf.thenYield();
|
||||
YieldOp thenYield2 = nextIf.thenYield();
|
||||
rewriter.inlineRegionBefore(prevIf.getThenRegion(),
|
||||
combinedIf.getThenRegion(),
|
||||
combinedIf.getThenRegion().begin());
|
||||
|
||||
combinedIf.getThenRegion().getBlocks().splice(
|
||||
combinedIf.getThenRegion().getBlocks().begin(),
|
||||
prevIf.getThenRegion().getBlocks());
|
||||
if (nextThen) {
|
||||
YieldOp thenYield = combinedIf.thenYield();
|
||||
YieldOp thenYield2 = cast<YieldOp>(nextThen->getTerminator());
|
||||
rewriter.mergeBlocks(nextThen, combinedIf.thenBlock());
|
||||
rewriter.setInsertionPointToEnd(combinedIf.thenBlock());
|
||||
|
||||
rewriter.mergeBlocks(nextIf.thenBlock(), combinedIf.thenBlock());
|
||||
rewriter.setInsertionPointToEnd(combinedIf.thenBlock());
|
||||
SmallVector<Value> mergedYields(thenYield.getOperands());
|
||||
llvm::append_range(mergedYields, thenYield2.getOperands());
|
||||
rewriter.create<YieldOp>(thenYield2.getLoc(), mergedYields);
|
||||
rewriter.eraseOp(thenYield);
|
||||
rewriter.eraseOp(thenYield2);
|
||||
}
|
||||
|
||||
SmallVector<Value> mergedYields(thenYield.getOperands());
|
||||
llvm::append_range(mergedYields, thenYield2.getOperands());
|
||||
rewriter.create<YieldOp>(thenYield2.getLoc(), mergedYields);
|
||||
rewriter.eraseOp(thenYield);
|
||||
rewriter.eraseOp(thenYield2);
|
||||
rewriter.inlineRegionBefore(prevIf.getElseRegion(),
|
||||
combinedIf.getElseRegion(),
|
||||
combinedIf.getElseRegion().begin());
|
||||
|
||||
combinedIf.getElseRegion().getBlocks().splice(
|
||||
combinedIf.getElseRegion().getBlocks().begin(),
|
||||
prevIf.getElseRegion().getBlocks());
|
||||
|
||||
if (!nextIf.getElseRegion().empty()) {
|
||||
if (nextElse) {
|
||||
if (combinedIf.getElseRegion().empty()) {
|
||||
combinedIf.getElseRegion().getBlocks().splice(
|
||||
combinedIf.getElseRegion().getBlocks().begin(),
|
||||
nextIf.getElseRegion().getBlocks());
|
||||
rewriter.inlineRegionBefore(*nextElse->getParent(),
|
||||
combinedIf.getElseRegion(),
|
||||
combinedIf.getElseRegion().begin());
|
||||
} else {
|
||||
YieldOp elseYield = combinedIf.elseYield();
|
||||
YieldOp elseYield2 = nextIf.elseYield();
|
||||
rewriter.mergeBlocks(nextIf.elseBlock(), combinedIf.elseBlock());
|
||||
YieldOp elseYield2 = cast<YieldOp>(nextElse->getTerminator());
|
||||
rewriter.mergeBlocks(nextElse, combinedIf.elseBlock());
|
||||
|
||||
rewriter.setInsertionPointToEnd(combinedIf.elseBlock());
|
||||
|
||||
|
@ -1119,6 +1119,79 @@ func @combineIfs4(%arg0 : i1, %arg2: i64) {
|
||||
// CHECK-NEXT: "test.secondCodeTrue"() : () -> ()
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// CHECK-LABEL: @combineIfsUsed
|
||||
// CHECK-SAME: %[[arg0:.+]]: i1
|
||||
func @combineIfsUsed(%arg0 : i1, %arg2: i64) -> (i32, i32) {
|
||||
%res = scf.if %arg0 -> i32 {
|
||||
%v = "test.firstCodeTrue"() : () -> i32
|
||||
scf.yield %v : i32
|
||||
} else {
|
||||
%v2 = "test.firstCodeFalse"() : () -> i32
|
||||
scf.yield %v2 : i32
|
||||
}
|
||||
%res2 = scf.if %arg0 -> i32 {
|
||||
%v = "test.secondCodeTrue"(%res) : (i32) -> i32
|
||||
scf.yield %v : i32
|
||||
} else {
|
||||
%v2 = "test.secondCodeFalse"(%res) : (i32) -> i32
|
||||
scf.yield %v2 : i32
|
||||
}
|
||||
return %res, %res2 : i32, i32
|
||||
}
|
||||
// CHECK-NEXT: %[[res:.+]]:2 = scf.if %[[arg0]] -> (i32, i32) {
|
||||
// CHECK-NEXT: %[[tval0:.+]] = "test.firstCodeTrue"() : () -> i32
|
||||
// CHECK-NEXT: %[[tval:.+]] = "test.secondCodeTrue"(%[[tval0]]) : (i32) -> i32
|
||||
// CHECK-NEXT: scf.yield %[[tval0]], %[[tval]] : i32, i32
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: %[[fval0:.+]] = "test.firstCodeFalse"() : () -> i32
|
||||
// CHECK-NEXT: %[[fval:.+]] = "test.secondCodeFalse"(%[[fval0]]) : (i32) -> i32
|
||||
// CHECK-NEXT: scf.yield %[[fval0]], %[[fval]] : i32, i32
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %[[res]]#0, %[[res]]#1 : i32, i32
|
||||
|
||||
// CHECK-LABEL: @combineIfsNot
|
||||
// CHECK-SAME: %[[arg0:.+]]: i1
|
||||
func @combineIfsNot(%arg0 : i1, %arg2: i64) {
|
||||
%true = arith.constant true
|
||||
%not = arith.xori %arg0, %true : i1
|
||||
scf.if %arg0 {
|
||||
"test.firstCodeTrue"() : () -> ()
|
||||
scf.yield
|
||||
}
|
||||
scf.if %not {
|
||||
"test.secondCodeTrue"() : () -> ()
|
||||
scf.yield
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-NEXT: scf.if %[[arg0]] {
|
||||
// CHECK-NEXT: "test.firstCodeTrue"() : () -> ()
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: "test.secondCodeTrue"() : () -> ()
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// CHECK-LABEL: @combineIfsNot2
|
||||
// CHECK-SAME: %[[arg0:.+]]: i1
|
||||
func @combineIfsNot2(%arg0 : i1, %arg2: i64) {
|
||||
%true = arith.constant true
|
||||
%not = arith.xori %arg0, %true : i1
|
||||
scf.if %not {
|
||||
"test.firstCodeTrue"() : () -> ()
|
||||
scf.yield
|
||||
}
|
||||
scf.if %arg0 {
|
||||
"test.secondCodeTrue"() : () -> ()
|
||||
scf.yield
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-NEXT: scf.if %[[arg0]] {
|
||||
// CHECK-NEXT: "test.secondCodeTrue"() : () -> ()
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: "test.firstCodeTrue"() : () -> ()
|
||||
// CHECK-NEXT: }
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @propagate_into_execute_region
|
||||
|
Loading…
x
Reference in New Issue
Block a user