diff --git a/lib/Transforms/InstCombine/InstCombine.h b/lib/Transforms/InstCombine/InstCombine.h index 41017c52879..7467eca7ab1 100644 --- a/lib/Transforms/InstCombine/InstCombine.h +++ b/lib/Transforms/InstCombine/InstCombine.h @@ -367,6 +367,10 @@ private: Value *EvaluateInDifferentType(Value *V, Type *Ty, bool isSigned); + + /// Descale - Return a value X such that Val = X * Scale, or null if none. If + /// the multiplication is known not to overflow then NoSignedWrap is set. + Value *Descale(Value *Val, APInt Scale, bool &NoSignedWrap); }; diff --git a/lib/Transforms/InstCombine/InstructionCombining.cpp b/lib/Transforms/InstCombine/InstructionCombining.cpp index 5356fdcba7c..390b63c1965 100644 --- a/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -805,6 +805,244 @@ static bool shouldMergeGEPs(GEPOperator &GEP, GEPOperator &Src) { return true; } +/// Descale - Return a value X such that Val = X * Scale, or null if none. If +/// the multiplication is known not to overflow then NoSignedWrap is set. +Value *InstCombiner::Descale(Value *Val, APInt Scale, bool &NoSignedWrap) { + assert(isa(Val->getType()) && "Can only descale integers!"); + assert(cast(Val->getType())->getBitWidth() == + Scale.getBitWidth() && "Scale not compatible with value!"); + + // If Val is zero or Scale is one then Val = Val * Scale. + if (match(Val, m_Zero()) || Scale == 1) { + NoSignedWrap = true; + return Val; + } + + // If Scale is zero then it does not divide Val. + if (Scale.isMinValue()) + return 0; + + // Look through chains of multiplications, searching for a constant that is + // divisible by Scale. For example, descaling X*(Y*(Z*4)) by a factor of 4 + // will find the constant factor 4 and produce X*(Y*Z). Descaling X*(Y*8) by + // a factor of 4 will produce X*(Y*2). The principle of operation is to bore + // down from Val: + // + // Val = M1 * X || Analysis starts here and works down + // M1 = M2 * Y || Doesn't descend into terms with more + // M2 = Z * 4 \/ than one use + // + // Then to modify a term at the bottom: + // + // Val = M1 * X + // M1 = Z * Y || Replaced M2 with Z + // + // Then to work back up correcting nsw flags. + + // Op - the term we are currently analyzing. Starts at Val then drills down. + // Replaced with its descaled value before exiting from the drill down loop. + Value *Op = Val; + + // Parent - initially null, but after drilling down notes where Op came from. + // In the example above, Parent is (Val, 0) when Op is M1, because M1 is the + // 0'th operand of Val. + std::pair Parent; + + // RequireNoSignedWrap - Set if the transform requires a descaling at deeper + // levels that doesn't overflow. + bool RequireNoSignedWrap = false; + + // logScale - log base 2 of the scale. Negative if not a power of 2. + int32_t logScale = Scale.exactLogBase2(); + + for (;; Op = Parent.first->getOperand(Parent.second)) { // Drill down + + if (ConstantInt *CI = dyn_cast(Op)) { + // If Op is a constant divisible by Scale then descale to the quotient. + APInt Quotient(Scale), Remainder(Scale); // Init ensures right bitwidth. + APInt::sdivrem(CI->getValue(), Scale, Quotient, Remainder); + if (!Remainder.isMinValue()) + // Not divisible by Scale. + return 0; + // Replace with the quotient in the parent. + Op = ConstantInt::get(CI->getType(), Quotient); + NoSignedWrap = true; + break; + } + + if (BinaryOperator *BO = dyn_cast(Op)) { + + if (BO->getOpcode() == Instruction::Mul) { + // Multiplication. + NoSignedWrap = BO->hasNoSignedWrap(); + if (RequireNoSignedWrap && !NoSignedWrap) + return 0; + + // There are three cases for multiplication: multiplication by exactly + // the scale, multiplication by a constant different to the scale, and + // multiplication by something else. + Value *LHS = BO->getOperand(0); + Value *RHS = BO->getOperand(1); + + if (ConstantInt *CI = dyn_cast(RHS)) { + // Multiplication by a constant. + if (CI->getValue() == Scale) { + // Multiplication by exactly the scale, replace the multiplication + // by its left-hand side in the parent. + Op = LHS; + break; + } + + // Otherwise drill down into the constant. + if (!Op->hasOneUse()) + return 0; + + Parent = std::make_pair(BO, 1); + continue; + } + + // Multiplication by something else. Drill down into the left-hand side + // since that's where the reassociate pass puts the good stuff. + if (!Op->hasOneUse()) + return 0; + + Parent = std::make_pair(BO, 0); + continue; + } + + if (logScale > 0 && BO->getOpcode() == Instruction::Shl && + isa(BO->getOperand(1))) { + // Multiplication by a power of 2. + NoSignedWrap = BO->hasNoSignedWrap(); + if (RequireNoSignedWrap && !NoSignedWrap) + return 0; + + Value *LHS = BO->getOperand(0); + int32_t Amt = cast(BO->getOperand(1))-> + getLimitedValue(Scale.getBitWidth()); + // Op = LHS << Amt. + + if (Amt == logScale) { + // Multiplication by exactly the scale, replace the multiplication + // by its left-hand side in the parent. + Op = LHS; + break; + } + if (Amt < logScale || !Op->hasOneUse()) + return 0; + + // Multiplication by more than the scale. Reduce the multiplying amount + // by the scale in the parent. + Parent = std::make_pair(BO, 1); + Op = ConstantInt::get(BO->getType(), Amt - logScale); + break; + } + } + + if (!Op->hasOneUse()) + return 0; + + if (CastInst *Cast = dyn_cast(Op)) { + if (Cast->getOpcode() == Instruction::SExt) { + // Op is sign-extended from a smaller type, descale in the smaller type. + unsigned SmallSize = Cast->getSrcTy()->getPrimitiveSizeInBits(); + APInt SmallScale = Scale.trunc(SmallSize); + // Suppose Op = sext X, and we descale X as Y * SmallScale. We want to + // descale Op as (sext Y) * Scale. In order to have + // sext (Y * SmallScale) = (sext Y) * Scale + // some conditions need to hold however: SmallScale must sign-extend to + // Scale and the multiplication Y * SmallScale should not overflow. + if (SmallScale.sext(Scale.getBitWidth()) != Scale) + // SmallScale does not sign-extend to Scale. + return 0; + assert(SmallScale.exactLogBase2() == logScale); + // Require that Y * SmallScale must not overflow. + RequireNoSignedWrap = true; + + // Drill down through the cast. + Parent = std::make_pair(Cast, 0); + Scale = SmallScale; + continue; + } + + if (Cast->getOperand(0)) { + // Op is truncated from a larger type, descale in the larger type. + // Suppose Op = trunc X, and we descale X as Y * sext Scale. Then + // trunc (Y * sext Scale) = (trunc Y) * Scale + // always holds. However (trunc Y) * Scale may overflow even if + // trunc (Y * sext Scale) does not, so nsw flags need to be cleared + // from this point up in the expression (see later). + if (RequireNoSignedWrap) + return 0; + + // Drill down through the cast. + unsigned LargeSize = Cast->getSrcTy()->getPrimitiveSizeInBits(); + Parent = std::make_pair(Cast, 0); + Scale = Scale.sext(LargeSize); + if (logScale + 1 == (int32_t)Cast->getType()->getPrimitiveSizeInBits()) + logScale = -1; + assert(Scale.exactLogBase2() == logScale); + continue; + } + } + + // Unsupported expression, bail out. + return 0; + } + + // We know that we can successfully descale, so from here on we can safely + // modify the IR. Op holds the descaled version of the deepest term in the + // expression. NoSignedWrap is 'true' if multiplying Op by Scale is known + // not to overflow. + + if (!Parent.first) + // The expression only had one term. + return Op; + + // Rewrite the parent using the descaled version of its operand. + assert(Parent.first->hasOneUse() && "Drilled down when more than one use!"); + assert(Op != Parent.first->getOperand(Parent.second) && + "Descaling was a no-op?"); + Parent.first->setOperand(Parent.second, Op); + Worklist.Add(Parent.first); + + // Now work back up the expression correcting nsw flags. The logic is based + // on the following observation: if X * Y is known not to overflow as a signed + // multiplication, and Y is replaced by a value Z with smaller absolute value, + // then X * Z will not overflow as a signed multiplication either. As we work + // our way up, having NoSignedWrap 'true' means that the descaled value at the + // current level has strictly smaller absolute value than the original. + Instruction *Ancestor = Parent.first; + do { + if (BinaryOperator *BO = dyn_cast(Ancestor)) { + // If the multiplication wasn't nsw then we can't say anything about the + // value of the descaled multiplication, and we have to clear nsw flags + // from this point on up. + bool OpNoSignedWrap = BO->hasNoSignedWrap(); + NoSignedWrap &= OpNoSignedWrap; + if (NoSignedWrap != OpNoSignedWrap) { + BO->setHasNoSignedWrap(NoSignedWrap); + Worklist.Add(Ancestor); + } + } else if (Ancestor->getOpcode() == Instruction::Trunc) { + // The fact that the descaled input to the trunc has smaller absolute + // value than the original input doesn't tell us anything useful about + // the absolute values of the truncations. + NoSignedWrap = false; + } + assert((Ancestor->getOpcode() != Instruction::SExt || NoSignedWrap) && + "Failed to keep proper track of nsw flags while drilling down?"); + + if (Ancestor == Val) + // Got to the top, all done! + return Val; + + // Move up one level in the expression. + assert(Ancestor->hasOneUse() && "Drilled down when more than one use!"); + Ancestor = Ancestor->use_back(); + } while (1); +} + Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { SmallVector Ops(GEP.op_begin(), GEP.op_end()); @@ -855,7 +1093,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (!shouldMergeGEPs(*cast(&GEP), *Src)) return 0; - // Note that if our source is a gep chain itself that we wait for that + // Note that if our source is a gep chain itself then we wait for that // chain to be resolved before we perform this transformation. This // avoids us creating a TON of code in some cases. if (GEPOperator *SrcGEP = @@ -987,63 +1225,74 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { } // Transform things like: + // %V = mul i64 %N, 4 + // %t = getelementptr i8* bitcast (i32* %arr to i8*), i32 %V + // into: %t1 = getelementptr i32* %arr, i32 %N; bitcast + if (TD && ResElTy->isSized() && SrcElTy->isSized()) { + // Check that changing the type amounts to dividing the index by a scale + // factor. + uint64_t ResSize = TD->getTypeAllocSize(ResElTy); + uint64_t SrcSize = TD->getTypeAllocSize(SrcElTy); + if (ResSize && SrcSize % ResSize == 0) { + Value *Idx = GEP.getOperand(1); + unsigned BitWidth = Idx->getType()->getPrimitiveSizeInBits(); + uint64_t Scale = SrcSize / ResSize; + + // Earlier transforms ensure that the index has type IntPtrType, which + // considerably simplifies the logic by eliminating implicit casts. + assert(Idx->getType() == TD->getIntPtrType(GEP.getContext()) && + "Index not cast to pointer width?"); + + bool NSW; + if (Value *NewIdx = Descale(Idx, APInt(BitWidth, Scale), NSW)) { + // Successfully decomposed Idx as NewIdx * Scale, form a new GEP. + // If the multiplication NewIdx * Scale may overflow then the new + // GEP may not be "inbounds". + Value *NewGEP = GEP.isInBounds() && NSW ? + Builder->CreateInBoundsGEP(StrippedPtr, NewIdx, GEP.getName()) : + Builder->CreateGEP(StrippedPtr, NewIdx, GEP.getName()); + // The NewGEP must be pointer typed, so must the old one -> BitCast + return new BitCastInst(NewGEP, GEP.getType()); + } + } + } + + // Similarly, transform things like: // getelementptr i8* bitcast ([100 x double]* X to i8*), i32 %tmp // (where tmp = 8*tmp2) into: // getelementptr [100 x double]* %arr, i32 0, i32 %tmp2; bitcast - - if (TD && SrcElTy->isArrayTy() && ResElTy->isIntegerTy(8)) { + if (TD && ResElTy->isSized() && SrcElTy->isSized() && + SrcElTy->isArrayTy()) { + // Check that changing to the array element type amounts to dividing the + // index by a scale factor. + uint64_t ResSize = TD->getTypeAllocSize(ResElTy); uint64_t ArrayEltSize = - TD->getTypeAllocSize(cast(SrcElTy)->getElementType()); + TD->getTypeAllocSize(cast(SrcElTy)->getElementType()); + if (ResSize && ArrayEltSize % ResSize == 0) { + Value *Idx = GEP.getOperand(1); + unsigned BitWidth = Idx->getType()->getPrimitiveSizeInBits(); + uint64_t Scale = ArrayEltSize / ResSize; - // Check to see if "tmp" is a scale by a multiple of ArrayEltSize. We - // allow either a mul, shift, or constant here. - Value *NewIdx = 0; - ConstantInt *Scale = 0; - if (ArrayEltSize == 1) { - NewIdx = GEP.getOperand(1); - Scale = ConstantInt::get(cast(NewIdx->getType()), 1); - } else if (ConstantInt *CI = dyn_cast(GEP.getOperand(1))) { - NewIdx = ConstantInt::get(CI->getType(), 1); - Scale = CI; - } else if (Instruction *Inst =dyn_cast(GEP.getOperand(1))){ - if (Inst->getOpcode() == Instruction::Shl && - isa(Inst->getOperand(1))) { - ConstantInt *ShAmt = cast(Inst->getOperand(1)); - uint32_t ShAmtVal = ShAmt->getLimitedValue(64); - Scale = ConstantInt::get(cast(Inst->getType()), - 1ULL << ShAmtVal); - NewIdx = Inst->getOperand(0); - } else if (Inst->getOpcode() == Instruction::Mul && - isa(Inst->getOperand(1))) { - Scale = cast(Inst->getOperand(1)); - NewIdx = Inst->getOperand(0); + // Earlier transforms ensure that the index has type IntPtrType, which + // considerably simplifies the logic by eliminating implicit casts. + assert(Idx->getType() == TD->getIntPtrType(GEP.getContext()) && + "Index not cast to pointer width?"); + + bool NSW; + if (Value *NewIdx = Descale(Idx, APInt(BitWidth, Scale), NSW)) { + // Successfully decomposed Idx as NewIdx * Scale, form a new GEP. + // If the multiplication NewIdx * Scale may overflow then the new + // GEP may not be "inbounds". + Value *Off[2]; + Off[0] = Constant::getNullValue(Type::getInt32Ty(GEP.getContext())); + Off[1] = NewIdx; + Value *NewGEP = GEP.isInBounds() && NSW ? + Builder->CreateInBoundsGEP(StrippedPtr, Off, GEP.getName()) : + Builder->CreateGEP(StrippedPtr, Off, GEP.getName()); + // The NewGEP must be pointer typed, so must the old one -> BitCast + return new BitCastInst(NewGEP, GEP.getType()); } } - - // If the index will be to exactly the right offset with the scale taken - // out, perform the transformation. Note, we don't know whether Scale is - // signed or not. We'll use unsigned version of division/modulo - // operation after making sure Scale doesn't have the sign bit set. - if (ArrayEltSize && Scale && Scale->getSExtValue() >= 0LL && - Scale->getZExtValue() % ArrayEltSize == 0) { - Scale = ConstantInt::get(Scale->getType(), - Scale->getZExtValue() / ArrayEltSize); - if (Scale->getZExtValue() != 1) { - Constant *C = ConstantExpr::getIntegerCast(Scale, NewIdx->getType(), - false /*ZExt*/); - NewIdx = Builder->CreateMul(NewIdx, C, "idxscale"); - } - - // Insert the new GEP instruction. - Value *Idx[2]; - Idx[0] = Constant::getNullValue(Type::getInt32Ty(GEP.getContext())); - Idx[1] = NewIdx; - Value *NewGEP = GEP.isInBounds() ? - Builder->CreateInBoundsGEP(StrippedPtr, Idx, GEP.getName()): - Builder->CreateGEP(StrippedPtr, Idx, GEP.getName()); - // The NewGEP must be pointer typed, so must the old one -> BitCast - return new BitCastInst(NewGEP, GEP.getType()); - } } } } diff --git a/test/Transforms/InstCombine/cast.ll b/test/Transforms/InstCombine/cast.ll index 56e5ca3ff72..899ffddd5bc 100644 --- a/test/Transforms/InstCombine/cast.ll +++ b/test/Transforms/InstCombine/cast.ll @@ -694,3 +694,200 @@ define i1 @test67(i1 %a, i32 %b) { ; CHECK: @test67 ; CHECK: ret i1 false } + +%s = type { i32, i32, i32 } + +define %s @test68(%s *%p, i64 %i) { +; CHECK: @test68 + %o = mul i64 %i, 12 + %q = bitcast %s* %p to i8* + %pp = getelementptr inbounds i8* %q, i64 %o +; CHECK-NEXT: getelementptr %s* + %r = bitcast i8* %pp to %s* + %l = load %s* %r +; CHECK-NEXT: load %s* + ret %s %l +; CHECK-NEXT: ret %s +} + +define double @test69(double *%p, i64 %i) { +; CHECK: @test69 + %o = shl nsw i64 %i, 3 + %q = bitcast double* %p to i8* + %pp = getelementptr inbounds i8* %q, i64 %o +; CHECK-NEXT: getelementptr inbounds double* + %r = bitcast i8* %pp to double* + %l = load double* %r +; CHECK-NEXT: load double* + ret double %l +; CHECK-NEXT: ret double +} + +define %s @test70(%s *%p, i64 %i) { +; CHECK: @test70 + %o = mul nsw i64 %i, 36 +; CHECK-NEXT: mul nsw i64 %i, 3 + %q = bitcast %s* %p to i8* + %pp = getelementptr inbounds i8* %q, i64 %o +; CHECK-NEXT: getelementptr inbounds %s* + %r = bitcast i8* %pp to %s* + %l = load %s* %r +; CHECK-NEXT: load %s* + ret %s %l +; CHECK-NEXT: ret %s +} + +define double @test71(double *%p, i64 %i) { +; CHECK: @test71 + %o = shl i64 %i, 5 +; CHECK-NEXT: shl i64 %i, 2 + %q = bitcast double* %p to i8* + %pp = getelementptr i8* %q, i64 %o +; CHECK-NEXT: getelementptr double* + %r = bitcast i8* %pp to double* + %l = load double* %r +; CHECK-NEXT: load double* + ret double %l +; CHECK-NEXT: ret double +} + +define double @test72(double *%p, i32 %i) { +; CHECK: @test72 + %so = mul nsw i32 %i, 8 + %o = sext i32 %so to i64 +; CHECK-NEXT: sext i32 %i to i64 + %q = bitcast double* %p to i8* + %pp = getelementptr inbounds i8* %q, i64 %o +; CHECK-NEXT: getelementptr inbounds double* + %r = bitcast i8* %pp to double* + %l = load double* %r +; CHECK-NEXT: load double* + ret double %l +; CHECK-NEXT: ret double +} + +define double @test73(double *%p, i128 %i) { +; CHECK: @test73 + %lo = mul nsw i128 %i, 8 + %o = trunc i128 %lo to i64 +; CHECK-NEXT: trunc i128 %i to i64 + %q = bitcast double* %p to i8* + %pp = getelementptr inbounds i8* %q, i64 %o +; CHECK-NEXT: getelementptr double* + %r = bitcast i8* %pp to double* + %l = load double* %r +; CHECK-NEXT: load double* + ret double %l +; CHECK-NEXT: ret double +} + +define double @test74(double *%p, i64 %i) { +; CHECK: @test74 + %q = bitcast double* %p to i64* + %pp = getelementptr inbounds i64* %q, i64 %i +; CHECK-NEXT: getelementptr inbounds double* + %r = bitcast i64* %pp to double* + %l = load double* %r +; CHECK-NEXT: load double* + ret double %l +; CHECK-NEXT: ret double +} + +define i32* @test75(i32* %p, i32 %x) { +; CHECK: @test75 + %y = shl i32 %x, 3 +; CHECK-NEXT: shl i32 %x, 3 + %z = sext i32 %y to i64 +; CHECK-NEXT: sext i32 %y to i64 + %q = bitcast i32* %p to i8* + %r = getelementptr i8* %q, i64 %z + %s = bitcast i8* %r to i32* + ret i32* %s +} + +define %s @test76(%s *%p, i64 %i, i64 %j) { +; CHECK: @test76 + %o = mul i64 %i, 12 + %o2 = mul nsw i64 %o, %j +; CHECK-NEXT: %o2 = mul i64 %i, %j + %q = bitcast %s* %p to i8* + %pp = getelementptr inbounds i8* %q, i64 %o2 +; CHECK-NEXT: getelementptr %s* %p, i64 %o2 + %r = bitcast i8* %pp to %s* + %l = load %s* %r +; CHECK-NEXT: load %s* + ret %s %l +; CHECK-NEXT: ret %s +} + +define %s @test77(%s *%p, i64 %i, i64 %j) { +; CHECK: @test77 + %o = mul nsw i64 %i, 36 + %o2 = mul nsw i64 %o, %j +; CHECK-NEXT: %o = mul nsw i64 %i, 3 +; CHECK-NEXT: %o2 = mul nsw i64 %o, %j + %q = bitcast %s* %p to i8* + %pp = getelementptr inbounds i8* %q, i64 %o2 +; CHECK-NEXT: getelementptr inbounds %s* %p, i64 %o2 + %r = bitcast i8* %pp to %s* + %l = load %s* %r +; CHECK-NEXT: load %s* + ret %s %l +; CHECK-NEXT: ret %s +} + +define %s @test78(%s *%p, i64 %i, i64 %j, i32 %k, i32 %l, i128 %m, i128 %n) { +; CHECK: @test78 + %a = mul nsw i32 %k, 36 +; CHECK-NEXT: mul nsw i32 %k, 3 + %b = mul nsw i32 %a, %l +; CHECK-NEXT: mul nsw i32 %a, %l + %c = sext i32 %b to i128 +; CHECK-NEXT: sext i32 %b to i128 + %d = mul nsw i128 %c, %m +; CHECK-NEXT: mul nsw i128 %c, %m + %e = mul i128 %d, %n +; CHECK-NEXT: mul i128 %d, %n + %f = trunc i128 %e to i64 +; CHECK-NEXT: trunc i128 %e to i64 + %g = mul nsw i64 %f, %i +; CHECK-NEXT: mul i64 %f, %i + %h = mul nsw i64 %g, %j +; CHECK-NEXT: mul i64 %g, %j + %q = bitcast %s* %p to i8* + %pp = getelementptr inbounds i8* %q, i64 %h +; CHECK-NEXT: getelementptr %s* %p, i64 %h + %r = bitcast i8* %pp to %s* + %load = load %s* %r +; CHECK-NEXT: load %s* + ret %s %load +; CHECK-NEXT: ret %s +} + +define %s @test79(%s *%p, i64 %i, i32 %j) { +; CHECK: @test79 + %a = mul nsw i64 %i, 36 +; CHECK: mul nsw i64 %i, 36 + %b = trunc i64 %a to i32 + %c = mul i32 %b, %j + %q = bitcast %s* %p to i8* +; CHECK: bitcast + %pp = getelementptr inbounds i8* %q, i32 %c + %r = bitcast i8* %pp to %s* + %l = load %s* %r + ret %s %l +} + +define double @test80([100 x double]* %p, i32 %i) { +; CHECK: @test80 + %tmp = mul nsw i32 %i, 8 +; CHECK-NEXT: sext i32 %i to i64 + %q = bitcast [100 x double]* %p to i8* + %pp = getelementptr i8* %q, i32 %tmp +; CHECK-NEXT: getelementptr [100 x double]* + %r = bitcast i8* %pp to double* + %l = load double* %r +; CHECK-NEXT: load double* + ret double %l +; CHECK-NEXT: ret double +}