From aed4e7449f23144f60be43fe6e6eb43ce36bdf0c Mon Sep 17 00:00:00 2001 From: Simon Pilgrim Date: Wed, 15 Sep 2021 10:20:54 +0100 Subject: [PATCH] [X86] combineX86ShuffleChain - ensure we only peek through bitcasts to vectors (PR51858) When searching for hidden identity shuffles (added at rG41146bfe82aecc79961c3de898cda02998172e4b), only peek through bitcasts to the source operand if it is a vector type as well. (cherry picked from commit dcba99418438ec1d624ad207674234bd2e9e3394) --- lib/Target/X86/X86ISelLowering.cpp | 2 +- test/CodeGen/X86/vector-reduce-mul.ll | 58 +++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 1 deletion(-) diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp index a6985089643..032db2a80a7 100644 --- a/lib/Target/X86/X86ISelLowering.cpp +++ b/lib/Target/X86/X86ISelLowering.cpp @@ -35823,7 +35823,7 @@ static SDValue combineX86ShuffleChain(ArrayRef Inputs, SDValue Root, // See if the shuffle is a hidden identity shuffle - repeated args in HOPs // etc. can be simplified. - if (VT1 == VT2 && VT1.getSizeInBits() == RootSizeInBits) { + if (VT1 == VT2 && VT1.getSizeInBits() == RootSizeInBits && VT1.isVector()) { SmallVector ScaledMask, IdentityMask; unsigned NumElts = VT1.getVectorNumElements(); if (BaseMask.size() <= NumElts && diff --git a/test/CodeGen/X86/vector-reduce-mul.ll b/test/CodeGen/X86/vector-reduce-mul.ll index 5484eeeff45..fc1beb69d5d 100644 --- a/test/CodeGen/X86/vector-reduce-mul.ll +++ b/test/CodeGen/X86/vector-reduce-mul.ll @@ -2344,6 +2344,64 @@ define i8 @illegal_v8i8(i8 %a0, <8 x i8>* %a1) { ret i8 %mul } +define i8 @PR51858(i128 %arg) { +; SSE2-LABEL: PR51858: +; SSE2: # %bb.0: +; SSE2-NEXT: movq %rdi, %xmm0 +; SSE2-NEXT: movq %rsi, %xmm1 +; SSE2-NEXT: punpcklbw {{.*#+}} xmm1 = xmm1[0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7] +; SSE2-NEXT: punpcklbw {{.*#+}} xmm0 = xmm0[0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7] +; SSE2-NEXT: pmullw %xmm1, %xmm0 +; SSE2-NEXT: pshufd {{.*#+}} xmm1 = xmm0[2,3,2,3] +; SSE2-NEXT: pmullw %xmm0, %xmm1 +; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm1[1,1,1,1] +; SSE2-NEXT: pmullw %xmm1, %xmm0 +; SSE2-NEXT: movdqa %xmm0, %xmm1 +; SSE2-NEXT: psrld $16, %xmm1 +; SSE2-NEXT: pmullw %xmm0, %xmm1 +; SSE2-NEXT: movd %xmm1, %eax +; SSE2-NEXT: # kill: def $al killed $al killed $eax +; SSE2-NEXT: retq +; +; SSE41-LABEL: PR51858: +; SSE41: # %bb.0: +; SSE41-NEXT: movq %rdi, %xmm0 +; SSE41-NEXT: movq %rsi, %xmm1 +; SSE41-NEXT: pmovzxbw {{.*#+}} xmm1 = xmm1[0],zero,xmm1[1],zero,xmm1[2],zero,xmm1[3],zero,xmm1[4],zero,xmm1[5],zero,xmm1[6],zero,xmm1[7],zero +; SSE41-NEXT: pmovzxbw {{.*#+}} xmm0 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero,xmm0[4],zero,xmm0[5],zero,xmm0[6],zero,xmm0[7],zero +; SSE41-NEXT: pmullw %xmm1, %xmm0 +; SSE41-NEXT: pshufd {{.*#+}} xmm1 = xmm0[2,3,2,3] +; SSE41-NEXT: pmullw %xmm0, %xmm1 +; SSE41-NEXT: pshufd {{.*#+}} xmm0 = xmm1[1,1,1,1] +; SSE41-NEXT: pmullw %xmm1, %xmm0 +; SSE41-NEXT: movdqa %xmm0, %xmm1 +; SSE41-NEXT: psrld $16, %xmm1 +; SSE41-NEXT: pmullw %xmm0, %xmm1 +; SSE41-NEXT: movd %xmm1, %eax +; SSE41-NEXT: # kill: def $al killed $al killed $eax +; SSE41-NEXT: retq +; +; AVX-LABEL: PR51858: +; AVX: # %bb.0: +; AVX-NEXT: vmovq %rdi, %xmm0 +; AVX-NEXT: vmovq %rsi, %xmm1 +; AVX-NEXT: vpmovzxbw {{.*#+}} xmm1 = xmm1[0],zero,xmm1[1],zero,xmm1[2],zero,xmm1[3],zero,xmm1[4],zero,xmm1[5],zero,xmm1[6],zero,xmm1[7],zero +; AVX-NEXT: vpmovzxbw {{.*#+}} xmm0 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero,xmm0[4],zero,xmm0[5],zero,xmm0[6],zero,xmm0[7],zero +; AVX-NEXT: vpmullw %xmm1, %xmm0, %xmm0 +; AVX-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,2,3] +; AVX-NEXT: vpmullw %xmm1, %xmm0, %xmm0 +; AVX-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1] +; AVX-NEXT: vpmullw %xmm1, %xmm0, %xmm0 +; AVX-NEXT: vpsrld $16, %xmm0, %xmm1 +; AVX-NEXT: vpmullw %xmm1, %xmm0, %xmm0 +; AVX-NEXT: vmovd %xmm0, %eax +; AVX-NEXT: # kill: def $al killed $al killed $eax +; AVX-NEXT: retq + %vec = bitcast i128 %arg to <16 x i8> + %red = tail call i8 @llvm.vector.reduce.mul.v16i8(<16 x i8> %vec) + ret i8 %red +} + declare i64 @llvm.vector.reduce.mul.v2i64(<2 x i64>) declare i64 @llvm.vector.reduce.mul.v4i64(<4 x i64>) declare i64 @llvm.vector.reduce.mul.v8i64(<8 x i64>)