Transform code like this

%V = mul i64 %N, 4
 %t = getelementptr i8* bitcast (i32* %arr to i8*), i32 %V
into
 %t1 = getelementptr i32* %arr, i32 %N
 %t = bitcast i32* %t1 to i8*
incorporating the multiplication into the getelementptr.
This happens all the time in dragonegg, for example for
  int foo(int *A, int N) {
    return A[N];
  }
because gcc turns this into byte pointer arithmetic before it hits the plugin:
  D.1590_2 = (long unsigned int) N_1(D);
  D.1591_3 = D.1590_2 * 4;
  D.1592_5 = A_4(D) + D.1591_3;
  D.1589_6 = *D.1592_5;
  return D.1589_6;
The D.1592_5 line is a POINTER_PLUS_EXPR, which is turned into a getelementptr
on a bitcast of A_4 to i8*, so this becomes exactly the kind of IR that the
transform fires on.

An analogous transform (with no testcases!) already existed for bitcasts of
arrays, so I rewrote it to share code with this one.

llvm-svn: 166474
This commit is contained in:
Duncan Sands 2012-10-23 08:28:26 +00:00
parent 1190b1a97e
commit 6ce2ce7ed1
3 changed files with 501 additions and 51 deletions

View File

@ -367,6 +367,10 @@ private:
Value *EvaluateInDifferentType(Value *V, Type *Ty, bool isSigned);
/// Descale - Return a value X such that Val = X * Scale, or null if none. If
/// the multiplication is known not to overflow then NoSignedWrap is set.
Value *Descale(Value *Val, APInt Scale, bool &NoSignedWrap);
};

View File

@ -805,6 +805,244 @@ static bool shouldMergeGEPs(GEPOperator &GEP, GEPOperator &Src) {
return true;
}
/// Descale - Return a value X such that Val = X * Scale, or null if none. If
/// the multiplication is known not to overflow then NoSignedWrap is set.
Value *InstCombiner::Descale(Value *Val, APInt Scale, bool &NoSignedWrap) {
assert(isa<IntegerType>(Val->getType()) && "Can only descale integers!");
assert(cast<IntegerType>(Val->getType())->getBitWidth() ==
Scale.getBitWidth() && "Scale not compatible with value!");
// If Val is zero or Scale is one then Val = Val * Scale.
if (match(Val, m_Zero()) || Scale == 1) {
NoSignedWrap = true;
return Val;
}
// If Scale is zero then it does not divide Val.
if (Scale.isMinValue())
return 0;
// Look through chains of multiplications, searching for a constant that is
// divisible by Scale. For example, descaling X*(Y*(Z*4)) by a factor of 4
// will find the constant factor 4 and produce X*(Y*Z). Descaling X*(Y*8) by
// a factor of 4 will produce X*(Y*2). The principle of operation is to bore
// down from Val:
//
// Val = M1 * X || Analysis starts here and works down
// M1 = M2 * Y || Doesn't descend into terms with more
// M2 = Z * 4 \/ than one use
//
// Then to modify a term at the bottom:
//
// Val = M1 * X
// M1 = Z * Y || Replaced M2 with Z
//
// Then to work back up correcting nsw flags.
// Op - the term we are currently analyzing. Starts at Val then drills down.
// Replaced with its descaled value before exiting from the drill down loop.
Value *Op = Val;
// Parent - initially null, but after drilling down notes where Op came from.
// In the example above, Parent is (Val, 0) when Op is M1, because M1 is the
// 0'th operand of Val.
std::pair<Instruction*, unsigned> Parent;
// RequireNoSignedWrap - Set if the transform requires a descaling at deeper
// levels that doesn't overflow.
bool RequireNoSignedWrap = false;
// logScale - log base 2 of the scale. Negative if not a power of 2.
int32_t logScale = Scale.exactLogBase2();
for (;; Op = Parent.first->getOperand(Parent.second)) { // Drill down
if (ConstantInt *CI = dyn_cast<ConstantInt>(Op)) {
// If Op is a constant divisible by Scale then descale to the quotient.
APInt Quotient(Scale), Remainder(Scale); // Init ensures right bitwidth.
APInt::sdivrem(CI->getValue(), Scale, Quotient, Remainder);
if (!Remainder.isMinValue())
// Not divisible by Scale.
return 0;
// Replace with the quotient in the parent.
Op = ConstantInt::get(CI->getType(), Quotient);
NoSignedWrap = true;
break;
}
if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Op)) {
if (BO->getOpcode() == Instruction::Mul) {
// Multiplication.
NoSignedWrap = BO->hasNoSignedWrap();
if (RequireNoSignedWrap && !NoSignedWrap)
return 0;
// There are three cases for multiplication: multiplication by exactly
// the scale, multiplication by a constant different to the scale, and
// multiplication by something else.
Value *LHS = BO->getOperand(0);
Value *RHS = BO->getOperand(1);
if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) {
// Multiplication by a constant.
if (CI->getValue() == Scale) {
// Multiplication by exactly the scale, replace the multiplication
// by its left-hand side in the parent.
Op = LHS;
break;
}
// Otherwise drill down into the constant.
if (!Op->hasOneUse())
return 0;
Parent = std::make_pair(BO, 1);
continue;
}
// Multiplication by something else. Drill down into the left-hand side
// since that's where the reassociate pass puts the good stuff.
if (!Op->hasOneUse())
return 0;
Parent = std::make_pair(BO, 0);
continue;
}
if (logScale > 0 && BO->getOpcode() == Instruction::Shl &&
isa<ConstantInt>(BO->getOperand(1))) {
// Multiplication by a power of 2.
NoSignedWrap = BO->hasNoSignedWrap();
if (RequireNoSignedWrap && !NoSignedWrap)
return 0;
Value *LHS = BO->getOperand(0);
int32_t Amt = cast<ConstantInt>(BO->getOperand(1))->
getLimitedValue(Scale.getBitWidth());
// Op = LHS << Amt.
if (Amt == logScale) {
// Multiplication by exactly the scale, replace the multiplication
// by its left-hand side in the parent.
Op = LHS;
break;
}
if (Amt < logScale || !Op->hasOneUse())
return 0;
// Multiplication by more than the scale. Reduce the multiplying amount
// by the scale in the parent.
Parent = std::make_pair(BO, 1);
Op = ConstantInt::get(BO->getType(), Amt - logScale);
break;
}
}
if (!Op->hasOneUse())
return 0;
if (CastInst *Cast = dyn_cast<CastInst>(Op)) {
if (Cast->getOpcode() == Instruction::SExt) {
// Op is sign-extended from a smaller type, descale in the smaller type.
unsigned SmallSize = Cast->getSrcTy()->getPrimitiveSizeInBits();
APInt SmallScale = Scale.trunc(SmallSize);
// Suppose Op = sext X, and we descale X as Y * SmallScale. We want to
// descale Op as (sext Y) * Scale. In order to have
// sext (Y * SmallScale) = (sext Y) * Scale
// some conditions need to hold however: SmallScale must sign-extend to
// Scale and the multiplication Y * SmallScale should not overflow.
if (SmallScale.sext(Scale.getBitWidth()) != Scale)
// SmallScale does not sign-extend to Scale.
return 0;
assert(SmallScale.exactLogBase2() == logScale);
// Require that Y * SmallScale must not overflow.
RequireNoSignedWrap = true;
// Drill down through the cast.
Parent = std::make_pair(Cast, 0);
Scale = SmallScale;
continue;
}
if (Cast->getOperand(0)) {
// Op is truncated from a larger type, descale in the larger type.
// Suppose Op = trunc X, and we descale X as Y * sext Scale. Then
// trunc (Y * sext Scale) = (trunc Y) * Scale
// always holds. However (trunc Y) * Scale may overflow even if
// trunc (Y * sext Scale) does not, so nsw flags need to be cleared
// from this point up in the expression (see later).
if (RequireNoSignedWrap)
return 0;
// Drill down through the cast.
unsigned LargeSize = Cast->getSrcTy()->getPrimitiveSizeInBits();
Parent = std::make_pair(Cast, 0);
Scale = Scale.sext(LargeSize);
if (logScale + 1 == (int32_t)Cast->getType()->getPrimitiveSizeInBits())
logScale = -1;
assert(Scale.exactLogBase2() == logScale);
continue;
}
}
// Unsupported expression, bail out.
return 0;
}
// We know that we can successfully descale, so from here on we can safely
// modify the IR. Op holds the descaled version of the deepest term in the
// expression. NoSignedWrap is 'true' if multiplying Op by Scale is known
// not to overflow.
if (!Parent.first)
// The expression only had one term.
return Op;
// Rewrite the parent using the descaled version of its operand.
assert(Parent.first->hasOneUse() && "Drilled down when more than one use!");
assert(Op != Parent.first->getOperand(Parent.second) &&
"Descaling was a no-op?");
Parent.first->setOperand(Parent.second, Op);
Worklist.Add(Parent.first);
// Now work back up the expression correcting nsw flags. The logic is based
// on the following observation: if X * Y is known not to overflow as a signed
// multiplication, and Y is replaced by a value Z with smaller absolute value,
// then X * Z will not overflow as a signed multiplication either. As we work
// our way up, having NoSignedWrap 'true' means that the descaled value at the
// current level has strictly smaller absolute value than the original.
Instruction *Ancestor = Parent.first;
do {
if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Ancestor)) {
// If the multiplication wasn't nsw then we can't say anything about the
// value of the descaled multiplication, and we have to clear nsw flags
// from this point on up.
bool OpNoSignedWrap = BO->hasNoSignedWrap();
NoSignedWrap &= OpNoSignedWrap;
if (NoSignedWrap != OpNoSignedWrap) {
BO->setHasNoSignedWrap(NoSignedWrap);
Worklist.Add(Ancestor);
}
} else if (Ancestor->getOpcode() == Instruction::Trunc) {
// The fact that the descaled input to the trunc has smaller absolute
// value than the original input doesn't tell us anything useful about
// the absolute values of the truncations.
NoSignedWrap = false;
}
assert((Ancestor->getOpcode() != Instruction::SExt || NoSignedWrap) &&
"Failed to keep proper track of nsw flags while drilling down?");
if (Ancestor == Val)
// Got to the top, all done!
return Val;
// Move up one level in the expression.
assert(Ancestor->hasOneUse() && "Drilled down when more than one use!");
Ancestor = Ancestor->use_back();
} while (1);
}
Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) {
SmallVector<Value*, 8> Ops(GEP.op_begin(), GEP.op_end());
@ -855,7 +1093,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) {
if (!shouldMergeGEPs(*cast<GEPOperator>(&GEP), *Src))
return 0;
// Note that if our source is a gep chain itself that we wait for that
// Note that if our source is a gep chain itself then we wait for that
// chain to be resolved before we perform this transformation. This
// avoids us creating a TON of code in some cases.
if (GEPOperator *SrcGEP =
@ -987,63 +1225,74 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) {
}
// Transform things like:
// %V = mul i64 %N, 4
// %t = getelementptr i8* bitcast (i32* %arr to i8*), i32 %V
// into: %t1 = getelementptr i32* %arr, i32 %N; bitcast
if (TD && ResElTy->isSized() && SrcElTy->isSized()) {
// Check that changing the type amounts to dividing the index by a scale
// factor.
uint64_t ResSize = TD->getTypeAllocSize(ResElTy);
uint64_t SrcSize = TD->getTypeAllocSize(SrcElTy);
if (ResSize && SrcSize % ResSize == 0) {
Value *Idx = GEP.getOperand(1);
unsigned BitWidth = Idx->getType()->getPrimitiveSizeInBits();
uint64_t Scale = SrcSize / ResSize;
// Earlier transforms ensure that the index has type IntPtrType, which
// considerably simplifies the logic by eliminating implicit casts.
assert(Idx->getType() == TD->getIntPtrType(GEP.getContext()) &&
"Index not cast to pointer width?");
bool NSW;
if (Value *NewIdx = Descale(Idx, APInt(BitWidth, Scale), NSW)) {
// Successfully decomposed Idx as NewIdx * Scale, form a new GEP.
// If the multiplication NewIdx * Scale may overflow then the new
// GEP may not be "inbounds".
Value *NewGEP = GEP.isInBounds() && NSW ?
Builder->CreateInBoundsGEP(StrippedPtr, NewIdx, GEP.getName()) :
Builder->CreateGEP(StrippedPtr, NewIdx, GEP.getName());
// The NewGEP must be pointer typed, so must the old one -> BitCast
return new BitCastInst(NewGEP, GEP.getType());
}
}
}
// Similarly, transform things like:
// getelementptr i8* bitcast ([100 x double]* X to i8*), i32 %tmp
// (where tmp = 8*tmp2) into:
// getelementptr [100 x double]* %arr, i32 0, i32 %tmp2; bitcast
if (TD && SrcElTy->isArrayTy() && ResElTy->isIntegerTy(8)) {
if (TD && ResElTy->isSized() && SrcElTy->isSized() &&
SrcElTy->isArrayTy()) {
// Check that changing to the array element type amounts to dividing the
// index by a scale factor.
uint64_t ResSize = TD->getTypeAllocSize(ResElTy);
uint64_t ArrayEltSize =
TD->getTypeAllocSize(cast<ArrayType>(SrcElTy)->getElementType());
TD->getTypeAllocSize(cast<ArrayType>(SrcElTy)->getElementType());
if (ResSize && ArrayEltSize % ResSize == 0) {
Value *Idx = GEP.getOperand(1);
unsigned BitWidth = Idx->getType()->getPrimitiveSizeInBits();
uint64_t Scale = ArrayEltSize / ResSize;
// Check to see if "tmp" is a scale by a multiple of ArrayEltSize. We
// allow either a mul, shift, or constant here.
Value *NewIdx = 0;
ConstantInt *Scale = 0;
if (ArrayEltSize == 1) {
NewIdx = GEP.getOperand(1);
Scale = ConstantInt::get(cast<IntegerType>(NewIdx->getType()), 1);
} else if (ConstantInt *CI = dyn_cast<ConstantInt>(GEP.getOperand(1))) {
NewIdx = ConstantInt::get(CI->getType(), 1);
Scale = CI;
} else if (Instruction *Inst =dyn_cast<Instruction>(GEP.getOperand(1))){
if (Inst->getOpcode() == Instruction::Shl &&
isa<ConstantInt>(Inst->getOperand(1))) {
ConstantInt *ShAmt = cast<ConstantInt>(Inst->getOperand(1));
uint32_t ShAmtVal = ShAmt->getLimitedValue(64);
Scale = ConstantInt::get(cast<IntegerType>(Inst->getType()),
1ULL << ShAmtVal);
NewIdx = Inst->getOperand(0);
} else if (Inst->getOpcode() == Instruction::Mul &&
isa<ConstantInt>(Inst->getOperand(1))) {
Scale = cast<ConstantInt>(Inst->getOperand(1));
NewIdx = Inst->getOperand(0);
// Earlier transforms ensure that the index has type IntPtrType, which
// considerably simplifies the logic by eliminating implicit casts.
assert(Idx->getType() == TD->getIntPtrType(GEP.getContext()) &&
"Index not cast to pointer width?");
bool NSW;
if (Value *NewIdx = Descale(Idx, APInt(BitWidth, Scale), NSW)) {
// Successfully decomposed Idx as NewIdx * Scale, form a new GEP.
// If the multiplication NewIdx * Scale may overflow then the new
// GEP may not be "inbounds".
Value *Off[2];
Off[0] = Constant::getNullValue(Type::getInt32Ty(GEP.getContext()));
Off[1] = NewIdx;
Value *NewGEP = GEP.isInBounds() && NSW ?
Builder->CreateInBoundsGEP(StrippedPtr, Off, GEP.getName()) :
Builder->CreateGEP(StrippedPtr, Off, GEP.getName());
// The NewGEP must be pointer typed, so must the old one -> BitCast
return new BitCastInst(NewGEP, GEP.getType());
}
}
// If the index will be to exactly the right offset with the scale taken
// out, perform the transformation. Note, we don't know whether Scale is
// signed or not. We'll use unsigned version of division/modulo
// operation after making sure Scale doesn't have the sign bit set.
if (ArrayEltSize && Scale && Scale->getSExtValue() >= 0LL &&
Scale->getZExtValue() % ArrayEltSize == 0) {
Scale = ConstantInt::get(Scale->getType(),
Scale->getZExtValue() / ArrayEltSize);
if (Scale->getZExtValue() != 1) {
Constant *C = ConstantExpr::getIntegerCast(Scale, NewIdx->getType(),
false /*ZExt*/);
NewIdx = Builder->CreateMul(NewIdx, C, "idxscale");
}
// Insert the new GEP instruction.
Value *Idx[2];
Idx[0] = Constant::getNullValue(Type::getInt32Ty(GEP.getContext()));
Idx[1] = NewIdx;
Value *NewGEP = GEP.isInBounds() ?
Builder->CreateInBoundsGEP(StrippedPtr, Idx, GEP.getName()):
Builder->CreateGEP(StrippedPtr, Idx, GEP.getName());
// The NewGEP must be pointer typed, so must the old one -> BitCast
return new BitCastInst(NewGEP, GEP.getType());
}
}
}
}

View File

@ -694,3 +694,200 @@ define i1 @test67(i1 %a, i32 %b) {
; CHECK: @test67
; CHECK: ret i1 false
}
%s = type { i32, i32, i32 }
define %s @test68(%s *%p, i64 %i) {
; CHECK: @test68
%o = mul i64 %i, 12
%q = bitcast %s* %p to i8*
%pp = getelementptr inbounds i8* %q, i64 %o
; CHECK-NEXT: getelementptr %s*
%r = bitcast i8* %pp to %s*
%l = load %s* %r
; CHECK-NEXT: load %s*
ret %s %l
; CHECK-NEXT: ret %s
}
define double @test69(double *%p, i64 %i) {
; CHECK: @test69
%o = shl nsw i64 %i, 3
%q = bitcast double* %p to i8*
%pp = getelementptr inbounds i8* %q, i64 %o
; CHECK-NEXT: getelementptr inbounds double*
%r = bitcast i8* %pp to double*
%l = load double* %r
; CHECK-NEXT: load double*
ret double %l
; CHECK-NEXT: ret double
}
define %s @test70(%s *%p, i64 %i) {
; CHECK: @test70
%o = mul nsw i64 %i, 36
; CHECK-NEXT: mul nsw i64 %i, 3
%q = bitcast %s* %p to i8*
%pp = getelementptr inbounds i8* %q, i64 %o
; CHECK-NEXT: getelementptr inbounds %s*
%r = bitcast i8* %pp to %s*
%l = load %s* %r
; CHECK-NEXT: load %s*
ret %s %l
; CHECK-NEXT: ret %s
}
define double @test71(double *%p, i64 %i) {
; CHECK: @test71
%o = shl i64 %i, 5
; CHECK-NEXT: shl i64 %i, 2
%q = bitcast double* %p to i8*
%pp = getelementptr i8* %q, i64 %o
; CHECK-NEXT: getelementptr double*
%r = bitcast i8* %pp to double*
%l = load double* %r
; CHECK-NEXT: load double*
ret double %l
; CHECK-NEXT: ret double
}
define double @test72(double *%p, i32 %i) {
; CHECK: @test72
%so = mul nsw i32 %i, 8
%o = sext i32 %so to i64
; CHECK-NEXT: sext i32 %i to i64
%q = bitcast double* %p to i8*
%pp = getelementptr inbounds i8* %q, i64 %o
; CHECK-NEXT: getelementptr inbounds double*
%r = bitcast i8* %pp to double*
%l = load double* %r
; CHECK-NEXT: load double*
ret double %l
; CHECK-NEXT: ret double
}
define double @test73(double *%p, i128 %i) {
; CHECK: @test73
%lo = mul nsw i128 %i, 8
%o = trunc i128 %lo to i64
; CHECK-NEXT: trunc i128 %i to i64
%q = bitcast double* %p to i8*
%pp = getelementptr inbounds i8* %q, i64 %o
; CHECK-NEXT: getelementptr double*
%r = bitcast i8* %pp to double*
%l = load double* %r
; CHECK-NEXT: load double*
ret double %l
; CHECK-NEXT: ret double
}
define double @test74(double *%p, i64 %i) {
; CHECK: @test74
%q = bitcast double* %p to i64*
%pp = getelementptr inbounds i64* %q, i64 %i
; CHECK-NEXT: getelementptr inbounds double*
%r = bitcast i64* %pp to double*
%l = load double* %r
; CHECK-NEXT: load double*
ret double %l
; CHECK-NEXT: ret double
}
define i32* @test75(i32* %p, i32 %x) {
; CHECK: @test75
%y = shl i32 %x, 3
; CHECK-NEXT: shl i32 %x, 3
%z = sext i32 %y to i64
; CHECK-NEXT: sext i32 %y to i64
%q = bitcast i32* %p to i8*
%r = getelementptr i8* %q, i64 %z
%s = bitcast i8* %r to i32*
ret i32* %s
}
define %s @test76(%s *%p, i64 %i, i64 %j) {
; CHECK: @test76
%o = mul i64 %i, 12
%o2 = mul nsw i64 %o, %j
; CHECK-NEXT: %o2 = mul i64 %i, %j
%q = bitcast %s* %p to i8*
%pp = getelementptr inbounds i8* %q, i64 %o2
; CHECK-NEXT: getelementptr %s* %p, i64 %o2
%r = bitcast i8* %pp to %s*
%l = load %s* %r
; CHECK-NEXT: load %s*
ret %s %l
; CHECK-NEXT: ret %s
}
define %s @test77(%s *%p, i64 %i, i64 %j) {
; CHECK: @test77
%o = mul nsw i64 %i, 36
%o2 = mul nsw i64 %o, %j
; CHECK-NEXT: %o = mul nsw i64 %i, 3
; CHECK-NEXT: %o2 = mul nsw i64 %o, %j
%q = bitcast %s* %p to i8*
%pp = getelementptr inbounds i8* %q, i64 %o2
; CHECK-NEXT: getelementptr inbounds %s* %p, i64 %o2
%r = bitcast i8* %pp to %s*
%l = load %s* %r
; CHECK-NEXT: load %s*
ret %s %l
; CHECK-NEXT: ret %s
}
define %s @test78(%s *%p, i64 %i, i64 %j, i32 %k, i32 %l, i128 %m, i128 %n) {
; CHECK: @test78
%a = mul nsw i32 %k, 36
; CHECK-NEXT: mul nsw i32 %k, 3
%b = mul nsw i32 %a, %l
; CHECK-NEXT: mul nsw i32 %a, %l
%c = sext i32 %b to i128
; CHECK-NEXT: sext i32 %b to i128
%d = mul nsw i128 %c, %m
; CHECK-NEXT: mul nsw i128 %c, %m
%e = mul i128 %d, %n
; CHECK-NEXT: mul i128 %d, %n
%f = trunc i128 %e to i64
; CHECK-NEXT: trunc i128 %e to i64
%g = mul nsw i64 %f, %i
; CHECK-NEXT: mul i64 %f, %i
%h = mul nsw i64 %g, %j
; CHECK-NEXT: mul i64 %g, %j
%q = bitcast %s* %p to i8*
%pp = getelementptr inbounds i8* %q, i64 %h
; CHECK-NEXT: getelementptr %s* %p, i64 %h
%r = bitcast i8* %pp to %s*
%load = load %s* %r
; CHECK-NEXT: load %s*
ret %s %load
; CHECK-NEXT: ret %s
}
define %s @test79(%s *%p, i64 %i, i32 %j) {
; CHECK: @test79
%a = mul nsw i64 %i, 36
; CHECK: mul nsw i64 %i, 36
%b = trunc i64 %a to i32
%c = mul i32 %b, %j
%q = bitcast %s* %p to i8*
; CHECK: bitcast
%pp = getelementptr inbounds i8* %q, i32 %c
%r = bitcast i8* %pp to %s*
%l = load %s* %r
ret %s %l
}
define double @test80([100 x double]* %p, i32 %i) {
; CHECK: @test80
%tmp = mul nsw i32 %i, 8
; CHECK-NEXT: sext i32 %i to i64
%q = bitcast [100 x double]* %p to i8*
%pp = getelementptr i8* %q, i32 %tmp
; CHECK-NEXT: getelementptr [100 x double]*
%r = bitcast i8* %pp to double*
%l = load double* %r
; CHECK-NEXT: load double*
ret double %l
; CHECK-NEXT: ret double
}