[flang][hlfir] fix missing conversion in transpose simplification

It seems just replacing the operation was not replacing all of the uses
when the types of the expression before and after this pass differ (due
to differing shape information). Now the shape information is always
kept the same.

This fixes https://github.com/llvm/llvm-project/issues/63399

Differential Revision: https://reviews.llvm.org/D153333
This commit is contained in:
Tom Eccles 2023-06-20 12:40:26 +00:00
parent 9d796d05a1
commit 74adc3e0eb
4 changed files with 109 additions and 9 deletions

View File

@ -363,11 +363,13 @@ using ElementalKernelGenerator = std::function<hlfir::Entity(
mlir::Location, fir::FirOpBuilder &, mlir::ValueRange)>;
/// Generate an hlfir.elementalOp given call back to generate the element
/// value at for each iteration.
/// If exprType is specified, this will be the return type of the elemental op
hlfir::ElementalOp genElementalOp(mlir::Location loc,
fir::FirOpBuilder &builder,
mlir::Type elementType, mlir::Value shape,
mlir::ValueRange typeParams,
const ElementalKernelGenerator &genKernel);
const ElementalKernelGenerator &genKernel,
mlir::Type exprType = mlir::Type{});
/// Structure to describe a loop nest.
struct LoopNest {

View File

@ -722,12 +722,12 @@ static hlfir::ExprType getArrayExprType(mlir::Type elementType,
isPolymorphic);
}
hlfir::ElementalOp
hlfir::genElementalOp(mlir::Location loc, fir::FirOpBuilder &builder,
mlir::Type elementType, mlir::Value shape,
mlir::ValueRange typeParams,
const ElementalKernelGenerator &genKernel) {
mlir::Type exprType = getArrayExprType(elementType, shape, false);
hlfir::ElementalOp hlfir::genElementalOp(
mlir::Location loc, fir::FirOpBuilder &builder, mlir::Type elementType,
mlir::Value shape, mlir::ValueRange typeParams,
const ElementalKernelGenerator &genKernel, mlir::Type exprType) {
if (!exprType)
exprType = getArrayExprType(elementType, shape, false);
auto elementalOp =
builder.create<hlfir::ElementalOp>(loc, exprType, shape, typeParams);
auto insertPt = builder.saveInsertionPoint();

View File

@ -59,9 +59,16 @@ public:
return val;
};
hlfir::ElementalOp elementalOp = hlfir::genElementalOp(
loc, builder, elementType, resultShape, typeParams, genKernel);
loc, builder, elementType, resultShape, typeParams, genKernel,
transpose.getResult().getType());
rewriter.replaceOp(transpose, elementalOp.getResult());
// it wouldn't be safe to replace block arguments with a different
// hlfir.expr type. Types can differ due to differing amounts of shape
// information
assert(elementalOp.getResult().getType() ==
transpose.getResult().getType());
rewriter.replaceOp(transpose, elementalOp);
return mlir::success();
}

View File

@ -93,3 +93,94 @@ func.func @transpose3(%arg0: !hlfir.expr<?x2xi32>) {
// CHECK: }
// CHECK: return
// CHECK: }
// expr with multiple uses
func.func @transpose4(%arg0: !hlfir.expr<2x2xf32>, %arg1: !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>) {
%0 = hlfir.transpose %arg0 : (!hlfir.expr<2x2xf32>) -> !hlfir.expr<2x2xf32>
%1 = hlfir.shape_of %0 : (!hlfir.expr<2x2xf32>) -> !fir.shape<2>
%2 = hlfir.elemental %1 : (!fir.shape<2>) -> !hlfir.expr<2x2xf32> {
^bb0(%arg2: index, %arg3: index):
%3 = hlfir.apply %0, %arg2, %arg3 : (!hlfir.expr<2x2xf32>, index, index) -> f32
%4 = math.cos %3 fastmath<contract> : f32
hlfir.yield_element %4 : f32
}
hlfir.assign %2 to %arg1 realloc : !hlfir.expr<2x2xf32>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>
hlfir.destroy %2 : !hlfir.expr<2x2xf32>
hlfir.destroy %0 : !hlfir.expr<2x2xf32>
return
}
// CHECK-LABEL: func.func @transpose4(
// CHECK-SAME: %[[ARG0:.*]]: !hlfir.expr<2x2xf32>
// CHECK-SAME: %[[ARG1:.*]]:
// CHECK: %[[SHAPE0:.*]] = fir.shape
// CHECK: %[[TRANSPOSE:.*]] = hlfir.elemental %[[SHAPE0]] : (!fir.shape<2>) -> !hlfir.expr<2x2xf32> {
// CHECK: ^bb0(%[[I:.*]]: index, %[[J:.*]]: index):
// CHECK: %[[ELE:.*]] = hlfir.apply %[[ARG0]], %[[J]], %[[I]] : (!hlfir.expr<2x2xf32>, index, index) -> f32
// CHECK: hlfir.yield_element %[[ELE]] : f32
// CHECK: }
// CHECK: %[[SHAPE1:.*]] = hlfir.shape_of %[[TRANSPOSE]] : (!hlfir.expr<2x2xf32>) -> !fir.shape<2>
// CHECK: %[[COS:.*]] = hlfir.elemental %[[SHAPE1]] : (!fir.shape<2>) -> !hlfir.expr<2x2xf32> {
// CHECK: ^bb0(%[[I:.*]]: index, %[[J:.*]]: index):
// CHECK: %[[ELE:.*]] = hlfir.apply %[[TRANSPOSE]], %[[I]], %[[J]] : (!hlfir.expr<2x2xf32>, index, index) -> f32
// CHECK: %[[COS_ELE:.*]] = math.cos %[[ELE]] fastmath<contract> : f32
// CHECK: hlfir.yield_element %[[COS_ELE]] : f32
// CHECK: }
// CHECK: hlfir.assign %[[COS]] to %[[ARG1]] realloc
// CHECK: hlfir.destroy %[[COS]] : !hlfir.expr<2x2xf32>
// CHECK: hlfir.destroy %[[TRANSPOSE]] : !hlfir.expr<2x2xf32>
// CHECK: return
// CHECK: }
// regression test
func.func @transpose5(%arg0: !fir.ref<tuple<!fir.box<!fir.array<2x2xf64>>, !fir.box<!fir.array<2x2xf64>>>> {fir.host_assoc}) attributes {fir.internal_proc} {
%0 = fir.address_of(@_QFEb) : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf64>>>>
%1:2 = hlfir.declare %0 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFEb"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf64>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf64>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf64>>>>)
%c0_i32 = arith.constant 0 : i32
%2 = fir.coordinate_of %arg0, %c0_i32 : (!fir.ref<tuple<!fir.box<!fir.array<2x2xf64>>, !fir.box<!fir.array<2x2xf64>>>>, i32) -> !fir.ref<!fir.box<!fir.array<2x2xf64>>>
%3 = fir.load %2 : !fir.ref<!fir.box<!fir.array<2x2xf64>>>
%4 = fir.box_addr %3 : (!fir.box<!fir.array<2x2xf64>>) -> !fir.ref<!fir.array<2x2xf64>>
%c0 = arith.constant 0 : index
%5:3 = fir.box_dims %3, %c0 : (!fir.box<!fir.array<2x2xf64>>, index) -> (index, index, index)
%c1 = arith.constant 1 : index
%6:3 = fir.box_dims %3, %c1 : (!fir.box<!fir.array<2x2xf64>>, index) -> (index, index, index)
%7 = fir.shape %5#1, %6#1 : (index, index) -> !fir.shape<2>
%8:2 = hlfir.declare %4(%7) {uniq_name = "_QFEa"} : (!fir.ref<!fir.array<2x2xf64>>, !fir.shape<2>) -> (!fir.ref<!fir.array<2x2xf64>>, !fir.ref<!fir.array<2x2xf64>>)
%c1_i32 = arith.constant 1 : i32
%9 = fir.coordinate_of %arg0, %c1_i32 : (!fir.ref<tuple<!fir.box<!fir.array<2x2xf64>>, !fir.box<!fir.array<2x2xf64>>>>, i32) -> !fir.ref<!fir.box<!fir.array<2x2xf64>>>
%10 = fir.load %9 : !fir.ref<!fir.box<!fir.array<2x2xf64>>>
%11 = fir.box_addr %10 : (!fir.box<!fir.array<2x2xf64>>) -> !fir.ref<!fir.array<2x2xf64>>
%c0_0 = arith.constant 0 : index
%12:3 = fir.box_dims %10, %c0_0 : (!fir.box<!fir.array<2x2xf64>>, index) -> (index, index, index)
%c1_1 = arith.constant 1 : index
%13:3 = fir.box_dims %10, %c1_1 : (!fir.box<!fir.array<2x2xf64>>, index) -> (index, index, index)
%14 = fir.shape %12#1, %13#1 : (index, index) -> !fir.shape<2>
%15:2 = hlfir.declare %11(%14) {uniq_name = "_QFEc"} : (!fir.ref<!fir.array<2x2xf64>>, !fir.shape<2>) -> (!fir.ref<!fir.array<2x2xf64>>, !fir.ref<!fir.array<2x2xf64>>)
%16 = hlfir.transpose %8#0 : (!fir.ref<!fir.array<2x2xf64>>) -> !hlfir.expr<2x2xf64>
%17 = hlfir.shape_of %16 : (!hlfir.expr<2x2xf64>) -> !fir.shape<2>
%18 = hlfir.elemental %17 : (!fir.shape<2>) -> !hlfir.expr<?x?xf64> {
^bb0(%arg1: index, %arg2: index):
%19 = hlfir.apply %16, %arg1, %arg2 : (!hlfir.expr<2x2xf64>, index, index) -> f64
%20 = math.cos %19 fastmath<contract> : f64
hlfir.yield_element %20 : f64
}
hlfir.assign %18 to %1#0 realloc : !hlfir.expr<?x?xf64>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf64>>>>
hlfir.destroy %18 : !hlfir.expr<?x?xf64>
hlfir.destroy %16 : !hlfir.expr<2x2xf64>
return
}
// CHECK-LABEL: func.func @transpose5(
// ...
// CHECK: %[[TRANSPOSE:.*]] = hlfir.elemental %[[SHAPE0:.*]]
// CHECK: ^bb0(%[[I:.*]]: index, %[[J:.*]]: index):
// CHECK: %[[ELE:.*]] = hlfir.designate %[[ARRAY:.*]] (%[[J]], %[[I]])
// CHECK: %[[LOAD:.*]] = fir.load %[[ELE]]
// CHECK: hlfir.yield_element %[[LOAD]]
// CHECK: }
// CHECK: %[[SHAPE1:.*]] = hlfir.shape_of %[[TRANSPOSE]]
// CHECK: %[[COS:.*]] = hlfir.elemental %[[SHAPE1]]
// ...
// CHECK: hlfir.assign %[[COS]] to %{{.*}} realloc
// CHECK: hlfir.destroy %[[COS]]
// CHECK: hlfir.destroy %[[TRANSPOSE]]
// CHECK: return
// CHECK: }