[Polly] Partially refactoring of IslAstInfo and IslNodeBuilder to use isl++. NFC.

Polly use algorithms from the Integer Set Library (isl), which is a library written in C and which is incompatible with the rest of the LLVM as it is written in C++.

Changes made:
 - Refactoring the following methods of class IslAstInfo
   - isParallel() isExecutedInParallel() isReductionParallel() getSchedule() getMinimalDependenceDistance() getBrokenReductions()
 - Refactoring the following methods of class IslNodeBuilder
   - getReferencesInSubtree() getScheduleForAstNode()
 - Refactoring function getBrokenReductionsStr()
 - Fixed the mismatching function declaration for getScheduleForAstNode()

Reviewed By: Meinersbur

Differential Revision: https://reviews.llvm.org/D99971
This commit is contained in:
patacca 2021-04-10 16:25:05 -05:00 committed by Michael Kruse
parent 9829f5e6b1
commit 82fbc5d45b
4 changed files with 55 additions and 56 deletions

View File

@ -142,7 +142,7 @@ public:
static bool isInnermost(const isl::ast_node &Node);
/// Is this loop a parallel loop?
static bool isParallel(__isl_keep isl_ast_node *Node);
static bool isParallel(const isl::ast_node &Node);
/// Is this loop an outermost parallel loop?
static bool isOutermostParallel(const isl::ast_node &Node);
@ -151,20 +151,19 @@ public:
static bool isInnermostParallel(const isl::ast_node &Node);
/// Is this loop a reduction parallel loop?
static bool isReductionParallel(__isl_keep isl_ast_node *Node);
static bool isReductionParallel(const isl::ast_node &Node);
/// Will the loop be run as thread parallel?
static bool isExecutedInParallel(__isl_keep isl_ast_node *Node);
static bool isExecutedInParallel(const isl::ast_node &Node);
/// Get the nodes schedule or a nullptr if not available.
static __isl_give isl_union_map *getSchedule(__isl_keep isl_ast_node *Node);
static isl::union_map getSchedule(const isl::ast_node &Node);
/// Get minimal dependence distance or nullptr if not available.
static __isl_give isl_pw_aff *
getMinimalDependenceDistance(__isl_keep isl_ast_node *Node);
static isl::pw_aff getMinimalDependenceDistance(const isl::ast_node &Node);
/// Get the nodes broken reductions or a nullptr if not available.
static MemoryAccessSet *getBrokenReductions(__isl_keep isl_ast_node *Node);
static MemoryAccessSet *getBrokenReductions(const isl::ast_node &Node);
/// Get the nodes build context or a nullptr if not available.
static __isl_give isl_ast_build *getBuild(__isl_keep isl_ast_node *Node);

View File

@ -248,7 +248,7 @@ protected:
/// this subtree.
/// @param Loops A vector that will be filled with the Loops referenced in
/// this subtree.
void getReferencesInSubtree(__isl_keep isl_ast_node *For,
void getReferencesInSubtree(const isl::ast_node &For,
SetVector<Value *> &Values,
SetVector<const Loop *> &Loops);
@ -398,8 +398,7 @@ protected:
/// below this ast node to the scheduling vectors used to enumerate
/// them.
///
virtual __isl_give isl_union_map *
getScheduleForAstNode(__isl_take isl_ast_node *Node);
virtual isl::union_map getScheduleForAstNode(const isl::ast_node &Node);
private:
/// Create code for a copy statement.

View File

@ -140,7 +140,7 @@ static isl_printer *printLine(__isl_take isl_printer *Printer,
}
/// Return all broken reductions as a string of clauses (OpenMP style).
static const std::string getBrokenReductionsStr(__isl_keep isl_ast_node *Node) {
static const std::string getBrokenReductionsStr(const isl::ast_node &Node) {
IslAstInfo::MemoryAccessSet *BrokenReductions;
std::string str;
@ -171,25 +171,26 @@ static const std::string getBrokenReductionsStr(__isl_keep isl_ast_node *Node) {
static isl_printer *cbPrintFor(__isl_take isl_printer *Printer,
__isl_take isl_ast_print_options *Options,
__isl_keep isl_ast_node *Node, void *) {
isl_pw_aff *DD = IslAstInfo::getMinimalDependenceDistance(Node);
const std::string BrokenReductionsStr = getBrokenReductionsStr(Node);
isl::pw_aff DD =
IslAstInfo::getMinimalDependenceDistance(isl::manage_copy(Node));
const std::string BrokenReductionsStr =
getBrokenReductionsStr(isl::manage_copy(Node));
const std::string KnownParallelStr = "#pragma known-parallel";
const std::string DepDisPragmaStr = "#pragma minimal dependence distance: ";
const std::string SimdPragmaStr = "#pragma simd";
const std::string OmpPragmaStr = "#pragma omp parallel for";
if (DD)
Printer = printLine(Printer, DepDisPragmaStr, DD);
if (!DD.is_null())
Printer = printLine(Printer, DepDisPragmaStr, DD.get());
if (IslAstInfo::isInnermostParallel(isl::manage_copy(Node)))
Printer = printLine(Printer, SimdPragmaStr + BrokenReductionsStr);
if (IslAstInfo::isExecutedInParallel(Node))
if (IslAstInfo::isExecutedInParallel(isl::manage_copy(Node)))
Printer = printLine(Printer, OmpPragmaStr);
else if (IslAstInfo::isOutermostParallel(isl::manage_copy(Node)))
Printer = printLine(Printer, KnownParallelStr + BrokenReductionsStr);
isl_pw_aff_free(DD);
return isl_ast_node_for_print(Node, Printer, Options);
}
@ -472,15 +473,15 @@ static void walkAstForStatistics(__isl_keep isl_ast_node *Ast) {
switch (isl_ast_node_get_type(Node)) {
case isl_ast_node_for:
NumForLoops++;
if (IslAstInfo::isParallel(Node))
if (IslAstInfo::isParallel(isl::manage_copy(Node)))
NumParallel++;
if (IslAstInfo::isInnermostParallel(isl::manage_copy(Node)))
NumInnermostParallel++;
if (IslAstInfo::isOutermostParallel(isl::manage_copy(Node)))
NumOutermostParallel++;
if (IslAstInfo::isReductionParallel(Node))
if (IslAstInfo::isReductionParallel(isl::manage_copy(Node)))
NumReductionParallel++;
if (IslAstInfo::isExecutedInParallel(Node))
if (IslAstInfo::isExecutedInParallel(isl::manage_copy(Node)))
NumExecutedInParallel++;
break;
@ -593,9 +594,9 @@ bool IslAstInfo::isInnermost(const isl::ast_node &Node) {
return Payload && Payload->IsInnermost;
}
bool IslAstInfo::isParallel(__isl_keep isl_ast_node *Node) {
return IslAstInfo::isInnermostParallel(isl::manage_copy(Node)) ||
IslAstInfo::isOutermostParallel(isl::manage_copy(Node));
bool IslAstInfo::isParallel(const isl::ast_node &Node) {
return IslAstInfo::isInnermostParallel(Node) ||
IslAstInfo::isOutermostParallel(Node);
}
bool IslAstInfo::isInnermostParallel(const isl::ast_node &Node) {
@ -608,12 +609,12 @@ bool IslAstInfo::isOutermostParallel(const isl::ast_node &Node) {
return Payload && Payload->IsOutermostParallel;
}
bool IslAstInfo::isReductionParallel(__isl_keep isl_ast_node *Node) {
IslAstUserPayload *Payload = getNodePayload(isl::manage_copy(Node));
bool IslAstInfo::isReductionParallel(const isl::ast_node &Node) {
IslAstUserPayload *Payload = getNodePayload(Node);
return Payload && Payload->IsReductionParallel;
}
bool IslAstInfo::isExecutedInParallel(__isl_keep isl_ast_node *Node) {
bool IslAstInfo::isExecutedInParallel(const isl::ast_node &Node) {
if (!PollyParallel)
return false;
@ -626,28 +627,30 @@ bool IslAstInfo::isExecutedInParallel(__isl_keep isl_ast_node *Node) {
// executed. This can possibly require run-time checks, which again
// raises the question of both run-time check overhead and code size
// costs.
if (!PollyParallelForce && isInnermost(isl::manage_copy(Node)))
if (!PollyParallelForce && isInnermost(Node))
return false;
return isOutermostParallel(isl::manage_copy(Node)) &&
!isReductionParallel(Node);
return isOutermostParallel(Node) && !isReductionParallel(Node);
}
__isl_give isl_union_map *
IslAstInfo::getSchedule(__isl_keep isl_ast_node *Node) {
IslAstUserPayload *Payload = getNodePayload(isl::manage_copy(Node));
return Payload ? isl_ast_build_get_schedule(Payload->Build) : nullptr;
isl::union_map IslAstInfo::getSchedule(const isl::ast_node &Node) {
IslAstUserPayload *Payload = getNodePayload(Node);
if (!Payload)
return nullptr;
isl::ast_build Build = isl::manage_copy(Payload->Build);
return Build.get_schedule();
}
__isl_give isl_pw_aff *
IslAstInfo::getMinimalDependenceDistance(__isl_keep isl_ast_node *Node) {
IslAstUserPayload *Payload = getNodePayload(isl::manage_copy(Node));
return Payload ? Payload->MinimalDependenceDistance.copy() : nullptr;
isl::pw_aff
IslAstInfo::getMinimalDependenceDistance(const isl::ast_node &Node) {
IslAstUserPayload *Payload = getNodePayload(Node);
return Payload ? Payload->MinimalDependenceDistance : nullptr;
}
IslAstInfo::MemoryAccessSet *
IslAstInfo::getBrokenReductions(__isl_keep isl_ast_node *Node) {
IslAstUserPayload *Payload = getNodePayload(isl::manage_copy(Node));
IslAstInfo::getBrokenReductions(const isl::ast_node &Node) {
IslAstUserPayload *Payload = getNodePayload(Node);
return Payload ? &Payload->BrokenReductions : nullptr;
}

View File

@ -300,12 +300,12 @@ addReferencesFromStmtUnionSet(isl::union_set USet,
addReferencesFromStmtSet(Set, &References);
}
__isl_give isl_union_map *
IslNodeBuilder::getScheduleForAstNode(__isl_keep isl_ast_node *For) {
return IslAstInfo::getSchedule(For);
isl::union_map
IslNodeBuilder::getScheduleForAstNode(const isl::ast_node &Node) {
return IslAstInfo::getSchedule(Node);
}
void IslNodeBuilder::getReferencesInSubtree(__isl_keep isl_ast_node *For,
void IslNodeBuilder::getReferencesInSubtree(const isl::ast_node &For,
SetVector<Value *> &Values,
SetVector<const Loop *> &Loops) {
SetVector<const SCEV *> SCEVs;
@ -319,8 +319,7 @@ void IslNodeBuilder::getReferencesInSubtree(__isl_keep isl_ast_node *For,
for (const auto &I : OutsideLoopIterations)
Values.insert(cast<SCEVUnknown>(I.second)->getValue());
isl::union_set Schedule =
isl::manage(isl_union_map_domain(getScheduleForAstNode(For)));
isl::union_set Schedule = getScheduleForAstNode(For).domain();
addReferencesFromStmtUnionSet(Schedule, References);
for (const SCEV *Expr : SCEVs) {
@ -476,22 +475,22 @@ void IslNodeBuilder::createForVector(__isl_take isl_ast_node *For,
for (int i = 1; i < VectorWidth; i++)
IVS[i] = Builder.CreateAdd(IVS[i - 1], ValueInc, "p_vector_iv");
isl_union_map *Schedule = getScheduleForAstNode(For);
assert(Schedule && "For statement annotation does not contain its schedule");
isl::union_map Schedule = getScheduleForAstNode(isl::manage_copy(For));
assert(!Schedule.is_null() &&
"For statement annotation does not contain its schedule");
IDToValue[IteratorID] = ValueLB;
switch (isl_ast_node_get_type(Body)) {
case isl_ast_node_user:
createUserVector(Body, IVS, isl_id_copy(IteratorID),
isl_union_map_copy(Schedule));
createUserVector(Body, IVS, isl_id_copy(IteratorID), Schedule.copy());
break;
case isl_ast_node_block: {
isl_ast_node_list *List = isl_ast_node_block_get_children(Body);
for (int i = 0; i < isl_ast_node_list_n_ast_node(List); ++i)
createUserVector(isl_ast_node_list_get_ast_node(List, i), IVS,
isl_id_copy(IteratorID), isl_union_map_copy(Schedule));
isl_id_copy(IteratorID), Schedule.copy());
isl_ast_node_free(Body);
isl_ast_node_list_free(List);
@ -504,7 +503,6 @@ void IslNodeBuilder::createForVector(__isl_take isl_ast_node *For,
IDToValue.erase(IDToValue.find(IteratorID));
isl_id_free(IteratorID);
isl_union_map_free(Schedule);
isl_ast_node_free(For);
isl_ast_expr_free(Iterator);
@ -685,7 +683,7 @@ void IslNodeBuilder::createForParallel(__isl_take isl_ast_node *For) {
SetVector<Value *> SubtreeValues;
SetVector<const Loop *> Loops;
getReferencesInSubtree(For, SubtreeValues, Loops);
getReferencesInSubtree(isl::manage_copy(For), SubtreeValues, Loops);
// Create for all loops we depend on values that contain the current loop
// iteration. These values are necessary to generate code for SCEVs that
@ -783,7 +781,7 @@ void IslNodeBuilder::createFor(__isl_take isl_ast_node *For) {
bool Vector = PollyVectorizerChoice == VECTORIZER_POLLY;
if (Vector && IslAstInfo::isInnermostParallel(isl::manage_copy(For)) &&
!IslAstInfo::isReductionParallel(For)) {
!IslAstInfo::isReductionParallel(isl::manage_copy(For))) {
int VectorWidth = getNumberOfIterations(isl::manage_copy(For));
if (1 < VectorWidth && VectorWidth <= 16 && !hasPartialAccesses(For)) {
createForVector(For, VectorWidth);
@ -791,12 +789,12 @@ void IslNodeBuilder::createFor(__isl_take isl_ast_node *For) {
}
}
if (IslAstInfo::isExecutedInParallel(For)) {
if (IslAstInfo::isExecutedInParallel(isl::manage_copy(For))) {
createForParallel(For);
return;
}
bool Parallel =
(IslAstInfo::isParallel(For) && !IslAstInfo::isReductionParallel(For));
bool Parallel = (IslAstInfo::isParallel(isl::manage_copy(For)) &&
!IslAstInfo::isReductionParallel(isl::manage_copy(For)));
createForSequential(isl::manage(For), Parallel);
}