diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 82cee72d812e..188a808f569f 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -16551,13 +16551,20 @@ static SDValue lower1BitShuffle(const SDLoc &DL, ArrayRef Mask, assert(Subtarget.hasAVX512() && "Cannot lower 512-bit vectors w/o basic ISA!"); - unsigned NumElts = Mask.size(); + int NumElts = Mask.size(); // Try to recognize shuffles that are just padding a subvector with zeros. - unsigned SubvecElts = 0; - for (int i = 0; i != (int)NumElts; ++i) { - if (Mask[i] >= 0 && Mask[i] != i) - break; + int SubvecElts = 0; + int Src = -1; + for (int i = 0; i != NumElts; ++i) { + if (Mask[i] >= 0) { + // Grab the source from the first valid mask. All subsequent elements need + // to use this same source. + if (Src < 0) + Src = Mask[i] / NumElts; + if (Src != (Mask[i] / NumElts) || (Mask[i] % NumElts) != i) + break; + } ++SubvecElts; } @@ -16568,10 +16575,12 @@ static SDValue lower1BitShuffle(const SDLoc &DL, ArrayRef Mask, // Make sure the number of zeroable bits in the top at least covers the bits // not covered by the subvector. - if (Zeroable.countLeadingOnes() >= (NumElts - SubvecElts)) { + if ((int)Zeroable.countLeadingOnes() >= (NumElts - SubvecElts)) { + assert(Src >= 0 && "Expected a source!"); MVT ExtractVT = MVT::getVectorVT(MVT::i1, SubvecElts); SDValue Extract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ExtractVT, - V1, DAG.getIntPtrConstant(0, DL)); + Src == 0 ? V1 : V2, + DAG.getIntPtrConstant(0, DL)); return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, DAG.getConstant(0, DL, VT), Extract, DAG.getIntPtrConstant(0, DL)); diff --git a/llvm/test/CodeGen/X86/avx512-skx-insert-subvec.ll b/llvm/test/CodeGen/X86/avx512-skx-insert-subvec.ll index 76c983e6708c..a24c1d8c2fcc 100644 --- a/llvm/test/CodeGen/X86/avx512-skx-insert-subvec.ll +++ b/llvm/test/CodeGen/X86/avx512-skx-insert-subvec.ll @@ -205,12 +205,8 @@ define i8 @test15(<2 x i64> %x) { ; CHECK-LABEL: test15: ; CHECK: # %bb.0: ; CHECK-NEXT: vptestnmq %xmm0, %xmm0, %k0 -; CHECK-NEXT: vpmovm2d %k0, %ymm0 -; CHECK-NEXT: vmovq {{.*#+}} xmm0 = xmm0[0],zero -; CHECK-NEXT: vpmovd2m %ymm0, %k0 ; CHECK-NEXT: kmovd %k0, %eax ; CHECK-NEXT: # kill: def $al killed $al killed $eax -; CHECK-NEXT: vzeroupper ; CHECK-NEXT: retq %a = icmp eq <2 x i64> %x, zeroinitializer %b = shufflevector <2 x i1> %a, <2 x i1> , <8 x i32>