mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-02-05 00:49:43 +00:00
[flang] Correct folding of SPREAD() for higher ranks
The construction of the dimension order vector used to populate the result array was incorrect, leading to a scrambled-looking result for rank-3 and higher results. Fix, and extend tests. Differential Revision: https://reviews.llvm.org/D125113
This commit is contained in:
parent
9641b9be9d
commit
85fdbc1569
@ -890,9 +890,9 @@ template <typename T> Expr<T> Folder<T>::SPREAD(FunctionRef<T> &&funcRef) {
|
||||
Constant<T> spread{source->Reshape(std::move(shape))};
|
||||
std::vector<int> dimOrder;
|
||||
for (int j{0}; j < sourceRank; ++j) {
|
||||
dimOrder.push_back(j);
|
||||
dimOrder.push_back(j < *dim - 1 ? j : j + 1);
|
||||
}
|
||||
dimOrder.insert(dimOrder.begin() + *dim - 1, sourceRank);
|
||||
dimOrder.push_back(*dim - 1);
|
||||
ConstantSubscripts at{spread.lbounds()}; // all 1
|
||||
spread.CopyFrom(*source, TotalElementCount(spread.shape()), at, &dimOrder);
|
||||
return Expr<T>{std::move(spread)};
|
||||
|
@ -5,9 +5,11 @@ module m1
|
||||
logical, parameter :: test_stov = all(spread(1, 1, 2) == [1, 1])
|
||||
logical, parameter :: test_vtom1 = all(spread([1, 2], 1, 3) == reshape([1, 1, 1, 2, 2, 2], [3, 2]))
|
||||
logical, parameter :: test_vtom2 = all(spread([1, 2], 2, 3) == reshape([1, 2, 1, 2, 1, 2], [2, 3]))
|
||||
logical, parameter :: test_vtom3 = all(spread([1, 2], 2, 3) == reshape([1, 2, 1, 2, 1, 2], [2, 3]))
|
||||
logical, parameter :: test_vtom3 = all(spread([1, 2], 1, 3) == reshape([1, 1, 1, 2, 2, 2], [3, 2]))
|
||||
logical, parameter :: test_log1 = all(all(spread([.false., .true.], 1, 2), dim=2) .eqv. [.false., .false.])
|
||||
logical, parameter :: test_log2 = all(all(spread([.false., .true.], 2, 2), dim=2) .eqv. [.false., .true.])
|
||||
logical, parameter :: test_log3 = all(any(spread([.false., .true.], 1, 2), dim=2) .eqv. [.true., .true.])
|
||||
logical, parameter :: test_log4 = all(any(spread([.false., .true.], 2, 2), dim=2) .eqv. [.false., .true.])
|
||||
logical, parameter :: test_m2toa3 = all(spread(reshape([(j,j=1,6)],[2,3]),1,4) == &
|
||||
reshape([((j,k=1,4),j=1,6)],[4,2,3]))
|
||||
end module
|
||||
|
Loading…
x
Reference in New Issue
Block a user