mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-02-24 20:44:09 +00:00
[MatMul] Make MatMul detection independent of internal isl representations.
The pattern recognition for MatMul is restrictive. The number of "disjuncts" in the isl_map containing constraint information was previously required to be 1 (as per isl_*_coalesce - which should ideally produce a domain map with a single disjunct, but does not under some circumstances). This was changed and made more flexible. Contributed-by: Annanay Agarwal <cs14btech11001@iith.ac.in> Differential Revision: https://reviews.llvm.org/D36460 llvm-svn: 311302
This commit is contained in:
parent
d6491f2c4a
commit
d091bf8d8e
@ -483,61 +483,6 @@ ScheduleTreeOptimizer::standardBandOpts(isl::schedule_node Node, void *User) {
|
||||
return Node;
|
||||
}
|
||||
|
||||
/// Get the position of a dimension with a non-zero coefficient.
|
||||
///
|
||||
/// Check that isl constraint @p Constraint has only one non-zero
|
||||
/// coefficient for dimensions that have type @p DimType. If this is true,
|
||||
/// return the position of the dimension corresponding to the non-zero
|
||||
/// coefficient and negative value, otherwise.
|
||||
///
|
||||
/// @param Constraint The isl constraint to be checked.
|
||||
/// @param DimType The type of the dimensions.
|
||||
/// @return The position of the dimension in case the isl
|
||||
/// constraint satisfies the requirements, a negative
|
||||
/// value, otherwise.
|
||||
static int getMatMulConstraintDim(isl::constraint Constraint,
|
||||
isl::dim DimType) {
|
||||
int DimPos = -1;
|
||||
auto LocalSpace = Constraint.get_local_space();
|
||||
int LocalSpaceDimNum = LocalSpace.dim(DimType);
|
||||
for (int i = 0; i < LocalSpaceDimNum; i++) {
|
||||
auto Val = Constraint.get_coefficient_val(DimType, i);
|
||||
if (Val.is_zero())
|
||||
continue;
|
||||
if (DimPos >= 0 || (DimType == isl::dim::out && !Val.is_one()) ||
|
||||
(DimType == isl::dim::in && !Val.is_negone()))
|
||||
return -1;
|
||||
DimPos = i;
|
||||
}
|
||||
return DimPos;
|
||||
}
|
||||
|
||||
/// Check the form of the isl constraint.
|
||||
///
|
||||
/// Check that the @p DimInPos input dimension of the isl constraint
|
||||
/// @p Constraint has a coefficient that is equal to negative one, the @p
|
||||
/// DimOutPos has a coefficient that is equal to one and others
|
||||
/// have coefficients equal to zero.
|
||||
///
|
||||
/// @param Constraint The isl constraint to be checked.
|
||||
/// @param DimInPos The input dimension of the isl constraint.
|
||||
/// @param DimOutPos The output dimension of the isl constraint.
|
||||
/// @return isl_stat_ok in case the isl constraint satisfies
|
||||
/// the requirements, isl_stat_error otherwise.
|
||||
static isl_stat isMatMulOperandConstraint(isl::constraint Constraint,
|
||||
int &DimInPos, int &DimOutPos) {
|
||||
auto Val = Constraint.get_constant_val();
|
||||
if (!isl_constraint_is_equality(Constraint.get()) || !Val.is_zero())
|
||||
return isl_stat_error;
|
||||
DimInPos = getMatMulConstraintDim(Constraint, isl::dim::in);
|
||||
if (DimInPos < 0)
|
||||
return isl_stat_error;
|
||||
DimOutPos = getMatMulConstraintDim(Constraint, isl::dim::out);
|
||||
if (DimOutPos < 0)
|
||||
return isl_stat_error;
|
||||
return isl_stat_ok;
|
||||
}
|
||||
|
||||
/// Permute the two dimensions of the isl map.
|
||||
///
|
||||
/// Permute @p DstPos and @p SrcPos dimensions of the isl map @p Map that
|
||||
@ -585,30 +530,49 @@ isl::map permuteDimensions(isl::map Map, isl::dim DimType, unsigned DstPos,
|
||||
/// second output dimension.
|
||||
/// @return True in case @p AccMap has the expected form and false,
|
||||
/// otherwise.
|
||||
static bool isMatMulOperandAcc(isl::map AccMap, int &FirstPos, int &SecondPos) {
|
||||
int DimInPos[] = {FirstPos, SecondPos};
|
||||
auto Lambda = [=, &DimInPos](isl::basic_map BasicMap) -> isl::stat {
|
||||
auto Constraints = BasicMap.get_constraint_list();
|
||||
if (isl_constraint_list_n_constraint(Constraints.get()) != 2)
|
||||
return isl::stat::error;
|
||||
for (int i = 0; i < 2; i++) {
|
||||
auto Constraint =
|
||||
isl::manage(isl_constraint_list_get_constraint(Constraints.get(), i));
|
||||
int InPos, OutPos;
|
||||
if (isMatMulOperandConstraint(Constraint, InPos, OutPos) ==
|
||||
isl_stat_error ||
|
||||
OutPos > 1 || (DimInPos[OutPos] >= 0 && DimInPos[OutPos] != InPos))
|
||||
return isl::stat::error;
|
||||
DimInPos[OutPos] = InPos;
|
||||
}
|
||||
return isl::stat::ok;
|
||||
};
|
||||
if (AccMap.foreach_basic_map(Lambda) != isl::stat::ok || DimInPos[0] < 0 ||
|
||||
DimInPos[1] < 0)
|
||||
static bool isMatMulOperandAcc(isl::set Domain, isl::map AccMap, int &FirstPos,
|
||||
int &SecondPos) {
|
||||
|
||||
isl::space Space = AccMap.get_space();
|
||||
isl::map Universe = isl::map::universe(Space);
|
||||
|
||||
if (Space.dim(isl::dim::out) != 2)
|
||||
return false;
|
||||
FirstPos = DimInPos[0];
|
||||
SecondPos = DimInPos[1];
|
||||
return true;
|
||||
|
||||
// MatMul has the form:
|
||||
// for (i = 0; i < N; i++)
|
||||
// for (j = 0; j < M; j++)
|
||||
// for (k = 0; k < P; k++)
|
||||
// C[i, j] += A[i, k] * B[k, j]
|
||||
//
|
||||
// Permutation of three outer loops: 3! = 6 possibilities.
|
||||
int FirstDims[] = {0, 0, 1, 1, 2, 2};
|
||||
int SecondDims[] = {1, 2, 2, 0, 0, 1};
|
||||
for (int i = 0; i < 6; i += 1) {
|
||||
auto PossibleMatMul =
|
||||
Universe.equate(isl::dim::in, FirstDims[i], isl::dim::out, 0)
|
||||
.equate(isl::dim::in, SecondDims[i], isl::dim::out, 1);
|
||||
|
||||
AccMap = AccMap.intersect_domain(Domain);
|
||||
PossibleMatMul = PossibleMatMul.intersect_domain(Domain);
|
||||
|
||||
// If AccMap spans entire domain (Non-partial write),
|
||||
// compute FirstPos and SecondPos.
|
||||
// If AccMap != PossibleMatMul here (the two maps have been gisted at
|
||||
// this point), it means that the writes are not complete, or in other
|
||||
// words, it is a Partial write and Partial writes must be rejected.
|
||||
if (AccMap.is_equal(PossibleMatMul)) {
|
||||
if (FirstPos != -1 && FirstPos != FirstDims[i])
|
||||
continue;
|
||||
FirstPos = FirstDims[i];
|
||||
if (SecondPos != -1 && SecondPos != SecondDims[i])
|
||||
continue;
|
||||
SecondPos = SecondDims[i];
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Does the memory access represent a non-scalar operand of the matrix
|
||||
@ -627,18 +591,16 @@ static bool isMatMulNonScalarReadAccess(MemoryAccess *MemAccess,
|
||||
if (!MemAccess->isLatestArrayKind() || !MemAccess->isRead())
|
||||
return false;
|
||||
auto AccMap = MemAccess->getLatestAccessRelation();
|
||||
if (isMatMulOperandAcc(AccMap, MMI.i, MMI.j) && !MMI.ReadFromC &&
|
||||
isl_map_n_basic_map(AccMap.get()) == 1) {
|
||||
isl::set StmtDomain = MemAccess->getStatement()->getDomain();
|
||||
if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.i, MMI.j) && !MMI.ReadFromC) {
|
||||
MMI.ReadFromC = MemAccess;
|
||||
return true;
|
||||
}
|
||||
if (isMatMulOperandAcc(AccMap, MMI.i, MMI.k) && !MMI.A &&
|
||||
isl_map_n_basic_map(AccMap.get()) == 1) {
|
||||
if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.i, MMI.k) && !MMI.A) {
|
||||
MMI.A = MemAccess;
|
||||
return true;
|
||||
}
|
||||
if (isMatMulOperandAcc(AccMap, MMI.k, MMI.j) && !MMI.B &&
|
||||
isl_map_n_basic_map(AccMap.get()) == 1) {
|
||||
if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.k, MMI.j) && !MMI.B) {
|
||||
MMI.B = MemAccess;
|
||||
return true;
|
||||
}
|
||||
@ -758,8 +720,7 @@ static bool containsMatrMult(isl::map PartialSchedule, const Dependences *D,
|
||||
if (!MemAccessPtr->isWrite())
|
||||
return false;
|
||||
auto AccMap = MemAccessPtr->getLatestAccessRelation();
|
||||
if (isl_map_n_basic_map(AccMap.get()) != 1 ||
|
||||
!isMatMulOperandAcc(AccMap, MMI.i, MMI.j))
|
||||
if (!isMatMulOperandAcc(Stmt->getDomain(), AccMap, MMI.i, MMI.j))
|
||||
return false;
|
||||
MMI.WriteToC = MemAccessPtr;
|
||||
break;
|
||||
|
@ -0,0 +1,59 @@
|
||||
; RUN: opt %loadPolly -polly-import-jscop -polly-import-jscop-postfix=transformed -polly-opt-isl -debug-only=polly-opt-isl -disable-output < %s 2>&1 | FileCheck %s
|
||||
; REQUIRES: asserts
|
||||
;
|
||||
; void pattern_matching_based_opts_splitmap(double C[static const restrict 2][2], double A[static const restrict 2][784], double B[static const restrict 784][2]) {
|
||||
; for (int i = 0; i < 2; i+=1)
|
||||
; for (int j = 0; j < 2; j+=1)
|
||||
; for (int k = 0; k < 784; k+=1)
|
||||
; C[i][j] += A[i][k] * B[k][j];
|
||||
;}
|
||||
;
|
||||
; Check that the pattern matching detects the matrix multiplication pattern
|
||||
; when the AccMap cannot be reduced to a single disjunct.
|
||||
;
|
||||
; CHECK: The matrix multiplication pattern was detected
|
||||
;
|
||||
; ModuleID = 'pattern_matching_based_opts_splitmap.ll'
|
||||
;
|
||||
; Function Attrs: noinline nounwind uwtable
|
||||
define void @pattern_matching_based_opts_splitmap([2 x double]* noalias dereferenceable(32) %C, [784 x double]* noalias dereferenceable(12544) %A, [2 x double]* noalias dereferenceable(12544) %B) {
|
||||
entry:
|
||||
br label %for.body
|
||||
|
||||
for.body: ; preds = %entry, %for.inc21
|
||||
%i = phi i64 [ 0, %entry ], [ %add22, %for.inc21 ]
|
||||
br label %for.body3
|
||||
|
||||
for.body3: ; preds = %for.body, %for.inc18
|
||||
%j = phi i64 [ 0, %for.body ], [ %add19, %for.inc18 ]
|
||||
br label %for.body6
|
||||
|
||||
for.body6: ; preds = %for.body3, %for.body6
|
||||
%k = phi i64 [ 0, %for.body3 ], [ %add17, %for.body6 ]
|
||||
%arrayidx8 = getelementptr inbounds [784 x double], [784 x double]* %A, i64 %i, i64 %k
|
||||
%tmp6 = load double, double* %arrayidx8, align 8
|
||||
%arrayidx12 = getelementptr inbounds [2 x double], [2 x double]* %B, i64 %k, i64 %j
|
||||
%tmp10 = load double, double* %arrayidx12, align 8
|
||||
%mul = fmul double %tmp6, %tmp10
|
||||
%arrayidx16 = getelementptr inbounds [2 x double], [2 x double]* %C, i64 %i, i64 %j
|
||||
%tmp14 = load double, double* %arrayidx16, align 8
|
||||
%add = fadd double %tmp14, %mul
|
||||
store double %add, double* %arrayidx16, align 8
|
||||
%add17 = add nsw i64 %k, 1
|
||||
%cmp5 = icmp slt i64 %add17, 784
|
||||
br i1 %cmp5, label %for.body6, label %for.inc18
|
||||
|
||||
for.inc18: ; preds = %for.body6
|
||||
%add19 = add nsw i64 %j, 1
|
||||
%cmp2 = icmp slt i64 %add19, 2
|
||||
br i1 %cmp2, label %for.body3, label %for.inc21
|
||||
|
||||
for.inc21: ; preds = %for.inc18
|
||||
%add22 = add nsw i64 %i, 1
|
||||
%cmp = icmp slt i64 %add22, 2
|
||||
br i1 %cmp, label %for.body, label %for.end23
|
||||
|
||||
for.end23: ; preds = %for.inc21
|
||||
ret void
|
||||
}
|
||||
|
@ -0,0 +1,46 @@
|
||||
{
|
||||
"arrays" : [
|
||||
{
|
||||
"name" : "MemRef_A",
|
||||
"sizes" : [ "*", "784" ],
|
||||
"type" : "double"
|
||||
},
|
||||
{
|
||||
"name" : "MemRef_B",
|
||||
"sizes" : [ "*", "2" ],
|
||||
"type" : "double"
|
||||
},
|
||||
{
|
||||
"name" : "MemRef_C",
|
||||
"sizes" : [ "*", "2" ],
|
||||
"type" : "double"
|
||||
}
|
||||
],
|
||||
"context" : "{ : }",
|
||||
"name" : "%for.body---%for.end23",
|
||||
"statements" : [
|
||||
{
|
||||
"accesses" : [
|
||||
{
|
||||
"kind" : "read",
|
||||
"relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_A[i0, i2] }"
|
||||
},
|
||||
{
|
||||
"kind" : "read",
|
||||
"relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_B[i2, i1] }"
|
||||
},
|
||||
{
|
||||
"kind" : "read",
|
||||
"relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_C[i0, i1] }"
|
||||
},
|
||||
{
|
||||
"kind" : "write",
|
||||
"relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_C[i0, i1] }"
|
||||
}
|
||||
],
|
||||
"domain" : "{ Stmt_for_body6[i0, i1, i2] : 0 <= i0 <= 1 and 0 <= i1 <= 1 and 0 <= i2 <= 783 }",
|
||||
"name" : "Stmt_for_body6",
|
||||
"schedule" : "{ Stmt_for_body6[i0, i1, i2] -> [i0, i1, i2] }"
|
||||
}
|
||||
]
|
||||
}
|
@ -0,0 +1,46 @@
|
||||
{
|
||||
"arrays" : [
|
||||
{
|
||||
"name" : "MemRef_A",
|
||||
"sizes" : [ "*", "784" ],
|
||||
"type" : "double"
|
||||
},
|
||||
{
|
||||
"name" : "MemRef_B",
|
||||
"sizes" : [ "*", "2" ],
|
||||
"type" : "double"
|
||||
},
|
||||
{
|
||||
"name" : "MemRef_C",
|
||||
"sizes" : [ "*", "2" ],
|
||||
"type" : "double"
|
||||
}
|
||||
],
|
||||
"context" : "{ : }",
|
||||
"name" : "%for.body---%for.end23",
|
||||
"statements" : [
|
||||
{
|
||||
"accesses" : [
|
||||
{
|
||||
"kind" : "read",
|
||||
"relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_A[i0, i2] }"
|
||||
},
|
||||
{
|
||||
"kind" : "read",
|
||||
"relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_B[i2, i1] }"
|
||||
},
|
||||
{
|
||||
"kind" : "read",
|
||||
"relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_C[i0, i1] }"
|
||||
},
|
||||
{
|
||||
"kind" : "write",
|
||||
"relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_C[i0, i1] : i2 <= 784 - i0 - i1; Stmt_for_body6[1, 1, 783] -> MemRef_C[1, 1] }"
|
||||
}
|
||||
],
|
||||
"domain" : "{ Stmt_for_body6[i0, i1, i2] : 0 <= i0 <= 1 and 0 <= i1 <= 1 and 0 <= i2 <= 783 }",
|
||||
"name" : "Stmt_for_body6",
|
||||
"schedule" : "{ Stmt_for_body6[i0, i1, i2] -> [i0, i1, i2] }"
|
||||
}
|
||||
]
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user