mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-02-26 05:18:46 +00:00
[flang][openmp] Parallel reduction FIR lowering
This patch extends the logic for lowering loop construct reductions to parallel block reductions. Reviewed By: kiranchandramohan Differential Revision: https://reviews.llvm.org/D154182
This commit is contained in:
parent
60bb4bafe2
commit
9bf5093623
@ -2172,6 +2172,8 @@ private:
|
||||
|
||||
const Fortran::parser::OpenMPLoopConstruct *ompLoop =
|
||||
std::get_if<Fortran::parser::OpenMPLoopConstruct>(&omp.u);
|
||||
const Fortran::parser::OpenMPBlockConstruct *ompBlock =
|
||||
std::get_if<Fortran::parser::OpenMPBlockConstruct>(&omp.u);
|
||||
|
||||
// If loop is part of an OpenMP Construct then the OpenMP dialect
|
||||
// workshare loop operation has already been created. Only the
|
||||
@ -2196,8 +2198,15 @@ private:
|
||||
for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations())
|
||||
genFIR(e);
|
||||
|
||||
if (ompLoop)
|
||||
if (ompLoop) {
|
||||
genOpenMPReduction(*this, *loopOpClauseList);
|
||||
} else if (ompBlock) {
|
||||
const auto &blockStart =
|
||||
std::get<Fortran::parser::OmpBeginBlockDirective>(ompBlock->t);
|
||||
const auto &blockClauses =
|
||||
std::get<Fortran::parser::OmpClauseList>(blockStart.t);
|
||||
genOpenMPReduction(*this, blockClauses);
|
||||
}
|
||||
|
||||
localSymbols.popScope();
|
||||
builder->restoreInsertionPoint(insertPt);
|
||||
|
@ -1154,209 +1154,6 @@ createCombinedParallelOp(Fortran::lower::AbstractConverter &converter,
|
||||
/*isCombined=*/true);
|
||||
}
|
||||
|
||||
static void
|
||||
genOMP(Fortran::lower::AbstractConverter &converter,
|
||||
Fortran::lower::pft::Evaluation &eval,
|
||||
const Fortran::parser::OpenMPBlockConstruct &blockConstruct) {
|
||||
const auto &beginBlockDirective =
|
||||
std::get<Fortran::parser::OmpBeginBlockDirective>(blockConstruct.t);
|
||||
const auto &blockDirective =
|
||||
std::get<Fortran::parser::OmpBlockDirective>(beginBlockDirective.t);
|
||||
const auto &endBlockDirective =
|
||||
std::get<Fortran::parser::OmpEndBlockDirective>(blockConstruct.t);
|
||||
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
|
||||
mlir::Location currentLocation = converter.genLocation(blockDirective.source);
|
||||
|
||||
Fortran::lower::StatementContext stmtCtx;
|
||||
llvm::ArrayRef<mlir::Type> argTy;
|
||||
mlir::Value ifClauseOperand, numThreadsClauseOperand, finalClauseOperand,
|
||||
priorityClauseOperand;
|
||||
mlir::omp::ClauseProcBindKindAttr procBindKindAttr;
|
||||
SmallVector<Value> allocateOperands, allocatorOperands, dependOperands;
|
||||
SmallVector<Attribute> dependTypeOperands;
|
||||
mlir::UnitAttr nowaitAttr, untiedAttr, mergeableAttr;
|
||||
|
||||
const auto &opClauseList =
|
||||
std::get<Fortran::parser::OmpClauseList>(beginBlockDirective.t);
|
||||
for (const auto &clause : opClauseList.v) {
|
||||
mlir::Location clauseLocation = converter.genLocation(clause.source);
|
||||
if (const auto &ifClause =
|
||||
std::get_if<Fortran::parser::OmpClause::If>(&clause.u)) {
|
||||
ifClauseOperand =
|
||||
getIfClauseOperand(converter, stmtCtx, ifClause, clauseLocation);
|
||||
} else if (const auto &numThreadsClause =
|
||||
std::get_if<Fortran::parser::OmpClause::NumThreads>(
|
||||
&clause.u)) {
|
||||
// OMPIRBuilder expects `NUM_THREAD` clause as a `Value`.
|
||||
numThreadsClauseOperand = fir::getBase(converter.genExprValue(
|
||||
*Fortran::semantics::GetExpr(numThreadsClause->v), stmtCtx));
|
||||
} else if (const auto &procBindClause =
|
||||
std::get_if<Fortran::parser::OmpClause::ProcBind>(
|
||||
&clause.u)) {
|
||||
procBindKindAttr = genProcBindKindAttr(firOpBuilder, procBindClause);
|
||||
} else if (const auto &allocateClause =
|
||||
std::get_if<Fortran::parser::OmpClause::Allocate>(
|
||||
&clause.u)) {
|
||||
genAllocateClause(converter, allocateClause->v, allocatorOperands,
|
||||
allocateOperands);
|
||||
} else if (std::get_if<Fortran::parser::OmpClause::Private>(&clause.u) ||
|
||||
std::get_if<Fortran::parser::OmpClause::Firstprivate>(
|
||||
&clause.u) ||
|
||||
std::get_if<Fortran::parser::OmpClause::Copyin>(&clause.u)) {
|
||||
// Privatisation and copyin clauses are handled elsewhere.
|
||||
continue;
|
||||
} else if (std::get_if<Fortran::parser::OmpClause::Shared>(&clause.u)) {
|
||||
// Shared is the default behavior in the IR, so no handling is required.
|
||||
continue;
|
||||
} else if (const auto &defaultClause =
|
||||
std::get_if<Fortran::parser::OmpClause::Default>(
|
||||
&clause.u)) {
|
||||
if ((defaultClause->v.v ==
|
||||
Fortran::parser::OmpDefaultClause::Type::Shared) ||
|
||||
(defaultClause->v.v ==
|
||||
Fortran::parser::OmpDefaultClause::Type::None)) {
|
||||
// Default clause with shared or none do not require any handling since
|
||||
// Shared is the default behavior in the IR and None is only required
|
||||
// for semantic checks.
|
||||
continue;
|
||||
}
|
||||
} else if (std::get_if<Fortran::parser::OmpClause::Threads>(&clause.u)) {
|
||||
// Nothing needs to be done for threads clause.
|
||||
continue;
|
||||
} else if (std::get_if<Fortran::parser::OmpClause::Map>(&clause.u)) {
|
||||
// Map clause is exclusive to Target Data directives. It is handled
|
||||
// as part of the TargetOp creation.
|
||||
continue;
|
||||
} else if (std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(
|
||||
&clause.u)) {
|
||||
// UseDevicePtr clause is exclusive to Target Data directives. It is
|
||||
// handled as part of the TargetOp creation.
|
||||
continue;
|
||||
} else if (std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(
|
||||
&clause.u)) {
|
||||
// UseDeviceAddr clause is exclusive to Target Data directives. It is
|
||||
// handled as part of the TargetOp creation.
|
||||
continue;
|
||||
} else if (std::get_if<Fortran::parser::OmpClause::ThreadLimit>(
|
||||
&clause.u)) {
|
||||
// Handled as part of TargetOp creation.
|
||||
continue;
|
||||
} else if (const auto &finalClause =
|
||||
std::get_if<Fortran::parser::OmpClause::Final>(&clause.u)) {
|
||||
mlir::Value finalVal = fir::getBase(converter.genExprValue(
|
||||
*Fortran::semantics::GetExpr(finalClause->v), stmtCtx));
|
||||
finalClauseOperand = firOpBuilder.createConvert(
|
||||
currentLocation, firOpBuilder.getI1Type(), finalVal);
|
||||
} else if (std::get_if<Fortran::parser::OmpClause::Untied>(&clause.u)) {
|
||||
untiedAttr = firOpBuilder.getUnitAttr();
|
||||
} else if (std::get_if<Fortran::parser::OmpClause::Mergeable>(&clause.u)) {
|
||||
mergeableAttr = firOpBuilder.getUnitAttr();
|
||||
} else if (const auto &priorityClause =
|
||||
std::get_if<Fortran::parser::OmpClause::Priority>(
|
||||
&clause.u)) {
|
||||
priorityClauseOperand = fir::getBase(converter.genExprValue(
|
||||
*Fortran::semantics::GetExpr(priorityClause->v), stmtCtx));
|
||||
} else if (std::get_if<Fortran::parser::OmpClause::Reduction>(&clause.u)) {
|
||||
TODO(currentLocation,
|
||||
"Reduction in OpenMP " +
|
||||
llvm::omp::getOpenMPDirectiveName(blockDirective.v) +
|
||||
" construct");
|
||||
} else if (const auto &dependClause =
|
||||
std::get_if<Fortran::parser::OmpClause::Depend>(&clause.u)) {
|
||||
const std::list<Fortran::parser::Designator> &depVal =
|
||||
std::get<std::list<Fortran::parser::Designator>>(
|
||||
std::get<Fortran::parser::OmpDependClause::InOut>(
|
||||
dependClause->v.u)
|
||||
.t);
|
||||
omp::ClauseTaskDependAttr dependTypeOperand =
|
||||
genDependKindAttr(firOpBuilder, dependClause);
|
||||
dependTypeOperands.insert(dependTypeOperands.end(), depVal.size(),
|
||||
dependTypeOperand);
|
||||
for (const Fortran::parser::Designator &ompObject : depVal) {
|
||||
Fortran::semantics::Symbol *sym = nullptr;
|
||||
std::visit(
|
||||
Fortran::common::visitors{
|
||||
[&](const Fortran::parser::DataRef &designator) {
|
||||
if (const Fortran::parser::Name *name =
|
||||
std::get_if<Fortran::parser::Name>(&designator.u)) {
|
||||
sym = name->symbol;
|
||||
} else if (std::get_if<Fortran::common::Indirection<
|
||||
Fortran::parser::ArrayElement>>(
|
||||
&designator.u)) {
|
||||
TODO(converter.getCurrentLocation(),
|
||||
"array sections not supported for task depend");
|
||||
}
|
||||
},
|
||||
[&](const Fortran::parser::Substring &designator) {
|
||||
TODO(converter.getCurrentLocation(),
|
||||
"substring not supported for task depend");
|
||||
}},
|
||||
(ompObject).u);
|
||||
const mlir::Value variable = converter.getSymbolAddress(*sym);
|
||||
dependOperands.push_back(((variable)));
|
||||
}
|
||||
} else {
|
||||
TODO(converter.getCurrentLocation(), "OpenMP Block construct clause");
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto &clause :
|
||||
std::get<Fortran::parser::OmpClauseList>(endBlockDirective.t).v) {
|
||||
if (std::get_if<Fortran::parser::OmpClause::Nowait>(&clause.u))
|
||||
nowaitAttr = firOpBuilder.getUnitAttr();
|
||||
}
|
||||
|
||||
if (blockDirective.v == llvm::omp::OMPD_parallel) {
|
||||
// Create and insert the operation.
|
||||
auto parallelOp = firOpBuilder.create<mlir::omp::ParallelOp>(
|
||||
currentLocation, argTy, ifClauseOperand, numThreadsClauseOperand,
|
||||
allocateOperands, allocatorOperands, /*reduction_vars=*/ValueRange(),
|
||||
/*reductions=*/nullptr, procBindKindAttr);
|
||||
createBodyOfOp<omp::ParallelOp>(parallelOp, converter, currentLocation,
|
||||
eval, &opClauseList);
|
||||
} else if (blockDirective.v == llvm::omp::OMPD_master) {
|
||||
auto masterOp =
|
||||
firOpBuilder.create<mlir::omp::MasterOp>(currentLocation, argTy);
|
||||
createBodyOfOp<omp::MasterOp>(masterOp, converter, currentLocation, eval);
|
||||
} else if (blockDirective.v == llvm::omp::OMPD_single) {
|
||||
auto singleOp = firOpBuilder.create<mlir::omp::SingleOp>(
|
||||
currentLocation, allocateOperands, allocatorOperands, nowaitAttr);
|
||||
createBodyOfOp<omp::SingleOp>(singleOp, converter, currentLocation, eval,
|
||||
&opClauseList);
|
||||
} else if (blockDirective.v == llvm::omp::OMPD_ordered) {
|
||||
auto orderedOp = firOpBuilder.create<mlir::omp::OrderedRegionOp>(
|
||||
currentLocation, /*simd=*/false);
|
||||
createBodyOfOp<omp::OrderedRegionOp>(orderedOp, converter, currentLocation,
|
||||
eval);
|
||||
} else if (blockDirective.v == llvm::omp::OMPD_task) {
|
||||
auto taskOp = firOpBuilder.create<mlir::omp::TaskOp>(
|
||||
currentLocation, ifClauseOperand, finalClauseOperand, untiedAttr,
|
||||
mergeableAttr, /*in_reduction_vars=*/ValueRange(),
|
||||
/*in_reductions=*/nullptr, priorityClauseOperand,
|
||||
dependTypeOperands.empty()
|
||||
? nullptr
|
||||
: mlir::ArrayAttr::get(firOpBuilder.getContext(),
|
||||
dependTypeOperands),
|
||||
dependOperands, allocateOperands, allocatorOperands);
|
||||
createBodyOfOp(taskOp, converter, currentLocation, eval, &opClauseList);
|
||||
} else if (blockDirective.v == llvm::omp::OMPD_taskgroup) {
|
||||
// TODO: Add task_reduction support
|
||||
auto taskGroupOp = firOpBuilder.create<mlir::omp::TaskGroupOp>(
|
||||
currentLocation, /*task_reduction_vars=*/ValueRange(),
|
||||
/*task_reductions=*/nullptr, allocateOperands, allocatorOperands);
|
||||
createBodyOfOp(taskGroupOp, converter, currentLocation, eval,
|
||||
&opClauseList);
|
||||
} else if (blockDirective.v == llvm::omp::OMPD_target) {
|
||||
createTargetOp(converter, opClauseList, blockDirective.v, currentLocation,
|
||||
&eval);
|
||||
} else if (blockDirective.v == llvm::omp::OMPD_target_data) {
|
||||
createTargetOp(converter, opClauseList, blockDirective.v, currentLocation,
|
||||
&eval);
|
||||
} else {
|
||||
TODO(currentLocation, "Unhandled block directive");
|
||||
}
|
||||
}
|
||||
|
||||
/// This function returns the identity value of the operator \p reductionOpName.
|
||||
/// For example:
|
||||
/// 0 + x = x,
|
||||
@ -1691,6 +1488,97 @@ static std::string getReductionName(
|
||||
return getReductionName(reductionName, ty);
|
||||
}
|
||||
|
||||
/// Creates a reduction declaration and associates it with an
|
||||
/// OpenMP block directive
|
||||
static void
|
||||
addReductionDecl(mlir::Location currentLocation,
|
||||
Fortran::lower::AbstractConverter &converter,
|
||||
const Fortran::parser::OmpReductionClause &reduction,
|
||||
SmallVector<Value> &reductionVars,
|
||||
SmallVector<Attribute> &reductionDeclSymbols) {
|
||||
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
|
||||
omp::ReductionDeclareOp decl;
|
||||
const auto &redOperator{
|
||||
std::get<Fortran::parser::OmpReductionOperator>(reduction.t)};
|
||||
const auto &objectList{std::get<Fortran::parser::OmpObjectList>(reduction.t)};
|
||||
if (const auto &redDefinedOp =
|
||||
std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
|
||||
const auto &intrinsicOp{
|
||||
std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>(
|
||||
redDefinedOp->u)};
|
||||
switch (intrinsicOp) {
|
||||
case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
|
||||
case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
|
||||
case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
|
||||
case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
|
||||
case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
|
||||
case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
|
||||
break;
|
||||
|
||||
default:
|
||||
TODO(currentLocation,
|
||||
"Reduction of some intrinsic operators is not supported");
|
||||
break;
|
||||
}
|
||||
for (const auto &ompObject : objectList.v) {
|
||||
if (const auto *name{
|
||||
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
|
||||
if (const auto *symbol{name->symbol}) {
|
||||
mlir::Value symVal = converter.getSymbolAddress(*symbol);
|
||||
mlir::Type redType =
|
||||
symVal.getType().cast<fir::ReferenceType>().getEleTy();
|
||||
reductionVars.push_back(symVal);
|
||||
if (redType.isa<fir::LogicalType>())
|
||||
decl = createReductionDecl(
|
||||
firOpBuilder,
|
||||
getReductionName(intrinsicOp, firOpBuilder.getI1Type()),
|
||||
intrinsicOp, redType, currentLocation);
|
||||
else if (redType.isIntOrIndexOrFloat()) {
|
||||
decl = createReductionDecl(firOpBuilder,
|
||||
getReductionName(intrinsicOp, redType),
|
||||
intrinsicOp, redType, currentLocation);
|
||||
} else {
|
||||
TODO(currentLocation, "Reduction of some types is not supported");
|
||||
}
|
||||
reductionDeclSymbols.push_back(
|
||||
SymbolRefAttr::get(firOpBuilder.getContext(), decl.getSymName()));
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (auto reductionIntrinsic =
|
||||
std::get_if<Fortran::parser::ProcedureDesignator>(
|
||||
&redOperator.u)) {
|
||||
if (const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(
|
||||
reductionIntrinsic)}) {
|
||||
if ((name->source != "max") && (name->source != "min") &&
|
||||
(name->source != "ior") && (name->source != "ieor") &&
|
||||
(name->source != "iand")) {
|
||||
TODO(currentLocation,
|
||||
"Reduction of intrinsic procedures is not supported");
|
||||
}
|
||||
std::string intrinsicOp = name->ToString();
|
||||
for (const auto &ompObject : objectList.v) {
|
||||
if (const auto *name{
|
||||
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
|
||||
if (const auto *symbol{name->symbol}) {
|
||||
mlir::Value symVal = converter.getSymbolAddress(*symbol);
|
||||
mlir::Type redType =
|
||||
symVal.getType().cast<fir::ReferenceType>().getEleTy();
|
||||
reductionVars.push_back(symVal);
|
||||
assert(redType.isIntOrIndexOrFloat() &&
|
||||
"Unsupported reduction type");
|
||||
decl = createReductionDecl(
|
||||
firOpBuilder, getReductionName(intrinsicOp, redType),
|
||||
*reductionIntrinsic, redType, currentLocation);
|
||||
reductionDeclSymbols.push_back(SymbolRefAttr::get(
|
||||
firOpBuilder.getContext(), decl.getSymName()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void genOMP(Fortran::lower::AbstractConverter &converter,
|
||||
Fortran::lower::pft::Evaluation &eval,
|
||||
const Fortran::parser::OpenMPLoopConstruct &loopConstruct) {
|
||||
@ -1786,88 +1674,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
|
||||
} else if (const auto &reductionClause =
|
||||
std::get_if<Fortran::parser::OmpClause::Reduction>(
|
||||
&clause.u)) {
|
||||
omp::ReductionDeclareOp decl;
|
||||
const auto &redOperator{std::get<Fortran::parser::OmpReductionOperator>(
|
||||
reductionClause->v.t)};
|
||||
const auto &objectList{
|
||||
std::get<Fortran::parser::OmpObjectList>(reductionClause->v.t)};
|
||||
if (const auto &redDefinedOp =
|
||||
std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
|
||||
const auto &intrinsicOp{
|
||||
std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>(
|
||||
redDefinedOp->u)};
|
||||
switch (intrinsicOp) {
|
||||
case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
|
||||
case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
|
||||
case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
|
||||
case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
|
||||
case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
|
||||
case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
|
||||
break;
|
||||
|
||||
default:
|
||||
TODO(currentLocation,
|
||||
"Reduction of some intrinsic operators is not supported");
|
||||
break;
|
||||
}
|
||||
for (const auto &ompObject : objectList.v) {
|
||||
if (const auto *name{
|
||||
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
|
||||
if (const auto *symbol{name->symbol}) {
|
||||
mlir::Value symVal = converter.getSymbolAddress(*symbol);
|
||||
mlir::Type redType =
|
||||
symVal.getType().cast<fir::ReferenceType>().getEleTy();
|
||||
reductionVars.push_back(symVal);
|
||||
if (redType.isa<fir::LogicalType>())
|
||||
decl = createReductionDecl(
|
||||
firOpBuilder,
|
||||
getReductionName(intrinsicOp, firOpBuilder.getI1Type()),
|
||||
intrinsicOp, redType, currentLocation);
|
||||
else if (redType.isIntOrIndexOrFloat()) {
|
||||
decl = createReductionDecl(
|
||||
firOpBuilder, getReductionName(intrinsicOp, redType),
|
||||
intrinsicOp, redType, currentLocation);
|
||||
} else {
|
||||
TODO(currentLocation,
|
||||
"Reduction of some types is not supported");
|
||||
}
|
||||
reductionDeclSymbols.push_back(SymbolRefAttr::get(
|
||||
firOpBuilder.getContext(), decl.getSymName()));
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (auto reductionIntrinsic =
|
||||
std::get_if<Fortran::parser::ProcedureDesignator>(
|
||||
&redOperator.u)) {
|
||||
if (const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(
|
||||
reductionIntrinsic)}) {
|
||||
if ((name->source != "max") && (name->source != "min") &&
|
||||
(name->source != "ior") && (name->source != "ieor") &&
|
||||
(name->source != "iand")) {
|
||||
TODO(currentLocation,
|
||||
"Reduction of intrinsic procedures is not supported");
|
||||
}
|
||||
std::string intrinsicOp = name->ToString();
|
||||
for (const auto &ompObject : objectList.v) {
|
||||
if (const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(
|
||||
ompObject)}) {
|
||||
if (const auto *symbol{name->symbol}) {
|
||||
mlir::Value symVal = converter.getSymbolAddress(*symbol);
|
||||
mlir::Type redType =
|
||||
symVal.getType().cast<fir::ReferenceType>().getEleTy();
|
||||
reductionVars.push_back(symVal);
|
||||
assert(redType.isIntOrIndexOrFloat() &&
|
||||
"Unsupported reduction type");
|
||||
decl = createReductionDecl(
|
||||
firOpBuilder, getReductionName(intrinsicOp, redType),
|
||||
*reductionIntrinsic, redType, currentLocation);
|
||||
reductionDeclSymbols.push_back(SymbolRefAttr::get(
|
||||
firOpBuilder.getContext(), decl.getSymName()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
addReductionDecl(currentLocation, converter, reductionClause->v,
|
||||
reductionVars, reductionDeclSymbols);
|
||||
} else if (const auto &simdlenClause =
|
||||
std::get_if<Fortran::parser::OmpClause::Simdlen>(
|
||||
&clause.u)) {
|
||||
@ -2003,6 +1811,214 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
|
||||
&loopOpClauseList, iv, /*outer=*/false, &dsp);
|
||||
}
|
||||
|
||||
static void
|
||||
genOMP(Fortran::lower::AbstractConverter &converter,
|
||||
Fortran::lower::pft::Evaluation &eval,
|
||||
const Fortran::parser::OpenMPBlockConstruct &blockConstruct) {
|
||||
const auto &beginBlockDirective =
|
||||
std::get<Fortran::parser::OmpBeginBlockDirective>(blockConstruct.t);
|
||||
const auto &blockDirective =
|
||||
std::get<Fortran::parser::OmpBlockDirective>(beginBlockDirective.t);
|
||||
const auto &endBlockDirective =
|
||||
std::get<Fortran::parser::OmpEndBlockDirective>(blockConstruct.t);
|
||||
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
|
||||
mlir::Location currentLocation = converter.genLocation(blockDirective.source);
|
||||
|
||||
Fortran::lower::StatementContext stmtCtx;
|
||||
llvm::ArrayRef<mlir::Type> argTy;
|
||||
mlir::Value ifClauseOperand, numThreadsClauseOperand, finalClauseOperand,
|
||||
priorityClauseOperand;
|
||||
mlir::omp::ClauseProcBindKindAttr procBindKindAttr;
|
||||
SmallVector<Value> allocateOperands, allocatorOperands, dependOperands,
|
||||
reductionVars;
|
||||
SmallVector<Attribute> dependTypeOperands, reductionDeclSymbols;
|
||||
mlir::UnitAttr nowaitAttr, untiedAttr, mergeableAttr;
|
||||
|
||||
const auto &opClauseList =
|
||||
std::get<Fortran::parser::OmpClauseList>(beginBlockDirective.t);
|
||||
for (const auto &clause : opClauseList.v) {
|
||||
mlir::Location clauseLocation = converter.genLocation(clause.source);
|
||||
if (const auto &ifClause =
|
||||
std::get_if<Fortran::parser::OmpClause::If>(&clause.u)) {
|
||||
ifClauseOperand =
|
||||
getIfClauseOperand(converter, stmtCtx, ifClause, clauseLocation);
|
||||
} else if (const auto &numThreadsClause =
|
||||
std::get_if<Fortran::parser::OmpClause::NumThreads>(
|
||||
&clause.u)) {
|
||||
// OMPIRBuilder expects `NUM_THREAD` clause as a `Value`.
|
||||
numThreadsClauseOperand = fir::getBase(converter.genExprValue(
|
||||
*Fortran::semantics::GetExpr(numThreadsClause->v), stmtCtx));
|
||||
} else if (const auto &procBindClause =
|
||||
std::get_if<Fortran::parser::OmpClause::ProcBind>(
|
||||
&clause.u)) {
|
||||
procBindKindAttr = genProcBindKindAttr(firOpBuilder, procBindClause);
|
||||
} else if (const auto &allocateClause =
|
||||
std::get_if<Fortran::parser::OmpClause::Allocate>(
|
||||
&clause.u)) {
|
||||
genAllocateClause(converter, allocateClause->v, allocatorOperands,
|
||||
allocateOperands);
|
||||
} else if (std::get_if<Fortran::parser::OmpClause::Private>(&clause.u) ||
|
||||
std::get_if<Fortran::parser::OmpClause::Firstprivate>(
|
||||
&clause.u) ||
|
||||
std::get_if<Fortran::parser::OmpClause::Copyin>(&clause.u)) {
|
||||
// Privatisation and copyin clauses are handled elsewhere.
|
||||
continue;
|
||||
} else if (std::get_if<Fortran::parser::OmpClause::Shared>(&clause.u)) {
|
||||
// Shared is the default behavior in the IR, so no handling is required.
|
||||
continue;
|
||||
} else if (const auto &defaultClause =
|
||||
std::get_if<Fortran::parser::OmpClause::Default>(
|
||||
&clause.u)) {
|
||||
if ((defaultClause->v.v ==
|
||||
Fortran::parser::OmpDefaultClause::Type::Shared) ||
|
||||
(defaultClause->v.v ==
|
||||
Fortran::parser::OmpDefaultClause::Type::None)) {
|
||||
// Default clause with shared or none do not require any handling since
|
||||
// Shared is the default behavior in the IR and None is only required
|
||||
// for semantic checks.
|
||||
continue;
|
||||
}
|
||||
} else if (std::get_if<Fortran::parser::OmpClause::Threads>(&clause.u)) {
|
||||
// Nothing needs to be done for threads clause.
|
||||
continue;
|
||||
} else if (std::get_if<Fortran::parser::OmpClause::Map>(&clause.u)) {
|
||||
// Map clause is exclusive to Target Data directives. It is handled
|
||||
// as part of the TargetOp creation.
|
||||
continue;
|
||||
} else if (std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(
|
||||
&clause.u)) {
|
||||
// UseDevicePtr clause is exclusive to Target Data directives. It is
|
||||
// handled as part of the TargetOp creation.
|
||||
continue;
|
||||
} else if (std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(
|
||||
&clause.u)) {
|
||||
// UseDeviceAddr clause is exclusive to Target Data directives. It is
|
||||
// handled as part of the TargetOp creation.
|
||||
continue;
|
||||
} else if (std::get_if<Fortran::parser::OmpClause::ThreadLimit>(
|
||||
&clause.u)) {
|
||||
// Handled as part of TargetOp creation.
|
||||
continue;
|
||||
} else if (const auto &finalClause =
|
||||
std::get_if<Fortran::parser::OmpClause::Final>(&clause.u)) {
|
||||
mlir::Value finalVal = fir::getBase(converter.genExprValue(
|
||||
*Fortran::semantics::GetExpr(finalClause->v), stmtCtx));
|
||||
finalClauseOperand = firOpBuilder.createConvert(
|
||||
currentLocation, firOpBuilder.getI1Type(), finalVal);
|
||||
} else if (std::get_if<Fortran::parser::OmpClause::Untied>(&clause.u)) {
|
||||
untiedAttr = firOpBuilder.getUnitAttr();
|
||||
} else if (std::get_if<Fortran::parser::OmpClause::Mergeable>(&clause.u)) {
|
||||
mergeableAttr = firOpBuilder.getUnitAttr();
|
||||
} else if (const auto &priorityClause =
|
||||
std::get_if<Fortran::parser::OmpClause::Priority>(
|
||||
&clause.u)) {
|
||||
priorityClauseOperand = fir::getBase(converter.genExprValue(
|
||||
*Fortran::semantics::GetExpr(priorityClause->v), stmtCtx));
|
||||
} else if (const auto &reductionClause =
|
||||
std::get_if<Fortran::parser::OmpClause::Reduction>(
|
||||
&clause.u)) {
|
||||
addReductionDecl(currentLocation, converter, reductionClause->v,
|
||||
reductionVars, reductionDeclSymbols);
|
||||
} else if (const auto &dependClause =
|
||||
std::get_if<Fortran::parser::OmpClause::Depend>(&clause.u)) {
|
||||
const std::list<Fortran::parser::Designator> &depVal =
|
||||
std::get<std::list<Fortran::parser::Designator>>(
|
||||
std::get<Fortran::parser::OmpDependClause::InOut>(
|
||||
dependClause->v.u)
|
||||
.t);
|
||||
omp::ClauseTaskDependAttr dependTypeOperand =
|
||||
genDependKindAttr(firOpBuilder, dependClause);
|
||||
dependTypeOperands.insert(dependTypeOperands.end(), depVal.size(),
|
||||
dependTypeOperand);
|
||||
for (const Fortran::parser::Designator &ompObject : depVal) {
|
||||
Fortran::semantics::Symbol *sym = nullptr;
|
||||
std::visit(
|
||||
Fortran::common::visitors{
|
||||
[&](const Fortran::parser::DataRef &designator) {
|
||||
if (const Fortran::parser::Name *name =
|
||||
std::get_if<Fortran::parser::Name>(&designator.u)) {
|
||||
sym = name->symbol;
|
||||
} else if (std::get_if<Fortran::common::Indirection<
|
||||
Fortran::parser::ArrayElement>>(
|
||||
&designator.u)) {
|
||||
TODO(converter.getCurrentLocation(),
|
||||
"array sections not supported for task depend");
|
||||
}
|
||||
},
|
||||
[&](const Fortran::parser::Substring &designator) {
|
||||
TODO(converter.getCurrentLocation(),
|
||||
"substring not supported for task depend");
|
||||
}},
|
||||
(ompObject).u);
|
||||
const mlir::Value variable = converter.getSymbolAddress(*sym);
|
||||
dependOperands.push_back(((variable)));
|
||||
}
|
||||
} else {
|
||||
TODO(converter.getCurrentLocation(), "OpenMP Block construct clause");
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto &clause :
|
||||
std::get<Fortran::parser::OmpClauseList>(endBlockDirective.t).v) {
|
||||
if (std::get_if<Fortran::parser::OmpClause::Nowait>(&clause.u))
|
||||
nowaitAttr = firOpBuilder.getUnitAttr();
|
||||
}
|
||||
|
||||
if (blockDirective.v == llvm::omp::OMPD_parallel) {
|
||||
// Create and insert the operation.
|
||||
auto parallelOp = firOpBuilder.create<mlir::omp::ParallelOp>(
|
||||
currentLocation, argTy, ifClauseOperand, numThreadsClauseOperand,
|
||||
allocateOperands, allocatorOperands, reductionVars,
|
||||
reductionDeclSymbols.empty()
|
||||
? nullptr
|
||||
: mlir::ArrayAttr::get(firOpBuilder.getContext(),
|
||||
reductionDeclSymbols),
|
||||
procBindKindAttr);
|
||||
createBodyOfOp<omp::ParallelOp>(parallelOp, converter, currentLocation,
|
||||
eval, &opClauseList);
|
||||
} else if (blockDirective.v == llvm::omp::OMPD_master) {
|
||||
auto masterOp =
|
||||
firOpBuilder.create<mlir::omp::MasterOp>(currentLocation, argTy);
|
||||
createBodyOfOp<omp::MasterOp>(masterOp, converter, currentLocation, eval);
|
||||
} else if (blockDirective.v == llvm::omp::OMPD_single) {
|
||||
auto singleOp = firOpBuilder.create<mlir::omp::SingleOp>(
|
||||
currentLocation, allocateOperands, allocatorOperands, nowaitAttr);
|
||||
createBodyOfOp<omp::SingleOp>(singleOp, converter, currentLocation, eval,
|
||||
&opClauseList);
|
||||
} else if (blockDirective.v == llvm::omp::OMPD_ordered) {
|
||||
auto orderedOp = firOpBuilder.create<mlir::omp::OrderedRegionOp>(
|
||||
currentLocation, /*simd=*/false);
|
||||
createBodyOfOp<omp::OrderedRegionOp>(orderedOp, converter, currentLocation,
|
||||
eval);
|
||||
} else if (blockDirective.v == llvm::omp::OMPD_task) {
|
||||
auto taskOp = firOpBuilder.create<mlir::omp::TaskOp>(
|
||||
currentLocation, ifClauseOperand, finalClauseOperand, untiedAttr,
|
||||
mergeableAttr, /*in_reduction_vars=*/ValueRange(),
|
||||
/*in_reductions=*/nullptr, priorityClauseOperand,
|
||||
dependTypeOperands.empty()
|
||||
? nullptr
|
||||
: mlir::ArrayAttr::get(firOpBuilder.getContext(),
|
||||
dependTypeOperands),
|
||||
dependOperands, allocateOperands, allocatorOperands);
|
||||
createBodyOfOp(taskOp, converter, currentLocation, eval, &opClauseList);
|
||||
} else if (blockDirective.v == llvm::omp::OMPD_taskgroup) {
|
||||
// TODO: Add task_reduction support
|
||||
auto taskGroupOp = firOpBuilder.create<mlir::omp::TaskGroupOp>(
|
||||
currentLocation, /*task_reduction_vars=*/ValueRange(),
|
||||
/*task_reductions=*/nullptr, allocateOperands, allocatorOperands);
|
||||
createBodyOfOp(taskGroupOp, converter, currentLocation, eval,
|
||||
&opClauseList);
|
||||
} else if (blockDirective.v == llvm::omp::OMPD_target) {
|
||||
createTargetOp(converter, opClauseList, blockDirective.v, currentLocation,
|
||||
&eval);
|
||||
} else if (blockDirective.v == llvm::omp::OMPD_target_data) {
|
||||
createTargetOp(converter, opClauseList, blockDirective.v, currentLocation,
|
||||
&eval);
|
||||
} else {
|
||||
TODO(currentLocation, "Unhandled block directive");
|
||||
}
|
||||
}
|
||||
|
||||
static void
|
||||
genOMP(Fortran::lower::AbstractConverter &converter,
|
||||
Fortran::lower::pft::Evaluation &eval,
|
||||
|
@ -1,11 +0,0 @@
|
||||
! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
|
||||
! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
|
||||
|
||||
! CHECK: not yet implemented: Reduction in OpenMP parallel construct
|
||||
subroutine reduction_parallel
|
||||
integer :: x
|
||||
!$omp parallel reduction(+:x)
|
||||
x = x + i
|
||||
!$omp end parallel
|
||||
print *, x
|
||||
end subroutine
|
97
flang/test/Lower/OpenMP/parallel-reduction-add.f90
Normal file
97
flang/test/Lower/OpenMP/parallel-reduction-add.f90
Normal file
@ -0,0 +1,97 @@
|
||||
! RUN: bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
|
||||
! RUN: %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
|
||||
|
||||
!CHECK-LABEL: omp.reduction.declare
|
||||
!CHECK-SAME: @[[RED_F32_NAME:.*]] : f32 init {
|
||||
!CHECK: ^bb0(%{{.*}}: f32):
|
||||
!CHECK: %[[C0_1:.*]] = arith.constant 0.000000e+00 : f32
|
||||
!CHECK: omp.yield(%[[C0_1]] : f32)
|
||||
!CHECK: } combiner {
|
||||
!CHECK: ^bb0(%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32):
|
||||
!CHECK: %[[RES:.*]] = arith.addf %[[ARG0]], %[[ARG1]] {{.*}}: f32
|
||||
!CHECK: omp.yield(%[[RES]] : f32)
|
||||
!CHECK: }
|
||||
|
||||
!CHECK-LABEL: omp.reduction.declare
|
||||
!CHECK-SAME: @[[RED_I32_NAME:.*]] : i32 init {
|
||||
!CHECK: ^bb0(%{{.*}}: i32):
|
||||
!CHECK: %[[C0_1:.*]] = arith.constant 0 : i32
|
||||
!CHECK: omp.yield(%[[C0_1]] : i32)
|
||||
!CHECK: } combiner {
|
||||
!CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32):
|
||||
!CHECK: %[[RES:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i32
|
||||
!CHECK: omp.yield(%[[RES]] : i32)
|
||||
!CHECK: }
|
||||
|
||||
!CHECK-LABEL: func.func @_QPsimple_int_add
|
||||
!CHECK: %[[IREF:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFsimple_int_addEi"}
|
||||
!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
|
||||
!CHECK: fir.store %[[I_START]] to %[[IREF]] : !fir.ref<i32>
|
||||
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[IREF]] : !fir.ref<i32>) {
|
||||
!CHECK: %[[I_INCR:.*]] = arith.constant 1 : i32
|
||||
!CHECK: omp.reduction %[[I_INCR]], %[[IREF]] : i32, !fir.ref<i32>
|
||||
!CHECK: omp.terminator
|
||||
!CHECK: }
|
||||
!CHECK: return
|
||||
subroutine simple_int_add
|
||||
integer :: i
|
||||
i = 0
|
||||
|
||||
!$omp parallel reduction(+:i)
|
||||
i = i + 1
|
||||
!$omp end parallel
|
||||
|
||||
print *, i
|
||||
end subroutine
|
||||
|
||||
!CHECK-LABEL: func.func @_QPsimple_real_add
|
||||
!CHECK: %[[RREF:.*]] = fir.alloca f32 {bindc_name = "r", uniq_name = "_QFsimple_real_addEr"}
|
||||
!CHECK: %[[R_START:.*]] = arith.constant 0.000000e+00 : f32
|
||||
!CHECK: fir.store %[[R_START]] to %[[RREF]] : !fir.ref<f32>
|
||||
!CHECK: omp.parallel reduction(@[[RED_F32_NAME]] -> %[[RREF]] : !fir.ref<f32>) {
|
||||
!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
|
||||
!CHECK: omp.reduction %[[R_INCR]], %[[RREF]] : f32, !fir.ref<f32>
|
||||
!CHECK: omp.terminator
|
||||
!CHECK: }
|
||||
!CHECK: return
|
||||
subroutine simple_real_add
|
||||
real :: r
|
||||
r = 0.0
|
||||
|
||||
!$omp parallel reduction(+:r)
|
||||
r = r + 1.5
|
||||
!$omp end parallel
|
||||
|
||||
print *, r
|
||||
end subroutine
|
||||
|
||||
!CHECK-LABEL: func.func @_QPint_real_add
|
||||
!CHECK: %[[IREF:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFint_real_addEi"}
|
||||
!CHECK: %[[RREF:.*]] = fir.alloca f32 {bindc_name = "r", uniq_name = "_QFint_real_addEr"}
|
||||
!CHECK: %[[R_START:.*]] = arith.constant 0.000000e+00 : f32
|
||||
!CHECK: fir.store %[[R_START]] to %[[RREF]] : !fir.ref<f32>
|
||||
!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
|
||||
!CHECK: fir.store %[[I_START]] to %[[IREF]] : !fir.ref<i32>
|
||||
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[IREF]] : !fir.ref<i32>, @[[RED_F32_NAME]] -> %[[RREF]] : !fir.ref<f32>) {
|
||||
!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
|
||||
!CHECK: omp.reduction %[[R_INCR]], %[[RREF]] : f32, !fir.ref<f32>
|
||||
!CHECK: %[[I_INCR:.*]] = arith.constant 3 : i32
|
||||
!CHECK: omp.reduction %[[I_INCR]], %[[IREF]] : i32, !fir.ref<i32>
|
||||
!CHECK: omp.terminator
|
||||
!CHECK: }
|
||||
!CHECK: return
|
||||
subroutine int_real_add
|
||||
real :: r
|
||||
integer :: i
|
||||
|
||||
r = 0.0
|
||||
i = 0
|
||||
|
||||
!$omp parallel reduction(+:i,r)
|
||||
r = 1.5 + r
|
||||
i = i + 3
|
||||
!$omp end parallel
|
||||
|
||||
print *, r
|
||||
print *, i
|
||||
end subroutine
|
Loading…
x
Reference in New Issue
Block a user