[mlir] Lower Shape binary ops (AddOp, MulOp) to Standard.

Differential Revision: https://reviews.llvm.org/D81344
This commit is contained in:
Alexander Belyaev 2020-06-08 17:48:01 +02:00
parent 936ec89e91
commit 80be54c08f
2 changed files with 31 additions and 0 deletions

View File

@ -15,10 +15,26 @@
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
using namespace mlir::shape;
namespace {
/// Conversion patterns.
template <typename SrcOpTy, typename DstOpTy>
class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
public:
using OpConversionPattern<SrcOpTy>::OpConversionPattern;
LogicalResult
matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
typename SrcOpTy::OperandAdaptor adaptor(operands);
rewriter.replaceOpWithNewOp<DstOpTy>(op.getOperation(), adaptor.lhs(),
adaptor.rhs());
return success();
}
};
class FromExtentTensorOpConversion
: public OpConversionPattern<shape::FromExtentTensorOp> {
public:
@ -128,6 +144,8 @@ void mlir::populateShapeToStandardConversionPatterns(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
// clang-format off
patterns.insert<
BinaryOpConversion<AddOp, AddIOp>,
BinaryOpConversion<MulOp, MulIOp>,
FromExtentTensorOpConversion,
IndexToSizeOpConversion,
SizeToIndexOpConversion,

View File

@ -62,3 +62,16 @@ func @from_extent_tensor(%tensor : tensor<?xindex>) -> !shape.shape {
: (tensor<?xindex>) -> !shape.shape
return %shape : !shape.shape
}
// -----
// Lower binary ops.
// CHECK-LABEL: @binary_ops
// CHECK-SAME: (%[[LHS:.*]]: index, %[[RHS:.*]]: index)
func @binary_ops(%lhs : !shape.size, %rhs : !shape.size) {
%sum = "shape.add"(%lhs, %rhs) : (!shape.size, !shape.size) -> !shape.size
// CHECK-NEXT: addi %[[LHS]], %[[RHS]] : index
%product = shape.mul %lhs, %rhs
// CHECK-NEXT: muli %[[LHS]], %[[RHS]] : index
return
}