[mlir][tensor][bufferize] Support tensor.rank in BufferizableOpInterfaceImpl

This is the only op that is not supported via BufferizableOpInterfaceImpl bufferization. Once this op is supported we can switch `tensor-bufferize` over to the new unified bufferization.

Differential Revision: https://reviews.llvm.org/D117985
This commit is contained in:
Matthias Springer 2022-01-25 00:09:36 +09:00
parent d193f7be78
commit fc08d1c294
3 changed files with 54 additions and 7 deletions

View File

@ -457,16 +457,22 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
Value memref = frontBlock.addArgument(memrefType, bbArg.getLoc());
OpBuilder b(funcOp->getContext());
b.setInsertionPointToStart(&frontBlock);
// Replace all uses of bbArg through a ToMemRefOp by a memref::CastOp.
// Replace all uses of bbArg through a ToMemRefOp.
for (auto &use : llvm::make_early_inc_range(bbArg.getUses())) {
if (auto toMemrefOp =
dyn_cast<bufferization::ToMemrefOp>(use.getOwner())) {
assert(memref::CastOp::areCastCompatible(
memref.getType(), toMemrefOp.memref().getType()) &&
"bufferizeFuncOpBoundary: cast incompatible");
auto castOp = b.create<memref::CastOp>(
funcOp.getLoc(), toMemrefOp.memref().getType(), memref);
toMemrefOp.memref().replaceAllUsesWith(castOp);
if (memref.getType() != toMemrefOp.memref().getType()) {
// Type has changed, insert a cast.
assert(memref::CastOp::areCastCompatible(
memref.getType(), toMemrefOp.memref().getType()) &&
"bufferizeFuncOpBoundary: cast incompatible");
auto castOp = b.create<memref::CastOp>(
funcOp.getLoc(), toMemrefOp.memref().getType(), memref);
toMemrefOp.memref().replaceAllUsesWith(castOp);
} else {
// Type did not change, replace directly.
toMemrefOp.memref().replaceAllUsesWith(memref);
}
}
}
// Replace all remaining uses by a to_tensor.

View File

@ -463,6 +463,35 @@ struct InsertSliceOpInterface
}
};
/// Bufferization of tensor.rank. Replace with memref.rank.
struct RankOpInterface
: public BufferizableOpInterface::ExternalModel<RankOpInterface,
tensor::RankOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const BufferizationState &state) const {
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const BufferizationState &state) const {
return false;
}
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
const BufferizationState &state) const {
return OpResult();
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationState &state) const {
auto rankOp = cast<tensor::RankOp>(op);
Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/);
replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
v);
return success();
}
};
} // namespace
} // namespace tensor
} // namespace mlir
@ -475,4 +504,5 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
registry.addOpInterface<ExtractOp, ExtractOpInterface>();
registry.addOpInterface<InsertOp, InsertOpInterface>();
registry.addOpInterface<InsertSliceOp, InsertSliceOpInterface>();
registry.addOpInterface<RankOp, RankOpInterface>();
}

View File

@ -1348,3 +1348,14 @@ func @write_after_select_read_one(
// CHECK: return %[[f]], %[[select]]
return %f, %w : f32, tensor<?xf32>
}
// -----
// CHECK-LABEL: func @tensor_rank(
// CHECK-SAME: %[[arg0:.*]]: memref<*xf32>
func @tensor_rank(%arg0: tensor<*xf32>) -> index {
// CHECK: %[[r:.*]] = memref.rank %[[arg0]]
%0 = tensor.rank %arg0 : tensor<*xf32>
// CHECK: return %[[r]] : index
return %0 : index
}