[mlir] LLVM dialect: Generate conversions between EnumAttrCase and LLVM API

Summary:
MLIR materializes various enumeration-based LLVM IR operands as enumeration
attributes using ODS. This requires bidirectional conversion between different
but very similar enums, currently hardcoded. Extend the ODS modeling of
LLVM-specific enumeration attributes to include the name of the corresponding
enum in the LLVM C++ API as well as the names of specific enumerants. Use this
new information to automatically generate the conversion functions between enum
attributes and LLVM API enums in the two-way conversion between the LLVM
dialect and LLVM IR proper.

Differential Revision: https://reviews.llvm.org/D73468
This commit is contained in:
Alex Zinenko 2020-01-27 14:49:34 +01:00
parent 5be2ca2921
commit eb67bd78dc
6 changed files with 182 additions and 80 deletions

View File

@ -10,6 +10,8 @@ add_mlir_dialect(ROCDLOps ROCDLOps)
set(LLVM_TARGET_DEFINITIONS LLVMOps.td)
mlir_tablegen(LLVMConversions.inc -gen-llvmir-conversions)
mlir_tablegen(LLVMConversionEnumsToLLVM.inc -gen-enum-to-llvmir-conversions)
mlir_tablegen(LLVMConversionEnumsFromLLVM.inc -gen-enum-from-llvmir-conversions)
add_public_tablegen_target(MLIRLLVMConversionsIncGen)
set(LLVM_TARGET_DEFINITIONS NVVMOps.td)
mlir_tablegen(NVVMConversions.inc -gen-llvmir-conversions)

View File

@ -64,13 +64,36 @@ class LLVM_IntrOp<string mnemonic, list<OpTrait> traits = []> :
// Case of the LLVM enum attribute backed by I64Attr with customized string
// representation that corresponds to what is visible in the textual IR form.
class LLVM_EnumAttrCase<string cppSym, string irSym, int val> :
I64EnumAttrCase<cppSym, val, irSym>;
// The parameters are as follows:
// - `cppSym`: name of the C++ enumerant for this case in MLIR API;
// - `irSym`: keyword used in the custom form of MLIR operation;
// - `llvmSym`: name of the C++ enumerant for this case in LLVM API.
// For example, `LLVM_EnumAttrCase<"Weak", "weak", "WeakAnyLinkage">` is usable
// as `<MlirEnumName>::Weak` in MLIR API, `WeakAnyLinkage` in LLVM API and
// is printed/parsed as `weak` in MLIR custom textual format.
class LLVM_EnumAttrCase<string cppSym, string irSym, string llvmSym, int val> :
I64EnumAttrCase<cppSym, val, irSym> {
// The name of the equivalent enumerant in LLVM.
string llvmEnumerant = llvmSym;
}
// LLVM enum attribute backed by I64Attr with string representation
// corresponding to what is visible in the textual IR form.
class LLVM_EnumAttr<string name, string description,
// The parameters are as follows:
// - `name`: name of the C++ enum class in MLIR API;
// - `llvmName`: name of the C++ enum in LLVM API;
// - `description`: textual description for documentation purposes;
// - `cases`: list of enum cases.
// For example, `LLVM_EnumAttr<Linkage, "::llvm::GlobalValue::LinkageTypes`
// produces `mlir::LLVM::Linkage` enum class in MLIR API that corresponds to (a
// subset of) values in the `llvm::GlobalValue::LinkageTypes` in LLVM API.
class LLVM_EnumAttr<string name, string llvmName, string description,
list<LLVM_EnumAttrCase> cases> :
I64EnumAttr<name, description, cases>;
I64EnumAttr<name, description, cases> {
// The equivalent enum class name in LLVM.
string llvmClassName = llvmName;
}
#endif // LLVMIR_OP_BASE

View File

@ -495,22 +495,33 @@ def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable", []> {
// https://llvm.org/docs/LangRef.html#linkage-types. The names are equivalent to
// visible names in the IR rather than to enum values names in llvm::GlobalValue
// since the latter is easier to change.
def LinkagePrivate : LLVM_EnumAttrCase<"Private", "private", 0>;
def LinkageInternal : LLVM_EnumAttrCase<"Internal", "internal", 1>;
def LinkageAvailableExternally : LLVM_EnumAttrCase<"AvailableExternally",
"available_externally", 2>;
def LinkageLinkonce : LLVM_EnumAttrCase<"Linkonce", "linkonce", 3>;
def LinkageWeak : LLVM_EnumAttrCase<"Weak", "weak", 4>;
def LinkageCommon : LLVM_EnumAttrCase<"Common", "common", 5>;
def LinkageAppending : LLVM_EnumAttrCase<"Appending", "appending", 6>;
def LinkageExternWeak : LLVM_EnumAttrCase<"ExternWeak",
"extern_weak", 7>;
def LinkageLinkonceODR : LLVM_EnumAttrCase<"LinkonceODR",
"linkonce_odr", 8>;
def LinkageWeakODR : LLVM_EnumAttrCase<"WeakODR", "weak_odr", 9>;
def LinkageExternal : LLVM_EnumAttrCase<"External", "external", 10>;
def LinkagePrivate
: LLVM_EnumAttrCase<"Private", "private", "PrivateLinkage", 0>;
def LinkageInternal
: LLVM_EnumAttrCase<"Internal", "internal", "InternalLinkage", 1>;
def LinkageAvailableExternally
: LLVM_EnumAttrCase<"AvailableExternally", "available_externally",
"AvailableExternallyLinkage", 2>;
def LinkageLinkonce
: LLVM_EnumAttrCase<"Linkonce", "linkonce", "LinkOnceAnyLinkage", 3>;
def LinkageWeak
: LLVM_EnumAttrCase<"Weak", "weak", "WeakAnyLinkage", 4>;
def LinkageCommon
: LLVM_EnumAttrCase<"Common", "common", "CommonLinkage", 5>;
def LinkageAppending
: LLVM_EnumAttrCase<"Appending", "appending", "AppendingLinkage", 6>;
def LinkageExternWeak
: LLVM_EnumAttrCase<"ExternWeak", "extern_weak", "ExternalWeakLinkage", 7>;
def LinkageLinkonceODR
: LLVM_EnumAttrCase<"LinkonceODR", "linkonce_odr", "LinkOnceODRLinkage", 8>;
def LinkageWeakODR
: LLVM_EnumAttrCase<"WeakODR", "weak_odr", "WeakODRLinkage", 9>;
def LinkageExternal
: LLVM_EnumAttrCase<"External", "external", "ExternalLinkage", 10>;
def Linkage : LLVM_EnumAttr<
"Linkage",
"::llvm::GlobalValue::LinkageTypes",
"LLVM linkage types",
[LinkagePrivate, LinkageInternal, LinkageAvailableExternally,
LinkageLinkonce, LinkageWeak, LinkageCommon, LinkageAppending,

View File

@ -30,6 +30,8 @@
using namespace mlir;
using namespace mlir::LLVM;
#include "mlir/Dialect/LLVMIR/LLVMConversionEnumsFromLLVM.inc"
// Utility to print an LLVM value as a string for passing to emitError().
// FIXME: Diagnostic should be able to natively handle types that have
// operator << (raw_ostream&) defined.
@ -363,37 +365,6 @@ Attribute Importer::getConstantAsAttr(llvm::Constant *value) {
return nullptr;
}
/// Converts LLVM global variable linkage type into the LLVM dialect predicate.
static LLVM::Linkage
processLinkage(llvm::GlobalVariable::LinkageTypes linkage) {
switch (linkage) {
case llvm::GlobalValue::PrivateLinkage:
return LLVM::Linkage::Private;
case llvm::GlobalValue::InternalLinkage:
return LLVM::Linkage::Internal;
case llvm::GlobalValue::AvailableExternallyLinkage:
return LLVM::Linkage::AvailableExternally;
case llvm::GlobalValue::LinkOnceAnyLinkage:
return LLVM::Linkage::Linkonce;
case llvm::GlobalValue::WeakAnyLinkage:
return LLVM::Linkage::Weak;
case llvm::GlobalValue::CommonLinkage:
return LLVM::Linkage::Common;
case llvm::GlobalValue::AppendingLinkage:
return LLVM::Linkage::Appending;
case llvm::GlobalValue::ExternalWeakLinkage:
return LLVM::Linkage::ExternWeak;
case llvm::GlobalValue::LinkOnceODRLinkage:
return LLVM::Linkage::LinkonceODR;
case llvm::GlobalValue::WeakODRLinkage:
return LLVM::Linkage::WeakODR;
case llvm::GlobalValue::ExternalLinkage:
return LLVM::Linkage::External;
}
llvm_unreachable("unhandled linkage type");
}
GlobalOp Importer::processGlobal(llvm::GlobalVariable *GV) {
auto it = globals.find(GV);
if (it != globals.end())
@ -408,7 +379,7 @@ GlobalOp Importer::processGlobal(llvm::GlobalVariable *GV) {
return nullptr;
GlobalOp op = b.create<GlobalOp>(
UnknownLoc::get(context), type, GV->isConstant(),
processLinkage(GV->getLinkage()), GV->getName(), valueAttr);
convertLinkageFromLLVM(GV->getLinkage()), GV->getName(), valueAttr);
if (GV->hasInitializer() && !valueAttr) {
Region &r = op.getInitializerRegion();
currentEntryBlock = b.createBlock(&r);

View File

@ -31,6 +31,8 @@
using namespace mlir;
using namespace mlir::LLVM;
#include "mlir/Dialect/LLVMIR/LLVMConversionEnumsToLLVM.inc"
/// Builds a constant of a sequential LLVM type `type`, potentially containing
/// other sequential types recursively, from the individual constant values
/// provided in `constants`. `shape` contains the number of elements in nested
@ -400,35 +402,6 @@ LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments) {
return success();
}
/// Convert the LLVM dialect linkage type to LLVM IR linkage type.
llvm::GlobalVariable::LinkageTypes convertLinkageType(LLVM::Linkage linkage) {
switch (linkage) {
case LLVM::Linkage::Private:
return llvm::GlobalValue::PrivateLinkage;
case LLVM::Linkage::Internal:
return llvm::GlobalValue::InternalLinkage;
case LLVM::Linkage::AvailableExternally:
return llvm::GlobalValue::AvailableExternallyLinkage;
case LLVM::Linkage::Linkonce:
return llvm::GlobalValue::LinkOnceAnyLinkage;
case LLVM::Linkage::Weak:
return llvm::GlobalValue::WeakAnyLinkage;
case LLVM::Linkage::Common:
return llvm::GlobalValue::CommonLinkage;
case LLVM::Linkage::Appending:
return llvm::GlobalValue::AppendingLinkage;
case LLVM::Linkage::ExternWeak:
return llvm::GlobalValue::ExternalWeakLinkage;
case LLVM::Linkage::LinkonceODR:
return llvm::GlobalValue::LinkOnceODRLinkage;
case LLVM::Linkage::WeakODR:
return llvm::GlobalValue::WeakODRLinkage;
case LLVM::Linkage::External:
return llvm::GlobalValue::ExternalLinkage;
}
llvm_unreachable("unknown linkage type");
}
/// Create named global variables that correspond to llvm.mlir.global
/// definitions.
void ModuleTranslation::convertGlobals() {
@ -458,7 +431,7 @@ void ModuleTranslation::convertGlobals() {
cst = cast<llvm::Constant>(valueMapping.lookup(ret.getOperand(0)));
}
auto linkage = convertLinkageType(op.linkage());
auto linkage = convertLinkageToLLVM(op.linkage());
bool anyExternalLinkage =
(linkage == llvm::GlobalVariable::ExternalLinkage ||
linkage == llvm::GlobalVariable::ExternalWeakLinkage);

View File

@ -11,6 +11,8 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Support/LogicalResult.h"
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Operator.h"
@ -171,6 +173,126 @@ static bool emitBuilders(const RecordKeeper &recordKeeper, raw_ostream &os) {
return false;
}
namespace {
// Wrapper class around a Tablegen definition of an LLVM enum attribute case.
class LLVMEnumAttrCase : public tblgen::EnumAttrCase {
public:
using tblgen::EnumAttrCase::EnumAttrCase;
// Constructs a case from a non LLVM-specific enum attribute case.
explicit LLVMEnumAttrCase(const tblgen::EnumAttrCase &other)
: tblgen::EnumAttrCase(&other.getDef()) {}
// Returns the C++ enumerant for the LLVM API.
StringRef getLLVMEnumerant() const {
return def->getValueAsString("llvmEnumerant");
}
};
// Wraper class around a Tablegen definition of an LLVM enum attribute.
class LLVMEnumAttr : public tblgen::EnumAttr {
public:
using tblgen::EnumAttr::EnumAttr;
// Returns the C++ enum name for the LLVM API.
StringRef getLLVMClassName() const {
return def->getValueAsString("llvmClassName");
}
// Returns all associated cases viewed as LLVM-specific enum cases.
std::vector<LLVMEnumAttrCase> getAllCases() const {
std::vector<LLVMEnumAttrCase> cases;
for (auto &c : tblgen::EnumAttr::getAllCases())
cases.push_back(LLVMEnumAttrCase(c));
return cases;
}
};
} // namespace
// Emits conversion function "LLVMClass convertEnumToLLVM(Enum)" and containing
// switch-based logic to convert from the MLIR LLVM dialect enum attribute case
// (Enum) to the corresponding LLVM API enumerant
static void emitOneEnumToConversion(const llvm::Record *record,
raw_ostream &os) {
LLVMEnumAttr enumAttr(record);
StringRef llvmClass = enumAttr.getLLVMClassName();
StringRef cppClassName = enumAttr.getEnumClassName();
StringRef cppNamespace = enumAttr.getCppNamespace();
// Emit the function converting the enum attribute to its LLVM counterpart.
os << formatv("static {0} convert{1}ToLLVM({2}::{1} value) {{\n", llvmClass,
cppClassName, cppNamespace);
os << " switch (value) {\n";
for (const auto &enumerant : enumAttr.getAllCases()) {
StringRef llvmEnumerant = enumerant.getLLVMEnumerant();
StringRef cppEnumerant = enumerant.getSymbol();
os << formatv(" case {0}::{1}::{2}:\n", cppNamespace, cppClassName,
cppEnumerant);
os << formatv(" return {0}::{1};\n", llvmClass, llvmEnumerant);
}
os << " }\n";
os << formatv(" llvm_unreachable(\"unknown {0} type\");\n",
enumAttr.getEnumClassName());
os << "}\n\n";
}
// Emits conversion function "Enum convertEnumFromLLVM(LLVMClass)" and
// containing switch-based logic to convert from the LLVM API enumerant to MLIR
// LLVM dialect enum attribute (Enum).
static void emitOneEnumFromConversion(const llvm::Record *record,
raw_ostream &os) {
LLVMEnumAttr enumAttr(record);
StringRef llvmClass = enumAttr.getLLVMClassName();
StringRef cppClassName = enumAttr.getEnumClassName();
StringRef cppNamespace = enumAttr.getCppNamespace();
// Emit the function converting the enum attribute from its LLVM counterpart.
os << formatv("static {0}::{1} convert{1}FromLLVM({2} value) {{\n",
cppNamespace, cppClassName, llvmClass);
os << " switch (value) {\n";
for (const auto &enumerant : enumAttr.getAllCases()) {
StringRef llvmEnumerant = enumerant.getLLVMEnumerant();
StringRef cppEnumerant = enumerant.getSymbol();
os << formatv(" case {0}::{1}:\n", llvmClass, llvmEnumerant);
os << formatv(" return {0}::{1}::{2};\n", cppNamespace, cppClassName,
cppEnumerant);
}
os << " }\n";
os << formatv(" llvm_unreachable(\"unknown {0} type\");",
enumAttr.getLLVMClassName());
os << "}\n\n";
}
// Emits conversion functions between MLIR enum attribute case and corresponding
// LLVM API enumerants for all registered LLVM dialect enum attributes.
template <bool ConvertTo>
static bool emitEnumConversionDefs(const RecordKeeper &recordKeeper,
raw_ostream &os) {
for (const auto *def : recordKeeper.getAllDerivedDefinitions("LLVM_EnumAttr"))
if (ConvertTo)
emitOneEnumToConversion(def, os);
else
emitOneEnumFromConversion(def, os);
return false;
}
static mlir::GenRegistration
genLLVMIRConversions("gen-llvmir-conversions",
"Generate LLVM IR conversions", emitBuilders);
static mlir::GenRegistration
genEnumToLLVMConversion("gen-enum-to-llvmir-conversions",
"Generate conversions of EnumAttrs to LLVM IR",
emitEnumConversionDefs</*ConvertTo=*/true>);
static mlir::GenRegistration
genEnumFromLLVMConversion("gen-enum-from-llvmir-conversions",
"Generate conversions of EnumAttrs from LLVM IR",
emitEnumConversionDefs</*ConvertTo=*/false>);