[Loop Predication] Teach LP about reverse loops

Summary:
Currently, we only support predication for forward loops with step
of 1.  This patch enables loop predication for reverse or
countdownLoops, which satisfy the following conditions:
   1. The step of the IV is -1.
   2. The loop has a singe latch as B(X) = X <pred>
latchLimit with pred as s> or u>
   3. The IV of the guard is the decrement
IV of the latch condition (Guard is: G(X) = X-1 u< guardLimit).

This patch was downstream for a while and is the last series of patches
that's from our LP implementation downstream.

Reviewers: apilipenko, mkazantsev, sanjoy

Subscribers: llvm-commits

Differential Revision: https://reviews.llvm.org/D40353

llvm-svn: 319659
This commit is contained in:
Anna Thomas 2017-12-04 15:11:48 +00:00
parent d141e4806b
commit 7b360434ff
2 changed files with 271 additions and 54 deletions

View File

@ -98,60 +98,79 @@
// Note that we can use anything stronger than M, i.e. any condition which
// implies M.
//
// For now the transformation is limited to the following case:
// When S = 1 (i.e. forward iterating loop), the transformation is supported
// when:
// * The loop has a single latch with the condition of the form:
// B(X) = latchStart + X <pred> latchLimit,
// where <pred> is u<, u<=, s<, or s<=.
// * The step of the IV used in the latch condition is 1.
// * The guard condition is of the form
// G(X) = guardStart + X u< guardLimit
//
// For the ult latch comparison case M is:
// forall X . guardStart + X u< guardLimit && latchStart + X <u latchLimit =>
// guardStart + X + 1 u< guardLimit
// For the ult latch comparison case M is:
// forall X . guardStart + X u< guardLimit && latchStart + X <u latchLimit =>
// guardStart + X + 1 u< guardLimit
//
// The only way the antecedent can be true and the consequent can be false is
// if
// X == guardLimit - 1 - guardStart
// (and guardLimit is non-zero, but we won't use this latter fact).
// If X == guardLimit - 1 - guardStart then the second half of the antecedent is
// latchStart + guardLimit - 1 - guardStart u< latchLimit
// and its negation is
// latchStart + guardLimit - 1 - guardStart u>= latchLimit
// The only way the antecedent can be true and the consequent can be false is
// if
// X == guardLimit - 1 - guardStart
// (and guardLimit is non-zero, but we won't use this latter fact).
// If X == guardLimit - 1 - guardStart then the second half of the antecedent is
// latchStart + guardLimit - 1 - guardStart u< latchLimit
// and its negation is
// latchStart + guardLimit - 1 - guardStart u>= latchLimit
//
// In other words, if
// latchLimit u<= latchStart + guardLimit - 1 - guardStart
// then:
// (the ranges below are written in ConstantRange notation, where [A, B) is the
// set for (I = A; I != B; I++ /*maywrap*/) yield(I);)
// In other words, if
// latchLimit u<= latchStart + guardLimit - 1 - guardStart
// then:
// (the ranges below are written in ConstantRange notation, where [A, B) is the
// set for (I = A; I != B; I++ /*maywrap*/) yield(I);)
//
// forall X . guardStart + X u< guardLimit &&
// latchStart + X u< latchLimit =>
// guardStart + X + 1 u< guardLimit
// == forall X . guardStart + X u< guardLimit &&
// latchStart + X u< latchStart + guardLimit - 1 - guardStart =>
// guardStart + X + 1 u< guardLimit
// == forall X . (guardStart + X) in [0, guardLimit) &&
// (latchStart + X) in [0, latchStart + guardLimit - 1 - guardStart) =>
// (guardStart + X + 1) in [0, guardLimit)
// == forall X . X in [-guardStart, guardLimit - guardStart) &&
// X in [-latchStart, guardLimit - 1 - guardStart) =>
// X in [-guardStart - 1, guardLimit - guardStart - 1)
// == true
// forall X . guardStart + X u< guardLimit &&
// latchStart + X u< latchLimit =>
// guardStart + X + 1 u< guardLimit
// == forall X . guardStart + X u< guardLimit &&
// latchStart + X u< latchStart + guardLimit - 1 - guardStart =>
// guardStart + X + 1 u< guardLimit
// == forall X . (guardStart + X) in [0, guardLimit) &&
// (latchStart + X) in [0, latchStart + guardLimit - 1 - guardStart) =>
// (guardStart + X + 1) in [0, guardLimit)
// == forall X . X in [-guardStart, guardLimit - guardStart) &&
// X in [-latchStart, guardLimit - 1 - guardStart) =>
// X in [-guardStart - 1, guardLimit - guardStart - 1)
// == true
//
// So the widened condition is:
// guardStart u< guardLimit &&
// latchStart + guardLimit - 1 - guardStart u>= latchLimit
// Similarly for ule condition the widened condition is:
// guardStart u< guardLimit &&
// latchStart + guardLimit - 1 - guardStart u> latchLimit
// For slt condition the widened condition is:
// guardStart u< guardLimit &&
// latchStart + guardLimit - 1 - guardStart s>= latchLimit
// For sle condition the widened condition is:
// guardStart u< guardLimit &&
// latchStart + guardLimit - 1 - guardStart s> latchLimit
// So the widened condition is:
// guardStart u< guardLimit &&
// latchStart + guardLimit - 1 - guardStart u>= latchLimit
// Similarly for ule condition the widened condition is:
// guardStart u< guardLimit &&
// latchStart + guardLimit - 1 - guardStart u> latchLimit
// For slt condition the widened condition is:
// guardStart u< guardLimit &&
// latchStart + guardLimit - 1 - guardStart s>= latchLimit
// For sle condition the widened condition is:
// guardStart u< guardLimit &&
// latchStart + guardLimit - 1 - guardStart s> latchLimit
//
// When S = -1 (i.e. reverse iterating loop), the transformation is supported
// when:
// * The loop has a single latch with the condition of the form:
// B(X) = X <pred> latchLimit, where <pred> is u> or s>.
// * The guard condition is of the form
// G(X) = X - 1 u< guardLimit
//
// For the ugt latch comparison case M is:
// forall X. X-1 u< guardLimit and X u> latchLimit => X-2 u< guardLimit
//
// The only way the antecedent can be true and the consequent can be false is if
// X == 1.
// If X == 1 then the second half of the antecedent is
// 1 u> latchLimit, and its negation is latchLimit u>= 1.
//
// So the widened condition is:
// guardStart u< guardLimit && latchLimit u>= 1.
// Similarly for sgt condition the widened condition is:
// guardStart u< guardLimit && latchLimit s>= 1.
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/Scalar/LoopPredication.h"
@ -177,6 +196,8 @@ using namespace llvm;
static cl::opt<bool> EnableIVTruncation("loop-predication-enable-iv-truncation",
cl::Hidden, cl::init(true));
static cl::opt<bool> EnableCountDownLoop("loop-predication-enable-count-down-loop",
cl::Hidden, cl::init(true));
namespace {
class LoopPredication {
/// Represents an induction variable check:
@ -223,7 +244,10 @@ class LoopPredication {
LoopICmp RangeCheck,
SCEVExpander &Expander,
IRBuilder<> &Builder);
Optional<Value *> widenICmpRangeCheckDecrementingLoop(LoopICmp LatchCheck,
LoopICmp RangeCheck,
SCEVExpander &Expander,
IRBuilder<> &Builder);
bool widenGuardConditions(IntrinsicInst *II, SCEVExpander &Expander);
// When the IV type is wider than the range operand type, we can still do loop
@ -360,7 +384,7 @@ LoopPredication::generateLoopLatchCheck(Type *RangeCheckType) {
}
bool LoopPredication::isSupportedStep(const SCEV* Step) {
return Step->isOne();
return Step->isOne() || (Step->isAllOnesValue() && EnableCountDownLoop);
}
bool LoopPredication::CanExpand(const SCEV* S) {
@ -420,6 +444,44 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop(
GuardStart, GuardLimit, InsertAt);
return Builder.CreateAnd(FirstIterationCheck, LimitCheck);
}
Optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop(
LoopPredication::LoopICmp LatchCheck, LoopPredication::LoopICmp RangeCheck,
SCEVExpander &Expander, IRBuilder<> &Builder) {
auto *Ty = RangeCheck.IV->getType();
const SCEV *GuardStart = RangeCheck.IV->getStart();
const SCEV *GuardLimit = RangeCheck.Limit;
const SCEV *LatchLimit = LatchCheck.Limit;
if (!CanExpand(GuardStart) || !CanExpand(GuardLimit) ||
!CanExpand(LatchLimit)) {
DEBUG(dbgs() << "Can't expand limit check!\n");
return None;
}
// The decrement of the latch check IV should be the same as the
// rangeCheckIV.
auto *PostDecLatchCheckIV = LatchCheck.IV->getPostIncExpr(*SE);
if (RangeCheck.IV != PostDecLatchCheckIV) {
DEBUG(dbgs() << "Not the same. PostDecLatchCheckIV: "
<< *PostDecLatchCheckIV
<< " and RangeCheckIV: " << *RangeCheck.IV << "\n");
return None;
}
// Generate the widened condition for CountDownLoop:
// guardStart u< guardLimit &&
// latchLimit <pred> 1.
// See the header comment for reasoning of the checks.
Instruction *InsertAt = Preheader->getTerminator();
auto LimitCheckPred = ICmpInst::isSigned(LatchCheck.Pred)
? ICmpInst::ICMP_SGE
: ICmpInst::ICMP_UGE;
auto *FirstIterationCheck = expandCheck(Expander, Builder, ICmpInst::ICMP_ULT,
GuardStart, GuardLimit, InsertAt);
auto *LimitCheck = expandCheck(Expander, Builder, LimitCheckPred, LatchLimit,
SE->getOne(Ty), InsertAt);
return Builder.CreateAnd(FirstIterationCheck, LimitCheck);
}
/// If ICI can be widened to a loop invariant condition emits the loop
/// invariant condition in the loop preheader and return it, otherwise
/// returns None.
@ -467,13 +529,24 @@ Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI,
}
LoopICmp CurrLatchCheck = *CurrLatchCheckOpt;
// At this point the range check step and latch step should have the same
// value and type.
assert(Step == CurrLatchCheck.IV->getStepRecurrence(*SE) &&
"Range and latch should have same step recurrence!");
// At this point, the range and latch step should have the same type, but need
// not have the same value (we support both 1 and -1 steps).
assert(Step->getType() ==
CurrLatchCheck.IV->getStepRecurrence(*SE)->getType() &&
"Range and latch steps should be of same type!");
if (Step != CurrLatchCheck.IV->getStepRecurrence(*SE)) {
DEBUG(dbgs() << "Range and latch have different step values!\n");
return None;
}
return widenICmpRangeCheckIncrementingLoop(CurrLatchCheck, *RangeCheck,
Expander, Builder);
if (Step->isOne())
return widenICmpRangeCheckIncrementingLoop(CurrLatchCheck, *RangeCheck,
Expander, Builder);
else {
assert(Step->isAllOnesValue() && "Step should be -1!");
return widenICmpRangeCheckDecrementingLoop(CurrLatchCheck, *RangeCheck,
Expander, Builder);
}
}
bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard,
@ -580,9 +653,13 @@ Optional<LoopPredication::LoopICmp> LoopPredication::parseLoopLatchICmp() {
}
auto IsUnsupportedPredicate = [](const SCEV *Step, ICmpInst::Predicate Pred) {
assert(Step->isOne() && "expected Step to be one!");
return Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_SLT &&
Pred != ICmpInst::ICMP_ULE && Pred != ICmpInst::ICMP_SLE;
if (Step->isOne()) {
return Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_SLT &&
Pred != ICmpInst::ICMP_ULE && Pred != ICmpInst::ICMP_SLE;
} else {
assert(Step->isAllOnesValue() && "Step should be -1!");
return Pred != ICmpInst::ICMP_UGT && Pred != ICmpInst::ICMP_SGT;
}
};
if (IsUnsupportedPredicate(Step, Result->Pred)) {

View File

@ -0,0 +1,140 @@
; RUN: opt -S -loop-predication -loop-predication-enable-count-down-loop=true < %s 2>&1 | FileCheck %s
; RUN: opt -S -passes='require<scalar-evolution>,loop(loop-predication)' -loop-predication-enable-count-down-loop=true < %s 2>&1 | FileCheck %s
declare void @llvm.experimental.guard(i1, ...)
define i32 @signed_reverse_loop_n_to_lower_limit(i32* %array, i32 %length, i32 %n, i32 %lowerlimit) {
; CHECK-LABEL: @signed_reverse_loop_n_to_lower_limit(
entry:
%tmp5 = icmp eq i32 %n, 0
br i1 %tmp5, label %exit, label %loop.preheader
; CHECK: loop.preheader:
; CHECK-NEXT: [[range_start:%.*]] = add i32 %n, -1
; CHECK-NEXT: [[first_iteration_check:%.*]] = icmp ult i32 [[range_start]], %length
; CHECK-NEXT: [[no_wrap_check:%.*]] = icmp sge i32 %lowerlimit, 1
; CHECK-NEXT: [[wide_cond:%.*]] = and i1 [[first_iteration_check]], [[no_wrap_check]]
loop.preheader:
br label %loop
; CHECK: loop:
; CHECK: call void (i1, ...) @llvm.experimental.guard(i1 [[wide_cond]], i32 9) [ "deopt"() ]
loop:
%loop.acc = phi i32 [ %loop.acc.next, %loop ], [ 0, %loop.preheader ]
%i = phi i32 [ %i.next, %loop ], [ %n, %loop.preheader ]
%i.next = add nsw i32 %i, -1
%within.bounds = icmp ult i32 %i.next, %length
call void (i1, ...) @llvm.experimental.guard(i1 %within.bounds, i32 9) [ "deopt"() ]
%i.i64 = zext i32 %i.next to i64
%array.i.ptr = getelementptr inbounds i32, i32* %array, i64 %i.i64
%array.i = load i32, i32* %array.i.ptr, align 4
%loop.acc.next = add i32 %loop.acc, %array.i
%continue = icmp sgt i32 %i, %lowerlimit
br i1 %continue, label %loop, label %exit
exit:
%result = phi i32 [ 0, %entry ], [ %loop.acc.next, %loop ]
ret i32 %result
}
define i32 @unsigned_reverse_loop_n_to_lower_limit(i32* %array, i32 %length, i32 %n, i32 %lowerlimit) {
; CHECK-LABEL: @unsigned_reverse_loop_n_to_lower_limit(
entry:
%tmp5 = icmp eq i32 %n, 0
br i1 %tmp5, label %exit, label %loop.preheader
; CHECK: loop.preheader:
; CHECK-NEXT: [[range_start:%.*]] = add i32 %n, -1
; CHECK-NEXT: [[first_iteration_check:%.*]] = icmp ult i32 [[range_start]], %length
; CHECK-NEXT: [[no_wrap_check:%.*]] = icmp uge i32 %lowerlimit, 1
; CHECK-NEXT: [[wide_cond:%.*]] = and i1 [[first_iteration_check]], [[no_wrap_check]]
loop.preheader:
br label %loop
; CHECK: loop:
; CHECK: call void (i1, ...) @llvm.experimental.guard(i1 [[wide_cond]], i32 9) [ "deopt"() ]
loop:
%loop.acc = phi i32 [ %loop.acc.next, %loop ], [ 0, %loop.preheader ]
%i = phi i32 [ %i.next, %loop ], [ %n, %loop.preheader ]
%i.next = add nsw i32 %i, -1
%within.bounds = icmp ult i32 %i.next, %length
call void (i1, ...) @llvm.experimental.guard(i1 %within.bounds, i32 9) [ "deopt"() ]
%i.i64 = zext i32 %i.next to i64
%array.i.ptr = getelementptr inbounds i32, i32* %array, i64 %i.i64
%array.i = load i32, i32* %array.i.ptr, align 4
%loop.acc.next = add i32 %loop.acc, %array.i
%continue = icmp ugt i32 %i, %lowerlimit
br i1 %continue, label %loop, label %exit
exit:
%result = phi i32 [ 0, %entry ], [ %loop.acc.next, %loop ]
ret i32 %result
}
; if we predicated the loop, the guard will definitely fail and we will
; deoptimize early on.
define i32 @unsigned_reverse_loop_n_to_0(i32* %array, i32 %length, i32 %n, i32 %lowerlimit) {
; CHECK-LABEL: @unsigned_reverse_loop_n_to_0(
entry:
%tmp5 = icmp eq i32 %n, 0
br i1 %tmp5, label %exit, label %loop.preheader
; CHECK: loop.preheader:
; CHECK-NEXT: [[range_start:%.*]] = add i32 %n, -1
; CHECK-NEXT: [[first_iteration_check:%.*]] = icmp ult i32 [[range_start]], %length
; CHECK-NEXT: [[wide_cond:%.*]] = and i1 [[first_iteration_check]], false
loop.preheader:
br label %loop
; CHECK: loop:
; CHECK: call void (i1, ...) @llvm.experimental.guard(i1 [[wide_cond]], i32 9) [ "deopt"() ]
loop:
%loop.acc = phi i32 [ %loop.acc.next, %loop ], [ 0, %loop.preheader ]
%i = phi i32 [ %i.next, %loop ], [ %n, %loop.preheader ]
%i.next = add nsw i32 %i, -1
%within.bounds = icmp ult i32 %i.next, %length
call void (i1, ...) @llvm.experimental.guard(i1 %within.bounds, i32 9) [ "deopt"() ]
%i.i64 = zext i32 %i.next to i64
%array.i.ptr = getelementptr inbounds i32, i32* %array, i64 %i.i64
%array.i = load i32, i32* %array.i.ptr, align 4
%loop.acc.next = add i32 %loop.acc, %array.i
%continue = icmp ugt i32 %i, 0
br i1 %continue, label %loop, label %exit
exit:
%result = phi i32 [ 0, %entry ], [ %loop.acc.next, %loop ]
ret i32 %result
}
; do not loop predicate when the range has step -1 and latch has step 1.
define i32 @reverse_loop_range_step_increment(i32 %n, i32* %array, i32 %length) {
; CHECK-LABEL: @reverse_loop_range_step_increment(
entry:
%tmp5 = icmp eq i32 %n, 0
br i1 %tmp5, label %exit, label %loop.preheader
loop.preheader:
br label %loop
; CHECK: loop:
; CHECK: llvm.experimental.guard(i1 %within.bounds, i32 9)
loop:
%loop.acc = phi i32 [ %loop.acc.next, %loop ], [ 0, %loop.preheader ]
%i = phi i32 [ %i.next, %loop ], [ %n, %loop.preheader ]
%irc = phi i32 [ %i.inc, %loop ], [ 1, %loop.preheader ]
%i.inc = add nuw nsw i32 %irc, 1
%within.bounds = icmp ult i32 %irc, %length
call void (i1, ...) @llvm.experimental.guard(i1 %within.bounds, i32 9) [ "deopt"() ]
%i.i64 = zext i32 %irc to i64
%array.i.ptr = getelementptr inbounds i32, i32* %array, i64 %i.i64
%array.i = load i32, i32* %array.i.ptr, align 4
%i.next = add nsw i32 %i, -1
%loop.acc.next = add i32 %loop.acc, %array.i
%continue = icmp ugt i32 %i, 65534
br i1 %continue, label %loop, label %exit
exit:
%result = phi i32 [ 0, %entry ], [ %loop.acc.next, %loop ]
ret i32 %result
}