[mlir][memref] Fix expanded shape ops memref.cast folding with changed type

`memref.expand_shape` has verification logic to make sure
result dim must be static if all the collapsing src dims are static.

This can be relaxed once expand_shape supports more dynamism.

Differential Revision: https://reviews.llvm.org/D114391
This commit is contained in:
Benjamin Kramer 2021-11-22 22:11:45 +01:00
parent 2dec2aa3ad
commit 966b720983
2 changed files with 12 additions and 2 deletions

View File

@ -1640,8 +1640,6 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
CollapseShapeOpMemRefCastFolder>(context);
}
OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
if (succeeded(foldMemRefCast(*this)))
return getResult();
return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands);
}
OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {

View File

@ -600,6 +600,18 @@ func @fold_memref_reshape_dynamic(%arg0 : memref<?x?xf32>) -> memref<?x?xf32> {
// -----
func @fold_memref_expand_cast(%arg0 : memref<?x?xf32>) -> memref<2x4x4xf32> {
%0 = memref.cast %arg0 : memref<?x?xf32> to memref<8x4xf32>
%1 = memref.expand_shape %0 [[0, 1], [2]]
: memref<8x4xf32> into memref<2x4x4xf32>
return %1 : memref<2x4x4xf32>
}
// CHECK-LABEL: @fold_memref_expand_cast
// CHECK: memref.expand_shape
// -----
// CHECK-LABEL: func @collapse_after_memref_cast_type_change(
// CHECK-SAME: %[[INPUT:.*]]: memref<?x512x1x1xf32>) -> memref<?x?xf32> {
// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[INPUT]]