[flang] Don't fold operation when shapes differ

When folding a binary operation between two array constructors, it
is necessary to check if each value contained in the left operand
has the same rank and shape as the one on the right.
Otherwise, lowering would end up with an operation between values
of different ranks/shapes, which could result in a crash.

For instance, the following code was crashing the compiler:
integer :: x(4), y(2, 2), z(4)

z = (/x/) + (/y/)

Fixes #60229

Reviewed By: klausler, jeanPerier

Differential Revision: https://reviews.llvm.org/D147181
This commit is contained in:
Leandro Lupori 2023-03-28 14:28:56 +00:00
parent 942b403ff1
commit 5e5176adb1
3 changed files with 164 additions and 4 deletions

View File

@ -1381,6 +1381,28 @@ ArrayConstructor<RESULT> ArrayConstructorFromMold(
return result;
}
template <typename LEFT, typename RIGHT>
bool ShapesMatch(FoldingContext &context,
const ArrayConstructor<LEFT> &leftArrConst,
const ArrayConstructor<RIGHT> &rightArrConst) {
auto rightIter{rightArrConst.begin()};
for (auto &leftValue : leftArrConst) {
CHECK(rightIter != rightArrConst.end());
auto &leftExpr{std::get<Expr<LEFT>>(leftValue.u)};
auto &rightExpr{std::get<Expr<RIGHT>>(rightIter->u)};
if (leftExpr.Rank() != rightExpr.Rank()) {
return false;
}
std::optional<Shape> leftShape{GetShape(context, leftExpr)};
std::optional<Shape> rightShape{GetShape(context, rightExpr)};
if (!leftShape || !rightShape || *leftShape != *rightShape) {
return false;
}
++rightIter;
}
return true;
}
// array * array case
template <typename RESULT, typename LEFT, typename RIGHT>
auto MapOperation(FoldingContext &context,
@ -1391,11 +1413,14 @@ auto MapOperation(FoldingContext &context,
auto result{ArrayConstructorFromMold<RESULT>(leftValues, std::move(length))};
auto &leftArrConst{std::get<ArrayConstructor<LEFT>>(leftValues.u)};
if constexpr (common::HasMember<RIGHT, AllIntrinsicCategoryTypes>) {
common::visit(
[&](auto &&kindExpr) {
bool mapped{common::visit(
[&](auto &&kindExpr) -> bool {
using kindType = ResultType<decltype(kindExpr)>;
auto &rightArrConst{std::get<ArrayConstructor<kindType>>(kindExpr.u)};
if (!ShapesMatch(context, leftArrConst, rightArrConst)) {
return false;
}
auto rightIter{rightArrConst.begin()};
for (auto &leftValue : leftArrConst) {
CHECK(rightIter != rightArrConst.end());
@ -1405,10 +1430,17 @@ auto MapOperation(FoldingContext &context,
f(std::move(leftScalar), Expr<RIGHT>{std::move(rightScalar)})));
++rightIter;
}
return true;
},
std::move(rightValues.u));
std::move(rightValues.u))};
if (!mapped) {
return std::nullopt;
}
} else {
auto &rightArrConst{std::get<ArrayConstructor<RIGHT>>(rightValues.u)};
if (!ShapesMatch(context, leftArrConst, rightArrConst)) {
return std::nullopt;
}
auto rightIter{rightArrConst.begin()};
for (auto &leftValue : leftArrConst) {
CHECK(rightIter != rightArrConst.end());

View File

@ -196,7 +196,9 @@ subroutine associate_tests(p)
end subroutine
!CHECK-LABEL: array_constructor
subroutine array_constructor()
subroutine array_constructor(a, u, v, w, x, y, z)
real :: a(4)
integer :: u(:), v(1), w(2), x(4), y(4), z(2, 2)
interface
function return_allocatable()
real, allocatable :: return_allocatable(:)
@ -204,6 +206,28 @@ subroutine array_constructor()
end interface
!CHECK: PRINT *, size([REAL(4)::return_allocatable(),return_allocatable()])
print *, size([return_allocatable(), return_allocatable()])
!CHECK: PRINT *, [INTEGER(4)::x+y]
print *, (/x/) + (/y/)
!CHECK: PRINT *, [INTEGER(4)::x]+[INTEGER(4)::z]
print *, (/x/) + (/z/)
!CHECK: PRINT *, [INTEGER(4)::x+y,x+y]
print *, (/x, x/) + (/y, y/)
!CHECK: PRINT *, [INTEGER(4)::x,x]+[INTEGER(4)::x,z]
print *, (/x, x/) + (/x, z/)
!CHECK: PRINT *, [INTEGER(4)::x,w,w]+[INTEGER(4)::w,w,x]
print *, (/x, w, w/) + (/w, w, x/)
!CHECK: PRINT *, [INTEGER(4)::x]+[INTEGER(4)::1_4,2_4,3_4,4_4]
print *, (/x/) + (/1, 2, 3, 4/)
!CHECK: PRINT *, [INTEGER(4)::v]+[INTEGER(4)::1_4]
print *, (/v/) + (/1/)
!CHECK: PRINT *, [INTEGER(4)::x]+[INTEGER(4)::u]
print *, (/x/) + (/u/)
!CHECK: PRINT *, [INTEGER(4)::u]+[INTEGER(4)::u]
print *, (/u/) + (/u/)
!CHECK: PRINT *, [REAL(4)::a**x]
print *, (/a/) ** (/x/)
!CHECK: PRINT *, [REAL(4)::a]**[INTEGER(4)::z]
print *, (/a/) ** (/z/)
end subroutine
!CHECK-LABEL: array_ctor_implied_do_index

View File

@ -1158,4 +1158,108 @@ subroutine test_elemental_character_intrinsic(c1, c2)
print *, scan(c1, c2)
end subroutine
! Check that the expression is folded, with the first operation being an add
! between x and y, resulting in a new temporary array.
!
! CHECK-LABEL: func @_QPtest20a(
! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<4xi32>> {{.*}}, %[[ARG1:.*]]: !fir.ref<!fir.array<4xi32>> {{.*}}, %[[ARG2:.*]]: !fir.ref<!fir.array<4xi32>>
! CHECK: %[[Z:.*]] = fir.array_load %[[ARG2]]({{.*}}) : (!fir.ref<!fir.array<4xi32>>, !fir.shape<1>) -> !fir.array<4xi32>
! CHECK: %[[X:.*]] = fir.array_load %[[ARG0]]({{.*}}) : (!fir.ref<!fir.array<4xi32>>, !fir.shape<1>) -> !fir.array<4xi32>
! CHECK: %[[Y:.*]] = fir.array_load %[[ARG1]]({{.*}}) : (!fir.ref<!fir.array<4xi32>>, !fir.shape<1>) -> !fir.array<4xi32>
! CHECK: %[[TEMP:.*]] = fir.allocmem !fir.array<4xi32>
! CHECK: %[[TEMP2:.*]] = fir.array_load %[[TEMP]]({{.*}}) : (!fir.heap<!fir.array<4xi32>>, !fir.shape<1>) -> !fir.array<4xi32>
! CHECK: {{.*}} = fir.do_loop %[[I:.*]] = {{.*}} iter_args(%[[TEMP3:.*]] = %[[TEMP2]]) -> (!fir.array<4xi32>) {
! CHECK: %[[XI:.*]] = fir.array_fetch %[[X]], %[[I]] : (!fir.array<4xi32>, index) -> i32
! CHECK: %[[YI:.*]] = fir.array_fetch %[[Y]], %[[I]] : (!fir.array<4xi32>, index) -> i32
! CHECK: %[[ADD:.*]] = arith.addi %[[XI]], %[[YI]] : i32
! CHECK: {{.*}} = fir.array_update %[[TEMP3]], %[[ADD]], %[[I]] : (!fir.array<4xi32>, i32, index) -> !fir.array<4xi32>
! CHECK: }
subroutine test20a(x, y, z)
integer :: x(4), y(4), z(4)
z = (/x/) + (/y/)
end subroutine
! Check that the expression is not folded, with the first operations being
! array constructions from x and y.
!
! CHECK-LABEL: func @_QPtest20b(
! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<4xi32>> {{.*}}, %[[ARG1:.*]]: !fir.ref<!fir.array<2x2xi32>> {{.*}}, %[[ARG2:.*]]: !fir.ref<!fir.array<4xi32>>
! CHECK: %[[Z:.*]] = fir.array_load %[[ARG2]]({{.*}}) : (!fir.ref<!fir.array<4xi32>>, !fir.shape<1>) -> !fir.array<4xi32>
! CHECK: %[[X:.*]] = fir.array_load %[[ARG0]]({{.*}}) : (!fir.ref<!fir.array<4xi32>>, !fir.shape<1>) -> !fir.array<4xi32>
! CHECK: %[[TEMP:.*]] = fir.allocmem !fir.array<4xi32>
! CHECK: %[[TEMP2:.*]] = fir.array_load %[[TEMP]]({{.*}}) : (!fir.heap<!fir.array<4xi32>>, !fir.shape<1>) -> !fir.array<4xi32>
! CHECK: {{.*}} = fir.do_loop %[[I:.*]] = {{.*}} iter_args(%[[TEMP3:.*]] = %[[TEMP2]]) -> (!fir.array<4xi32>) {
! CHECK: %[[XI:.*]] = fir.array_fetch %[[X]], %[[I]] : (!fir.array<4xi32>, index) -> i32
! CHECK: {{.*}} = fir.array_update %[[TEMP3]], %[[XI]], %[[I]] : (!fir.array<4xi32>, i32, index) -> !fir.array<4xi32>
! CHECK: }
! CHECK: %[[Y:.*]] = fir.array_load %[[ARG1]]({{.*}}) : (!fir.ref<!fir.array<2x2xi32>>, !fir.shape<2>) -> !fir.array<2x2xi32>
! CHECK: %[[TEMP4:.*]] = fir.allocmem !fir.array<2x2xi32>
! CHECK: %[[TEMP5:.*]] = fir.array_load %[[TEMP4]]({{.*}}) : (!fir.heap<!fir.array<2x2xi32>>, !fir.shape<2>) -> !fir.array<2x2xi32>
! CHECK: {{.*}} = fir.do_loop %[[I:.*]] = {{.*}} iter_args(%[[TEMP6:.*]] = %[[TEMP5]]) -> (!fir.array<2x2xi32>) {
! CHECK: {{.*}} = fir.do_loop %[[J:.*]] = {{.*}} iter_args(%[[TEMP7:.*]] = %[[TEMP6]]) -> (!fir.array<2x2xi32>) {
! CHECK: %[[YJI:.*]] = fir.array_fetch %[[Y]], %[[J]], %[[I]] : (!fir.array<2x2xi32>, index, index) -> i32
! CHECK: {{.*}} = fir.array_update %[[TEMP7]], %[[YJI]], %[[J]], %[[I]] : (!fir.array<2x2xi32>, i32, index, index) -> !fir.array<2x2xi32>
! CHECK: }
! CHECK: }
subroutine test20b(x, y, z)
integer :: x(4), y(2, 2), z(4)
z = (/x/) + (/y/)
end subroutine
! CHECK-LABEL: func @_QPtest20c(
! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<4xi32>> {{.*}}, %[[ARG1:.*]]: !fir.ref<!fir.array<2x2xi32>> {{.*}}
! (/x/)
! CHECK: %[[X:.*]] = fir.array_load %[[ARG0]]({{.*}}) : (!fir.ref<!fir.array<4xi32>>, !fir.shape<1>) -> !fir.array<4xi32>
! CHECK: %[[ACX_MEM:.*]] = fir.allocmem !fir.array<4xi32>
! CHECK: %[[ACX:.*]] = fir.array_load %[[ACX_MEM]]({{.*}}) : (!fir.heap<!fir.array<4xi32>>, !fir.shape<1>) -> !fir.array<4xi32>
! CHECK: {{.*}} = fir.do_loop %[[I:.*]] = {{.*}} iter_args(%[[TEMP:.*]] = %[[ACX]]) -> (!fir.array<4xi32>) {
! CHECK: %[[XI:.*]] = fir.array_fetch %[[X]], %[[I]] : (!fir.array<4xi32>, index) -> i32
! CHECK: {{.*}} = fir.array_update %[[TEMP]], %[[XI]], %[[I]] : (!fir.array<4xi32>, i32, index) -> !fir.array<4xi32>
! CHECK: }
! CHECK: %[[T:.*]] = fir.coordinate_of %[[ACX_MEM2:.*]], %{{.*}} : (!fir.heap<!fir.array<4xi32>>, index) -> !fir.ref<i32>
! CHECK: %[[T1:.*]] = fir.convert %[[T]] : (!fir.ref<i32>) -> !fir.ref<i8>
! CHECK: %[[T2:.*]] = fir.convert %[[ACX_MEM]] : (!fir.heap<!fir.array<4xi32>>) -> !fir.ref<i8>
! CHECK: fir.call @llvm.memcpy.p0.p0.i64(%[[T1]], %[[T2]], {{.*}})
! CHECK: %[[ACX2:.*]] = fir.array_load %[[ACX_MEM2]]({{.*}}) : (!fir.heap<!fir.array<4xi32>>, !fir.shape<1>) -> !fir.array<4xi32>
! (/y/)
! CHECK: %[[Y:.*]] = fir.array_load %[[ARG1]]({{.*}}) : (!fir.ref<!fir.array<2x2xi32>>, !fir.shape<2>) -> !fir.array<2x2xi32>
! CHECK: %[[ACY_MEM:.*]] = fir.allocmem !fir.array<2x2xi32>
! CHECK: %[[ACY:.*]] = fir.array_load %[[ACY_MEM]]({{.*}}) : (!fir.heap<!fir.array<2x2xi32>>, !fir.shape<2>) -> !fir.array<2x2xi32>
! CHECK: {{.*}} = fir.do_loop %[[I:.*]] = {{.*}} iter_args(%[[TEMP:.*]] = %[[ACY]]) -> (!fir.array<2x2xi32>) {
! CHECK: {{.*}} = fir.do_loop %[[J:.*]] = {{.*}} iter_args(%[[TEMP2:.*]] = %[[TEMP]]) -> (!fir.array<2x2xi32>) {
! CHECK: %[[YJI:.*]] = fir.array_fetch %[[Y]], %[[J]], %[[I]] : (!fir.array<2x2xi32>, index, index) -> i32
! CHECK: {{.*}} = fir.array_update %[[TEMP2]], %[[YJI]], %[[J]], %[[I]] : (!fir.array<2x2xi32>, i32, index, index) -> !fir.array<2x2xi32>
! CHECK: }
! CHECK: }
! CHECK: %[[T:.*]] = fir.coordinate_of %[[ACY_MEM2:.*]], {{.*}} : (!fir.heap<!fir.array<4xi32>>, index) -> !fir.ref<i32>
! CHECK: %[[T1:.*]] = fir.convert %[[T]] : (!fir.ref<i32>) -> !fir.ref<i8>
! CHECK: %[[T2:.*]] = fir.convert %[[ACY_MEM]] : (!fir.heap<!fir.array<2x2xi32>>) -> !fir.ref<i8>
! CHECK: fir.call @llvm.memcpy.p0.p0.i64(%[[T1]], %[[T2]], {{.*}})
! CHECK: %[[ACY2:.*]] = fir.array_load %[[ACY_MEM2]]({{.*}}) : (!fir.heap<!fir.array<4xi32>>, !fir.shape<1>) -> !fir.array<4xi32>
! (/x/) /= (/y/)
! CHECK: %[[RES_MEM:.*]] = fir.allocmem !fir.array<4x!fir.logical<4>>
! CHECK: %[[RES:.*]] = fir.array_load %[[RES_MEM]]({{.*}}) : (!fir.heap<!fir.array<4x!fir.logical<4>>>, !fir.shape<1>) -> !fir.array<4x!fir.logical<4>>
! CHECK: %{{.*}} = fir.do_loop %[[I:.*]] = {{.*}} iter_args(%[[TEMP:.*]] = %[[RES]]) -> (!fir.array<4x!fir.logical<4>>) {
! CHECK: %[[XI:.*]] = fir.array_fetch %[[ACX2]], %[[I]] : (!fir.array<4xi32>, index) -> i32
! CHECK: %[[YI:.*]] = fir.array_fetch %[[ACY2]], %[[I]] : (!fir.array<4xi32>, index) -> i32
! CHECK: %[[T1:.*]] = arith.cmpi ne, %[[XI]], %[[YI]] : i32
! CHECK: %[[T2:.*]] = fir.convert %[[T1]] : (i1) -> !fir.logical<4>
! CHECK: {{.*}} = fir.array_update %[[TEMP]], %[[T2]], %[[I]] : (!fir.array<4x!fir.logical<4>>, !fir.logical<4>, index) -> !fir.array<4x!fir.logical<4>>
! CHECK: }
! any((/x/) /= (/y/))
! CHECK: %[[T1:.*]] = fir.embox %[[RES_MEM]]({{.*}}) : (!fir.heap<!fir.array<4x!fir.logical<4>>>, !fir.shape<1>) -> !fir.box<!fir.array<4x!fir.logical<4>>>
! CHECK: %[[T2:.*]] = fir.convert %[[T1]] : (!fir.box<!fir.array<4x!fir.logical<4>>>) -> !fir.box<none>
! CHECK: fir.call @_FortranAAny(%[[T2]], {{.*}}){{.*}} : (!fir.box<none>, !fir.ref<i8>, i32, i32) -> i1
subroutine test20c(x, y)
integer :: x(4), y(2, 2)
if (any((/x/) /= (/y/))) print *, "different"
end subroutine
! CHECK: func private @_QPbar(