[mlir][bufferization] Implement BufferDeallocationopInterface for scf.forall.in_parallel (#66351)

The scf.forall.in_parallel terminator operation has a nested graph region with the NoTerminator trait. Such regions are not supported by the default implementations. Therefore, this commit adds a specialized implementation for
this operation which only covers the case where the nested region is empty.
This is because after bufferization, ops like tensor.parallel_insert_slice were already converted to memref operations residing int the scf.forall only and the nested region of scf.forall.in_parallel ends up empty.
This commit is contained in:
Martin Erhart 2023-09-14 16:20:24 +02:00 committed by GitHub
parent 9e739fdb85
commit 66aa9a2517
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 137 additions and 0 deletions

View File

@ -0,0 +1,22 @@
//===- BufferDeallocationOpInterfaceImpl.h ----------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_SCF_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H
#define MLIR_DIALECT_SCF_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H
namespace mlir {
class DialectRegistry;
namespace scf {
void registerBufferDeallocationOpInterfaceExternalModels(
DialectRegistry &registry);
} // namespace scf
} // namespace mlir
#endif // MLIR_DIALECT_SCF_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H

View File

@ -60,6 +60,8 @@
#include "mlir/Dialect/Quant/QuantOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
#include "mlir/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.h"
#include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
@ -149,6 +151,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
memref::registerValueBoundsOpInterfaceExternalModels(registry);
memref::registerMemorySlotExternalModels(registry);
scf::registerBufferDeallocationOpInterfaceExternalModels(registry);
scf::registerBufferizableOpInterfaceExternalModels(registry);
scf::registerValueBoundsOpInterfaceExternalModels(registry);
shape::registerBufferizableOpInterfaceExternalModels(registry);

View File

@ -0,0 +1,87 @@
//===- BufferDeallocationOpInterfaceImpl.cpp ------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.h"
#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
using namespace mlir;
using namespace mlir::bufferization;
namespace {
/// The `scf.forall.in_parallel` terminator is special in a few ways:
/// * It does not implement the BranchOpInterface or
/// RegionBranchTerminatorOpInterface, but the ParallelCombiningOpInterface
/// which is not supported by BufferDeallocation.
/// * It has a graph-like region which only allows one specific tensor op
/// * After bufferization the nested region is always empty
/// For these reasons we provide custom deallocation logic via this external
/// model.
///
/// Example:
/// ```mlir
/// scf.forall (%arg1) in (%arg0) {
/// %alloc = memref.alloc() : memref<2xf32>
/// ...
/// <implicit in_parallel terminator here>
/// }
/// ```
/// gets transformed to
/// ```mlir
/// scf.forall (%arg1) in (%arg0) {
/// %alloc = memref.alloc() : memref<2xf32>
/// ...
/// bufferization.dealloc (%alloc : memref<2xf32>) if (%true)
/// <implicit in_parallel terminator here>
/// }
/// ```
struct InParallelOpInterface
: public BufferDeallocationOpInterface::ExternalModel<InParallelOpInterface,
scf::InParallelOp> {
FailureOr<Operation *> process(Operation *op, DeallocationState &state,
const DeallocationOptions &options) const {
auto inParallelOp = cast<scf::InParallelOp>(op);
OpBuilder builder(op);
if (!inParallelOp.getBody()->empty())
return op->emitError("only supported when nested region is empty");
// Collect the values to deallocate and retain and use them to create the
// dealloc operation.
Block *block = op->getBlock();
SmallVector<Value> memrefs, conditions, toRetain;
if (failed(state.getMemrefsAndConditionsToDeallocate(
builder, op->getLoc(), block, memrefs, conditions)))
return failure();
state.getMemrefsToRetain(block, /*toBlock=*/nullptr, {}, toRetain);
if (memrefs.empty() && toRetain.empty())
return op;
auto deallocOp = builder.create<bufferization::DeallocOp>(
op->getLoc(), memrefs, conditions, toRetain);
// We want to replace the current ownership of the retained values with the
// result values of the dealloc operation as they are always unique.
state.resetOwnerships(deallocOp.getRetained(), block);
for (auto [retained, ownership] :
llvm::zip(deallocOp.getRetained(), deallocOp.getUpdatedConditions()))
state.updateOwnership(retained, ownership, block);
return op;
}
};
} // namespace
void mlir::scf::registerBufferDeallocationOpInterfaceExternalModels(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, SCFDialect *dialect) {
InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
});
}

View File

@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRSCFTransforms
BufferDeallocationOpInterfaceImpl.cpp
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
ForToWhile.cpp

View File

@ -0,0 +1,24 @@
// RUN: mlir-opt -verify-diagnostics -ownership-based-buffer-deallocation \
// RUN: -buffer-deallocation-simplification -split-input-file %s | FileCheck %s
func.func @parallel_insert_slice(%arg0: index) {
%c0 = arith.constant 0 : index
%alloc = memref.alloc() : memref<2xf32>
scf.forall (%arg1) in (%arg0) {
%alloc0 = memref.alloc() : memref<2xf32>
%0 = memref.load %alloc[%c0] : memref<2xf32>
linalg.fill ins(%0 : f32) outs(%alloc0 : memref<2xf32>)
}
return
}
// CHECK-LABEL: func @parallel_insert_slice
// CHECK-SAME: (%arg0: index)
// CHECK: [[ALLOC0:%.+]] = memref.alloc(
// CHECK: scf.forall
// CHECK: [[ALLOC1:%.+]] = memref.alloc(
// CHECK: bufferization.dealloc ([[ALLOC1]] : memref<2xf32>) if (%true
// CHECK-NOT: retain
// CHECK: }
// CHECK: bufferization.dealloc ([[ALLOC0]] : memref<2xf32>) if (%true
// CHECK-NOT: retain