[mlir][shape] Migrate bufferization to BufferizableOpInterface

Differential Revision: https://reviews.llvm.org/D121043
This commit is contained in:
Matthias Springer 2022-03-07 21:27:53 +09:00
parent df6c26fd34
commit 93e663273b
7 changed files with 208 additions and 98 deletions

View File

@ -0,0 +1,20 @@
//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_SHAPE_BUFFERIZABLEOPINTERFACEIMPL_H
#define MLIR_DIALECT_SHAPE_BUFFERIZABLEOPINTERFACEIMPL_H
namespace mlir {
class DialectRegistry;
namespace shape {
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
} // namespace shape
} // namespace mlir
#endif // MLIR_DIALECT_SHAPE_BUFFERIZABLEOPINTERFACEIMPL_H

View File

@ -40,21 +40,6 @@ void populateShapeRewritePatterns(RewritePatternSet &patterns);
void populateRemoveShapeConstraintsPatterns(RewritePatternSet &patterns);
std::unique_ptr<OperationPass<FuncOp>> createRemoveShapeConstraintsPass();
/// Populates patterns for shape dialect structural type conversions and sets up
/// the provided ConversionTarget with the appropriate legality configuration
/// for the ops to get converted properly.
///
/// A "structural" type conversion is one where the underlying ops are
/// completely agnostic to the actual types involved and simply need to update
/// their types consistently. An example of this is shape.assuming -- the
/// shape.assuming op and the corresponding shape.assuming_yield op need to have
/// consistent types, but the exact types don't matter. So all that we need to
/// do for a structural type conversion is to update both of their types
/// consistently to the new types prescribed by the TypeConverter.
void populateShapeStructuralTypeConversionsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target);
// Bufferizes shape dialect ops.
//
// Note that most shape dialect ops must be converted to std before

View File

@ -0,0 +1,169 @@
//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
using namespace mlir;
using namespace mlir::bufferization;
using namespace mlir::shape;
namespace mlir {
namespace shape {
namespace {
/// Bufferization of shape.assuming.
struct AssumingOpInterface
: public BufferizableOpInterface::ExternalModel<AssumingOpInterface,
shape::AssumingOp> {
SmallVector<OpOperand *>
getAliasingOpOperand(Operation *op, OpResult opResult,
const BufferizationState &state) const {
// AssumingOps do not have tensor OpOperands. The yielded value can be any
// SSA value that is in scope. To allow for use-def chain traversal through
// AssumingOps in the analysis, the corresponding yield value is considered
// to be aliasing with the result.
auto assumingOp = cast<shape::AssumingOp>(op);
size_t resultNum = std::distance(op->getOpResults().begin(),
llvm::find(op->getOpResults(), opResult));
// TODO: Support multiple blocks.
assert(assumingOp.getDoRegion().getBlocks().size() == 1 &&
"expected exactly 1 block");
auto yieldOp = dyn_cast<shape::AssumingYieldOp>(
assumingOp.getDoRegion().front().getTerminator());
assert(yieldOp && "expected shape.assuming_yield terminator");
return {&yieldOp->getOpOperand(resultNum)};
}
// TODO: For better bufferization results, this could return `true` only if
// there is a memory write in the region.
bool isMemoryWrite(Operation *op, OpResult opResult,
const BufferizationState &state) const {
// Similar to scf.if, results of this op are always considered memory writes
// in the analysis. This is a useful pattern for all ops that have tensor
// OpResults but no tensor OpOperands. By default, `isMemoryWrite` is
// implemented in terms of `bufferizesToMemoryWrite`, which does not work on
// ops without OpOperands.
return true;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationState &state) const {
auto assumingOp = cast<shape::AssumingOp>(op);
// Compute new result types.
SmallVector<Type> newResultTypes;
for (Type type : assumingOp->getResultTypes()) {
if (auto tensorType = type.dyn_cast<TensorType>()) {
newResultTypes.push_back(getMemRefType(tensorType, state.getOptions()));
} else {
newResultTypes.push_back(type);
}
}
// Create new op and move over region.
auto newOp = rewriter.create<shape::AssumingOp>(
op->getLoc(), newResultTypes, assumingOp.getWitness());
newOp.getDoRegion().takeBody(assumingOp.getRegion());
// Update terminator.
assert(newOp.getDoRegion().getBlocks().size() == 1 &&
"only 1 block supported");
Block *newBlock = &newOp.getDoRegion().front();
auto yieldOp = cast<shape::AssumingYieldOp>(newBlock->getTerminator());
rewriter.setInsertionPoint(yieldOp);
SmallVector<Value> newYieldValues;
for (const auto &it : llvm::enumerate(yieldOp.operands())) {
Value val = it.value();
if (val.getType().isa<TensorType>()) {
newYieldValues.push_back(rewriter.create<bufferization::ToMemrefOp>(
yieldOp.getLoc(), newResultTypes[it.index()], val));
} else {
newYieldValues.push_back(val);
}
}
rewriter.replaceOpWithNewOp<shape::AssumingYieldOp>(yieldOp,
newYieldValues);
// Update all uses of the old op.
rewriter.setInsertionPointAfter(newOp);
SmallVector<Value> newResults;
for (const auto &it : llvm::enumerate(assumingOp->getResultTypes())) {
if (it.value().isa<TensorType>()) {
newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
assumingOp.getLoc(), newOp->getResult(it.index())));
} else {
newResults.push_back(newOp->getResult(it.index()));
}
}
// Replace old op.
rewriter.replaceOp(assumingOp, newResults);
return success();
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const BufferizationState &state) const {
return BufferRelation::Equivalent;
}
};
/// Bufferization of shape.assuming_yield. Bufferized as part of their enclosing
/// ops, so this is for analysis only.
struct AssumingYieldOpInterface
: public BufferizableOpInterface::ExternalModel<AssumingYieldOpInterface,
shape::AssumingOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const BufferizationState &state) const {
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const BufferizationState &state) const {
return false;
}
SmallVector<OpResult>
getAliasingOpResult(Operation *op, OpOperand &opOperand,
const BufferizationState &state) const {
assert(isa<shape::AssumingOp>(op->getParentOp()) &&
"expected that parent is an AssumingOp");
return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
}
bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
const BufferizationState &state) const {
// Yield operands always bufferize inplace. Otherwise, an alloc + copy
// may be generated inside the block. We should not return/yield allocations
// when possible.
return true;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationState &state) const {
// Op is bufferized as part of AssumingOp.
return failure();
}
};
} // namespace
} // namespace shape
} // namespace mlir
void mlir::shape::registerBufferizableOpInterfaceExternalModels(
DialectRegistry &registry) {
registry.addOpInterface<shape::AssumingOp, AssumingOpInterface>();
registry.addOpInterface<shape::AssumingYieldOp, AssumingYieldOpInterface>();
}

View File

@ -8,30 +8,32 @@
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "PassDetail.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
using namespace bufferization;
namespace {
struct ShapeBufferizePass : public ShapeBufferizeBase<ShapeBufferizePass> {
void runOnOperation() override {
MLIRContext &ctx = getContext();
BufferizationOptions options = getPartialBufferizationOptions();
options.allowDialectInFilter<shape::ShapeDialect>();
RewritePatternSet patterns(&ctx);
bufferization::BufferizeTypeConverter typeConverter;
ConversionTarget target(ctx);
bufferization::populateBufferizeMaterializationLegality(target);
populateShapeStructuralTypeConversionsAndLegality(typeConverter, patterns,
target);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
if (failed(bufferizeOp(getOperation(), options)))
signalPassFailure();
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect,
shape::ShapeDialect>();
shape::registerBufferizableOpInterfaceExternalModels(registry);
}
};
} // namespace

View File

@ -1,8 +1,8 @@
add_mlir_dialect_library(MLIRShapeOpsTransforms
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
RemoveShapeConstraints.cpp
ShapeToShapeLowering.cpp
StructuralTypeConversions.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ShapeOps/Transforms
@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRShapeOpsTransforms
target_link_libraries(MLIRShapeOpsTransforms
PUBLIC
MLIRArithmetic
MLIRBufferization
MLIRBufferizationTransforms
MLIRIR
MLIRMemRef

View File

@ -1,70 +0,0 @@
//===- StructuralTypeConversions.cpp - Shape structural type conversions --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
using namespace mlir::shape;
namespace {
class ConvertAssumingOpTypes : public OpConversionPattern<AssumingOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AssumingOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
SmallVector<Type, 2> newResultTypes;
newResultTypes.reserve(op.getNumResults());
for (auto result : op.getResults()) {
auto originalType = result.getType();
Type convertedType = getTypeConverter()->convertType(originalType);
newResultTypes.push_back(convertedType);
}
auto newAssumingOp = rewriter.create<AssumingOp>(
op.getLoc(), newResultTypes, op.getWitness());
rewriter.inlineRegionBefore(op.getDoRegion(), newAssumingOp.getDoRegion(),
newAssumingOp.getDoRegion().end());
rewriter.replaceOp(op, newAssumingOp.getResults());
return success();
}
};
} // namespace
namespace {
class ConvertAssumingYieldOpTypes
: public OpConversionPattern<AssumingYieldOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AssumingYieldOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
rewriter.replaceOpWithNewOp<AssumingYieldOp>(op, adaptor.getOperands());
return success();
}
};
} // namespace
void mlir::populateShapeStructuralTypeConversionsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
patterns.add<ConvertAssumingOpTypes, ConvertAssumingYieldOpTypes>(
typeConverter, patterns.getContext());
target.addDynamicallyLegalOp<AssumingOp>([&](AssumingOp op) {
return typeConverter.isLegal(op.getResultTypes());
});
target.addDynamicallyLegalOp<AssumingYieldOp>([&](AssumingYieldOp op) {
return typeConverter.isLegal(op.getOperandTypes());
});
}

View File

@ -2702,7 +2702,10 @@ cc_library(
"lib/Dialect/Shape/Transforms/*.cpp",
"lib/Dialect/Shape/Transforms/*.h",
]),
hdrs = ["include/mlir/Dialect/Shape/Transforms/Passes.h"],
hdrs = [
"include/mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h",
"include/mlir/Dialect/Shape/Transforms/Passes.h",
],
includes = ["include"],
deps = [
":ArithmeticDialect",