[MatMul] Make MatMul detection independent of internal isl representations.

The pattern recognition for MatMul is restrictive.

The number of "disjuncts" in the isl_map containing constraint
information was previously required to be 1
(as per isl_*_coalesce - which should ideally produce a domain map with
a single disjunct, but does not under some circumstances).

This was changed and made more flexible.

Contributed-by: Annanay Agarwal <cs14btech11001@iith.ac.in>

Differential Revision: https://reviews.llvm.org/D36460

llvm-svn: 311302
This commit is contained in:
Michael Kruse 2017-08-20 21:31:11 +00:00
parent d6491f2c4a
commit d091bf8d8e
4 changed files with 198 additions and 86 deletions

View File

@ -483,61 +483,6 @@ ScheduleTreeOptimizer::standardBandOpts(isl::schedule_node Node, void *User) {
return Node;
}
/// Get the position of a dimension with a non-zero coefficient.
///
/// Check that isl constraint @p Constraint has only one non-zero
/// coefficient for dimensions that have type @p DimType. If this is true,
/// return the position of the dimension corresponding to the non-zero
/// coefficient and negative value, otherwise.
///
/// @param Constraint The isl constraint to be checked.
/// @param DimType The type of the dimensions.
/// @return The position of the dimension in case the isl
/// constraint satisfies the requirements, a negative
/// value, otherwise.
static int getMatMulConstraintDim(isl::constraint Constraint,
isl::dim DimType) {
int DimPos = -1;
auto LocalSpace = Constraint.get_local_space();
int LocalSpaceDimNum = LocalSpace.dim(DimType);
for (int i = 0; i < LocalSpaceDimNum; i++) {
auto Val = Constraint.get_coefficient_val(DimType, i);
if (Val.is_zero())
continue;
if (DimPos >= 0 || (DimType == isl::dim::out && !Val.is_one()) ||
(DimType == isl::dim::in && !Val.is_negone()))
return -1;
DimPos = i;
}
return DimPos;
}
/// Check the form of the isl constraint.
///
/// Check that the @p DimInPos input dimension of the isl constraint
/// @p Constraint has a coefficient that is equal to negative one, the @p
/// DimOutPos has a coefficient that is equal to one and others
/// have coefficients equal to zero.
///
/// @param Constraint The isl constraint to be checked.
/// @param DimInPos The input dimension of the isl constraint.
/// @param DimOutPos The output dimension of the isl constraint.
/// @return isl_stat_ok in case the isl constraint satisfies
/// the requirements, isl_stat_error otherwise.
static isl_stat isMatMulOperandConstraint(isl::constraint Constraint,
int &DimInPos, int &DimOutPos) {
auto Val = Constraint.get_constant_val();
if (!isl_constraint_is_equality(Constraint.get()) || !Val.is_zero())
return isl_stat_error;
DimInPos = getMatMulConstraintDim(Constraint, isl::dim::in);
if (DimInPos < 0)
return isl_stat_error;
DimOutPos = getMatMulConstraintDim(Constraint, isl::dim::out);
if (DimOutPos < 0)
return isl_stat_error;
return isl_stat_ok;
}
/// Permute the two dimensions of the isl map.
///
/// Permute @p DstPos and @p SrcPos dimensions of the isl map @p Map that
@ -585,30 +530,49 @@ isl::map permuteDimensions(isl::map Map, isl::dim DimType, unsigned DstPos,
/// second output dimension.
/// @return True in case @p AccMap has the expected form and false,
/// otherwise.
static bool isMatMulOperandAcc(isl::map AccMap, int &FirstPos, int &SecondPos) {
int DimInPos[] = {FirstPos, SecondPos};
auto Lambda = [=, &DimInPos](isl::basic_map BasicMap) -> isl::stat {
auto Constraints = BasicMap.get_constraint_list();
if (isl_constraint_list_n_constraint(Constraints.get()) != 2)
return isl::stat::error;
for (int i = 0; i < 2; i++) {
auto Constraint =
isl::manage(isl_constraint_list_get_constraint(Constraints.get(), i));
int InPos, OutPos;
if (isMatMulOperandConstraint(Constraint, InPos, OutPos) ==
isl_stat_error ||
OutPos > 1 || (DimInPos[OutPos] >= 0 && DimInPos[OutPos] != InPos))
return isl::stat::error;
DimInPos[OutPos] = InPos;
}
return isl::stat::ok;
};
if (AccMap.foreach_basic_map(Lambda) != isl::stat::ok || DimInPos[0] < 0 ||
DimInPos[1] < 0)
static bool isMatMulOperandAcc(isl::set Domain, isl::map AccMap, int &FirstPos,
int &SecondPos) {
isl::space Space = AccMap.get_space();
isl::map Universe = isl::map::universe(Space);
if (Space.dim(isl::dim::out) != 2)
return false;
FirstPos = DimInPos[0];
SecondPos = DimInPos[1];
return true;
// MatMul has the form:
// for (i = 0; i < N; i++)
// for (j = 0; j < M; j++)
// for (k = 0; k < P; k++)
// C[i, j] += A[i, k] * B[k, j]
//
// Permutation of three outer loops: 3! = 6 possibilities.
int FirstDims[] = {0, 0, 1, 1, 2, 2};
int SecondDims[] = {1, 2, 2, 0, 0, 1};
for (int i = 0; i < 6; i += 1) {
auto PossibleMatMul =
Universe.equate(isl::dim::in, FirstDims[i], isl::dim::out, 0)
.equate(isl::dim::in, SecondDims[i], isl::dim::out, 1);
AccMap = AccMap.intersect_domain(Domain);
PossibleMatMul = PossibleMatMul.intersect_domain(Domain);
// If AccMap spans entire domain (Non-partial write),
// compute FirstPos and SecondPos.
// If AccMap != PossibleMatMul here (the two maps have been gisted at
// this point), it means that the writes are not complete, or in other
// words, it is a Partial write and Partial writes must be rejected.
if (AccMap.is_equal(PossibleMatMul)) {
if (FirstPos != -1 && FirstPos != FirstDims[i])
continue;
FirstPos = FirstDims[i];
if (SecondPos != -1 && SecondPos != SecondDims[i])
continue;
SecondPos = SecondDims[i];
return true;
}
}
return false;
}
/// Does the memory access represent a non-scalar operand of the matrix
@ -627,18 +591,16 @@ static bool isMatMulNonScalarReadAccess(MemoryAccess *MemAccess,
if (!MemAccess->isLatestArrayKind() || !MemAccess->isRead())
return false;
auto AccMap = MemAccess->getLatestAccessRelation();
if (isMatMulOperandAcc(AccMap, MMI.i, MMI.j) && !MMI.ReadFromC &&
isl_map_n_basic_map(AccMap.get()) == 1) {
isl::set StmtDomain = MemAccess->getStatement()->getDomain();
if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.i, MMI.j) && !MMI.ReadFromC) {
MMI.ReadFromC = MemAccess;
return true;
}
if (isMatMulOperandAcc(AccMap, MMI.i, MMI.k) && !MMI.A &&
isl_map_n_basic_map(AccMap.get()) == 1) {
if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.i, MMI.k) && !MMI.A) {
MMI.A = MemAccess;
return true;
}
if (isMatMulOperandAcc(AccMap, MMI.k, MMI.j) && !MMI.B &&
isl_map_n_basic_map(AccMap.get()) == 1) {
if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.k, MMI.j) && !MMI.B) {
MMI.B = MemAccess;
return true;
}
@ -758,8 +720,7 @@ static bool containsMatrMult(isl::map PartialSchedule, const Dependences *D,
if (!MemAccessPtr->isWrite())
return false;
auto AccMap = MemAccessPtr->getLatestAccessRelation();
if (isl_map_n_basic_map(AccMap.get()) != 1 ||
!isMatMulOperandAcc(AccMap, MMI.i, MMI.j))
if (!isMatMulOperandAcc(Stmt->getDomain(), AccMap, MMI.i, MMI.j))
return false;
MMI.WriteToC = MemAccessPtr;
break;

View File

@ -0,0 +1,59 @@
; RUN: opt %loadPolly -polly-import-jscop -polly-import-jscop-postfix=transformed -polly-opt-isl -debug-only=polly-opt-isl -disable-output < %s 2>&1 | FileCheck %s
; REQUIRES: asserts
;
; void pattern_matching_based_opts_splitmap(double C[static const restrict 2][2], double A[static const restrict 2][784], double B[static const restrict 784][2]) {
; for (int i = 0; i < 2; i+=1)
; for (int j = 0; j < 2; j+=1)
; for (int k = 0; k < 784; k+=1)
; C[i][j] += A[i][k] * B[k][j];
;}
;
; Check that the pattern matching detects the matrix multiplication pattern
; when the AccMap cannot be reduced to a single disjunct.
;
; CHECK: The matrix multiplication pattern was detected
;
; ModuleID = 'pattern_matching_based_opts_splitmap.ll'
;
; Function Attrs: noinline nounwind uwtable
define void @pattern_matching_based_opts_splitmap([2 x double]* noalias dereferenceable(32) %C, [784 x double]* noalias dereferenceable(12544) %A, [2 x double]* noalias dereferenceable(12544) %B) {
entry:
br label %for.body
for.body: ; preds = %entry, %for.inc21
%i = phi i64 [ 0, %entry ], [ %add22, %for.inc21 ]
br label %for.body3
for.body3: ; preds = %for.body, %for.inc18
%j = phi i64 [ 0, %for.body ], [ %add19, %for.inc18 ]
br label %for.body6
for.body6: ; preds = %for.body3, %for.body6
%k = phi i64 [ 0, %for.body3 ], [ %add17, %for.body6 ]
%arrayidx8 = getelementptr inbounds [784 x double], [784 x double]* %A, i64 %i, i64 %k
%tmp6 = load double, double* %arrayidx8, align 8
%arrayidx12 = getelementptr inbounds [2 x double], [2 x double]* %B, i64 %k, i64 %j
%tmp10 = load double, double* %arrayidx12, align 8
%mul = fmul double %tmp6, %tmp10
%arrayidx16 = getelementptr inbounds [2 x double], [2 x double]* %C, i64 %i, i64 %j
%tmp14 = load double, double* %arrayidx16, align 8
%add = fadd double %tmp14, %mul
store double %add, double* %arrayidx16, align 8
%add17 = add nsw i64 %k, 1
%cmp5 = icmp slt i64 %add17, 784
br i1 %cmp5, label %for.body6, label %for.inc18
for.inc18: ; preds = %for.body6
%add19 = add nsw i64 %j, 1
%cmp2 = icmp slt i64 %add19, 2
br i1 %cmp2, label %for.body3, label %for.inc21
for.inc21: ; preds = %for.inc18
%add22 = add nsw i64 %i, 1
%cmp = icmp slt i64 %add22, 2
br i1 %cmp, label %for.body, label %for.end23
for.end23: ; preds = %for.inc21
ret void
}

View File

@ -0,0 +1,46 @@
{
"arrays" : [
{
"name" : "MemRef_A",
"sizes" : [ "*", "784" ],
"type" : "double"
},
{
"name" : "MemRef_B",
"sizes" : [ "*", "2" ],
"type" : "double"
},
{
"name" : "MemRef_C",
"sizes" : [ "*", "2" ],
"type" : "double"
}
],
"context" : "{ : }",
"name" : "%for.body---%for.end23",
"statements" : [
{
"accesses" : [
{
"kind" : "read",
"relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_A[i0, i2] }"
},
{
"kind" : "read",
"relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_B[i2, i1] }"
},
{
"kind" : "read",
"relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_C[i0, i1] }"
},
{
"kind" : "write",
"relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_C[i0, i1] }"
}
],
"domain" : "{ Stmt_for_body6[i0, i1, i2] : 0 <= i0 <= 1 and 0 <= i1 <= 1 and 0 <= i2 <= 783 }",
"name" : "Stmt_for_body6",
"schedule" : "{ Stmt_for_body6[i0, i1, i2] -> [i0, i1, i2] }"
}
]
}

View File

@ -0,0 +1,46 @@
{
"arrays" : [
{
"name" : "MemRef_A",
"sizes" : [ "*", "784" ],
"type" : "double"
},
{
"name" : "MemRef_B",
"sizes" : [ "*", "2" ],
"type" : "double"
},
{
"name" : "MemRef_C",
"sizes" : [ "*", "2" ],
"type" : "double"
}
],
"context" : "{ : }",
"name" : "%for.body---%for.end23",
"statements" : [
{
"accesses" : [
{
"kind" : "read",
"relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_A[i0, i2] }"
},
{
"kind" : "read",
"relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_B[i2, i1] }"
},
{
"kind" : "read",
"relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_C[i0, i1] }"
},
{
"kind" : "write",
"relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_C[i0, i1] : i2 <= 784 - i0 - i1; Stmt_for_body6[1, 1, 783] -> MemRef_C[1, 1] }"
}
],
"domain" : "{ Stmt_for_body6[i0, i1, i2] : 0 <= i0 <= 1 and 0 <= i1 <= 1 and 0 <= i2 <= 783 }",
"name" : "Stmt_for_body6",
"schedule" : "{ Stmt_for_body6[i0, i1, i2] -> [i0, i1, i2] }"
}
]
}