[MLIR][TOSA] Lower tosa.reshape to linalg.reshape

Lowering from the tosa.reshape op to linalg.reshape. For same-rank or
non-collapsed/expanded cases two linalg.reshapes are inserted.

Differential Revision: https://reviews.llvm.org/D97439
This commit is contained in:
Rob Suderman 2021-02-24 14:12:03 -08:00
parent 799c50fe93
commit caccddc52a
2 changed files with 151 additions and 2 deletions

View File

@ -16,8 +16,11 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <numeric>
using namespace mlir;
static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
@ -339,6 +342,106 @@ public:
}
};
class ReshapeOpConverter : public OpConversionPattern<tosa::ReshapeOp> {
public:
using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(tosa::ReshapeOp reshape, ArrayRef<Value> args,
ConversionPatternRewriter &rewriter) const final {
typename tosa::ReshapeOp::Adaptor operands(args);
ShapedType operandTy = operands.input1().getType().cast<ShapedType>();
ShapedType resultTy = reshape.getType().template cast<ShapedType>();
if (!operandTy.hasStaticShape() || !resultTy.hasStaticShape())
return failure();
// Compute the reassociation maps for the linalg operation.
ArrayRef<int64_t> expandedShape =
(operandTy.getRank() > resultTy.getRank() ? operandTy.getShape()
: resultTy.getShape());
ArrayRef<int64_t> collapsedShape =
(operandTy.getRank() > resultTy.getRank() ? resultTy.getShape()
: operandTy.getShape());
unsigned currSrcDim = 0, currDstDim = 0;
SmallVector<linalg::ReassociationExprs, 4> reassociationMap(
collapsedShape.size());
// First scan all dimensions in the source shapes to see whether we have a
// perfect case where consecutive dimensions in source are collapsed. For
// such case we can just generate one single linalg.reshape.
bool isCollapsingSource = true;
while (currSrcDim < expandedShape.size() &&
currDstDim < collapsedShape.size()) {
int64_t dstSize = collapsedShape[currDstDim];
int64_t srcSize = expandedShape[currSrcDim];
while (srcSize < dstSize && currSrcDim < expandedShape.size()) {
reassociationMap[currDstDim].push_back(
rewriter.getAffineDimExpr(currSrcDim++));
srcSize *= expandedShape[currSrcDim];
}
if (srcSize == dstSize) {
reassociationMap[currDstDim].push_back(
rewriter.getAffineDimExpr(currSrcDim++));
// If the next dim in collapsedShape is not 1, treat subsequent dims in
// expandedShape which are 1 to be collapsed.
if (currDstDim == collapsedShape.size() - 1 ||
collapsedShape[currDstDim + 1] != 1) {
while (currSrcDim < expandedShape.size() &&
expandedShape[currSrcDim] == 1) {
reassociationMap[currDstDim].push_back(
rewriter.getAffineDimExpr(currSrcDim++));
}
}
} else {
isCollapsingSource = false;
break;
}
currDstDim++;
}
if (currSrcDim != expandedShape.size() ||
currDstDim != collapsedShape.size())
isCollapsingSource = false;
// Otherwise, we need to first reduce all source dimensions into one and
// then expand to the destination dimensions.
if (!isCollapsingSource) {
auto getIdentityExprs = [&rewriter](int n) {
SmallVector<AffineExpr, 4> exprs;
for (int i = 0; i < n; ++i)
exprs.push_back(rewriter.getAffineDimExpr(i));
return exprs;
};
Location loc = reshape.getLoc();
int64_t totalElems =
std::accumulate(expandedShape.begin(), expandedShape.end(), 1,
std::multiplies<int64_t>());
auto elemTy = operandTy.getElementType();
SmallVector<linalg::ReassociationExprs, 4> collapsingMap = {
// Use operandTy here because we need to collapse all operands
// dimensions.
getIdentityExprs(operandTy.getShape().size())};
SmallVector<linalg::ReassociationExprs, 4> expandingMap = {
// Use resultTy here because we need to expand to all result
// dimensions.
getIdentityExprs(resultTy.getShape().size())};
auto collapsedTy = RankedTensorType::get({totalElems}, elemTy);
Value collapsedOp = rewriter.create<linalg::TensorReshapeOp>(
loc, collapsedTy, args[0], collapsingMap);
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
reshape, resultTy, collapsedOp, expandingMap);
return success();
}
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
reshape, resultTy, args[0], reassociationMap);
return success();
}
};
} // namespace
void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
@ -358,6 +461,6 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
PointwiseConverter<tosa::GreaterEqualOp>,
PointwiseConverter<tosa::MaximumOp>, PointwiseConverter<tosa::MinimumOp>,
PointwiseConverter<tosa::CeilOp>, PointwiseConverter<tosa::FloorOp>,
PointwiseConverter<tosa::ClampOp>, PointwiseConverter<tosa::ReluNOp>>(
context);
PointwiseConverter<tosa::ClampOp>, PointwiseConverter<tosa::ReluNOp>,
ReshapeOpConverter>(context);
}

View File

@ -258,3 +258,49 @@ func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
return
}
// -----
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: @test_reshape_downrank
func @test_reshape_downrank(%arg0: tensor<2x3xf32>) -> tensor<6xf32> {
// CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg0 [#[[$MAP0]]]
%0 = "tosa.reshape"(%arg0) {new_shape = [6]} : (tensor<2x3xf32>) -> tensor<6xf32>
// CHECK: return [[RESHAPE]]
return %0 : tensor<6xf32>
}
// -----
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: @test_reshape_uprank
func @test_reshape_uprank(%arg0: tensor<6xf32>) -> tensor<2x3xf32> {
// CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg0 [#[[$MAP0]]]
%0 = "tosa.reshape"(%arg0) {new_shape = [2, 3]} : (tensor<6xf32>) -> tensor<2x3xf32>
// CHECK: return [[RESHAPE]]
return %0 : tensor<2x3xf32>
}
// -----
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: @test_reshape_samerank
func @test_reshape_samerank(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> {
// CHECK: [[RESHAPE1:%.+]] = linalg.tensor_reshape %arg0 [#[[$MAP0]]]
// CHECK: [[RESHAPE2:%.+]] = linalg.tensor_reshape [[RESHAPE1]] [#[[$MAP0]]]
%0 = "tosa.reshape"(%arg0) {new_shape = [2, 3]} : (tensor<3x2xf32>) -> tensor<2x3xf32>
// CHECK: return [[RESHAPE2]]
return %0 : tensor<2x3xf32>
}
// -----
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3)>
// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
// CHECK-LABEL: @test_reshape_downrank_6D
func @test_reshape_downrank_6D(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> {
// CHECK: linalg.tensor_reshape %arg0 [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
%0 = "tosa.reshape"(%arg0) {new_shape = [2, 3]} : (tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32>
return %0 : tensor<6x5x77xf32>
}