mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-01-13 19:32:41 +00:00
Determination of statements that contain matrix multiplication
Add determination of statements that contain, in particular, matrix multiplications and can be optimized with [1] to try to get close-to-peak performance. It can be enabled via polly-pm-based-opts, which is false by default. Refs: [1] - http://www.cs.utexas.edu/users/flame/pubs/TOMS-BLIS-Analytical.pdf Contributed-by: Roman Gareev <gareevroman@gmail.com> Reviewed-by: Tobias Grosser <tobias@grosser.es> Differential Revision: http://reviews.llvm.org/D20575 llvm-svn: 271128
This commit is contained in:
parent
395eca8d26
commit
9c3eb5960a
@ -147,8 +147,45 @@ private:
|
||||
/// - if vectorization is enabled
|
||||
///
|
||||
/// @param Node The schedule node to (possibly) optimize.
|
||||
/// @param User A pointer to forward some use information (currently unused).
|
||||
/// @param User A pointer to forward some use information
|
||||
/// (currently unused).
|
||||
static isl_schedule_node *optimizeBand(isl_schedule_node *Node, void *User);
|
||||
|
||||
/// @brief Apply additional optimizations on the bands in the schedule tree.
|
||||
///
|
||||
/// We apply the following
|
||||
/// transformations:
|
||||
///
|
||||
/// - Tile the band
|
||||
/// - Prevectorize the schedule of the band (or the point loop in case of
|
||||
/// tiling).
|
||||
/// - if vectorization is enabled
|
||||
///
|
||||
/// @param Node The schedule node to (possibly) optimize.
|
||||
/// @param User A pointer to forward some use information
|
||||
/// (currently unused).
|
||||
static isl_schedule_node *standardBandOpts(__isl_take isl_schedule_node *Node,
|
||||
void *User);
|
||||
|
||||
/// @brief Check if this node contains a partial schedule that could
|
||||
/// probably be optimized with analytical modeling.
|
||||
///
|
||||
/// isMatrMultPattern tries to determine whether the following conditions
|
||||
/// are true:
|
||||
/// 1. the partial schedule contains only one statement.
|
||||
/// 2. there are exactly three input dimensions.
|
||||
/// 3. all memory accesses of the statement will have stride 0 or 1, if we
|
||||
/// interchange loops (switch the variable used in the inner loop to
|
||||
/// the outer loop).
|
||||
/// 4. all memory accesses of the statement except from the last one, are
|
||||
/// read memory access and the last one is write memory access.
|
||||
/// 5. all subscripts of the last memory access of the statement don’t
|
||||
/// contain the variable used in the inner loop.
|
||||
/// If this is the case, we could try to use an approach that is similar to
|
||||
/// the one used to get close-to-peak performance of matrix multiplications.
|
||||
///
|
||||
/// @param Node The node to check.
|
||||
static bool isMatrMultPattern(__isl_keep isl_schedule_node *Node);
|
||||
};
|
||||
|
||||
#endif
|
||||
|
@ -166,6 +166,11 @@ static cl::list<int>
|
||||
cl::Hidden, cl::ZeroOrMore, cl::CommaSeparated,
|
||||
cl::cat(PollyCategory));
|
||||
|
||||
static cl::opt<bool>
|
||||
PMBasedOpts("polly-pattern-matching-based-opts",
|
||||
cl::desc("Perform optimizations based on pattern matching"),
|
||||
cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory));
|
||||
|
||||
/// @brief Create an isl_union_set, which describes the isolate option based
|
||||
/// on IsoalteDomain.
|
||||
///
|
||||
@ -359,11 +364,8 @@ bool ScheduleTreeOptimizer::isTileableBandNode(
|
||||
}
|
||||
|
||||
__isl_give isl_schedule_node *
|
||||
ScheduleTreeOptimizer::optimizeBand(__isl_take isl_schedule_node *Node,
|
||||
void *User) {
|
||||
if (!isTileableBandNode(Node))
|
||||
return Node;
|
||||
|
||||
ScheduleTreeOptimizer::standardBandOpts(__isl_take isl_schedule_node *Node,
|
||||
void *User) {
|
||||
if (FirstLevelTiling)
|
||||
Node = tileNode(Node, "1st level tiling", FirstLevelTileSizes,
|
||||
FirstLevelDefaultTileSize);
|
||||
@ -396,6 +398,110 @@ ScheduleTreeOptimizer::optimizeBand(__isl_take isl_schedule_node *Node,
|
||||
return Node;
|
||||
}
|
||||
|
||||
/// @brief Check whether output dimensions of the map rely on the specified
|
||||
/// input dimension.
|
||||
///
|
||||
/// @param IslMap The isl map to be considered.
|
||||
/// @param DimNum The number of an input dimension to be checked.
|
||||
static bool isInputDimUsed(__isl_take isl_map *IslMap, unsigned DimNum) {
|
||||
auto *CheckedAccessRelation =
|
||||
isl_map_project_out(isl_map_copy(IslMap), isl_dim_in, DimNum, 1);
|
||||
CheckedAccessRelation =
|
||||
isl_map_insert_dims(CheckedAccessRelation, isl_dim_in, DimNum, 1);
|
||||
auto *InputDimsId = isl_map_get_tuple_id(IslMap, isl_dim_in);
|
||||
CheckedAccessRelation =
|
||||
isl_map_set_tuple_id(CheckedAccessRelation, isl_dim_in, InputDimsId);
|
||||
InputDimsId = isl_map_get_tuple_id(IslMap, isl_dim_out);
|
||||
CheckedAccessRelation =
|
||||
isl_map_set_tuple_id(CheckedAccessRelation, isl_dim_out, InputDimsId);
|
||||
auto res = !isl_map_is_equal(CheckedAccessRelation, IslMap);
|
||||
isl_map_free(CheckedAccessRelation);
|
||||
isl_map_free(IslMap);
|
||||
return res;
|
||||
}
|
||||
|
||||
/// @brief Check if the SCoP statement could probably be optimized with
|
||||
/// analytical modeling.
|
||||
///
|
||||
/// containsMatrMult tries to determine whether the following conditions
|
||||
/// are true:
|
||||
/// 1. all memory accesses of the statement will have stride 0 or 1,
|
||||
/// if we interchange loops (switch the variable used in the inner
|
||||
/// loop to the outer loop).
|
||||
/// 2. all memory accesses of the statement except from the last one, are
|
||||
/// read memory access and the last one is write memory access.
|
||||
/// 3. all subscripts of the last memory access of the statement don’t contain
|
||||
/// the variable used in the inner loop.
|
||||
///
|
||||
/// @param PartialSchedule The PartialSchedule that contains a SCoP statement
|
||||
/// to check.
|
||||
static bool containsMatrMult(__isl_keep isl_map *PartialSchedule) {
|
||||
auto InputDimsId = isl_map_get_tuple_id(PartialSchedule, isl_dim_in);
|
||||
auto *ScpStmt = static_cast<ScopStmt *>(isl_id_get_user(InputDimsId));
|
||||
isl_id_free(InputDimsId);
|
||||
if (ScpStmt->size() <= 1)
|
||||
return false;
|
||||
auto MemA = ScpStmt->begin();
|
||||
for (unsigned i = 0; i < ScpStmt->size() - 2 && MemA != ScpStmt->end();
|
||||
i++, MemA++)
|
||||
if (!(*MemA)->isRead() or
|
||||
((*MemA)->isArrayKind() and
|
||||
!((*MemA)->isStrideOne(isl_map_copy(PartialSchedule)) or
|
||||
(*MemA)->isStrideZero(isl_map_copy(PartialSchedule)))))
|
||||
return false;
|
||||
MemA++;
|
||||
if (!(*MemA)->isWrite() or !(*MemA)->isArrayKind() or
|
||||
!((*MemA)->isStrideOne(isl_map_copy(PartialSchedule)) or
|
||||
(*MemA)->isStrideZero(isl_map_copy(PartialSchedule))))
|
||||
return false;
|
||||
auto DimNum = isl_map_dim(PartialSchedule, isl_dim_in);
|
||||
return !isInputDimUsed((*MemA)->getAccessRelation(), DimNum - 1);
|
||||
}
|
||||
|
||||
/// @brief Circular shift of output dimensions of the integer map.
|
||||
///
|
||||
/// @param IslMap The isl map to be modified.
|
||||
static __isl_give isl_map *circularShiftOutputDims(__isl_take isl_map *IslMap) {
|
||||
auto InputDimsId = isl_map_get_tuple_id(IslMap, isl_dim_in);
|
||||
auto DimNum = isl_map_dim(IslMap, isl_dim_out);
|
||||
IslMap = isl_map_move_dims(IslMap, isl_dim_in, 0, isl_dim_out, DimNum - 1, 1);
|
||||
IslMap = isl_map_move_dims(IslMap, isl_dim_out, 0, isl_dim_in, 0, 1);
|
||||
return isl_map_set_tuple_id(IslMap, isl_dim_in, InputDimsId);
|
||||
}
|
||||
|
||||
bool ScheduleTreeOptimizer::isMatrMultPattern(
|
||||
__isl_keep isl_schedule_node *Node) {
|
||||
auto *PartialSchedule =
|
||||
isl_schedule_node_band_get_partial_schedule_union_map(Node);
|
||||
if (isl_union_map_n_map(PartialSchedule) != 1)
|
||||
return false;
|
||||
auto *NewPartialSchedule = isl_map_from_union_map(PartialSchedule);
|
||||
auto DimNum = isl_map_dim(NewPartialSchedule, isl_dim_in);
|
||||
if (DimNum != 3) {
|
||||
isl_map_free(NewPartialSchedule);
|
||||
return false;
|
||||
}
|
||||
NewPartialSchedule = circularShiftOutputDims(NewPartialSchedule);
|
||||
if (containsMatrMult(NewPartialSchedule)) {
|
||||
isl_map_free(NewPartialSchedule);
|
||||
return true;
|
||||
}
|
||||
isl_map_free(NewPartialSchedule);
|
||||
return false;
|
||||
}
|
||||
|
||||
__isl_give isl_schedule_node *
|
||||
ScheduleTreeOptimizer::optimizeBand(__isl_take isl_schedule_node *Node,
|
||||
void *User) {
|
||||
if (!isTileableBandNode(Node))
|
||||
return Node;
|
||||
|
||||
if (PMBasedOpts && isMatrMultPattern(Node))
|
||||
DEBUG(dbgs() << "The matrix multiplication pattern was detected\n");
|
||||
|
||||
return standardBandOpts(Node, User);
|
||||
}
|
||||
|
||||
__isl_give isl_schedule *
|
||||
ScheduleTreeOptimizer::optimizeSchedule(__isl_take isl_schedule *Schedule) {
|
||||
isl_schedule_node *Root = isl_schedule_get_root(Schedule);
|
||||
|
65
polly/test/ScheduleOptimizer/pattern-matching-based-opts.ll
Normal file
65
polly/test/ScheduleOptimizer/pattern-matching-based-opts.ll
Normal file
@ -0,0 +1,65 @@
|
||||
; RUN: opt %loadPolly -polly-opt-isl -debug < %s 2>&1| FileCheck %s
|
||||
; RUN: opt %loadPolly -polly-opt-isl -polly-pattern-matching-based-opts=true -debug < %s 2>&1| FileCheck %s --check-prefix=PATTERN-MATCHING-OPTS
|
||||
; REQUIRES: asserts
|
||||
; CHECK-NOT: The matrix multiplication pattern was detected
|
||||
; PATTERN-MATCHING-OPTS: The matrix multiplication pattern was detected
|
||||
|
||||
define internal void @kernel_gemm(i32 %arg, i32 %arg1, i32 %arg2, double %arg3, double %arg4, [1056 x double]* %arg5, [1024 x double]* %arg6, [1056 x double]* %arg7) {
|
||||
bb:
|
||||
br label %bb8
|
||||
|
||||
bb8: ; preds = %bb39, %bb
|
||||
%tmp = phi i32 [ 0, %bb ], [ %tmp40, %bb39 ]
|
||||
%tmp9 = icmp slt i32 %tmp, 1056
|
||||
br i1 %tmp9, label %bb10, label %bb41
|
||||
|
||||
bb10: ; preds = %bb8
|
||||
br label %bb11
|
||||
|
||||
bb11: ; preds = %bb37, %bb10
|
||||
%tmp12 = phi i32 [ 0, %bb10 ], [ %tmp38, %bb37 ]
|
||||
%tmp13 = icmp slt i32 %tmp12, 1056
|
||||
br i1 %tmp13, label %bb14, label %bb39
|
||||
|
||||
bb14: ; preds = %bb11
|
||||
%tmp15 = sext i32 %tmp12 to i64
|
||||
%tmp16 = sext i32 %tmp to i64
|
||||
%tmp17 = getelementptr inbounds [1056 x double], [1056 x double]* %arg5, i64 %tmp16
|
||||
%tmp18 = getelementptr inbounds [1056 x double], [1056 x double]* %tmp17, i64 0, i64 %tmp15
|
||||
%tmp19 = load double, double* %tmp18, align 8
|
||||
%tmp20 = fmul double %tmp19, %arg4
|
||||
store double %tmp20, double* %tmp18, align 8
|
||||
br label %bb21
|
||||
|
||||
bb21: ; preds = %bb24, %bb14
|
||||
%tmp22 = phi i32 [ 0, %bb14 ], [ %tmp36, %bb24 ]
|
||||
%tmp23 = icmp slt i32 %tmp22, 1024
|
||||
br i1 %tmp23, label %bb24, label %bb37
|
||||
|
||||
bb24: ; preds = %bb21
|
||||
%tmp25 = sext i32 %tmp22 to i64
|
||||
%tmp26 = getelementptr inbounds [1024 x double], [1024 x double]* %arg6, i64 %tmp16
|
||||
%tmp27 = getelementptr inbounds [1024 x double], [1024 x double]* %tmp26, i64 0, i64 %tmp25
|
||||
%tmp28 = load double, double* %tmp27, align 8
|
||||
%tmp29 = fmul double %arg3, %tmp28
|
||||
%tmp30 = getelementptr inbounds [1056 x double], [1056 x double]* %arg7, i64 %tmp25
|
||||
%tmp31 = getelementptr inbounds [1056 x double], [1056 x double]* %tmp30, i64 0, i64 %tmp15
|
||||
%tmp32 = load double, double* %tmp31, align 8
|
||||
%tmp33 = fmul double %tmp29, %tmp32
|
||||
%tmp34 = load double, double* %tmp18, align 8
|
||||
%tmp35 = fadd double %tmp34, %tmp33
|
||||
store double %tmp35, double* %tmp18, align 8
|
||||
%tmp36 = add nsw i32 %tmp22, 1
|
||||
br label %bb21
|
||||
|
||||
bb37: ; preds = %bb21
|
||||
%tmp38 = add nsw i32 %tmp12, 1
|
||||
br label %bb11
|
||||
|
||||
bb39: ; preds = %bb11
|
||||
%tmp40 = add nsw i32 %tmp, 1
|
||||
br label %bb8
|
||||
|
||||
bb41: ; preds = %bb8
|
||||
ret void
|
||||
}
|
@ -0,0 +1,63 @@
|
||||
; RUN: opt %loadPolly -polly-opt-isl -polly-pattern-matching-based-opts=true -debug < %s 2>&1 | FileCheck %s
|
||||
; REQUIRES: asserts
|
||||
; CHECK-NOT: The matrix multiplication pattern was detected
|
||||
|
||||
define internal void @kernel_gemm(i32 %arg, i32 %arg1, i32 %arg2, double %arg3, double %arg4, [1056 x double]* %arg5, [1024 x double]* %arg6, [1056 x double]* %arg7) {
|
||||
bb:
|
||||
br label %bb8
|
||||
|
||||
bb8: ; preds = %bb39, %bb
|
||||
%tmp = phi i32 [ 0, %bb ], [ %tmp40, %bb39 ]
|
||||
%tmp9 = icmp slt i32 %tmp, 1056
|
||||
br i1 %tmp9, label %bb10, label %bb41
|
||||
|
||||
bb10: ; preds = %bb8
|
||||
br label %bb11
|
||||
|
||||
bb11: ; preds = %bb37, %bb10
|
||||
%tmp12 = phi i32 [ 0, %bb10 ], [ %tmp38, %bb37 ]
|
||||
%tmp13 = icmp slt i32 %tmp12, 1056
|
||||
br i1 %tmp13, label %bb14, label %bb39
|
||||
|
||||
bb14: ; preds = %bb11
|
||||
%tmp15 = sext i32 %tmp12 to i64
|
||||
%tmp16 = sext i32 %tmp to i64
|
||||
%tmp17 = getelementptr inbounds [1056 x double], [1056 x double]* %arg5, i64 %tmp16
|
||||
%tmp18 = getelementptr inbounds [1056 x double], [1056 x double]* %tmp17, i64 0, i64 %tmp15
|
||||
%tmp19 = load double, double* %tmp18, align 8
|
||||
%tmp20 = fmul double %tmp19, %arg4
|
||||
store double %tmp20, double* %tmp18, align 8
|
||||
br label %bb21
|
||||
|
||||
bb21: ; preds = %bb24, %bb14
|
||||
%tmp22 = phi i32 [ 0, %bb14 ], [ %tmp36, %bb24 ]
|
||||
%tmp23 = icmp slt i32 %tmp22, 1024
|
||||
br i1 %tmp23, label %bb24, label %bb37
|
||||
|
||||
bb24: ; preds = %bb21
|
||||
%tmp25 = sext i32 %tmp22 to i64
|
||||
%tmp26 = getelementptr inbounds [1024 x double], [1024 x double]* %arg6, i64 %tmp16
|
||||
%tmp27 = getelementptr inbounds [1024 x double], [1024 x double]* %tmp26, i64 0, i64 %tmp25
|
||||
%tmp28 = load double, double* %tmp27, align 8
|
||||
%tmp29 = fmul double %arg3, %tmp28
|
||||
%tmp30 = getelementptr inbounds [1056 x double], [1056 x double]* %arg7, i64 %tmp25
|
||||
%tmp31 = getelementptr inbounds [1056 x double], [1056 x double]* %tmp30, i64 0, i64 %tmp15
|
||||
%tmp32 = load double, double* %tmp31, align 8
|
||||
%tmp33 = fmul double %tmp29, %tmp32
|
||||
%tmp34 = load double, double* %tmp18, align 8
|
||||
%tmp35 = fadd double %tmp34, %tmp33
|
||||
store double %tmp35, double* %tmp18, align 8
|
||||
%tmp36 = add nsw i32 %tmp22, 1
|
||||
br label %bb21
|
||||
|
||||
bb37: ; preds = %bb21
|
||||
%tmp38 = add nsw i32 %tmp12, 2
|
||||
br label %bb11
|
||||
|
||||
bb39: ; preds = %bb11
|
||||
%tmp40 = add nsw i32 %tmp, 1
|
||||
br label %bb8
|
||||
|
||||
bb41: ; preds = %bb8
|
||||
ret void
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user