From 6d67ce109bff25fa19074622dc06722cbc00990f Mon Sep 17 00:00:00 2001 From: David Majnemer Date: Thu, 14 Jul 2016 06:58:37 +0000 Subject: [PATCH] Simplify llvm.masked.load w/ undef masks We can always pick the passthru value if the mask is undef: we are permitted to treat the mask as-if it were filled with zeros. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@275379 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/Analysis/ConstantFolding.cpp | 35 +++++++++++++++++----------- lib/Analysis/InstructionSimplify.cpp | 26 +++++++++++++++++---- test/Transforms/InstSimplify/call.ll | 7 ++++++ 3 files changed, 49 insertions(+), 19 deletions(-) diff --git a/lib/Analysis/ConstantFolding.cpp b/lib/Analysis/ConstantFolding.cpp index 96a2d02ed5b..6c471ab4504 100644 --- a/lib/Analysis/ConstantFolding.cpp +++ b/lib/Analysis/ConstantFolding.cpp @@ -1854,32 +1854,39 @@ Constant *ConstantFoldVectorCall(StringRef Name, unsigned IntrinsicID, auto *SrcPtr = Operands[0]; auto *Mask = Operands[2]; auto *Passthru = Operands[3]; + Constant *VecData = ConstantFoldLoadFromConstPtr(SrcPtr, VTy, DL); - if (!VecData) - return nullptr; SmallVector NewElements; for (unsigned I = 0, E = VTy->getNumElements(); I != E; ++I) { - auto *MaskElt = - dyn_cast_or_null(Mask->getAggregateElement(I)); + auto *MaskElt = Mask->getAggregateElement(I); if (!MaskElt) break; - if (MaskElt->isZero()) { - auto *PassthruElt = Passthru->getAggregateElement(I); + auto *PassthruElt = Passthru->getAggregateElement(I); + auto *VecElt = VecData ? VecData->getAggregateElement(I) : nullptr; + if (isa(MaskElt)) { + if (PassthruElt) + NewElements.push_back(PassthruElt); + else if (VecElt) + NewElements.push_back(VecElt); + else + return nullptr; + } + if (MaskElt->isNullValue()) { if (!PassthruElt) - break; + return nullptr; NewElements.push_back(PassthruElt); - } else { - assert(MaskElt->isOne()); - auto *VecElt = VecData->getAggregateElement(I); + } else if (MaskElt->isOneValue()) { if (!VecElt) - break; + return nullptr; NewElements.push_back(VecElt); + } else { + return nullptr; } } - if (NewElements.size() == VTy->getNumElements()) - return ConstantVector::get(NewElements); - return nullptr; + if (NewElements.size() != VTy->getNumElements()) + return nullptr; + return ConstantVector::get(NewElements); } for (unsigned I = 0, E = VTy->getNumElements(); I != E; ++I) { diff --git a/lib/Analysis/InstructionSimplify.cpp b/lib/Analysis/InstructionSimplify.cpp index 609cd26bcd0..0cb2c78afb4 100644 --- a/lib/Analysis/InstructionSimplify.cpp +++ b/lib/Analysis/InstructionSimplify.cpp @@ -3944,6 +3944,22 @@ static Value *SimplifyRelativeLoad(Constant *Ptr, Constant *Offset, return ConstantExpr::getBitCast(LoadedLHSPtr, Int8PtrTy); } +static bool maskIsAllZeroOrUndef(Value *Mask) { + auto *ConstMask = dyn_cast(Mask); + if (!ConstMask) + return false; + if (ConstMask->isNullValue() || isa(ConstMask)) + return true; + for (unsigned I = 0, E = ConstMask->getType()->getVectorNumElements(); I != E; + ++I) { + if (auto *MaskElt = ConstMask->getAggregateElement(I)) + if (MaskElt->isNullValue() || isa(MaskElt)) + continue; + return false; + } + return true; +} + template static Value *SimplifyIntrinsic(Function *F, IterTy ArgBegin, IterTy ArgEnd, const Query &Q, unsigned MaxRecurse) { @@ -3993,11 +4009,11 @@ static Value *SimplifyIntrinsic(Function *F, IterTy ArgBegin, IterTy ArgEnd, // Simplify calls to llvm.masked.load.* if (IID == Intrinsic::masked_load) { - IterTy MaskArg = ArgBegin + 2; - // If the mask is all zeros, the "passthru" argument is the result. - if (auto *ConstMask = dyn_cast(*MaskArg)) - if (ConstMask->isNullValue()) - return ArgBegin[3]; + Value *MaskArg = ArgBegin[2]; + Value *PassthruArg = ArgBegin[3]; + // If the mask is all zeros or undef, the "passthru" argument is the result. + if (maskIsAllZeroOrUndef(MaskArg)) + return PassthruArg; } // Perform idempotent optimizations diff --git a/test/Transforms/InstSimplify/call.ll b/test/Transforms/InstSimplify/call.ll index e0a071a3bb1..988ec2b71c5 100644 --- a/test/Transforms/InstSimplify/call.ll +++ b/test/Transforms/InstSimplify/call.ll @@ -213,6 +213,13 @@ define <8 x i32> @partial_masked_load() { ret <8 x i32> %masked.load } +define <8 x i32> @masked_load_undef_mask(<8 x i32>* %V) { +; CHECK-LABEL: @masked_load_undef_mask( +; CHECK: ret <8 x i32> + %masked.load = call <8 x i32> @llvm.masked.load.v8i32.p0v8i32(<8 x i32>* %V, i32 4, <8 x i1> undef, <8 x i32> ) + ret <8 x i32> %masked.load +} + declare noalias i8* @malloc(i64) declare <8 x i32> @llvm.masked.load.v8i32.p0v8i32(<8 x i32>*, i32, <8 x i1>, <8 x i32>)