mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-12-11 17:08:42 +00:00
[mlir][tensor][bufferize] Support tensor.rank in BufferizableOpInterfaceImpl
This is the only op that is not supported via BufferizableOpInterfaceImpl bufferization. Once this op is supported we can switch `tensor-bufferize` over to the new unified bufferization. Differential Revision: https://reviews.llvm.org/D117985
This commit is contained in:
parent
d193f7be78
commit
fc08d1c294
@ -457,16 +457,22 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
|
||||
Value memref = frontBlock.addArgument(memrefType, bbArg.getLoc());
|
||||
OpBuilder b(funcOp->getContext());
|
||||
b.setInsertionPointToStart(&frontBlock);
|
||||
// Replace all uses of bbArg through a ToMemRefOp by a memref::CastOp.
|
||||
// Replace all uses of bbArg through a ToMemRefOp.
|
||||
for (auto &use : llvm::make_early_inc_range(bbArg.getUses())) {
|
||||
if (auto toMemrefOp =
|
||||
dyn_cast<bufferization::ToMemrefOp>(use.getOwner())) {
|
||||
assert(memref::CastOp::areCastCompatible(
|
||||
memref.getType(), toMemrefOp.memref().getType()) &&
|
||||
"bufferizeFuncOpBoundary: cast incompatible");
|
||||
auto castOp = b.create<memref::CastOp>(
|
||||
funcOp.getLoc(), toMemrefOp.memref().getType(), memref);
|
||||
toMemrefOp.memref().replaceAllUsesWith(castOp);
|
||||
if (memref.getType() != toMemrefOp.memref().getType()) {
|
||||
// Type has changed, insert a cast.
|
||||
assert(memref::CastOp::areCastCompatible(
|
||||
memref.getType(), toMemrefOp.memref().getType()) &&
|
||||
"bufferizeFuncOpBoundary: cast incompatible");
|
||||
auto castOp = b.create<memref::CastOp>(
|
||||
funcOp.getLoc(), toMemrefOp.memref().getType(), memref);
|
||||
toMemrefOp.memref().replaceAllUsesWith(castOp);
|
||||
} else {
|
||||
// Type did not change, replace directly.
|
||||
toMemrefOp.memref().replaceAllUsesWith(memref);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Replace all remaining uses by a to_tensor.
|
||||
|
@ -463,6 +463,35 @@ struct InsertSliceOpInterface
|
||||
}
|
||||
};
|
||||
|
||||
/// Bufferization of tensor.rank. Replace with memref.rank.
|
||||
struct RankOpInterface
|
||||
: public BufferizableOpInterface::ExternalModel<RankOpInterface,
|
||||
tensor::RankOp> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||
const BufferizationState &state) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||
const BufferizationState &state) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||
const BufferizationState &state) const {
|
||||
return OpResult();
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
const BufferizationState &state) const {
|
||||
auto rankOp = cast<tensor::RankOp>(op);
|
||||
Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/);
|
||||
replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
|
||||
v);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
} // namespace tensor
|
||||
} // namespace mlir
|
||||
@ -475,4 +504,5 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
|
||||
registry.addOpInterface<ExtractOp, ExtractOpInterface>();
|
||||
registry.addOpInterface<InsertOp, InsertOpInterface>();
|
||||
registry.addOpInterface<InsertSliceOp, InsertSliceOpInterface>();
|
||||
registry.addOpInterface<RankOp, RankOpInterface>();
|
||||
}
|
||||
|
@ -1348,3 +1348,14 @@ func @write_after_select_read_one(
|
||||
// CHECK: return %[[f]], %[[select]]
|
||||
return %f, %w : f32, tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @tensor_rank(
|
||||
// CHECK-SAME: %[[arg0:.*]]: memref<*xf32>
|
||||
func @tensor_rank(%arg0: tensor<*xf32>) -> index {
|
||||
// CHECK: %[[r:.*]] = memref.rank %[[arg0]]
|
||||
%0 = tensor.rank %arg0 : tensor<*xf32>
|
||||
// CHECK: return %[[r]] : index
|
||||
return %0 : index
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user