Fold full-size subview of static shapes.

Differential Revision: https://reviews.llvm.org/D97429
This commit is contained in:
Ahmed Taei 2021-02-24 17:24:14 -08:00
parent f21d78633a
commit da1e37a8b0
2 changed files with 17 additions and 2 deletions

View File

@ -3496,9 +3496,13 @@ void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
}
OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {
if (getResult().getType().cast<ShapedType>().getRank() == 0 &&
source().getType().cast<ShapedType>().getRank() == 0)
auto resultShapedType = getResult().getType().cast<ShapedType>();
auto sourceShapedType = source().getType().cast<ShapedType>();
if (resultShapedType.hasStaticShape() &&
resultShapedType == sourceShapedType) {
return getViewSource();
}
return {};
}

View File

@ -204,6 +204,17 @@ func @subview_of_memcast(%arg : memref<4x6x16x32xi8>) ->
// -----
// CHECK-LABEL: func @subview_of_static_full_size
// CHECK-SAME: %[[ARG0:.+]]: memref<4x6x16x32xi8>
// CHECK-NOT: subview
// CHECK: return %[[ARG0]] : memref<4x6x16x32xi8>
func @subview_of_static_full_size(%arg0 : memref<4x6x16x32xi8>) -> memref<4x6x16x32xi8> {
%0 = subview %arg0[0, 0, 0, 0] [4, 6, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<4x6x16x32xi8>
return %0 : memref<4x6x16x32xi8>
}
// -----
// CHECK-LABEL: func @trivial_subtensor
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
// CHECK-NOT: subtensor