[mlir] Make OpBuilder::createOperation to accept raw inputs

This provides a way to create an operation without manipulating
OperationState directly. This is useful for creating unregistered ops.

Reviewed By: rriddle, mehdi_amini

Differential Revision: https://reviews.llvm.org/D120787
This commit is contained in:
Chia-hung Duan 2022-03-23 21:37:26 +00:00
parent 00cc73044d
commit 14ecafd0bd
20 changed files with 73 additions and 76 deletions

View File

@ -1035,7 +1035,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
$_op->getAttrs());
for (Region &r : $_op->getRegions())
r.cloneInto(state.addRegion(), bvm);
return b.createOperation(state);
return b.create(state);
}]
>,
InterfaceMethod<
@ -1056,7 +1056,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
$_op->getAttrs());
for (Region &r : $_op->getRegions())
r.cloneInto(state.addRegion(), bvm);
return b.createOperation(state);
return b.create(state);
}]
>,
InterfaceMethod<
@ -1077,7 +1077,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
$_op->getAttrs());
for (size_t cnt = 0, e = $_op->getNumRegions(); cnt < e; ++cnt)
state.addRegion();
return b.createOperation(state);
return b.create(state);
}]
>,
StaticInterfaceMethod<

View File

@ -405,7 +405,13 @@ public:
Operation *insert(Operation *op);
/// Creates an operation given the fields represented as an OperationState.
Operation *createOperation(const OperationState &state);
Operation *create(const OperationState &state);
/// Creates an operation with the given fields.
Operation *create(Location loc, StringAttr opName, ValueRange operands,
TypeRange types, ArrayRef<NamedAttribute> attributes = {},
BlockRange successors = {},
MutableArrayRef<std::unique_ptr<Region>> regions = {});
private:
/// Helper for sanity checking preconditions for create* methods below.
@ -431,7 +437,7 @@ public:
OperationState state(location,
getCheckRegisteredInfo<OpTy>(location.getContext()));
OpTy::build(*this, state, std::forward<Args>(args)...);
auto *op = createOperation(state);
auto *op = create(state);
auto result = dyn_cast<OpTy>(op);
assert(result && "builder didn't return the right type");
return result;
@ -443,7 +449,7 @@ public:
template <typename OpTy, typename... Args>
void createOrFold(SmallVectorImpl<Value> &results, Location location,
Args &&...args) {
// Create the operation without using 'createOperation' as we don't want to
// Create the operation without using 'create' as we don't want to
// insert it yet.
OperationState state(location,
getCheckRegisteredInfo<OpTy>(location.getContext()));

View File

@ -11,6 +11,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinAttributes.h"
using namespace mlir;
@ -319,11 +320,9 @@ LogicalResult LLVM::detail::oneToOneRewrite(
}
// Create the operation through state since we don't know its C++ type.
OperationState state(op->getLoc(), targetOp);
state.addTypes(packedType);
state.addOperands(operands);
state.addAttributes(op->getAttrs());
Operation *newOp = rewriter.createOperation(state);
Operation *newOp =
rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
packedType, op->getAttrs());
// If the operation produced 0 or 1 result, return them immediately.
if (numResults == 0)

View File

@ -130,11 +130,10 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite(
auto callback = [op, targetOp, &rewriter](Type llvm1DVectorTy,
ValueRange operands) {
OperationState state(op->getLoc(), targetOp);
state.addTypes(llvm1DVectorTy);
state.addOperands(operands);
state.addAttributes(op->getAttrs());
return rewriter.createOperation(state)->getResult(0);
return rewriter
.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
llvm1DVectorTy, op->getAttrs())
->getResult(0);
};
return handleMultidimensionalVectors(op, operands, typeConverter, callback,

View File

@ -1408,10 +1408,9 @@ static Operation *widenOp(Operation *op, VectorizationState &state) {
// name that works both in scalar mode and vector mode.
// TODO: Is it worth considering an Operation.clone operation which
// changes the type so we can promote an Operation with less boilerplate?
OperationState vecOpState(op->getLoc(), op->getName(), vectorOperands,
vectorTypes, op->getAttrs(), /*successors=*/{},
/*regions=*/{});
Operation *vecOp = state.builder.createOperation(vecOpState);
Operation *vecOp =
state.builder.create(op->getLoc(), op->getName().getIdentifier(),
vectorOperands, vectorTypes, op->getAttrs());
state.registerOpVectorReplacement(op, vecOp);
return vecOp;
}

View File

@ -1242,7 +1242,7 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
}
// Create the new operation.
auto *repOp = builder.createOperation(state);
auto *repOp = builder.create(state);
op->replaceAllUsesWith(repOp);
op->erase();

View File

@ -98,16 +98,14 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
/*iteratorTypes=*/iteratorTypes,
/*bodyBuilder=*/
[&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
OperationState state(loc, op->getName());
state.addAttributes(op->getAttrs());
// Only take the input operands in the cloned elementwise op.
state.addOperands(regionArgs.take_front(op->getNumOperands()));
auto resultTypes = llvm::to_vector<6>(
llvm::map_range(op->getResultTypes(), [](Type type) {
return type.cast<TensorType>().getElementType();
}));
state.addTypes(resultTypes);
auto *scalarOp = builder.createOperation(state);
auto *scalarOp =
builder.create(loc, op->getName().getIdentifier(),
regionArgs.take_front(op->getNumOperands()),
resultTypes, op->getAttrs());
builder.create<linalg::YieldOp>(loc, scalarOp->getResults());
});
return success();

View File

@ -299,17 +299,6 @@ static VectorizationResult vectorizeLinalgIndex(OpBuilder &b, Operation *op,
return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
}
/// Create a new vectorized verstion of `op` with the given operands and types.
static Operation *createVectorizedOp(OpBuilder &b, Operation *op,
ValueRange newOperands,
ArrayRef<Type> types) {
OperationState state(op->getLoc(), op->getName());
state.addAttributes(op->getAttrs());
state.addOperands(newOperands);
state.addTypes(types);
return b.createOperation(state);
}
/// Emit reduction operations if the shapes of the value to reduce is different
/// that the result shape.
static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
@ -326,7 +315,9 @@ static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
return nullptr;
SmallVector<bool> reductionMask = getReductionMask(linalgOp);
Value reduce = buildMultiDimReduce(b, op, reduceVec, reductionMask);
return createVectorizedOp(b, op, {reduce, outputVec}, reduce.getType());
return b.create(op->getLoc(), op->getName().getIdentifier(),
/*operands=*/{reduce, outputVec}, reduce.getType(),
op->getAttrs());
}
/// Generic vectorization for a single operation `op`, given already vectorized
@ -420,8 +411,9 @@ vectorizeOneOp(OpBuilder &b, LinalgOp linalgOp, Operation *op,
// Build and return the new op.
return VectorizationResult{
VectorizationStatus::NewOp,
createVectorizedOp(b, op, llvm::to_vector<4>(vectorizedOperands),
llvm::to_vector<4>(returnTypes))};
b.create(op->getLoc(), op->getName().getIdentifier(),
llvm::to_vector<4>(vectorizedOperands),
llvm::to_vector<4>(returnTypes), op->getAttrs())};
}
/// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.

View File

@ -517,7 +517,7 @@ Operation *NormalizeMemRefs::createOpResultsNormalized(FuncOp funcOp,
Region *newRegion = result.addRegion();
newRegion->takeBody(oldRegion);
}
return bb.createOperation(result);
return bb.create(result);
}
return oldOp;
}

View File

@ -368,11 +368,9 @@ public:
newOperands.push_back(operand);
}
}
OperationState state(op->getLoc(), op->getName());
state.addAttributes(op->getAttrs());
state.addOperands(newOperands);
state.addTypes(newVecType);
Operation *newOp = rewriter.createOperation(state);
Operation *newOp =
rewriter.create(op->getLoc(), op->getName().getIdentifier(),
newOperands, newVecType, op->getAttrs());
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecType,
newOp->getResult(0));
return success();

View File

@ -1028,9 +1028,8 @@ struct ReorderCastOpsOnBroadcast
Type castResTy = getElementTypeOrSelf(op->getResult(0));
if (auto vecTy = bcastOp.getSourceType().dyn_cast<VectorType>())
castResTy = VectorType::get(vecTy.getShape(), castResTy);
OperationState state(op->getLoc(), op->getName(), bcastOp.source(),
castResTy, op->getAttrs());
auto castOp = rewriter.createOperation(state);
auto castOp = rewriter.create(op->getLoc(), op->getName().getIdentifier(),
bcastOp.source(), castResTy, op->getAttrs());
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
op, op->getResult(0).getType(), castOp->getResult(0));
return success();
@ -1068,9 +1067,8 @@ struct ReorderCastOpsOnTranspose
auto castResTy = transpOp.getVectorType();
castResTy = VectorType::get(castResTy.getShape(),
getElementTypeOrSelf(op->getResult(0)));
OperationState state(op->getLoc(), op->getName(), transpOp.vector(),
castResTy, op->getAttrs());
auto castOp = rewriter.createOperation(state);
auto castOp = rewriter.create(op->getLoc(), op->getName().getIdentifier(),
transpOp.vector(), castResTy, op->getAttrs());
rewriter.replaceOpWithNewOp<vector::TransposeOp>(
op, op->getResult(0).getType(), castOp->getResult(0),
transpOp.getTransp());

View File

@ -70,8 +70,8 @@ static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
Operation *op,
ArrayRef<Value> operands,
ArrayRef<Type> resultTypes) {
OperationState res(loc, op->getName(), operands, resultTypes, op->getAttrs());
return builder.createOperation(res);
return builder.create(loc, op->getName().getIdentifier(), operands,
resultTypes, op->getAttrs());
}
/// Return the target shape for unrolling for the given `op`. Return llvm::None

View File

@ -377,10 +377,21 @@ Block *OpBuilder::createBlock(Block *insertBefore, TypeRange argTypes,
}
/// Create an operation given the fields represented as an OperationState.
Operation *OpBuilder::createOperation(const OperationState &state) {
Operation *OpBuilder::create(const OperationState &state) {
return insert(Operation::create(state));
}
/// Creates an operation with the given fields.
Operation *OpBuilder::create(Location loc, StringAttr opName,
ValueRange operands, TypeRange types,
ArrayRef<NamedAttribute> attributes,
BlockRange successors,
MutableArrayRef<std::unique_ptr<Region>> regions) {
OperationState state(loc, opName, operands, types, attributes, successors,
regions);
return create(state);
}
/// Attempts to fold the given operation and places new results within
/// 'results'. Returns success if the operation was folded, failure otherwise.
/// Note: This function does not erase the operation on a successful fold.

View File

@ -1151,7 +1151,7 @@ Operation *OperationParser::parseGenericOperation() {
return nullptr;
// Create the operation and try to parse a location for it.
Operation *op = opBuilder.createOperation(result);
Operation *op = opBuilder.create(result);
if (parseTrailingLocationSpecifier(op))
return nullptr;
return op;
@ -1756,7 +1756,7 @@ OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
return nullptr;
// Otherwise, create the operation and try to parse a location for it.
Operation *op = opBuilder.createOperation(opState);
Operation *op = opBuilder.create(opState);
if (parseTrailingLocationSpecifier(op))
return nullptr;
return op;

View File

@ -1552,7 +1552,7 @@ void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
break;
}
Operation *resultOp = rewriter.createOperation(state);
Operation *resultOp = rewriter.create(state);
memory[memIndex] = resultOp;
LLVM_DEBUG({

View File

@ -509,12 +509,12 @@ Value Importer::processValue(llvm::Value *value) {
// We don't expect to see instructions in dominator order. If we haven't seen
// this instruction yet, create an unknown op and remap it later.
if (isa<llvm::Instruction>(value)) {
OperationState state(UnknownLoc::get(context), "llvm.unknown");
Type type = processType(value->getType());
if (!type)
return nullptr;
state.addTypes(type);
unknownInstMap[value] = b.createOperation(state);
unknownInstMap[value] =
b.create(UnknownLoc::get(context), b.getStringAttr("llvm.unknown"),
/*operands=*/{}, type);
return unknownInstMap[value]->getResult(0);
}
@ -705,7 +705,7 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
return failure();
state.addTypes(type);
}
Operation *op = b.createOperation(state);
Operation *op = b.create(state);
if (!inst->getType()->isVoidTy())
v = op->getResult(0);
return success();
@ -747,7 +747,7 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
b.getI32VectorAttr(operandSegmentSizes));
}
b.createOperation(state);
b.create(state);
return success();
}
case llvm::Instruction::PHI: {

View File

@ -283,7 +283,7 @@ LogicalResult spirv::Deserializer::processOpWithoutGrammarAttr(
if (hasResult)
opState.addTypes(resultTypes);
opState.addAttributes(attributes);
Operation *op = opBuilder.createOperation(opState);
Operation *op = opBuilder.create(opState);
if (hasResult)
valueMap[valueID] = op->getResult(0);

View File

@ -789,7 +789,7 @@ ParseResult PrettyPrintedRegionOp::parse(OpAsmParser &parser,
if (failed(parseOpNameInfo))
return failure();
StringRef innerOpName = parseOpNameInfo->getStringRef();
StringAttr innerOpName = parseOpNameInfo->getIdentifier();
FunctionType opFntype;
Optional<Location> explicitLoc;
@ -823,12 +823,8 @@ ParseResult PrettyPrintedRegionOp::parse(OpAsmParser &parser,
OpBuilder builder(parser.getBuilder().getContext());
builder.setInsertionPointToStart(&block);
OperationState innerOpState(opLoc, innerOpName);
innerOpState.operands.push_back(lhs);
innerOpState.operands.push_back(rhs);
innerOpState.addTypes(innerOpType);
Operation *innerOp = builder.createOperation(innerOpState);
Operation *innerOp =
builder.create(opLoc, innerOpName, /*operands=*/{lhs, rhs}, innerOpType);
// Insert a return statement in the block returning the inner-op's result.
builder.create<TestReturnOp>(innerOp->getLoc(), innerOp->getResults());

View File

@ -168,7 +168,7 @@ static void invokeCreateWithInferredReturnType(Operation *op) {
OperationState state(location, OpTy::getOperationName());
// TODO: Expand to regions.
OpTy::build(b, state, values, op->getAttrs());
(void)b.createOperation(state);
(void)b.create(state);
}
}
}
@ -295,7 +295,7 @@ struct TestRegionRewriteUndo : public RewritePattern {
// Create the region operation with an entry block containing arguments.
OperationState newRegion(op->getLoc(), "test.region");
newRegion.addRegion();
auto *regionOp = rewriter.createOperation(newRegion);
auto *regionOp = rewriter.create(newRegion);
auto *entryBlock = rewriter.createBlock(&regionOp->getRegion(0));
entryBlock->addArgument(rewriter.getIntegerType(64),
rewriter.getUnknownLoc());

View File

@ -38,9 +38,10 @@ customMultiEntityVariadicConstraint(ArrayRef<PDLValue> values,
// Custom creator invoked from PDL.
static void customCreate(ArrayRef<PDLValue> args, PatternRewriter &rewriter,
PDLResultList &results) {
results.push_back(rewriter.createOperation(
results.push_back(rewriter.create(
OperationState(args[0].cast<Operation *>()->getLoc(), "test.success")));
}
static void customVariadicResultCreate(ArrayRef<PDLValue> args,
PatternRewriter &rewriter,
PDLResultList &results) {
@ -59,7 +60,7 @@ static void customRewriter(ArrayRef<PDLValue> args, PatternRewriter &rewriter,
Operation *root = args[0].cast<Operation *>();
OperationState successOpState(root->getLoc(), "test.success");
successOpState.addOperands(args[1].cast<Value>());
rewriter.createOperation(successOpState);
rewriter.create(successOpState);
rewriter.eraseOp(root);
}