mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-02-01 22:53:29 +00:00
[mlir][bufferization] Implement BufferDeallocationopInterface for scf.forall.in_parallel (#66351)
The scf.forall.in_parallel terminator operation has a nested graph region with the NoTerminator trait. Such regions are not supported by the default implementations. Therefore, this commit adds a specialized implementation for this operation which only covers the case where the nested region is empty. This is because after bufferization, ops like tensor.parallel_insert_slice were already converted to memref operations residing int the scf.forall only and the nested region of scf.forall.in_parallel ends up empty.
This commit is contained in:
parent
9e739fdb85
commit
66aa9a2517
@ -0,0 +1,22 @@
|
||||
//===- BufferDeallocationOpInterfaceImpl.h ----------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_SCF_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H
|
||||
#define MLIR_DIALECT_SCF_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class DialectRegistry;
|
||||
|
||||
namespace scf {
|
||||
void registerBufferDeallocationOpInterfaceExternalModels(
|
||||
DialectRegistry ®istry);
|
||||
} // namespace scf
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_SCF_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H
|
@ -60,6 +60,8 @@
|
||||
#include "mlir/Dialect/Quant/QuantOps.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h"
|
||||
#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
|
||||
#include "mlir/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.h"
|
||||
#include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
@ -149,6 +151,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
|
||||
memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
|
||||
memref::registerValueBoundsOpInterfaceExternalModels(registry);
|
||||
memref::registerMemorySlotExternalModels(registry);
|
||||
scf::registerBufferDeallocationOpInterfaceExternalModels(registry);
|
||||
scf::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
scf::registerValueBoundsOpInterfaceExternalModels(registry);
|
||||
shape::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
|
@ -0,0 +1,87 @@
|
||||
//===- BufferDeallocationOpInterfaceImpl.cpp ------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::bufferization;
|
||||
|
||||
namespace {
|
||||
/// The `scf.forall.in_parallel` terminator is special in a few ways:
|
||||
/// * It does not implement the BranchOpInterface or
|
||||
/// RegionBranchTerminatorOpInterface, but the ParallelCombiningOpInterface
|
||||
/// which is not supported by BufferDeallocation.
|
||||
/// * It has a graph-like region which only allows one specific tensor op
|
||||
/// * After bufferization the nested region is always empty
|
||||
/// For these reasons we provide custom deallocation logic via this external
|
||||
/// model.
|
||||
///
|
||||
/// Example:
|
||||
/// ```mlir
|
||||
/// scf.forall (%arg1) in (%arg0) {
|
||||
/// %alloc = memref.alloc() : memref<2xf32>
|
||||
/// ...
|
||||
/// <implicit in_parallel terminator here>
|
||||
/// }
|
||||
/// ```
|
||||
/// gets transformed to
|
||||
/// ```mlir
|
||||
/// scf.forall (%arg1) in (%arg0) {
|
||||
/// %alloc = memref.alloc() : memref<2xf32>
|
||||
/// ...
|
||||
/// bufferization.dealloc (%alloc : memref<2xf32>) if (%true)
|
||||
/// <implicit in_parallel terminator here>
|
||||
/// }
|
||||
/// ```
|
||||
struct InParallelOpInterface
|
||||
: public BufferDeallocationOpInterface::ExternalModel<InParallelOpInterface,
|
||||
scf::InParallelOp> {
|
||||
FailureOr<Operation *> process(Operation *op, DeallocationState &state,
|
||||
const DeallocationOptions &options) const {
|
||||
auto inParallelOp = cast<scf::InParallelOp>(op);
|
||||
OpBuilder builder(op);
|
||||
if (!inParallelOp.getBody()->empty())
|
||||
return op->emitError("only supported when nested region is empty");
|
||||
|
||||
// Collect the values to deallocate and retain and use them to create the
|
||||
// dealloc operation.
|
||||
Block *block = op->getBlock();
|
||||
SmallVector<Value> memrefs, conditions, toRetain;
|
||||
if (failed(state.getMemrefsAndConditionsToDeallocate(
|
||||
builder, op->getLoc(), block, memrefs, conditions)))
|
||||
return failure();
|
||||
|
||||
state.getMemrefsToRetain(block, /*toBlock=*/nullptr, {}, toRetain);
|
||||
if (memrefs.empty() && toRetain.empty())
|
||||
return op;
|
||||
|
||||
auto deallocOp = builder.create<bufferization::DeallocOp>(
|
||||
op->getLoc(), memrefs, conditions, toRetain);
|
||||
|
||||
// We want to replace the current ownership of the retained values with the
|
||||
// result values of the dealloc operation as they are always unique.
|
||||
state.resetOwnerships(deallocOp.getRetained(), block);
|
||||
for (auto [retained, ownership] :
|
||||
llvm::zip(deallocOp.getRetained(), deallocOp.getUpdatedConditions()))
|
||||
state.updateOwnership(retained, ownership, block);
|
||||
|
||||
return op;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::scf::registerBufferDeallocationOpInterfaceExternalModels(
|
||||
DialectRegistry ®istry) {
|
||||
registry.addExtension(+[](MLIRContext *ctx, SCFDialect *dialect) {
|
||||
InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
|
||||
});
|
||||
}
|
@ -1,4 +1,5 @@
|
||||
add_mlir_dialect_library(MLIRSCFTransforms
|
||||
BufferDeallocationOpInterfaceImpl.cpp
|
||||
BufferizableOpInterfaceImpl.cpp
|
||||
Bufferize.cpp
|
||||
ForToWhile.cpp
|
||||
|
24
mlir/test/Dialect/SCF/buffer-deallocation.mlir
Normal file
24
mlir/test/Dialect/SCF/buffer-deallocation.mlir
Normal file
@ -0,0 +1,24 @@
|
||||
// RUN: mlir-opt -verify-diagnostics -ownership-based-buffer-deallocation \
|
||||
// RUN: -buffer-deallocation-simplification -split-input-file %s | FileCheck %s
|
||||
|
||||
func.func @parallel_insert_slice(%arg0: index) {
|
||||
%c0 = arith.constant 0 : index
|
||||
%alloc = memref.alloc() : memref<2xf32>
|
||||
scf.forall (%arg1) in (%arg0) {
|
||||
%alloc0 = memref.alloc() : memref<2xf32>
|
||||
%0 = memref.load %alloc[%c0] : memref<2xf32>
|
||||
linalg.fill ins(%0 : f32) outs(%alloc0 : memref<2xf32>)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @parallel_insert_slice
|
||||
// CHECK-SAME: (%arg0: index)
|
||||
// CHECK: [[ALLOC0:%.+]] = memref.alloc(
|
||||
// CHECK: scf.forall
|
||||
// CHECK: [[ALLOC1:%.+]] = memref.alloc(
|
||||
// CHECK: bufferization.dealloc ([[ALLOC1]] : memref<2xf32>) if (%true
|
||||
// CHECK-NOT: retain
|
||||
// CHECK: }
|
||||
// CHECK: bufferization.dealloc ([[ALLOC0]] : memref<2xf32>) if (%true
|
||||
// CHECK-NOT: retain
|
Loading…
x
Reference in New Issue
Block a user