[mlir] Update LLVMIR Fastmath flags use of MLIR BitEnum functionality

This diff updates the LLVMIR dialect Fastmath flags attribute to use recently
added features of `BitEnum` attributes. Specifically, this diff uses the bit
enum "group" case to represent the `fast` value as an alias for a combination
of other values (`ninf`, `nnan`, ...), instead of using a separate integer
value. (This is in line with LLVM's fastmath flags representation.) This diff
also leverages the `printBitEnumPrimaryGroups` `tblgen` field for concise
enum printing.

The `BitEnum` features were developed for an upcoming diff that adds `fastmath`
support to the arithmetic dialect. This diff simply applies some of the relevant
new features to the LLVM dialect attribute.

Reviewed By: ftynse, Mogball

Differential Revision: https://reviews.llvm.org/D124720
This commit is contained in:
jfurtek 2022-05-17 18:18:52 +00:00 committed by Mogball
parent 127a1492d7
commit 5c3b20520b
10 changed files with 39 additions and 41 deletions

View File

@ -28,14 +28,17 @@ def FMFarcp : I32BitEnumAttrCaseBit<"arcp", 3>;
def FMFcontract : I32BitEnumAttrCaseBit<"contract", 4>;
def FMFafn : I32BitEnumAttrCaseBit<"afn", 5>;
def FMFreassoc : I32BitEnumAttrCaseBit<"reassoc", 6>;
def FMFfast : I32BitEnumAttrCaseBit<"fast", 7>;
def FMFfast : I32BitEnumAttrCaseGroup<"fast",
[ FMFnnan, FMFninf, FMFnsz, FMFarcp, FMFcontract, FMFafn, FMFreassoc]>;
def FastmathFlags : I32BitEnumAttr<
"FastmathFlags",
"LLVM fastmath flags",
[FMFnnan, FMFninf, FMFnsz, FMFarcp, FMFcontract, FMFafn, FMFreassoc, FMFfast
]> {
let separator = ", ";
let cppNamespace = "::mlir::LLVM";
let printBitEnumPrimaryGroups = 1;
}
def LLVM_FMFAttr : DialectAttr<

View File

@ -267,13 +267,20 @@ class BitEnumAttr<I intType, string name, string summary,
// bits together.
let symbolToStringFnRetType = "std::string";
// The delimiter used to separate bit enum cases in strings.
// The delimiter used to separate bit enum cases in strings. Only "|" and
// "," (along with optional spaces) are supported due to the use of the
// parseSeparatorFn in parameterParser below.
// Spaces in the separator string are used for printing, but will be optional
// for parsing.
string separator = "|";
assert !or(!ge(!find(separator, "|"), 0), !ge(!find(separator, ","), 0)),
"separator must contain '|' or ',' for parameter parsing";
// Parsing function that corresponds to the enum separator. Only
// "," and "|" are supported by this definition.
string parseSeparatorFn = !if(!eq(separator,"|"),"parseOptionalVerticalBar",
"parseOptionalComma");
string parseSeparatorFn = !if(!ge(!find(separator, "|"), 0),
"parseOptionalVerticalBar",
"parseOptionalComma");
// Parse a keyword and pass it to `stringToSymbol`. Emit an error if a the
// symbol is not valid.

View File

@ -2884,26 +2884,9 @@ bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
op->hasTrait<OpTrait::IsIsolatedFromAbove>();
}
static constexpr const FastmathFlags fastmathFlagsList[] = {
// clang-format off
FastmathFlags::nnan,
FastmathFlags::ninf,
FastmathFlags::nsz,
FastmathFlags::arcp,
FastmathFlags::contract,
FastmathFlags::afn,
FastmathFlags::reassoc,
FastmathFlags::fast,
// clang-format on
};
void FMFAttr::print(AsmPrinter &printer) const {
printer << "<";
auto flags = llvm::make_filter_range(fastmathFlagsList, [&](auto flag) {
return bitEnumContains(this->getFlags(), flag);
});
llvm::interleaveComma(flags, printer,
[&](auto flag) { printer << stringifyEnum(flag); });
printer << stringifyFastmathFlags(this->getFlags());
printer << ">";
}

View File

@ -157,7 +157,6 @@ static llvm::FastMathFlags getFastmathFlags(FastmathFlagsInterface &op) {
{FastmathFlags::contract, &llvmFMF::setAllowContract},
{FastmathFlags::afn, &llvmFMF::setApproxFunc},
{FastmathFlags::reassoc, &llvmFMF::setAllowReassoc},
{FastmathFlags::fast, &llvmFMF::setFast},
// clang-format on
};
llvm::FastMathFlags ret;

View File

@ -445,7 +445,7 @@ func.func @fastmathFlags(%arg0: f32, %arg1: f32, %arg2: i32) {
// CHECK: {{.*}} = llvm.fadd %arg0, %arg1 : f32
%8 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<>} : f32
// CHECK: {{.*}} = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<nnan, ninf>} : f32
%9 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<nnan, ninf>} : f32
%9 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<nnan,ninf>} : f32
// CHECK: {{.*}} = llvm.fneg %arg0 : f32
%10 = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath<>} : f32

View File

@ -413,9 +413,9 @@ func.func @disallowed_case7_fail() {
// CHECK-LABEL: func @allowed_cases_pass
func.func @allowed_cases_pass() {
// CHECK: test.op_with_bit_enum <read,write>
// CHECK: test.op_with_bit_enum <read, write>
"test.op_with_bit_enum"() {value = #test.bit_enum<read, write>} : () -> ()
// CHECK: test.op_with_bit_enum <read,execute>
// CHECK: test.op_with_bit_enum <read, execute>
test.op_with_bit_enum <read,execute>
return
}
@ -424,11 +424,11 @@ func.func @allowed_cases_pass() {
// CHECK-LABEL: func @allowed_cases_pass
func.func @allowed_cases_pass() {
// CHECK: test.op_with_bit_enum_vbar <user|group>
// CHECK: test.op_with_bit_enum_vbar <user | group>
"test.op_with_bit_enum_vbar"() {
value = #test.bit_enum_vbar<user|group>
} : () -> ()
// CHECK: test.op_with_bit_enum_vbar <user|group|other>
// CHECK: test.op_with_bit_enum_vbar <user | group | other>
test.op_with_bit_enum_vbar <user | group | other>
return
}

View File

@ -324,7 +324,7 @@ def TestBitEnum
]> {
let genSpecializedAttr = 0;
let cppNamespace = "test";
let separator = ",";
let separator = ", ";
}
// Define the enum attribute.
@ -347,7 +347,7 @@ def TestBitEnumVerticalBar
]> {
let genSpecializedAttr = 0;
let cppNamespace = "test";
let separator = "|";
let separator = " | ";
}
def TestBitEnumVerticalBarAttr

View File

@ -277,6 +277,7 @@ static void emitStrToSymFnForBitEnum(const Record &enumDef, raw_ostream &os) {
std::string underlyingType = std::string(enumAttr.getUnderlyingType());
StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
StringRef separator = enumDef.getValueAsString("separator");
StringRef separatorTrimmed = separator.trim();
auto enumerants = enumAttr.getAllCases();
auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants);
@ -292,15 +293,16 @@ static void emitStrToSymFnForBitEnum(const Record &enumDef, raw_ostream &os) {
// Split the string to get symbols for all the bits.
os << " ::llvm::SmallVector<::llvm::StringRef, 2> symbols;\n";
os << formatv(" str.split(symbols, \"{0}\");\n\n", separator);
// Remove whitespace from the separator string when parsing.
os << formatv(" str.split(symbols, \"{0}\");\n\n", separatorTrimmed);
os << formatv(" {0} val = 0;\n", underlyingType);
os << " for (auto symbol : symbols) {\n";
// Convert each symbol to the bit ordinal and set the corresponding bit.
os << formatv(
" auto bit = llvm::StringSwitch<::llvm::Optional<{0}>>(symbol)\n",
underlyingType);
os << formatv(" auto bit = "
"llvm::StringSwitch<::llvm::Optional<{0}>>(symbol.trim())\n",
underlyingType);
for (const auto &enumerant : enumerants) {
// Skip the special enumerant for None.
if (auto val = enumerant.getValue())

View File

@ -80,7 +80,7 @@ TEST(EnumsGenTest, GeneratedSymbolToStringFnForBitEnum) {
EXPECT_EQ(stringifyBitEnumWithNone(BitEnumWithNone::Bit3), "Bit3");
EXPECT_EQ(
stringifyBitEnumWithNone(BitEnumWithNone::Bit0 | BitEnumWithNone::Bit3),
"Bit0|Bit3");
"Bit0 | Bit3");
EXPECT_EQ(stringifyBitEnum64_Test(BitEnum64_Test::Bit1), "Bit1");
EXPECT_EQ(
@ -96,7 +96,7 @@ TEST(EnumsGenTest, GeneratedStringToSymbolForBitEnum) {
BitEnumWithNone::Bit3 | BitEnumWithNone::Bit0);
EXPECT_EQ(symbolizeBitEnumWithNone("Bit2"), llvm::None);
EXPECT_EQ(symbolizeBitEnumWithNone("Bit3|Bit4"), llvm::None);
EXPECT_EQ(symbolizeBitEnumWithNone("Bit3 | Bit4"), llvm::None);
EXPECT_EQ(symbolizeBitEnumWithoutNone("None"), llvm::None);
}
@ -129,11 +129,11 @@ TEST(EnumsGenTest, GeneratedSymbolToStringFnForPrimaryGroupBitEnum) {
EXPECT_EQ(stringifyBitEnumPrimaryGroup(BitEnumPrimaryGroup::Bit0 |
BitEnumPrimaryGroup::Bit2 |
BitEnumPrimaryGroup::Bit3),
"Bit0,Bit2,Bit3");
"Bit0, Bit2, Bit3");
EXPECT_EQ(stringifyBitEnumPrimaryGroup(BitEnumPrimaryGroup::Bit0 |
BitEnumPrimaryGroup::Bit4 |
BitEnumPrimaryGroup::Bit5),
"Bits4And5,Bit0");
"Bits4And5, Bit0");
EXPECT_EQ(stringifyBitEnumPrimaryGroup(
BitEnumPrimaryGroup::Bit0 | BitEnumPrimaryGroup::Bit1 |
BitEnumPrimaryGroup::Bit2 | BitEnumPrimaryGroup::Bit3 |

View File

@ -33,7 +33,9 @@ def Bit4 : I32BitEnumAttrCaseBit<"Bit4", 4>;
def Bit5 : I32BitEnumAttrCaseBit<"Bit5", 5>;
def BitEnumWithNone : I32BitEnumAttr<"BitEnumWithNone", "A test enum",
[NoBits, Bit0, Bit3]>;
[NoBits, Bit0, Bit3]> {
let separator = " | ";
}
def BitEnumWithoutNone : I32BitEnumAttr<"BitEnumWithoutNone", "A test enum",
[Bit0, Bit3]>;
@ -46,12 +48,14 @@ def Bits0To5 : I32BitEnumAttrCaseGroup<"Bits0To5",
[Bits0To3, Bits4And5]>;
def BitEnumWithGroup : I32BitEnumAttr<"BitEnumWithGroup", "A test enum",
[Bit0, Bit1, Bit2, Bit3, Bit4, Bits0To3]>;
[Bit0, Bit1, Bit2, Bit3, Bit4, Bits0To3]> {
let separator = "|";
}
def BitEnumPrimaryGroup : I32BitEnumAttr<"BitEnumPrimaryGroup", "test enum",
[Bit0, Bit1, Bit2, Bit3, Bit4, Bit5,
Bits0To3, Bits4And5, Bits0To5]> {
let separator = ",";
let separator = ", ";
let printBitEnumPrimaryGroups = 1;
}