From e456689fb3d6dd785202cd25f89e9443e5ad7d1a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Valentin=20Clement=20=28=E3=83=90=E3=83=AC=E3=83=B3?= =?UTF-8?q?=E3=82=BF=E3=82=A4=E3=83=B3=20=E3=82=AF=E3=83=AC=E3=83=A1?= =?UTF-8?q?=E3=83=B3=29?= Date: Thu, 4 Jan 2024 16:33:33 -0800 Subject: [PATCH] [mlir][flang][openacc] Support device_type on loop construct (#76892) This is adding support for `device_type` clause representation in the OpenACC MLIR dialect on the acc.loop operation and adjust flang to lower correctly to the new representation. Each "value" that can be impacted by a `device_type` clause is now associated with an array attribute that carry this information. This includes: - `worker` clause information - `gang` clause information - `vector` clause information - `collapse` clause information - `tile` clause information The representation of the `gang` clause information has been updated and all values are now carried in a single operand segment. This segment is then subdivided by `device_type`. Each value in a segment is also associated with a `GangArgType` so it can be differentiated (num/dim/static). This simplify the handling of gang values an limit the number of new attributes needed. When the clause can be associated with the operation without any value (`gang`, `vector`, `worker`). These are represented by a dedicated attributes with device_type information. Extra getter functions are provided to make it easier to retrieve a value based on a device_type. --- flang/lib/Lower/OpenACC.cpp | 183 +++++--- flang/test/Lower/OpenACC/acc-kernels-loop.f90 | 36 +- flang/test/Lower/OpenACC/acc-loop.f90 | 41 +- .../test/Lower/OpenACC/acc-parallel-loop.f90 | 36 +- flang/test/Lower/OpenACC/acc-reduction.f90 | 6 +- flang/test/Lower/OpenACC/acc-serial-loop.f90 | 36 +- .../mlir/Dialect/OpenACC/OpenACCOps.td | 141 +++++- mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp | 433 ++++++++++++++---- mlir/test/Dialect/OpenACC/invalid.mlir | 57 ++- mlir/test/Dialect/OpenACC/ops.mlir | 96 ++-- 10 files changed, 748 insertions(+), 317 deletions(-) diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index d10e56e5d117..d24c369d81be 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -1593,67 +1593,89 @@ createLoopOp(Fortran::lower::AbstractConverter &converter, const Fortran::parser::AccClauseList &accClauseList, bool needEarlyReturnHandling = false) { fir::FirOpBuilder &builder = converter.getFirOpBuilder(); - - mlir::Value workerNum; - mlir::Value vectorNum; - mlir::Value gangNum; - mlir::Value gangDim; - mlir::Value gangStatic; llvm::SmallVector tileOperands, privateOperands, - reductionOperands, cacheOperands; + reductionOperands, cacheOperands, vectorOperands, workerNumOperands, + gangOperands; llvm::SmallVector privatizations, reductionRecipes; - bool hasGang = false, hasVector = false, hasWorker = false; + llvm::SmallVector tileOperandsSegments, gangOperandsSegments; + llvm::SmallVector collapseValues; + + llvm::SmallVector gangArgTypes; + llvm::SmallVector seqDeviceTypes, independentDeviceTypes, + autoDeviceTypes, vectorOperandsDeviceTypes, workerNumOperandsDeviceTypes, + vectorDeviceTypes, workerNumDeviceTypes, tileOperandsDeviceTypes, + collapseDeviceTypes, gangDeviceTypes, gangOperandsDeviceTypes; + + // device_type attribute is set to `none` until a device_type clause is + // encountered. + auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get( + builder.getContext(), mlir::acc::DeviceType::None); for (const Fortran::parser::AccClause &clause : accClauseList.v) { mlir::Location clauseLocation = converter.genLocation(clause.source); if (const auto *gangClause = std::get_if(&clause.u)) { if (gangClause->v) { + auto crtGangOperands = gangOperands.size(); const Fortran::parser::AccGangArgList &x = *gangClause->v; for (const Fortran::parser::AccGangArg &gangArg : x.v) { if (const auto *num = std::get_if(&gangArg.u)) { - gangNum = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(num->v), stmtCtx)); + gangOperands.push_back(fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(num->v), stmtCtx))); + gangArgTypes.push_back(mlir::acc::GangArgTypeAttr::get( + builder.getContext(), mlir::acc::GangArgType::Num)); } else if (const auto *staticArg = std::get_if( &gangArg.u)) { const Fortran::parser::AccSizeExpr &sizeExpr = staticArg->v; if (sizeExpr.v) { - gangStatic = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(*sizeExpr.v), stmtCtx)); + gangOperands.push_back(fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(*sizeExpr.v), stmtCtx))); } else { // * was passed as value and will be represented as a special // constant. - gangStatic = builder.createIntegerConstant( - clauseLocation, builder.getIndexType(), starCst); + gangOperands.push_back(builder.createIntegerConstant( + clauseLocation, builder.getIndexType(), starCst)); } + gangArgTypes.push_back(mlir::acc::GangArgTypeAttr::get( + builder.getContext(), mlir::acc::GangArgType::Static)); } else if (const auto *dim = std::get_if( &gangArg.u)) { - gangDim = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(dim->v), stmtCtx)); + gangOperands.push_back(fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(dim->v), stmtCtx))); + gangArgTypes.push_back(mlir::acc::GangArgTypeAttr::get( + builder.getContext(), mlir::acc::GangArgType::Dim)); } } + gangOperandsSegments.push_back(gangOperands.size() - crtGangOperands); + gangOperandsDeviceTypes.push_back(crtDeviceTypeAttr); + } else { + gangDeviceTypes.push_back(crtDeviceTypeAttr); } - hasGang = true; } else if (const auto *workerClause = std::get_if(&clause.u)) { if (workerClause->v) { - workerNum = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(*workerClause->v), stmtCtx)); + workerNumOperands.push_back(fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(*workerClause->v), stmtCtx))); + workerNumOperandsDeviceTypes.push_back(crtDeviceTypeAttr); + } else { + workerNumDeviceTypes.push_back(crtDeviceTypeAttr); } - hasWorker = true; } else if (const auto *vectorClause = std::get_if(&clause.u)) { if (vectorClause->v) { - vectorNum = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(*vectorClause->v), stmtCtx)); + vectorOperands.push_back(fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(*vectorClause->v), stmtCtx))); + vectorOperandsDeviceTypes.push_back(crtDeviceTypeAttr); + } else { + vectorDeviceTypes.push_back(crtDeviceTypeAttr); } - hasVector = true; } else if (const auto *tileClause = std::get_if(&clause.u)) { const Fortran::parser::AccTileExprList &accTileExprList = tileClause->v; + auto crtTileOperands = tileOperands.size(); for (const auto &accTileExpr : accTileExprList.v) { const auto &expr = std::get>( @@ -1669,6 +1691,8 @@ createLoopOp(Fortran::lower::AbstractConverter &converter, tileOperands.push_back(tileStar); } } + tileOperandsDeviceTypes.push_back(crtDeviceTypeAttr); + tileOperandsSegments.push_back(tileOperands.size() - crtTileOperands); } else if (const auto *privateClause = std::get_if( &clause.u)) { @@ -1680,17 +1704,46 @@ createLoopOp(Fortran::lower::AbstractConverter &converter, &clause.u)) { genReductions(reductionClause->v, converter, semanticsContext, stmtCtx, reductionOperands, reductionRecipes); + } else if (std::get_if(&clause.u)) { + seqDeviceTypes.push_back(crtDeviceTypeAttr); + } else if (std::get_if( + &clause.u)) { + independentDeviceTypes.push_back(crtDeviceTypeAttr); + } else if (std::get_if(&clause.u)) { + autoDeviceTypes.push_back(crtDeviceTypeAttr); + } else if (const auto *deviceTypeClause = + std::get_if( + &clause.u)) { + const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList = + deviceTypeClause->v; + assert(deviceTypeExprList.v.size() == 1 && + "expect only one device_type expr"); + crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get( + builder.getContext(), getDeviceType(deviceTypeExprList.v.front().v)); + } else if (const auto *collapseClause = + std::get_if( + &clause.u)) { + const Fortran::parser::AccCollapseArg &arg = collapseClause->v; + const auto &force = std::get(arg.t); + if (force) + TODO(clauseLocation, "OpenACC collapse force modifier"); + const auto &intExpr = + std::get(arg.t); + const auto *expr = Fortran::semantics::GetExpr(intExpr); + const std::optional collapseValue = + Fortran::evaluate::ToInt64(*expr); + assert(collapseValue && "expect integer value for the collapse clause"); + collapseValues.push_back(*collapseValue); + collapseDeviceTypes.push_back(crtDeviceTypeAttr); } } // Prepare the operand segment size attribute and the operands value range. llvm::SmallVector operands; llvm::SmallVector operandSegments; - addOperand(operands, operandSegments, gangNum); - addOperand(operands, operandSegments, gangDim); - addOperand(operands, operandSegments, gangStatic); - addOperand(operands, operandSegments, workerNum); - addOperand(operands, operandSegments, vectorNum); + addOperands(operands, operandSegments, gangOperands); + addOperands(operands, operandSegments, workerNumOperands); + addOperands(operands, operandSegments, vectorOperands); addOperands(operands, operandSegments, tileOperands); addOperands(operands, operandSegments, cacheOperands); addOperands(operands, operandSegments, privateOperands); @@ -1708,12 +1761,42 @@ createLoopOp(Fortran::lower::AbstractConverter &converter, builder, currentLocation, eval, operands, operandSegments, /*outerCombined=*/false, retTy, yieldValue); - if (hasGang) - loopOp.setHasGangAttr(builder.getUnitAttr()); - if (hasWorker) - loopOp.setHasWorkerAttr(builder.getUnitAttr()); - if (hasVector) - loopOp.setHasVectorAttr(builder.getUnitAttr()); + if (!gangDeviceTypes.empty()) + loopOp.setGangAttr(builder.getArrayAttr(gangDeviceTypes)); + if (!gangArgTypes.empty()) + loopOp.setGangOperandsArgTypeAttr(builder.getArrayAttr(gangArgTypes)); + if (!gangOperandsSegments.empty()) + loopOp.setGangOperandsSegmentsAttr( + builder.getDenseI32ArrayAttr(gangOperandsSegments)); + if (!gangOperandsDeviceTypes.empty()) + loopOp.setGangOperandsDeviceTypeAttr( + builder.getArrayAttr(gangOperandsDeviceTypes)); + + if (!workerNumDeviceTypes.empty()) + loopOp.setWorkerAttr(builder.getArrayAttr(workerNumDeviceTypes)); + if (!workerNumOperandsDeviceTypes.empty()) + loopOp.setWorkerNumOperandsDeviceTypeAttr( + builder.getArrayAttr(workerNumOperandsDeviceTypes)); + + if (!vectorDeviceTypes.empty()) + loopOp.setVectorAttr(builder.getArrayAttr(vectorDeviceTypes)); + if (!vectorOperandsDeviceTypes.empty()) + loopOp.setVectorOperandsDeviceTypeAttr( + builder.getArrayAttr(vectorOperandsDeviceTypes)); + + if (!tileOperandsDeviceTypes.empty()) + loopOp.setTileOperandsDeviceTypeAttr( + builder.getArrayAttr(tileOperandsDeviceTypes)); + if (!tileOperandsSegments.empty()) + loopOp.setTileOperandsSegmentsAttr( + builder.getDenseI32ArrayAttr(tileOperandsSegments)); + + if (!seqDeviceTypes.empty()) + loopOp.setSeqAttr(builder.getArrayAttr(seqDeviceTypes)); + if (!independentDeviceTypes.empty()) + loopOp.setIndependentAttr(builder.getArrayAttr(independentDeviceTypes)); + if (!autoDeviceTypes.empty()) + loopOp.setAuto_Attr(builder.getArrayAttr(autoDeviceTypes)); if (!privatizations.empty()) loopOp.setPrivatizationsAttr( @@ -1723,33 +1806,11 @@ createLoopOp(Fortran::lower::AbstractConverter &converter, loopOp.setReductionRecipesAttr( mlir::ArrayAttr::get(builder.getContext(), reductionRecipes)); - // Lower clauses mapped to attributes - for (const Fortran::parser::AccClause &clause : accClauseList.v) { - mlir::Location clauseLocation = converter.genLocation(clause.source); - if (const auto *collapseClause = - std::get_if(&clause.u)) { - const Fortran::parser::AccCollapseArg &arg = collapseClause->v; - const auto &force = std::get(arg.t); - if (force) - TODO(clauseLocation, "OpenACC collapse force modifier"); - const auto &intExpr = - std::get(arg.t); - const auto *expr = Fortran::semantics::GetExpr(intExpr); - const std::optional collapseValue = - Fortran::evaluate::ToInt64(*expr); - if (collapseValue) { - loopOp.setCollapseAttr(builder.getI64IntegerAttr(*collapseValue)); - } - } else if (std::get_if(&clause.u)) { - loopOp.setSeqAttr(builder.getUnitAttr()); - } else if (std::get_if( - &clause.u)) { - loopOp.setIndependentAttr(builder.getUnitAttr()); - } else if (std::get_if(&clause.u)) { - loopOp->setAttr(mlir::acc::LoopOp::getAutoAttrStrName(), - builder.getUnitAttr()); - } - } + if (!collapseValues.empty()) + loopOp.setCollapseAttr(builder.getI64ArrayAttr(collapseValues)); + if (!collapseDeviceTypes.empty()) + loopOp.setCollapseDeviceTypeAttr(builder.getArrayAttr(collapseDeviceTypes)); + return loopOp; } diff --git a/flang/test/Lower/OpenACC/acc-kernels-loop.f90 b/flang/test/Lower/OpenACC/acc-kernels-loop.f90 index 93bc699031d5..b17f2e2c80b2 100644 --- a/flang/test/Lower/OpenACC/acc-kernels-loop.f90 +++ b/flang/test/Lower/OpenACC/acc-kernels-loop.f90 @@ -461,7 +461,7 @@ subroutine acc_kernels_loop ! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield -! CHECK-NEXT: } attributes {seq} +! CHECK-NEXT: } attributes {seq = [#acc.device_type]} ! CHECK: acc.terminator ! CHECK-NEXT: }{{$}} @@ -474,7 +474,7 @@ subroutine acc_kernels_loop ! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield -! CHECK-NEXT: } attributes {auto} +! CHECK-NEXT: } attributes {auto_ = [#acc.device_type]} ! CHECK: acc.terminator ! CHECK-NEXT: }{{$}} @@ -487,7 +487,7 @@ subroutine acc_kernels_loop ! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield -! CHECK-NEXT: } attributes {independent} +! CHECK-NEXT: } attributes {independent = [#acc.device_type]} ! CHECK: acc.terminator ! CHECK-NEXT: }{{$}} @@ -497,10 +497,10 @@ subroutine acc_kernels_loop END DO ! CHECK: acc.kernels { -! CHECK: acc.loop gang { +! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield -! CHECK-NEXT: }{{$}} +! CHECK-NEXT: } attributes {gang = [#acc.device_type]}{{$}} ! CHECK: acc.terminator ! CHECK-NEXT: }{{$}} @@ -511,7 +511,7 @@ subroutine acc_kernels_loop ! CHECK: acc.kernels { ! CHECK: [[GANGNUM1:%.*]] = arith.constant 8 : i32 -! CHECK-NEXT: acc.loop gang(num=[[GANGNUM1]] : i32) { +! CHECK-NEXT: acc.loop gang({num=[[GANGNUM1]] : i32}) { ! CHECK: fir.do_loop ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -525,7 +525,7 @@ subroutine acc_kernels_loop ! CHECK: acc.kernels { ! CHECK: [[GANGNUM2:%.*]] = fir.load %{{.*}} : !fir.ref -! CHECK-NEXT: acc.loop gang(num=[[GANGNUM2]] : i32) { +! CHECK-NEXT: acc.loop gang({num=[[GANGNUM2]] : i32}) { ! CHECK: fir.do_loop ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -538,7 +538,7 @@ subroutine acc_kernels_loop END DO ! CHECK: acc.kernels { -! CHECK: acc.loop gang(num=%{{.*}} : i32, static=%{{.*}} : i32) { +! CHECK: acc.loop gang({num=%{{.*}} : i32, static=%{{.*}} : i32}) { ! CHECK: fir.do_loop ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -550,10 +550,10 @@ subroutine acc_kernels_loop a(i) = b(i) END DO ! CHECK: acc.kernels { -! CHECK: acc.loop vector { +! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield -! CHECK-NEXT: }{{$}} +! CHECK-NEXT: } attributes {vector = [#acc.device_type]}{{$}} ! CHECK: acc.terminator ! CHECK-NEXT: }{{$}} @@ -591,10 +591,10 @@ subroutine acc_kernels_loop END DO ! CHECK: acc.kernels { -! CHECK: acc.loop worker { +! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield -! CHECK-NEXT: }{{$}} +! CHECK-NEXT: } attributes {worker = [#acc.device_type]}{{$}} ! CHECK: acc.terminator ! CHECK-NEXT: }{{$}} @@ -624,7 +624,7 @@ subroutine acc_kernels_loop ! CHECK: fir.do_loop ! CHECK: fir.do_loop ! CHECK: acc.yield -! CHECK-NEXT: } attributes {collapse = 2 : i64} +! CHECK-NEXT: } attributes {collapse = [2], collapseDeviceType = [#acc.device_type]} ! CHECK: acc.terminator ! CHECK-NEXT: }{{$}} @@ -655,7 +655,7 @@ subroutine acc_kernels_loop ! CHECK: acc.kernels { ! CHECK: [[TILESIZE:%.*]] = arith.constant 2 : i32 -! CHECK: acc.loop tile([[TILESIZE]] : i32) { +! CHECK: acc.loop tile({[[TILESIZE]] : i32}) { ! CHECK: fir.do_loop ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -669,7 +669,7 @@ subroutine acc_kernels_loop ! CHECK: acc.kernels { ! CHECK: [[TILESIZEM1:%.*]] = arith.constant -1 : i32 -! CHECK: acc.loop tile([[TILESIZEM1]] : i32) { +! CHECK: acc.loop tile({[[TILESIZEM1]] : i32}) { ! CHECK: fir.do_loop ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -686,7 +686,7 @@ subroutine acc_kernels_loop ! CHECK: acc.kernels { ! CHECK: [[TILESIZE1:%.*]] = arith.constant 2 : i32 ! CHECK: [[TILESIZE2:%.*]] = arith.constant 2 : i32 -! CHECK: acc.loop tile([[TILESIZE1]], [[TILESIZE2]] : i32, i32) { +! CHECK: acc.loop tile({[[TILESIZE1]] : i32, [[TILESIZE2]] : i32}) { ! CHECK: fir.do_loop ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -699,7 +699,7 @@ subroutine acc_kernels_loop END DO ! CHECK: acc.kernels { -! CHECK: acc.loop tile(%{{.*}} : i32) { +! CHECK: acc.loop tile({%{{.*}} : i32}) { ! CHECK: fir.do_loop ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -714,7 +714,7 @@ subroutine acc_kernels_loop END DO ! CHECK: acc.kernels { -! CHECK: acc.loop tile(%{{.*}}, %{{.*}} : i32, i32) { +! CHECK: acc.loop tile({%{{.*}} : i32, %{{.*}} : i32}) { ! CHECK: fir.do_loop ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} diff --git a/flang/test/Lower/OpenACC/acc-loop.f90 b/flang/test/Lower/OpenACC/acc-loop.f90 index 924574512da4..e7f65770498f 100644 --- a/flang/test/Lower/OpenACC/acc-loop.f90 +++ b/flang/test/Lower/OpenACC/acc-loop.f90 @@ -1,6 +1,5 @@ ! This test checks lowering of OpenACC loop directive. -! RUN: bbc -fopenacc -emit-fir -hlfir=false %s -o - | FileCheck %s ! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s ! CHECK-LABEL: acc.private.recipe @privatization_ref_10x10xf32 : !fir.ref> init { @@ -41,7 +40,7 @@ program acc_loop !CHECK: acc.loop { !CHECK: fir.do_loop !CHECK: acc.yield -!CHECK-NEXT: } attributes {seq} +!CHECK-NEXT: } attributes {seq = [#acc.device_type]} !$acc loop auto DO i = 1, n @@ -51,7 +50,7 @@ program acc_loop !CHECK: acc.loop { !CHECK: fir.do_loop !CHECK: acc.yield -!CHECK-NEXT: } attributes {auto} +!CHECK-NEXT: } attributes {auto_ = [#acc.device_type]} !$acc loop independent DO i = 1, n @@ -61,17 +60,17 @@ program acc_loop !CHECK: acc.loop { !CHECK: fir.do_loop !CHECK: acc.yield -!CHECK-NEXT: } attributes {independent} +!CHECK-NEXT: } attributes {independent = [#acc.device_type]} !$acc loop gang DO i = 1, n a(i) = b(i) END DO -!CHECK: acc.loop gang { +!CHECK: acc.loop { !CHECK: fir.do_loop !CHECK: acc.yield -!CHECK-NEXT: }{{$}} +!CHECK-NEXT: } attributes {gang = [#acc.device_type]}{{$}} !$acc loop gang(num: 8) DO i = 1, n @@ -79,7 +78,7 @@ program acc_loop END DO !CHECK: [[GANGNUM1:%.*]] = arith.constant 8 : i32 -!CHECK-NEXT: acc.loop gang(num=[[GANGNUM1]] : i32) { +!CHECK-NEXT: acc.loop gang({num=[[GANGNUM1]] : i32}) { !CHECK: fir.do_loop !CHECK: acc.yield !CHECK-NEXT: }{{$}} @@ -90,7 +89,7 @@ program acc_loop END DO !CHECK: [[GANGNUM2:%.*]] = fir.load %{{.*}} : !fir.ref -!CHECK-NEXT: acc.loop gang(num=[[GANGNUM2]] : i32) { +!CHECK-NEXT: acc.loop gang({num=[[GANGNUM2]] : i32}) { !CHECK: fir.do_loop !CHECK: acc.yield !CHECK-NEXT: }{{$}} @@ -100,7 +99,7 @@ program acc_loop a(i) = b(i) END DO -!CHECK: acc.loop gang(num=%{{.*}} : i32, static=%{{.*}} : i32) { +!CHECK: acc.loop gang({num=%{{.*}} : i32, static=%{{.*}} : i32}) { !CHECK: fir.do_loop !CHECK: acc.yield !CHECK-NEXT: }{{$}} @@ -110,10 +109,10 @@ program acc_loop a(i) = b(i) END DO -!CHECK: acc.loop vector { +!CHECK: acc.loop { !CHECK: fir.do_loop !CHECK: acc.yield -!CHECK-NEXT: }{{$}} +!CHECK-NEXT: } attributes {vector = [#acc.device_type]}{{$}} !$acc loop vector(128) DO i = 1, n @@ -142,10 +141,10 @@ program acc_loop a(i) = b(i) END DO -!CHECK: acc.loop worker { +!CHECK: acc.loop { !CHECK: fir.do_loop !CHECK: acc.yield -!CHECK-NEXT: }{{$}} +!CHECK-NEXT: } attributes {worker = [#acc.device_type]}{{$}} !$acc loop worker(128) DO i = 1, n @@ -193,7 +192,7 @@ program acc_loop a(i) = b(i) END DO !CHECK: [[TILESIZE:%.*]] = arith.constant 2 : i32 -!CHECK: acc.loop tile([[TILESIZE]] : i32) { +!CHECK: acc.loop tile({[[TILESIZE]] : i32}) { !CHECK: fir.do_loop !CHECK: acc.yield !CHECK-NEXT: }{{$}} @@ -203,7 +202,7 @@ program acc_loop a(i) = b(i) END DO !CHECK: [[TILESIZEM1:%.*]] = arith.constant -1 : i32 -!CHECK: acc.loop tile([[TILESIZEM1]] : i32) { +!CHECK: acc.loop tile({[[TILESIZEM1]] : i32}) { !CHECK: fir.do_loop !CHECK: acc.yield !CHECK-NEXT: }{{$}} @@ -217,7 +216,7 @@ program acc_loop !CHECK: [[TILESIZE1:%.*]] = arith.constant 2 : i32 !CHECK: [[TILESIZE2:%.*]] = arith.constant 2 : i32 -!CHECK: acc.loop tile([[TILESIZE1]], [[TILESIZE2]] : i32, i32) { +!CHECK: acc.loop tile({[[TILESIZE1]] : i32, [[TILESIZE2]] : i32}) { !CHECK: fir.do_loop !CHECK: acc.yield !CHECK-NEXT: }{{$}} @@ -227,7 +226,7 @@ program acc_loop a(i) = b(i) END DO -!CHECK: acc.loop tile(%{{.*}} : i32) { +!CHECK: acc.loop tile({%{{.*}} : i32}) { !CHECK: fir.do_loop !CHECK: acc.yield !CHECK-NEXT: }{{$}} @@ -239,7 +238,7 @@ program acc_loop END DO END DO -!CHECK: acc.loop tile(%{{.*}}, %{{.*}} : i32, i32) { +!CHECK: acc.loop tile({%{{.*}} : i32, %{{.*}} : i32}) { !CHECK: fir.do_loop !CHECK: acc.yield !CHECK-NEXT: }{{$}} @@ -255,7 +254,7 @@ program acc_loop !CHECK: fir.do_loop !CHECK: fir.do_loop !CHECK: acc.yield -!CHECK-NEXT: } attributes {collapse = 2 : i64} +!CHECK-NEXT: } attributes {collapse = [2], collapseDeviceType = [#acc.device_type]} !$acc loop DO i = 1, n @@ -290,7 +289,7 @@ program acc_loop a(i) = b(i) END DO -!CHECK: acc.loop gang(dim=%{{.*}}, static=%{{.*}} : i32) { +!CHECK: acc.loop gang({dim=%{{.*}}, static=%{{.*}} : i32}) { !CHECK: fir.do_loop !CHECK: acc.yield !CHECK-NEXT: }{{$}} @@ -301,7 +300,7 @@ program acc_loop END DO !CHECK: [[GANGDIM1:%.*]] = arith.constant 1 : i32 -!CHECK-NEXT: acc.loop gang(dim=[[GANGDIM1]] : i32) { +!CHECK-NEXT: acc.loop gang({dim=[[GANGDIM1]] : i32}) { !CHECK: fir.do_loop !CHECK: acc.yield !CHECK-NEXT: }{{$}} diff --git a/flang/test/Lower/OpenACC/acc-parallel-loop.f90 b/flang/test/Lower/OpenACC/acc-parallel-loop.f90 index deee7089033e..e9150a71f382 100644 --- a/flang/test/Lower/OpenACC/acc-parallel-loop.f90 +++ b/flang/test/Lower/OpenACC/acc-parallel-loop.f90 @@ -476,7 +476,7 @@ subroutine acc_parallel_loop ! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield -! CHECK-NEXT: } attributes {seq} +! CHECK-NEXT: } attributes {seq = [#acc.device_type]} ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -489,7 +489,7 @@ subroutine acc_parallel_loop ! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield -! CHECK-NEXT: } attributes {auto} +! CHECK-NEXT: } attributes {auto_ = [#acc.device_type]} ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -502,7 +502,7 @@ subroutine acc_parallel_loop ! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield -! CHECK-NEXT: } attributes {independent} +! CHECK-NEXT: } attributes {independent = [#acc.device_type]} ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -512,10 +512,10 @@ subroutine acc_parallel_loop END DO ! CHECK: acc.parallel { -! CHECK: acc.loop gang { +! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield -! CHECK-NEXT: }{{$}} +! CHECK-NEXT: } attributes {gang = [#acc.device_type]}{{$}} ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -526,7 +526,7 @@ subroutine acc_parallel_loop ! CHECK: acc.parallel { ! CHECK: [[GANGNUM1:%.*]] = arith.constant 8 : i32 -! CHECK-NEXT: acc.loop gang(num=[[GANGNUM1]] : i32) { +! CHECK-NEXT: acc.loop gang({num=[[GANGNUM1]] : i32}) { ! CHECK: fir.do_loop ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -540,7 +540,7 @@ subroutine acc_parallel_loop ! CHECK: acc.parallel { ! CHECK: [[GANGNUM2:%.*]] = fir.load %{{.*}} : !fir.ref -! CHECK-NEXT: acc.loop gang(num=[[GANGNUM2]] : i32) { +! CHECK-NEXT: acc.loop gang({num=[[GANGNUM2]] : i32}) { ! CHECK: fir.do_loop ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -553,7 +553,7 @@ subroutine acc_parallel_loop END DO ! CHECK: acc.parallel { -! CHECK: acc.loop gang(num=%{{.*}} : i32, static=%{{.*}} : i32) { +! CHECK: acc.loop gang({num=%{{.*}} : i32, static=%{{.*}} : i32}) { ! CHECK: fir.do_loop ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -565,10 +565,10 @@ subroutine acc_parallel_loop a(i) = b(i) END DO ! CHECK: acc.parallel { -! CHECK: acc.loop vector { +! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield -! CHECK-NEXT: }{{$}} +! CHECK-NEXT: } attributes {vector = [#acc.device_type]}{{$}} ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -606,10 +606,10 @@ subroutine acc_parallel_loop END DO ! CHECK: acc.parallel { -! CHECK: acc.loop worker { +! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield -! CHECK-NEXT: }{{$}} +! CHECK-NEXT: } attributes {worker = [#acc.device_type]}{{$}} ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -639,7 +639,7 @@ subroutine acc_parallel_loop ! CHECK: fir.do_loop ! CHECK: fir.do_loop ! CHECK: acc.yield -! CHECK-NEXT: } attributes {collapse = 2 : i64} +! CHECK-NEXT: } attributes {collapse = [2], collapseDeviceType = [#acc.device_type]} ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -670,7 +670,7 @@ subroutine acc_parallel_loop ! CHECK: acc.parallel { ! CHECK: [[TILESIZE:%.*]] = arith.constant 2 : i32 -! CHECK: acc.loop tile([[TILESIZE]] : i32) { +! CHECK: acc.loop tile({[[TILESIZE]] : i32}) { ! CHECK: fir.do_loop ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -684,7 +684,7 @@ subroutine acc_parallel_loop ! CHECK: acc.parallel { ! CHECK: [[TILESIZEM1:%.*]] = arith.constant -1 : i32 -! CHECK: acc.loop tile([[TILESIZEM1]] : i32) { +! CHECK: acc.loop tile({[[TILESIZEM1]] : i32}) { ! CHECK: fir.do_loop ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -701,7 +701,7 @@ subroutine acc_parallel_loop ! CHECK: acc.parallel { ! CHECK: [[TILESIZE1:%.*]] = arith.constant 2 : i32 ! CHECK: [[TILESIZE2:%.*]] = arith.constant 2 : i32 -! CHECK: acc.loop tile([[TILESIZE1]], [[TILESIZE2]] : i32, i32) { +! CHECK: acc.loop tile({[[TILESIZE1]] : i32, [[TILESIZE2]] : i32}) { ! CHECK: fir.do_loop ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -714,7 +714,7 @@ subroutine acc_parallel_loop END DO ! CHECK: acc.parallel { -! CHECK: acc.loop tile(%{{.*}} : i32) { +! CHECK: acc.loop tile({%{{.*}} : i32}) { ! CHECK: fir.do_loop ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -729,7 +729,7 @@ subroutine acc_parallel_loop END DO ! CHECK: acc.parallel { -! CHECK: acc.loop tile(%{{.*}}, %{{.*}} : i32, i32) { +! CHECK: acc.loop tile({%{{.*}} : i32, %{{.*}} : i32}) { ! CHECK: fir.do_loop ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} diff --git a/flang/test/Lower/OpenACC/acc-reduction.f90 b/flang/test/Lower/OpenACC/acc-reduction.f90 index a8f7e1fa81ef..dcfa77c9f97d 100644 --- a/flang/test/Lower/OpenACC/acc-reduction.f90 +++ b/flang/test/Lower/OpenACC/acc-reduction.f90 @@ -743,7 +743,7 @@ end subroutine ! FIR: %[[RED_ARG1:.*]] = acc.reduction varPtr(%[[ARG1]] : !fir.ref>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref> {name = "b"} ! HLFIR: %[[RED_ARG1:.*]] = acc.reduction varPtr(%[[DECLARG1]]#1 : !fir.ref>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref> {name = "b"} ! CHECK: acc.loop reduction(@reduction_add_section_ext100xext10_ref_100x10xi32 -> %[[RED_ARG1]] : !fir.ref>) { -! CHECK: } attributes {collapse = 2 : i64} +! CHECK: } attributes {collapse = [2], collapseDeviceType = [#acc.device_type]} subroutine acc_reduction_add_int_array_3d(a, b) integer :: a(100, 10, 2), b(100, 10, 2) @@ -765,7 +765,7 @@ end subroutine ! FIR: %[[RED_ARG1:.*]] = acc.reduction varPtr(%[[ARG1]] : !fir.ref>) bounds(%{{.*}}, %{{.*}}, %{{.*}}) -> !fir.ref> {name = "b"} ! HLFIR: %[[RED_ARG1:.*]] = acc.reduction varPtr(%[[DECLARG1]]#1 : !fir.ref>) bounds(%{{.*}}, %{{.*}}, %{{.*}}) -> !fir.ref> {name = "b"} ! CHECK: acc.loop reduction(@reduction_add_section_ext100xext10xext2_ref_100x10x2xi32 -> %[[RED_ARG1]] : !fir.ref>) -! CHECK: } attributes {collapse = 3 : i64} +! CHECK: } attributes {collapse = [3], collapseDeviceType = [#acc.device_type]} subroutine acc_reduction_add_float(a, b) real :: a(100), b @@ -938,7 +938,7 @@ end subroutine ! FIR: %[[RED_ARG1:.*]] = acc.reduction varPtr(%[[ARG1]] : !fir.ref>) bounds(%3, %5) -> !fir.ref> {name = "b"} ! HLFIR: %[[RED_ARG1:.*]] = acc.reduction varPtr(%[[DECLARG1]]#1 : !fir.ref>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref> {name = "b"} ! CHECK: acc.loop reduction(@reduction_min_section_ext100xext10_ref_100x10xf32 -> %[[RED_ARG1]] : !fir.ref>) -! CHECK: attributes {collapse = 2 : i64} +! CHECK: attributes {collapse = [2], collapseDeviceType = [#acc.device_type]} subroutine acc_reduction_max_int(a, b) integer :: a(100) diff --git a/flang/test/Lower/OpenACC/acc-serial-loop.f90 b/flang/test/Lower/OpenACC/acc-serial-loop.f90 index 712bfc80ce38..6041e7fb1b49 100644 --- a/flang/test/Lower/OpenACC/acc-serial-loop.f90 +++ b/flang/test/Lower/OpenACC/acc-serial-loop.f90 @@ -411,7 +411,7 @@ subroutine acc_serial_loop ! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield -! CHECK-NEXT: } attributes {seq} +! CHECK-NEXT: } attributes {seq = [#acc.device_type]} ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -424,7 +424,7 @@ subroutine acc_serial_loop ! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield -! CHECK-NEXT: } attributes {auto} +! CHECK-NEXT: } attributes {auto_ = [#acc.device_type]} ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -437,7 +437,7 @@ subroutine acc_serial_loop ! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield -! CHECK-NEXT: } attributes {independent} +! CHECK-NEXT: } attributes {independent = [#acc.device_type]} ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -447,10 +447,10 @@ subroutine acc_serial_loop END DO ! CHECK: acc.serial { -! CHECK: acc.loop gang { +! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield -! CHECK-NEXT: }{{$}} +! CHECK-NEXT: } attributes {gang = [#acc.device_type]}{{$}} ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -461,7 +461,7 @@ subroutine acc_serial_loop ! CHECK: acc.serial { ! CHECK: [[GANGNUM1:%.*]] = arith.constant 8 : i32 -! CHECK-NEXT: acc.loop gang(num=[[GANGNUM1]] : i32) { +! CHECK-NEXT: acc.loop gang({num=[[GANGNUM1]] : i32}) { ! CHECK: fir.do_loop ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -475,7 +475,7 @@ subroutine acc_serial_loop ! CHECK: acc.serial { ! CHECK: [[GANGNUM2:%.*]] = fir.load %{{.*}} : !fir.ref -! CHECK-NEXT: acc.loop gang(num=[[GANGNUM2]] : i32) { +! CHECK-NEXT: acc.loop gang({num=[[GANGNUM2]] : i32}) { ! CHECK: fir.do_loop ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -488,7 +488,7 @@ subroutine acc_serial_loop END DO ! CHECK: acc.serial { -! CHECK: acc.loop gang(num=%{{.*}} : i32, static=%{{.*}} : i32) { +! CHECK: acc.loop gang({num=%{{.*}} : i32, static=%{{.*}} : i32}) { ! CHECK: fir.do_loop ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -500,10 +500,10 @@ subroutine acc_serial_loop a(i) = b(i) END DO ! CHECK: acc.serial { -! CHECK: acc.loop vector { +! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield -! CHECK-NEXT: }{{$}} +! CHECK-NEXT: } attributes {vector = [#acc.device_type]}{{$}} ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -541,10 +541,10 @@ subroutine acc_serial_loop END DO ! CHECK: acc.serial { -! CHECK: acc.loop worker { +! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield -! CHECK-NEXT: }{{$}} +! CHECK-NEXT: } attributes {worker = [#acc.device_type]}{{$}} ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -574,7 +574,7 @@ subroutine acc_serial_loop ! CHECK: fir.do_loop ! CHECK: fir.do_loop ! CHECK: acc.yield -! CHECK-NEXT: } attributes {collapse = 2 : i64} +! CHECK-NEXT: } attributes {collapse = [2], collapseDeviceType = [#acc.device_type]} ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -605,7 +605,7 @@ subroutine acc_serial_loop ! CHECK: acc.serial { ! CHECK: [[TILESIZE:%.*]] = arith.constant 2 : i32 -! CHECK: acc.loop tile([[TILESIZE]] : i32) { +! CHECK: acc.loop tile({[[TILESIZE]] : i32}) { ! CHECK: fir.do_loop ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -619,7 +619,7 @@ subroutine acc_serial_loop ! CHECK: acc.serial { ! CHECK: [[TILESIZEM1:%.*]] = arith.constant -1 : i32 -! CHECK: acc.loop tile([[TILESIZEM1]] : i32) { +! CHECK: acc.loop tile({[[TILESIZEM1]] : i32}) { ! CHECK: fir.do_loop ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -636,7 +636,7 @@ subroutine acc_serial_loop ! CHECK: acc.serial { ! CHECK: [[TILESIZE1:%.*]] = arith.constant 2 : i32 ! CHECK: [[TILESIZE2:%.*]] = arith.constant 2 : i32 -! CHECK: acc.loop tile([[TILESIZE1]], [[TILESIZE2]] : i32, i32) { +! CHECK: acc.loop tile({[[TILESIZE1]] : i32, [[TILESIZE2]] : i32}) { ! CHECK: fir.do_loop ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -649,7 +649,7 @@ subroutine acc_serial_loop END DO ! CHECK: acc.serial { -! CHECK: acc.loop tile(%{{.*}} : i32) { +! CHECK: acc.loop tile({%{{.*}} : i32}) { ! CHECK: fir.do_loop ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -664,7 +664,7 @@ subroutine acc_serial_loop END DO ! CHECK: acc.serial { -! CHECK: acc.loop tile(%{{.*}}, %{{.*}} : i32, i32) { +! CHECK: acc.loop tile({%{{.*}} : i32, %{{.*}} : i32}) { ! CHECK: fir.do_loop ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td index 1dd83e933034..e6954062a50e 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -196,6 +196,27 @@ def DeviceTypeArrayAttr : let constBuilderCall = ?; } +// Gang arg type enumeration +def OpenACC_GangArgNum : I32EnumAttrCase<"Num", 0, "Num">; +def OpenACC_GangArgDim : I32EnumAttrCase<"Dim", 1, "Dim">; +def OpenACC_GangArgStatic : I32EnumAttrCase<"Static", 2, "Static">; + +def OpenACC_GangArgType : I32EnumAttr<"GangArgType", + "Differentiate the different gang arg values", + [OpenACC_GangArgNum, OpenACC_GangArgDim, OpenACC_GangArgStatic]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::acc"; +} +def OpenACC_GangArgTypeAttr : EnumAttr { + let assemblyFormat = [{ ```<` $value `>` }]; +} +def GangArgTypeArrayAttr : + TypedArrayAttrBase { + let constBuilderCall = ?; +} + // Define a resource for the OpenACC runtime counters. def OpenACC_RuntimeCounters : Resource<"::mlir::acc::RuntimeCounters">; @@ -1462,7 +1483,7 @@ def OpenACC_LoopOp : OpenACC_Op<"loop", Example: ```mlir - acc.loop gang vector { + acc.loop { scf.for %arg3 = %c0 to %c10 step %c1 { scf.for %arg4 = %c0 to %c10 step %c1 { scf.for %arg5 = %c0 to %c10 step %c1 { @@ -1471,23 +1492,33 @@ def OpenACC_LoopOp : OpenACC_Op<"loop", } } acc.yield - } attributes { collapse = 3 } + } attributes { + collapse = [3], gang = [#acc.device_type], + vector = [#acc.device_type] + } ``` }]; - let arguments = (ins OptionalAttr:$collapse, - Optional:$gangNum, - Optional:$gangDim, - Optional:$gangStatic, - Optional:$workerNum, - Optional:$vectorLength, - UnitAttr:$seq, - UnitAttr:$independent, - UnitAttr:$auto_, - UnitAttr:$hasGang, - UnitAttr:$hasWorker, - UnitAttr:$hasVector, + let arguments = (ins + OptionalAttr:$collapse, + OptionalAttr:$collapseDeviceType, + Variadic:$gangOperands, + OptionalAttr:$gangOperandsArgType, + OptionalAttr:$gangOperandsSegments, + OptionalAttr:$gangOperandsDeviceType, + Variadic:$workerNumOperands, + OptionalAttr:$workerNumOperandsDeviceType, + Variadic:$vectorOperands, + OptionalAttr:$vectorOperandsDeviceType, + OptionalAttr:$seq, + OptionalAttr:$independent, + OptionalAttr:$auto_, + OptionalAttr:$gang, + OptionalAttr:$worker, + OptionalAttr:$vector, Variadic:$tileOperands, + OptionalAttr:$tileOperandsSegments, + OptionalAttr:$tileOperandsDeviceType, Variadic:$cacheOperands, Variadic:$privateOperands, OptionalAttr:$privatizations, @@ -1510,18 +1541,90 @@ def OpenACC_LoopOp : OpenACC_Op<"loop", /// The i-th data operand passed. Value getDataOperand(unsigned i); + + /// Return true if the op has the auto attribute for the + /// mlir::acc::DeviceType::None device_type. + bool hasAuto(); + /// Return true if the op has the auto attribute for the given device_type. + bool hasAuto(mlir::acc::DeviceType deviceType); + /// Return true if the op has the independent attribute for the + /// mlir::acc::DeviceType::None device_type. + bool hasIndependent(); + /// Return true if the op has the independent attribute for the given + /// device_type. + bool hasIndependent(mlir::acc::DeviceType deviceType); + /// Return true if the op has the seq attribute for the + /// mlir::acc::DeviceType::None device_type. + bool hasSeq(); + /// Return true if the op has the seq attribute for the given device_type. + bool hasSeq(mlir::acc::DeviceType deviceType); + + /// Return the value of the vector clause if present. + mlir::Value getVectorValue(); + /// Return the value of the vector clause for the given device_type + /// if present. + mlir::Value getVectorValue(mlir::acc::DeviceType deviceType); + /// Return true if the op has the vector attribute for the + /// mlir::acc::DeviceType::None device_type. + bool hasVector(); + /// Return true if the op has the vector attribute for the given + /// device_type. + bool hasVector(mlir::acc::DeviceType deviceType); + + /// Return the value of the worker clause if present. + mlir::Value getWorkerValue(); + /// Return the value of the worker clause for the given device_type + /// if present. + mlir::Value getWorkerValue(mlir::acc::DeviceType deviceType); + /// Return true if the op has the worker attribute for the + /// mlir::acc::DeviceType::None device_type. + bool hasWorker(); + /// Return true if the op has the worker attribute for the given + /// device_type. + bool hasWorker(mlir::acc::DeviceType deviceType); + + /// Return the values of the tile clause if present. + mlir::Operation::operand_range getTileValues(); + /// Return the values of the tile clause for the given device_type if + /// present. + mlir::Operation::operand_range + getTileValues(mlir::acc::DeviceType deviceType); + + /// Return the value of the collapse clause if present. + std::optional getCollapseValue(); + /// Return the value of the collapse clause for the given device_type + /// if present. + std::optional getCollapseValue(mlir::acc::DeviceType deviceType); + + /// Return true if the op has the gang attribute for the + /// mlir::acc::DeviceType::None device_type. + bool hasGang(); + /// Return true if the op has the gang attribute for the given + /// device_type. + bool hasGang(mlir::acc::DeviceType deviceType); + + /// Return the value of the worker clause if present. + mlir::Value getGangValue(mlir::acc::GangArgType gangArgType); + /// Return the value of the worker clause for the given device_type + /// if present. + mlir::Value getGangValue(mlir::acc::GangArgType gangArgType, mlir::acc::DeviceType deviceType); }]; let hasCustomAssemblyFormat = 1; let assemblyFormat = [{ oilist( - `gang` `` custom($gangNum, type($gangNum), $gangDim, type($gangDim), $gangStatic, type($gangStatic), $hasGang) - | `worker` `` custom($workerNum, type($workerNum), $hasWorker) - | `vector` `` custom($vectorLength, type($vectorLength), $hasVector) + `gang` `` `(` custom($gangOperands, type($gangOperands), + $gangOperandsArgType, $gangOperandsDeviceType, + $gangOperandsSegments) `)` + | `worker` `` `(` custom($workerNumOperands, + type($workerNumOperands), $workerNumOperandsDeviceType) `)` + | `vector` `` `(` custom($vectorOperands, + type($vectorOperands), $vectorOperandsDeviceType) `)` | `private` `(` custom( - $privateOperands, type($privateOperands), $privatizations) + $privateOperands, type($privateOperands), $privatizations) `)` + | `tile` `(` custom($tileOperands, + type($tileOperands), $tileOperandsDeviceType, $tileOperandsSegments) `)` - | `tile` `(` $tileOperands `:` type($tileOperands) `)` | `reduction` `(` custom( $reductionOperands, type($reductionOperands), $reductionRecipes) `)` diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index 66605ead0529..c53673fa4260 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -16,6 +16,7 @@ #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; @@ -886,6 +887,8 @@ static ParseResult parseDeviceTypeOperandsWithSegment( if (failed(parser.parseLBrace())) return failure(); + int32_t crtOperandsSize = operands.size(); + if (failed(parser.parseCommaSeparatedList( mlir::AsmParser::Delimiter::None, [&]() { if (parser.parseOperand(operands.emplace_back()) || @@ -895,7 +898,7 @@ static ParseResult parseDeviceTypeOperandsWithSegment( }))) return failure(); - seg.push_back(operands.size()); + seg.push_back(operands.size() - crtOperandsSize); if (failed(parser.parseRBrace())) return failure(); @@ -1207,16 +1210,19 @@ void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results, // LoopOp //===----------------------------------------------------------------------===// -static ParseResult -parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, - std::optional &value, - Type &valueType, bool &needComa, bool &newValue) { +static ParseResult parseGangValue( + OpAsmParser &parser, llvm::StringRef keyword, + llvm::SmallVectorImpl &operands, + llvm::SmallVectorImpl &types, + llvm::SmallVector &attributes, GangArgTypeAttr gangArgType, + bool &needComa, bool &newValue) { if (succeeded(parser.parseOptionalKeyword(keyword))) { if (parser.parseEqual()) return failure(); - value = OpAsmParser::UnresolvedOperand{}; - if (parser.parseOperand(*value) || parser.parseColonType(valueType)) + if (parser.parseOperand(operands.emplace_back()) || + parser.parseColonType(types.emplace_back())) return failure(); + attributes.push_back(gangArgType); needComa = true; newValue = true; } @@ -1224,19 +1230,27 @@ parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, } static ParseResult parseGangClause( - OpAsmParser &parser, std::optional &gangNum, - Type &gangNumType, std::optional &gangDim, - Type &gangDimType, - std::optional &gangStatic, - Type &gangStaticType, UnitAttr &hasGang) { - hasGang = UnitAttr::get(parser.getBuilder().getContext()); - gangNum = std::nullopt; - gangDim = std::nullopt; - gangStatic = std::nullopt; + OpAsmParser &parser, + llvm::SmallVectorImpl &gangOperands, + llvm::SmallVectorImpl &gangOperandsType, mlir::ArrayAttr &gangArgType, + mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments) { + llvm::SmallVector attributes; + llvm::SmallVector deviceTypeAttributes; + llvm::SmallVector seg; bool needComa = false; - // optional gang operands - if (succeeded(parser.parseOptionalLParen())) { + auto argNum = mlir::acc::GangArgTypeAttr::get(parser.getContext(), + mlir::acc::GangArgType::Num); + auto argDim = mlir::acc::GangArgTypeAttr::get(parser.getContext(), + mlir::acc::GangArgType::Dim); + auto argStatic = mlir::acc::GangArgTypeAttr::get( + parser.getContext(), mlir::acc::GangArgType::Static); + + do { + if (failed(parser.parseLBrace())) + return failure(); + + int32_t crtOperandsSize = gangOperands.size(); while (true) { bool newValue = false; bool needValue = false; @@ -1247,15 +1261,17 @@ static ParseResult parseGangClause( break; } - if (failed(parseGangValue(parser, LoopOp::getGangNumKeyword(), gangNum, - gangNumType, needComa, newValue))) + if (failed(parseGangValue(parser, LoopOp::getGangNumKeyword(), + gangOperands, gangOperandsType, attributes, + argNum, needComa, newValue))) return failure(); - if (failed(parseGangValue(parser, LoopOp::getGangDimKeyword(), gangDim, - gangDimType, needComa, newValue))) + if (failed(parseGangValue(parser, LoopOp::getGangDimKeyword(), + gangOperands, gangOperandsType, attributes, + argDim, needComa, newValue))) return failure(); if (failed(parseGangValue(parser, LoopOp::getGangStaticKeyword(), - gangStatic, gangStaticType, needComa, - newValue))) + gangOperands, gangOperandsType, attributes, + argStatic, needComa, newValue))) return failure(); if (!newValue && needValue) { @@ -1268,86 +1284,168 @@ static ParseResult parseGangClause( break; } - if (!gangNum && !gangDim && !gangStatic) { - parser.emitError(parser.getCurrentLocation(), - "expect at least one of num, dim or static values"); + if (gangOperands.empty()) + return parser.emitError( + parser.getCurrentLocation(), + "expect at least one of num, dim or static values"); + + if (failed(parser.parseRBrace())) return failure(); + + if (succeeded(parser.parseOptionalLSquare())) { + if (parser.parseAttribute(deviceTypeAttributes.emplace_back()) || + parser.parseRSquare()) + return failure(); + } else { + deviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get( + parser.getContext(), mlir::acc::DeviceType::None)); } - if (failed(parser.parseRParen())) - return failure(); - } + seg.push_back(gangOperands.size() - crtOperandsSize); + + } while (succeeded(parser.parseOptionalComma())); + + llvm::SmallVector arrayAttr(attributes.begin(), + attributes.end()); + gangArgType = ArrayAttr::get(parser.getContext(), arrayAttr); + + llvm::SmallVector deviceTypeAttr( + deviceTypeAttributes.begin(), deviceTypeAttributes.end()); + deviceType = ArrayAttr::get(parser.getContext(), deviceTypeAttr); + segments = DenseI32ArrayAttr::get(parser.getContext(), seg); return success(); } -void printGangClause(OpAsmPrinter &p, Operation *op, Value gangNum, - Type gangNumType, Value gangDim, Type gangDimType, - Value gangStatic, Type gangStaticType, UnitAttr hasGang) { - if (gangNum || gangStatic || gangDim) { - p << "("; - if (gangNum) { - p << LoopOp::getGangNumKeyword() << "=" << gangNum << " : " - << gangNumType; - if (gangStatic || gangDim) +void printGangClause(OpAsmPrinter &p, Operation *op, + mlir::OperandRange operands, mlir::TypeRange types, + std::optional gangArgTypes, + std::optional deviceTypes, + std::optional segments) { + unsigned opIdx = 0; + for (unsigned i = 0; i < deviceTypes->size(); ++i) { + if (i != 0) + p << ", "; + p << "{"; + for (int32_t j = 0; j < (*segments)[i]; ++j) { + if (j != 0) p << ", "; + auto gangArgTypeAttr = + mlir::dyn_cast((*gangArgTypes)[opIdx]); + if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num) + p << LoopOp::getGangNumKeyword(); + else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim) + p << LoopOp::getGangDimKeyword(); + else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Static) + p << LoopOp::getGangStaticKeyword(); + p << "=" << operands[opIdx] << " : " << operands[opIdx].getType(); + ++opIdx; } - if (gangDim) { - p << LoopOp::getGangDimKeyword() << "=" << gangDim << " : " - << gangDimType; - if (gangStatic) - p << ", "; - } - if (gangStatic) - p << LoopOp::getGangStaticKeyword() << "=" << gangStatic << " : " - << gangStaticType; - p << ")"; + + p << "}"; + auto deviceTypeAttr = + mlir::dyn_cast((*deviceTypes)[i]); + if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None) + p << " [" << (*deviceTypes)[i] << "]"; } } -static ParseResult -parseWorkerClause(OpAsmParser &parser, - std::optional &workerNum, - Type &workerNumType, UnitAttr &hasWorker) { - hasWorker = UnitAttr::get(parser.getBuilder().getContext()); - if (succeeded(parser.parseOptionalLParen())) { - workerNum = OpAsmParser::UnresolvedOperand{}; - if (parser.parseOperand(*workerNum) || - parser.parseColonType(workerNumType) || parser.parseRParen()) +bool hasDuplicateDeviceTypes( + std::optional segments, + llvm::SmallSet &deviceTypes) { + if (!segments) + return false; + for (auto attr : *segments) { + auto deviceTypeAttr = mlir::dyn_cast(attr); + if (deviceTypes.contains(deviceTypeAttr.getValue())) + return true; + deviceTypes.insert(deviceTypeAttr.getValue()); + } + return false; +} + +/// Check for duplicates in the DeviceType array attribute. +LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes) { + llvm::SmallSet crtDeviceTypes; + if (!deviceTypes) + return success(); + for (auto attr : deviceTypes) { + auto deviceTypeAttr = + mlir::dyn_cast_or_null(attr); + if (!deviceTypeAttr) return failure(); + if (crtDeviceTypes.contains(deviceTypeAttr.getValue())) + return failure(); + crtDeviceTypes.insert(deviceTypeAttr.getValue()); } return success(); } -void printWorkerClause(OpAsmPrinter &p, Operation *op, Value workerNum, - Type workerNumType, UnitAttr hasWorker) { - if (workerNum) - p << "(" << workerNum << " : " << workerNumType << ")"; -} - -static ParseResult -parseVectorClause(OpAsmParser &parser, - std::optional &vectorLength, - Type &vectorLengthType, UnitAttr &hasVector) { - hasVector = UnitAttr::get(parser.getBuilder().getContext()); - if (succeeded(parser.parseOptionalLParen())) { - vectorLength = OpAsmParser::UnresolvedOperand{}; - if (parser.parseOperand(*vectorLength) || - parser.parseColonType(vectorLengthType) || parser.parseRParen()) - return failure(); - } - return success(); -} - -void printVectorClause(OpAsmPrinter &p, Operation *op, Value vectorLength, - Type vectorLengthType, UnitAttr hasVector) { - if (vectorLength) - p << "(" << vectorLength << " : " << vectorLengthType << ")"; -} - LogicalResult acc::LoopOp::verify() { + // Check collapse + if (getCollapseAttr() && !getCollapseDeviceTypeAttr()) + return emitOpError() << "collapse device_type attr must be define when" + << " collapse attr is present"; + + if (getCollapseAttr() && getCollapseDeviceTypeAttr() && + getCollapseAttr().getValue().size() != + getCollapseDeviceTypeAttr().getValue().size()) + return emitOpError() << "collapse attribute count must match collapse" + << " device_type count"; + if (failed(checkDeviceTypes(getCollapseDeviceTypeAttr()))) + return emitOpError() + << "duplicate device_type found in collapseDeviceType attribute"; + + // Check gang + if (!getGangOperands().empty()) { + if (!getGangOperandsArgType()) + return emitOpError() << "gangOperandsArgType attribute must be defined" + << " when gang operands are present"; + + if (getGangOperands().size() != + getGangOperandsArgTypeAttr().getValue().size()) + return emitOpError() << "gangOperandsArgType attribute count must match" + << " gangOperands count"; + } + if (getGangAttr() && failed(checkDeviceTypes(getGangAttr()))) + return emitOpError() << "duplicate device_type found in gang attribute"; + + if (failed(verifyDeviceTypeAndSegmentCountMatch( + *this, getGangOperands(), getGangOperandsSegmentsAttr(), + getGangOperandsDeviceTypeAttr(), "gang"))) + return failure(); + + // Check worker + if (failed(checkDeviceTypes(getWorkerAttr()))) + return emitOpError() << "duplicate device_type found in worker attribute"; + if (failed(checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr()))) + return emitOpError() << "duplicate device_type found in " + "workerNumOperandsDeviceType attribute"; + if (failed(verifyDeviceTypeCountMatch(*this, getWorkerNumOperands(), + getWorkerNumOperandsDeviceTypeAttr(), + "worker"))) + return failure(); + + // Check vector + if (failed(checkDeviceTypes(getVectorAttr()))) + return emitOpError() << "duplicate device_type found in vector attribute"; + if (failed(checkDeviceTypes(getVectorOperandsDeviceTypeAttr()))) + return emitOpError() << "duplicate device_type found in " + "vectorOperandsDeviceType attribute"; + if (failed(verifyDeviceTypeCountMatch(*this, getVectorOperands(), + getVectorOperandsDeviceTypeAttr(), + "vector"))) + return failure(); + + if (failed(verifyDeviceTypeAndSegmentCountMatch( + *this, getTileOperands(), getTileOperandsSegmentsAttr(), + getTileOperandsDeviceTypeAttr(), "tile"))) + return failure(); + // auto, independent and seq attribute are mutually exclusive. - if ((getAuto_() && (getIndependent() || getSeq())) || - (getIndependent() && getSeq())) { + llvm::SmallSet deviceTypes; + if (hasDuplicateDeviceTypes(getAuto_(), deviceTypes) || + hasDuplicateDeviceTypes(getIndependent(), deviceTypes) || + hasDuplicateDeviceTypes(getSeq(), deviceTypes)) { return emitError() << "only one of \"" << acc::LoopOp::getAutoAttrStrName() << "\", " << getIndependentAttrName() << ", " << getSeqAttrName() @@ -1355,8 +1453,24 @@ LogicalResult acc::LoopOp::verify() { } // Gang, worker and vector are incompatible with seq. - if (getSeq() && (getHasGang() || getHasWorker() || getHasVector())) - return emitError("gang, worker or vector cannot appear with the seq attr"); + if (getSeqAttr()) { + for (auto attr : getSeqAttr()) { + auto deviceTypeAttr = mlir::dyn_cast(attr); + if (hasVector(deviceTypeAttr.getValue()) || + getVectorValue(deviceTypeAttr.getValue()) || + hasWorker(deviceTypeAttr.getValue()) || + getWorkerValue(deviceTypeAttr.getValue()) || + hasGang(deviceTypeAttr.getValue()) || + getGangValue(mlir::acc::GangArgType::Num, + deviceTypeAttr.getValue()) || + getGangValue(mlir::acc::GangArgType::Dim, + deviceTypeAttr.getValue()) || + getGangValue(mlir::acc::GangArgType::Static, + deviceTypeAttr.getValue())) + return emitError() + << "gang, worker or vector cannot appear with the seq attr"; + } + } if (failed(checkSymOperandList( *this, getPrivatizations(), getPrivateOperands(), "private", @@ -1380,16 +1494,149 @@ unsigned LoopOp::getNumDataOperands() { } Value LoopOp::getDataOperand(unsigned i) { - unsigned numOptional = getGangNum() ? 1 : 0; - numOptional += getGangDim() ? 1 : 0; - numOptional += getGangStatic() ? 1 : 0; - numOptional += getVectorLength() ? 1 : 0; - numOptional += getWorkerNum() ? 1 : 0; + unsigned numOptional = getGangOperands().size(); + numOptional += getVectorOperands().size(); + numOptional += getWorkerNumOperands().size(); numOptional += getTileOperands().size(); numOptional += getCacheOperands().size(); return getOperand(numOptional + i); } +bool LoopOp::hasAuto() { return hasAuto(mlir::acc::DeviceType::None); } + +bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) { + if (auto arrayAttr = getAuto_()) { + if (findSegment(*arrayAttr, deviceType)) + return true; + } + return false; +} + +bool LoopOp::hasIndependent() { + return hasIndependent(mlir::acc::DeviceType::None); +} + +bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) { + if (auto arrayAttr = getIndependent()) { + if (findSegment(*arrayAttr, deviceType)) + return true; + } + return false; +} + +bool LoopOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); } + +bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) { + if (auto arrayAttr = getSeq()) { + if (findSegment(*arrayAttr, deviceType)) + return true; + } + return false; +} + +mlir::Value LoopOp::getVectorValue() { + return getVectorValue(mlir::acc::DeviceType::None); +} + +mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) { + return getValueInDeviceTypeSegment(getVectorOperandsDeviceType(), + getVectorOperands(), deviceType); +} + +bool LoopOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); } + +bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) { + if (auto arrayAttr = getVector()) { + if (findSegment(*arrayAttr, deviceType)) + return true; + } + return false; +} + +mlir::Value LoopOp::getWorkerValue() { + return getWorkerValue(mlir::acc::DeviceType::None); +} + +mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) { + return getValueInDeviceTypeSegment(getWorkerNumOperandsDeviceType(), + getWorkerNumOperands(), deviceType); +} + +bool LoopOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); } + +bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) { + if (auto arrayAttr = getWorker()) { + if (findSegment(*arrayAttr, deviceType)) + return true; + } + return false; +} + +mlir::Operation::operand_range LoopOp::getTileValues() { + return getTileValues(mlir::acc::DeviceType::None); +} + +mlir::Operation::operand_range +LoopOp::getTileValues(mlir::acc::DeviceType deviceType) { + return getValuesFromSegments(getTileOperandsDeviceType(), getTileOperands(), + getTileOperandsSegments(), deviceType); +} + +std::optional LoopOp::getCollapseValue() { + return getCollapseValue(mlir::acc::DeviceType::None); +} + +std::optional +LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) { + if (!getCollapseAttr()) + return std::nullopt; + if (auto pos = findSegment(getCollapseDeviceTypeAttr(), deviceType)) { + auto intAttr = + mlir::dyn_cast(getCollapseAttr().getValue()[*pos]); + return intAttr.getValue().getZExtValue(); + } + return std::nullopt; +} + +mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) { + return getGangValue(gangArgType, mlir::acc::DeviceType::None); +} + +mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType, + mlir::acc::DeviceType deviceType) { + if (getGangOperands().empty()) + return {}; + if (auto pos = findSegment(*getGangOperandsDeviceType(), deviceType)) { + int32_t nbOperandsBefore = 0; + for (unsigned i = 0; i < *pos; ++i) + nbOperandsBefore += (*getGangOperandsSegments())[i]; + mlir::Operation::operand_range values = + getGangOperands() + .drop_front(nbOperandsBefore) + .take_front((*getGangOperandsSegments())[*pos]); + + int32_t argTypeIdx = nbOperandsBefore; + for (auto value : values) { + auto gangArgTypeAttr = mlir::dyn_cast( + (*getGangOperandsArgType())[argTypeIdx]); + if (gangArgTypeAttr.getValue() == gangArgType) + return value; + ++argTypeIdx; + } + } + return {}; +} + +bool LoopOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); } + +bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) { + if (auto arrayAttr = getGang()) { + if (findSegment(*arrayAttr, deviceType)) + return true; + } + return false; +} + //===----------------------------------------------------------------------===// // DataOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir index c18d964b370f..5dcdb3a37e4e 100644 --- a/mlir/test/Dialect/OpenACC/invalid.mlir +++ b/mlir/test/Dialect/OpenACC/invalid.mlir @@ -1,58 +1,58 @@ // RUN: mlir-opt -split-input-file -verify-diagnostics %s // expected-error@+1 {{gang, worker or vector cannot appear with the seq attr}} -acc.loop gang { +acc.loop { "test.openacc_dummy_op"() : () -> () acc.yield -} attributes {seq} +} attributes {seq = [#acc.device_type], gang = [#acc.device_type]} // ----- // expected-error@+1 {{gang, worker or vector cannot appear with the seq attr}} -acc.loop worker { +acc.loop { "test.openacc_dummy_op"() : () -> () acc.yield -} attributes {seq} +} attributes {seq = [#acc.device_type], worker = [#acc.device_type]} // ----- // expected-error@+1 {{gang, worker or vector cannot appear with the seq attr}} -acc.loop vector { +acc.loop { "test.openacc_dummy_op"() : () -> () acc.yield -} attributes {seq} +} attributes {seq = [#acc.device_type], vector = [#acc.device_type]} // ----- // expected-error@+1 {{gang, worker or vector cannot appear with the seq attr}} -acc.loop gang worker { +acc.loop { "test.openacc_dummy_op"() : () -> () acc.yield -} attributes {seq} +} attributes {seq = [#acc.device_type], worker = [#acc.device_type], gang = [#acc.device_type]} // ----- // expected-error@+1 {{gang, worker or vector cannot appear with the seq attr}} -acc.loop gang vector { +acc.loop { "test.openacc_dummy_op"() : () -> () acc.yield -} attributes {seq} +} attributes {seq = [#acc.device_type], vector = [#acc.device_type], gang = [#acc.device_type]} // ----- // expected-error@+1 {{gang, worker or vector cannot appear with the seq attr}} -acc.loop worker vector { +acc.loop { "test.openacc_dummy_op"() : () -> () acc.yield -} attributes {seq} +} attributes {seq = [#acc.device_type], vector = [#acc.device_type], worker = [#acc.device_type]} // ----- // expected-error@+1 {{gang, worker or vector cannot appear with the seq attr}} -acc.loop gang worker vector { +acc.loop { "test.openacc_dummy_op"() : () -> () acc.yield -} attributes {seq} +} attributes {seq = [#acc.device_type], vector = [#acc.device_type], worker = [#acc.device_type], gang = [#acc.device_type]} // ----- @@ -62,10 +62,31 @@ acc.loop { // ----- +// expected-error@+1 {{'acc.loop' op duplicate device_type found in gang attribute}} +acc.loop { + acc.yield +} attributes {gang = [#acc.device_type, #acc.device_type]} + +// ----- + +// expected-error@+1 {{'acc.loop' op duplicate device_type found in worker attribute}} +acc.loop { + acc.yield +} attributes {worker = [#acc.device_type, #acc.device_type]} + +// ----- + +// expected-error@+1 {{'acc.loop' op duplicate device_type found in vector attribute}} +acc.loop { + acc.yield +} attributes {vector = [#acc.device_type, #acc.device_type]} + +// ----- + // expected-error@+1 {{only one of "auto", "independent", "seq" can be present at the same time}} acc.loop { acc.yield -} attributes {auto_, seq} +} attributes {auto_ = [#acc.device_type], seq = [#acc.device_type]} // ----- @@ -368,7 +389,7 @@ acc.firstprivate.recipe @privatization_i32 : i32 init { // ----- // expected-error@+1 {{expected ')'}} -acc.loop gang(static=%i64Value: i64, num=%i64Value: i64 { +acc.loop gang({static=%i64Value: i64, num=%i64Value: i64} { "test.openacc_dummy_op"() : () -> () acc.yield } @@ -437,7 +458,7 @@ acc.reduction.recipe @reduction_i64 : i64 reduction_operator init { // ----- // expected-error@+1 {{new value expected after comma}} -acc.loop gang(static=%i64Value: i64, ) { +acc.loop gang({static=%i64Value: i64, ) { "test.openacc_dummy_op"() : () -> () acc.yield } @@ -454,7 +475,7 @@ func.func @fct1(%0 : !llvm.ptr) -> () { // ----- // expected-error@+1 {{expect at least one of num, dim or static values}} -acc.loop gang() { +acc.loop gang({}) { "test.openacc_dummy_op"() : () -> () acc.yield } diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir index 52375b1af314..ce5bfa490013 100644 --- a/mlir/test/Dialect/OpenACC/ops.mlir +++ b/mlir/test/Dialect/OpenACC/ops.mlir @@ -11,7 +11,7 @@ func.func @compute1(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x %async = arith.constant 1 : i64 acc.parallel async(%async: i64) { - acc.loop gang vector { + acc.loop { scf.for %arg3 = %c0 to %c10 step %c1 { scf.for %arg4 = %c0 to %c10 step %c1 { scf.for %arg5 = %c0 to %c10 step %c1 { @@ -25,7 +25,7 @@ func.func @compute1(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x } } acc.yield - } attributes { collapse = 3 } + } attributes { collapse = [3], collapseDeviceType = [#acc.device_type], vector = [#acc.device_type], gang = [#acc.device_type]} acc.yield } @@ -38,7 +38,7 @@ func.func @compute1(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x // CHECK-NEXT: %{{.*}} = arith.constant 1 : index // CHECK-NEXT: [[ASYNC:%.*]] = arith.constant 1 : i64 // CHECK-NEXT: acc.parallel async([[ASYNC]] : i64) { -// CHECK-NEXT: acc.loop gang vector { +// CHECK-NEXT: acc.loop { // CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { // CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { // CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { @@ -52,7 +52,7 @@ func.func @compute1(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: acc.yield -// CHECK-NEXT: } attributes {collapse = 3 : i64} +// CHECK-NEXT: } attributes {collapse = [3], collapseDeviceType = [#acc.device_type], gang = [#acc.device_type], vector = [#acc.device_type]} // CHECK-NEXT: acc.yield // CHECK-NEXT: } // CHECK-NEXT: return %{{.*}} : memref<10x10xf32> @@ -80,7 +80,7 @@ func.func @compute2(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x } } acc.yield - } attributes {seq} + } attributes {seq = [#acc.device_type]} acc.yield } @@ -106,7 +106,7 @@ func.func @compute2(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: acc.yield -// CHECK-NEXT: } attributes {seq} +// CHECK-NEXT: } attributes {seq = [#acc.device_type]} // CHECK-NEXT: acc.yield // CHECK-NEXT: } // CHECK-NEXT: return %{{.*}} : memref<10x10xf32> @@ -138,9 +138,9 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x acc.data dataOperands(%pa, %pb, %pc, %pd: memref<10x10xf32>, memref<10x10xf32>, memref<10xf32>, memref<10xf32>) { %private = acc.private varPtr(%c : memref<10xf32>) -> memref<10xf32> acc.parallel num_gangs({%numGangs: i64}) num_workers(%numWorkers: i64 [#acc.device_type]) private(@privatization_memref_10_f32 -> %private : memref<10xf32>) { - acc.loop gang { + acc.loop { scf.for %x = %lb to %c10 step %st { - acc.loop worker { + acc.loop { scf.for %y = %lb to %c10 step %st { %axy = memref.load %a[%x, %y] : memref<10x10xf32> %bxy = memref.load %b[%x, %y] : memref<10x10xf32> @@ -148,7 +148,7 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x memref.store %tmp, %c[%y] : memref<10xf32> } acc.yield - } + } attributes {worker = [#acc.device_type]} acc.loop { // for i = 0 to 10 step 1 @@ -160,10 +160,10 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x memref.store %z, %d[%x] : memref<10xf32> } acc.yield - } attributes {seq} + } attributes {seq = [#acc.device_type]} } acc.yield - } + } attributes {gang = [#acc.device_type]} acc.yield } acc.terminator @@ -181,9 +181,9 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x // CHECK: acc.data dataOperands(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<10x10xf32>, memref<10x10xf32>, memref<10xf32>, memref<10xf32>) { // CHECK-NEXT: %[[P_ARG2:.*]] = acc.private varPtr([[ARG2]] : memref<10xf32>) -> memref<10xf32> // CHECK-NEXT: acc.parallel num_gangs({[[NUMGANG]] : i64}) num_workers([[NUMWORKERS]] : i64 [#acc.device_type]) private(@privatization_memref_10_f32 -> %[[P_ARG2]] : memref<10xf32>) { -// CHECK-NEXT: acc.loop gang { +// CHECK-NEXT: acc.loop { // CHECK-NEXT: scf.for %{{.*}} = [[C0]] to [[C10]] step [[C1]] { -// CHECK-NEXT: acc.loop worker { +// CHECK-NEXT: acc.loop { // CHECK-NEXT: scf.for %{{.*}} = [[C0]] to [[C10]] step [[C1]] { // CHECK-NEXT: %{{.*}} = memref.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32> // CHECK-NEXT: %{{.*}} = memref.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32> @@ -191,7 +191,7 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x // CHECK-NEXT: memref.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: acc.yield -// CHECK-NEXT: } +// CHECK-NEXT: } attributes {worker = [#acc.device_type]} // CHECK-NEXT: acc.loop { // CHECK-NEXT: scf.for %{{.*}} = [[C0]] to [[C10]] step [[C1]] { // CHECK-NEXT: %{{.*}} = memref.load %{{.*}}[%{{.*}}] : memref<10xf32> @@ -200,10 +200,10 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x // CHECK-NEXT: memref.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: acc.yield -// CHECK-NEXT: } attributes {seq} +// CHECK-NEXT: } attributes {seq = [#acc.device_type]} // CHECK-NEXT: } // CHECK-NEXT: acc.yield -// CHECK-NEXT: } +// CHECK-NEXT: } attributes {gang = [#acc.device_type]} // CHECK-NEXT: acc.yield // CHECK-NEXT: } // CHECK-NEXT: acc.terminator @@ -218,15 +218,15 @@ func.func @testloopop(%a : memref<10xf32>) -> () { %i32Value = arith.constant 128 : i32 %idxValue = arith.constant 8 : index - acc.loop gang worker vector { + acc.loop { + "test.openacc_dummy_op"() : () -> () + acc.yield + } attributes {vector = [#acc.device_type], worker = [#acc.device_type], gang = [#acc.device_type]} + acc.loop gang({num=%i64Value: i64}) { "test.openacc_dummy_op"() : () -> () acc.yield } - acc.loop gang(num=%i64Value: i64) { - "test.openacc_dummy_op"() : () -> () - acc.yield - } - acc.loop gang(static=%i64Value: i64) { + acc.loop gang({static=%i64Value: i64}) { "test.openacc_dummy_op"() : () -> () acc.yield } @@ -254,31 +254,31 @@ func.func @testloopop(%a : memref<10xf32>) -> () { "test.openacc_dummy_op"() : () -> () acc.yield } - acc.loop gang(num=%i64Value: i64) worker vector { + acc.loop gang({num=%i64Value: i64}) { + "test.openacc_dummy_op"() : () -> () + acc.yield + } attributes {vector = [#acc.device_type], worker = [#acc.device_type]} + acc.loop gang({num=%i64Value: i64, static=%i64Value: i64}) worker(%i64Value: i64) vector(%i64Value: i64) { "test.openacc_dummy_op"() : () -> () acc.yield } - acc.loop gang(num=%i64Value: i64, static=%i64Value: i64) worker(%i64Value: i64) vector(%i64Value: i64) { + acc.loop gang({num=%i32Value: i32, static=%idxValue: index}) { "test.openacc_dummy_op"() : () -> () acc.yield } - acc.loop gang(num=%i32Value: i32, static=%idxValue: index) { + acc.loop tile({%i64Value : i64, %i64Value : i64}) { "test.openacc_dummy_op"() : () -> () acc.yield } - acc.loop tile(%i64Value, %i64Value : i64, i64) { + acc.loop tile({%i32Value : i32, %i32Value : i32}) { "test.openacc_dummy_op"() : () -> () acc.yield } - acc.loop tile(%i32Value, %i32Value : i32, i32) { + acc.loop gang({static=%i64Value: i64, num=%i64Value: i64}) { "test.openacc_dummy_op"() : () -> () acc.yield } - acc.loop gang(static=%i64Value: i64, num=%i64Value: i64) { - "test.openacc_dummy_op"() : () -> () - acc.yield - } - acc.loop gang(dim=%i64Value : i64, static=%i64Value: i64) { + acc.loop gang({dim=%i64Value : i64, static=%i64Value: i64}) { "test.openacc_dummy_op"() : () -> () acc.yield } @@ -293,15 +293,15 @@ func.func @testloopop(%a : memref<10xf32>) -> () { // CHECK: [[I64VALUE:%.*]] = arith.constant 1 : i64 // CHECK-NEXT: [[I32VALUE:%.*]] = arith.constant 128 : i32 // CHECK-NEXT: [[IDXVALUE:%.*]] = arith.constant 8 : index -// CHECK: acc.loop gang worker vector { +// CHECK: acc.loop { +// CHECK-NEXT: "test.openacc_dummy_op"() : () -> () +// CHECK-NEXT: acc.yield +// CHECK-NEXT: } attributes {gang = [#acc.device_type], vector = [#acc.device_type], worker = [#acc.device_type]} +// CHECK: acc.loop gang({num=[[I64VALUE]] : i64}) { // CHECK-NEXT: "test.openacc_dummy_op"() : () -> () // CHECK-NEXT: acc.yield // CHECK-NEXT: } -// CHECK: acc.loop gang(num=[[I64VALUE]] : i64) { -// CHECK-NEXT: "test.openacc_dummy_op"() : () -> () -// CHECK-NEXT: acc.yield -// CHECK-NEXT: } -// CHECK: acc.loop gang(static=[[I64VALUE]] : i64) { +// CHECK: acc.loop gang({static=[[I64VALUE]] : i64}) { // CHECK-NEXT: "test.openacc_dummy_op"() : () -> () // CHECK-NEXT: acc.yield // CHECK-NEXT: } @@ -329,31 +329,31 @@ func.func @testloopop(%a : memref<10xf32>) -> () { // CHECK-NEXT: "test.openacc_dummy_op"() : () -> () // CHECK-NEXT: acc.yield // CHECK-NEXT: } -// CHECK: acc.loop gang(num=[[I64VALUE]] : i64) worker vector { +// CHECK: acc.loop gang({num=[[I64VALUE]] : i64}) { +// CHECK-NEXT: "test.openacc_dummy_op"() : () -> () +// CHECK-NEXT: acc.yield +// CHECK-NEXT: } attributes {vector = [#acc.device_type], worker = [#acc.device_type]} +// CHECK: acc.loop gang({num=[[I64VALUE]] : i64, static=[[I64VALUE]] : i64}) worker([[I64VALUE]] : i64) vector([[I64VALUE]] : i64) { // CHECK-NEXT: "test.openacc_dummy_op"() : () -> () // CHECK-NEXT: acc.yield // CHECK-NEXT: } -// CHECK: acc.loop gang(num=[[I64VALUE]] : i64, static=[[I64VALUE]] : i64) worker([[I64VALUE]] : i64) vector([[I64VALUE]] : i64) { +// CHECK: acc.loop gang({num=[[I32VALUE]] : i32, static=[[IDXVALUE]] : index}) { // CHECK-NEXT: "test.openacc_dummy_op"() : () -> () // CHECK-NEXT: acc.yield // CHECK-NEXT: } -// CHECK: acc.loop gang(num=[[I32VALUE]] : i32, static=[[IDXVALUE]] : index) { +// CHECK: acc.loop tile({[[I64VALUE]] : i64, [[I64VALUE]] : i64}) { // CHECK-NEXT: "test.openacc_dummy_op"() : () -> () // CHECK-NEXT: acc.yield // CHECK-NEXT: } -// CHECK: acc.loop tile([[I64VALUE]], [[I64VALUE]] : i64, i64) { +// CHECK: acc.loop tile({[[I32VALUE]] : i32, [[I32VALUE]] : i32}) { // CHECK-NEXT: "test.openacc_dummy_op"() : () -> () // CHECK-NEXT: acc.yield // CHECK-NEXT: } -// CHECK: acc.loop tile([[I32VALUE]], [[I32VALUE]] : i32, i32) { +// CHECK: acc.loop gang({static=[[I64VALUE]] : i64, num=[[I64VALUE]] : i64}) { // CHECK-NEXT: "test.openacc_dummy_op"() : () -> () // CHECK-NEXT: acc.yield // CHECK-NEXT: } -// CHECK: acc.loop gang(num=[[I64VALUE]] : i64, static=[[I64VALUE]] : i64) { -// CHECK-NEXT: "test.openacc_dummy_op"() : () -> () -// CHECK-NEXT: acc.yield -// CHECK-NEXT: } -// CHECK: acc.loop gang(dim=[[I64VALUE]] : i64, static=[[I64VALUE]] : i64) { +// CHECK: acc.loop gang({dim=[[I64VALUE]] : i64, static=[[I64VALUE]] : i64}) { // CHECK-NEXT: "test.openacc_dummy_op"() : () -> () // CHECK-NEXT: acc.yield // CHECK-NEXT: }