mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-02-08 18:37:11 +00:00
[mlir][bufferization] Improve analysis for element-wise operations
Before this change, two equivalent operands that bufferize to a memory read and write, respectively, were always conflicting. This change improves the analysis for ops that bufferize to element-wise access. Such ops can bufferize in-place, because an original element value is not needed anymore after computing and writing an updated element value. This change allows ops such as the following one to bufferize in-place: ``` %0 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%a, %b : tensor<5xf32>, tensor<5xf32>) outs(%a : tensor<5xf32>) -> tensor<5xf32> ``` Differential Revision: https://reviews.llvm.org/D156887
This commit is contained in:
parent
3feb63e112
commit
5468340553
@ -91,6 +91,56 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
|
||||
llvm_unreachable("bufferizesToMemoryWrite not implemented");
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return `true` if the operation bufferizes to IR that performs only
|
||||
element-wise accesses on all tensor operands. (All operands must have
|
||||
the same shape.) The `bufferize` method must be implemented in such a
|
||||
way that it is free of loop-carried dependences. I.e., all loads at a
|
||||
position appear before all stores at the same position.
|
||||
|
||||
Example: Consider a hypothetical op element-wise op, where the "ins"
|
||||
bufferize to a memory read and the "outs" bufferize to a memory write.
|
||||
```
|
||||
test.element_wise ins(%0), outs(%1) : tensor<3xf32>
|
||||
```
|
||||
|
||||
The following is a valid access pattern:
|
||||
```
|
||||
load(%0[1])
|
||||
store(%1[1])
|
||||
load(%0[2])
|
||||
store(%1[2])
|
||||
load(%0[0])
|
||||
store(%1[0])
|
||||
```
|
||||
|
||||
The following would be an invalid (not element-wise) access pattern:
|
||||
```
|
||||
load(%0[1])
|
||||
store(%0[1])
|
||||
load(%0[1])
|
||||
...
|
||||
```
|
||||
|
||||
Element-wise ops can sometimes bufferize more efficiently: a RaW
|
||||
conflict between two operands of the same op can be avoided if it is
|
||||
guaranteed that an original element value is no longer needed after
|
||||
writing a computed element value at the same location. E.g., such an
|
||||
optimization is possible in the above example if %0 and %1 are
|
||||
equivalent tensors. (It is not possible, if %0 and %1 are merely
|
||||
aliasing. It is not necessary if %0 and %1 are not aliasing at all,
|
||||
because there would be no conflict anyway.)
|
||||
}],
|
||||
/*retType=*/"bool",
|
||||
/*methodName=*/"bufferizesToElementwiseAccess",
|
||||
/*args=*/(ins "const ::mlir::bufferization::AnalysisState &":$state),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
// It is always safe to assume that the op is not element-wise.
|
||||
return false;
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return `true` if the given OpResult bufferizes to a memory write.
|
||||
|
@ -542,6 +542,22 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
|
||||
}
|
||||
}
|
||||
|
||||
// Two equivalent operands of the same op are not conflicting if the op
|
||||
// bufferizes to element-wise access. I.e., all loads at a position happen
|
||||
// before all stores to the same position.
|
||||
if (conflictingWritingOp == readingOp &&
|
||||
state.areEquivalentBufferizedValues(uRead->get(),
|
||||
uConflictingWrite->get())) {
|
||||
if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) {
|
||||
if (bufferizableOp.bufferizesToElementwiseAccess(state)) {
|
||||
LLVM_DEBUG(
|
||||
llvm::dbgs()
|
||||
<< " no conflict: op bufferizes to element-wise access\n");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// No conflict if the op interface says so.
|
||||
if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) {
|
||||
if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state)) {
|
||||
|
@ -95,8 +95,8 @@ struct LinalgOpInterface
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
// Operand is read if it is used in the computation.
|
||||
auto genericOp = cast<linalg::LinalgOp>(op);
|
||||
return genericOp.payloadUsesValueFromOperand(&opOperand);
|
||||
auto linalgOp = cast<linalg::LinalgOp>(op);
|
||||
return linalgOp.payloadUsesValueFromOperand(&opOperand);
|
||||
}
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||
@ -106,6 +106,33 @@ struct LinalgOpInterface
|
||||
return dpsOp.isDpsInit(&opOperand);
|
||||
}
|
||||
|
||||
bool bufferizesToElementwiseAccess(Operation *op,
|
||||
const AnalysisState &state) const {
|
||||
auto linalgOp = cast<linalg::LinalgOp>(op);
|
||||
|
||||
// All loops must be parallel.
|
||||
if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
|
||||
return false;
|
||||
|
||||
// All index maps of tensors must be identity maps.
|
||||
SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
|
||||
assert(linalgOp->getNumOperands() == indexingMaps.size() &&
|
||||
"unexpected number of indexing maps");
|
||||
for (auto [operand, map] :
|
||||
llvm::zip(linalgOp->getOperands(), indexingMaps)) {
|
||||
// Non-tensors do not participate in bufferization, so they can be
|
||||
// ignored.
|
||||
if (!isa<RankedTensorType, MemRefType>(operand.getType()))
|
||||
continue;
|
||||
// TODO: This could be generalized to other indexing maps. (All indexing
|
||||
// must be the same.)
|
||||
if (!map.isIdentity())
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
const BufferizationOptions &options) const {
|
||||
return bufferizeDestinationStyleOpInterface(
|
||||
|
59
mlir/test/Dialect/Linalg/one-shot-bufferize-analysis.mlir
Normal file
59
mlir/test/Dialect/Linalg/one-shot-bufferize-analysis.mlir
Normal file
@ -0,0 +1,59 @@
|
||||
// RUN: mlir-opt %s -one-shot-bufferize="allow-return-allocs bufferize-function-boundaries test-analysis-only" -split-input-file | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @elementwise_no_conflict
|
||||
func.func @elementwise_no_conflict(%a: tensor<5xf32>,
|
||||
%b: tensor<5xf32>) -> tensor<5xf32> {
|
||||
// CHECK: linalg.elemwise_binary
|
||||
// CHECK-SAME: {__inplace_operands_attr__ = ["true", "true", "true"], fun = #linalg.binary_fn<add>}
|
||||
%0 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
|
||||
ins(%a, %b : tensor<5xf32>, tensor<5xf32>)
|
||||
outs(%a : tensor<5xf32>) -> tensor<5xf32>
|
||||
return %0 : tensor<5xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @elementwise_no_conflict_2
|
||||
func.func @elementwise_no_conflict_2(%a: tensor<5xf32>) -> tensor<5xf32> {
|
||||
// CHECK: linalg.elemwise_binary
|
||||
// CHECK-SAME: {__inplace_operands_attr__ = ["true", "true", "true"], fun = #linalg.binary_fn<add>}
|
||||
%0 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
|
||||
ins(%a, %a : tensor<5xf32>, tensor<5xf32>)
|
||||
outs(%a : tensor<5xf32>) -> tensor<5xf32>
|
||||
return %0 : tensor<5xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @elementwise_no_conflict_3
|
||||
func.func @elementwise_no_conflict_3(%a: tensor<5xf32>) -> tensor<5xf32> {
|
||||
%c0f = arith.constant 1.0 : f32
|
||||
// CHECK: linalg.elemwise_binary
|
||||
// CHECK-SAME: {__inplace_operands_attr__ = ["true", "none", "true"], fun = #linalg.binary_fn<add>}
|
||||
%0 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
|
||||
ins(%a, %c0f : tensor<5xf32>, f32)
|
||||
outs(%a : tensor<5xf32>) -> tensor<5xf32>
|
||||
return %0 : tensor<5xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @not_elementwise(%a: tensor<5x6xf32>) -> tensor<5x6xf32> {
|
||||
%cst = arith.constant 5.0 : f32
|
||||
// CHECK: tensor.extract_slice
|
||||
// CHECK-SAME: {__inplace_operands_attr__ = ["false"]}
|
||||
%b = tensor.extract_slice %a[0, 0] [1, 6] [1, 1]
|
||||
: tensor<5x6xf32> to tensor<6xf32>
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: {__inplace_operands_attr__ = ["true", "true"]}
|
||||
%0 = linalg.generic
|
||||
{ iterator_types = ["parallel", "parallel"],
|
||||
indexing_maps = [ affine_map<(d0, d1) -> (d1)>,
|
||||
affine_map<(d0, d1) -> (d0, d1)>] }
|
||||
ins(%b: tensor<6xf32>) outs(%a: tensor<5x6xf32>) {
|
||||
^bb0(%arg0: f32, %arg1: f32):
|
||||
%r = arith.addf %arg0, %arg1 : f32
|
||||
linalg.yield %r : f32
|
||||
} -> tensor<5x6xf32>
|
||||
return %0 : tensor<5x6xf32>
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user