[mlir] Add a pattern to bufferize std.index_cast.

Differential Revision: https://reviews.llvm.org/D102088
This commit is contained in:
Alexander Belyaev 2021-05-07 21:20:55 +02:00
parent a3f22d020b
commit 3444996b4c
2 changed files with 34 additions and 7 deletions

View File

@ -34,9 +34,22 @@ public:
return success();
}
};
} // namespace
namespace {
class BufferizeIndexCastOp : public OpConversionPattern<IndexCastOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(IndexCastOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
IndexCastOp::Adaptor adaptor(operands);
auto tensorType = op.getType().cast<RankedTensorType>();
rewriter.replaceOpWithNewOp<IndexCastOp>(
op, adaptor.in(),
MemRefType::get(tensorType.getShape(), tensorType.getElementType()));
return success();
}
};
class BufferizeSelectOp : public OpConversionPattern<SelectOp> {
public:
using OpConversionPattern::OpConversionPattern;
@ -56,8 +69,8 @@ public:
void mlir::populateStdBufferizePatterns(BufferizeTypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<BufferizeDimOp, BufferizeSelectOp>(typeConverter,
patterns.getContext());
patterns.add<BufferizeDimOp, BufferizeSelectOp, BufferizeIndexCastOp>(
typeConverter, patterns.getContext());
}
namespace {
@ -68,14 +81,15 @@ struct StdBufferizePass : public StdBufferizeBase<StdBufferizePass> {
RewritePatternSet patterns(context);
ConversionTarget target(*context);
target.addLegalDialect<memref::MemRefDialect>();
target.addLegalDialect<StandardOpsDialect>();
target.addLegalDialect<scf::SCFDialect>();
target.addLegalDialect<scf::SCFDialect, StandardOpsDialect,
memref::MemRefDialect>();
populateStdBufferizePatterns(typeConverter, patterns);
// We only bufferize the case of tensor selected type and scalar condition,
// as that boils down to a select over memref descriptors (don't need to
// touch the data).
target.addDynamicallyLegalOp<IndexCastOp>(
[&](IndexCastOp op) { return typeConverter.isLegal(op.getType()); });
target.addDynamicallyLegalOp<SelectOp>([&](SelectOp op) {
return typeConverter.isLegal(op.getType()) ||
!op.condition().getType().isa<IntegerType>();

View File

@ -24,3 +24,16 @@ func @select(%arg0: i1, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> {
%0 = select %arg0, %arg1, %arg2 : tensor<f32>
return %0 : tensor<f32>
}
// CHECK-LABEL: func @index_cast(
// CHECK-SAME: %[[TENSOR:.*]]: tensor<i32>, %[[SCALAR:.*]]: i32
func @index_cast(%tensor: tensor<i32>, %scalar: i32) -> (tensor<index>, index) {
%index_tensor = index_cast %tensor : tensor<i32> to tensor<index>
%index_scalar = index_cast %scalar : i32 to index
return %index_tensor, %index_scalar : tensor<index>, index
}
// CHECK: %[[MEMREF:.*]] = memref.buffer_cast %[[TENSOR]] : memref<i32>
// CHECK-NEXT: %[[INDEX_MEMREF:.*]] = index_cast %[[MEMREF]]
// CHECK-SAME: memref<i32> to memref<index>
// CHECK-NEXT: %[[INDEX_TENSOR:.*]] = memref.tensor_load %[[INDEX_MEMREF]]
// CHECK: return %[[INDEX_TENSOR]]