From 3dd08734c1812e47ae5f6aceba15f28865f75943 Mon Sep 17 00:00:00 2001 From: Chris Lattner Date: Sat, 28 Aug 2010 01:20:38 +0000 Subject: [PATCH] optimize bitcasts from large integers to vector into vector element insertion from the pieces that feed into the vector. This handles a pattern that occurs frequently due to code generated for the x86-64 abi. We now compile something like this: struct S { float A, B, C, D; }; struct S g; struct S bar() { struct S A = g; ++A.A; ++A.C; return A; } into all nice vector operations: _bar: ## @bar ## BB#0: ## %entry movq _g@GOTPCREL(%rip), %rax movss LCPI1_0(%rip), %xmm1 movss (%rax), %xmm0 addss %xmm1, %xmm0 pshufd $16, %xmm0, %xmm0 movss 4(%rax), %xmm2 movss 12(%rax), %xmm3 pshufd $16, %xmm2, %xmm2 unpcklps %xmm2, %xmm0 addss 8(%rax), %xmm1 pshufd $16, %xmm1, %xmm1 pshufd $16, %xmm3, %xmm2 unpcklps %xmm2, %xmm1 ret instead of icky integer operations: _bar: ## @bar movq _g@GOTPCREL(%rip), %rax movss LCPI1_0(%rip), %xmm1 movss (%rax), %xmm0 addss %xmm1, %xmm0 movd %xmm0, %ecx movl 4(%rax), %edx movl 12(%rax), %esi shlq $32, %rdx addq %rcx, %rdx movd %rdx, %xmm0 addss 8(%rax), %xmm1 movd %xmm1, %eax shlq $32, %rsi addq %rax, %rsi movd %rsi, %xmm1 ret This resolves rdar://8360454 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@112343 91177308-0d34-0410-b5e6-96231b3b80d8 --- .../InstCombine/InstCombineCasts.cpp | 136 ++++++++++++++++-- .../InstCombine/InstCombineShifts.cpp | 4 +- test/Transforms/InstCombine/bitcast.ll | 31 ++++ 3 files changed, 160 insertions(+), 11 deletions(-) diff --git a/lib/Transforms/InstCombine/InstCombineCasts.cpp b/lib/Transforms/InstCombine/InstCombineCasts.cpp index 10eeab2f9f4..27eab7516f1 100644 --- a/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -1362,6 +1362,116 @@ static Instruction *OptimizeVectorResize(Value *InVal, const VectorType *DestTy, return new ShuffleVectorInst(InVal, V2, Mask); } +static bool isMultipleOfTypeSize(unsigned Value, const Type *Ty) { + return Value % Ty->getPrimitiveSizeInBits() == 0; +} + +static bool getTypeSizeIndex(unsigned Value, const Type *Ty) { + return Value / Ty->getPrimitiveSizeInBits(); +} + +/// CollectInsertionElements - V is a value which is inserted into a vector of +/// VecEltTy. Look through the value to see if we can decompose it into +/// insertions into the vector. See the example in the comment for +/// OptimizeIntegerToVectorInsertions for the pattern this handles. +/// The type of V is always a non-zero multiple of VecEltTy's size. +/// +/// This returns false if the pattern can't be matched or true if it can, +/// filling in Elements with the elements found here. +static bool CollectInsertionElements(Value *V, unsigned ElementIndex, + SmallVectorImpl &Elements, + const Type *VecEltTy) { + // If we got down to a value of the right type, we win, try inserting into the + // right element. + if (V->getType() == VecEltTy) { + // Fail if multiple elements are inserted into this slot. + if (ElementIndex >= Elements.size() || Elements[ElementIndex] != 0) + return false; + + Elements[ElementIndex] = V; + return true; + } + + //if (Constant *C = dyn_cast(V)) { + // Figure out the # elements this provides, and bitcast it or slice it up + // as required. + //} + + if (!V->hasOneUse()) return false; + + Instruction *I = dyn_cast(V); + if (I == 0) return false; + switch (I->getOpcode()) { + default: return false; // Unhandled case. + case Instruction::BitCast: + return CollectInsertionElements(I->getOperand(0), ElementIndex, + Elements, VecEltTy); + case Instruction::ZExt: + if (!isMultipleOfTypeSize( + I->getOperand(0)->getType()->getPrimitiveSizeInBits(), + VecEltTy)) + return false; + return CollectInsertionElements(I->getOperand(0), ElementIndex, + Elements, VecEltTy); + case Instruction::Or: + return CollectInsertionElements(I->getOperand(0), ElementIndex, + Elements, VecEltTy) && + CollectInsertionElements(I->getOperand(1), ElementIndex, + Elements, VecEltTy); + case Instruction::Shl: { + // Must be shifting by a constant that is a multiple of the element size. + ConstantInt *CI = dyn_cast(I->getOperand(1)); + if (CI == 0) return false; + if (!isMultipleOfTypeSize(CI->getZExtValue(), VecEltTy)) return false; + unsigned IndexShift = getTypeSizeIndex(CI->getZExtValue(), VecEltTy); + + return CollectInsertionElements(I->getOperand(0), ElementIndex+IndexShift, + Elements, VecEltTy); + } + + } +} + + +/// OptimizeIntegerToVectorInsertions - If the input is an 'or' instruction, we +/// may be doing shifts and ors to assemble the elements of the vector manually. +/// Try to rip the code out and replace it with insertelements. This is to +/// optimize code like this: +/// +/// %tmp37 = bitcast float %inc to i32 +/// %tmp38 = zext i32 %tmp37 to i64 +/// %tmp31 = bitcast float %inc5 to i32 +/// %tmp32 = zext i32 %tmp31 to i64 +/// %tmp33 = shl i64 %tmp32, 32 +/// %ins35 = or i64 %tmp33, %tmp38 +/// %tmp43 = bitcast i64 %ins35 to <2 x float> +/// +/// Into two insertelements that do "buildvector{%inc, %inc5}". +static Value *OptimizeIntegerToVectorInsertions(BitCastInst &CI, + InstCombiner &IC) { + const VectorType *DestVecTy = cast(CI.getType()); + Value *IntInput = CI.getOperand(0); + + SmallVector Elements(DestVecTy->getNumElements()); + if (!CollectInsertionElements(IntInput, 0, Elements, + DestVecTy->getElementType())) + return 0; + + // If we succeeded, we know that all of the element are specified by Elements + // or are zero if Elements has a null entry. Recast this as a set of + // insertions. + Value *Result = Constant::getNullValue(CI.getType()); + for (unsigned i = 0, e = Elements.size(); i != e; ++i) { + if (Elements[i] == 0) continue; // Unset element. + + Result = IC.Builder->CreateInsertElement(Result, Elements[i], + IC.Builder->getInt32(i)); + } + + return Result; +} + + /// OptimizeIntToFloatBitCast - See if we can optimize an integer->float/double /// bitcast. The various long double bitcasts can't get in here. static Instruction *OptimizeIntToFloatBitCast(BitCastInst &CI,InstCombiner &IC){ @@ -1478,16 +1588,24 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { // FIXME: Canonicalize bitcast(insertelement) -> insertelement(bitcast) } - // If this is a cast from an integer to vector, check to see if the input - // is a trunc or zext of a bitcast from vector. If so, we can replace all - // the casts with a shuffle and (potentially) a bitcast. - if (isa(SrcTy) && (isa(Src) || isa(Src))){ - CastInst *SrcCast = cast(Src); - if (BitCastInst *BCIn = dyn_cast(SrcCast->getOperand(0))) - if (isa(BCIn->getOperand(0)->getType())) - if (Instruction *I = OptimizeVectorResize(BCIn->getOperand(0), + if (isa(SrcTy)) { + // If this is a cast from an integer to vector, check to see if the input + // is a trunc or zext of a bitcast from vector. If so, we can replace all + // the casts with a shuffle and (potentially) a bitcast. + if (isa(Src) || isa(Src)) { + CastInst *SrcCast = cast(Src); + if (BitCastInst *BCIn = dyn_cast(SrcCast->getOperand(0))) + if (isa(BCIn->getOperand(0)->getType())) + if (Instruction *I = OptimizeVectorResize(BCIn->getOperand(0), cast(DestTy), *this)) - return I; + return I; + } + + // If the input is an 'or' instruction, we may be doing shifts and ors to + // assemble the elements of the vector manually. Try to rip the code out + // and replace it with insertelements. + if (Value *V = OptimizeIntegerToVectorInsertions(CI, *this)) + return ReplaceInstUsesWith(CI, V); } } diff --git a/lib/Transforms/InstCombine/InstCombineShifts.cpp b/lib/Transforms/InstCombine/InstCombineShifts.cpp index 270f489f682..27716b886a2 100644 --- a/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -312,8 +312,8 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, ConstantInt *Op1, // cast of lshr(shl(x,c1),c2) as well as other more complex cases. if (I.getOpcode() != Instruction::AShr && CanEvaluateShifted(Op0, Op1->getZExtValue(), isLeftShift, *this)) { - DEBUG(dbgs() << "ICE: GetShiftedValue propagatin shift through expression" - " to eliminate shift:\n IN: " << *Op0 << "\nSH: " << I << "\n"); + DEBUG(dbgs() << "ICE: GetShiftedValue propagating shift through expression" + " to eliminate shift:\n IN: " << *Op0 << "\n SH: " << I <<"\n"); return ReplaceInstUsesWith(I, GetShiftedValue(Op0, Op1->getZExtValue(), isLeftShift, *this)); diff --git a/test/Transforms/InstCombine/bitcast.ll b/test/Transforms/InstCombine/bitcast.ll index 10898397b98..87e413ea27f 100644 --- a/test/Transforms/InstCombine/bitcast.ll +++ b/test/Transforms/InstCombine/bitcast.ll @@ -60,3 +60,34 @@ define float @test3(<2 x float> %A, <2 x i64> %B) { ; CHECK-NEXT: %add = fadd float %tmp24, %tmp4 ; CHECK-NEXT: ret float %add } + + +define <2 x i32> @test4(i32 %A, i32 %B){ + %tmp38 = zext i32 %A to i64 + %tmp32 = zext i32 %B to i64 + %tmp33 = shl i64 %tmp32, 32 + %ins35 = or i64 %tmp33, %tmp38 + %tmp43 = bitcast i64 %ins35 to <2 x i32> + ret <2 x i32> %tmp43 + ; CHECK: @test4 + ; CHECK-NEXT: insertelement <2 x i32> undef, i32 %A, i32 0 + ; CHECK-NEXT: insertelement <2 x i32> {{.*}}, i32 %B, i32 1 + ; CHECK-NEXT: ret <2 x i32> + +} + +; rdar://8360454 +define <2 x float> @test5(float %A, float %B) { + %tmp37 = bitcast float %A to i32 + %tmp38 = zext i32 %tmp37 to i64 + %tmp31 = bitcast float %B to i32 + %tmp32 = zext i32 %tmp31 to i64 + %tmp33 = shl i64 %tmp32, 32 + %ins35 = or i64 %tmp33, %tmp38 + %tmp43 = bitcast i64 %ins35 to <2 x float> + ret <2 x float> %tmp43 + ; CHECK: @test5 + ; CHECK-NEXT: insertelement <2 x float> undef, float %A, i32 0 + ; CHECK-NEXT: insertelement <2 x float> {{.*}}, float %B, i32 1 + ; CHECK-NEXT: ret <2 x float> +}