mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-02-12 04:43:48 +00:00
[flang][hlfir] fix missing conversion in transpose simplification
It seems just replacing the operation was not replacing all of the uses when the types of the expression before and after this pass differ (due to differing shape information). Now the shape information is always kept the same. This fixes https://github.com/llvm/llvm-project/issues/63399 Differential Revision: https://reviews.llvm.org/D153333
This commit is contained in:
parent
9d796d05a1
commit
74adc3e0eb
@ -363,11 +363,13 @@ using ElementalKernelGenerator = std::function<hlfir::Entity(
|
||||
mlir::Location, fir::FirOpBuilder &, mlir::ValueRange)>;
|
||||
/// Generate an hlfir.elementalOp given call back to generate the element
|
||||
/// value at for each iteration.
|
||||
/// If exprType is specified, this will be the return type of the elemental op
|
||||
hlfir::ElementalOp genElementalOp(mlir::Location loc,
|
||||
fir::FirOpBuilder &builder,
|
||||
mlir::Type elementType, mlir::Value shape,
|
||||
mlir::ValueRange typeParams,
|
||||
const ElementalKernelGenerator &genKernel);
|
||||
const ElementalKernelGenerator &genKernel,
|
||||
mlir::Type exprType = mlir::Type{});
|
||||
|
||||
/// Structure to describe a loop nest.
|
||||
struct LoopNest {
|
||||
|
@ -722,12 +722,12 @@ static hlfir::ExprType getArrayExprType(mlir::Type elementType,
|
||||
isPolymorphic);
|
||||
}
|
||||
|
||||
hlfir::ElementalOp
|
||||
hlfir::genElementalOp(mlir::Location loc, fir::FirOpBuilder &builder,
|
||||
mlir::Type elementType, mlir::Value shape,
|
||||
mlir::ValueRange typeParams,
|
||||
const ElementalKernelGenerator &genKernel) {
|
||||
mlir::Type exprType = getArrayExprType(elementType, shape, false);
|
||||
hlfir::ElementalOp hlfir::genElementalOp(
|
||||
mlir::Location loc, fir::FirOpBuilder &builder, mlir::Type elementType,
|
||||
mlir::Value shape, mlir::ValueRange typeParams,
|
||||
const ElementalKernelGenerator &genKernel, mlir::Type exprType) {
|
||||
if (!exprType)
|
||||
exprType = getArrayExprType(elementType, shape, false);
|
||||
auto elementalOp =
|
||||
builder.create<hlfir::ElementalOp>(loc, exprType, shape, typeParams);
|
||||
auto insertPt = builder.saveInsertionPoint();
|
||||
|
@ -59,9 +59,16 @@ public:
|
||||
return val;
|
||||
};
|
||||
hlfir::ElementalOp elementalOp = hlfir::genElementalOp(
|
||||
loc, builder, elementType, resultShape, typeParams, genKernel);
|
||||
loc, builder, elementType, resultShape, typeParams, genKernel,
|
||||
transpose.getResult().getType());
|
||||
|
||||
rewriter.replaceOp(transpose, elementalOp.getResult());
|
||||
// it wouldn't be safe to replace block arguments with a different
|
||||
// hlfir.expr type. Types can differ due to differing amounts of shape
|
||||
// information
|
||||
assert(elementalOp.getResult().getType() ==
|
||||
transpose.getResult().getType());
|
||||
|
||||
rewriter.replaceOp(transpose, elementalOp);
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
|
@ -93,3 +93,94 @@ func.func @transpose3(%arg0: !hlfir.expr<?x2xi32>) {
|
||||
// CHECK: }
|
||||
// CHECK: return
|
||||
// CHECK: }
|
||||
|
||||
// expr with multiple uses
|
||||
func.func @transpose4(%arg0: !hlfir.expr<2x2xf32>, %arg1: !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>) {
|
||||
%0 = hlfir.transpose %arg0 : (!hlfir.expr<2x2xf32>) -> !hlfir.expr<2x2xf32>
|
||||
%1 = hlfir.shape_of %0 : (!hlfir.expr<2x2xf32>) -> !fir.shape<2>
|
||||
%2 = hlfir.elemental %1 : (!fir.shape<2>) -> !hlfir.expr<2x2xf32> {
|
||||
^bb0(%arg2: index, %arg3: index):
|
||||
%3 = hlfir.apply %0, %arg2, %arg3 : (!hlfir.expr<2x2xf32>, index, index) -> f32
|
||||
%4 = math.cos %3 fastmath<contract> : f32
|
||||
hlfir.yield_element %4 : f32
|
||||
}
|
||||
hlfir.assign %2 to %arg1 realloc : !hlfir.expr<2x2xf32>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>
|
||||
hlfir.destroy %2 : !hlfir.expr<2x2xf32>
|
||||
hlfir.destroy %0 : !hlfir.expr<2x2xf32>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func.func @transpose4(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !hlfir.expr<2x2xf32>
|
||||
// CHECK-SAME: %[[ARG1:.*]]:
|
||||
// CHECK: %[[SHAPE0:.*]] = fir.shape
|
||||
// CHECK: %[[TRANSPOSE:.*]] = hlfir.elemental %[[SHAPE0]] : (!fir.shape<2>) -> !hlfir.expr<2x2xf32> {
|
||||
// CHECK: ^bb0(%[[I:.*]]: index, %[[J:.*]]: index):
|
||||
// CHECK: %[[ELE:.*]] = hlfir.apply %[[ARG0]], %[[J]], %[[I]] : (!hlfir.expr<2x2xf32>, index, index) -> f32
|
||||
// CHECK: hlfir.yield_element %[[ELE]] : f32
|
||||
// CHECK: }
|
||||
// CHECK: %[[SHAPE1:.*]] = hlfir.shape_of %[[TRANSPOSE]] : (!hlfir.expr<2x2xf32>) -> !fir.shape<2>
|
||||
// CHECK: %[[COS:.*]] = hlfir.elemental %[[SHAPE1]] : (!fir.shape<2>) -> !hlfir.expr<2x2xf32> {
|
||||
// CHECK: ^bb0(%[[I:.*]]: index, %[[J:.*]]: index):
|
||||
// CHECK: %[[ELE:.*]] = hlfir.apply %[[TRANSPOSE]], %[[I]], %[[J]] : (!hlfir.expr<2x2xf32>, index, index) -> f32
|
||||
// CHECK: %[[COS_ELE:.*]] = math.cos %[[ELE]] fastmath<contract> : f32
|
||||
// CHECK: hlfir.yield_element %[[COS_ELE]] : f32
|
||||
// CHECK: }
|
||||
// CHECK: hlfir.assign %[[COS]] to %[[ARG1]] realloc
|
||||
// CHECK: hlfir.destroy %[[COS]] : !hlfir.expr<2x2xf32>
|
||||
// CHECK: hlfir.destroy %[[TRANSPOSE]] : !hlfir.expr<2x2xf32>
|
||||
// CHECK: return
|
||||
// CHECK: }
|
||||
|
||||
// regression test
|
||||
func.func @transpose5(%arg0: !fir.ref<tuple<!fir.box<!fir.array<2x2xf64>>, !fir.box<!fir.array<2x2xf64>>>> {fir.host_assoc}) attributes {fir.internal_proc} {
|
||||
%0 = fir.address_of(@_QFEb) : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf64>>>>
|
||||
%1:2 = hlfir.declare %0 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFEb"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf64>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf64>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf64>>>>)
|
||||
%c0_i32 = arith.constant 0 : i32
|
||||
%2 = fir.coordinate_of %arg0, %c0_i32 : (!fir.ref<tuple<!fir.box<!fir.array<2x2xf64>>, !fir.box<!fir.array<2x2xf64>>>>, i32) -> !fir.ref<!fir.box<!fir.array<2x2xf64>>>
|
||||
%3 = fir.load %2 : !fir.ref<!fir.box<!fir.array<2x2xf64>>>
|
||||
%4 = fir.box_addr %3 : (!fir.box<!fir.array<2x2xf64>>) -> !fir.ref<!fir.array<2x2xf64>>
|
||||
%c0 = arith.constant 0 : index
|
||||
%5:3 = fir.box_dims %3, %c0 : (!fir.box<!fir.array<2x2xf64>>, index) -> (index, index, index)
|
||||
%c1 = arith.constant 1 : index
|
||||
%6:3 = fir.box_dims %3, %c1 : (!fir.box<!fir.array<2x2xf64>>, index) -> (index, index, index)
|
||||
%7 = fir.shape %5#1, %6#1 : (index, index) -> !fir.shape<2>
|
||||
%8:2 = hlfir.declare %4(%7) {uniq_name = "_QFEa"} : (!fir.ref<!fir.array<2x2xf64>>, !fir.shape<2>) -> (!fir.ref<!fir.array<2x2xf64>>, !fir.ref<!fir.array<2x2xf64>>)
|
||||
%c1_i32 = arith.constant 1 : i32
|
||||
%9 = fir.coordinate_of %arg0, %c1_i32 : (!fir.ref<tuple<!fir.box<!fir.array<2x2xf64>>, !fir.box<!fir.array<2x2xf64>>>>, i32) -> !fir.ref<!fir.box<!fir.array<2x2xf64>>>
|
||||
%10 = fir.load %9 : !fir.ref<!fir.box<!fir.array<2x2xf64>>>
|
||||
%11 = fir.box_addr %10 : (!fir.box<!fir.array<2x2xf64>>) -> !fir.ref<!fir.array<2x2xf64>>
|
||||
%c0_0 = arith.constant 0 : index
|
||||
%12:3 = fir.box_dims %10, %c0_0 : (!fir.box<!fir.array<2x2xf64>>, index) -> (index, index, index)
|
||||
%c1_1 = arith.constant 1 : index
|
||||
%13:3 = fir.box_dims %10, %c1_1 : (!fir.box<!fir.array<2x2xf64>>, index) -> (index, index, index)
|
||||
%14 = fir.shape %12#1, %13#1 : (index, index) -> !fir.shape<2>
|
||||
%15:2 = hlfir.declare %11(%14) {uniq_name = "_QFEc"} : (!fir.ref<!fir.array<2x2xf64>>, !fir.shape<2>) -> (!fir.ref<!fir.array<2x2xf64>>, !fir.ref<!fir.array<2x2xf64>>)
|
||||
%16 = hlfir.transpose %8#0 : (!fir.ref<!fir.array<2x2xf64>>) -> !hlfir.expr<2x2xf64>
|
||||
%17 = hlfir.shape_of %16 : (!hlfir.expr<2x2xf64>) -> !fir.shape<2>
|
||||
%18 = hlfir.elemental %17 : (!fir.shape<2>) -> !hlfir.expr<?x?xf64> {
|
||||
^bb0(%arg1: index, %arg2: index):
|
||||
%19 = hlfir.apply %16, %arg1, %arg2 : (!hlfir.expr<2x2xf64>, index, index) -> f64
|
||||
%20 = math.cos %19 fastmath<contract> : f64
|
||||
hlfir.yield_element %20 : f64
|
||||
}
|
||||
hlfir.assign %18 to %1#0 realloc : !hlfir.expr<?x?xf64>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf64>>>>
|
||||
hlfir.destroy %18 : !hlfir.expr<?x?xf64>
|
||||
hlfir.destroy %16 : !hlfir.expr<2x2xf64>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func.func @transpose5(
|
||||
// ...
|
||||
// CHECK: %[[TRANSPOSE:.*]] = hlfir.elemental %[[SHAPE0:.*]]
|
||||
// CHECK: ^bb0(%[[I:.*]]: index, %[[J:.*]]: index):
|
||||
// CHECK: %[[ELE:.*]] = hlfir.designate %[[ARRAY:.*]] (%[[J]], %[[I]])
|
||||
// CHECK: %[[LOAD:.*]] = fir.load %[[ELE]]
|
||||
// CHECK: hlfir.yield_element %[[LOAD]]
|
||||
// CHECK: }
|
||||
// CHECK: %[[SHAPE1:.*]] = hlfir.shape_of %[[TRANSPOSE]]
|
||||
// CHECK: %[[COS:.*]] = hlfir.elemental %[[SHAPE1]]
|
||||
// ...
|
||||
// CHECK: hlfir.assign %[[COS]] to %{{.*}} realloc
|
||||
// CHECK: hlfir.destroy %[[COS]]
|
||||
// CHECK: hlfir.destroy %[[TRANSPOSE]]
|
||||
// CHECK: return
|
||||
// CHECK: }
|
||||
|
Loading…
x
Reference in New Issue
Block a user