From f372d2334fa933b1b1aa7ba87698ae0dc1953ce8 Mon Sep 17 00:00:00 2001 From: Sebastian Pop Date: Fri, 30 Nov 2012 19:08:04 +0000 Subject: [PATCH] Codegen failure for vmull with small vectors Codegen was failing with an assertion because of unexpected vector operands when legalizing the selection DAG for a MUL instruction. The asserting code was legalizing multiplies for vectors of size 128 bits. It uses a custom lowering to try and detect cases where it can use a VMULL instruction instead of a VMOVL + VMUL. The code was looking for input operands to the MUL that had been sign or zero extended. If it found the extended operands it would drop the sign/zero extension and use the original vector size as input to a VMULL instruction. The code assumed that the original input vector was 64 bits so that after dropping the extension it would fit directly into a D register and could be used as an operand of a VMULL instruction. The input code that trigger the failure used a vector of <4 x i8> that was sign extended to <4 x i32>. It was not safe to drop the sign extension in this case because the original vector is only 32 bits wide. The fix is to insert a sign extension for the vector to reach the required 64 bit size. In this particular example, the vector would need to be sign extented to a <4 x i16>. llvm-svn: 169024 --- lib/Target/ARM/ARMISelLowering.cpp | 87 ++++++++-- test/CodeGen/ARM/2012-08-23-legalize-vmull.ll | 150 ++++++++++++++++++ 2 files changed, 224 insertions(+), 13 deletions(-) create mode 100644 test/CodeGen/ARM/2012-08-23-legalize-vmull.ll diff --git a/lib/Target/ARM/ARMISelLowering.cpp b/lib/Target/ARM/ARMISelLowering.cpp index d139a568a74..8ccb3c3eb76 100644 --- a/lib/Target/ARM/ARMISelLowering.cpp +++ b/lib/Target/ARM/ARMISelLowering.cpp @@ -4939,16 +4939,76 @@ static bool isZeroExtended(SDNode *N, SelectionDAG &DAG) { return false; } -/// SkipExtension - For a node that is a SIGN_EXTEND, ZERO_EXTEND, extending -/// load, or BUILD_VECTOR with extended elements, return the unextended value. -static SDValue SkipExtension(SDNode *N, SelectionDAG &DAG) { +/// AddRequiredExtensionForVMULL - Add a sign/zero extension to extend the total +/// value size to 64 bits. We need a 64-bit D register as an operand to VMULL. +/// We insert the required extension here to get the vector to fill a D register. +static SDValue AddRequiredExtensionForVMULL(SDValue N, SelectionDAG &DAG, + const EVT &OrigTy, + const EVT &ExtTy, + unsigned ExtOpcode) { + // The vector originally had a size of OrigTy. It was then extended to ExtTy. + // We expect the ExtTy to be 128-bits total. If the OrigTy is less than + // 64-bits we need to insert a new extension so that it will be 64-bits. + assert(ExtTy.is128BitVector() && "Unexpected extension size"); + if (OrigTy.getSizeInBits() >= 64) + return N; + + // Must extend size to at least 64 bits to be used as an operand for VMULL. + MVT::SimpleValueType OrigSimpleTy = OrigTy.getSimpleVT().SimpleTy; + EVT NewVT; + switch (OrigSimpleTy) { + default: llvm_unreachable("Unexpected Orig Vector Type"); + case MVT::v2i8: + case MVT::v2i16: + NewVT = MVT::v2i32; + break; + case MVT::v4i8: + NewVT = MVT::v4i16; + break; + } + return DAG.getNode(ExtOpcode, N->getDebugLoc(), NewVT, N); +} + +/// SkipLoadExtensionForVMULL - return a load of the original vector size that +/// does not do any sign/zero extension. If the original vector is less +/// than 64 bits, an appropriate extension will be added after the load to +/// reach a total size of 64 bits. We have to add the extension separately +/// because ARM does not have a sign/zero extending load for vectors. +static SDValue SkipLoadExtensionForVMULL(LoadSDNode *LD, SelectionDAG& DAG) { + SDValue NonExtendingLoad = + DAG.getLoad(LD->getMemoryVT(), LD->getDebugLoc(), LD->getChain(), + LD->getBasePtr(), LD->getPointerInfo(), LD->isVolatile(), + LD->isNonTemporal(), LD->isInvariant(), + LD->getAlignment()); + unsigned ExtOp = 0; + switch (LD->getExtensionType()) { + default: llvm_unreachable("Unexpected LoadExtType"); + case ISD::EXTLOAD: + case ISD::SEXTLOAD: ExtOp = ISD::SIGN_EXTEND; break; + case ISD::ZEXTLOAD: ExtOp = ISD::ZERO_EXTEND; break; + } + MVT::SimpleValueType MemType = LD->getMemoryVT().getSimpleVT().SimpleTy; + MVT::SimpleValueType ExtType = LD->getValueType(0).getSimpleVT().SimpleTy; + return AddRequiredExtensionForVMULL(NonExtendingLoad, DAG, + MemType, ExtType, ExtOp); +} + +/// SkipExtensionForVMULL - For a node that is a SIGN_EXTEND, ZERO_EXTEND, +/// extending load, or BUILD_VECTOR with extended elements, return the +/// unextended value. The unextended vector should be 64 bits so that it can +/// be used as an operand to a VMULL instruction. If the original vector size +/// before extension is less than 64 bits we add a an extension to resize +/// the vector to 64 bits. +static SDValue SkipExtensionForVMULL(SDNode *N, SelectionDAG &DAG) { if (N->getOpcode() == ISD::SIGN_EXTEND || N->getOpcode() == ISD::ZERO_EXTEND) - return N->getOperand(0); + return AddRequiredExtensionForVMULL(N->getOperand(0), DAG, + N->getOperand(0)->getValueType(0), + N->getValueType(0), + N->getOpcode()); + if (LoadSDNode *LD = dyn_cast(N)) - return DAG.getLoad(LD->getMemoryVT(), N->getDebugLoc(), LD->getChain(), - LD->getBasePtr(), LD->getPointerInfo(), LD->isVolatile(), - LD->isNonTemporal(), LD->isInvariant(), - LD->getAlignment()); + return SkipLoadExtensionForVMULL(LD, DAG); + // Otherwise, the value must be a BUILD_VECTOR. For v2i64, it will // have been legalized as a BITCAST from v4i32. if (N->getOpcode() == ISD::BITCAST) { @@ -5003,7 +5063,8 @@ static SDValue LowerMUL(SDValue Op, SelectionDAG &DAG) { // Multiplications are only custom-lowered for 128-bit vectors so that // VMULL can be detected. Otherwise v2i64 multiplications are not legal. EVT VT = Op.getValueType(); - assert(VT.is128BitVector() && "unexpected type for custom-lowering ISD::MUL"); + assert(VT.is128BitVector() && VT.isInteger() && + "unexpected type for custom-lowering ISD::MUL"); SDNode *N0 = Op.getOperand(0).getNode(); SDNode *N1 = Op.getOperand(1).getNode(); unsigned NewOpc = 0; @@ -5046,9 +5107,9 @@ static SDValue LowerMUL(SDValue Op, SelectionDAG &DAG) { // Legalize to a VMULL instruction. DebugLoc DL = Op.getDebugLoc(); SDValue Op0; - SDValue Op1 = SkipExtension(N1, DAG); + SDValue Op1 = SkipExtensionForVMULL(N1, DAG); if (!isMLA) { - Op0 = SkipExtension(N0, DAG); + Op0 = SkipExtensionForVMULL(N0, DAG); assert(Op0.getValueType().is64BitVector() && Op1.getValueType().is64BitVector() && "unexpected types for extended operands to VMULL"); @@ -5063,8 +5124,8 @@ static SDValue LowerMUL(SDValue Op, SelectionDAG &DAG) { // vaddl q0, d4, d5 // vmovl q1, d6 // vmul q0, q0, q1 - SDValue N00 = SkipExtension(N0->getOperand(0).getNode(), DAG); - SDValue N01 = SkipExtension(N0->getOperand(1).getNode(), DAG); + SDValue N00 = SkipExtensionForVMULL(N0->getOperand(0).getNode(), DAG); + SDValue N01 = SkipExtensionForVMULL(N0->getOperand(1).getNode(), DAG); EVT Op1VT = Op1.getValueType(); return DAG.getNode(N0->getOpcode(), DL, VT, DAG.getNode(NewOpc, DL, VT, diff --git a/test/CodeGen/ARM/2012-08-23-legalize-vmull.ll b/test/CodeGen/ARM/2012-08-23-legalize-vmull.ll new file mode 100644 index 00000000000..2f55204aa40 --- /dev/null +++ b/test/CodeGen/ARM/2012-08-23-legalize-vmull.ll @@ -0,0 +1,150 @@ +; RUN: llc < %s -march=arm -mattr=+neon | FileCheck %s + +; PR12281 +; Test generataion of code for vmull instruction when multiplying 128-bit +; vectors that were created by sign-extending smaller vector sizes. +; +; The vmull operation requires 64-bit vectors, so we must extend the original +; vector size to 64 bits for vmull operation. +; Previously failed with an assertion because the <4 x i8> vector was too small +; for vmull. + +; Vector x Constant +; v4i8 +; +define void @sextload_v4i8_c(<4 x i8>* %v) nounwind { +;CHECK: sextload_v4i8_c: +entry: + %0 = load <4 x i8>* %v, align 8 + %v0 = sext <4 x i8> %0 to <4 x i32> +;CHECK: vmull + %v1 = mul <4 x i32> %v0, + store <4 x i32> %v1, <4 x i32>* undef, align 8 + ret void; +} + +; v2i8 +; +define void @sextload_v2i8_c(<2 x i8>* %v) nounwind { +;CHECK: sextload_v2i8_c: +entry: + %0 = load <2 x i8>* %v, align 8 + %v0 = sext <2 x i8> %0 to <2 x i64> +;CHECK: vmull + %v1 = mul <2 x i64> %v0, + store <2 x i64> %v1, <2 x i64>* undef, align 8 + ret void; +} + +; v2i16 +; +define void @sextload_v2i16_c(<2 x i16>* %v) nounwind { +;CHECK: sextload_v2i16_c: +entry: + %0 = load <2 x i16>* %v, align 8 + %v0 = sext <2 x i16> %0 to <2 x i64> +;CHECK: vmull + %v1 = mul <2 x i64> %v0, + store <2 x i64> %v1, <2 x i64>* undef, align 8 + ret void; +} + + +; Vector x Vector +; v4i8 +; +define void @sextload_v4i8_v(<4 x i8>* %v, <4 x i8>* %p) nounwind { +;CHECK: sextload_v4i8_v: +entry: + %0 = load <4 x i8>* %v, align 8 + %v0 = sext <4 x i8> %0 to <4 x i32> + + %1 = load <4 x i8>* %p, align 8 + %v2 = sext <4 x i8> %1 to <4 x i32> +;CHECK: vmull + %v1 = mul <4 x i32> %v0, %v2 + store <4 x i32> %v1, <4 x i32>* undef, align 8 + ret void; +} + +; v2i8 +; +define void @sextload_v2i8_v(<2 x i8>* %v, <2 x i8>* %p) nounwind { +;CHECK: sextload_v2i8_v: +entry: + %0 = load <2 x i8>* %v, align 8 + %v0 = sext <2 x i8> %0 to <2 x i64> + + %1 = load <2 x i8>* %p, align 8 + %v2 = sext <2 x i8> %1 to <2 x i64> +;CHECK: vmull + %v1 = mul <2 x i64> %v0, %v2 + store <2 x i64> %v1, <2 x i64>* undef, align 8 + ret void; +} + +; v2i16 +; +define void @sextload_v2i16_v(<2 x i16>* %v, <2 x i16>* %p) nounwind { +;CHECK: sextload_v2i16_v: +entry: + %0 = load <2 x i16>* %v, align 8 + %v0 = sext <2 x i16> %0 to <2 x i64> + + %1 = load <2 x i16>* %p, align 8 + %v2 = sext <2 x i16> %1 to <2 x i64> +;CHECK: vmull + %v1 = mul <2 x i64> %v0, %v2 + store <2 x i64> %v1, <2 x i64>* undef, align 8 + ret void; +} + + +; Vector(small) x Vector(big) +; v4i8 x v4i16 +; +define void @sextload_v4i8_vs(<4 x i8>* %v, <4 x i16>* %p) nounwind { +;CHECK: sextload_v4i8_vs: +entry: + %0 = load <4 x i8>* %v, align 8 + %v0 = sext <4 x i8> %0 to <4 x i32> + + %1 = load <4 x i16>* %p, align 8 + %v2 = sext <4 x i16> %1 to <4 x i32> +;CHECK: vmull + %v1 = mul <4 x i32> %v0, %v2 + store <4 x i32> %v1, <4 x i32>* undef, align 8 + ret void; +} + +; v2i8 +; v2i8 x v2i16 +define void @sextload_v2i8_vs(<2 x i8>* %v, <2 x i16>* %p) nounwind { +;CHECK: sextload_v2i8_vs: +entry: + %0 = load <2 x i8>* %v, align 8 + %v0 = sext <2 x i8> %0 to <2 x i64> + + %1 = load <2 x i16>* %p, align 8 + %v2 = sext <2 x i16> %1 to <2 x i64> +;CHECK: vmull + %v1 = mul <2 x i64> %v0, %v2 + store <2 x i64> %v1, <2 x i64>* undef, align 8 + ret void; +} + +; v2i16 +; v2i16 x v2i32 +define void @sextload_v2i16_vs(<2 x i16>* %v, <2 x i32>* %p) nounwind { +;CHECK: sextload_v2i16_vs: +entry: + %0 = load <2 x i16>* %v, align 8 + %v0 = sext <2 x i16> %0 to <2 x i64> + + %1 = load <2 x i32>* %p, align 8 + %v2 = sext <2 x i32> %1 to <2 x i64> +;CHECK: vmull + %v1 = mul <2 x i64> %v0, %v2 + store <2 x i64> %v1, <2 x i64>* undef, align 8 + ret void; +}