[C API] Add getters and setters for fast-math flags on relevant instructions (#75123)

These flags are usable on floating point arithmetic, as well as call,
select, and phi instructions whose resulting type is floating point, or
a vector of, or an array of, a valid type. Whether or not the flags are
valid for a given instruction can be checked with the new
LLVMCanValueUseFastMathFlags function.

These are exposed using a new LLVMFastMathFlags type, which is an alias
for unsigned. An anonymous enum defines the bit values for it.

Tests are added in echo.ll for select/phil/call, and the floating point
types in the new float_ops.ll bindings test.

Select and the floating point arithmetic instructions were not
implemented in llvm-c-test/echo.cpp, so they were added as well.
This commit is contained in:
Benji Smith 2023-12-12 11:15:05 -05:00 committed by GitHub
parent ed210f9f5a
commit d5c95302b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 350 additions and 0 deletions

View File

@ -226,6 +226,10 @@ Changes to the C API
* ``LLVMGetOperandBundleArgAtIndex`` * ``LLVMGetOperandBundleArgAtIndex``
* ``LLVMGetOperandBundleTag`` * ``LLVMGetOperandBundleTag``
* Added ``LLVMGetFastMathFlags`` and ``LLVMSetFastMathFlags`` for getting/setting
the fast-math flags of an instruction, as well as ``LLVMCanValueUseFastMathFlags``
for checking if an instruction can use such flags
Changes to the CodeGen infrastructure Changes to the CodeGen infrastructure
------------------------------------- -------------------------------------

View File

@ -483,6 +483,29 @@ typedef enum {
typedef unsigned LLVMAttributeIndex; typedef unsigned LLVMAttributeIndex;
enum {
LLVMFastMathAllowReassoc = (1 << 0),
LLVMFastMathNoNaNs = (1 << 1),
LLVMFastMathNoInfs = (1 << 2),
LLVMFastMathNoSignedZeros = (1 << 3),
LLVMFastMathAllowReciprocal = (1 << 4),
LLVMFastMathAllowContract = (1 << 5),
LLVMFastMathApproxFunc = (1 << 6),
LLVMFastMathNone = 0,
LLVMFastMathAll = LLVMFastMathAllowReassoc | LLVMFastMathNoNaNs |
LLVMFastMathNoInfs | LLVMFastMathNoSignedZeros |
LLVMFastMathAllowReciprocal | LLVMFastMathAllowContract |
LLVMFastMathApproxFunc,
};
/**
* Flags to indicate what fast-math-style optimizations are allowed
* on operations.
*
* See https://llvm.org/docs/LangRef.html#fast-math-flags
*/
typedef unsigned LLVMFastMathFlags;
/** /**
* @} * @}
*/ */
@ -4075,6 +4098,33 @@ LLVMBool LLVMGetNNeg(LLVMValueRef NonNegInst);
*/ */
void LLVMSetNNeg(LLVMValueRef NonNegInst, LLVMBool IsNonNeg); void LLVMSetNNeg(LLVMValueRef NonNegInst, LLVMBool IsNonNeg);
/**
* Get the flags for which fast-math-style optimizations are allowed for this
* value.
*
* Only valid on floating point instructions.
* @see LLVMCanValueUseFastMathFlags
*/
LLVMFastMathFlags LLVMGetFastMathFlags(LLVMValueRef FPMathInst);
/**
* Sets the flags for which fast-math-style optimizations are allowed for this
* value.
*
* Only valid on floating point instructions.
* @see LLVMCanValueUseFastMathFlags
*/
void LLVMSetFastMathFlags(LLVMValueRef FPMathInst, LLVMFastMathFlags FMF);
/**
* Check if a given value can potentially have fast math flags.
*
* Will return true for floating point arithmetic instructions, and for select,
* phi, and call instructions whose type is a floating point type, or a vector
* or array thereof. See https://llvm.org/docs/LangRef.html#fast-math-flags
*/
LLVMBool LLVMCanValueUseFastMathFlags(LLVMValueRef Inst);
/** /**
* Gets whether the instruction has the disjoint flag set. * Gets whether the instruction has the disjoint flag set.
* Only valid for or instructions. * Only valid for or instructions.

View File

@ -3319,6 +3319,39 @@ void LLVMSetArgOperand(LLVMValueRef Funclet, unsigned i, LLVMValueRef value) {
/*--.. Arithmetic ..........................................................--*/ /*--.. Arithmetic ..........................................................--*/
static FastMathFlags mapFromLLVMFastMathFlags(LLVMFastMathFlags FMF) {
FastMathFlags NewFMF;
NewFMF.setAllowReassoc((FMF & LLVMFastMathAllowReassoc) != 0);
NewFMF.setNoNaNs((FMF & LLVMFastMathNoNaNs) != 0);
NewFMF.setNoInfs((FMF & LLVMFastMathNoInfs) != 0);
NewFMF.setNoSignedZeros((FMF & LLVMFastMathNoSignedZeros) != 0);
NewFMF.setAllowReciprocal((FMF & LLVMFastMathAllowReciprocal) != 0);
NewFMF.setAllowContract((FMF & LLVMFastMathAllowContract) != 0);
NewFMF.setApproxFunc((FMF & LLVMFastMathApproxFunc) != 0);
return NewFMF;
}
static LLVMFastMathFlags mapToLLVMFastMathFlags(FastMathFlags FMF) {
LLVMFastMathFlags NewFMF = LLVMFastMathNone;
if (FMF.allowReassoc())
NewFMF |= LLVMFastMathAllowReassoc;
if (FMF.noNaNs())
NewFMF |= LLVMFastMathNoNaNs;
if (FMF.noInfs())
NewFMF |= LLVMFastMathNoInfs;
if (FMF.noSignedZeros())
NewFMF |= LLVMFastMathNoSignedZeros;
if (FMF.allowReciprocal())
NewFMF |= LLVMFastMathAllowReciprocal;
if (FMF.allowContract())
NewFMF |= LLVMFastMathAllowContract;
if (FMF.approxFunc())
NewFMF |= LLVMFastMathApproxFunc;
return NewFMF;
}
LLVMValueRef LLVMBuildAdd(LLVMBuilderRef B, LLVMValueRef LHS, LLVMValueRef RHS, LLVMValueRef LLVMBuildAdd(LLVMBuilderRef B, LLVMValueRef LHS, LLVMValueRef RHS,
const char *Name) { const char *Name) {
return wrap(unwrap(B)->CreateAdd(unwrap(LHS), unwrap(RHS), Name)); return wrap(unwrap(B)->CreateAdd(unwrap(LHS), unwrap(RHS), Name));
@ -3518,6 +3551,22 @@ void LLVMSetNNeg(LLVMValueRef NonNegInst, LLVMBool IsNonNeg) {
cast<Instruction>(P)->setNonNeg(IsNonNeg); cast<Instruction>(P)->setNonNeg(IsNonNeg);
} }
LLVMFastMathFlags LLVMGetFastMathFlags(LLVMValueRef FPMathInst) {
Value *P = unwrap<Value>(FPMathInst);
FastMathFlags FMF = cast<Instruction>(P)->getFastMathFlags();
return mapToLLVMFastMathFlags(FMF);
}
void LLVMSetFastMathFlags(LLVMValueRef FPMathInst, LLVMFastMathFlags FMF) {
Value *P = unwrap<Value>(FPMathInst);
cast<Instruction>(P)->setFastMathFlags(mapFromLLVMFastMathFlags(FMF));
}
LLVMBool LLVMCanValueUseFastMathFlags(LLVMValueRef V) {
Value *Val = unwrap<Value>(V);
return isa<FPMathOperator>(Val);
}
LLVMBool LLVMGetIsDisjoint(LLVMValueRef Inst) { LLVMBool LLVMGetIsDisjoint(LLVMValueRef Inst) {
Value *P = unwrap<Value>(Inst); Value *P = unwrap<Value>(Inst);
return cast<PossiblyDisjointInst>(P)->isDisjoint(); return cast<PossiblyDisjointInst>(P)->isDisjoint();

View File

@ -299,6 +299,41 @@ entry:
ret void ret void
} }
define void @test_fast_math_flags(i1 %c, float %a, float %b) {
entry:
%select.f.1 = select i1 %c, float %a, float %b
%select.f.2 = select nsz i1 %c, float %a, float %b
%select.f.3 = select fast i1 %c, float %a, float %b
%select.f.4 = select nnan arcp afn i1 %c, float %a, float %b
br i1 %c, label %choose_a, label %choose_b
choose_a:
br label %final
choose_b:
br label %final
final:
%phi.f.1 = phi float [ %a, %choose_a ], [ %b, %choose_b ]
%phi.f.2 = phi nsz float [ %a, %choose_a ], [ %b, %choose_b ]
%phi.f.3 = phi fast float [ %a, %choose_a ], [ %b, %choose_b ]
%phi.f.4 = phi nnan arcp afn float [ %a, %choose_a ], [ %b, %choose_b ]
ret void
}
define float @test_fast_math_flags_call_inner(float %a) {
ret float %a
}
define void @test_fast_math_flags_call_outer(float %a) {
%a.1 = call float @test_fast_math_flags_call_inner(float %a)
%a.2 = call nsz float @test_fast_math_flags_call_inner(float %a)
%a.3 = call fast float @test_fast_math_flags_call_inner(float %a)
%a.4 = call nnan arcp afn float @test_fast_math_flags_call_inner(float %a)
ret void
}
!llvm.dbg.cu = !{!0, !2} !llvm.dbg.cu = !{!0, !2}
!llvm.module.flags = !{!3} !llvm.module.flags = !{!3}

View File

@ -0,0 +1,156 @@
; RUN: llvm-as < %s | llvm-dis > %t.orig
; RUN: llvm-as < %s | llvm-c-test --echo > %t.echo
; RUN: diff -w %t.orig %t.echo
;
source_filename = "/test/Bindings/float_ops.ll"
target datalayout = "e-m:o-i64:64-f80:128-n8:16:32:64-S128"
define float @float_ops_f32(float %a, float %b) {
%1 = fneg float %a
%2 = fadd float %a, %b
%3 = fsub float %a, %b
%4 = fmul float %a, %b
%5 = fdiv float %a, %b
%6 = frem float %a, %b
ret float %1
}
define double @float_ops_f64(double %a, double %b) {
%1 = fneg double %a
%2 = fadd double %a, %b
%3 = fsub double %a, %b
%4 = fmul double %a, %b
%5 = fdiv double %a, %b
%6 = frem double %a, %b
ret double %1
}
define void @float_cmp_f32(float %a, float %b) {
%1 = fcmp oeq float %a, %b
%2 = fcmp ogt float %a, %b
%3 = fcmp olt float %a, %b
%4 = fcmp ole float %a, %b
%5 = fcmp one float %a, %b
%6 = fcmp ueq float %a, %b
%7 = fcmp ugt float %a, %b
%8 = fcmp ult float %a, %b
%9 = fcmp ule float %a, %b
%10 = fcmp une float %a, %b
%11 = fcmp ord float %a, %b
%12 = fcmp false float %a, %b
%13 = fcmp true float %a, %b
ret void
}
define void @float_cmp_f64(double %a, double %b) {
%1 = fcmp oeq double %a, %b
%2 = fcmp ogt double %a, %b
%3 = fcmp olt double %a, %b
%4 = fcmp ole double %a, %b
%5 = fcmp one double %a, %b
%6 = fcmp ueq double %a, %b
%7 = fcmp ugt double %a, %b
%8 = fcmp ult double %a, %b
%9 = fcmp ule double %a, %b
%10 = fcmp une double %a, %b
%11 = fcmp ord double %a, %b
%12 = fcmp false double %a, %b
%13 = fcmp true double %a, %b
ret void
}
define void @float_cmp_fast_f32(float %a, float %b) {
%1 = fcmp fast oeq float %a, %b
%2 = fcmp nsz ogt float %a, %b
%3 = fcmp nsz nnan olt float %a, %b
%4 = fcmp contract ole float %a, %b
%5 = fcmp nnan one float %a, %b
%6 = fcmp nnan ninf nsz ueq float %a, %b
%7 = fcmp arcp ugt float %a, %b
%8 = fcmp fast ult float %a, %b
%9 = fcmp fast ule float %a, %b
%10 = fcmp fast une float %a, %b
%11 = fcmp fast ord float %a, %b
%12 = fcmp nnan ninf false float %a, %b
%13 = fcmp nnan ninf true float %a, %b
ret void
}
define void @float_cmp_fast_f64(double %a, double %b) {
%1 = fcmp fast oeq double %a, %b
%2 = fcmp nsz ogt double %a, %b
%3 = fcmp nsz nnan olt double %a, %b
%4 = fcmp contract ole double %a, %b
%5 = fcmp nnan one double %a, %b
%6 = fcmp nnan ninf nsz ueq double %a, %b
%7 = fcmp arcp ugt double %a, %b
%8 = fcmp fast ult double %a, %b
%9 = fcmp fast ule double %a, %b
%10 = fcmp fast une double %a, %b
%11 = fcmp fast ord double %a, %b
%12 = fcmp nnan ninf false double %a, %b
%13 = fcmp nnan ninf true double %a, %b
ret void
}
define float @float_ops_fast_f32(float %a, float %b) {
%1 = fneg nnan float %a
%2 = fadd ninf float %a, %b
%3 = fsub nsz float %a, %b
%4 = fmul arcp float %a, %b
%5 = fdiv contract float %a, %b
%6 = frem afn float %a, %b
%7 = fadd reassoc float %a, %b
%8 = fadd reassoc float %7, %b
%9 = fadd fast float %a, %b
%10 = fadd nnan nsz float %a, %b
%11 = frem nnan nsz float %a, %b
%12 = fdiv nnan nsz arcp float %a, %b
%13 = fmul nnan nsz ninf contract float %a, %b
%14 = fmul nnan nsz ninf arcp contract afn reassoc float %a, %b
ret float %1
}
define double @float_ops_fast_f64(double %a, double %b) {
%1 = fneg nnan double %a
%2 = fadd ninf double %a, %b
%3 = fsub nsz double %a, %b
%4 = fmul arcp double %a, %b
%5 = fdiv contract double %a, %b
%6 = frem afn double %a, %b
%7 = fadd reassoc double %a, %b
%8 = fadd reassoc double %7, %b
%9 = fadd fast double %a, %b
%10 = fadd nnan nsz double %a, %b
%11 = frem nnan nsz double %a, %b
%12 = fdiv nnan nsz arcp double %a, %b
%13 = fmul nnan nsz ninf contract double %a, %b
%14 = fmul nnan nsz ninf arcp contract afn reassoc double %a, %b
ret double %1
}

View File

@ -770,8 +770,18 @@ struct FunCloner {
} }
LLVMAddIncoming(Dst, Values.data(), Blocks.data(), IncomingCount); LLVMAddIncoming(Dst, Values.data(), Blocks.data(), IncomingCount);
// Copy fast math flags here since we return early
if (LLVMCanValueUseFastMathFlags(Src))
LLVMSetFastMathFlags(Dst, LLVMGetFastMathFlags(Src));
return Dst; return Dst;
} }
case LLVMSelect: {
LLVMValueRef If = CloneValue(LLVMGetOperand(Src, 0));
LLVMValueRef Then = CloneValue(LLVMGetOperand(Src, 1));
LLVMValueRef Else = CloneValue(LLVMGetOperand(Src, 2));
Dst = LLVMBuildSelect(Builder, If, Then, Else, Name);
break;
}
case LLVMCall: { case LLVMCall: {
SmallVector<LLVMValueRef, 8> Args; SmallVector<LLVMValueRef, 8> Args;
SmallVector<LLVMOperandBundleRef, 8> Bundles; SmallVector<LLVMOperandBundleRef, 8> Bundles;
@ -930,6 +940,48 @@ struct FunCloner {
LLVMSetNNeg(Dst, NNeg); LLVMSetNNeg(Dst, NNeg);
break; break;
} }
case LLVMFAdd: {
LLVMValueRef LHS = CloneValue(LLVMGetOperand(Src, 0));
LLVMValueRef RHS = CloneValue(LLVMGetOperand(Src, 1));
Dst = LLVMBuildFAdd(Builder, LHS, RHS, Name);
break;
}
case LLVMFSub: {
LLVMValueRef LHS = CloneValue(LLVMGetOperand(Src, 0));
LLVMValueRef RHS = CloneValue(LLVMGetOperand(Src, 1));
Dst = LLVMBuildFSub(Builder, LHS, RHS, Name);
break;
}
case LLVMFMul: {
LLVMValueRef LHS = CloneValue(LLVMGetOperand(Src, 0));
LLVMValueRef RHS = CloneValue(LLVMGetOperand(Src, 1));
Dst = LLVMBuildFMul(Builder, LHS, RHS, Name);
break;
}
case LLVMFDiv: {
LLVMValueRef LHS = CloneValue(LLVMGetOperand(Src, 0));
LLVMValueRef RHS = CloneValue(LLVMGetOperand(Src, 1));
Dst = LLVMBuildFDiv(Builder, LHS, RHS, Name);
break;
}
case LLVMFRem: {
LLVMValueRef LHS = CloneValue(LLVMGetOperand(Src, 0));
LLVMValueRef RHS = CloneValue(LLVMGetOperand(Src, 1));
Dst = LLVMBuildFRem(Builder, LHS, RHS, Name);
break;
}
case LLVMFNeg: {
LLVMValueRef Val = CloneValue(LLVMGetOperand(Src, 0));
Dst = LLVMBuildFNeg(Builder, Val, Name);
break;
}
case LLVMFCmp: {
LLVMRealPredicate Pred = LLVMGetFCmpPredicate(Src);
LLVMValueRef LHS = CloneValue(LLVMGetOperand(Src, 0));
LLVMValueRef RHS = CloneValue(LLVMGetOperand(Src, 1));
Dst = LLVMBuildFCmp(Builder, Pred, LHS, RHS, Name);
break;
}
default: default:
break; break;
} }
@ -939,6 +991,10 @@ struct FunCloner {
exit(-1); exit(-1);
} }
// Copy fast-math flags on instructions that support them
if (LLVMCanValueUseFastMathFlags(Src))
LLVMSetFastMathFlags(Dst, LLVMGetFastMathFlags(Src));
auto Ctx = LLVMGetModuleContext(M); auto Ctx = LLVMGetModuleContext(M);
size_t NumMetadataEntries; size_t NumMetadataEntries;
auto *AllMetadata = auto *AllMetadata =