mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-01-15 04:29:42 +00:00
[mlir][shape] Lower Shape ConstSizeOp
to Standard ConstantOp
.
Differential Revision: https://reviews.llvm.org/D81735
This commit is contained in:
parent
2596da3174
commit
cd320446f4
@ -36,61 +36,72 @@ public:
|
||||
};
|
||||
|
||||
class FromExtentTensorOpConversion
|
||||
: public OpConversionPattern<shape::FromExtentTensorOp> {
|
||||
: public OpConversionPattern<FromExtentTensorOp> {
|
||||
public:
|
||||
using OpConversionPattern<shape::FromExtentTensorOp>::OpConversionPattern;
|
||||
using OpConversionPattern<FromExtentTensorOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(shape::FromExtentTensorOp op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(FromExtentTensorOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
shape::FromExtentTensorOpOperandAdaptor transformed(operands);
|
||||
FromExtentTensorOpOperandAdaptor transformed(operands);
|
||||
rewriter.replaceOp(op.getOperation(), transformed.input());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class IndexToSizeOpConversion
|
||||
: public OpConversionPattern<shape::IndexToSizeOp> {
|
||||
class IndexToSizeOpConversion : public OpConversionPattern<IndexToSizeOp> {
|
||||
public:
|
||||
using OpConversionPattern<shape::IndexToSizeOp>::OpConversionPattern;
|
||||
using OpConversionPattern<IndexToSizeOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(shape::IndexToSizeOp op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(IndexToSizeOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
shape::IndexToSizeOpOperandAdaptor transformed(operands);
|
||||
IndexToSizeOpOperandAdaptor transformed(operands);
|
||||
rewriter.replaceOp(op.getOperation(), transformed.arg());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class SizeToIndexOpConversion
|
||||
: public OpConversionPattern<shape::SizeToIndexOp> {
|
||||
class SizeToIndexOpConversion : public OpConversionPattern<SizeToIndexOp> {
|
||||
public:
|
||||
using OpConversionPattern<shape::SizeToIndexOp>::OpConversionPattern;
|
||||
using OpConversionPattern<SizeToIndexOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(shape::SizeToIndexOp op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(SizeToIndexOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
shape::SizeToIndexOpOperandAdaptor transformed(operands);
|
||||
SizeToIndexOpOperandAdaptor transformed(operands);
|
||||
rewriter.replaceOp(op.getOperation(), transformed.arg());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class ToExtentTensorOpConversion
|
||||
: public OpConversionPattern<shape::ToExtentTensorOp> {
|
||||
: public OpConversionPattern<ToExtentTensorOp> {
|
||||
public:
|
||||
using OpConversionPattern<shape::ToExtentTensorOp>::OpConversionPattern;
|
||||
using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(shape::ToExtentTensorOp op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
shape::ToExtentTensorOpOperandAdaptor transformed(operands);
|
||||
ToExtentTensorOpOperandAdaptor transformed(operands);
|
||||
rewriter.replaceOp(op.getOperation(), transformed.input());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class ConstSizeOpConverter : public OpConversionPattern<ConstSizeOp> {
|
||||
public:
|
||||
using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(),
|
||||
op.value().getSExtValue());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Type conversions.
|
||||
class ShapeTypeConverter : public TypeConverter {
|
||||
public:
|
||||
@ -100,8 +111,8 @@ public:
|
||||
// Add default pass-through conversion.
|
||||
addConversion([&](Type type) { return type; });
|
||||
|
||||
addConversion([ctx](shape::SizeType type) { return IndexType::get(ctx); });
|
||||
addConversion([ctx](shape::ShapeType type) {
|
||||
addConversion([ctx](SizeType type) { return IndexType::get(ctx); });
|
||||
addConversion([ctx](ShapeType type) {
|
||||
return RankedTensorType::get({ShapedType::kDynamicSize},
|
||||
IndexType::get(ctx));
|
||||
});
|
||||
@ -111,9 +122,7 @@ public:
|
||||
/// Conversion pass.
|
||||
class ConvertShapeToStandardPass
|
||||
: public ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
|
||||
|
||||
void runOnOperation() override {
|
||||
|
||||
// Setup type conversion.
|
||||
MLIRContext &ctx = getContext();
|
||||
ShapeTypeConverter typeConverter(&ctx);
|
||||
@ -146,6 +155,7 @@ void mlir::populateShapeToStandardConversionPatterns(
|
||||
patterns.insert<
|
||||
BinaryOpConversion<AddOp, AddIOp>,
|
||||
BinaryOpConversion<MulOp, MulIOp>,
|
||||
ConstSizeOpConverter,
|
||||
FromExtentTensorOpConversion,
|
||||
IndexToSizeOpConversion,
|
||||
SizeToIndexOpConversion,
|
||||
|
@ -75,3 +75,14 @@ func @binary_ops(%lhs : !shape.size, %rhs : !shape.size) {
|
||||
// CHECK-NEXT: muli %[[LHS]], %[[RHS]] : index
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Convert `const_size` to `constant` op.
|
||||
// CHECK-LABEL: @size_const
|
||||
func @size_const() -> !shape.size {
|
||||
%c1 = shape.const_size 1
|
||||
return %c1 : !shape.size
|
||||
}
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: return %[[C1]] : index
|
||||
|
Loading…
x
Reference in New Issue
Block a user