mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-10-08 20:04:02 +00:00
Perform copying to created arrays according to the packing transformation
This is the fourth patch to apply the BLIS matmul optimization pattern on matmul kernels (http://www.cs.utexas.edu/users/flame/pubs/TOMS-BLIS-Analytical.pdf). BLIS implements gemm as three nested loops around a macro-kernel, plus two packing routines. The macro-kernel is implemented in terms of two additional loops around a micro-kernel. The micro-kernel is a loop around a rank-1 (i.e., outer product) update. In this change we perform copying to created arrays, which is the last step to implement the packing transformation. Reviewed-by: Tobias Grosser <tobias@grosser.es> Differential Revision: https://reviews.llvm.org/D23260 llvm-svn: 281441
This commit is contained in:
parent
79e00930e2
commit
b3224adfb6
@ -166,6 +166,17 @@ public:
|
||||
/// was enabled.
|
||||
llvm::Value *getOverflowState() const;
|
||||
|
||||
/// Create LLVM-IR that computes the memory location of an access expression.
|
||||
///
|
||||
/// For a given isl_ast_expr[ession] of type isl_ast_op_access this function
|
||||
/// creates IR that computes the address the access expression refers to.
|
||||
///
|
||||
/// @param Expr The ast expression of type isl_ast_op_access
|
||||
/// for which we generate LLVM-IR.
|
||||
///
|
||||
/// @return The llvm::Value* containing the result of the computation.
|
||||
llvm::Value *createAccessAddress(__isl_take isl_ast_expr *Expr);
|
||||
|
||||
private:
|
||||
Scop &S;
|
||||
|
||||
@ -203,7 +214,6 @@ private:
|
||||
llvm::Value *createId(__isl_take isl_ast_expr *Expr);
|
||||
llvm::Value *createInt(__isl_take isl_ast_expr *Expr);
|
||||
llvm::Value *createOpAddressOf(__isl_take isl_ast_expr *Expr);
|
||||
llvm::Value *createAccessAddress(__isl_take isl_ast_expr *Expr);
|
||||
|
||||
/// Create a binary operation @p Opc and track overflows if requested.
|
||||
///
|
||||
|
@ -375,6 +375,21 @@ protected:
|
||||
///
|
||||
virtual __isl_give isl_union_map *
|
||||
getScheduleForAstNode(__isl_take isl_ast_node *Node);
|
||||
|
||||
private:
|
||||
/// Create code for a copy statement.
|
||||
///
|
||||
/// A copy statement is expected to have one read memory access and one write
|
||||
/// memory access (in this very order). Data is loaded from the location
|
||||
/// described by the read memory access and written to the location described
|
||||
/// by the write memory access. @p NewAccesses contains for each access
|
||||
/// the isl ast expression that describes the location accessed.
|
||||
///
|
||||
/// @param Stmt The copy statement that contains the accesses.
|
||||
/// @param NewAccesses The hash table that contains remappings from memory
|
||||
/// ids to new access expressions.
|
||||
void generateCopyStmt(ScopStmt *Stmt,
|
||||
__isl_keep isl_id_to_ast_expr *NewAccesses);
|
||||
};
|
||||
|
||||
#endif
|
||||
|
@ -88,7 +88,7 @@ public:
|
||||
///
|
||||
/// @return True, if we believe @p NewSchedule is an improvement for @p S.
|
||||
static bool isProfitableSchedule(polly::Scop &S,
|
||||
__isl_keep isl_union_map *NewSchedule);
|
||||
__isl_keep isl_schedule *NewSchedule);
|
||||
|
||||
/// Isolate a set of partial tile prefixes.
|
||||
///
|
||||
|
@ -689,6 +689,19 @@ public:
|
||||
ArrayRef<const SCEV *> Subscripts, ArrayRef<const SCEV *> Sizes,
|
||||
Value *AccessValue, ScopArrayInfo::MemoryKind Kind,
|
||||
StringRef BaseName);
|
||||
|
||||
/// Create a new MemoryAccess that corresponds to @p AccRel.
|
||||
///
|
||||
/// Along with @p Stmt and @p AccType it uses information about dimension
|
||||
/// lengths of the accessed array, the type of the accessed array elements,
|
||||
/// the name of the accessed array that is derived from the object accessible
|
||||
/// via @p AccRel.
|
||||
///
|
||||
/// @param Stmt The parent statement.
|
||||
/// @param AccType Whether read or write access.
|
||||
/// @param AccRel The access relation that describes the memory access.
|
||||
MemoryAccess(ScopStmt *Stmt, AccessType AccType, __isl_take isl_map *AccRel);
|
||||
|
||||
~MemoryAccess();
|
||||
|
||||
/// Add a new incoming block/value pairs for this PHI/ExitPHI access.
|
||||
@ -1083,6 +1096,16 @@ public:
|
||||
/// Create an overapproximating ScopStmt for the region @p R.
|
||||
ScopStmt(Scop &parent, Region &R);
|
||||
|
||||
/// Create a copy statement.
|
||||
///
|
||||
/// @param Stmt The parent statement.
|
||||
/// @param SourceRel The source location.
|
||||
/// @param TargetRel The target location.
|
||||
/// @param Domain The original domain under which copy statement whould
|
||||
/// be executed.
|
||||
ScopStmt(Scop &parent, __isl_take isl_map *SourceRel,
|
||||
__isl_take isl_map *TargetRel, __isl_take isl_set *Domain);
|
||||
|
||||
/// Initialize members after all MemoryAccesses have been added.
|
||||
void init(LoopInfo &LI);
|
||||
|
||||
@ -1217,10 +1240,14 @@ public:
|
||||
|
||||
/// Get the schedule function of this ScopStmt.
|
||||
///
|
||||
/// @return The schedule function of this ScopStmt.
|
||||
/// @return The schedule function of this ScopStmt, if it does not contain
|
||||
/// extension nodes, and nullptr, otherwise.
|
||||
__isl_give isl_map *getSchedule() const;
|
||||
|
||||
/// Get an isl string representing this schedule.
|
||||
///
|
||||
/// @return An isl string representing this schedule, if it does not contain
|
||||
/// extension nodes, and an empty string, otherwise.
|
||||
std::string getScheduleStr() const;
|
||||
|
||||
/// Get the invalid domain for this statement.
|
||||
@ -1245,6 +1272,9 @@ public:
|
||||
/// Return true if this statement represents a single basic block.
|
||||
bool isBlockStmt() const { return BB != nullptr; }
|
||||
|
||||
/// Return true if this is a copy statement.
|
||||
bool isCopyStmt() const { return BB == nullptr && R == nullptr; }
|
||||
|
||||
/// Get the region represented by this ScopStmt (if any).
|
||||
///
|
||||
/// @return The region represented by this ScopStmt, or null if the statement
|
||||
@ -1448,6 +1478,9 @@ private:
|
||||
/// Max loop depth.
|
||||
unsigned MaxLoopDepth;
|
||||
|
||||
/// Number of copy statements.
|
||||
unsigned CopyStmtsNum;
|
||||
|
||||
typedef std::list<ScopStmt> StmtSet;
|
||||
/// The statements in this Scop.
|
||||
StmtSet Stmts;
|
||||
@ -1615,11 +1648,6 @@ private:
|
||||
Scop(Region &R, ScalarEvolution &SE, LoopInfo &LI,
|
||||
ScopDetection::DetectionContext &DC);
|
||||
|
||||
/// Add the access function to all MemoryAccess objects of the Scop
|
||||
/// created in this pass.
|
||||
void addAccessFunction(MemoryAccess *Access) {
|
||||
AccessFunctions.emplace_back(Access);
|
||||
}
|
||||
//@}
|
||||
|
||||
/// Initialize this ScopBuilder.
|
||||
@ -1927,6 +1955,30 @@ private:
|
||||
public:
|
||||
~Scop();
|
||||
|
||||
/// Get the count of copy statements added to this Scop.
|
||||
///
|
||||
/// @return The count of copy statements added to this Scop.
|
||||
unsigned getCopyStmtsNum() { return CopyStmtsNum; }
|
||||
|
||||
/// Create a new copy statement.
|
||||
///
|
||||
/// A new statement will be created and added to the statement vector.
|
||||
///
|
||||
/// @param Stmt The parent statement.
|
||||
/// @param SourceRel The source location.
|
||||
/// @param TargetRel The target location.
|
||||
/// @param Domain The original domain under which copy statement whould
|
||||
/// be executed.
|
||||
ScopStmt *addScopStmt(__isl_take isl_map *SourceRel,
|
||||
__isl_take isl_map *TargetRel,
|
||||
__isl_take isl_set *Domain);
|
||||
|
||||
/// Add the access function to all MemoryAccess objects of the Scop
|
||||
/// created in this pass.
|
||||
void addAccessFunction(MemoryAccess *Access) {
|
||||
AccessFunctions.emplace_back(Access);
|
||||
}
|
||||
|
||||
ScalarEvolution *getSE() const;
|
||||
|
||||
/// Get the count of parameters used in this Scop.
|
||||
@ -2349,6 +2401,9 @@ public:
|
||||
__isl_give isl_union_map *getAccesses();
|
||||
|
||||
/// Get the schedule of all the statements in the SCoP.
|
||||
///
|
||||
/// @return The schedule of all the statements in the SCoP, if the schedule of
|
||||
/// the Scop does not contain extension nodes, and nullptr, otherwise.
|
||||
__isl_give isl_union_map *getSchedule() const;
|
||||
|
||||
/// Get a schedule tree describing the schedule of all statements.
|
||||
@ -2380,6 +2435,11 @@ public:
|
||||
/// Find the ScopArrayInfo associated with an isl Id
|
||||
/// that has name @p Name.
|
||||
ScopArrayInfo *getArrayInfoByName(const std::string BaseName);
|
||||
|
||||
/// Check whether @p Schedule contains extension nodes.
|
||||
///
|
||||
/// @return true if @p Schedule contains extension nodes.
|
||||
static bool containsExtensionNode(__isl_keep isl_schedule *Schedule);
|
||||
};
|
||||
|
||||
/// Print Scop scop to raw_ostream O.
|
||||
|
@ -153,6 +153,8 @@ static void collectInfo(Scop &S, isl_union_map **Read, isl_union_map **Write,
|
||||
// to match the new access domains, thus we need
|
||||
// [Stmt[i0, i1] -> MemAcc_A[i0 + i1]] -> [0, i0, 2, i1, 0]
|
||||
isl_map *Schedule = Stmt.getSchedule();
|
||||
assert(Schedule && "Schedules that contain extension nodes require "
|
||||
"special handling.");
|
||||
Schedule = isl_map_apply_domain(
|
||||
Schedule,
|
||||
isl_map_reverse(isl_map_domain_map(isl_map_copy(accdom))));
|
||||
@ -162,7 +164,10 @@ static void collectInfo(Scop &S, isl_union_map **Read, isl_union_map **Write,
|
||||
} else {
|
||||
accdom = tag(accdom, MA, Level);
|
||||
if (Level > Dependences::AL_Statement) {
|
||||
isl_map *Schedule = tag(Stmt.getSchedule(), MA, Level);
|
||||
auto *StmtScheduleMap = Stmt.getSchedule();
|
||||
assert(StmtScheduleMap && "Schedules that contain extension nodes "
|
||||
"require special handling.");
|
||||
isl_map *Schedule = tag(StmtScheduleMap, MA, Level);
|
||||
*StmtSchedule = isl_union_map_add_map(*StmtSchedule, Schedule);
|
||||
}
|
||||
}
|
||||
@ -610,6 +615,8 @@ bool Dependences::isValidSchedule(Scop &S,
|
||||
StmtScat = Stmt.getSchedule();
|
||||
else
|
||||
StmtScat = isl_map_copy((*NewSchedule)[&Stmt]);
|
||||
assert(StmtScat &&
|
||||
"Schedules that contain extension nodes require special handling.");
|
||||
|
||||
if (!ScheduleSpace)
|
||||
ScheduleSpace = isl_space_range(isl_map_get_space(StmtScat));
|
||||
|
@ -134,6 +134,8 @@ __isl_give isl_union_map *PolyhedralInfo::getScheduleForLoop(const Scop *S,
|
||||
unsigned int MaxDim = SS->getNumIterators();
|
||||
DEBUG(dbgs() << "Maximum depth of Stmt:\t" << MaxDim << "\n");
|
||||
auto *ScheduleMap = SS->getSchedule();
|
||||
assert(ScheduleMap &&
|
||||
"Schedules that contain extension nodes require special handling.");
|
||||
|
||||
ScheduleMap = isl_map_project_out(ScheduleMap, isl_dim_out, CurrDim + 1,
|
||||
MaxDim - CurrDim - 1);
|
||||
|
@ -857,6 +857,28 @@ MemoryAccess::MemoryAccess(ScopStmt *Stmt, Instruction *AccessInst,
|
||||
Id = isl_id_alloc(Stmt->getParent()->getIslCtx(), IdName.c_str(), this);
|
||||
}
|
||||
|
||||
MemoryAccess::MemoryAccess(ScopStmt *Stmt, AccessType AccType,
|
||||
__isl_take isl_map *AccRel)
|
||||
: Kind(ScopArrayInfo::MemoryKind::MK_Array), AccType(AccType),
|
||||
RedType(RT_NONE), Statement(Stmt), InvalidDomain(nullptr),
|
||||
AccessInstruction(nullptr), IsAffine(true), AccessRelation(nullptr),
|
||||
NewAccessRelation(AccRel) {
|
||||
auto *ArrayInfoId = isl_map_get_tuple_id(NewAccessRelation, isl_dim_out);
|
||||
auto *SAI = ScopArrayInfo::getFromId(ArrayInfoId);
|
||||
Sizes.push_back(nullptr);
|
||||
for (unsigned i = 1; i < SAI->getNumberOfDimensions(); i++)
|
||||
Sizes.push_back(SAI->getDimensionSize(i));
|
||||
ElementType = SAI->getElementType();
|
||||
BaseAddr = SAI->getBasePtr();
|
||||
BaseName = SAI->getName();
|
||||
static const std::string TypeStrings[] = {"", "_Read", "_Write", "_MayWrite"};
|
||||
const std::string Access = TypeStrings[AccType] + utostr(Stmt->size()) + "_";
|
||||
|
||||
std::string IdName =
|
||||
getIslCompatibleName(Stmt->getBaseName(), Access, BaseName);
|
||||
Id = isl_id_alloc(Stmt->getParent()->getIslCtx(), IdName.c_str(), this);
|
||||
}
|
||||
|
||||
void MemoryAccess::realignParams() {
|
||||
auto *Ctx = Statement->getParent()->getContext();
|
||||
InvalidDomain = isl_set_gist_params(InvalidDomain, isl_set_copy(Ctx));
|
||||
@ -1040,6 +1062,10 @@ __isl_give isl_map *ScopStmt::getSchedule() const {
|
||||
isl_aff_zero_on_domain(isl_local_space_from_space(getDomainSpace())));
|
||||
}
|
||||
auto *Schedule = getParent()->getSchedule();
|
||||
if (!Schedule) {
|
||||
isl_set_free(Domain);
|
||||
return nullptr;
|
||||
}
|
||||
Schedule = isl_union_map_intersect_domain(
|
||||
Schedule, isl_union_set_from_set(isl_set_copy(Domain)));
|
||||
if (isl_union_map_is_empty(Schedule)) {
|
||||
@ -1430,6 +1456,25 @@ ScopStmt::ScopStmt(Scop &parent, BasicBlock &bb)
|
||||
BaseName = getIslCompatibleName("Stmt_", &bb, "");
|
||||
}
|
||||
|
||||
ScopStmt::ScopStmt(Scop &parent, __isl_take isl_map *SourceRel,
|
||||
__isl_take isl_map *TargetRel, __isl_take isl_set *NewDomain)
|
||||
: Parent(parent), InvalidDomain(nullptr), Domain(NewDomain), BB(nullptr),
|
||||
R(nullptr), Build(nullptr) {
|
||||
BaseName = getIslCompatibleName("CopyStmt_", "",
|
||||
std::to_string(parent.getCopyStmtsNum()));
|
||||
auto *Id = isl_id_alloc(getIslCtx(), getBaseName(), this);
|
||||
Domain = isl_set_set_tuple_id(Domain, isl_id_copy(Id));
|
||||
TargetRel = isl_map_set_tuple_id(TargetRel, isl_dim_in, Id);
|
||||
auto *Access =
|
||||
new MemoryAccess(this, MemoryAccess::AccessType::MUST_WRITE, TargetRel);
|
||||
parent.addAccessFunction(Access);
|
||||
addAccess(Access);
|
||||
SourceRel = isl_map_set_tuple_id(SourceRel, isl_dim_in, isl_id_copy(Id));
|
||||
Access = new MemoryAccess(this, MemoryAccess::AccessType::READ, SourceRel);
|
||||
parent.addAccessFunction(Access);
|
||||
addAccess(Access);
|
||||
}
|
||||
|
||||
void ScopStmt::init(LoopInfo &LI) {
|
||||
assert(!Domain && "init must be called only once");
|
||||
|
||||
@ -1576,6 +1621,8 @@ std::string ScopStmt::getDomainStr() const { return stringFromIslObj(Domain); }
|
||||
|
||||
std::string ScopStmt::getScheduleStr() const {
|
||||
auto *S = getSchedule();
|
||||
if (!S)
|
||||
return "";
|
||||
auto Str = stringFromIslObj(S);
|
||||
isl_map_free(S);
|
||||
return Str;
|
||||
@ -3041,9 +3088,10 @@ Scop::Scop(Region &R, ScalarEvolution &ScalarEvolution, LoopInfo &LI,
|
||||
ScopDetection::DetectionContext &DC)
|
||||
: SE(&ScalarEvolution), R(R), IsOptimized(false),
|
||||
HasSingleExitEdge(R.getExitingBlock()), HasErrorBlock(false),
|
||||
MaxLoopDepth(0), DC(DC), IslCtx(isl_ctx_alloc(), isl_ctx_free),
|
||||
Context(nullptr), Affinator(this, LI), AssumedContext(nullptr),
|
||||
InvalidContext(nullptr), Schedule(nullptr) {
|
||||
MaxLoopDepth(0), CopyStmtsNum(0), DC(DC),
|
||||
IslCtx(isl_ctx_alloc(), isl_ctx_free), Context(nullptr),
|
||||
Affinator(this, LI), AssumedContext(nullptr), InvalidContext(nullptr),
|
||||
Schedule(nullptr) {
|
||||
if (IslOnErrorAbort)
|
||||
isl_options_set_on_error(getIslCtx(), ISL_ON_ERROR_ABORT);
|
||||
buildContext();
|
||||
@ -3922,8 +3970,27 @@ __isl_give isl_union_map *Scop::getAccesses() {
|
||||
return getAccessesOfType([](MemoryAccess &MA) { return true; });
|
||||
}
|
||||
|
||||
// Check whether @p Node is an extension node.
|
||||
//
|
||||
// @return true if @p Node is an extension node.
|
||||
isl_bool isNotExtNode(__isl_keep isl_schedule_node *Node, void *User) {
|
||||
if (isl_schedule_node_get_type(Node) == isl_schedule_node_extension)
|
||||
return isl_bool_error;
|
||||
else
|
||||
return isl_bool_true;
|
||||
}
|
||||
|
||||
bool Scop::containsExtensionNode(__isl_keep isl_schedule *Schedule) {
|
||||
return isl_schedule_foreach_schedule_node_top_down(Schedule, isNotExtNode,
|
||||
nullptr) == isl_stat_error;
|
||||
}
|
||||
|
||||
__isl_give isl_union_map *Scop::getSchedule() const {
|
||||
auto *Tree = getScheduleTree();
|
||||
if (containsExtensionNode(Tree)) {
|
||||
isl_schedule_free(Tree);
|
||||
return nullptr;
|
||||
}
|
||||
auto *S = isl_schedule_get_map(Tree);
|
||||
isl_schedule_free(Tree);
|
||||
return S;
|
||||
@ -4059,6 +4126,14 @@ void Scop::addScopStmt(BasicBlock *BB, Region *R) {
|
||||
}
|
||||
}
|
||||
|
||||
ScopStmt *Scop::addScopStmt(__isl_take isl_map *SourceRel,
|
||||
__isl_take isl_map *TargetRel,
|
||||
__isl_take isl_set *Domain) {
|
||||
Stmts.emplace_back(*this, SourceRel, TargetRel, Domain);
|
||||
CopyStmtsNum++;
|
||||
return &(Stmts.back());
|
||||
}
|
||||
|
||||
void Scop::buildSchedule(LoopInfo &LI) {
|
||||
Loop *L = getLoopSurroundingScop(*this, LI);
|
||||
LoopStackTy LoopStack({LoopStackElementTy(L, nullptr, 0)});
|
||||
|
@ -681,7 +681,9 @@ void BlockGenerator::createExitPHINodeMerges(Scop &S) {
|
||||
|
||||
void BlockGenerator::invalidateScalarEvolution(Scop &S) {
|
||||
for (auto &Stmt : S)
|
||||
if (Stmt.isBlockStmt())
|
||||
if (Stmt.isCopyStmt())
|
||||
continue;
|
||||
else if (Stmt.isBlockStmt())
|
||||
for (auto &Inst : *Stmt.getBasicBlock())
|
||||
SE.forgetValue(&Inst);
|
||||
else if (Stmt.isRegionStmt())
|
||||
|
@ -61,7 +61,8 @@ void ScopAnnotator::buildAliasScopes(Scop &S) {
|
||||
SetVector<Value *> BasePtrs;
|
||||
for (ScopStmt &Stmt : S)
|
||||
for (MemoryAccess *MA : Stmt)
|
||||
BasePtrs.insert(MA->getBaseAddr());
|
||||
if (!Stmt.isCopyStmt())
|
||||
BasePtrs.insert(MA->getBaseAddr());
|
||||
|
||||
std::string AliasScopeStr = "polly.alias.scope.";
|
||||
for (Value *BasePtr : BasePtrs)
|
||||
|
@ -593,8 +593,7 @@ void IslAstInfo::printScop(raw_ostream &OS, Scop &S) const {
|
||||
P = isl_ast_node_print(RootNode, P, Options);
|
||||
AstStr = isl_printer_get_str(P);
|
||||
|
||||
isl_union_map *Schedule =
|
||||
isl_union_map_intersect_domain(S.getSchedule(), S.getDomains());
|
||||
auto *Schedule = S.getScheduleTree();
|
||||
|
||||
DEBUG({
|
||||
dbgs() << S.getContextStr() << "\n";
|
||||
@ -609,7 +608,7 @@ void IslAstInfo::printScop(raw_ostream &OS, Scop &S) const {
|
||||
free(AstStr);
|
||||
|
||||
isl_ast_expr_free(RunCondition);
|
||||
isl_union_map_free(Schedule);
|
||||
isl_schedule_free(Schedule);
|
||||
isl_ast_node_free(RootNode);
|
||||
isl_printer_free(P);
|
||||
}
|
||||
|
@ -767,6 +767,23 @@ void IslNodeBuilder::createSubstitutionsVector(
|
||||
isl_ast_expr_free(Expr);
|
||||
}
|
||||
|
||||
void IslNodeBuilder::generateCopyStmt(
|
||||
ScopStmt *Stmt, __isl_keep isl_id_to_ast_expr *NewAccesses) {
|
||||
assert(Stmt->size() == 2);
|
||||
auto ReadAccess = Stmt->begin();
|
||||
auto WriteAccess = ReadAccess++;
|
||||
assert((*ReadAccess)->isRead() && (*WriteAccess)->isMustWrite());
|
||||
assert((*ReadAccess)->getElementType() == (*WriteAccess)->getElementType() &&
|
||||
"Accesses use the same data type");
|
||||
assert((*ReadAccess)->isArrayKind() && (*WriteAccess)->isArrayKind());
|
||||
auto *AccessExpr =
|
||||
isl_id_to_ast_expr_get(NewAccesses, (*ReadAccess)->getId());
|
||||
auto *LoadValue = ExprBuilder.create(AccessExpr);
|
||||
AccessExpr = isl_id_to_ast_expr_get(NewAccesses, (*WriteAccess)->getId());
|
||||
auto *StoreAddr = ExprBuilder.createAccessAddress(AccessExpr);
|
||||
Builder.CreateStore(LoadValue, StoreAddr);
|
||||
}
|
||||
|
||||
void IslNodeBuilder::createUser(__isl_take isl_ast_node *User) {
|
||||
LoopToScevMapT LTS;
|
||||
isl_id *Id;
|
||||
@ -781,12 +798,17 @@ void IslNodeBuilder::createUser(__isl_take isl_ast_node *User) {
|
||||
|
||||
Stmt = (ScopStmt *)isl_id_get_user(Id);
|
||||
auto *NewAccesses = createNewAccesses(Stmt, User);
|
||||
createSubstitutions(Expr, Stmt, LTS);
|
||||
if (Stmt->isCopyStmt()) {
|
||||
generateCopyStmt(Stmt, NewAccesses);
|
||||
isl_ast_expr_free(Expr);
|
||||
} else {
|
||||
createSubstitutions(Expr, Stmt, LTS);
|
||||
|
||||
if (Stmt->isBlockStmt())
|
||||
BlockGen.copyStmt(*Stmt, LTS, NewAccesses);
|
||||
else
|
||||
RegionGen.copyStmt(*Stmt, LTS, NewAccesses);
|
||||
if (Stmt->isBlockStmt())
|
||||
BlockGen.copyStmt(*Stmt, LTS, NewAccesses);
|
||||
else
|
||||
RegionGen.copyStmt(*Stmt, LTS, NewAccesses);
|
||||
}
|
||||
|
||||
isl_id_to_ast_expr_free(NewAccesses);
|
||||
isl_ast_node_free(User);
|
||||
|
@ -294,6 +294,8 @@ bool JSONImporter::importSchedule(Scop &S, Json::Value &JScop,
|
||||
int Index = 0;
|
||||
for (ScopStmt &Stmt : S) {
|
||||
Json::Value Schedule = JScop["statements"][Index]["schedule"];
|
||||
assert(!Schedule.asString().empty() &&
|
||||
"Schedules that contain extension nodes require special handling.");
|
||||
isl_map *Map = isl_map_read_from_str(S.getIslCtx(), Schedule.asCString());
|
||||
isl_space *Space = Stmt.getDomainSpace();
|
||||
|
||||
|
@ -92,6 +92,8 @@ char DeadCodeElim::ID = 0;
|
||||
// no point in trying to remove them from the live-out set.
|
||||
__isl_give isl_union_set *DeadCodeElim::getLiveOut(Scop &S) {
|
||||
isl_union_map *Schedule = S.getSchedule();
|
||||
assert(Schedule &&
|
||||
"Schedules that contain extension nodes require special handling.");
|
||||
isl_union_map *WriteIterations = isl_union_map_reverse(S.getMustWrites());
|
||||
isl_union_map *WriteTimes =
|
||||
isl_union_map_apply_range(WriteIterations, isl_union_map_copy(Schedule));
|
||||
|
@ -660,6 +660,76 @@ identifyAccessByAccessRelation(ScopStmt *Stmt,
|
||||
return IdentifiedAccess;
|
||||
}
|
||||
|
||||
/// Add constrains to @Dim dimension of @p ExtMap.
|
||||
///
|
||||
/// If @ExtMap has the following form [O0, O1, O2]->[I1, I2, I3],
|
||||
/// the following constraint will be added
|
||||
/// Bound * OM <= IM <= Bound * (OM + 1) - 1,
|
||||
/// where M is @p Dim and Bound is @p Bound.
|
||||
///
|
||||
/// @param ExtMap The isl map to be modified.
|
||||
/// @param Dim The output dimension to be modfied.
|
||||
/// @param Bound The value that is used to specify the constraint.
|
||||
/// @return The modified isl map
|
||||
__isl_give isl_map *
|
||||
addExtensionMapMatMulDimConstraint(__isl_take isl_map *ExtMap, unsigned Dim,
|
||||
unsigned Bound) {
|
||||
assert(Bound != 0);
|
||||
auto *ExtMapSpace = isl_map_get_space(ExtMap);
|
||||
auto *ConstrSpace = isl_local_space_from_space(ExtMapSpace);
|
||||
auto *Constr =
|
||||
isl_constraint_alloc_inequality(isl_local_space_copy(ConstrSpace));
|
||||
Constr = isl_constraint_set_coefficient_si(Constr, isl_dim_out, Dim, 1);
|
||||
Constr =
|
||||
isl_constraint_set_coefficient_si(Constr, isl_dim_in, Dim, Bound * (-1));
|
||||
ExtMap = isl_map_add_constraint(ExtMap, Constr);
|
||||
Constr = isl_constraint_alloc_inequality(ConstrSpace);
|
||||
Constr = isl_constraint_set_coefficient_si(Constr, isl_dim_out, Dim, -1);
|
||||
Constr = isl_constraint_set_coefficient_si(Constr, isl_dim_in, Dim, Bound);
|
||||
Constr = isl_constraint_set_constant_si(Constr, Bound - 1);
|
||||
return isl_map_add_constraint(ExtMap, Constr);
|
||||
}
|
||||
|
||||
/// Create an access relation that is specific for matrix multiplication
|
||||
/// pattern.
|
||||
///
|
||||
/// Create an access relation of the following form:
|
||||
/// { [O0, O1, O2]->[I1, I2, I3] :
|
||||
/// FirstOutputDimBound * O0 <= I1 <= FirstOutputDimBound * (O0 + 1) - 1
|
||||
/// and SecondOutputDimBound * O1 <= I2 <= SecondOutputDimBound * (O1 + 1) - 1
|
||||
/// and ThirdOutputDimBound * O2 <= I3 <= ThirdOutputDimBound * (O2 + 1) - 1}
|
||||
/// where FirstOutputDimBound is @p FirstOutputDimBound,
|
||||
/// SecondOutputDimBound is @p SecondOutputDimBound,
|
||||
/// ThirdOutputDimBound is @p ThirdOutputDimBound
|
||||
///
|
||||
/// @param Ctx The isl context.
|
||||
/// @param FirstOutputDimBound,
|
||||
/// SecondOutputDimBound,
|
||||
/// ThirdOutputDimBound The parameters of the access relation.
|
||||
/// @return The specified access relation.
|
||||
__isl_give isl_map *getMatMulExt(isl_ctx *Ctx, unsigned FirstOutputDimBound,
|
||||
unsigned SecondOutputDimBound,
|
||||
unsigned ThirdOutputDimBound) {
|
||||
auto *NewRelSpace = isl_space_alloc(Ctx, 0, 3, 3);
|
||||
auto *extensionMap = isl_map_universe(NewRelSpace);
|
||||
if (!FirstOutputDimBound)
|
||||
extensionMap = isl_map_fix_si(extensionMap, isl_dim_out, 0, 0);
|
||||
else
|
||||
extensionMap = addExtensionMapMatMulDimConstraint(extensionMap, 0,
|
||||
FirstOutputDimBound);
|
||||
if (!SecondOutputDimBound)
|
||||
extensionMap = isl_map_fix_si(extensionMap, isl_dim_out, 1, 0);
|
||||
else
|
||||
extensionMap = addExtensionMapMatMulDimConstraint(extensionMap, 1,
|
||||
SecondOutputDimBound);
|
||||
if (!ThirdOutputDimBound)
|
||||
extensionMap = isl_map_fix_si(extensionMap, isl_dim_out, 2, 0);
|
||||
else
|
||||
extensionMap = addExtensionMapMatMulDimConstraint(extensionMap, 2,
|
||||
ThirdOutputDimBound);
|
||||
return extensionMap;
|
||||
}
|
||||
|
||||
/// Create an access relation that is specific to the matrix
|
||||
/// multiplication pattern.
|
||||
///
|
||||
@ -758,6 +828,14 @@ __isl_give isl_map *getMatMulAccRel(__isl_take isl_map *MapOldIndVar,
|
||||
return isl_map_apply_range(MapOldIndVar, AccessRel);
|
||||
}
|
||||
|
||||
__isl_give isl_schedule_node *
|
||||
createExtensionNode(__isl_take isl_schedule_node *Node,
|
||||
__isl_take isl_map *ExtensionMap) {
|
||||
auto *Extension = isl_union_map_from_map(ExtensionMap);
|
||||
auto *NewNode = isl_schedule_node_from_extension(Extension);
|
||||
return isl_schedule_node_graft_before(Node, NewNode);
|
||||
}
|
||||
|
||||
/// Apply the packing transformation.
|
||||
///
|
||||
/// The packing transformation can be described as a data-layout
|
||||
@ -772,9 +850,9 @@ __isl_give isl_map *getMatMulAccRel(__isl_take isl_map *MapOldIndVar,
|
||||
/// @param MicroParams, MacroParams Parameters of the BLIS kernel
|
||||
/// to be taken into account.
|
||||
/// @return The optimized schedule node.
|
||||
static void optimizeDataLayoutMatrMulPattern(__isl_take isl_map *MapOldIndVar,
|
||||
MicroKernelParamsTy MicroParams,
|
||||
MacroKernelParamsTy MacroParams) {
|
||||
static __isl_give isl_schedule_node *optimizeDataLayoutMatrMulPattern(
|
||||
__isl_take isl_schedule_node *Node, __isl_take isl_map *MapOldIndVar,
|
||||
MicroKernelParamsTy MicroParams, MacroKernelParamsTy MacroParams) {
|
||||
auto InputDimsId = isl_map_get_tuple_id(MapOldIndVar, isl_dim_in);
|
||||
auto *Stmt = static_cast<ScopStmt *>(isl_id_get_user(InputDimsId));
|
||||
isl_id_free(InputDimsId);
|
||||
@ -782,8 +860,12 @@ static void optimizeDataLayoutMatrMulPattern(__isl_take isl_map *MapOldIndVar,
|
||||
MemoryAccess *MemAccessB = identifyAccessB(Stmt);
|
||||
if (!MemAccessA || !MemAccessB) {
|
||||
isl_map_free(MapOldIndVar);
|
||||
return;
|
||||
return Node;
|
||||
}
|
||||
Node = isl_schedule_node_parent(isl_schedule_node_parent(Node));
|
||||
Node = isl_schedule_node_parent(isl_schedule_node_parent(Node));
|
||||
Node = isl_schedule_node_parent(Node);
|
||||
Node = isl_schedule_node_child(isl_schedule_node_band_split(Node, 2), 0);
|
||||
auto *AccRel =
|
||||
getMatMulAccRel(isl_map_copy(MapOldIndVar), MacroParams.Kc, 3, 6);
|
||||
unsigned FirstDimSize = MacroParams.Mc * MacroParams.Kc / MicroParams.Mr;
|
||||
@ -791,14 +873,34 @@ static void optimizeDataLayoutMatrMulPattern(__isl_take isl_map *MapOldIndVar,
|
||||
auto *SAI = Stmt->getParent()->createScopArrayInfo(
|
||||
MemAccessA->getElementType(), "Packed_A", {FirstDimSize, SecondDimSize});
|
||||
AccRel = isl_map_set_tuple_id(AccRel, isl_dim_out, SAI->getBasePtrId());
|
||||
auto *OldAcc = MemAccessA->getAccessRelation();
|
||||
MemAccessA->setNewAccessRelation(AccRel);
|
||||
auto *ExtMap =
|
||||
getMatMulExt(Stmt->getIslCtx(), MacroParams.Mc, 0, MacroParams.Kc);
|
||||
ExtMap = isl_map_project_out(ExtMap, isl_dim_in, 1, 1);
|
||||
auto *Domain = Stmt->getDomain();
|
||||
auto *NewStmt = Stmt->getParent()->addScopStmt(
|
||||
OldAcc, MemAccessA->getAccessRelation(), isl_set_copy(Domain));
|
||||
ExtMap = isl_map_set_tuple_id(ExtMap, isl_dim_out, NewStmt->getDomainId());
|
||||
Node = createExtensionNode(Node, ExtMap);
|
||||
Node = isl_schedule_node_child(Node, 0);
|
||||
AccRel = getMatMulAccRel(MapOldIndVar, MacroParams.Kc, 4, 7);
|
||||
FirstDimSize = MacroParams.Nc * MacroParams.Kc / MicroParams.Nr;
|
||||
SecondDimSize = MicroParams.Nr;
|
||||
SAI = Stmt->getParent()->createScopArrayInfo(
|
||||
MemAccessB->getElementType(), "Packed_B", {FirstDimSize, SecondDimSize});
|
||||
AccRel = isl_map_set_tuple_id(AccRel, isl_dim_out, SAI->getBasePtrId());
|
||||
OldAcc = MemAccessB->getAccessRelation();
|
||||
MemAccessB->setNewAccessRelation(AccRel);
|
||||
ExtMap = getMatMulExt(Stmt->getIslCtx(), 0, MacroParams.Nc, MacroParams.Kc);
|
||||
isl_map_move_dims(ExtMap, isl_dim_out, 0, isl_dim_in, 1, 1);
|
||||
isl_map_move_dims(ExtMap, isl_dim_in, 2, isl_dim_out, 0, 1);
|
||||
NewStmt = Stmt->getParent()->addScopStmt(
|
||||
OldAcc, MemAccessB->getAccessRelation(), Domain);
|
||||
ExtMap = isl_map_set_tuple_id(ExtMap, isl_dim_out, NewStmt->getDomainId());
|
||||
Node = createExtensionNode(Node, ExtMap);
|
||||
Node = isl_schedule_node_child(isl_schedule_node_child(Node, 0), 0);
|
||||
return isl_schedule_node_child(isl_schedule_node_child(Node, 0), 0);
|
||||
}
|
||||
|
||||
/// Get a relation mapping induction variables produced by schedule
|
||||
@ -842,9 +944,8 @@ __isl_give isl_schedule_node *ScheduleTreeOptimizer::optimizeMatMulPattern(
|
||||
Node, MicroKernelParams, MacroKernelParams);
|
||||
if (!MapOldIndVar)
|
||||
return Node;
|
||||
optimizeDataLayoutMatrMulPattern(MapOldIndVar, MicroKernelParams,
|
||||
MacroKernelParams);
|
||||
return Node;
|
||||
return optimizeDataLayoutMatrMulPattern(Node, MapOldIndVar, MicroKernelParams,
|
||||
MacroKernelParams);
|
||||
}
|
||||
|
||||
bool ScheduleTreeOptimizer::isMatrMultPattern(
|
||||
@ -901,7 +1002,7 @@ __isl_give isl_schedule_node *ScheduleTreeOptimizer::optimizeScheduleNode(
|
||||
}
|
||||
|
||||
bool ScheduleTreeOptimizer::isProfitableSchedule(
|
||||
Scop &S, __isl_keep isl_union_map *NewSchedule) {
|
||||
Scop &S, __isl_keep isl_schedule *NewSchedule) {
|
||||
// To understand if the schedule has been optimized we check if the schedule
|
||||
// has changed at all.
|
||||
// TODO: We can improve this by tracking if any necessarily beneficial
|
||||
@ -911,9 +1012,15 @@ bool ScheduleTreeOptimizer::isProfitableSchedule(
|
||||
// optimizations, by comparing (yet to be defined) performance metrics
|
||||
// before/after the scheduling optimizer
|
||||
// (e.g., #stride-one accesses)
|
||||
if (S.containsExtensionNode(NewSchedule))
|
||||
return true;
|
||||
auto *NewScheduleMap = isl_schedule_get_map(NewSchedule);
|
||||
isl_union_map *OldSchedule = S.getSchedule();
|
||||
bool changed = !isl_union_map_is_equal(OldSchedule, NewSchedule);
|
||||
assert(OldSchedule && "Only IslScheduleOptimizer can insert extension nodes "
|
||||
"that make Scop::getSchedule() return nullptr.");
|
||||
bool changed = !isl_union_map_is_equal(OldSchedule, NewScheduleMap);
|
||||
isl_union_map_free(OldSchedule);
|
||||
isl_union_map_free(NewScheduleMap);
|
||||
return changed;
|
||||
}
|
||||
|
||||
@ -1090,10 +1197,8 @@ bool IslScheduleOptimizer::runOnScop(Scop &S) {
|
||||
auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
|
||||
isl_schedule *NewSchedule =
|
||||
ScheduleTreeOptimizer::optimizeSchedule(Schedule, TTI);
|
||||
isl_union_map *NewScheduleMap = isl_schedule_get_map(NewSchedule);
|
||||
|
||||
if (!ScheduleTreeOptimizer::isProfitableSchedule(S, NewScheduleMap)) {
|
||||
isl_union_map_free(NewScheduleMap);
|
||||
if (!ScheduleTreeOptimizer::isProfitableSchedule(S, NewSchedule)) {
|
||||
isl_schedule_free(NewSchedule);
|
||||
return false;
|
||||
}
|
||||
@ -1104,7 +1209,6 @@ bool IslScheduleOptimizer::runOnScop(Scop &S) {
|
||||
if (OptimizedScops)
|
||||
S.dump();
|
||||
|
||||
isl_union_map_free(NewScheduleMap);
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -12,11 +12,34 @@
|
||||
; CHECK: double Packed_A[ { [] -> [(1024)] } ][ { [] -> [(4)] } ]; // Element size 8
|
||||
; CHECK: double Packed_B[ { [] -> [(3072)] } ][ { [] -> [(8)] } ]; // Element size 8
|
||||
;
|
||||
; CHECK: { Stmt_bb14[i0, i1, i2] -> MemRef_arg6[i0, i2] };
|
||||
; CHECK: new: { Stmt_bb14[i0, i1, i2] -> Packed_A[o0, o1] : 256*floor((-i2 + o0)/256) = -i2 + o0 and 4*floor((-i0 + o1)/4) = -i0 + o1 and 0 <= o1 <= 3 and -3 + i0 - 16*floor((i0)/16) <= 4*floor((o0)/256) <= i0 - 16*floor((i0)/16) };
|
||||
; CHECK: { Stmt_Copy_0[i0, i1, i2] -> MemRef_arg6[i0, i2] };
|
||||
; CHECK: new: { Stmt_Copy_0[i0, i1, i2] -> Packed_A[o0, o1] : 256*floor((-i2 + o0)/256) = -i2 + o0 and 4*floor((-i0 + o1)/4) = -i0 + o1 and 0 <= o1 <= 3 and -3 + i0 - 16*floor((i0)/16) <= 4*floor((o0)/256) <= i0 - 16*floor((i0)/16) };
|
||||
;
|
||||
; CHECK: { Stmt_bb14[i0, i1, i2] -> MemRef_arg7[i2, i1] };
|
||||
; CHECK: new: { Stmt_bb14[i0, i1, i2] -> Packed_B[o0, o1] : 256*floor((-i2 + o0)/256) = -i2 + o0 and 8*floor((-i1 + o1)/8) = -i1 + o1 and 0 <= o1 <= 7 and -7 + i1 - 96*floor((i1)/96) <= 8*floor((o0)/256) <= i1 - 96*floor((i1)/96) };
|
||||
; CHECK: { Stmt_Copy_0[i0, i1, i2] -> MemRef_arg7[i2, i1] };
|
||||
; CHECK: new: { Stmt_Copy_0[i0, i1, i2] -> Packed_B[o0, o1] : 256*floor((-i2 + o0)/256) = -i2 + o0 and 8*floor((-i1 + o1)/8) = -i1 + o1 and 0 <= o1 <= 7 and -7 + i1 - 96*floor((i1)/96) <= 8*floor((o0)/256) <= i1 - 96*floor((i1)/96) };
|
||||
;
|
||||
; CHECK: CopyStmt_0
|
||||
; CHECK: Domain :=
|
||||
; CHECK: { CopyStmt_0[i0, i1, i2] : 0 <= i0 <= 1055 and 0 <= i1 <= 1055 and 0 <= i2 <= 1023 };
|
||||
; CHECK: Schedule :=
|
||||
; CHECK: ;
|
||||
; CHECK: MustWriteAccess := [Reduction Type: NONE] [Scalar: 0]
|
||||
; CHECK: null;
|
||||
; CHECK: new: { CopyStmt_0[i0, i1, i2] -> Packed_A[o0, o1] : 256*floor((-i2 + o0)/256) = -i2 + o0 and 4*floor((-i0 + o1)/4) = -i0 + o1 and 0 <= o1 <= 3 and -3 + i0 - 16*floor((i0)/16) <= 4*floor((o0)/256) <= i0 - 16*floor((i0)/16) };
|
||||
; CHECK: ReadAccess := [Reduction Type: NONE] [Scalar: 0]
|
||||
; CHECK: null;
|
||||
; CHECK: new: { CopyStmt_0[i0, i1, i2] -> MemRef_arg6[i0, i2] };
|
||||
; CHECK: CopyStmt_1
|
||||
; CHECK: Domain :=
|
||||
; CHECK: { CopyStmt_1[i0, i1, i2] : 0 <= i0 <= 1055 and 0 <= i1 <= 1055 and 0 <= i2 <= 1023 };
|
||||
; CHECK: Schedule :=
|
||||
; CHECK: ;
|
||||
; CHECK: MustWriteAccess := [Reduction Type: NONE] [Scalar: 0]
|
||||
; CHECK: null;
|
||||
; CHECK: new: { CopyStmt_1[i0, i1, i2] -> Packed_B[o0, o1] : 256*floor((-i2 + o0)/256) = -i2 + o0 and 8*floor((-i1 + o1)/8) = -i1 + o1 and 0 <= o1 <= 7 and -7 + i1 - 96*floor((i1)/96) <= 8*floor((o0)/256) <= i1 - 96*floor((i1)/96) };
|
||||
; CHECK: ReadAccess := [Reduction Type: NONE] [Scalar: 0]
|
||||
; CHECK: null;
|
||||
; CHECK: new: { CopyStmt_1[i0, i1, i2] -> MemRef_arg7[i2, i1] };
|
||||
;
|
||||
target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
|
||||
target triple = "x86_64-unknown-unknown"
|
||||
@ -35,10 +58,10 @@ bb9: ; preds = %bb26, %bb8
|
||||
%tmp12 = load double, double* %tmp11, align 8
|
||||
%tmp13 = fmul double %tmp12, %arg4
|
||||
store double %tmp13, double* %tmp11, align 8
|
||||
br label %bb14
|
||||
br label %Copy_0
|
||||
|
||||
bb14: ; preds = %bb14, %bb9
|
||||
%tmp15 = phi i64 [ 0, %bb9 ], [ %tmp24, %bb14 ]
|
||||
Copy_0: ; preds = %Copy_0, %bb9
|
||||
%tmp15 = phi i64 [ 0, %bb9 ], [ %tmp24, %Copy_0 ]
|
||||
%tmp16 = getelementptr inbounds [1024 x double], [1024 x double]* %arg6, i64 %tmp, i64 %tmp15
|
||||
%tmp17 = load double, double* %tmp16, align 8
|
||||
%tmp18 = fmul double %tmp17, %arg3
|
||||
@ -50,9 +73,9 @@ bb14: ; preds = %bb14, %bb9
|
||||
store double %tmp23, double* %tmp11, align 8
|
||||
%tmp24 = add nuw nsw i64 %tmp15, 1
|
||||
%tmp25 = icmp ne i64 %tmp24, 1024
|
||||
br i1 %tmp25, label %bb14, label %bb26
|
||||
br i1 %tmp25, label %Copy_0, label %bb26
|
||||
|
||||
bb26: ; preds = %bb14
|
||||
bb26: ; preds = %Copy_0
|
||||
%tmp27 = add nuw nsw i64 %tmp10, 1
|
||||
%tmp28 = icmp ne i64 %tmp27, 1056
|
||||
br i1 %tmp28, label %bb9, label %bb29
|
||||
|
Loading…
Reference in New Issue
Block a user