mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-01-16 05:01:56 +00:00
[MLIR][Shape] Lower shape.shape_of
to standard dialect
Lower `shape.shape_of` to standard dialect. This lowering supports statically and dynamically shaped tensors. Support for unranked tensors will be added as part of the lowering to `scf`. Differential Revision: https://reviews.llvm.org/D82098
This commit is contained in:
parent
a3adfb400e
commit
ac3e5c4d93
@ -38,6 +38,45 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
|
||||
public:
|
||||
using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
ShapeOfOp::Adaptor transformed(operands);
|
||||
auto loc = op.getLoc();
|
||||
auto tensorVal = transformed.arg();
|
||||
auto tensorTy = tensorVal.getType();
|
||||
|
||||
// For unranked tensors `shape_of` lowers to `scf` and the pattern can be
|
||||
// found in the corresponding pass.
|
||||
if (tensorTy.isa<UnrankedTensorType>())
|
||||
return failure();
|
||||
|
||||
// Build values for individual dimensions.
|
||||
SmallVector<Value, 8> dimValues;
|
||||
auto rankedTensorTy = tensorTy.cast<RankedTensorType>();
|
||||
int64_t rank = rankedTensorTy.getRank();
|
||||
for (int64_t i = 0; i < rank; i++) {
|
||||
if (rankedTensorTy.isDynamicDim(i)) {
|
||||
auto dimVal = rewriter.create<DimOp>(loc, tensorVal, i);
|
||||
dimValues.push_back(dimVal);
|
||||
} else {
|
||||
int64_t dim = rankedTensorTy.getDimSize(i);
|
||||
auto dimVal = rewriter.create<ConstantIndexOp>(loc, dim);
|
||||
dimValues.push_back(dimVal);
|
||||
}
|
||||
}
|
||||
|
||||
// Materialize shape as ranked tensor.
|
||||
rewriter.replaceOpWithNewOp<TensorFromElementsOp>(op.getOperation(),
|
||||
dimValues);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class ConstSizeOpConverter : public OpConversionPattern<ConstSizeOp> {
|
||||
public:
|
||||
using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
|
||||
@ -107,7 +146,8 @@ void mlir::populateShapeToStandardConversionPatterns(
|
||||
patterns.insert<
|
||||
BinaryOpConversion<AddOp, AddIOp>,
|
||||
BinaryOpConversion<MulOp, MulIOp>,
|
||||
ConstSizeOpConverter>(ctx);
|
||||
ConstSizeOpConverter,
|
||||
ShapeOfOpConversion>(ctx);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
|
@ -86,3 +86,32 @@ func @size_const() -> !shape.size {
|
||||
}
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: return %[[C1]] : index
|
||||
|
||||
// -----
|
||||
|
||||
// Lower `shape_of` for statically shaped tensor.
|
||||
// CHECK-LABEL: @shape_of_stat
|
||||
// CHECK-SAME: (%[[ARG:.*]]: tensor<1x2x3xf32>)
|
||||
func @shape_of_stat(%arg : tensor<1x2x3xf32>) {
|
||||
// CHECK-DAG: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK-DAG: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK-DAG: %[[C3:.*]] = constant 3 : index
|
||||
// CHECK-DAG: %[[SHAPE:.*]] = tensor_from_elements(%[[C1]], %[[C2]], %[[C3]]) : tensor<3xindex>
|
||||
%shape = shape.shape_of %arg : tensor<1x2x3xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Lower `shape_of` for dynamically shaped tensor.
|
||||
// CHECK-LABEL: @shape_of_dyn
|
||||
// CHECK-SAME: (%[[ARG:.*]]: tensor<1x5x?xf32>)
|
||||
func @shape_of_dyn(%arg : tensor<1x5x?xf32>) {
|
||||
// CHECK-DAG: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK-DAG: %[[C5:.*]] = constant 5 : index
|
||||
// CHECK-DAG: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK-DAG: %[[DYN_DIM:.*]] = dim %[[ARG]], %[[C2]] : tensor<1x5x?xf32>
|
||||
// CHECK-DAG: %[[SHAPE:.*]] = tensor_from_elements(%[[C1]], %[[C5]], %[[DYN_DIM]]) : tensor<3xindex>
|
||||
%shape = shape.shape_of %arg : tensor<1x5x?xf32>
|
||||
return
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user