From 3f8711fc5e01685f0a751ef296d16cf9a1f4fd4d Mon Sep 17 00:00:00 2001 From: Yingwei Zheng Date: Tue, 5 Mar 2024 22:34:04 +0800 Subject: [PATCH] [InstCombine] Fix miscompilation in PR83947 (#83993) https://github.com/llvm/llvm-project/blob/762f762504967efbe159db5c737154b989afc9bb/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp#L394-L407 Comment from @topperc: > This transforms assumes the mask is a non-zero splat. We only know its a splat and not provably all 0s. The mask is a constexpr that includes the address of the global variable. We can't resolve the constant expression to an exact value. Fixes #83947. --- llvm/include/llvm/Analysis/VectorUtils.h | 5 ++ llvm/lib/Analysis/VectorUtils.cpp | 25 +++++++ .../InstCombine/InstCombineCalls.cpp | 13 ++-- .../InstCombine/masked_intrinsics.ll | 6 +- llvm/test/Transforms/InstCombine/pr83947.ll | 67 +++++++++++++++++++ 5 files changed, 110 insertions(+), 6 deletions(-) create mode 100644 llvm/test/Transforms/InstCombine/pr83947.ll diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h index 7a92e62b53c5..c6eb66cc9660 100644 --- a/llvm/include/llvm/Analysis/VectorUtils.h +++ b/llvm/include/llvm/Analysis/VectorUtils.h @@ -406,6 +406,11 @@ bool maskIsAllZeroOrUndef(Value *Mask); /// lanes can be assumed active. bool maskIsAllOneOrUndef(Value *Mask); +/// Given a mask vector of i1, Return true if any of the elements of this +/// predicate mask are known to be true or undef. That is, return true if at +/// least one lane can be assumed active. +bool maskContainsAllOneOrUndef(Value *Mask); + /// Given a mask vector of the form , return an APInt (of bitwidth Y) /// for each lane which may be active. APInt possiblyDemandedEltsInMask(Value *Mask); diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp index 73facc76a92b..bf7bc0ba84a0 100644 --- a/llvm/lib/Analysis/VectorUtils.cpp +++ b/llvm/lib/Analysis/VectorUtils.cpp @@ -1012,6 +1012,31 @@ bool llvm::maskIsAllOneOrUndef(Value *Mask) { return true; } +bool llvm::maskContainsAllOneOrUndef(Value *Mask) { + assert(isa(Mask->getType()) && + isa(Mask->getType()->getScalarType()) && + cast(Mask->getType()->getScalarType())->getBitWidth() == + 1 && + "Mask must be a vector of i1"); + + auto *ConstMask = dyn_cast(Mask); + if (!ConstMask) + return false; + if (ConstMask->isAllOnesValue() || isa(ConstMask)) + return true; + if (isa(ConstMask->getType())) + return false; + for (unsigned + I = 0, + E = cast(ConstMask->getType())->getNumElements(); + I != E; ++I) { + if (auto *MaskElt = ConstMask->getAggregateElement(I)) + if (MaskElt->isAllOnesValue() || isa(MaskElt)) + return true; + } + return false; +} + /// TODO: This is a lot like known bits, but for /// vectors. Is there something we can common this with? APInt llvm::possiblyDemandedEltsInMask(Value *Mask) { diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index a647be2d26c7..bc43edb5e620 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -412,11 +412,14 @@ Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) { if (auto *SplatPtr = getSplatValue(II.getArgOperand(1))) { // scatter(splat(value), splat(ptr), non-zero-mask) -> store value, ptr if (auto *SplatValue = getSplatValue(II.getArgOperand(0))) { - Align Alignment = cast(II.getArgOperand(2))->getAlignValue(); - StoreInst *S = - new StoreInst(SplatValue, SplatPtr, /*IsVolatile=*/false, Alignment); - S->copyMetadata(II); - return S; + if (maskContainsAllOneOrUndef(ConstMask)) { + Align Alignment = + cast(II.getArgOperand(2))->getAlignValue(); + StoreInst *S = new StoreInst(SplatValue, SplatPtr, /*IsVolatile=*/false, + Alignment); + S->copyMetadata(II); + return S; + } } // scatter(vector, splat(ptr), splat(true)) -> store extract(vector, // lastlane), ptr diff --git a/llvm/test/Transforms/InstCombine/masked_intrinsics.ll b/llvm/test/Transforms/InstCombine/masked_intrinsics.ll index 2704905f7a35..c87c1199f727 100644 --- a/llvm/test/Transforms/InstCombine/masked_intrinsics.ll +++ b/llvm/test/Transforms/InstCombine/masked_intrinsics.ll @@ -292,7 +292,11 @@ entry: define void @scatter_nxv4i16_uniform_vals_uniform_ptrs_all_active_mask(ptr %dst, i16 %val) { ; CHECK-LABEL: @scatter_nxv4i16_uniform_vals_uniform_ptrs_all_active_mask( ; CHECK-NEXT: entry: -; CHECK-NEXT: store i16 [[VAL:%.*]], ptr [[DST:%.*]], align 2 +; CHECK-NEXT: [[BROADCAST_SPLATINSERT:%.*]] = insertelement poison, ptr [[DST:%.*]], i64 0 +; CHECK-NEXT: [[BROADCAST_SPLAT:%.*]] = shufflevector [[BROADCAST_SPLATINSERT]], poison, zeroinitializer +; CHECK-NEXT: [[BROADCAST_VALUE:%.*]] = insertelement poison, i16 [[VAL:%.*]], i64 0 +; CHECK-NEXT: [[BROADCAST_SPLATVALUE:%.*]] = shufflevector [[BROADCAST_VALUE]], poison, zeroinitializer +; CHECK-NEXT: call void @llvm.masked.scatter.nxv4i16.nxv4p0( [[BROADCAST_SPLATVALUE]], [[BROADCAST_SPLAT]], i32 2, shufflevector ( insertelement ( zeroinitializer, i1 true, i32 0), zeroinitializer, zeroinitializer)) ; CHECK-NEXT: ret void ; entry: diff --git a/llvm/test/Transforms/InstCombine/pr83947.ll b/llvm/test/Transforms/InstCombine/pr83947.ll new file mode 100644 index 000000000000..c1d601ff6371 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/pr83947.ll @@ -0,0 +1,67 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4 +; RUN: opt -S -passes=instcombine < %s | FileCheck %s + +@c = global i32 0, align 4 +@b = global i32 0, align 4 + +define void @masked_scatter1() { +; CHECK-LABEL: define void @masked_scatter1() { +; CHECK-NEXT: call void @llvm.masked.scatter.nxv4i32.nxv4p0( zeroinitializer, shufflevector ( insertelement ( poison, ptr @c, i64 0), poison, zeroinitializer), i32 4, shufflevector ( insertelement ( poison, i1 icmp eq (ptr getelementptr inbounds (i32, ptr @b, i64 1), ptr @c), i64 0), poison, zeroinitializer)) +; CHECK-NEXT: ret void +; + call void @llvm.masked.scatter.nxv4i32.nxv4p0( zeroinitializer, splat (ptr @c), i32 4, splat (i1 icmp eq (ptr getelementptr (i32, ptr @b, i64 1), ptr @c))) + ret void +} + +define void @masked_scatter2() { +; CHECK-LABEL: define void @masked_scatter2() { +; CHECK-NEXT: store i32 0, ptr @c, align 4 +; CHECK-NEXT: ret void +; + call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> splat (i1 true)) + ret void +} + +define void @masked_scatter3() { +; CHECK-LABEL: define void @masked_scatter3() { +; CHECK-NEXT: store i32 0, ptr @c, align 4 +; CHECK-NEXT: ret void +; + call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> undef) + ret void +} + +define void @masked_scatter4() { +; CHECK-LABEL: define void @masked_scatter4() { +; CHECK-NEXT: ret void +; + call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> splat (i1 false)) + ret void +} + +define void @masked_scatter5() { +; CHECK-LABEL: define void @masked_scatter5() { +; CHECK-NEXT: store i32 0, ptr @c, align 4 +; CHECK-NEXT: ret void +; + call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> ) + ret void +} + +define void @masked_scatter6() { +; CHECK-LABEL: define void @masked_scatter6() { +; CHECK-NEXT: store i32 0, ptr @c, align 4 +; CHECK-NEXT: ret void +; + call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> ) + ret void +} + +define void @masked_scatter7() { +; CHECK-LABEL: define void @masked_scatter7() { +; CHECK-NEXT: call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> , i32 4, <2 x i1> ) +; CHECK-NEXT: ret void +; + call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> splat (i1 icmp eq (ptr getelementptr (i32, ptr @b, i64 1), ptr @c))) + ret void +}