mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-01-27 11:55:49 +00:00
[mlir] Make OpBuilder::createOperation to accept raw inputs
This provides a way to create an operation without manipulating OperationState directly. This is useful for creating unregistered ops. Reviewed By: rriddle, mehdi_amini Differential Revision: https://reviews.llvm.org/D120787
This commit is contained in:
parent
00cc73044d
commit
14ecafd0bd
@ -1035,7 +1035,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
|
||||
$_op->getAttrs());
|
||||
for (Region &r : $_op->getRegions())
|
||||
r.cloneInto(state.addRegion(), bvm);
|
||||
return b.createOperation(state);
|
||||
return b.create(state);
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
@ -1056,7 +1056,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
|
||||
$_op->getAttrs());
|
||||
for (Region &r : $_op->getRegions())
|
||||
r.cloneInto(state.addRegion(), bvm);
|
||||
return b.createOperation(state);
|
||||
return b.create(state);
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
@ -1077,7 +1077,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
|
||||
$_op->getAttrs());
|
||||
for (size_t cnt = 0, e = $_op->getNumRegions(); cnt < e; ++cnt)
|
||||
state.addRegion();
|
||||
return b.createOperation(state);
|
||||
return b.create(state);
|
||||
}]
|
||||
>,
|
||||
StaticInterfaceMethod<
|
||||
|
@ -405,7 +405,13 @@ public:
|
||||
Operation *insert(Operation *op);
|
||||
|
||||
/// Creates an operation given the fields represented as an OperationState.
|
||||
Operation *createOperation(const OperationState &state);
|
||||
Operation *create(const OperationState &state);
|
||||
|
||||
/// Creates an operation with the given fields.
|
||||
Operation *create(Location loc, StringAttr opName, ValueRange operands,
|
||||
TypeRange types, ArrayRef<NamedAttribute> attributes = {},
|
||||
BlockRange successors = {},
|
||||
MutableArrayRef<std::unique_ptr<Region>> regions = {});
|
||||
|
||||
private:
|
||||
/// Helper for sanity checking preconditions for create* methods below.
|
||||
@ -431,7 +437,7 @@ public:
|
||||
OperationState state(location,
|
||||
getCheckRegisteredInfo<OpTy>(location.getContext()));
|
||||
OpTy::build(*this, state, std::forward<Args>(args)...);
|
||||
auto *op = createOperation(state);
|
||||
auto *op = create(state);
|
||||
auto result = dyn_cast<OpTy>(op);
|
||||
assert(result && "builder didn't return the right type");
|
||||
return result;
|
||||
@ -443,7 +449,7 @@ public:
|
||||
template <typename OpTy, typename... Args>
|
||||
void createOrFold(SmallVectorImpl<Value> &results, Location location,
|
||||
Args &&...args) {
|
||||
// Create the operation without using 'createOperation' as we don't want to
|
||||
// Create the operation without using 'create' as we don't want to
|
||||
// insert it yet.
|
||||
OperationState state(location,
|
||||
getCheckRegisteredInfo<OpTy>(location.getContext()));
|
||||
|
@ -11,6 +11,7 @@
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@ -319,11 +320,9 @@ LogicalResult LLVM::detail::oneToOneRewrite(
|
||||
}
|
||||
|
||||
// Create the operation through state since we don't know its C++ type.
|
||||
OperationState state(op->getLoc(), targetOp);
|
||||
state.addTypes(packedType);
|
||||
state.addOperands(operands);
|
||||
state.addAttributes(op->getAttrs());
|
||||
Operation *newOp = rewriter.createOperation(state);
|
||||
Operation *newOp =
|
||||
rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
|
||||
packedType, op->getAttrs());
|
||||
|
||||
// If the operation produced 0 or 1 result, return them immediately.
|
||||
if (numResults == 0)
|
||||
|
@ -130,11 +130,10 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite(
|
||||
|
||||
auto callback = [op, targetOp, &rewriter](Type llvm1DVectorTy,
|
||||
ValueRange operands) {
|
||||
OperationState state(op->getLoc(), targetOp);
|
||||
state.addTypes(llvm1DVectorTy);
|
||||
state.addOperands(operands);
|
||||
state.addAttributes(op->getAttrs());
|
||||
return rewriter.createOperation(state)->getResult(0);
|
||||
return rewriter
|
||||
.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
|
||||
llvm1DVectorTy, op->getAttrs())
|
||||
->getResult(0);
|
||||
};
|
||||
|
||||
return handleMultidimensionalVectors(op, operands, typeConverter, callback,
|
||||
|
@ -1408,10 +1408,9 @@ static Operation *widenOp(Operation *op, VectorizationState &state) {
|
||||
// name that works both in scalar mode and vector mode.
|
||||
// TODO: Is it worth considering an Operation.clone operation which
|
||||
// changes the type so we can promote an Operation with less boilerplate?
|
||||
OperationState vecOpState(op->getLoc(), op->getName(), vectorOperands,
|
||||
vectorTypes, op->getAttrs(), /*successors=*/{},
|
||||
/*regions=*/{});
|
||||
Operation *vecOp = state.builder.createOperation(vecOpState);
|
||||
Operation *vecOp =
|
||||
state.builder.create(op->getLoc(), op->getName().getIdentifier(),
|
||||
vectorOperands, vectorTypes, op->getAttrs());
|
||||
state.registerOpVectorReplacement(op, vecOp);
|
||||
return vecOp;
|
||||
}
|
||||
|
@ -1242,7 +1242,7 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
|
||||
}
|
||||
|
||||
// Create the new operation.
|
||||
auto *repOp = builder.createOperation(state);
|
||||
auto *repOp = builder.create(state);
|
||||
op->replaceAllUsesWith(repOp);
|
||||
op->erase();
|
||||
|
||||
|
@ -98,16 +98,14 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
|
||||
/*iteratorTypes=*/iteratorTypes,
|
||||
/*bodyBuilder=*/
|
||||
[&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
|
||||
OperationState state(loc, op->getName());
|
||||
state.addAttributes(op->getAttrs());
|
||||
// Only take the input operands in the cloned elementwise op.
|
||||
state.addOperands(regionArgs.take_front(op->getNumOperands()));
|
||||
auto resultTypes = llvm::to_vector<6>(
|
||||
llvm::map_range(op->getResultTypes(), [](Type type) {
|
||||
return type.cast<TensorType>().getElementType();
|
||||
}));
|
||||
state.addTypes(resultTypes);
|
||||
auto *scalarOp = builder.createOperation(state);
|
||||
auto *scalarOp =
|
||||
builder.create(loc, op->getName().getIdentifier(),
|
||||
regionArgs.take_front(op->getNumOperands()),
|
||||
resultTypes, op->getAttrs());
|
||||
builder.create<linalg::YieldOp>(loc, scalarOp->getResults());
|
||||
});
|
||||
return success();
|
||||
|
@ -299,17 +299,6 @@ static VectorizationResult vectorizeLinalgIndex(OpBuilder &b, Operation *op,
|
||||
return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
|
||||
}
|
||||
|
||||
/// Create a new vectorized verstion of `op` with the given operands and types.
|
||||
static Operation *createVectorizedOp(OpBuilder &b, Operation *op,
|
||||
ValueRange newOperands,
|
||||
ArrayRef<Type> types) {
|
||||
OperationState state(op->getLoc(), op->getName());
|
||||
state.addAttributes(op->getAttrs());
|
||||
state.addOperands(newOperands);
|
||||
state.addTypes(types);
|
||||
return b.createOperation(state);
|
||||
}
|
||||
|
||||
/// Emit reduction operations if the shapes of the value to reduce is different
|
||||
/// that the result shape.
|
||||
static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
|
||||
@ -326,7 +315,9 @@ static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
|
||||
return nullptr;
|
||||
SmallVector<bool> reductionMask = getReductionMask(linalgOp);
|
||||
Value reduce = buildMultiDimReduce(b, op, reduceVec, reductionMask);
|
||||
return createVectorizedOp(b, op, {reduce, outputVec}, reduce.getType());
|
||||
return b.create(op->getLoc(), op->getName().getIdentifier(),
|
||||
/*operands=*/{reduce, outputVec}, reduce.getType(),
|
||||
op->getAttrs());
|
||||
}
|
||||
|
||||
/// Generic vectorization for a single operation `op`, given already vectorized
|
||||
@ -420,8 +411,9 @@ vectorizeOneOp(OpBuilder &b, LinalgOp linalgOp, Operation *op,
|
||||
// Build and return the new op.
|
||||
return VectorizationResult{
|
||||
VectorizationStatus::NewOp,
|
||||
createVectorizedOp(b, op, llvm::to_vector<4>(vectorizedOperands),
|
||||
llvm::to_vector<4>(returnTypes))};
|
||||
b.create(op->getLoc(), op->getName().getIdentifier(),
|
||||
llvm::to_vector<4>(vectorizedOperands),
|
||||
llvm::to_vector<4>(returnTypes), op->getAttrs())};
|
||||
}
|
||||
|
||||
/// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
|
||||
|
@ -517,7 +517,7 @@ Operation *NormalizeMemRefs::createOpResultsNormalized(FuncOp funcOp,
|
||||
Region *newRegion = result.addRegion();
|
||||
newRegion->takeBody(oldRegion);
|
||||
}
|
||||
return bb.createOperation(result);
|
||||
return bb.create(result);
|
||||
}
|
||||
return oldOp;
|
||||
}
|
||||
|
@ -368,11 +368,9 @@ public:
|
||||
newOperands.push_back(operand);
|
||||
}
|
||||
}
|
||||
OperationState state(op->getLoc(), op->getName());
|
||||
state.addAttributes(op->getAttrs());
|
||||
state.addOperands(newOperands);
|
||||
state.addTypes(newVecType);
|
||||
Operation *newOp = rewriter.createOperation(state);
|
||||
Operation *newOp =
|
||||
rewriter.create(op->getLoc(), op->getName().getIdentifier(),
|
||||
newOperands, newVecType, op->getAttrs());
|
||||
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecType,
|
||||
newOp->getResult(0));
|
||||
return success();
|
||||
|
@ -1028,9 +1028,8 @@ struct ReorderCastOpsOnBroadcast
|
||||
Type castResTy = getElementTypeOrSelf(op->getResult(0));
|
||||
if (auto vecTy = bcastOp.getSourceType().dyn_cast<VectorType>())
|
||||
castResTy = VectorType::get(vecTy.getShape(), castResTy);
|
||||
OperationState state(op->getLoc(), op->getName(), bcastOp.source(),
|
||||
castResTy, op->getAttrs());
|
||||
auto castOp = rewriter.createOperation(state);
|
||||
auto castOp = rewriter.create(op->getLoc(), op->getName().getIdentifier(),
|
||||
bcastOp.source(), castResTy, op->getAttrs());
|
||||
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
|
||||
op, op->getResult(0).getType(), castOp->getResult(0));
|
||||
return success();
|
||||
@ -1068,9 +1067,8 @@ struct ReorderCastOpsOnTranspose
|
||||
auto castResTy = transpOp.getVectorType();
|
||||
castResTy = VectorType::get(castResTy.getShape(),
|
||||
getElementTypeOrSelf(op->getResult(0)));
|
||||
OperationState state(op->getLoc(), op->getName(), transpOp.vector(),
|
||||
castResTy, op->getAttrs());
|
||||
auto castOp = rewriter.createOperation(state);
|
||||
auto castOp = rewriter.create(op->getLoc(), op->getName().getIdentifier(),
|
||||
transpOp.vector(), castResTy, op->getAttrs());
|
||||
rewriter.replaceOpWithNewOp<vector::TransposeOp>(
|
||||
op, op->getResult(0).getType(), castOp->getResult(0),
|
||||
transpOp.getTransp());
|
||||
|
@ -70,8 +70,8 @@ static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
|
||||
Operation *op,
|
||||
ArrayRef<Value> operands,
|
||||
ArrayRef<Type> resultTypes) {
|
||||
OperationState res(loc, op->getName(), operands, resultTypes, op->getAttrs());
|
||||
return builder.createOperation(res);
|
||||
return builder.create(loc, op->getName().getIdentifier(), operands,
|
||||
resultTypes, op->getAttrs());
|
||||
}
|
||||
|
||||
/// Return the target shape for unrolling for the given `op`. Return llvm::None
|
||||
|
@ -377,10 +377,21 @@ Block *OpBuilder::createBlock(Block *insertBefore, TypeRange argTypes,
|
||||
}
|
||||
|
||||
/// Create an operation given the fields represented as an OperationState.
|
||||
Operation *OpBuilder::createOperation(const OperationState &state) {
|
||||
Operation *OpBuilder::create(const OperationState &state) {
|
||||
return insert(Operation::create(state));
|
||||
}
|
||||
|
||||
/// Creates an operation with the given fields.
|
||||
Operation *OpBuilder::create(Location loc, StringAttr opName,
|
||||
ValueRange operands, TypeRange types,
|
||||
ArrayRef<NamedAttribute> attributes,
|
||||
BlockRange successors,
|
||||
MutableArrayRef<std::unique_ptr<Region>> regions) {
|
||||
OperationState state(loc, opName, operands, types, attributes, successors,
|
||||
regions);
|
||||
return create(state);
|
||||
}
|
||||
|
||||
/// Attempts to fold the given operation and places new results within
|
||||
/// 'results'. Returns success if the operation was folded, failure otherwise.
|
||||
/// Note: This function does not erase the operation on a successful fold.
|
||||
|
@ -1151,7 +1151,7 @@ Operation *OperationParser::parseGenericOperation() {
|
||||
return nullptr;
|
||||
|
||||
// Create the operation and try to parse a location for it.
|
||||
Operation *op = opBuilder.createOperation(result);
|
||||
Operation *op = opBuilder.create(result);
|
||||
if (parseTrailingLocationSpecifier(op))
|
||||
return nullptr;
|
||||
return op;
|
||||
@ -1756,7 +1756,7 @@ OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
|
||||
return nullptr;
|
||||
|
||||
// Otherwise, create the operation and try to parse a location for it.
|
||||
Operation *op = opBuilder.createOperation(opState);
|
||||
Operation *op = opBuilder.create(opState);
|
||||
if (parseTrailingLocationSpecifier(op))
|
||||
return nullptr;
|
||||
return op;
|
||||
|
@ -1552,7 +1552,7 @@ void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
|
||||
break;
|
||||
}
|
||||
|
||||
Operation *resultOp = rewriter.createOperation(state);
|
||||
Operation *resultOp = rewriter.create(state);
|
||||
memory[memIndex] = resultOp;
|
||||
|
||||
LLVM_DEBUG({
|
||||
|
@ -509,12 +509,12 @@ Value Importer::processValue(llvm::Value *value) {
|
||||
// We don't expect to see instructions in dominator order. If we haven't seen
|
||||
// this instruction yet, create an unknown op and remap it later.
|
||||
if (isa<llvm::Instruction>(value)) {
|
||||
OperationState state(UnknownLoc::get(context), "llvm.unknown");
|
||||
Type type = processType(value->getType());
|
||||
if (!type)
|
||||
return nullptr;
|
||||
state.addTypes(type);
|
||||
unknownInstMap[value] = b.createOperation(state);
|
||||
unknownInstMap[value] =
|
||||
b.create(UnknownLoc::get(context), b.getStringAttr("llvm.unknown"),
|
||||
/*operands=*/{}, type);
|
||||
return unknownInstMap[value]->getResult(0);
|
||||
}
|
||||
|
||||
@ -705,7 +705,7 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
|
||||
return failure();
|
||||
state.addTypes(type);
|
||||
}
|
||||
Operation *op = b.createOperation(state);
|
||||
Operation *op = b.create(state);
|
||||
if (!inst->getType()->isVoidTy())
|
||||
v = op->getResult(0);
|
||||
return success();
|
||||
@ -747,7 +747,7 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
|
||||
b.getI32VectorAttr(operandSegmentSizes));
|
||||
}
|
||||
|
||||
b.createOperation(state);
|
||||
b.create(state);
|
||||
return success();
|
||||
}
|
||||
case llvm::Instruction::PHI: {
|
||||
|
@ -283,7 +283,7 @@ LogicalResult spirv::Deserializer::processOpWithoutGrammarAttr(
|
||||
if (hasResult)
|
||||
opState.addTypes(resultTypes);
|
||||
opState.addAttributes(attributes);
|
||||
Operation *op = opBuilder.createOperation(opState);
|
||||
Operation *op = opBuilder.create(opState);
|
||||
if (hasResult)
|
||||
valueMap[valueID] = op->getResult(0);
|
||||
|
||||
|
@ -789,7 +789,7 @@ ParseResult PrettyPrintedRegionOp::parse(OpAsmParser &parser,
|
||||
if (failed(parseOpNameInfo))
|
||||
return failure();
|
||||
|
||||
StringRef innerOpName = parseOpNameInfo->getStringRef();
|
||||
StringAttr innerOpName = parseOpNameInfo->getIdentifier();
|
||||
|
||||
FunctionType opFntype;
|
||||
Optional<Location> explicitLoc;
|
||||
@ -823,12 +823,8 @@ ParseResult PrettyPrintedRegionOp::parse(OpAsmParser &parser,
|
||||
OpBuilder builder(parser.getBuilder().getContext());
|
||||
builder.setInsertionPointToStart(&block);
|
||||
|
||||
OperationState innerOpState(opLoc, innerOpName);
|
||||
innerOpState.operands.push_back(lhs);
|
||||
innerOpState.operands.push_back(rhs);
|
||||
innerOpState.addTypes(innerOpType);
|
||||
|
||||
Operation *innerOp = builder.createOperation(innerOpState);
|
||||
Operation *innerOp =
|
||||
builder.create(opLoc, innerOpName, /*operands=*/{lhs, rhs}, innerOpType);
|
||||
|
||||
// Insert a return statement in the block returning the inner-op's result.
|
||||
builder.create<TestReturnOp>(innerOp->getLoc(), innerOp->getResults());
|
||||
|
@ -168,7 +168,7 @@ static void invokeCreateWithInferredReturnType(Operation *op) {
|
||||
OperationState state(location, OpTy::getOperationName());
|
||||
// TODO: Expand to regions.
|
||||
OpTy::build(b, state, values, op->getAttrs());
|
||||
(void)b.createOperation(state);
|
||||
(void)b.create(state);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -295,7 +295,7 @@ struct TestRegionRewriteUndo : public RewritePattern {
|
||||
// Create the region operation with an entry block containing arguments.
|
||||
OperationState newRegion(op->getLoc(), "test.region");
|
||||
newRegion.addRegion();
|
||||
auto *regionOp = rewriter.createOperation(newRegion);
|
||||
auto *regionOp = rewriter.create(newRegion);
|
||||
auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0));
|
||||
entryBlock->addArgument(rewriter.getIntegerType(64),
|
||||
rewriter.getUnknownLoc());
|
||||
|
@ -38,9 +38,10 @@ customMultiEntityVariadicConstraint(ArrayRef<PDLValue> values,
|
||||
// Custom creator invoked from PDL.
|
||||
static void customCreate(ArrayRef<PDLValue> args, PatternRewriter &rewriter,
|
||||
PDLResultList &results) {
|
||||
results.push_back(rewriter.createOperation(
|
||||
results.push_back(rewriter.create(
|
||||
OperationState(args[0].cast<Operation *>()->getLoc(), "test.success")));
|
||||
}
|
||||
|
||||
static void customVariadicResultCreate(ArrayRef<PDLValue> args,
|
||||
PatternRewriter &rewriter,
|
||||
PDLResultList &results) {
|
||||
@ -59,7 +60,7 @@ static void customRewriter(ArrayRef<PDLValue> args, PatternRewriter &rewriter,
|
||||
Operation *root = args[0].cast<Operation *>();
|
||||
OperationState successOpState(root->getLoc(), "test.success");
|
||||
successOpState.addOperands(args[1].cast<Value>());
|
||||
rewriter.createOperation(successOpState);
|
||||
rewriter.create(successOpState);
|
||||
rewriter.eraseOp(root);
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user