mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-01-13 11:22:03 +00:00
[mlir][OpenMP] Added ReductionClauseInterface
This patch adds the ReductionClauseInterface and also adds reduction support for `omp.parallel` operation. Reviewed By: kiranchandramohan Differential Revision: https://reviews.llvm.org/D122402
This commit is contained in:
parent
1f52d02ceb
commit
fcbf00f098
@ -205,7 +205,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
|
||||
// Create and insert the operation.
|
||||
auto parallelOp = firOpBuilder.create<mlir::omp::ParallelOp>(
|
||||
currentLocation, argTy, ifClauseOperand, numThreadsClauseOperand,
|
||||
ValueRange(), ValueRange(),
|
||||
/*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(),
|
||||
/*reduction_vars=*/ValueRange(), /*reductions=*/nullptr,
|
||||
procBindClauseOperand.dyn_cast_or_null<omp::ClauseProcBindKindAttr>());
|
||||
// Handle attribute based clauses.
|
||||
for (const auto &clause : parallelOpClauseList.v) {
|
||||
|
@ -66,7 +66,7 @@ def OpenMP_PointerLikeType : Type<
|
||||
def ParallelOp : OpenMP_Op<"parallel", [
|
||||
AutomaticAllocationScope, AttrSizedOperandSegments,
|
||||
DeclareOpInterfaceMethods<OutlineableOpenMPOpInterface>,
|
||||
RecursiveSideEffects]> {
|
||||
RecursiveSideEffects, ReductionClauseInterface]> {
|
||||
let summary = "parallel construct";
|
||||
let description = [{
|
||||
The parallel construct includes a region of code which is to be executed
|
||||
@ -83,6 +83,18 @@ def ParallelOp : OpenMP_Op<"parallel", [
|
||||
The $allocators_vars and $allocate_vars parameters are a variadic list of values
|
||||
that specify the memory allocator to be used to obtain storage for private values.
|
||||
|
||||
Reductions can be performed in a parallel construct by specifying reduction
|
||||
accumulator variables in `reduction_vars` and symbols referring to reduction
|
||||
declarations in the `reductions` attribute. Each reduction is identified
|
||||
by the accumulator it uses and accumulators must not be repeated in the same
|
||||
reduction. The `omp.reduction` operation accepts the accumulator and a
|
||||
partial value which is considered to be produced by the thread for the
|
||||
given reduction. If multiple values are produced for the same accumulator,
|
||||
i.e. there are multiple `omp.reduction`s, the last value is taken. The
|
||||
reduction declaration specifies how to combine the values from each thread
|
||||
into the final value, which is available in the accumulator after all the
|
||||
threads complete.
|
||||
|
||||
The optional $proc_bind_val attribute controls the thread affinity for the execution
|
||||
of the parallel region.
|
||||
}];
|
||||
@ -91,6 +103,8 @@ def ParallelOp : OpenMP_Op<"parallel", [
|
||||
Optional<AnyType>:$num_threads_var,
|
||||
Variadic<AnyType>:$allocate_vars,
|
||||
Variadic<AnyType>:$allocators_vars,
|
||||
Variadic<OpenMP_PointerLikeType>:$reduction_vars,
|
||||
OptionalAttr<SymbolRefArrayAttr>:$reductions,
|
||||
OptionalAttr<ProcBindKindAttr>:$proc_bind_val);
|
||||
|
||||
let regions = (region AnyRegion:$region);
|
||||
@ -99,7 +113,11 @@ def ParallelOp : OpenMP_Op<"parallel", [
|
||||
OpBuilder<(ins CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
|
||||
];
|
||||
let assemblyFormat = [{
|
||||
oilist( `if` `(` $if_expr_var `:` type($if_expr_var) `)`
|
||||
oilist( `reduction` `(`
|
||||
custom<ReductionVarList>(
|
||||
$reduction_vars, type($reduction_vars), $reductions
|
||||
) `)`
|
||||
| `if` `(` $if_expr_var `:` type($if_expr_var) `)`
|
||||
| `num_threads` `(` $num_threads_var `:` type($num_threads_var) `)`
|
||||
| `allocate` `(`
|
||||
custom<AllocateAndAllocator>(
|
||||
@ -110,6 +128,12 @@ def ParallelOp : OpenMP_Op<"parallel", [
|
||||
) $region attr-dict
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
let extraClassDeclaration = [{
|
||||
// TODO: remove this once emitAccessorPrefix is set to
|
||||
// kEmitAccessorPrefix_Prefixed for the dialect.
|
||||
/// Returns the reduction variables
|
||||
operand_range getReductionVars() { return reduction_vars(); }
|
||||
}];
|
||||
}
|
||||
|
||||
def TerminatorOp : OpenMP_Op<"terminator", [Terminator]> {
|
||||
@ -156,7 +180,8 @@ def SectionOp : OpenMP_Op<"section", [HasParent<"SectionsOp">]> {
|
||||
let assemblyFormat = "$region attr-dict";
|
||||
}
|
||||
|
||||
def SectionsOp : OpenMP_Op<"sections", [AttrSizedOperandSegments]> {
|
||||
def SectionsOp : OpenMP_Op<"sections", [AttrSizedOperandSegments,
|
||||
ReductionClauseInterface]> {
|
||||
let summary = "sections construct";
|
||||
let description = [{
|
||||
The sections construct is a non-iterative worksharing construct that
|
||||
@ -207,6 +232,13 @@ def SectionsOp : OpenMP_Op<"sections", [AttrSizedOperandSegments]> {
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasRegionVerifier = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// TODO: remove this once emitAccessorPrefix is set to
|
||||
// kEmitAccessorPrefix_Prefixed for the dialect.
|
||||
/// Returns the reduction variables
|
||||
operand_range getReductionVars() { return reduction_vars(); }
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -247,7 +279,7 @@ def SingleOp : OpenMP_Op<"single", [AttrSizedOperandSegments]> {
|
||||
|
||||
def WsLoopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
|
||||
AllTypesMatch<["lowerBound", "upperBound", "step"]>,
|
||||
RecursiveSideEffects]> {
|
||||
RecursiveSideEffects, ReductionClauseInterface]> {
|
||||
let summary = "workshare loop construct";
|
||||
let description = [{
|
||||
The workshare loop construct specifies that the iterations of the loop(s)
|
||||
@ -338,6 +370,11 @@ def WsLoopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
|
||||
|
||||
/// Returns the number of reduction variables.
|
||||
unsigned getNumReductionVars() { return reduction_vars().size(); }
|
||||
|
||||
// TODO: remove this once emitAccessorPrefix is set to
|
||||
// kEmitAccessorPrefix_Prefixed for the dialect.
|
||||
/// Returns the reduction variables
|
||||
operand_range getReductionVars() { return reduction_vars(); }
|
||||
}];
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let assemblyFormat = [{
|
||||
|
@ -31,4 +31,18 @@ def OutlineableOpenMPOpInterface : OpInterface<"OutlineableOpenMPOpInterface"> {
|
||||
];
|
||||
}
|
||||
|
||||
def ReductionClauseInterface : OpInterface<"ReductionClauseInterface"> {
|
||||
let description = [{
|
||||
OpenMP operations that support reduction clause have this interface.
|
||||
}];
|
||||
|
||||
let cppNamespace = "::mlir::omp";
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<
|
||||
"Get reduction vars", "::mlir::Operation::operand_range",
|
||||
"getReductionVars">,
|
||||
];
|
||||
}
|
||||
|
||||
#endif // OpenMP_OPS_INTERFACES
|
||||
|
@ -27,6 +27,7 @@
|
||||
|
||||
#include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
|
||||
#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
|
||||
#include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
|
||||
#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
|
||||
|
||||
using namespace mlir;
|
||||
@ -58,19 +59,6 @@ void OpenMPDialect::initialize() {
|
||||
MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ParallelOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void ParallelOp::build(OpBuilder &builder, OperationState &state,
|
||||
ArrayRef<NamedAttribute> attributes) {
|
||||
ParallelOp::build(
|
||||
builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
|
||||
/*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(),
|
||||
/*proc_bind_val=*/nullptr);
|
||||
state.addAttributes(attributes);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Parser and printer for Allocate Clause
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -142,13 +130,6 @@ void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
|
||||
p << stringifyEnum(attr.getValue());
|
||||
}
|
||||
|
||||
LogicalResult ParallelOp::verify() {
|
||||
if (allocate_vars().size() != allocators_vars().size())
|
||||
return emitError(
|
||||
"expected equal sizes for allocate and allocator variables");
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Parser and printer for Linear Clause
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -469,6 +450,27 @@ static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ParallelOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void ParallelOp::build(OpBuilder &builder, OperationState &state,
|
||||
ArrayRef<NamedAttribute> attributes) {
|
||||
ParallelOp::build(
|
||||
builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
|
||||
/*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(),
|
||||
/*reduction_vars=*/ValueRange(), /*reductions=*/nullptr,
|
||||
/*proc_bind_val=*/nullptr);
|
||||
state.addAttributes(attributes);
|
||||
}
|
||||
|
||||
LogicalResult ParallelOp::verify() {
|
||||
if (allocate_vars().size() != allocators_vars().size())
|
||||
return emitError(
|
||||
"expected equal sizes for allocate and allocator variables");
|
||||
return verifyReductionVarList(*this, reductions(), reduction_vars());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Verifier for SectionsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -709,13 +711,17 @@ LogicalResult ReductionDeclareOp::verifyRegions() {
|
||||
}
|
||||
|
||||
LogicalResult ReductionOp::verify() {
|
||||
// TODO: generalize this to an op interface when there is more than one op
|
||||
// that supports reductions.
|
||||
auto container = (*this)->getParentOfType<WsLoopOp>();
|
||||
for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i)
|
||||
if (container.reduction_vars()[i] == accumulator())
|
||||
return success();
|
||||
|
||||
auto *op = (*this)->getParentWithTrait<ReductionClauseInterface::Trait>();
|
||||
if (!op)
|
||||
return emitOpError() << "must be used within an operation supporting "
|
||||
"reduction clause interface";
|
||||
while (op) {
|
||||
for (const auto &var :
|
||||
cast<ReductionClauseInterface>(op).getReductionVars())
|
||||
if (var == accumulator())
|
||||
return success();
|
||||
op = op->getParentWithTrait<ReductionClauseInterface::Trait>();
|
||||
}
|
||||
return emitOpError() << "the accumulator is not used by the parent";
|
||||
}
|
||||
|
||||
|
@ -59,7 +59,7 @@ func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : si32)
|
||||
// CHECK: omp.parallel num_threads(%{{.*}} : si32) allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
|
||||
"omp.parallel"(%num_threads, %data_var, %data_var) ({
|
||||
omp.terminator
|
||||
}) {operand_segment_sizes = dense<[0,1,1,1]>: vector<4xi32>} : (si32, memref<i32>, memref<i32>) -> ()
|
||||
}) {operand_segment_sizes = dense<[0,1,1,1,0]> : vector<5xi32>} : (si32, memref<i32>, memref<i32>) -> ()
|
||||
|
||||
// CHECK: omp.barrier
|
||||
omp.barrier
|
||||
@ -68,22 +68,22 @@ func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : si32)
|
||||
// CHECK: omp.parallel if(%{{.*}}) allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
|
||||
"omp.parallel"(%if_cond, %data_var, %data_var) ({
|
||||
omp.terminator
|
||||
}) {operand_segment_sizes = dense<[1,0,1,1]> : vector<4xi32>} : (i1, memref<i32>, memref<i32>) -> ()
|
||||
}) {operand_segment_sizes = dense<[1,0,1,1,0]> : vector<5xi32>} : (i1, memref<i32>, memref<i32>) -> ()
|
||||
|
||||
// test without allocate
|
||||
// CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : si32)
|
||||
"omp.parallel"(%if_cond, %num_threads) ({
|
||||
omp.terminator
|
||||
}) {operand_segment_sizes = dense<[1,1,0,0]> : vector<4xi32>} : (i1, si32) -> ()
|
||||
}) {operand_segment_sizes = dense<[1,1,0,0,0]> : vector<5xi32>} : (i1, si32) -> ()
|
||||
|
||||
omp.terminator
|
||||
}) {operand_segment_sizes = dense<[1,1,1,1]> : vector<4xi32>, proc_bind_val = #omp<"procbindkind spread">} : (i1, si32, memref<i32>, memref<i32>) -> ()
|
||||
}) {operand_segment_sizes = dense<[1,1,1,1,0]> : vector<5xi32>, proc_bind_val = #omp<"procbindkind spread">} : (i1, si32, memref<i32>, memref<i32>) -> ()
|
||||
|
||||
// test with multiple parameters for single variadic argument
|
||||
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
|
||||
"omp.parallel" (%data_var, %data_var) ({
|
||||
omp.terminator
|
||||
}) {operand_segment_sizes = dense<[0,0,1,1]> : vector<4xi32>} : (memref<i32>, memref<i32>) -> ()
|
||||
}) {operand_segment_sizes = dense<[0,0,1,1,0]> : vector<5xi32>} : (memref<i32>, memref<i32>) -> ()
|
||||
|
||||
return
|
||||
}
|
||||
@ -407,7 +407,8 @@ atomic {
|
||||
omp.yield
|
||||
}
|
||||
|
||||
func @reduction(%lb : index, %ub : index, %step : index) {
|
||||
// CHECK-LABEL: func @wsloop_reduction
|
||||
func @wsloop_reduction(%lb : index, %ub : index, %step : index) {
|
||||
%c1 = arith.constant 1 : i32
|
||||
%0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr<f32>
|
||||
// CHECK: reduction(@add_f32 -> %{{.+}} : !llvm.ptr<f32>)
|
||||
@ -421,6 +422,65 @@ func @reduction(%lb : index, %ub : index, %step : index) {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @parallel_reduction
|
||||
func @parallel_reduction() {
|
||||
%c1 = arith.constant 1 : i32
|
||||
%0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr<f32>
|
||||
// CHECK: omp.parallel reduction(@add_f32 -> {{.+}} : !llvm.ptr<f32>)
|
||||
omp.parallel reduction(@add_f32 -> %0 : !llvm.ptr<f32>) {
|
||||
%1 = arith.constant 2.0 : f32
|
||||
// CHECK: omp.reduction %{{.+}}, %{{.+}}
|
||||
omp.reduction %1, %0 : !llvm.ptr<f32>
|
||||
omp.terminator
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: func @parallel_wsloop_reduction
|
||||
func @parallel_wsloop_reduction(%lb : index, %ub : index, %step : index) {
|
||||
%c1 = arith.constant 1 : i32
|
||||
%0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr<f32>
|
||||
// CHECK: omp.parallel reduction(@add_f32 -> %{{.+}} : !llvm.ptr<f32>) {
|
||||
omp.parallel reduction(@add_f32 -> %0 : !llvm.ptr<f32>) {
|
||||
// CHECK: omp.wsloop for (%{{.+}}) : index = (%{{.+}}) to (%{{.+}}) step (%{{.+}})
|
||||
omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) {
|
||||
%1 = arith.constant 2.0 : f32
|
||||
// CHECK: omp.reduction %{{.+}}, %{{.+}} : !llvm.ptr<f32>
|
||||
omp.reduction %1, %0 : !llvm.ptr<f32>
|
||||
// CHECK: omp.yield
|
||||
omp.yield
|
||||
}
|
||||
// CHECK: omp.terminator
|
||||
omp.terminator
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @sections_reduction
|
||||
func @sections_reduction() {
|
||||
%c1 = arith.constant 1 : i32
|
||||
%0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr<f32>
|
||||
// CHECK: omp.sections reduction(@add_f32 -> {{.+}} : !llvm.ptr<f32>)
|
||||
omp.sections reduction(@add_f32 -> %0 : !llvm.ptr<f32>) {
|
||||
// CHECK: omp.section
|
||||
omp.section {
|
||||
%1 = arith.constant 2.0 : f32
|
||||
// CHECK: omp.reduction %{{.+}}, %{{.+}}
|
||||
omp.reduction %1, %0 : !llvm.ptr<f32>
|
||||
omp.terminator
|
||||
}
|
||||
// CHECK: omp.section
|
||||
omp.section {
|
||||
%1 = arith.constant 3.0 : f32
|
||||
// CHECK: omp.reduction %{{.+}}, %{{.+}}
|
||||
omp.reduction %1, %0 : !llvm.ptr<f32>
|
||||
omp.terminator
|
||||
}
|
||||
omp.terminator
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: omp.reduction.declare
|
||||
// CHECK-LABEL: @add2_f32
|
||||
omp.reduction.declare @add2_f32 : f32
|
||||
@ -438,9 +498,10 @@ combiner {
|
||||
}
|
||||
// CHECK-NOT: atomic
|
||||
|
||||
func @reduction2(%lb : index, %ub : index, %step : index) {
|
||||
// CHECK-LABEL: func @wsloop_reduction2
|
||||
func @wsloop_reduction2(%lb : index, %ub : index, %step : index) {
|
||||
%0 = memref.alloca() : memref<1xf32>
|
||||
// CHECK: reduction
|
||||
// CHECK: omp.wsloop reduction(@add2_f32 -> %{{.+}} : memref<1xf32>)
|
||||
omp.wsloop reduction(@add2_f32 -> %0 : memref<1xf32>)
|
||||
for (%iv) : index = (%lb) to (%ub) step (%step) {
|
||||
%1 = arith.constant 2.0 : f32
|
||||
@ -451,6 +512,61 @@ func @reduction2(%lb : index, %ub : index, %step : index) {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @parallel_reduction2
|
||||
func @parallel_reduction2() {
|
||||
%0 = memref.alloca() : memref<1xf32>
|
||||
// CHECK: omp.parallel reduction(@add2_f32 -> %{{.+}} : memref<1xf32>)
|
||||
omp.parallel reduction(@add2_f32 -> %0 : memref<1xf32>) {
|
||||
%1 = arith.constant 2.0 : f32
|
||||
// CHECK: omp.reduction
|
||||
omp.reduction %1, %0 : memref<1xf32>
|
||||
omp.terminator
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: func @parallel_wsloop_reduction2
|
||||
func @parallel_wsloop_reduction2(%lb : index, %ub : index, %step : index) {
|
||||
%c1 = arith.constant 1 : i32
|
||||
%0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr<f32>
|
||||
// CHECK: omp.parallel reduction(@add2_f32 -> %{{.+}} : !llvm.ptr<f32>) {
|
||||
omp.parallel reduction(@add2_f32 -> %0 : !llvm.ptr<f32>) {
|
||||
// CHECK: omp.wsloop for (%{{.+}}) : index = (%{{.+}}) to (%{{.+}}) step (%{{.+}})
|
||||
omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) {
|
||||
%1 = arith.constant 2.0 : f32
|
||||
// CHECK: omp.reduction %{{.+}}, %{{.+}} : !llvm.ptr<f32>
|
||||
omp.reduction %1, %0 : !llvm.ptr<f32>
|
||||
// CHECK: omp.yield
|
||||
omp.yield
|
||||
}
|
||||
// CHECK: omp.terminator
|
||||
omp.terminator
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @sections_reduction2
|
||||
func @sections_reduction2() {
|
||||
%0 = memref.alloca() : memref<1xf32>
|
||||
// CHECK: omp.sections reduction(@add2_f32 -> %{{.+}} : memref<1xf32>)
|
||||
omp.sections reduction(@add2_f32 -> %0 : memref<1xf32>) {
|
||||
omp.section {
|
||||
%1 = arith.constant 2.0 : f32
|
||||
// CHECK: omp.reduction
|
||||
omp.reduction %1, %0 : memref<1xf32>
|
||||
omp.terminator
|
||||
}
|
||||
omp.section {
|
||||
%1 = arith.constant 2.0 : f32
|
||||
// CHECK: omp.reduction
|
||||
omp.reduction %1, %0 : memref<1xf32>
|
||||
omp.terminator
|
||||
}
|
||||
omp.terminator
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: omp.critical.declare @mutex1 hint(uncontended)
|
||||
omp.critical.declare @mutex1 hint(uncontended)
|
||||
// CHECK: omp.critical.declare @mutex2 hint(contended)
|
||||
|
Loading…
x
Reference in New Issue
Block a user