mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-02-22 19:24:13 +00:00
[mlir] Add a pattern to bufferize std.index_cast.
Differential Revision: https://reviews.llvm.org/D102088
This commit is contained in:
parent
a3f22d020b
commit
3444996b4c
@ -34,9 +34,22 @@ public:
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class BufferizeIndexCastOp : public OpConversionPattern<IndexCastOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(IndexCastOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
IndexCastOp::Adaptor adaptor(operands);
|
||||
auto tensorType = op.getType().cast<RankedTensorType>();
|
||||
rewriter.replaceOpWithNewOp<IndexCastOp>(
|
||||
op, adaptor.in(),
|
||||
MemRefType::get(tensorType.getShape(), tensorType.getElementType()));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class BufferizeSelectOp : public OpConversionPattern<SelectOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
@ -56,8 +69,8 @@ public:
|
||||
|
||||
void mlir::populateStdBufferizePatterns(BufferizeTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<BufferizeDimOp, BufferizeSelectOp>(typeConverter,
|
||||
patterns.getContext());
|
||||
patterns.add<BufferizeDimOp, BufferizeSelectOp, BufferizeIndexCastOp>(
|
||||
typeConverter, patterns.getContext());
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -68,14 +81,15 @@ struct StdBufferizePass : public StdBufferizeBase<StdBufferizePass> {
|
||||
RewritePatternSet patterns(context);
|
||||
ConversionTarget target(*context);
|
||||
|
||||
target.addLegalDialect<memref::MemRefDialect>();
|
||||
target.addLegalDialect<StandardOpsDialect>();
|
||||
target.addLegalDialect<scf::SCFDialect>();
|
||||
target.addLegalDialect<scf::SCFDialect, StandardOpsDialect,
|
||||
memref::MemRefDialect>();
|
||||
|
||||
populateStdBufferizePatterns(typeConverter, patterns);
|
||||
// We only bufferize the case of tensor selected type and scalar condition,
|
||||
// as that boils down to a select over memref descriptors (don't need to
|
||||
// touch the data).
|
||||
target.addDynamicallyLegalOp<IndexCastOp>(
|
||||
[&](IndexCastOp op) { return typeConverter.isLegal(op.getType()); });
|
||||
target.addDynamicallyLegalOp<SelectOp>([&](SelectOp op) {
|
||||
return typeConverter.isLegal(op.getType()) ||
|
||||
!op.condition().getType().isa<IntegerType>();
|
||||
|
@ -24,3 +24,16 @@ func @select(%arg0: i1, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> {
|
||||
%0 = select %arg0, %arg1, %arg2 : tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @index_cast(
|
||||
// CHECK-SAME: %[[TENSOR:.*]]: tensor<i32>, %[[SCALAR:.*]]: i32
|
||||
func @index_cast(%tensor: tensor<i32>, %scalar: i32) -> (tensor<index>, index) {
|
||||
%index_tensor = index_cast %tensor : tensor<i32> to tensor<index>
|
||||
%index_scalar = index_cast %scalar : i32 to index
|
||||
return %index_tensor, %index_scalar : tensor<index>, index
|
||||
}
|
||||
// CHECK: %[[MEMREF:.*]] = memref.buffer_cast %[[TENSOR]] : memref<i32>
|
||||
// CHECK-NEXT: %[[INDEX_MEMREF:.*]] = index_cast %[[MEMREF]]
|
||||
// CHECK-SAME: memref<i32> to memref<index>
|
||||
// CHECK-NEXT: %[[INDEX_TENSOR:.*]] = memref.tensor_load %[[INDEX_MEMREF]]
|
||||
// CHECK: return %[[INDEX_TENSOR]]
|
||||
|
Loading…
x
Reference in New Issue
Block a user