mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-04-01 12:43:47 +00:00
[flang] Don't fold operation when shapes differ
When folding a binary operation between two array constructors, it is necessary to check if each value contained in the left operand has the same rank and shape as the one on the right. Otherwise, lowering would end up with an operation between values of different ranks/shapes, which could result in a crash. For instance, the following code was crashing the compiler: integer :: x(4), y(2, 2), z(4) z = (/x/) + (/y/) Fixes #60229 Reviewed By: klausler, jeanPerier Differential Revision: https://reviews.llvm.org/D147181
This commit is contained in:
parent
942b403ff1
commit
5e5176adb1
@ -1381,6 +1381,28 @@ ArrayConstructor<RESULT> ArrayConstructorFromMold(
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename LEFT, typename RIGHT>
|
||||
bool ShapesMatch(FoldingContext &context,
|
||||
const ArrayConstructor<LEFT> &leftArrConst,
|
||||
const ArrayConstructor<RIGHT> &rightArrConst) {
|
||||
auto rightIter{rightArrConst.begin()};
|
||||
for (auto &leftValue : leftArrConst) {
|
||||
CHECK(rightIter != rightArrConst.end());
|
||||
auto &leftExpr{std::get<Expr<LEFT>>(leftValue.u)};
|
||||
auto &rightExpr{std::get<Expr<RIGHT>>(rightIter->u)};
|
||||
if (leftExpr.Rank() != rightExpr.Rank()) {
|
||||
return false;
|
||||
}
|
||||
std::optional<Shape> leftShape{GetShape(context, leftExpr)};
|
||||
std::optional<Shape> rightShape{GetShape(context, rightExpr)};
|
||||
if (!leftShape || !rightShape || *leftShape != *rightShape) {
|
||||
return false;
|
||||
}
|
||||
++rightIter;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// array * array case
|
||||
template <typename RESULT, typename LEFT, typename RIGHT>
|
||||
auto MapOperation(FoldingContext &context,
|
||||
@ -1391,11 +1413,14 @@ auto MapOperation(FoldingContext &context,
|
||||
auto result{ArrayConstructorFromMold<RESULT>(leftValues, std::move(length))};
|
||||
auto &leftArrConst{std::get<ArrayConstructor<LEFT>>(leftValues.u)};
|
||||
if constexpr (common::HasMember<RIGHT, AllIntrinsicCategoryTypes>) {
|
||||
common::visit(
|
||||
[&](auto &&kindExpr) {
|
||||
bool mapped{common::visit(
|
||||
[&](auto &&kindExpr) -> bool {
|
||||
using kindType = ResultType<decltype(kindExpr)>;
|
||||
|
||||
auto &rightArrConst{std::get<ArrayConstructor<kindType>>(kindExpr.u)};
|
||||
if (!ShapesMatch(context, leftArrConst, rightArrConst)) {
|
||||
return false;
|
||||
}
|
||||
auto rightIter{rightArrConst.begin()};
|
||||
for (auto &leftValue : leftArrConst) {
|
||||
CHECK(rightIter != rightArrConst.end());
|
||||
@ -1405,10 +1430,17 @@ auto MapOperation(FoldingContext &context,
|
||||
f(std::move(leftScalar), Expr<RIGHT>{std::move(rightScalar)})));
|
||||
++rightIter;
|
||||
}
|
||||
return true;
|
||||
},
|
||||
std::move(rightValues.u));
|
||||
std::move(rightValues.u))};
|
||||
if (!mapped) {
|
||||
return std::nullopt;
|
||||
}
|
||||
} else {
|
||||
auto &rightArrConst{std::get<ArrayConstructor<RIGHT>>(rightValues.u)};
|
||||
if (!ShapesMatch(context, leftArrConst, rightArrConst)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
auto rightIter{rightArrConst.begin()};
|
||||
for (auto &leftValue : leftArrConst) {
|
||||
CHECK(rightIter != rightArrConst.end());
|
||||
|
@ -196,7 +196,9 @@ subroutine associate_tests(p)
|
||||
end subroutine
|
||||
|
||||
!CHECK-LABEL: array_constructor
|
||||
subroutine array_constructor()
|
||||
subroutine array_constructor(a, u, v, w, x, y, z)
|
||||
real :: a(4)
|
||||
integer :: u(:), v(1), w(2), x(4), y(4), z(2, 2)
|
||||
interface
|
||||
function return_allocatable()
|
||||
real, allocatable :: return_allocatable(:)
|
||||
@ -204,6 +206,28 @@ subroutine array_constructor()
|
||||
end interface
|
||||
!CHECK: PRINT *, size([REAL(4)::return_allocatable(),return_allocatable()])
|
||||
print *, size([return_allocatable(), return_allocatable()])
|
||||
!CHECK: PRINT *, [INTEGER(4)::x+y]
|
||||
print *, (/x/) + (/y/)
|
||||
!CHECK: PRINT *, [INTEGER(4)::x]+[INTEGER(4)::z]
|
||||
print *, (/x/) + (/z/)
|
||||
!CHECK: PRINT *, [INTEGER(4)::x+y,x+y]
|
||||
print *, (/x, x/) + (/y, y/)
|
||||
!CHECK: PRINT *, [INTEGER(4)::x,x]+[INTEGER(4)::x,z]
|
||||
print *, (/x, x/) + (/x, z/)
|
||||
!CHECK: PRINT *, [INTEGER(4)::x,w,w]+[INTEGER(4)::w,w,x]
|
||||
print *, (/x, w, w/) + (/w, w, x/)
|
||||
!CHECK: PRINT *, [INTEGER(4)::x]+[INTEGER(4)::1_4,2_4,3_4,4_4]
|
||||
print *, (/x/) + (/1, 2, 3, 4/)
|
||||
!CHECK: PRINT *, [INTEGER(4)::v]+[INTEGER(4)::1_4]
|
||||
print *, (/v/) + (/1/)
|
||||
!CHECK: PRINT *, [INTEGER(4)::x]+[INTEGER(4)::u]
|
||||
print *, (/x/) + (/u/)
|
||||
!CHECK: PRINT *, [INTEGER(4)::u]+[INTEGER(4)::u]
|
||||
print *, (/u/) + (/u/)
|
||||
!CHECK: PRINT *, [REAL(4)::a**x]
|
||||
print *, (/a/) ** (/x/)
|
||||
!CHECK: PRINT *, [REAL(4)::a]**[INTEGER(4)::z]
|
||||
print *, (/a/) ** (/z/)
|
||||
end subroutine
|
||||
|
||||
!CHECK-LABEL: array_ctor_implied_do_index
|
||||
|
@ -1158,4 +1158,108 @@ subroutine test_elemental_character_intrinsic(c1, c2)
|
||||
print *, scan(c1, c2)
|
||||
end subroutine
|
||||
|
||||
! Check that the expression is folded, with the first operation being an add
|
||||
! between x and y, resulting in a new temporary array.
|
||||
!
|
||||
! CHECK-LABEL: func @_QPtest20a(
|
||||
! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<4xi32>> {{.*}}, %[[ARG1:.*]]: !fir.ref<!fir.array<4xi32>> {{.*}}, %[[ARG2:.*]]: !fir.ref<!fir.array<4xi32>>
|
||||
! CHECK: %[[Z:.*]] = fir.array_load %[[ARG2]]({{.*}}) : (!fir.ref<!fir.array<4xi32>>, !fir.shape<1>) -> !fir.array<4xi32>
|
||||
! CHECK: %[[X:.*]] = fir.array_load %[[ARG0]]({{.*}}) : (!fir.ref<!fir.array<4xi32>>, !fir.shape<1>) -> !fir.array<4xi32>
|
||||
! CHECK: %[[Y:.*]] = fir.array_load %[[ARG1]]({{.*}}) : (!fir.ref<!fir.array<4xi32>>, !fir.shape<1>) -> !fir.array<4xi32>
|
||||
! CHECK: %[[TEMP:.*]] = fir.allocmem !fir.array<4xi32>
|
||||
! CHECK: %[[TEMP2:.*]] = fir.array_load %[[TEMP]]({{.*}}) : (!fir.heap<!fir.array<4xi32>>, !fir.shape<1>) -> !fir.array<4xi32>
|
||||
! CHECK: {{.*}} = fir.do_loop %[[I:.*]] = {{.*}} iter_args(%[[TEMP3:.*]] = %[[TEMP2]]) -> (!fir.array<4xi32>) {
|
||||
! CHECK: %[[XI:.*]] = fir.array_fetch %[[X]], %[[I]] : (!fir.array<4xi32>, index) -> i32
|
||||
! CHECK: %[[YI:.*]] = fir.array_fetch %[[Y]], %[[I]] : (!fir.array<4xi32>, index) -> i32
|
||||
! CHECK: %[[ADD:.*]] = arith.addi %[[XI]], %[[YI]] : i32
|
||||
! CHECK: {{.*}} = fir.array_update %[[TEMP3]], %[[ADD]], %[[I]] : (!fir.array<4xi32>, i32, index) -> !fir.array<4xi32>
|
||||
! CHECK: }
|
||||
subroutine test20a(x, y, z)
|
||||
integer :: x(4), y(4), z(4)
|
||||
|
||||
z = (/x/) + (/y/)
|
||||
end subroutine
|
||||
|
||||
! Check that the expression is not folded, with the first operations being
|
||||
! array constructions from x and y.
|
||||
!
|
||||
! CHECK-LABEL: func @_QPtest20b(
|
||||
! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<4xi32>> {{.*}}, %[[ARG1:.*]]: !fir.ref<!fir.array<2x2xi32>> {{.*}}, %[[ARG2:.*]]: !fir.ref<!fir.array<4xi32>>
|
||||
! CHECK: %[[Z:.*]] = fir.array_load %[[ARG2]]({{.*}}) : (!fir.ref<!fir.array<4xi32>>, !fir.shape<1>) -> !fir.array<4xi32>
|
||||
! CHECK: %[[X:.*]] = fir.array_load %[[ARG0]]({{.*}}) : (!fir.ref<!fir.array<4xi32>>, !fir.shape<1>) -> !fir.array<4xi32>
|
||||
! CHECK: %[[TEMP:.*]] = fir.allocmem !fir.array<4xi32>
|
||||
! CHECK: %[[TEMP2:.*]] = fir.array_load %[[TEMP]]({{.*}}) : (!fir.heap<!fir.array<4xi32>>, !fir.shape<1>) -> !fir.array<4xi32>
|
||||
! CHECK: {{.*}} = fir.do_loop %[[I:.*]] = {{.*}} iter_args(%[[TEMP3:.*]] = %[[TEMP2]]) -> (!fir.array<4xi32>) {
|
||||
! CHECK: %[[XI:.*]] = fir.array_fetch %[[X]], %[[I]] : (!fir.array<4xi32>, index) -> i32
|
||||
! CHECK: {{.*}} = fir.array_update %[[TEMP3]], %[[XI]], %[[I]] : (!fir.array<4xi32>, i32, index) -> !fir.array<4xi32>
|
||||
! CHECK: }
|
||||
! CHECK: %[[Y:.*]] = fir.array_load %[[ARG1]]({{.*}}) : (!fir.ref<!fir.array<2x2xi32>>, !fir.shape<2>) -> !fir.array<2x2xi32>
|
||||
! CHECK: %[[TEMP4:.*]] = fir.allocmem !fir.array<2x2xi32>
|
||||
! CHECK: %[[TEMP5:.*]] = fir.array_load %[[TEMP4]]({{.*}}) : (!fir.heap<!fir.array<2x2xi32>>, !fir.shape<2>) -> !fir.array<2x2xi32>
|
||||
! CHECK: {{.*}} = fir.do_loop %[[I:.*]] = {{.*}} iter_args(%[[TEMP6:.*]] = %[[TEMP5]]) -> (!fir.array<2x2xi32>) {
|
||||
! CHECK: {{.*}} = fir.do_loop %[[J:.*]] = {{.*}} iter_args(%[[TEMP7:.*]] = %[[TEMP6]]) -> (!fir.array<2x2xi32>) {
|
||||
! CHECK: %[[YJI:.*]] = fir.array_fetch %[[Y]], %[[J]], %[[I]] : (!fir.array<2x2xi32>, index, index) -> i32
|
||||
! CHECK: {{.*}} = fir.array_update %[[TEMP7]], %[[YJI]], %[[J]], %[[I]] : (!fir.array<2x2xi32>, i32, index, index) -> !fir.array<2x2xi32>
|
||||
! CHECK: }
|
||||
! CHECK: }
|
||||
subroutine test20b(x, y, z)
|
||||
integer :: x(4), y(2, 2), z(4)
|
||||
|
||||
z = (/x/) + (/y/)
|
||||
end subroutine
|
||||
|
||||
! CHECK-LABEL: func @_QPtest20c(
|
||||
! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<4xi32>> {{.*}}, %[[ARG1:.*]]: !fir.ref<!fir.array<2x2xi32>> {{.*}}
|
||||
|
||||
! (/x/)
|
||||
! CHECK: %[[X:.*]] = fir.array_load %[[ARG0]]({{.*}}) : (!fir.ref<!fir.array<4xi32>>, !fir.shape<1>) -> !fir.array<4xi32>
|
||||
! CHECK: %[[ACX_MEM:.*]] = fir.allocmem !fir.array<4xi32>
|
||||
! CHECK: %[[ACX:.*]] = fir.array_load %[[ACX_MEM]]({{.*}}) : (!fir.heap<!fir.array<4xi32>>, !fir.shape<1>) -> !fir.array<4xi32>
|
||||
! CHECK: {{.*}} = fir.do_loop %[[I:.*]] = {{.*}} iter_args(%[[TEMP:.*]] = %[[ACX]]) -> (!fir.array<4xi32>) {
|
||||
! CHECK: %[[XI:.*]] = fir.array_fetch %[[X]], %[[I]] : (!fir.array<4xi32>, index) -> i32
|
||||
! CHECK: {{.*}} = fir.array_update %[[TEMP]], %[[XI]], %[[I]] : (!fir.array<4xi32>, i32, index) -> !fir.array<4xi32>
|
||||
! CHECK: }
|
||||
! CHECK: %[[T:.*]] = fir.coordinate_of %[[ACX_MEM2:.*]], %{{.*}} : (!fir.heap<!fir.array<4xi32>>, index) -> !fir.ref<i32>
|
||||
! CHECK: %[[T1:.*]] = fir.convert %[[T]] : (!fir.ref<i32>) -> !fir.ref<i8>
|
||||
! CHECK: %[[T2:.*]] = fir.convert %[[ACX_MEM]] : (!fir.heap<!fir.array<4xi32>>) -> !fir.ref<i8>
|
||||
! CHECK: fir.call @llvm.memcpy.p0.p0.i64(%[[T1]], %[[T2]], {{.*}})
|
||||
! CHECK: %[[ACX2:.*]] = fir.array_load %[[ACX_MEM2]]({{.*}}) : (!fir.heap<!fir.array<4xi32>>, !fir.shape<1>) -> !fir.array<4xi32>
|
||||
|
||||
! (/y/)
|
||||
! CHECK: %[[Y:.*]] = fir.array_load %[[ARG1]]({{.*}}) : (!fir.ref<!fir.array<2x2xi32>>, !fir.shape<2>) -> !fir.array<2x2xi32>
|
||||
! CHECK: %[[ACY_MEM:.*]] = fir.allocmem !fir.array<2x2xi32>
|
||||
! CHECK: %[[ACY:.*]] = fir.array_load %[[ACY_MEM]]({{.*}}) : (!fir.heap<!fir.array<2x2xi32>>, !fir.shape<2>) -> !fir.array<2x2xi32>
|
||||
! CHECK: {{.*}} = fir.do_loop %[[I:.*]] = {{.*}} iter_args(%[[TEMP:.*]] = %[[ACY]]) -> (!fir.array<2x2xi32>) {
|
||||
! CHECK: {{.*}} = fir.do_loop %[[J:.*]] = {{.*}} iter_args(%[[TEMP2:.*]] = %[[TEMP]]) -> (!fir.array<2x2xi32>) {
|
||||
! CHECK: %[[YJI:.*]] = fir.array_fetch %[[Y]], %[[J]], %[[I]] : (!fir.array<2x2xi32>, index, index) -> i32
|
||||
! CHECK: {{.*}} = fir.array_update %[[TEMP2]], %[[YJI]], %[[J]], %[[I]] : (!fir.array<2x2xi32>, i32, index, index) -> !fir.array<2x2xi32>
|
||||
! CHECK: }
|
||||
! CHECK: }
|
||||
! CHECK: %[[T:.*]] = fir.coordinate_of %[[ACY_MEM2:.*]], {{.*}} : (!fir.heap<!fir.array<4xi32>>, index) -> !fir.ref<i32>
|
||||
! CHECK: %[[T1:.*]] = fir.convert %[[T]] : (!fir.ref<i32>) -> !fir.ref<i8>
|
||||
! CHECK: %[[T2:.*]] = fir.convert %[[ACY_MEM]] : (!fir.heap<!fir.array<2x2xi32>>) -> !fir.ref<i8>
|
||||
! CHECK: fir.call @llvm.memcpy.p0.p0.i64(%[[T1]], %[[T2]], {{.*}})
|
||||
! CHECK: %[[ACY2:.*]] = fir.array_load %[[ACY_MEM2]]({{.*}}) : (!fir.heap<!fir.array<4xi32>>, !fir.shape<1>) -> !fir.array<4xi32>
|
||||
|
||||
! (/x/) /= (/y/)
|
||||
! CHECK: %[[RES_MEM:.*]] = fir.allocmem !fir.array<4x!fir.logical<4>>
|
||||
! CHECK: %[[RES:.*]] = fir.array_load %[[RES_MEM]]({{.*}}) : (!fir.heap<!fir.array<4x!fir.logical<4>>>, !fir.shape<1>) -> !fir.array<4x!fir.logical<4>>
|
||||
! CHECK: %{{.*}} = fir.do_loop %[[I:.*]] = {{.*}} iter_args(%[[TEMP:.*]] = %[[RES]]) -> (!fir.array<4x!fir.logical<4>>) {
|
||||
! CHECK: %[[XI:.*]] = fir.array_fetch %[[ACX2]], %[[I]] : (!fir.array<4xi32>, index) -> i32
|
||||
! CHECK: %[[YI:.*]] = fir.array_fetch %[[ACY2]], %[[I]] : (!fir.array<4xi32>, index) -> i32
|
||||
! CHECK: %[[T1:.*]] = arith.cmpi ne, %[[XI]], %[[YI]] : i32
|
||||
! CHECK: %[[T2:.*]] = fir.convert %[[T1]] : (i1) -> !fir.logical<4>
|
||||
! CHECK: {{.*}} = fir.array_update %[[TEMP]], %[[T2]], %[[I]] : (!fir.array<4x!fir.logical<4>>, !fir.logical<4>, index) -> !fir.array<4x!fir.logical<4>>
|
||||
! CHECK: }
|
||||
|
||||
! any((/x/) /= (/y/))
|
||||
! CHECK: %[[T1:.*]] = fir.embox %[[RES_MEM]]({{.*}}) : (!fir.heap<!fir.array<4x!fir.logical<4>>>, !fir.shape<1>) -> !fir.box<!fir.array<4x!fir.logical<4>>>
|
||||
! CHECK: %[[T2:.*]] = fir.convert %[[T1]] : (!fir.box<!fir.array<4x!fir.logical<4>>>) -> !fir.box<none>
|
||||
! CHECK: fir.call @_FortranAAny(%[[T2]], {{.*}}){{.*}} : (!fir.box<none>, !fir.ref<i8>, i32, i32) -> i1
|
||||
subroutine test20c(x, y)
|
||||
integer :: x(4), y(2, 2)
|
||||
|
||||
if (any((/x/) /= (/y/))) print *, "different"
|
||||
end subroutine
|
||||
|
||||
! CHECK: func private @_QPbar(
|
||||
|
Loading…
x
Reference in New Issue
Block a user