[mlir][shape] Lower Shape ConstSizeOp to Standard ConstantOp.

Differential Revision: https://reviews.llvm.org/D81735
This commit is contained in:
Alexander Belyaev 2020-06-12 15:50:03 +02:00
parent 2596da3174
commit cd320446f4
2 changed files with 43 additions and 22 deletions

View File

@ -36,61 +36,72 @@ public:
};
class FromExtentTensorOpConversion
: public OpConversionPattern<shape::FromExtentTensorOp> {
: public OpConversionPattern<FromExtentTensorOp> {
public:
using OpConversionPattern<shape::FromExtentTensorOp>::OpConversionPattern;
using OpConversionPattern<FromExtentTensorOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(shape::FromExtentTensorOp op, ArrayRef<Value> operands,
matchAndRewrite(FromExtentTensorOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
shape::FromExtentTensorOpOperandAdaptor transformed(operands);
FromExtentTensorOpOperandAdaptor transformed(operands);
rewriter.replaceOp(op.getOperation(), transformed.input());
return success();
}
};
class IndexToSizeOpConversion
: public OpConversionPattern<shape::IndexToSizeOp> {
class IndexToSizeOpConversion : public OpConversionPattern<IndexToSizeOp> {
public:
using OpConversionPattern<shape::IndexToSizeOp>::OpConversionPattern;
using OpConversionPattern<IndexToSizeOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(shape::IndexToSizeOp op, ArrayRef<Value> operands,
matchAndRewrite(IndexToSizeOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
shape::IndexToSizeOpOperandAdaptor transformed(operands);
IndexToSizeOpOperandAdaptor transformed(operands);
rewriter.replaceOp(op.getOperation(), transformed.arg());
return success();
}
};
class SizeToIndexOpConversion
: public OpConversionPattern<shape::SizeToIndexOp> {
class SizeToIndexOpConversion : public OpConversionPattern<SizeToIndexOp> {
public:
using OpConversionPattern<shape::SizeToIndexOp>::OpConversionPattern;
using OpConversionPattern<SizeToIndexOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(shape::SizeToIndexOp op, ArrayRef<Value> operands,
matchAndRewrite(SizeToIndexOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
shape::SizeToIndexOpOperandAdaptor transformed(operands);
SizeToIndexOpOperandAdaptor transformed(operands);
rewriter.replaceOp(op.getOperation(), transformed.arg());
return success();
}
};
class ToExtentTensorOpConversion
: public OpConversionPattern<shape::ToExtentTensorOp> {
: public OpConversionPattern<ToExtentTensorOp> {
public:
using OpConversionPattern<shape::ToExtentTensorOp>::OpConversionPattern;
using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(shape::ToExtentTensorOp op, ArrayRef<Value> operands,
matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
shape::ToExtentTensorOpOperandAdaptor transformed(operands);
ToExtentTensorOpOperandAdaptor transformed(operands);
rewriter.replaceOp(op.getOperation(), transformed.input());
return success();
}
};
class ConstSizeOpConverter : public OpConversionPattern<ConstSizeOp> {
public:
using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(),
op.value().getSExtValue());
return success();
}
};
/// Type conversions.
class ShapeTypeConverter : public TypeConverter {
public:
@ -100,8 +111,8 @@ public:
// Add default pass-through conversion.
addConversion([&](Type type) { return type; });
addConversion([ctx](shape::SizeType type) { return IndexType::get(ctx); });
addConversion([ctx](shape::ShapeType type) {
addConversion([ctx](SizeType type) { return IndexType::get(ctx); });
addConversion([ctx](ShapeType type) {
return RankedTensorType::get({ShapedType::kDynamicSize},
IndexType::get(ctx));
});
@ -111,9 +122,7 @@ public:
/// Conversion pass.
class ConvertShapeToStandardPass
: public ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
void runOnOperation() override {
// Setup type conversion.
MLIRContext &ctx = getContext();
ShapeTypeConverter typeConverter(&ctx);
@ -146,6 +155,7 @@ void mlir::populateShapeToStandardConversionPatterns(
patterns.insert<
BinaryOpConversion<AddOp, AddIOp>,
BinaryOpConversion<MulOp, MulIOp>,
ConstSizeOpConverter,
FromExtentTensorOpConversion,
IndexToSizeOpConversion,
SizeToIndexOpConversion,

View File

@ -75,3 +75,14 @@ func @binary_ops(%lhs : !shape.size, %rhs : !shape.size) {
// CHECK-NEXT: muli %[[LHS]], %[[RHS]] : index
return
}
// -----
// Convert `const_size` to `constant` op.
// CHECK-LABEL: @size_const
func @size_const() -> !shape.size {
%c1 = shape.const_size 1
return %c1 : !shape.size
}
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: return %[[C1]] : index