Perform copying to created arrays according to the packing transformation

This is the fourth patch to apply the BLIS matmul optimization pattern on matmul
kernels (http://www.cs.utexas.edu/users/flame/pubs/TOMS-BLIS-Analytical.pdf).
BLIS implements gemm as three nested loops around a macro-kernel, plus two
packing routines. The macro-kernel is implemented in terms of two additional
loops around a micro-kernel. The micro-kernel is a loop around a rank-1
(i.e., outer product) update. In this change we perform copying to created
arrays, which is the last step to implement the packing transformation.

Reviewed-by: Tobias Grosser <tobias@grosser.es>

Differential Revision: https://reviews.llvm.org/D23260

llvm-svn: 281441
This commit is contained in:
Roman Gareev 2016-09-14 06:26:09 +00:00
parent 79e00930e2
commit b3224adfb6
15 changed files with 368 additions and 44 deletions

View File

@ -166,6 +166,17 @@ public:
/// was enabled.
llvm::Value *getOverflowState() const;
/// Create LLVM-IR that computes the memory location of an access expression.
///
/// For a given isl_ast_expr[ession] of type isl_ast_op_access this function
/// creates IR that computes the address the access expression refers to.
///
/// @param Expr The ast expression of type isl_ast_op_access
/// for which we generate LLVM-IR.
///
/// @return The llvm::Value* containing the result of the computation.
llvm::Value *createAccessAddress(__isl_take isl_ast_expr *Expr);
private:
Scop &S;
@ -203,7 +214,6 @@ private:
llvm::Value *createId(__isl_take isl_ast_expr *Expr);
llvm::Value *createInt(__isl_take isl_ast_expr *Expr);
llvm::Value *createOpAddressOf(__isl_take isl_ast_expr *Expr);
llvm::Value *createAccessAddress(__isl_take isl_ast_expr *Expr);
/// Create a binary operation @p Opc and track overflows if requested.
///

View File

@ -375,6 +375,21 @@ protected:
///
virtual __isl_give isl_union_map *
getScheduleForAstNode(__isl_take isl_ast_node *Node);
private:
/// Create code for a copy statement.
///
/// A copy statement is expected to have one read memory access and one write
/// memory access (in this very order). Data is loaded from the location
/// described by the read memory access and written to the location described
/// by the write memory access. @p NewAccesses contains for each access
/// the isl ast expression that describes the location accessed.
///
/// @param Stmt The copy statement that contains the accesses.
/// @param NewAccesses The hash table that contains remappings from memory
/// ids to new access expressions.
void generateCopyStmt(ScopStmt *Stmt,
__isl_keep isl_id_to_ast_expr *NewAccesses);
};
#endif

View File

@ -88,7 +88,7 @@ public:
///
/// @return True, if we believe @p NewSchedule is an improvement for @p S.
static bool isProfitableSchedule(polly::Scop &S,
__isl_keep isl_union_map *NewSchedule);
__isl_keep isl_schedule *NewSchedule);
/// Isolate a set of partial tile prefixes.
///

View File

@ -689,6 +689,19 @@ public:
ArrayRef<const SCEV *> Subscripts, ArrayRef<const SCEV *> Sizes,
Value *AccessValue, ScopArrayInfo::MemoryKind Kind,
StringRef BaseName);
/// Create a new MemoryAccess that corresponds to @p AccRel.
///
/// Along with @p Stmt and @p AccType it uses information about dimension
/// lengths of the accessed array, the type of the accessed array elements,
/// the name of the accessed array that is derived from the object accessible
/// via @p AccRel.
///
/// @param Stmt The parent statement.
/// @param AccType Whether read or write access.
/// @param AccRel The access relation that describes the memory access.
MemoryAccess(ScopStmt *Stmt, AccessType AccType, __isl_take isl_map *AccRel);
~MemoryAccess();
/// Add a new incoming block/value pairs for this PHI/ExitPHI access.
@ -1083,6 +1096,16 @@ public:
/// Create an overapproximating ScopStmt for the region @p R.
ScopStmt(Scop &parent, Region &R);
/// Create a copy statement.
///
/// @param Stmt The parent statement.
/// @param SourceRel The source location.
/// @param TargetRel The target location.
/// @param Domain The original domain under which copy statement whould
/// be executed.
ScopStmt(Scop &parent, __isl_take isl_map *SourceRel,
__isl_take isl_map *TargetRel, __isl_take isl_set *Domain);
/// Initialize members after all MemoryAccesses have been added.
void init(LoopInfo &LI);
@ -1217,10 +1240,14 @@ public:
/// Get the schedule function of this ScopStmt.
///
/// @return The schedule function of this ScopStmt.
/// @return The schedule function of this ScopStmt, if it does not contain
/// extension nodes, and nullptr, otherwise.
__isl_give isl_map *getSchedule() const;
/// Get an isl string representing this schedule.
///
/// @return An isl string representing this schedule, if it does not contain
/// extension nodes, and an empty string, otherwise.
std::string getScheduleStr() const;
/// Get the invalid domain for this statement.
@ -1245,6 +1272,9 @@ public:
/// Return true if this statement represents a single basic block.
bool isBlockStmt() const { return BB != nullptr; }
/// Return true if this is a copy statement.
bool isCopyStmt() const { return BB == nullptr && R == nullptr; }
/// Get the region represented by this ScopStmt (if any).
///
/// @return The region represented by this ScopStmt, or null if the statement
@ -1448,6 +1478,9 @@ private:
/// Max loop depth.
unsigned MaxLoopDepth;
/// Number of copy statements.
unsigned CopyStmtsNum;
typedef std::list<ScopStmt> StmtSet;
/// The statements in this Scop.
StmtSet Stmts;
@ -1615,11 +1648,6 @@ private:
Scop(Region &R, ScalarEvolution &SE, LoopInfo &LI,
ScopDetection::DetectionContext &DC);
/// Add the access function to all MemoryAccess objects of the Scop
/// created in this pass.
void addAccessFunction(MemoryAccess *Access) {
AccessFunctions.emplace_back(Access);
}
//@}
/// Initialize this ScopBuilder.
@ -1927,6 +1955,30 @@ private:
public:
~Scop();
/// Get the count of copy statements added to this Scop.
///
/// @return The count of copy statements added to this Scop.
unsigned getCopyStmtsNum() { return CopyStmtsNum; }
/// Create a new copy statement.
///
/// A new statement will be created and added to the statement vector.
///
/// @param Stmt The parent statement.
/// @param SourceRel The source location.
/// @param TargetRel The target location.
/// @param Domain The original domain under which copy statement whould
/// be executed.
ScopStmt *addScopStmt(__isl_take isl_map *SourceRel,
__isl_take isl_map *TargetRel,
__isl_take isl_set *Domain);
/// Add the access function to all MemoryAccess objects of the Scop
/// created in this pass.
void addAccessFunction(MemoryAccess *Access) {
AccessFunctions.emplace_back(Access);
}
ScalarEvolution *getSE() const;
/// Get the count of parameters used in this Scop.
@ -2349,6 +2401,9 @@ public:
__isl_give isl_union_map *getAccesses();
/// Get the schedule of all the statements in the SCoP.
///
/// @return The schedule of all the statements in the SCoP, if the schedule of
/// the Scop does not contain extension nodes, and nullptr, otherwise.
__isl_give isl_union_map *getSchedule() const;
/// Get a schedule tree describing the schedule of all statements.
@ -2380,6 +2435,11 @@ public:
/// Find the ScopArrayInfo associated with an isl Id
/// that has name @p Name.
ScopArrayInfo *getArrayInfoByName(const std::string BaseName);
/// Check whether @p Schedule contains extension nodes.
///
/// @return true if @p Schedule contains extension nodes.
static bool containsExtensionNode(__isl_keep isl_schedule *Schedule);
};
/// Print Scop scop to raw_ostream O.

View File

@ -153,6 +153,8 @@ static void collectInfo(Scop &S, isl_union_map **Read, isl_union_map **Write,
// to match the new access domains, thus we need
// [Stmt[i0, i1] -> MemAcc_A[i0 + i1]] -> [0, i0, 2, i1, 0]
isl_map *Schedule = Stmt.getSchedule();
assert(Schedule && "Schedules that contain extension nodes require "
"special handling.");
Schedule = isl_map_apply_domain(
Schedule,
isl_map_reverse(isl_map_domain_map(isl_map_copy(accdom))));
@ -162,7 +164,10 @@ static void collectInfo(Scop &S, isl_union_map **Read, isl_union_map **Write,
} else {
accdom = tag(accdom, MA, Level);
if (Level > Dependences::AL_Statement) {
isl_map *Schedule = tag(Stmt.getSchedule(), MA, Level);
auto *StmtScheduleMap = Stmt.getSchedule();
assert(StmtScheduleMap && "Schedules that contain extension nodes "
"require special handling.");
isl_map *Schedule = tag(StmtScheduleMap, MA, Level);
*StmtSchedule = isl_union_map_add_map(*StmtSchedule, Schedule);
}
}
@ -610,6 +615,8 @@ bool Dependences::isValidSchedule(Scop &S,
StmtScat = Stmt.getSchedule();
else
StmtScat = isl_map_copy((*NewSchedule)[&Stmt]);
assert(StmtScat &&
"Schedules that contain extension nodes require special handling.");
if (!ScheduleSpace)
ScheduleSpace = isl_space_range(isl_map_get_space(StmtScat));

View File

@ -134,6 +134,8 @@ __isl_give isl_union_map *PolyhedralInfo::getScheduleForLoop(const Scop *S,
unsigned int MaxDim = SS->getNumIterators();
DEBUG(dbgs() << "Maximum depth of Stmt:\t" << MaxDim << "\n");
auto *ScheduleMap = SS->getSchedule();
assert(ScheduleMap &&
"Schedules that contain extension nodes require special handling.");
ScheduleMap = isl_map_project_out(ScheduleMap, isl_dim_out, CurrDim + 1,
MaxDim - CurrDim - 1);

View File

@ -857,6 +857,28 @@ MemoryAccess::MemoryAccess(ScopStmt *Stmt, Instruction *AccessInst,
Id = isl_id_alloc(Stmt->getParent()->getIslCtx(), IdName.c_str(), this);
}
MemoryAccess::MemoryAccess(ScopStmt *Stmt, AccessType AccType,
__isl_take isl_map *AccRel)
: Kind(ScopArrayInfo::MemoryKind::MK_Array), AccType(AccType),
RedType(RT_NONE), Statement(Stmt), InvalidDomain(nullptr),
AccessInstruction(nullptr), IsAffine(true), AccessRelation(nullptr),
NewAccessRelation(AccRel) {
auto *ArrayInfoId = isl_map_get_tuple_id(NewAccessRelation, isl_dim_out);
auto *SAI = ScopArrayInfo::getFromId(ArrayInfoId);
Sizes.push_back(nullptr);
for (unsigned i = 1; i < SAI->getNumberOfDimensions(); i++)
Sizes.push_back(SAI->getDimensionSize(i));
ElementType = SAI->getElementType();
BaseAddr = SAI->getBasePtr();
BaseName = SAI->getName();
static const std::string TypeStrings[] = {"", "_Read", "_Write", "_MayWrite"};
const std::string Access = TypeStrings[AccType] + utostr(Stmt->size()) + "_";
std::string IdName =
getIslCompatibleName(Stmt->getBaseName(), Access, BaseName);
Id = isl_id_alloc(Stmt->getParent()->getIslCtx(), IdName.c_str(), this);
}
void MemoryAccess::realignParams() {
auto *Ctx = Statement->getParent()->getContext();
InvalidDomain = isl_set_gist_params(InvalidDomain, isl_set_copy(Ctx));
@ -1040,6 +1062,10 @@ __isl_give isl_map *ScopStmt::getSchedule() const {
isl_aff_zero_on_domain(isl_local_space_from_space(getDomainSpace())));
}
auto *Schedule = getParent()->getSchedule();
if (!Schedule) {
isl_set_free(Domain);
return nullptr;
}
Schedule = isl_union_map_intersect_domain(
Schedule, isl_union_set_from_set(isl_set_copy(Domain)));
if (isl_union_map_is_empty(Schedule)) {
@ -1430,6 +1456,25 @@ ScopStmt::ScopStmt(Scop &parent, BasicBlock &bb)
BaseName = getIslCompatibleName("Stmt_", &bb, "");
}
ScopStmt::ScopStmt(Scop &parent, __isl_take isl_map *SourceRel,
__isl_take isl_map *TargetRel, __isl_take isl_set *NewDomain)
: Parent(parent), InvalidDomain(nullptr), Domain(NewDomain), BB(nullptr),
R(nullptr), Build(nullptr) {
BaseName = getIslCompatibleName("CopyStmt_", "",
std::to_string(parent.getCopyStmtsNum()));
auto *Id = isl_id_alloc(getIslCtx(), getBaseName(), this);
Domain = isl_set_set_tuple_id(Domain, isl_id_copy(Id));
TargetRel = isl_map_set_tuple_id(TargetRel, isl_dim_in, Id);
auto *Access =
new MemoryAccess(this, MemoryAccess::AccessType::MUST_WRITE, TargetRel);
parent.addAccessFunction(Access);
addAccess(Access);
SourceRel = isl_map_set_tuple_id(SourceRel, isl_dim_in, isl_id_copy(Id));
Access = new MemoryAccess(this, MemoryAccess::AccessType::READ, SourceRel);
parent.addAccessFunction(Access);
addAccess(Access);
}
void ScopStmt::init(LoopInfo &LI) {
assert(!Domain && "init must be called only once");
@ -1576,6 +1621,8 @@ std::string ScopStmt::getDomainStr() const { return stringFromIslObj(Domain); }
std::string ScopStmt::getScheduleStr() const {
auto *S = getSchedule();
if (!S)
return "";
auto Str = stringFromIslObj(S);
isl_map_free(S);
return Str;
@ -3041,9 +3088,10 @@ Scop::Scop(Region &R, ScalarEvolution &ScalarEvolution, LoopInfo &LI,
ScopDetection::DetectionContext &DC)
: SE(&ScalarEvolution), R(R), IsOptimized(false),
HasSingleExitEdge(R.getExitingBlock()), HasErrorBlock(false),
MaxLoopDepth(0), DC(DC), IslCtx(isl_ctx_alloc(), isl_ctx_free),
Context(nullptr), Affinator(this, LI), AssumedContext(nullptr),
InvalidContext(nullptr), Schedule(nullptr) {
MaxLoopDepth(0), CopyStmtsNum(0), DC(DC),
IslCtx(isl_ctx_alloc(), isl_ctx_free), Context(nullptr),
Affinator(this, LI), AssumedContext(nullptr), InvalidContext(nullptr),
Schedule(nullptr) {
if (IslOnErrorAbort)
isl_options_set_on_error(getIslCtx(), ISL_ON_ERROR_ABORT);
buildContext();
@ -3922,8 +3970,27 @@ __isl_give isl_union_map *Scop::getAccesses() {
return getAccessesOfType([](MemoryAccess &MA) { return true; });
}
// Check whether @p Node is an extension node.
//
// @return true if @p Node is an extension node.
isl_bool isNotExtNode(__isl_keep isl_schedule_node *Node, void *User) {
if (isl_schedule_node_get_type(Node) == isl_schedule_node_extension)
return isl_bool_error;
else
return isl_bool_true;
}
bool Scop::containsExtensionNode(__isl_keep isl_schedule *Schedule) {
return isl_schedule_foreach_schedule_node_top_down(Schedule, isNotExtNode,
nullptr) == isl_stat_error;
}
__isl_give isl_union_map *Scop::getSchedule() const {
auto *Tree = getScheduleTree();
if (containsExtensionNode(Tree)) {
isl_schedule_free(Tree);
return nullptr;
}
auto *S = isl_schedule_get_map(Tree);
isl_schedule_free(Tree);
return S;
@ -4059,6 +4126,14 @@ void Scop::addScopStmt(BasicBlock *BB, Region *R) {
}
}
ScopStmt *Scop::addScopStmt(__isl_take isl_map *SourceRel,
__isl_take isl_map *TargetRel,
__isl_take isl_set *Domain) {
Stmts.emplace_back(*this, SourceRel, TargetRel, Domain);
CopyStmtsNum++;
return &(Stmts.back());
}
void Scop::buildSchedule(LoopInfo &LI) {
Loop *L = getLoopSurroundingScop(*this, LI);
LoopStackTy LoopStack({LoopStackElementTy(L, nullptr, 0)});

View File

@ -681,7 +681,9 @@ void BlockGenerator::createExitPHINodeMerges(Scop &S) {
void BlockGenerator::invalidateScalarEvolution(Scop &S) {
for (auto &Stmt : S)
if (Stmt.isBlockStmt())
if (Stmt.isCopyStmt())
continue;
else if (Stmt.isBlockStmt())
for (auto &Inst : *Stmt.getBasicBlock())
SE.forgetValue(&Inst);
else if (Stmt.isRegionStmt())

View File

@ -61,7 +61,8 @@ void ScopAnnotator::buildAliasScopes(Scop &S) {
SetVector<Value *> BasePtrs;
for (ScopStmt &Stmt : S)
for (MemoryAccess *MA : Stmt)
BasePtrs.insert(MA->getBaseAddr());
if (!Stmt.isCopyStmt())
BasePtrs.insert(MA->getBaseAddr());
std::string AliasScopeStr = "polly.alias.scope.";
for (Value *BasePtr : BasePtrs)

View File

@ -593,8 +593,7 @@ void IslAstInfo::printScop(raw_ostream &OS, Scop &S) const {
P = isl_ast_node_print(RootNode, P, Options);
AstStr = isl_printer_get_str(P);
isl_union_map *Schedule =
isl_union_map_intersect_domain(S.getSchedule(), S.getDomains());
auto *Schedule = S.getScheduleTree();
DEBUG({
dbgs() << S.getContextStr() << "\n";
@ -609,7 +608,7 @@ void IslAstInfo::printScop(raw_ostream &OS, Scop &S) const {
free(AstStr);
isl_ast_expr_free(RunCondition);
isl_union_map_free(Schedule);
isl_schedule_free(Schedule);
isl_ast_node_free(RootNode);
isl_printer_free(P);
}

View File

@ -767,6 +767,23 @@ void IslNodeBuilder::createSubstitutionsVector(
isl_ast_expr_free(Expr);
}
void IslNodeBuilder::generateCopyStmt(
ScopStmt *Stmt, __isl_keep isl_id_to_ast_expr *NewAccesses) {
assert(Stmt->size() == 2);
auto ReadAccess = Stmt->begin();
auto WriteAccess = ReadAccess++;
assert((*ReadAccess)->isRead() && (*WriteAccess)->isMustWrite());
assert((*ReadAccess)->getElementType() == (*WriteAccess)->getElementType() &&
"Accesses use the same data type");
assert((*ReadAccess)->isArrayKind() && (*WriteAccess)->isArrayKind());
auto *AccessExpr =
isl_id_to_ast_expr_get(NewAccesses, (*ReadAccess)->getId());
auto *LoadValue = ExprBuilder.create(AccessExpr);
AccessExpr = isl_id_to_ast_expr_get(NewAccesses, (*WriteAccess)->getId());
auto *StoreAddr = ExprBuilder.createAccessAddress(AccessExpr);
Builder.CreateStore(LoadValue, StoreAddr);
}
void IslNodeBuilder::createUser(__isl_take isl_ast_node *User) {
LoopToScevMapT LTS;
isl_id *Id;
@ -781,12 +798,17 @@ void IslNodeBuilder::createUser(__isl_take isl_ast_node *User) {
Stmt = (ScopStmt *)isl_id_get_user(Id);
auto *NewAccesses = createNewAccesses(Stmt, User);
createSubstitutions(Expr, Stmt, LTS);
if (Stmt->isCopyStmt()) {
generateCopyStmt(Stmt, NewAccesses);
isl_ast_expr_free(Expr);
} else {
createSubstitutions(Expr, Stmt, LTS);
if (Stmt->isBlockStmt())
BlockGen.copyStmt(*Stmt, LTS, NewAccesses);
else
RegionGen.copyStmt(*Stmt, LTS, NewAccesses);
if (Stmt->isBlockStmt())
BlockGen.copyStmt(*Stmt, LTS, NewAccesses);
else
RegionGen.copyStmt(*Stmt, LTS, NewAccesses);
}
isl_id_to_ast_expr_free(NewAccesses);
isl_ast_node_free(User);

View File

@ -294,6 +294,8 @@ bool JSONImporter::importSchedule(Scop &S, Json::Value &JScop,
int Index = 0;
for (ScopStmt &Stmt : S) {
Json::Value Schedule = JScop["statements"][Index]["schedule"];
assert(!Schedule.asString().empty() &&
"Schedules that contain extension nodes require special handling.");
isl_map *Map = isl_map_read_from_str(S.getIslCtx(), Schedule.asCString());
isl_space *Space = Stmt.getDomainSpace();

View File

@ -92,6 +92,8 @@ char DeadCodeElim::ID = 0;
// no point in trying to remove them from the live-out set.
__isl_give isl_union_set *DeadCodeElim::getLiveOut(Scop &S) {
isl_union_map *Schedule = S.getSchedule();
assert(Schedule &&
"Schedules that contain extension nodes require special handling.");
isl_union_map *WriteIterations = isl_union_map_reverse(S.getMustWrites());
isl_union_map *WriteTimes =
isl_union_map_apply_range(WriteIterations, isl_union_map_copy(Schedule));

View File

@ -660,6 +660,76 @@ identifyAccessByAccessRelation(ScopStmt *Stmt,
return IdentifiedAccess;
}
/// Add constrains to @Dim dimension of @p ExtMap.
///
/// If @ExtMap has the following form [O0, O1, O2]->[I1, I2, I3],
/// the following constraint will be added
/// Bound * OM <= IM <= Bound * (OM + 1) - 1,
/// where M is @p Dim and Bound is @p Bound.
///
/// @param ExtMap The isl map to be modified.
/// @param Dim The output dimension to be modfied.
/// @param Bound The value that is used to specify the constraint.
/// @return The modified isl map
__isl_give isl_map *
addExtensionMapMatMulDimConstraint(__isl_take isl_map *ExtMap, unsigned Dim,
unsigned Bound) {
assert(Bound != 0);
auto *ExtMapSpace = isl_map_get_space(ExtMap);
auto *ConstrSpace = isl_local_space_from_space(ExtMapSpace);
auto *Constr =
isl_constraint_alloc_inequality(isl_local_space_copy(ConstrSpace));
Constr = isl_constraint_set_coefficient_si(Constr, isl_dim_out, Dim, 1);
Constr =
isl_constraint_set_coefficient_si(Constr, isl_dim_in, Dim, Bound * (-1));
ExtMap = isl_map_add_constraint(ExtMap, Constr);
Constr = isl_constraint_alloc_inequality(ConstrSpace);
Constr = isl_constraint_set_coefficient_si(Constr, isl_dim_out, Dim, -1);
Constr = isl_constraint_set_coefficient_si(Constr, isl_dim_in, Dim, Bound);
Constr = isl_constraint_set_constant_si(Constr, Bound - 1);
return isl_map_add_constraint(ExtMap, Constr);
}
/// Create an access relation that is specific for matrix multiplication
/// pattern.
///
/// Create an access relation of the following form:
/// { [O0, O1, O2]->[I1, I2, I3] :
/// FirstOutputDimBound * O0 <= I1 <= FirstOutputDimBound * (O0 + 1) - 1
/// and SecondOutputDimBound * O1 <= I2 <= SecondOutputDimBound * (O1 + 1) - 1
/// and ThirdOutputDimBound * O2 <= I3 <= ThirdOutputDimBound * (O2 + 1) - 1}
/// where FirstOutputDimBound is @p FirstOutputDimBound,
/// SecondOutputDimBound is @p SecondOutputDimBound,
/// ThirdOutputDimBound is @p ThirdOutputDimBound
///
/// @param Ctx The isl context.
/// @param FirstOutputDimBound,
/// SecondOutputDimBound,
/// ThirdOutputDimBound The parameters of the access relation.
/// @return The specified access relation.
__isl_give isl_map *getMatMulExt(isl_ctx *Ctx, unsigned FirstOutputDimBound,
unsigned SecondOutputDimBound,
unsigned ThirdOutputDimBound) {
auto *NewRelSpace = isl_space_alloc(Ctx, 0, 3, 3);
auto *extensionMap = isl_map_universe(NewRelSpace);
if (!FirstOutputDimBound)
extensionMap = isl_map_fix_si(extensionMap, isl_dim_out, 0, 0);
else
extensionMap = addExtensionMapMatMulDimConstraint(extensionMap, 0,
FirstOutputDimBound);
if (!SecondOutputDimBound)
extensionMap = isl_map_fix_si(extensionMap, isl_dim_out, 1, 0);
else
extensionMap = addExtensionMapMatMulDimConstraint(extensionMap, 1,
SecondOutputDimBound);
if (!ThirdOutputDimBound)
extensionMap = isl_map_fix_si(extensionMap, isl_dim_out, 2, 0);
else
extensionMap = addExtensionMapMatMulDimConstraint(extensionMap, 2,
ThirdOutputDimBound);
return extensionMap;
}
/// Create an access relation that is specific to the matrix
/// multiplication pattern.
///
@ -758,6 +828,14 @@ __isl_give isl_map *getMatMulAccRel(__isl_take isl_map *MapOldIndVar,
return isl_map_apply_range(MapOldIndVar, AccessRel);
}
__isl_give isl_schedule_node *
createExtensionNode(__isl_take isl_schedule_node *Node,
__isl_take isl_map *ExtensionMap) {
auto *Extension = isl_union_map_from_map(ExtensionMap);
auto *NewNode = isl_schedule_node_from_extension(Extension);
return isl_schedule_node_graft_before(Node, NewNode);
}
/// Apply the packing transformation.
///
/// The packing transformation can be described as a data-layout
@ -772,9 +850,9 @@ __isl_give isl_map *getMatMulAccRel(__isl_take isl_map *MapOldIndVar,
/// @param MicroParams, MacroParams Parameters of the BLIS kernel
/// to be taken into account.
/// @return The optimized schedule node.
static void optimizeDataLayoutMatrMulPattern(__isl_take isl_map *MapOldIndVar,
MicroKernelParamsTy MicroParams,
MacroKernelParamsTy MacroParams) {
static __isl_give isl_schedule_node *optimizeDataLayoutMatrMulPattern(
__isl_take isl_schedule_node *Node, __isl_take isl_map *MapOldIndVar,
MicroKernelParamsTy MicroParams, MacroKernelParamsTy MacroParams) {
auto InputDimsId = isl_map_get_tuple_id(MapOldIndVar, isl_dim_in);
auto *Stmt = static_cast<ScopStmt *>(isl_id_get_user(InputDimsId));
isl_id_free(InputDimsId);
@ -782,8 +860,12 @@ static void optimizeDataLayoutMatrMulPattern(__isl_take isl_map *MapOldIndVar,
MemoryAccess *MemAccessB = identifyAccessB(Stmt);
if (!MemAccessA || !MemAccessB) {
isl_map_free(MapOldIndVar);
return;
return Node;
}
Node = isl_schedule_node_parent(isl_schedule_node_parent(Node));
Node = isl_schedule_node_parent(isl_schedule_node_parent(Node));
Node = isl_schedule_node_parent(Node);
Node = isl_schedule_node_child(isl_schedule_node_band_split(Node, 2), 0);
auto *AccRel =
getMatMulAccRel(isl_map_copy(MapOldIndVar), MacroParams.Kc, 3, 6);
unsigned FirstDimSize = MacroParams.Mc * MacroParams.Kc / MicroParams.Mr;
@ -791,14 +873,34 @@ static void optimizeDataLayoutMatrMulPattern(__isl_take isl_map *MapOldIndVar,
auto *SAI = Stmt->getParent()->createScopArrayInfo(
MemAccessA->getElementType(), "Packed_A", {FirstDimSize, SecondDimSize});
AccRel = isl_map_set_tuple_id(AccRel, isl_dim_out, SAI->getBasePtrId());
auto *OldAcc = MemAccessA->getAccessRelation();
MemAccessA->setNewAccessRelation(AccRel);
auto *ExtMap =
getMatMulExt(Stmt->getIslCtx(), MacroParams.Mc, 0, MacroParams.Kc);
ExtMap = isl_map_project_out(ExtMap, isl_dim_in, 1, 1);
auto *Domain = Stmt->getDomain();
auto *NewStmt = Stmt->getParent()->addScopStmt(
OldAcc, MemAccessA->getAccessRelation(), isl_set_copy(Domain));
ExtMap = isl_map_set_tuple_id(ExtMap, isl_dim_out, NewStmt->getDomainId());
Node = createExtensionNode(Node, ExtMap);
Node = isl_schedule_node_child(Node, 0);
AccRel = getMatMulAccRel(MapOldIndVar, MacroParams.Kc, 4, 7);
FirstDimSize = MacroParams.Nc * MacroParams.Kc / MicroParams.Nr;
SecondDimSize = MicroParams.Nr;
SAI = Stmt->getParent()->createScopArrayInfo(
MemAccessB->getElementType(), "Packed_B", {FirstDimSize, SecondDimSize});
AccRel = isl_map_set_tuple_id(AccRel, isl_dim_out, SAI->getBasePtrId());
OldAcc = MemAccessB->getAccessRelation();
MemAccessB->setNewAccessRelation(AccRel);
ExtMap = getMatMulExt(Stmt->getIslCtx(), 0, MacroParams.Nc, MacroParams.Kc);
isl_map_move_dims(ExtMap, isl_dim_out, 0, isl_dim_in, 1, 1);
isl_map_move_dims(ExtMap, isl_dim_in, 2, isl_dim_out, 0, 1);
NewStmt = Stmt->getParent()->addScopStmt(
OldAcc, MemAccessB->getAccessRelation(), Domain);
ExtMap = isl_map_set_tuple_id(ExtMap, isl_dim_out, NewStmt->getDomainId());
Node = createExtensionNode(Node, ExtMap);
Node = isl_schedule_node_child(isl_schedule_node_child(Node, 0), 0);
return isl_schedule_node_child(isl_schedule_node_child(Node, 0), 0);
}
/// Get a relation mapping induction variables produced by schedule
@ -842,9 +944,8 @@ __isl_give isl_schedule_node *ScheduleTreeOptimizer::optimizeMatMulPattern(
Node, MicroKernelParams, MacroKernelParams);
if (!MapOldIndVar)
return Node;
optimizeDataLayoutMatrMulPattern(MapOldIndVar, MicroKernelParams,
MacroKernelParams);
return Node;
return optimizeDataLayoutMatrMulPattern(Node, MapOldIndVar, MicroKernelParams,
MacroKernelParams);
}
bool ScheduleTreeOptimizer::isMatrMultPattern(
@ -901,7 +1002,7 @@ __isl_give isl_schedule_node *ScheduleTreeOptimizer::optimizeScheduleNode(
}
bool ScheduleTreeOptimizer::isProfitableSchedule(
Scop &S, __isl_keep isl_union_map *NewSchedule) {
Scop &S, __isl_keep isl_schedule *NewSchedule) {
// To understand if the schedule has been optimized we check if the schedule
// has changed at all.
// TODO: We can improve this by tracking if any necessarily beneficial
@ -911,9 +1012,15 @@ bool ScheduleTreeOptimizer::isProfitableSchedule(
// optimizations, by comparing (yet to be defined) performance metrics
// before/after the scheduling optimizer
// (e.g., #stride-one accesses)
if (S.containsExtensionNode(NewSchedule))
return true;
auto *NewScheduleMap = isl_schedule_get_map(NewSchedule);
isl_union_map *OldSchedule = S.getSchedule();
bool changed = !isl_union_map_is_equal(OldSchedule, NewSchedule);
assert(OldSchedule && "Only IslScheduleOptimizer can insert extension nodes "
"that make Scop::getSchedule() return nullptr.");
bool changed = !isl_union_map_is_equal(OldSchedule, NewScheduleMap);
isl_union_map_free(OldSchedule);
isl_union_map_free(NewScheduleMap);
return changed;
}
@ -1090,10 +1197,8 @@ bool IslScheduleOptimizer::runOnScop(Scop &S) {
auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
isl_schedule *NewSchedule =
ScheduleTreeOptimizer::optimizeSchedule(Schedule, TTI);
isl_union_map *NewScheduleMap = isl_schedule_get_map(NewSchedule);
if (!ScheduleTreeOptimizer::isProfitableSchedule(S, NewScheduleMap)) {
isl_union_map_free(NewScheduleMap);
if (!ScheduleTreeOptimizer::isProfitableSchedule(S, NewSchedule)) {
isl_schedule_free(NewSchedule);
return false;
}
@ -1104,7 +1209,6 @@ bool IslScheduleOptimizer::runOnScop(Scop &S) {
if (OptimizedScops)
S.dump();
isl_union_map_free(NewScheduleMap);
return false;
}

View File

@ -12,11 +12,34 @@
; CHECK: double Packed_A[ { [] -> [(1024)] } ][ { [] -> [(4)] } ]; // Element size 8
; CHECK: double Packed_B[ { [] -> [(3072)] } ][ { [] -> [(8)] } ]; // Element size 8
;
; CHECK: { Stmt_bb14[i0, i1, i2] -> MemRef_arg6[i0, i2] };
; CHECK: new: { Stmt_bb14[i0, i1, i2] -> Packed_A[o0, o1] : 256*floor((-i2 + o0)/256) = -i2 + o0 and 4*floor((-i0 + o1)/4) = -i0 + o1 and 0 <= o1 <= 3 and -3 + i0 - 16*floor((i0)/16) <= 4*floor((o0)/256) <= i0 - 16*floor((i0)/16) };
; CHECK: { Stmt_Copy_0[i0, i1, i2] -> MemRef_arg6[i0, i2] };
; CHECK: new: { Stmt_Copy_0[i0, i1, i2] -> Packed_A[o0, o1] : 256*floor((-i2 + o0)/256) = -i2 + o0 and 4*floor((-i0 + o1)/4) = -i0 + o1 and 0 <= o1 <= 3 and -3 + i0 - 16*floor((i0)/16) <= 4*floor((o0)/256) <= i0 - 16*floor((i0)/16) };
;
; CHECK: { Stmt_bb14[i0, i1, i2] -> MemRef_arg7[i2, i1] };
; CHECK: new: { Stmt_bb14[i0, i1, i2] -> Packed_B[o0, o1] : 256*floor((-i2 + o0)/256) = -i2 + o0 and 8*floor((-i1 + o1)/8) = -i1 + o1 and 0 <= o1 <= 7 and -7 + i1 - 96*floor((i1)/96) <= 8*floor((o0)/256) <= i1 - 96*floor((i1)/96) };
; CHECK: { Stmt_Copy_0[i0, i1, i2] -> MemRef_arg7[i2, i1] };
; CHECK: new: { Stmt_Copy_0[i0, i1, i2] -> Packed_B[o0, o1] : 256*floor((-i2 + o0)/256) = -i2 + o0 and 8*floor((-i1 + o1)/8) = -i1 + o1 and 0 <= o1 <= 7 and -7 + i1 - 96*floor((i1)/96) <= 8*floor((o0)/256) <= i1 - 96*floor((i1)/96) };
;
; CHECK: CopyStmt_0
; CHECK: Domain :=
; CHECK: { CopyStmt_0[i0, i1, i2] : 0 <= i0 <= 1055 and 0 <= i1 <= 1055 and 0 <= i2 <= 1023 };
; CHECK: Schedule :=
; CHECK: ;
; CHECK: MustWriteAccess := [Reduction Type: NONE] [Scalar: 0]
; CHECK: null;
; CHECK: new: { CopyStmt_0[i0, i1, i2] -> Packed_A[o0, o1] : 256*floor((-i2 + o0)/256) = -i2 + o0 and 4*floor((-i0 + o1)/4) = -i0 + o1 and 0 <= o1 <= 3 and -3 + i0 - 16*floor((i0)/16) <= 4*floor((o0)/256) <= i0 - 16*floor((i0)/16) };
; CHECK: ReadAccess := [Reduction Type: NONE] [Scalar: 0]
; CHECK: null;
; CHECK: new: { CopyStmt_0[i0, i1, i2] -> MemRef_arg6[i0, i2] };
; CHECK: CopyStmt_1
; CHECK: Domain :=
; CHECK: { CopyStmt_1[i0, i1, i2] : 0 <= i0 <= 1055 and 0 <= i1 <= 1055 and 0 <= i2 <= 1023 };
; CHECK: Schedule :=
; CHECK: ;
; CHECK: MustWriteAccess := [Reduction Type: NONE] [Scalar: 0]
; CHECK: null;
; CHECK: new: { CopyStmt_1[i0, i1, i2] -> Packed_B[o0, o1] : 256*floor((-i2 + o0)/256) = -i2 + o0 and 8*floor((-i1 + o1)/8) = -i1 + o1 and 0 <= o1 <= 7 and -7 + i1 - 96*floor((i1)/96) <= 8*floor((o0)/256) <= i1 - 96*floor((i1)/96) };
; CHECK: ReadAccess := [Reduction Type: NONE] [Scalar: 0]
; CHECK: null;
; CHECK: new: { CopyStmt_1[i0, i1, i2] -> MemRef_arg7[i2, i1] };
;
target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-unknown"
@ -35,10 +58,10 @@ bb9: ; preds = %bb26, %bb8
%tmp12 = load double, double* %tmp11, align 8
%tmp13 = fmul double %tmp12, %arg4
store double %tmp13, double* %tmp11, align 8
br label %bb14
br label %Copy_0
bb14: ; preds = %bb14, %bb9
%tmp15 = phi i64 [ 0, %bb9 ], [ %tmp24, %bb14 ]
Copy_0: ; preds = %Copy_0, %bb9
%tmp15 = phi i64 [ 0, %bb9 ], [ %tmp24, %Copy_0 ]
%tmp16 = getelementptr inbounds [1024 x double], [1024 x double]* %arg6, i64 %tmp, i64 %tmp15
%tmp17 = load double, double* %tmp16, align 8
%tmp18 = fmul double %tmp17, %arg3
@ -50,9 +73,9 @@ bb14: ; preds = %bb14, %bb9
store double %tmp23, double* %tmp11, align 8
%tmp24 = add nuw nsw i64 %tmp15, 1
%tmp25 = icmp ne i64 %tmp24, 1024
br i1 %tmp25, label %bb14, label %bb26
br i1 %tmp25, label %Copy_0, label %bb26
bb26: ; preds = %bb14
bb26: ; preds = %Copy_0
%tmp27 = add nuw nsw i64 %tmp10, 1
%tmp28 = icmp ne i64 %tmp27, 1056
br i1 %tmp28, label %bb9, label %bb29