[mlir][linalg] Fix bufferize_to_allocation error checking

`bufferize_to_allocation` does not supports ops with regions, unless `bufferize_destination_only` is set. In that case, only the operand is replaced with an allocation and wrapped in a `to_tensor` op. The error checking was too strict.

Differential Revision: https://reviews.llvm.org/D159420
This commit is contained in:
Matthias Springer 2023-09-04 10:40:23 +02:00
parent cb54522853
commit b76a180d3b
2 changed files with 39 additions and 13 deletions

View File

@ -461,18 +461,20 @@ Value linalg::bufferizeToAllocation(
AnalysisState state(bufferizationOptions);
#ifndef NDEBUG
// Ops with nested tensor ops are not supported yet. At the moment, this
// function just bufferizes the given op itself, but not its body.
op->walk([&](Operation *nestedOp) {
if (op == nestedOp)
return;
if (llvm::any_of(nestedOp->getOperands(),
[](Value v) { return v.getType().isa<TensorType>(); }))
llvm_unreachable("ops with nested tensor ops are not supported yet");
if (llvm::any_of(nestedOp->getResults(),
[](Value v) { return v.getType().isa<TensorType>(); }))
llvm_unreachable("ops with nested tensor ops are not supported yet");
});
if (!options.bufferizeDestinationOnly) {
// Ops with nested tensor ops are not supported yet. At the moment, this
// function just bufferizes the given op itself, but not its body.
op->walk([&](Operation *nestedOp) {
if (op == nestedOp)
return;
if (llvm::any_of(nestedOp->getOperands(),
[](Value v) { return v.getType().isa<TensorType>(); }))
llvm_unreachable("ops with nested tensor ops are not supported yet");
if (llvm::any_of(nestedOp->getResults(),
[](Value v) { return v.getType().isa<TensorType>(); }))
llvm_unreachable("ops with nested tensor ops are not supported yet");
});
}
#endif // NDEBUG
// Gather tensor results.
@ -509,7 +511,7 @@ Value linalg::bufferizeToAllocation(
if (!state.bufferizesToMemoryWrite(operand))
continue;
if (!isa<RankedTensorType>(operand.get().getType()))
return nullptr;
continue;
addOutOfPlaceOperand(&operand);
}
// TODO: Support multiple buffers.

View File

@ -218,3 +218,27 @@ transform.sequence failures(propagate) {
%0 = transform.structured.match ops{["tensor.insert"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%2, %new = transform.structured.bufferize_to_allocation %0 {memory_space = 4, bufferize_destination_only} : !transform.any_op
}
// -----
// CHECK-LABEL: func @scf_for_destination(
// CHECK-SAME: %[[t:.*]]: tensor<?x10xindex>
// CHECK: %[[alloc:.*]] = memref.alloc(%{{.*}}) : memref<?x10xindex, 4>
// CHECK: memref.tensor_store %[[t]], %[[alloc]]
// CHECK: %[[t2:.*]] = bufferization.to_tensor %[[alloc]] restrict writable
// CHECK: %[[for:.*]] = scf.for {{.*}} iter_args(%{{.*}} = %[[t2]])
// CHECK: memref.dealloc %[[alloc]]
// CHECK: return %[[for]]
func.func @scf_for_destination(%t: tensor<?x10xindex>, %lb: index, %ub: index, %step: index) -> tensor<?x10xindex> {
%r = scf.for %iv = %lb to %ub step %step iter_args(%a = %t) -> tensor<?x10xindex> {
%b = "test.foo"(%a) : (tensor<?x10xindex>) -> (tensor<?x10xindex>)
scf.yield %b : tensor<?x10xindex>
}
return %r : tensor<?x10xindex>
}
transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%2, %new = transform.structured.bufferize_to_allocation %0 {memory_space = 4, bufferize_destination_only} : !transform.any_op
}