diff --git a/lib/Analysis/ConstantFolding.cpp b/lib/Analysis/ConstantFolding.cpp index 254f2d9f50b..08b60671a63 100644 --- a/lib/Analysis/ConstantFolding.cpp +++ b/lib/Analysis/ConstantFolding.cpp @@ -1327,28 +1327,19 @@ static double getValueAsDouble(ConstantFP *Op) { return APF.convertToDouble(); } -/// ConstantFoldCall - Attempt to constant fold a call to the specified function -/// with the specified arguments, returning null if unsuccessful. -Constant * -llvm::ConstantFoldCall(Function *F, ArrayRef Operands, - const TargetLibraryInfo *TLI) { - if (!F->hasName()) - return 0; - StringRef Name = F->getName(); - - Type *Ty = F->getReturnType(); +static Constant *ConstantFoldScalarCall(StringRef Name, unsigned IntrinsicID, + Type *Ty, ArrayRef Operands, + const TargetLibraryInfo *TLI) { if (Operands.size() == 1) { if (ConstantFP *Op = dyn_cast(Operands[0])) { - if (F->getIntrinsicID() == Intrinsic::convert_to_fp16) { + if (IntrinsicID == Intrinsic::convert_to_fp16) { APFloat Val(Op->getValueAPF()); bool lost = false; Val.convert(APFloat::IEEEhalf, APFloat::rmNearestTiesToEven, &lost); - return ConstantInt::get(F->getContext(), Val.bitcastToAPInt()); + return ConstantInt::get(Ty->getContext(), Val.bitcastToAPInt()); } - if (!TLI) - return 0; if (!Ty->isHalfTy() && !Ty->isFloatTy() && !Ty->isDoubleTy()) return 0; @@ -1365,7 +1356,7 @@ llvm::ConstantFoldCall(Function *F, ArrayRef Operands, /// f(arg). Long double not supported yet. double V = getValueAsDouble(Op); - switch (F->getIntrinsicID()) { + switch (IntrinsicID) { default: break; case Intrinsic::fabs: return ConstantFoldFP(fabs, V, Ty); @@ -1393,6 +1384,9 @@ llvm::ConstantFoldCall(Function *F, ArrayRef Operands, return ConstantFoldFP(floor, V, Ty); } + if (!TLI) + return 0; + switch (Name[0]) { case 'a': if (Name == "acos" && TLI->has(LibFunc::acos)) @@ -1433,7 +1427,7 @@ llvm::ConstantFoldCall(Function *F, ArrayRef Operands, return ConstantFoldFP(log, V, Ty); else if (Name == "log10" && V > 0 && TLI->has(LibFunc::log10)) return ConstantFoldFP(log10, V, Ty); - else if (F->getIntrinsicID() == Intrinsic::sqrt && + else if (IntrinsicID == Intrinsic::sqrt && (Ty->isHalfTy() || Ty->isFloatTy() || Ty->isDoubleTy())) { if (V >= -0.0) return ConstantFoldFP(sqrt, V, Ty); @@ -1466,9 +1460,9 @@ llvm::ConstantFoldCall(Function *F, ArrayRef Operands, } if (ConstantInt *Op = dyn_cast(Operands[0])) { - switch (F->getIntrinsicID()) { + switch (IntrinsicID) { case Intrinsic::bswap: - return ConstantInt::get(F->getContext(), Op->getValue().byteSwap()); + return ConstantInt::get(Ty->getContext(), Op->getValue().byteSwap()); case Intrinsic::ctpop: return ConstantInt::get(Ty, Op->getValue().countPopulation()); case Intrinsic::convert_from_fp16: { @@ -1483,7 +1477,7 @@ llvm::ConstantFoldCall(Function *F, ArrayRef Operands, assert(status == APFloat::opOK && !lost && "Precision lost during fp16 constfolding"); - return ConstantFP::get(F->getContext(), Val); + return ConstantFP::get(Ty->getContext(), Val); } default: return 0; @@ -1494,7 +1488,7 @@ llvm::ConstantFoldCall(Function *F, ArrayRef Operands, if (isa(Operands[0]) || isa(Operands[0])) { Constant *Op = cast(Operands[0]); - switch (F->getIntrinsicID()) { + switch (IntrinsicID) { default: break; case Intrinsic::x86_sse_cvtss2si: case Intrinsic::x86_sse_cvtss2si64: @@ -1516,7 +1510,7 @@ llvm::ConstantFoldCall(Function *F, ArrayRef Operands, } if (isa(Operands[0])) { - if (F->getIntrinsicID() == Intrinsic::bswap) + if (IntrinsicID == Intrinsic::bswap) return Operands[0]; return 0; } @@ -1535,7 +1529,7 @@ llvm::ConstantFoldCall(Function *F, ArrayRef Operands, return 0; double Op2V = getValueAsDouble(Op2); - if (F->getIntrinsicID() == Intrinsic::pow) { + if (IntrinsicID == Intrinsic::pow) { return ConstantFoldBinaryFP(pow, Op1V, Op2V, Ty); } if (!TLI) @@ -1547,16 +1541,16 @@ llvm::ConstantFoldCall(Function *F, ArrayRef Operands, if (Name == "atan2" && TLI->has(LibFunc::atan2)) return ConstantFoldBinaryFP(atan2, Op1V, Op2V, Ty); } else if (ConstantInt *Op2C = dyn_cast(Operands[1])) { - if (F->getIntrinsicID() == Intrinsic::powi && Ty->isHalfTy()) - return ConstantFP::get(F->getContext(), + if (IntrinsicID == Intrinsic::powi && Ty->isHalfTy()) + return ConstantFP::get(Ty->getContext(), APFloat((float)std::pow((float)Op1V, (int)Op2C->getZExtValue()))); - if (F->getIntrinsicID() == Intrinsic::powi && Ty->isFloatTy()) - return ConstantFP::get(F->getContext(), + if (IntrinsicID == Intrinsic::powi && Ty->isFloatTy()) + return ConstantFP::get(Ty->getContext(), APFloat((float)std::pow((float)Op1V, (int)Op2C->getZExtValue()))); - if (F->getIntrinsicID() == Intrinsic::powi && Ty->isDoubleTy()) - return ConstantFP::get(F->getContext(), + if (IntrinsicID == Intrinsic::powi && Ty->isDoubleTy()) + return ConstantFP::get(Ty->getContext(), APFloat((double)std::pow((double)Op1V, (int)Op2C->getZExtValue()))); } @@ -1565,7 +1559,7 @@ llvm::ConstantFoldCall(Function *F, ArrayRef Operands, if (ConstantInt *Op1 = dyn_cast(Operands[0])) { if (ConstantInt *Op2 = dyn_cast(Operands[1])) { - switch (F->getIntrinsicID()) { + switch (IntrinsicID) { default: break; case Intrinsic::sadd_with_overflow: case Intrinsic::uadd_with_overflow: @@ -1575,7 +1569,7 @@ llvm::ConstantFoldCall(Function *F, ArrayRef Operands, case Intrinsic::umul_with_overflow: { APInt Res; bool Overflow; - switch (F->getIntrinsicID()) { + switch (IntrinsicID) { default: llvm_unreachable("Invalid case"); case Intrinsic::sadd_with_overflow: Res = Op1->getValue().sadd_ov(Op2->getValue(), Overflow); @@ -1597,10 +1591,10 @@ llvm::ConstantFoldCall(Function *F, ArrayRef Operands, break; } Constant *Ops[] = { - ConstantInt::get(F->getContext(), Res), - ConstantInt::get(Type::getInt1Ty(F->getContext()), Overflow) + ConstantInt::get(Ty->getContext(), Res), + ConstantInt::get(Type::getInt1Ty(Ty->getContext()), Overflow) }; - return ConstantStruct::get(cast(F->getReturnType()), Ops); + return ConstantStruct::get(cast(Ty), Ops); } case Intrinsic::cttz: if (Op2->isOne() && Op1->isZero()) // cttz(0, 1) is undef. @@ -1624,7 +1618,7 @@ llvm::ConstantFoldCall(Function *F, ArrayRef Operands, if (const ConstantFP *Op1 = dyn_cast(Operands[0])) { if (const ConstantFP *Op2 = dyn_cast(Operands[1])) { if (const ConstantFP *Op3 = dyn_cast(Operands[2])) { - switch (F->getIntrinsicID()) { + switch (IntrinsicID) { default: break; case Intrinsic::fma: case Intrinsic::fmuladd: { @@ -1644,3 +1638,48 @@ llvm::ConstantFoldCall(Function *F, ArrayRef Operands, return 0; } + +static Constant *ConstantFoldVectorCall(StringRef Name, unsigned IntrinsicID, + VectorType *VTy, + ArrayRef Operands, + const TargetLibraryInfo *TLI) { + SmallVector Result(VTy->getNumElements()); + SmallVector Lane(Operands.size()); + Type *Ty = VTy->getElementType(); + + for (unsigned I = 0, E = VTy->getNumElements(); I != E; ++I) { + // Gather a column of constants. + for (unsigned J = 0, JE = Operands.size(); J != JE; ++J) { + Constant *Agg = Operands[J]->getAggregateElement(I); + if (!Agg) + return nullptr; + + Lane[J] = Agg; + } + + // Use the regular scalar folding to simplify this column. + Constant *Folded = ConstantFoldScalarCall(Name, IntrinsicID, Ty, Lane, TLI); + if (!Folded) + return nullptr; + Result[I] = Folded; + } + + return ConstantVector::get(Result); +} + +/// ConstantFoldCall - Attempt to constant fold a call to the specified function +/// with the specified arguments, returning null if unsuccessful. +Constant * +llvm::ConstantFoldCall(Function *F, ArrayRef Operands, + const TargetLibraryInfo *TLI) { + if (!F->hasName()) + return 0; + StringRef Name = F->getName(); + + Type *Ty = F->getReturnType(); + + if (VectorType *VTy = dyn_cast(Ty)) + return ConstantFoldVectorCall(Name, F->getIntrinsicID(), VTy, Operands, TLI); + + return ConstantFoldScalarCall(Name, F->getIntrinsicID(), Ty, Operands, TLI); +} diff --git a/test/Transforms/InstCombine/constant-fold-math.ll b/test/Transforms/InstCombine/constant-fold-math.ll index 00fceb19151..14377df3729 100644 --- a/test/Transforms/InstCombine/constant-fold-math.ll +++ b/test/Transforms/InstCombine/constant-fold-math.ll @@ -2,6 +2,7 @@ declare float @llvm.fma.f32(float, float, float) #0 declare float @llvm.fmuladd.f32(float, float, float) #0 +declare <4 x float> @llvm.fma.v4f32(<4 x float>, <4 x float>, <4 x float>) #0 declare double @llvm.fma.f64(double, double, double) #0 declare double @llvm.fmuladd.f64(double, double, double) #0 @@ -15,6 +16,13 @@ define float @constant_fold_fma_f32() #0 { ret float %x } +; CHECK-LABEL: @constant_fold_fma_v4f32 +; CHECK-NEXT: ret <4 x float> +define <4 x float> @constant_fold_fma_v4f32() #0 { + %x = call <4 x float> @llvm.fma.v4f32(<4 x float> , <4 x float> , <4 x float> ) + ret <4 x float> %x +} + ; CHECK-LABEL: @constant_fold_fmuladd_f32 ; CHECK-NEXT: ret float 6.000000e+00 define float @constant_fold_fmuladd_f32() #0 {