[MLIR,OpenMP] Lowering of parallel operation: proc_bind clause 2/n

This patch adds the translation of the proc_bind clause in a
parallel operation.

The values that can be specified for the proc_bind clause are
specified in the OMP.td tablegen file in the llvm/Frontend/OpenMP
directory. From this single source of truth enumeration for
proc_bind is generated in llvm and mlir (used in specification of
the parallel Operation in the OpenMP dialect). A function to return
the enum value from the string representation is also generated.
A new header file (DirectiveEmitter.h) containing definitions of
classes directive, clause, clauseval etc is created so that it can
be used in mlir as well.

Reviewers: clementval, jdoerfert, DavidTruby

Differential Revision: https://reviews.llvm.org/D84347
This commit is contained in:
Kiran Chandramohan 2020-07-22 17:28:04 +01:00
parent acc3d72e97
commit f33b4004b5
6 changed files with 350 additions and 173 deletions

View File

@ -51,6 +51,21 @@ class DirectiveLanguage {
string flangClauseBaseClass = "";
}
// Information about values accepted by enum-like clauses
class ClauseVal<string n, int v, bit uv> {
// Name of the clause value.
string name = n;
// Integer value of the clause.
int value = v;
// Can user specify this value?
bit isUserValue = uv;
// Set clause value used by default when unknown.
bit isDefault = 0;
}
// Information about a specific clause.
class Clause<string c> {
// Name of the clause.
@ -75,11 +90,17 @@ class Clause<string c> {
// If set to 1, value is optional. Not optional by default.
bit isValueOptional = 0;
// Name of enum when there is a list of allowed clause values.
string enumClauseValue = "";
// List of allowed clause values
list<ClauseVal> allowedClauseValues = [];
// Is clause implicit? If clause is set as implicit, the default kind will
// be return in get<LanguageName>ClauseKind instead of their own kind.
bit isImplicit = 0;
// Set directive used by default when unknown. Function returning the kind
// Set clause used by default when unknown. Function returning the kind
// of enumeration will use this clause as the default.
bit isDefault = 0;
}

View File

@ -99,9 +99,22 @@ def OMPC_CopyPrivate : Clause<"copyprivate"> {
let clangClass = "OMPCopyprivateClause";
let flangClassValue = "OmpObjectList";
}
def OMP_PROC_BIND_master : ClauseVal<"master",2,1> {}
def OMP_PROC_BIND_close : ClauseVal<"close",3,1> {}
def OMP_PROC_BIND_spread : ClauseVal<"spread",4,1> {}
def OMP_PROC_BIND_default : ClauseVal<"default",5,0> {}
def OMP_PROC_BIND_unknown : ClauseVal<"unknown",6,0> { let isDefault = 1; }
def OMPC_ProcBind : Clause<"proc_bind"> {
let clangClass = "OMPProcBindClause";
let flangClass = "OmpProcBindClause";
let enumClauseValue = "ProcBindKind";
let allowedClauseValues = [
OMP_PROC_BIND_master,
OMP_PROC_BIND_close,
OMP_PROC_BIND_spread,
OMP_PROC_BIND_default,
OMP_PROC_BIND_unknown
];
}
def OMPC_Schedule : Clause<"schedule"> {
let clangClass = "OMPScheduleClause";

View File

@ -68,16 +68,6 @@ enum class DefaultKind {
constexpr auto Enum = omp::DefaultKind::Enum;
#include "llvm/Frontend/OpenMP/OMPKinds.def"
/// IDs for the different proc bind kinds.
enum class ProcBindKind {
#define OMP_PROC_BIND_KIND(Enum, Str, Value) Enum = Value,
#include "llvm/Frontend/OpenMP/OMPKinds.def"
};
#define OMP_PROC_BIND_KIND(Enum, ...) \
constexpr auto Enum = omp::ProcBindKind::Enum;
#include "llvm/Frontend/OpenMP/OMPKinds.def"
/// IDs for all omp runtime library ident_t flag encodings (see
/// their defintion in openmp/runtime/src/kmp.h).
enum class IdentFlag {

View File

@ -0,0 +1,188 @@
#ifndef LLVM_TABLEGEN_DIRECTIVEEMITTER_H
#define LLVM_TABLEGEN_DIRECTIVEEMITTER_H
#include "llvm/ADT/StringExtras.h"
#include "llvm/TableGen/Record.h"
namespace llvm {
// Wrapper class that contains DirectiveLanguage's information defined in
// DirectiveBase.td and provides helper methods for accessing it.
class DirectiveLanguage {
public:
explicit DirectiveLanguage(const llvm::Record *Def) : Def(Def) {}
StringRef getName() const { return Def->getValueAsString("name"); }
StringRef getCppNamespace() const {
return Def->getValueAsString("cppNamespace");
}
StringRef getDirectivePrefix() const {
return Def->getValueAsString("directivePrefix");
}
StringRef getClausePrefix() const {
return Def->getValueAsString("clausePrefix");
}
StringRef getIncludeHeader() const {
return Def->getValueAsString("includeHeader");
}
StringRef getClauseEnumSetClass() const {
return Def->getValueAsString("clauseEnumSetClass");
}
StringRef getFlangClauseBaseClass() const {
return Def->getValueAsString("flangClauseBaseClass");
}
bool hasMakeEnumAvailableInNamespace() const {
return Def->getValueAsBit("makeEnumAvailableInNamespace");
}
bool hasEnableBitmaskEnumInNamespace() const {
return Def->getValueAsBit("enableBitmaskEnumInNamespace");
}
private:
const llvm::Record *Def;
};
// Base record class used for Directive and Clause class defined in
// DirectiveBase.td.
class BaseRecord {
public:
explicit BaseRecord(const llvm::Record *Def) : Def(Def) {}
StringRef getName() const { return Def->getValueAsString("name"); }
StringRef getAlternativeName() const {
return Def->getValueAsString("alternativeName");
}
// Returns the name of the directive formatted for output. Whitespace are
// replaced with underscores.
std::string getFormattedName() {
StringRef Name = Def->getValueAsString("name");
std::string N = Name.str();
std::replace(N.begin(), N.end(), ' ', '_');
return N;
}
bool isDefault() const { return Def->getValueAsBit("isDefault"); }
protected:
const llvm::Record *Def;
};
// Wrapper class that contains a Directive's information defined in
// DirectiveBase.td and provides helper methods for accessing it.
class Directive : public BaseRecord {
public:
explicit Directive(const llvm::Record *Def) : BaseRecord(Def) {}
std::vector<Record *> getAllowedClauses() const {
return Def->getValueAsListOfDefs("allowedClauses");
}
std::vector<Record *> getAllowedOnceClauses() const {
return Def->getValueAsListOfDefs("allowedOnceClauses");
}
std::vector<Record *> getAllowedExclusiveClauses() const {
return Def->getValueAsListOfDefs("allowedExclusiveClauses");
}
std::vector<Record *> getRequiredClauses() const {
return Def->getValueAsListOfDefs("requiredClauses");
}
};
// Wrapper class that contains Clause's information defined in DirectiveBase.td
// and provides helper methods for accessing it.
class Clause : public BaseRecord {
public:
explicit Clause(const llvm::Record *Def) : BaseRecord(Def) {}
// Optional field.
StringRef getClangClass() const {
return Def->getValueAsString("clangClass");
}
// Optional field.
StringRef getFlangClass() const {
return Def->getValueAsString("flangClass");
}
// Optional field.
StringRef getFlangClassValue() const {
return Def->getValueAsString("flangClassValue");
}
// Get the formatted name for Flang parser class. The generic formatted class
// name is constructed from the name were the first letter of each word is
// captitalized and the underscores are removed.
// ex: async -> Async
// num_threads -> NumThreads
std::string getFormattedParserClassName() {
StringRef Name = Def->getValueAsString("name");
std::string N = Name.str();
bool Cap = true;
std::transform(N.begin(), N.end(), N.begin(), [&Cap](unsigned char C) {
if (Cap == true) {
C = llvm::toUpper(C);
Cap = false;
} else if (C == '_') {
Cap = true;
}
return C;
});
N.erase(std::remove(N.begin(), N.end(), '_'), N.end());
return N;
}
// Optional field.
StringRef getEnumName() const {
return Def->getValueAsString("enumClauseValue");
}
std::vector<Record *> getClauseVals() const {
return Def->getValueAsListOfDefs("allowedClauseValues");
}
bool isValueOptional() const { return Def->getValueAsBit("isValueOptional"); }
bool isImplict() const { return Def->getValueAsBit("isImplicit"); }
};
// Wrapper class that contains VersionedClause's information defined in
// DirectiveBase.td and provides helper methods for accessing it.
class VersionedClause {
public:
explicit VersionedClause(const llvm::Record *Def) : Def(Def) {}
// Return the specific clause record wrapped in the Clause class.
Clause getClause() const { return Clause{Def->getValueAsDef("clause")}; }
int64_t getMinVersion() const { return Def->getValueAsInt("minVersion"); }
int64_t getMaxVersion() const { return Def->getValueAsInt("maxVersion"); }
private:
const llvm::Record *Def;
};
class ClauseVal : public BaseRecord {
public:
explicit ClauseVal(const llvm::Record *Def) : BaseRecord(Def) {}
int getValue() const { return Def->getValueAsInt("value"); }
bool isUserVisible() const { return Def->getValueAsBit("isUserValue"); }
};
} // namespace llvm
#endif

View File

@ -15,9 +15,20 @@ def TestDirectiveLanguage : DirectiveLanguage {
let flangClauseBaseClass = "TdlClause";
}
def TDLCV_vala : ClauseVal<"vala",1,1> {}
def TDLCV_valb : ClauseVal<"valb",2,1> {}
def TDLCV_valc : ClauseVal<"valc",3,0> { let isDefault = 1; }
def TDLC_ClauseA : Clause<"clausea"> {
let flangClass = "TdlClauseA";
let enumClauseValue = "AKind";
let allowedClauseValues = [
TDLCV_vala,
TDLCV_valb,
TDLCV_valc
];
}
def TDLC_ClauseB : Clause<"clauseb"> {
let flangClassValue = "IntExpr";
let isValueOptional = 1;
@ -61,6 +72,16 @@ def TDL_DirA : Directive<"dira"> {
// CHECK-NEXT: constexpr auto TDLC_clausea = llvm::tdl::Clause::TDLC_clausea;
// CHECK-NEXT: constexpr auto TDLC_clauseb = llvm::tdl::Clause::TDLC_clauseb;
// CHECK-EMPTY:
// CHECK-NEXT: enum class AKind {
// CHECK-NEXT: TDLCV_vala=1,
// CHECK-NEXT: TDLCV_valb=2,
// CHECK-NEXT: TDLCV_valc=3,
// CHECK-NEXT: };
// CHECK-EMPTY:
// CHECK-NEXT: constexpr auto TDLCV_vala = llvm::tdl::AKind::TDLCV_vala;
// CHECK-NEXT: constexpr auto TDLCV_valb = llvm::tdl::AKind::TDLCV_valb;
// CHECK-NEXT: constexpr auto TDLCV_valc = llvm::tdl::AKind::TDLCV_valc;
// CHECK-EMPTY:
// CHECK-NEXT: // Enumeration helper functions
// CHECK-NEXT: Directive getTdlDirectiveKind(llvm::StringRef Str);
// CHECK-EMPTY:
@ -73,6 +94,8 @@ def TDL_DirA : Directive<"dira"> {
// CHECK-NEXT: /// Return true if \p C is a valid clause for \p D in version \p Version.
// CHECK-NEXT: bool isAllowedClauseForDirective(Directive D, Clause C, unsigned Version);
// CHECK-EMPTY:
// CHECK-NEXT: AKind getAKind(StringRef);
// CHECK-EMPTY:
// CHECK-NEXT: } // namespace tdl
// CHECK-NEXT: } // namespace llvm
// CHECK-NEXT: #endif // LLVM_Tdl_INC
@ -116,6 +139,14 @@ def TDL_DirA : Directive<"dira"> {
// IMPL-NEXT: llvm_unreachable("Invalid Tdl Clause kind");
// IMPL-NEXT: }
// IMPL-EMPTY:
// IMPL-NEXT: AKind llvm::tdl::getAKind(llvm::StringRef Str) {
// IMPL-NEXT: return llvm::StringSwitch<AKind>(Str)
// IMPL-NEXT: .Case("vala",TDLCV_vala)
// IMPL-NEXT: .Case("valb",TDLCV_valb)
// IMPL-NEXT: .Case("valc",TDLCV_valc)
// IMPL-NEXT: .Default(TDLCV_valc);
// IMPL-NEXT: }
// IMPL-EMPTY:
// IMPL-NEXT: bool llvm::tdl::isAllowedClauseForDirective(Directive D, Clause C, unsigned Version) {
// IMPL-NEXT: assert(unsigned(D) <= llvm::tdl::Directive_enumSize);
// IMPL-NEXT: assert(unsigned(C) <= llvm::tdl::Clause_enumSize);

View File

@ -11,15 +11,14 @@
//
//===----------------------------------------------------------------------===//
#include "llvm/TableGen/DirectiveEmitter.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
using namespace llvm;
namespace {
@ -41,165 +40,6 @@ private:
namespace llvm {
// Wrapper class that contains DirectiveLanguage's information defined in
// DirectiveBase.td and provides helper methods for accessing it.
class DirectiveLanguage {
public:
explicit DirectiveLanguage(const llvm::Record *Def) : Def(Def) {}
StringRef getName() const { return Def->getValueAsString("name"); }
StringRef getCppNamespace() const {
return Def->getValueAsString("cppNamespace");
}
StringRef getDirectivePrefix() const {
return Def->getValueAsString("directivePrefix");
}
StringRef getClausePrefix() const {
return Def->getValueAsString("clausePrefix");
}
StringRef getIncludeHeader() const {
return Def->getValueAsString("includeHeader");
}
StringRef getClauseEnumSetClass() const {
return Def->getValueAsString("clauseEnumSetClass");
}
StringRef getFlangClauseBaseClass() const {
return Def->getValueAsString("flangClauseBaseClass");
}
bool hasMakeEnumAvailableInNamespace() const {
return Def->getValueAsBit("makeEnumAvailableInNamespace");
}
bool hasEnableBitmaskEnumInNamespace() const {
return Def->getValueAsBit("enableBitmaskEnumInNamespace");
}
private:
const llvm::Record *Def;
};
// Base record class used for Directive and Clause class defined in
// DirectiveBase.td.
class BaseRecord {
public:
explicit BaseRecord(const llvm::Record *Def) : Def(Def) {}
StringRef getName() const { return Def->getValueAsString("name"); }
StringRef getAlternativeName() const {
return Def->getValueAsString("alternativeName");
}
// Returns the name of the directive formatted for output. Whitespace are
// replaced with underscores.
std::string getFormattedName() {
StringRef Name = Def->getValueAsString("name");
std::string N = Name.str();
std::replace(N.begin(), N.end(), ' ', '_');
return N;
}
bool isDefault() const { return Def->getValueAsBit("isDefault"); }
protected:
const llvm::Record *Def;
};
// Wrapper class that contains a Directive's information defined in
// DirectiveBase.td and provides helper methods for accessing it.
class Directive : public BaseRecord {
public:
explicit Directive(const llvm::Record *Def) : BaseRecord(Def) {}
std::vector<Record *> getAllowedClauses() const {
return Def->getValueAsListOfDefs("allowedClauses");
}
std::vector<Record *> getAllowedOnceClauses() const {
return Def->getValueAsListOfDefs("allowedOnceClauses");
}
std::vector<Record *> getAllowedExclusiveClauses() const {
return Def->getValueAsListOfDefs("allowedExclusiveClauses");
}
std::vector<Record *> getRequiredClauses() const {
return Def->getValueAsListOfDefs("requiredClauses");
}
};
// Wrapper class that contains Clause's information defined in DirectiveBase.td
// and provides helper methods for accessing it.
class Clause : public BaseRecord {
public:
explicit Clause(const llvm::Record *Def) : BaseRecord(Def) {}
// Optional field.
StringRef getClangClass() const {
return Def->getValueAsString("clangClass");
}
// Optional field.
StringRef getFlangClass() const {
return Def->getValueAsString("flangClass");
}
// Optional field.
StringRef getFlangClassValue() const {
return Def->getValueAsString("flangClassValue");
}
// Get the formatted name for Flang parser class. The generic formatted class
// name is constructed from the name were the first letter of each word is
// captitalized and the underscores are removed.
// ex: async -> Async
// num_threads -> NumThreads
std::string getFormattedParserClassName() {
StringRef Name = Def->getValueAsString("name");
std::string N = Name.str();
bool Cap = true;
std::transform(N.begin(), N.end(), N.begin(), [&Cap](unsigned char C) {
if (Cap == true) {
C = llvm::toUpper(C);
Cap = false;
} else if (C == '_') {
Cap = true;
}
return C;
});
N.erase(std::remove(N.begin(), N.end(), '_'), N.end());
return N;
}
bool isValueOptional() const { return Def->getValueAsBit("isValueOptional"); }
bool isImplict() const { return Def->getValueAsBit("isImplicit"); }
};
// Wrapper class that contains VersionedClause's information defined in
// DirectiveBase.td and provides helper methods for accessing it.
class VersionedClause {
public:
explicit VersionedClause(const llvm::Record *Def) : Def(Def) {}
// Return the specific clause record wrapped in the Clause class.
Clause getClause() const { return Clause{Def->getValueAsDef("clause")}; }
int64_t getMinVersion() const { return Def->getValueAsInt("minVersion"); }
int64_t getMaxVersion() const { return Def->getValueAsInt("maxVersion"); }
private:
const llvm::Record *Def;
};
// Generate enum class
void GenerateEnumClass(const std::vector<Record *> &Records, raw_ostream &OS,
StringRef Enum, StringRef Prefix,
@ -231,6 +71,46 @@ void GenerateEnumClass(const std::vector<Record *> &Records, raw_ostream &OS,
}
}
// Generate enums for values that clauses can take.
// Also generate function declarations for get<Enum>Name(StringRef Str).
void GenerateEnumClauseVal(const std::vector<Record *> &Records,
raw_ostream &OS, DirectiveLanguage &DirLang,
std::string &EnumHelperFuncs) {
for (const auto &R : Records) {
Clause C{R};
const auto &ClauseVals = C.getClauseVals();
if (ClauseVals.size() <= 0)
continue;
const auto &EnumName = C.getEnumName();
if (EnumName.size() == 0) {
PrintError("enumClauseValue field not set in Clause" +
C.getFormattedName() + ".");
return;
}
OS << "\n";
OS << "enum class " << EnumName << " {\n";
for (const auto &CV : ClauseVals) {
ClauseVal CVal{CV};
OS << " " << CV->getName() << "=" << CVal.getValue() << ",\n";
}
OS << "};\n";
if (DirLang.hasMakeEnumAvailableInNamespace()) {
OS << "\n";
for (const auto &CV : ClauseVals) {
OS << "constexpr auto " << CV->getName() << " = "
<< "llvm::" << DirLang.getCppNamespace() << "::" << EnumName
<< "::" << CV->getName() << ";\n";
}
EnumHelperFuncs += (llvm::Twine(EnumName) + llvm::Twine(" get") +
llvm::Twine(EnumName) + llvm::Twine("(StringRef);\n"))
.str();
}
}
}
// Generate the declaration section for the enumeration in the directive
// language
void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) {
@ -273,6 +153,10 @@ void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) {
const auto &Clauses = Records.getAllDerivedDefinitions("Clause");
GenerateEnumClass(Clauses, OS, "Clause", DirLang.getClausePrefix(), DirLang);
// Emit ClauseVal enumeration
std::string EnumHelperFuncs;
GenerateEnumClauseVal(Clauses, OS, DirLang, EnumHelperFuncs);
// Generic function signatures
OS << "\n";
OS << "// Enumeration helper functions\n";
@ -292,6 +176,10 @@ void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) {
OS << "bool isAllowedClauseForDirective(Directive D, "
<< "Clause C, unsigned Version);\n";
OS << "\n";
if (EnumHelperFuncs.length() > 0) {
OS << EnumHelperFuncs;
OS << "\n";
}
// Closing namespaces
for (auto Ns : llvm::reverse(Namespaces))
@ -336,7 +224,7 @@ void GenerateGetKind(const std::vector<Record *> &Records, raw_ostream &OS,
});
if (DefaultIt == Records.end()) {
PrintError("A least one " + Enum + " must be defined as default.");
PrintError("At least one " + Enum + " must be defined as default.");
return;
}
@ -361,6 +249,49 @@ void GenerateGetKind(const std::vector<Record *> &Records, raw_ostream &OS,
OS << "}\n";
}
// Generate function implementation for get<ClauseVal>Kind(StringRef Str)
void GenerateGetKindClauseVal(const std::vector<Record *> &Records,
raw_ostream &OS, StringRef Namespace) {
for (const auto &R : Records) {
Clause C{R};
const auto &ClauseVals = C.getClauseVals();
if (ClauseVals.size() <= 0)
continue;
auto DefaultIt =
std::find_if(ClauseVals.begin(), ClauseVals.end(), [](Record *CV) {
return CV->getValueAsBit("isDefault") == true;
});
if (DefaultIt == ClauseVals.end()) {
PrintError("At least one val in Clause " + C.getFormattedName() +
" must be defined as default.");
return;
}
const auto DefaultName = (*DefaultIt)->getName();
const auto &EnumName = C.getEnumName();
if (EnumName.size() == 0) {
PrintError("enumClauseValue field not set in Clause" +
C.getFormattedName() + ".");
return;
}
OS << "\n";
OS << EnumName << " llvm::" << Namespace << "::get" << EnumName
<< "(llvm::StringRef Str) {\n";
OS << " return llvm::StringSwitch<" << EnumName << ">(Str)\n";
for (const auto &CV : ClauseVals) {
ClauseVal CVal{CV};
OS << " .Case(\"" << CVal.getFormattedName() << "\"," << CV->getName()
<< ")\n";
}
OS << " .Default(" << DefaultName << ");\n";
OS << "}\n";
}
}
void GenerateCaseForVersionedClauses(const std::vector<Record *> &Clauses,
raw_ostream &OS, StringRef DirectiveName,
DirectiveLanguage &DirLang,
@ -672,6 +603,9 @@ void EmitDirectivesImpl(RecordKeeper &Records, raw_ostream &OS) {
// getClauseName(Clause Kind)
GenerateGetName(Clauses, OS, "Clause", DirLang, DirLang.getClausePrefix());
// get<ClauseVal>Kind(StringRef Str)
GenerateGetKindClauseVal(Clauses, OS, DirLang.getCppNamespace());
// isAllowedClauseForDirective(Directive D, Clause C, unsigned Version)
GenerateIsAllowedClause(Directives, OS, DirLang);
}