mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-12-01 18:12:44 +00:00
[MLIR][TOSA] Lower tosa.reshape to linalg.reshape
Lowering from the tosa.reshape op to linalg.reshape. For same-rank or non-collapsed/expanded cases two linalg.reshapes are inserted. Differential Revision: https://reviews.llvm.org/D97439
This commit is contained in:
parent
799c50fe93
commit
caccddc52a
@ -16,8 +16,11 @@
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
#include <numeric>
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
|
||||
@ -339,6 +342,106 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class ReshapeOpConverter : public OpConversionPattern<tosa::ReshapeOp> {
|
||||
public:
|
||||
using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(tosa::ReshapeOp reshape, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
typename tosa::ReshapeOp::Adaptor operands(args);
|
||||
|
||||
ShapedType operandTy = operands.input1().getType().cast<ShapedType>();
|
||||
ShapedType resultTy = reshape.getType().template cast<ShapedType>();
|
||||
|
||||
if (!operandTy.hasStaticShape() || !resultTy.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
// Compute the reassociation maps for the linalg operation.
|
||||
ArrayRef<int64_t> expandedShape =
|
||||
(operandTy.getRank() > resultTy.getRank() ? operandTy.getShape()
|
||||
: resultTy.getShape());
|
||||
ArrayRef<int64_t> collapsedShape =
|
||||
(operandTy.getRank() > resultTy.getRank() ? resultTy.getShape()
|
||||
: operandTy.getShape());
|
||||
unsigned currSrcDim = 0, currDstDim = 0;
|
||||
SmallVector<linalg::ReassociationExprs, 4> reassociationMap(
|
||||
collapsedShape.size());
|
||||
|
||||
// First scan all dimensions in the source shapes to see whether we have a
|
||||
// perfect case where consecutive dimensions in source are collapsed. For
|
||||
// such case we can just generate one single linalg.reshape.
|
||||
bool isCollapsingSource = true;
|
||||
while (currSrcDim < expandedShape.size() &&
|
||||
currDstDim < collapsedShape.size()) {
|
||||
int64_t dstSize = collapsedShape[currDstDim];
|
||||
int64_t srcSize = expandedShape[currSrcDim];
|
||||
while (srcSize < dstSize && currSrcDim < expandedShape.size()) {
|
||||
reassociationMap[currDstDim].push_back(
|
||||
rewriter.getAffineDimExpr(currSrcDim++));
|
||||
srcSize *= expandedShape[currSrcDim];
|
||||
}
|
||||
if (srcSize == dstSize) {
|
||||
reassociationMap[currDstDim].push_back(
|
||||
rewriter.getAffineDimExpr(currSrcDim++));
|
||||
// If the next dim in collapsedShape is not 1, treat subsequent dims in
|
||||
// expandedShape which are 1 to be collapsed.
|
||||
if (currDstDim == collapsedShape.size() - 1 ||
|
||||
collapsedShape[currDstDim + 1] != 1) {
|
||||
while (currSrcDim < expandedShape.size() &&
|
||||
expandedShape[currSrcDim] == 1) {
|
||||
reassociationMap[currDstDim].push_back(
|
||||
rewriter.getAffineDimExpr(currSrcDim++));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
isCollapsingSource = false;
|
||||
break;
|
||||
}
|
||||
currDstDim++;
|
||||
}
|
||||
if (currSrcDim != expandedShape.size() ||
|
||||
currDstDim != collapsedShape.size())
|
||||
isCollapsingSource = false;
|
||||
|
||||
// Otherwise, we need to first reduce all source dimensions into one and
|
||||
// then expand to the destination dimensions.
|
||||
if (!isCollapsingSource) {
|
||||
auto getIdentityExprs = [&rewriter](int n) {
|
||||
SmallVector<AffineExpr, 4> exprs;
|
||||
for (int i = 0; i < n; ++i)
|
||||
exprs.push_back(rewriter.getAffineDimExpr(i));
|
||||
return exprs;
|
||||
};
|
||||
Location loc = reshape.getLoc();
|
||||
int64_t totalElems =
|
||||
std::accumulate(expandedShape.begin(), expandedShape.end(), 1,
|
||||
std::multiplies<int64_t>());
|
||||
auto elemTy = operandTy.getElementType();
|
||||
SmallVector<linalg::ReassociationExprs, 4> collapsingMap = {
|
||||
// Use operandTy here because we need to collapse all operands
|
||||
// dimensions.
|
||||
getIdentityExprs(operandTy.getShape().size())};
|
||||
SmallVector<linalg::ReassociationExprs, 4> expandingMap = {
|
||||
// Use resultTy here because we need to expand to all result
|
||||
// dimensions.
|
||||
getIdentityExprs(resultTy.getShape().size())};
|
||||
|
||||
auto collapsedTy = RankedTensorType::get({totalElems}, elemTy);
|
||||
Value collapsedOp = rewriter.create<linalg::TensorReshapeOp>(
|
||||
loc, collapsedTy, args[0], collapsingMap);
|
||||
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
|
||||
reshape, resultTy, collapsedOp, expandingMap);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
|
||||
reshape, resultTy, args[0], reassociationMap);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
|
||||
@ -358,6 +461,6 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
|
||||
PointwiseConverter<tosa::GreaterEqualOp>,
|
||||
PointwiseConverter<tosa::MaximumOp>, PointwiseConverter<tosa::MinimumOp>,
|
||||
PointwiseConverter<tosa::CeilOp>, PointwiseConverter<tosa::FloorOp>,
|
||||
PointwiseConverter<tosa::ClampOp>, PointwiseConverter<tosa::ReluNOp>>(
|
||||
context);
|
||||
PointwiseConverter<tosa::ClampOp>, PointwiseConverter<tosa::ReluNOp>,
|
||||
ReshapeOpConverter>(context);
|
||||
}
|
||||
|
@ -258,3 +258,49 @@ func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-LABEL: @test_reshape_downrank
|
||||
func @test_reshape_downrank(%arg0: tensor<2x3xf32>) -> tensor<6xf32> {
|
||||
// CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg0 [#[[$MAP0]]]
|
||||
%0 = "tosa.reshape"(%arg0) {new_shape = [6]} : (tensor<2x3xf32>) -> tensor<6xf32>
|
||||
// CHECK: return [[RESHAPE]]
|
||||
return %0 : tensor<6xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-LABEL: @test_reshape_uprank
|
||||
func @test_reshape_uprank(%arg0: tensor<6xf32>) -> tensor<2x3xf32> {
|
||||
// CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg0 [#[[$MAP0]]]
|
||||
%0 = "tosa.reshape"(%arg0) {new_shape = [2, 3]} : (tensor<6xf32>) -> tensor<2x3xf32>
|
||||
// CHECK: return [[RESHAPE]]
|
||||
return %0 : tensor<2x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-LABEL: @test_reshape_samerank
|
||||
func @test_reshape_samerank(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> {
|
||||
// CHECK: [[RESHAPE1:%.+]] = linalg.tensor_reshape %arg0 [#[[$MAP0]]]
|
||||
// CHECK: [[RESHAPE2:%.+]] = linalg.tensor_reshape [[RESHAPE1]] [#[[$MAP0]]]
|
||||
%0 = "tosa.reshape"(%arg0) {new_shape = [2, 3]} : (tensor<3x2xf32>) -> tensor<2x3xf32>
|
||||
// CHECK: return [[RESHAPE2]]
|
||||
return %0 : tensor<2x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>
|
||||
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3)>
|
||||
// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
|
||||
|
||||
// CHECK-LABEL: @test_reshape_downrank_6D
|
||||
func @test_reshape_downrank_6D(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> {
|
||||
// CHECK: linalg.tensor_reshape %arg0 [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
|
||||
%0 = "tosa.reshape"(%arg0) {new_shape = [2, 3]} : (tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32>
|
||||
return %0 : tensor<6x5x77xf32>
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user