[MLIR][SCF] Allow combining subsequent if statements that yield & negated condition

This patch extends the existing if combining canonicalization to also handle the case where a value returned by the first if is used within the body of the second if.

This patch also extends if combining to support if's whose conditions are logical negations of each other.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D120924
This commit is contained in:
William S. Moses 2022-03-03 14:04:14 -05:00
parent 0e96d95d13
commit 62f84c73d2
2 changed files with 149 additions and 29 deletions

View File

@ -1519,51 +1519,98 @@ struct CombineIfs : public OpRewritePattern<IfOp> {
if (!prevIf)
return failure();
if (nextIf.getCondition() != prevIf.getCondition())
// Determine the logical then/else blocks when prevIf's
// condition is used. Null means the block does not exist
// in that case (e.g. empty else). If neither of these
// are set, the two conditions cannot be compared.
Block *nextThen = nullptr;
Block *nextElse = nullptr;
if (nextIf.getCondition() == prevIf.getCondition()) {
nextThen = nextIf.thenBlock();
if (!nextIf.getElseRegion().empty())
nextElse = nextIf.elseBlock();
}
if (arith::XOrIOp notv =
nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
if (notv.getLhs() == prevIf.getCondition() &&
matchPattern(notv.getRhs(), m_One())) {
nextElse = nextIf.thenBlock();
if (!nextIf.getElseRegion().empty())
nextThen = nextIf.elseBlock();
}
}
if (arith::XOrIOp notv =
prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
if (notv.getLhs() == nextIf.getCondition() &&
matchPattern(notv.getRhs(), m_One())) {
nextElse = nextIf.thenBlock();
if (!nextIf.getElseRegion().empty())
nextThen = nextIf.elseBlock();
}
}
if (!nextThen && !nextElse)
return failure();
// Don't permit merging if a result of the first if is used
// within the second.
if (llvm::any_of(prevIf->getUsers(),
[&](Operation *user) { return nextIf->isAncestor(user); }))
return failure();
SmallVector<Value> prevElseYielded;
if (!prevIf.getElseRegion().empty())
prevElseYielded = prevIf.elseYield().getOperands();
// Replace all uses of return values of op within nextIf with the
// corresponding yields
for (auto it : llvm::zip(prevIf.getResults(),
prevIf.thenYield().getOperands(), prevElseYielded))
for (OpOperand &use :
llvm::make_early_inc_range(std::get<0>(it).getUses())) {
if (nextThen && nextThen->getParent()->isAncestor(
use.getOwner()->getParentRegion())) {
rewriter.startRootUpdate(use.getOwner());
use.set(std::get<1>(it));
rewriter.finalizeRootUpdate(use.getOwner());
} else if (nextElse && nextElse->getParent()->isAncestor(
use.getOwner()->getParentRegion())) {
rewriter.startRootUpdate(use.getOwner());
use.set(std::get<2>(it));
rewriter.finalizeRootUpdate(use.getOwner());
}
}
SmallVector<Type> mergedTypes(prevIf.getResultTypes());
llvm::append_range(mergedTypes, nextIf.getResultTypes());
IfOp combinedIf = rewriter.create<IfOp>(
nextIf.getLoc(), mergedTypes, nextIf.getCondition(), /*hasElse=*/false);
nextIf.getLoc(), mergedTypes, prevIf.getCondition(), /*hasElse=*/false);
rewriter.eraseBlock(&combinedIf.getThenRegion().back());
YieldOp thenYield = prevIf.thenYield();
YieldOp thenYield2 = nextIf.thenYield();
rewriter.inlineRegionBefore(prevIf.getThenRegion(),
combinedIf.getThenRegion(),
combinedIf.getThenRegion().begin());
combinedIf.getThenRegion().getBlocks().splice(
combinedIf.getThenRegion().getBlocks().begin(),
prevIf.getThenRegion().getBlocks());
if (nextThen) {
YieldOp thenYield = combinedIf.thenYield();
YieldOp thenYield2 = cast<YieldOp>(nextThen->getTerminator());
rewriter.mergeBlocks(nextThen, combinedIf.thenBlock());
rewriter.setInsertionPointToEnd(combinedIf.thenBlock());
rewriter.mergeBlocks(nextIf.thenBlock(), combinedIf.thenBlock());
rewriter.setInsertionPointToEnd(combinedIf.thenBlock());
SmallVector<Value> mergedYields(thenYield.getOperands());
llvm::append_range(mergedYields, thenYield2.getOperands());
rewriter.create<YieldOp>(thenYield2.getLoc(), mergedYields);
rewriter.eraseOp(thenYield);
rewriter.eraseOp(thenYield2);
}
SmallVector<Value> mergedYields(thenYield.getOperands());
llvm::append_range(mergedYields, thenYield2.getOperands());
rewriter.create<YieldOp>(thenYield2.getLoc(), mergedYields);
rewriter.eraseOp(thenYield);
rewriter.eraseOp(thenYield2);
rewriter.inlineRegionBefore(prevIf.getElseRegion(),
combinedIf.getElseRegion(),
combinedIf.getElseRegion().begin());
combinedIf.getElseRegion().getBlocks().splice(
combinedIf.getElseRegion().getBlocks().begin(),
prevIf.getElseRegion().getBlocks());
if (!nextIf.getElseRegion().empty()) {
if (nextElse) {
if (combinedIf.getElseRegion().empty()) {
combinedIf.getElseRegion().getBlocks().splice(
combinedIf.getElseRegion().getBlocks().begin(),
nextIf.getElseRegion().getBlocks());
rewriter.inlineRegionBefore(*nextElse->getParent(),
combinedIf.getElseRegion(),
combinedIf.getElseRegion().begin());
} else {
YieldOp elseYield = combinedIf.elseYield();
YieldOp elseYield2 = nextIf.elseYield();
rewriter.mergeBlocks(nextIf.elseBlock(), combinedIf.elseBlock());
YieldOp elseYield2 = cast<YieldOp>(nextElse->getTerminator());
rewriter.mergeBlocks(nextElse, combinedIf.elseBlock());
rewriter.setInsertionPointToEnd(combinedIf.elseBlock());

View File

@ -1119,6 +1119,79 @@ func @combineIfs4(%arg0 : i1, %arg2: i64) {
// CHECK-NEXT: "test.secondCodeTrue"() : () -> ()
// CHECK-NEXT: }
// CHECK-LABEL: @combineIfsUsed
// CHECK-SAME: %[[arg0:.+]]: i1
func @combineIfsUsed(%arg0 : i1, %arg2: i64) -> (i32, i32) {
%res = scf.if %arg0 -> i32 {
%v = "test.firstCodeTrue"() : () -> i32
scf.yield %v : i32
} else {
%v2 = "test.firstCodeFalse"() : () -> i32
scf.yield %v2 : i32
}
%res2 = scf.if %arg0 -> i32 {
%v = "test.secondCodeTrue"(%res) : (i32) -> i32
scf.yield %v : i32
} else {
%v2 = "test.secondCodeFalse"(%res) : (i32) -> i32
scf.yield %v2 : i32
}
return %res, %res2 : i32, i32
}
// CHECK-NEXT: %[[res:.+]]:2 = scf.if %[[arg0]] -> (i32, i32) {
// CHECK-NEXT: %[[tval0:.+]] = "test.firstCodeTrue"() : () -> i32
// CHECK-NEXT: %[[tval:.+]] = "test.secondCodeTrue"(%[[tval0]]) : (i32) -> i32
// CHECK-NEXT: scf.yield %[[tval0]], %[[tval]] : i32, i32
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[fval0:.+]] = "test.firstCodeFalse"() : () -> i32
// CHECK-NEXT: %[[fval:.+]] = "test.secondCodeFalse"(%[[fval0]]) : (i32) -> i32
// CHECK-NEXT: scf.yield %[[fval0]], %[[fval]] : i32, i32
// CHECK-NEXT: }
// CHECK-NEXT: return %[[res]]#0, %[[res]]#1 : i32, i32
// CHECK-LABEL: @combineIfsNot
// CHECK-SAME: %[[arg0:.+]]: i1
func @combineIfsNot(%arg0 : i1, %arg2: i64) {
%true = arith.constant true
%not = arith.xori %arg0, %true : i1
scf.if %arg0 {
"test.firstCodeTrue"() : () -> ()
scf.yield
}
scf.if %not {
"test.secondCodeTrue"() : () -> ()
scf.yield
}
return
}
// CHECK-NEXT: scf.if %[[arg0]] {
// CHECK-NEXT: "test.firstCodeTrue"() : () -> ()
// CHECK-NEXT: } else {
// CHECK-NEXT: "test.secondCodeTrue"() : () -> ()
// CHECK-NEXT: }
// CHECK-LABEL: @combineIfsNot2
// CHECK-SAME: %[[arg0:.+]]: i1
func @combineIfsNot2(%arg0 : i1, %arg2: i64) {
%true = arith.constant true
%not = arith.xori %arg0, %true : i1
scf.if %not {
"test.firstCodeTrue"() : () -> ()
scf.yield
}
scf.if %arg0 {
"test.secondCodeTrue"() : () -> ()
scf.yield
}
return
}
// CHECK-NEXT: scf.if %[[arg0]] {
// CHECK-NEXT: "test.secondCodeTrue"() : () -> ()
// CHECK-NEXT: } else {
// CHECK-NEXT: "test.firstCodeTrue"() : () -> ()
// CHECK-NEXT: }
// -----
// CHECK-LABEL: func @propagate_into_execute_region