[flang][openacc][openmp] Update stride computation for bounds (#72168)

This patch updates the stride computation for the outer dimensions of
multidimensional arrays where the stride is read from the descriptor.
For the inner dimension, the stride is the element size in bytes. Then
it is multiplied by the n-1 extent for outer dimensions.
This commit is contained in:
Valentin Clement (バレンタイン クレメン) 2023-11-14 13:55:39 -08:00 committed by GitHub
parent a40900211a
commit 447af1ce99
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 3 deletions

View File

@ -597,6 +597,7 @@ genBoundsOpsFromBox(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
assert(box.getType().isa<fir::BaseBoxType>() && assert(box.getType().isa<fir::BaseBoxType>() &&
"expect fir.box or fir.class"); "expect fir.box or fir.class");
mlir::Value byteStride;
for (unsigned dim = 0; dim < dataExv.rank(); ++dim) { for (unsigned dim = 0; dim < dataExv.rank(); ++dim) {
mlir::Value d = builder.createIntegerConstant(loc, idxTy, dim); mlir::Value d = builder.createIntegerConstant(loc, idxTy, dim);
mlir::Value baseLb = mlir::Value baseLb =
@ -606,9 +607,13 @@ genBoundsOpsFromBox(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value lb = builder.createIntegerConstant(loc, idxTy, 0); mlir::Value lb = builder.createIntegerConstant(loc, idxTy, 0);
mlir::Value ub = mlir::Value ub =
builder.create<mlir::arith::SubIOp>(loc, dimInfo.getExtent(), one); builder.create<mlir::arith::SubIOp>(loc, dimInfo.getExtent(), one);
mlir::Value bound = if (dim == 0) // First stride is the element size.
builder.create<BoundsOp>(loc, boundTy, lb, ub, mlir::Value(), byteStride = dimInfo.getByteStride();
dimInfo.getByteStride(), true, baseLb); mlir::Value bound = builder.create<BoundsOp>(
loc, boundTy, lb, ub, mlir::Value(), byteStride, true, baseLb);
// Compute the stride for the next dimension.
byteStride = builder.create<mlir::arith::MulIOp>(loc, byteStride,
dimInfo.getExtent());
bounds.push_back(bound); bounds.push_back(bound);
} }
return bounds; return bounds;

View File

@ -102,4 +102,26 @@ contains
! HLFIR: %[[PRESENT:.*]] = acc.present varPtr(%[[DECL_ARG0]]#1 : !fir.ref<!fir.array<?xf32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<?xf32>> {name = "a"} ! HLFIR: %[[PRESENT:.*]] = acc.present varPtr(%[[DECL_ARG0]]#1 : !fir.ref<!fir.array<?xf32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<?xf32>> {name = "a"}
! CHECK: acc.kernels dataOperands(%[[PRESENT]] : !fir.ref<!fir.array<?xf32>>) ! CHECK: acc.kernels dataOperands(%[[PRESENT]] : !fir.ref<!fir.array<?xf32>>)
subroutine acc_multi_strides(a)
real, dimension(:,:,:) :: a
!$acc kernels present(a)
!$acc end kernels
end subroutine
! CHECK-LABEL: func.func @_QMopenacc_boundsPacc_multi_strides(
! CHECK-SAME: %[[ARG0:.*]]: !fir.box<!fir.array<?x?x?xf32>> {fir.bindc_name = "a"})
! HLFIR: %[[DECL_ARG0:.*]]:2 = hlfir.declare %[[ARG0]] {uniq_name = "_QMopenacc_boundsFacc_multi_stridesEa"} : (!fir.box<!fir.array<?x?x?xf32>>) -> (!fir.box<!fir.array<?x?x?xf32>>, !fir.box<!fir.array<?x?x?xf32>>)
! HLFIR: %[[BOX_DIMS0:.*]]:3 = fir.box_dims %[[DECL_ARG0]]#1, %c0{{.*}} : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
! HLFIR: %[[BOUNDS0:.*]] = acc.bounds lowerbound(%{{.*}} : index) upperbound(%{{.*}} : index) stride(%[[BOX_DIMS0]]#2 : index) startIdx(%{{.*}} : index) {strideInBytes = true}
! HLFIR: %[[STRIDE1:.*]] = arith.muli %[[BOX_DIMS0]]#2, %[[BOX_DIMS0]]#1 : index
! HLFIR: %[[BOX_DIMS1:.*]]:3 = fir.box_dims %[[DECL_ARG0]]#1, %c1{{.*}} : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
! HLFIR: %[[BOUNDS1:.*]] = acc.bounds lowerbound(%{{.*}} : index) upperbound(%{{.*}} : index) stride(%[[STRIDE1]] : index) startIdx(%{{.*}} : index) {strideInBytes = true}
! HLFIR: %[[STRIDE2:.*]] = arith.muli %[[STRIDE1]], %[[BOX_DIMS1]]#1 : index
! HLFIR: %[[BOX_DIMS2:.*]]:3 = fir.box_dims %[[DECL_ARG0]]#1, %c2{{.*}} : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
! HLFIR: %[[BOUNDS2:.*]] = acc.bounds lowerbound(%{{.*}} : index) upperbound(%{{.*}} : index) stride(%[[STRIDE2]] : index) startIdx(%{{.*}} : index) {strideInBytes = true}
! HLFIR: %[[BOX_ADDR:.*]] = fir.box_addr %[[DECL_ARG0]]#1 : (!fir.box<!fir.array<?x?x?xf32>>) -> !fir.ref<!fir.array<?x?x?xf32>>
! HLFIR: %[[PRESENT:.*]] = acc.present varPtr(%[[BOX_ADDR]] : !fir.ref<!fir.array<?x?x?xf32>>) bounds(%29, %33, %37) -> !fir.ref<!fir.array<?x?x?xf32>> {name = "a"}
! HLFIR: acc.kernels dataOperands(%[[PRESENT]] : !fir.ref<!fir.array<?x?x?xf32>>) {
end module end module