From 32fd32bc6f6de46b96c9e09575694d34f248d6ee Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Fri, 29 Mar 2019 22:00:12 +0000 Subject: [PATCH] [SCEV] Check the cache in get{S|U}MaxExpr before doing any work Summary: This lets us avoid e.g. checking if A >=s B in getSMaxExpr(A, B) if we've already established that (A smax B) is the best we can do. Fixes PR41225. Reviewers: asbirlea Subscribers: mcrosier, jlebar, bixia, jdoerfert, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D60010 llvm-svn: 357320 --- llvm/include/llvm/Analysis/ScalarEvolution.h | 10 ++ llvm/lib/Analysis/ScalarEvolution.cpp | 45 +++-- .../ScalarEvolution/max-expr-cache.ll | 156 ++++++++++++++++++ 3 files changed, 199 insertions(+), 12 deletions(-) create mode 100644 llvm/test/Analysis/ScalarEvolution/max-expr-cache.ll diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h index 0462d2a13a79..ba8648b4f442 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -1871,6 +1871,16 @@ private: /// Assign A and B to LHS and RHS, respectively. bool matchURem(const SCEV *Expr, const SCEV *&LHS, const SCEV *&RHS); + /// Look for a SCEV expression with type `SCEVType` and operands `Ops` in + /// `UniqueSCEVs`. + /// + /// The first component of the returned tuple is the SCEV if found and null + /// otherwise. The second component is the `FoldingSetNodeID` that was + /// constructed to look up the SCEV and the third component is the insertion + /// point. + std::tuple + findExistingSCEVInCache(int SCEVType, ArrayRef Ops); + FoldingSet UniqueSCEVs; FoldingSet UniquePreds; BumpPtrAllocator SCEVAllocator; diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 920e0dd27823..fd759bd80fc2 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -3523,6 +3523,17 @@ const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, return getSMaxExpr(Ops); } +std::tuple +ScalarEvolution::findExistingSCEVInCache(int SCEVType, + ArrayRef Ops) { + FoldingSetNodeID ID; + void *IP = nullptr; + ID.AddInteger(SCEVType); + for (unsigned i = 0, e = Ops.size(); i != e; ++i) + ID.AddPointer(Ops[i]); + return {UniqueSCEVs.FindNodeOrInsertPos(ID, IP), std::move(ID), IP}; +} + const SCEV * ScalarEvolution::getSMaxExpr(SmallVectorImpl &Ops) { assert(!Ops.empty() && "Cannot get empty smax!"); @@ -3537,6 +3548,11 @@ ScalarEvolution::getSMaxExpr(SmallVectorImpl &Ops) { // Sort by complexity, this groups all similar expression types together. GroupByComplexity(Ops, &LI, DT); + // Check if we have created the same SMax expression before. + if (const SCEV *S = std::get<0>(findExistingSCEVInCache(scSMaxExpr, Ops))) { + return S; + } + // If there are any constants, fold them together. unsigned Idx = 0; if (const SCEVConstant *LHSC = dyn_cast(Ops[0])) { @@ -3604,16 +3620,16 @@ ScalarEvolution::getSMaxExpr(SmallVectorImpl &Ops) { // Okay, it looks like we really DO need an smax expr. Check to see if we // already have one, otherwise create a new one. + const SCEV *ExistingSCEV; FoldingSetNodeID ID; - ID.AddInteger(scSMaxExpr); - for (unsigned i = 0, e = Ops.size(); i != e; ++i) - ID.AddPointer(Ops[i]); - void *IP = nullptr; - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + void *IP; + std::tie(ExistingSCEV, ID, IP) = findExistingSCEVInCache(scSMaxExpr, Ops); + if (ExistingSCEV) + return ExistingSCEV; const SCEV **O = SCEVAllocator.Allocate(Ops.size()); std::uninitialized_copy(Ops.begin(), Ops.end(), O); - SCEV *S = new (SCEVAllocator) SCEVSMaxExpr(ID.Intern(SCEVAllocator), - O, Ops.size()); + SCEV *S = + new (SCEVAllocator) SCEVSMaxExpr(ID.Intern(SCEVAllocator), O, Ops.size()); UniqueSCEVs.InsertNode(S, IP); addToLoopUseLists(S); return S; @@ -3639,6 +3655,11 @@ ScalarEvolution::getUMaxExpr(SmallVectorImpl &Ops) { // Sort by complexity, this groups all similar expression types together. GroupByComplexity(Ops, &LI, DT); + // Check if we have created the same UMax expression before. + if (const SCEV *S = std::get<0>(findExistingSCEVInCache(scUMaxExpr, Ops))) { + return S; + } + // If there are any constants, fold them together. unsigned Idx = 0; if (const SCEVConstant *LHSC = dyn_cast(Ops[0])) { @@ -3707,12 +3728,12 @@ ScalarEvolution::getUMaxExpr(SmallVectorImpl &Ops) { // Okay, it looks like we really DO need a umax expr. Check to see if we // already have one, otherwise create a new one. + const SCEV *ExistingSCEV; FoldingSetNodeID ID; - ID.AddInteger(scUMaxExpr); - for (unsigned i = 0, e = Ops.size(); i != e; ++i) - ID.AddPointer(Ops[i]); - void *IP = nullptr; - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + void *IP; + std::tie(ExistingSCEV, ID, IP) = findExistingSCEVInCache(scUMaxExpr, Ops); + if (ExistingSCEV) + return ExistingSCEV; const SCEV **O = SCEVAllocator.Allocate(Ops.size()); std::uninitialized_copy(Ops.begin(), Ops.end(), O); SCEV *S = new (SCEVAllocator) SCEVUMaxExpr(ID.Intern(SCEVAllocator), diff --git a/llvm/test/Analysis/ScalarEvolution/max-expr-cache.ll b/llvm/test/Analysis/ScalarEvolution/max-expr-cache.ll new file mode 100644 index 000000000000..1c50137398c1 --- /dev/null +++ b/llvm/test/Analysis/ScalarEvolution/max-expr-cache.ll @@ -0,0 +1,156 @@ +; RUN: opt < %s -analyze -scalar-evolution | FileCheck %s + +; SCEV would take a long time to compute SCEV expressions for this IR. If SCEV +; finishes in < 1 second then the bug is fixed. + +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64--linux-gnu" + +define void @smax(i32 %tmp3) { + ; CHECK-LABEL: Printing analysis 'Scalar Evolution Analysis' for function 'smax' +entry: + br label %bb4 + +bb4: + %tmp5 = phi i64 [ %tmp62, %bb61 ], [ 0, %entry ] + %tmp6 = trunc i64 %tmp5 to i32 + %tmp7 = shl nsw i32 %tmp6, 8 + %tmp8 = sub nsw i32 %tmp3, %tmp7 + %tmp9 = icmp slt i32 %tmp8, 256 + %tmp10 = select i1 %tmp9, i32 %tmp8, i32 256 + %tmp11 = add nsw i32 %tmp10, 1 + %tmp12 = icmp sgt i32 %tmp8, %tmp11 + %tmp13 = select i1 %tmp12, i32 %tmp11, i32 %tmp8 + %tmp14 = icmp slt i32 %tmp13, 256 + %tmp15 = select i1 %tmp14, i32 %tmp13, i32 256 + %tmp16 = add nsw i32 %tmp15, 1 + %tmp17 = icmp sgt i32 %tmp8, %tmp16 + %tmp18 = select i1 %tmp17, i32 %tmp16, i32 %tmp8 + %tmp19 = icmp slt i32 %tmp18, 256 + %tmp20 = select i1 %tmp19, i32 %tmp18, i32 256 + %tmp21 = add nsw i32 %tmp20, 1 + %tmp22 = icmp sgt i32 %tmp8, %tmp21 + %tmp23 = select i1 %tmp22, i32 %tmp21, i32 %tmp8 + %tmp24 = icmp slt i32 %tmp23, 256 + %tmp25 = select i1 %tmp24, i32 %tmp23, i32 256 + %tmp26 = add nsw i32 %tmp25, 1 + %tmp27 = icmp sgt i32 %tmp8, %tmp26 + %tmp28 = select i1 %tmp27, i32 %tmp26, i32 %tmp8 + %tmp29 = icmp slt i32 %tmp28, 256 + %tmp30 = select i1 %tmp29, i32 %tmp28, i32 256 + %tmp31 = add nsw i32 %tmp30, 1 + %tmp32 = icmp sgt i32 %tmp8, %tmp31 + %tmp33 = select i1 %tmp32, i32 %tmp31, i32 %tmp8 + %tmp34 = icmp slt i32 %tmp33, 256 + %tmp35 = select i1 %tmp34, i32 %tmp33, i32 256 + %tmp36 = add nsw i32 %tmp35, 1 + %tmp37 = icmp sgt i32 %tmp8, %tmp36 + %tmp38 = select i1 %tmp37, i32 %tmp36, i32 %tmp8 + %tmp39 = icmp slt i32 %tmp38, 256 + %tmp40 = select i1 %tmp39, i32 %tmp38, i32 256 + %tmp41 = add nsw i32 %tmp40, 1 + %tmp42 = icmp sgt i32 %tmp8, %tmp41 + %tmp43 = select i1 %tmp42, i32 %tmp41, i32 %tmp8 + %tmp44 = add nsw i32 %tmp10, 7 + %tmp45 = icmp slt i32 %tmp43, 256 + %tmp46 = select i1 %tmp45, i32 %tmp43, i32 256 +; CHECK: %tmp46 = select i1 %tmp45, i32 %tmp43, i32 256 +; CHECK-NEXT: --> (-1 + (-1 * (-257 smax (-1 + (-257 smax (-1 + (-257 smax (-1 + (-257 smax (-1 + (-257 smax (-1 + (-257 smax (-1 + (-257 smax (-1 + (-257 smax {(-1 + (-1 * %tmp3)),+,256}<%bb4>)) smax {(-1 + (-1 * %tmp3)),+,256}<%bb4>)) smax {(-1 + (-1 * %tmp3)),+,256}<%bb4>)) smax {(-1 + (-1 * %tmp3)),+,256}<%bb4>)) smax {(-1 + (-1 * %tmp3)),+,256}<%bb4>)) smax {(-1 + (-1 * %tmp3)),+,256}<%bb4>)) smax {(-1 + (-1 * %tmp3)),+,256}<%bb4>)) smax {(-1 + (-1 * %tmp3)),+,256}<%bb4>))) + %tmp47 = icmp sgt i32 %tmp44, %tmp46 + %tmp48 = select i1 %tmp47, i32 %tmp44, i32 %tmp46 + %tmp49 = ashr i32 %tmp48, 3 + %tmp50 = icmp sgt i32 %tmp49, 0 + %tmp51 = select i1 %tmp50, i32 %tmp49, i32 0 + %tmp52 = zext i32 %tmp51 to i64 + br label %bb53 + +bb53: + %tmp54 = phi i64 [ undef, %bb4 ], [ %tmp59, %bb53 ] + %tmp55 = trunc i64 %tmp54 to i32 + %tmp56 = shl nsw i32 %tmp55, 3 + %tmp57 = sext i32 %tmp56 to i64 + %tmp58 = getelementptr inbounds i8, i8* null, i64 %tmp57 + store i8 undef, i8* %tmp58, align 8 + %tmp59 = add nsw i64 %tmp54, 1 + %tmp60 = icmp eq i64 %tmp59, %tmp52 + br i1 %tmp60, label %bb61, label %bb53 + +bb61: + %tmp62 = add nuw nsw i64 %tmp5, 1 + br label %bb4 +} + + +define void @umax(i32 %tmp3) { +; CHECK-LABEL: Printing analysis 'Scalar Evolution Analysis' for function 'umax' +entry: + br label %bb4 + +bb4: + %tmp5 = phi i64 [ %tmp62, %bb61 ], [ 0, %entry ] + %tmp6 = trunc i64 %tmp5 to i32 + %tmp7 = shl nsw i32 %tmp6, 8 + %tmp8 = sub nsw i32 %tmp3, %tmp7 + %tmp9 = icmp ult i32 %tmp8, 256 + %tmp10 = select i1 %tmp9, i32 %tmp8, i32 256 + %tmp11 = add nsw i32 %tmp10, 1 + %tmp12 = icmp ugt i32 %tmp8, %tmp11 + %tmp13 = select i1 %tmp12, i32 %tmp11, i32 %tmp8 + %tmp14 = icmp ult i32 %tmp13, 256 + %tmp15 = select i1 %tmp14, i32 %tmp13, i32 256 + %tmp16 = add nsw i32 %tmp15, 1 + %tmp17 = icmp ugt i32 %tmp8, %tmp16 + %tmp18 = select i1 %tmp17, i32 %tmp16, i32 %tmp8 + %tmp19 = icmp ult i32 %tmp18, 256 + %tmp20 = select i1 %tmp19, i32 %tmp18, i32 256 + %tmp21 = add nsw i32 %tmp20, 1 + %tmp22 = icmp ugt i32 %tmp8, %tmp21 + %tmp23 = select i1 %tmp22, i32 %tmp21, i32 %tmp8 + %tmp24 = icmp ult i32 %tmp23, 256 + %tmp25 = select i1 %tmp24, i32 %tmp23, i32 256 + %tmp26 = add nsw i32 %tmp25, 1 + %tmp27 = icmp ugt i32 %tmp8, %tmp26 + %tmp28 = select i1 %tmp27, i32 %tmp26, i32 %tmp8 + %tmp29 = icmp ult i32 %tmp28, 256 + %tmp30 = select i1 %tmp29, i32 %tmp28, i32 256 + %tmp31 = add nsw i32 %tmp30, 1 + %tmp32 = icmp ugt i32 %tmp8, %tmp31 + %tmp33 = select i1 %tmp32, i32 %tmp31, i32 %tmp8 + %tmp34 = icmp ult i32 %tmp33, 256 + %tmp35 = select i1 %tmp34, i32 %tmp33, i32 256 + %tmp36 = add nsw i32 %tmp35, 1 + %tmp37 = icmp ugt i32 %tmp8, %tmp36 + %tmp38 = select i1 %tmp37, i32 %tmp36, i32 %tmp8 + %tmp39 = icmp ult i32 %tmp38, 256 + %tmp40 = select i1 %tmp39, i32 %tmp38, i32 256 + %tmp41 = add nsw i32 %tmp40, 1 + %tmp42 = icmp ugt i32 %tmp8, %tmp41 + %tmp43 = select i1 %tmp42, i32 %tmp41, i32 %tmp8 + %tmp44 = add nsw i32 %tmp10, 7 + %tmp45 = icmp ult i32 %tmp43, 256 + %tmp46 = select i1 %tmp45, i32 %tmp43, i32 256 +; CHECK: %tmp46 = select i1 %tmp45, i32 %tmp43, i32 256 +; CHECK-NEXT: --> (-1 + (-1 * (-257 umax (-1 + (-257 umax (-1 + (-257 umax (-1 + (-257 umax (-1 + (-257 umax (-1 + (-257 umax (-1 + (-257 umax (-1 + (-257 umax {(-1 + (-1 * %tmp3)),+,256}<%bb4>)) umax {(-1 + (-1 * %tmp3)),+,256}<%bb4>)) umax {(-1 + (-1 * %tmp3)),+,256}<%bb4>)) umax {(-1 + (-1 * %tmp3)),+,256}<%bb4>)) umax {(-1 + (-1 * %tmp3)),+,256}<%bb4>)) umax {(-1 + (-1 * %tmp3)),+,256}<%bb4>)) umax {(-1 + (-1 * %tmp3)),+,256}<%bb4>)) umax {(-1 + (-1 * %tmp3)),+,256}<%bb4>))) + %tmp47 = icmp ugt i32 %tmp44, %tmp46 + %tmp48 = select i1 %tmp47, i32 %tmp44, i32 %tmp46 + %tmp49 = ashr i32 %tmp48, 3 + %tmp50 = icmp ugt i32 %tmp49, 0 + %tmp51 = select i1 %tmp50, i32 %tmp49, i32 0 + %tmp52 = zext i32 %tmp51 to i64 + br label %bb53 + +bb53: + %tmp54 = phi i64 [ undef, %bb4 ], [ %tmp59, %bb53 ] + %tmp55 = trunc i64 %tmp54 to i32 + %tmp56 = shl nsw i32 %tmp55, 3 + %tmp57 = sext i32 %tmp56 to i64 + %tmp58 = getelementptr inbounds i8, i8* null, i64 %tmp57 + store i8 undef, i8* %tmp58, align 8 + %tmp59 = add nsw i64 %tmp54, 1 + %tmp60 = icmp eq i64 %tmp59, %tmp52 + br i1 %tmp60, label %bb61, label %bb53 + +bb61: + %tmp62 = add nuw nsw i64 %tmp5, 1 + br label %bb4 +}