Revert "[mlir] FunctionOpInterface: arg and result attrs dispatch to interface"

and "[flang] Fix flang after MLIR update"

This reverts commit dd74e6b6f4 and
1897b67ae8 due to ongoing test failures on flang
bots e.g. https://lab.llvm.org/buildbot/#/builders/179/builds/5050
This commit is contained in:
David Spickett 2022-12-09 14:49:44 +00:00
parent 05ff7606c9
commit f3379feabe
37 changed files with 282 additions and 383 deletions

View File

@ -501,8 +501,7 @@ public:
// correctly.
for (auto e : llvm::enumerate(funcTy.getInputs())) {
unsigned index = e.index();
llvm::ArrayRef<mlir::NamedAttribute> attrs =
mlir::function_interface_impl::getArgAttrs(func, index);
llvm::ArrayRef<mlir::NamedAttribute> attrs = func.getArgAttrs(index);
for (mlir::NamedAttribute attr : attrs) {
savedAttrs.push_back({index, attr});
}

View File

@ -134,9 +134,7 @@ def FuncOp : Toy_Op<"func", [
let arguments = (ins
SymbolNameAttr:$sym_name,
TypeAttrOf<FunctionType>:$function_type,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs
TypeAttrOf<FunctionType>:$function_type
);
let regions = (region AnyRegion:$body);

View File

@ -219,9 +219,8 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
void FuncOp::print(mlir::OpAsmPrinter &p) {
// Dispatch to the FunctionOpInterface provided utility method that prints the
// function operation.
mlir::function_interface_impl::printFunctionOp(
p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
getArgAttrsAttrName(), getResAttrsAttrName());
mlir::function_interface_impl::printFunctionOp(p, *this,
/*isVariadic=*/false);
}
//===----------------------------------------------------------------------===//

View File

@ -133,9 +133,7 @@ def FuncOp : Toy_Op<"func", [
let arguments = (ins
SymbolNameAttr:$sym_name,
TypeAttrOf<FunctionType>:$function_type,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs
TypeAttrOf<FunctionType>:$function_type
);
let regions = (region AnyRegion:$body);

View File

@ -206,9 +206,8 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
void FuncOp::print(mlir::OpAsmPrinter &p) {
// Dispatch to the FunctionOpInterface provided utility method that prints the
// function operation.
mlir::function_interface_impl::printFunctionOp(
p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
getArgAttrsAttrName(), getResAttrsAttrName());
mlir::function_interface_impl::printFunctionOp(p, *this,
/*isVariadic=*/false);
}
//===----------------------------------------------------------------------===//

View File

@ -163,9 +163,7 @@ def FuncOp : Toy_Op<"func", [
let arguments = (ins
SymbolNameAttr:$sym_name,
TypeAttrOf<FunctionType>:$function_type,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs
TypeAttrOf<FunctionType>:$function_type
);
let regions = (region AnyRegion:$body);

View File

@ -295,9 +295,8 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
void FuncOp::print(mlir::OpAsmPrinter &p) {
// Dispatch to the FunctionOpInterface provided utility method that prints the
// function operation.
mlir::function_interface_impl::printFunctionOp(
p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
getArgAttrsAttrName(), getResAttrsAttrName());
mlir::function_interface_impl::printFunctionOp(p, *this,
/*isVariadic=*/false);
}
/// Returns the region on the function operation that is callable.

View File

@ -163,9 +163,7 @@ def FuncOp : Toy_Op<"func", [
let arguments = (ins
SymbolNameAttr:$sym_name,
TypeAttrOf<FunctionType>:$function_type,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs
TypeAttrOf<FunctionType>:$function_type
);
let regions = (region AnyRegion:$body);

View File

@ -295,9 +295,8 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
void FuncOp::print(mlir::OpAsmPrinter &p) {
// Dispatch to the FunctionOpInterface provided utility method that prints the
// function operation.
mlir::function_interface_impl::printFunctionOp(
p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
getArgAttrsAttrName(), getResAttrsAttrName());
mlir::function_interface_impl::printFunctionOp(p, *this,
/*isVariadic=*/false);
}
/// Returns the region on the function operation that is callable.

View File

@ -163,9 +163,7 @@ def FuncOp : Toy_Op<"func", [
let arguments = (ins
SymbolNameAttr:$sym_name,
TypeAttrOf<FunctionType>:$function_type,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs
TypeAttrOf<FunctionType>:$function_type
);
let regions = (region AnyRegion:$body);

View File

@ -295,9 +295,8 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
void FuncOp::print(mlir::OpAsmPrinter &p) {
// Dispatch to the FunctionOpInterface provided utility method that prints the
// function operation.
mlir::function_interface_impl::printFunctionOp(
p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
getArgAttrsAttrName(), getResAttrsAttrName());
mlir::function_interface_impl::printFunctionOp(p, *this,
/*isVariadic=*/false);
}
/// Returns the region on the function operation that is callable.

View File

@ -186,9 +186,7 @@ def FuncOp : Toy_Op<"func", [
let arguments = (ins
SymbolNameAttr:$sym_name,
TypeAttrOf<FunctionType>:$function_type,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs
TypeAttrOf<FunctionType>:$function_type
);
let regions = (region AnyRegion:$body);

View File

@ -322,9 +322,8 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
void FuncOp::print(mlir::OpAsmPrinter &p) {
// Dispatch to the FunctionOpInterface provided utility method that prints the
// function operation.
mlir::function_interface_impl::printFunctionOp(
p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
getArgAttrsAttrName(), getResAttrsAttrName());
mlir::function_interface_impl::printFunctionOp(p, *this,
/*isVariadic=*/false);
}
/// Returns the region on the function operation that is callable.

View File

@ -140,9 +140,7 @@ def Async_FuncOp : Async_Op<"func",
let arguments = (ins SymbolNameAttr:$sym_name,
TypeAttrOf<FunctionType>:$function_type,
OptionalAttr<StrAttr>:$sym_visibility,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs);
OptionalAttr<StrAttr>:$sym_visibility);
let regions = (region AnyRegion:$body);

View File

@ -251,9 +251,7 @@ def FuncOp : Func_Op<"func", [
let arguments = (ins SymbolNameAttr:$sym_name,
TypeAttrOf<FunctionType>:$function_type,
OptionalAttr<StrAttr>:$sym_visibility,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs);
OptionalAttr<StrAttr>:$sym_visibility);
let regions = (region AnyRegion:$body);
let builders = [OpBuilder<(ins

View File

@ -242,9 +242,7 @@ def GPU_GPUFuncOp : GPU_Op<"func", [
attribution.
}];
let arguments = (ins TypeAttrOf<FunctionType>:$function_type,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs);
let arguments = (ins TypeAttrOf<FunctionType>:$function_type);
let regions = (region AnyRegion:$body);
let skipDefaultBuilders = 1;

View File

@ -1311,9 +1311,7 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
DefaultValuedAttr<CConv, "CConv::C">:$CConv,
OptionalAttr<FlatSymbolRefAttr>:$personality,
OptionalAttr<StrAttr>:$garbageCollector,
OptionalAttr<ArrayAttr>:$passthrough,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs
OptionalAttr<ArrayAttr>:$passthrough
);
let regions = (region AnyRegion:$body);

View File

@ -52,8 +52,6 @@ def MLProgram_FuncOp : MLProgram_Op<"func", [
let arguments = (ins SymbolNameAttr:$sym_name,
TypeAttrOf<FunctionType>:$function_type,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs,
OptionalAttr<StrAttr>:$sym_visibility);
let regions = (region AnyRegion:$body);
@ -403,8 +401,6 @@ def MLProgram_SubgraphOp : MLProgram_Op<"subgraph", [
let arguments = (ins SymbolNameAttr:$sym_name,
TypeAttrOf<FunctionType>:$function_type,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs,
OptionalAttr<StrAttr>:$sym_visibility);
let regions = (region AnyRegion:$body);

View File

@ -652,9 +652,7 @@ def PDLInterp_FuncOp : PDLInterp_Op<"func", [
let arguments = (ins
SymbolNameAttr:$sym_name,
TypeAttrOf<FunctionType>:$function_type,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs
TypeAttrOf<FunctionType>:$function_type
);
let regions = (region MinSizedRegion<1>:$body);

View File

@ -291,8 +291,6 @@ def SPIRV_FuncOp : SPIRV_Op<"func", [
let arguments = (ins
TypeAttrOf<FunctionType>:$function_type,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs,
StrAttr:$sym_name,
SPIRV_FunctionControlAttr:$function_control
);

View File

@ -1107,8 +1107,6 @@ def Shape_FuncOp : Shape_Op<"func",
let arguments = (ins SymbolNameAttr:$sym_name,
TypeAttrOf<FunctionType>:$function_type,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs,
OptionalAttr<StrAttr>:$sym_visibility);
let regions = (region AnyRegion:$body);

View File

@ -39,12 +39,10 @@ private:
/// with special names given by getResultAttrName, getArgumentAttrName.
void addArgAndResultAttrs(Builder &builder, OperationState &result,
ArrayRef<DictionaryAttr> argAttrs,
ArrayRef<DictionaryAttr> resultAttrs,
StringAttr argAttrsName, StringAttr resAttrsName);
ArrayRef<DictionaryAttr> resultAttrs);
void addArgAndResultAttrs(Builder &builder, OperationState &result,
ArrayRef<OpAsmParser::Argument> args,
ArrayRef<DictionaryAttr> resultAttrs,
StringAttr argAttrsName, StringAttr resAttrsName);
ArrayRef<OpAsmParser::Argument> argAttrs,
ArrayRef<DictionaryAttr> resultAttrs);
/// Callback type for `parseFunctionOp`, the callback should produce the
/// type that will be associated with a function-like operation from lists of
@ -79,17 +77,15 @@ Type getFunctionType(Builder &builder, ArrayRef<OpAsmParser::Argument> argAttrs,
/// type, report the error or delegate the reporting to the op's verifier.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result,
bool allowVariadic, StringAttr typeAttrName,
FuncTypeBuilder funcTypeBuilder,
StringAttr argAttrsName, StringAttr resAttrsName);
FuncTypeBuilder funcTypeBuilder);
/// Printer implementation for function-like operations.
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic,
StringRef typeAttrName, StringAttr argAttrsName,
StringAttr resAttrsName);
StringRef typeAttrName);
/// Prints the signature of the function-like operation `op`. Assumes `op` has
/// is a FunctionOpInterface and has passed verification.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op,
void printFunctionSignature(OpAsmPrinter &p, Operation *op,
ArrayRef<Type> argTypes, bool isVariadic,
ArrayRef<Type> resultTypes);

View File

@ -26,30 +26,48 @@ class FunctionOpInterface;
namespace function_interface_impl {
/// Return the name of the attribute used for function argument attributes.
inline StringRef getArgDictAttrName() { return "arg_attrs"; }
/// Return the name of the attribute used for function argument attributes.
inline StringRef getResultDictAttrName() { return "res_attrs"; }
/// Returns the dictionary attribute corresponding to the argument at 'index'.
/// If there are no argument attributes at 'index', a null attribute is
/// returned.
DictionaryAttr getArgAttrDict(FunctionOpInterface op, unsigned index);
DictionaryAttr getArgAttrDict(Operation *op, unsigned index);
/// Returns the dictionary attribute corresponding to the result at 'index'.
/// If there are no result attributes at 'index', a null attribute is
/// returned.
DictionaryAttr getResultAttrDict(FunctionOpInterface op, unsigned index);
DictionaryAttr getResultAttrDict(Operation *op, unsigned index);
/// Return all of the attributes for the argument at 'index'.
ArrayRef<NamedAttribute> getArgAttrs(FunctionOpInterface op, unsigned index);
/// Return all of the attributes for the result at 'index'.
ArrayRef<NamedAttribute> getResultAttrs(FunctionOpInterface op, unsigned index);
namespace detail {
/// Update the given index into an argument or result attribute dictionary.
void setArgResAttrDict(Operation *op, StringRef attrName,
unsigned numTotalIndices, unsigned index,
DictionaryAttr attrs);
} // namespace detail
/// Set all of the argument or result attribute dictionaries for a function. The
/// size of `attrs` is expected to match the number of arguments/results of the
/// given `op`.
void setAllArgAttrDicts(FunctionOpInterface op, ArrayRef<DictionaryAttr> attrs);
void setAllArgAttrDicts(FunctionOpInterface op, ArrayRef<Attribute> attrs);
void setAllResultAttrDicts(FunctionOpInterface op,
ArrayRef<DictionaryAttr> attrs);
void setAllResultAttrDicts(FunctionOpInterface op, ArrayRef<Attribute> attrs);
void setAllArgAttrDicts(Operation *op, ArrayRef<DictionaryAttr> attrs);
void setAllArgAttrDicts(Operation *op, ArrayRef<Attribute> attrs);
void setAllResultAttrDicts(Operation *op, ArrayRef<DictionaryAttr> attrs);
void setAllResultAttrDicts(Operation *op, ArrayRef<Attribute> attrs);
/// Return all of the attributes for the argument at 'index'.
inline ArrayRef<NamedAttribute> getArgAttrs(Operation *op, unsigned index) {
auto argDict = getArgAttrDict(op, index);
return argDict ? argDict.getValue() : std::nullopt;
}
/// Return all of the attributes for the result at 'index'.
inline ArrayRef<NamedAttribute> getResultAttrs(Operation *op, unsigned index) {
auto resultDict = getResultAttrDict(op, index);
return resultDict ? resultDict.getValue() : std::nullopt;
}
/// Insert the specified arguments and update the function type attribute.
void insertFunctionArguments(FunctionOpInterface op,
@ -92,10 +110,20 @@ TypeRange filterTypesOut(TypeRange types, const BitVector &indices,
//===----------------------------------------------------------------------===//
/// Set the attributes held by the argument at 'index'.
void setArgAttrs(FunctionOpInterface op, unsigned index,
ArrayRef<NamedAttribute> attributes);
void setArgAttrs(FunctionOpInterface op, unsigned index,
DictionaryAttr attributes);
template <typename ConcreteType>
void setArgAttrs(ConcreteType op, unsigned index,
ArrayRef<NamedAttribute> attributes) {
assert(index < op.getNumArguments() && "invalid argument number");
return detail::setArgResAttrDict(
op, getArgDictAttrName(), op.getNumArguments(), index,
DictionaryAttr::get(op->getContext(), attributes));
}
template <typename ConcreteType>
void setArgAttrs(ConcreteType op, unsigned index, DictionaryAttr attributes) {
return detail::setArgResAttrDict(
op, getArgDictAttrName(), op.getNumArguments(), index,
attributes ? attributes : DictionaryAttr::get(op->getContext()));
}
/// If the an attribute exists with the specified name, change it to the new
/// value. Otherwise, add a new attribute with the specified name/value.
@ -129,10 +157,23 @@ Attribute removeArgAttr(ConcreteType op, unsigned index, StringAttr name) {
//===----------------------------------------------------------------------===//
/// Set the attributes held by the result at 'index'.
void setResultAttrs(FunctionOpInterface op, unsigned index,
ArrayRef<NamedAttribute> attributes);
void setResultAttrs(FunctionOpInterface op, unsigned index,
DictionaryAttr attributes);
template <typename ConcreteType>
void setResultAttrs(ConcreteType op, unsigned index,
ArrayRef<NamedAttribute> attributes) {
assert(index < op.getNumResults() && "invalid result number");
return detail::setArgResAttrDict(
op, getResultDictAttrName(), op.getNumResults(), index,
DictionaryAttr::get(op->getContext(), attributes));
}
template <typename ConcreteType>
void setResultAttrs(ConcreteType op, unsigned index,
DictionaryAttr attributes) {
assert(index < op.getNumResults() && "invalid result number");
return detail::setArgResAttrDict(
op, getResultDictAttrName(), op.getNumResults(), index,
attributes ? attributes : DictionaryAttr::get(op->getContext()));
}
/// If the an attribute exists with the specified name, change it to the new
/// value. Otherwise, add a new attribute with the specified name/value.
@ -172,8 +213,9 @@ LogicalResult verifyTrait(ConcreteOp op) {
unsigned numArgs = op.getNumArguments();
if (allArgAttrs.size() != numArgs) {
return op.emitOpError()
<< "expects argument attribute array to have the same number of "
"elements as the number of function arguments, got "
<< "expects argument attribute array `" << getArgDictAttrName()
<< "` to have the same number of elements as the number of "
"function arguments, got "
<< allArgAttrs.size() << ", but expected " << numArgs;
}
for (unsigned i = 0; i != numArgs; ++i) {
@ -203,8 +245,9 @@ LogicalResult verifyTrait(ConcreteOp op) {
unsigned numResults = op.getNumResults();
if (allResultAttrs.size() != numResults) {
return op.emitOpError()
<< "expects result attribute array to have the same number of "
"elements as the number of function results, got "
<< "expects result attribute array `" << getResultDictAttrName()
<< "` to have the same number of elements as the number of "
"function results, got "
<< allResultAttrs.size() << ", but expected " << numResults;
}
for (unsigned i = 0; i != numResults; ++i) {

View File

@ -59,42 +59,6 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
result attributes.
}],
"void", "setFunctionTypeAttr", (ins "::mlir::TypeAttr":$type)>,
InterfaceMethod<[{
Get the array of argument attribute dictionaries. The method should return
an array attribute containing only dictionary attributes equal in number
to the number of function arguments. Alternatively, the method can return
null to indicate that the function has no argument attributes.
}],
"::mlir::ArrayAttr", "getArgAttrsAttr">,
InterfaceMethod<[{
Get the array of result attribute dictionaries. The method should return
an array attribute containing only dictionary attributes equal in number
to the number of function results. Alternatively, the method can return
null to indicate that the function has no result attributes.
}],
"::mlir::ArrayAttr", "getResAttrsAttr">,
InterfaceMethod<[{
Set the array of argument attribute dictionaries.
}],
"void", "setArgAttrsAttr", (ins "::mlir::ArrayAttr":$attrs)>,
InterfaceMethod<[{
Set the array of result attribute dictionaries.
}],
"void", "setResAttrsAttr", (ins "::mlir::ArrayAttr":$attrs)>,
InterfaceMethod<[{
Remove the array of argument attribute dictionaries. This is the same as
setting all argument attributes to an empty dictionary. The method should
return the removed attribute.
}],
"::mlir::Attribute", "removeArgAttrsAttr">,
InterfaceMethod<[{
Remove the array of result attribute dictionaries. This is the same as
setting all result attributes to an empty dictionary. The method should
return the removed attribute.
}],
"::mlir::Attribute", "removeResAttrsAttr">,
InterfaceMethod<[{
Returns the function argument types based exclusively on
the type (to allow for this method may be called on function
@ -286,6 +250,20 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
function_interface_impl::setFunctionType(this->getOperation(), newType);
}
// FIXME: These functions should be removed in favor of just forwarding to
// the derived operation, which should already have these defined
// (via ODS).
/// Returns the name of the attribute used for function argument attributes.
static StringRef getArgDictAttrName() {
return function_interface_impl::getArgDictAttrName();
}
/// Returns the name of the attribute used for function argument attributes.
static StringRef getResultDictAttrName() {
return function_interface_impl::getResultDictAttrName();
}
//===------------------------------------------------------------------===//
// Argument and Result Handling
//===------------------------------------------------------------------===//
@ -427,8 +405,10 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
/// Return an ArrayAttr containing all argument attribute dictionaries of
/// this function, or nullptr if no arguments have attributes.
ArrayAttr getAllArgAttrs() { return $_op.getArgAttrsAttr(); }
ArrayAttr getAllArgAttrs() {
return this->getOperation()->template getAttrOfType<ArrayAttr>(
getArgDictAttrName());
}
/// Return all argument attributes of this function.
void getAllArgAttrs(SmallVectorImpl<DictionaryAttr> &result) {
if (ArrayAttr argAttrs = getAllArgAttrs()) {
@ -480,7 +460,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
}
void setAllArgAttrs(ArrayAttr attributes) {
assert(attributes.size() == $_op.getNumArguments());
$_op.setArgAttrsAttr(attributes);
this->getOperation()->setAttr(getArgDictAttrName(), attributes);
}
/// If the an attribute exists with the specified name, change it to the new
@ -516,8 +496,10 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
/// Return an ArrayAttr containing all result attribute dictionaries of this
/// function, or nullptr if no result have attributes.
ArrayAttr getAllResultAttrs() { return $_op.getResAttrsAttr(); }
ArrayAttr getAllResultAttrs() {
return this->getOperation()->template getAttrOfType<ArrayAttr>(
getResultDictAttrName());
}
/// Return all result attributes of this function.
void getAllResultAttrs(SmallVectorImpl<DictionaryAttr> &result) {
if (ArrayAttr argAttrs = getAllResultAttrs()) {
@ -571,7 +553,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
}
void setAllResultAttrs(ArrayAttr attributes) {
assert(attributes.size() == $_op.getNumResults());
$_op.setResAttrsAttr(attributes);
this->getOperation()->setAttr(getResultDictAttrName(), attributes);
}
/// If the an attribute exists with the specified name, change it to the new

View File

@ -1524,8 +1524,6 @@ def TypeArrayAttr : TypedArrayAttrBase<TypeAttr, "type array attribute"> {
}
def IndexListArrayAttr :
TypedArrayAttrBase<I64ArrayAttr, "Array of 64-bit integer array attributes">;
def DictArrayAttr :
TypedArrayAttrBase<DictionaryAttr, "Array of dictionary attributes">;
// Attributes containing symbol references.
def SymbolRefAttr : Attr<CPred<"$_self.isa<::mlir::SymbolRefAttr>()">,

View File

@ -66,8 +66,8 @@ static void filterFuncAttributes(func::FuncOp func, bool filterArgAndResAttrs,
attr.getName() == func.getFunctionTypeAttrName() ||
attr.getName() == "func.varargs" ||
(filterArgAndResAttrs &&
(attr.getName() == func.getArgAttrsAttrName() ||
attr.getName() == func.getResAttrsAttrName())))
(attr.getName() == FunctionOpInterface::getArgDictAttrName() ||
attr.getName() == FunctionOpInterface::getResultDictAttrName())))
continue;
result.push_back(attr);
}
@ -90,19 +90,18 @@ static auto wrapAsStructAttrs(OpBuilder &b, ArrayAttr attrs) {
static void
prependResAttrsToArgAttrs(OpBuilder &builder,
SmallVectorImpl<NamedAttribute> &attributes,
func::FuncOp func) {
size_t numArguments = func.getNumArguments();
size_t numArguments) {
auto allAttrs = SmallVector<Attribute>(
numArguments + 1, DictionaryAttr::get(builder.getContext()));
NamedAttribute *argAttrs = nullptr;
for (auto *it = attributes.begin(); it != attributes.end();) {
if (it->getName() == func.getArgAttrsAttrName()) {
if (it->getName() == FunctionOpInterface::getArgDictAttrName()) {
auto arrayAttrs = it->getValue().cast<ArrayAttr>();
assert(arrayAttrs.size() == numArguments &&
"Number of arg attrs and args should match");
std::copy(arrayAttrs.begin(), arrayAttrs.end(), allAttrs.begin() + 1);
argAttrs = it;
} else if (it->getName() == func.getResAttrsAttrName()) {
} else if (it->getName() == FunctionOpInterface::getResultDictAttrName()) {
auto arrayAttrs = it->getValue().cast<ArrayAttr>();
assert(!arrayAttrs.empty() && "expected array to be non-empty");
allAttrs[0] = (arrayAttrs.size() == 1)
@ -114,8 +113,9 @@ prependResAttrsToArgAttrs(OpBuilder &builder,
it++;
}
auto newArgAttrs = builder.getNamedAttr(func.getArgAttrsAttrName(),
builder.getArrayAttr(allAttrs));
auto newArgAttrs =
builder.getNamedAttr(FunctionOpInterface::getArgDictAttrName(),
builder.getArrayAttr(allAttrs));
if (!argAttrs) {
attributes.emplace_back(newArgAttrs);
return;
@ -141,7 +141,7 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
auto [wrapperFuncType, resultIsNowArg] =
typeConverter.convertFunctionTypeCWrapper(type);
if (resultIsNowArg)
prependResAttrsToArgAttrs(rewriter, attributes, funcOp);
prependResAttrsToArgAttrs(rewriter, attributes, funcOp.getNumArguments());
auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
wrapperFuncType, LLVM::Linkage::External, /*dsoLocal*/ false,
@ -205,7 +205,7 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
filterFuncAttributes(funcOp, /*filterArgAndResAttrs=*/false, attributes);
if (resultIsNowArg)
prependResAttrsToArgAttrs(builder, attributes, funcOp);
prependResAttrsToArgAttrs(builder, attributes, funcOp.getNumArguments());
// Create the auxiliary function.
auto wrapperFunc = builder.create<LLVM::LLVMFuncOp>(
loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
@ -309,8 +309,8 @@ protected:
? resAttrDicts
: rewriter.getArrayAttr(
{wrapAsStructAttrs(rewriter, resAttrDicts)});
attributes.push_back(
rewriter.getNamedAttr(funcOp.getResAttrsAttrName(), newResAttrDicts));
attributes.push_back(rewriter.getNamedAttr(
FunctionOpInterface::getResultDictAttrName(), newResAttrDicts));
}
if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
SmallVector<Attribute, 4> newArgAttrs(
@ -353,8 +353,9 @@ protected:
newArgAttrs[mapping->inputNo + j] =
DictionaryAttr::get(rewriter.getContext(), convertedAttrs);
}
attributes.push_back(rewriter.getNamedAttr(
funcOp.getArgAttrsAttrName(), rewriter.getArrayAttr(newArgAttrs)));
attributes.push_back(
rewriter.getNamedAttr(FunctionOpInterface::getArgDictAttrName(),
rewriter.getArrayAttr(newArgAttrs)));
}
for (const auto &pair : llvm::enumerate(attributes)) {
if (pair.value().getName() == "llvm.linkage") {

View File

@ -340,9 +340,8 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
if (argAttrs.empty())
return;
assert(type.getNumInputs() == argAttrs.size());
function_interface_impl::addArgAndResultAttrs(
builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs,
/*resultAttrs=*/std::nullopt);
}
ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
@ -353,14 +352,12 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
return function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false,
getFunctionTypeAttrName(result.name), buildFuncType,
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
getFunctionTypeAttrName(result.name), buildFuncType);
}
void FuncOp::print(OpAsmPrinter &p) {
function_interface_impl::printFunctionOp(
p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
getArgAttrsAttrName(), getResAttrsAttrName());
function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
getFunctionTypeAttrName());
}
/// Check that the result type of async.func is not void and must be

View File

@ -251,9 +251,8 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
if (argAttrs.empty())
return;
assert(type.getNumInputs() == argAttrs.size());
function_interface_impl::addArgAndResultAttrs(
builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs,
/*resultAttrs=*/std::nullopt);
}
ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
@ -264,14 +263,12 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
return function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false,
getFunctionTypeAttrName(result.name), buildFuncType,
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
getFunctionTypeAttrName(result.name), buildFuncType);
}
void FuncOp::print(OpAsmPrinter &p) {
function_interface_impl::printFunctionOp(
p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
getArgAttrsAttrName(), getResAttrsAttrName());
function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
getFunctionTypeAttrName());
}
/// Clone the internal blocks from this function into dest and all attributes

View File

@ -934,9 +934,8 @@ ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
result.addAttribute(getFunctionTypeAttrName(result.name),
TypeAttr::get(type));
function_interface_impl::addArgAndResultAttrs(
builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
getResAttrsAttrName(result.name));
function_interface_impl::addArgAndResultAttrs(builder, result, entryArgs,
resultAttrs);
// Parse workgroup memory attributions.
if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(),
@ -997,8 +996,7 @@ void GPUFuncOp::print(OpAsmPrinter &p) {
function_interface_impl::printFunctionAttributes(
p, *this,
{getNumWorkgroupAttributionsAttrName(),
GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
getArgAttrsAttrName(), getResAttrsAttrName()});
GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName()});
p << ' ';
p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
}

View File

@ -2006,9 +2006,8 @@ void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
assert(type.cast<LLVMFunctionType>().getNumParams() == argAttrs.size() &&
"expected as many argument attribute lists as arguments");
function_interface_impl::addArgAndResultAttrs(
builder, result, argAttrs, /*resultAttrs=*/std::nullopt,
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
function_interface_impl::addArgAndResultAttrs(builder, result, argAttrs,
/*resultAttrs=*/std::nullopt);
}
// Builds an LLVM function type from the given lists of input and output types.
@ -2096,9 +2095,8 @@ ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
return failure();
function_interface_impl::addArgAndResultAttrs(
parser.getBuilder(), result, entryArgs, resultAttrs,
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
function_interface_impl::addArgAndResultAttrs(parser.getBuilder(), result,
entryArgs, resultAttrs);
auto *body = result.addRegion();
OptionalParseResult parseResult =
@ -2133,8 +2131,7 @@ void LLVMFuncOp::print(OpAsmPrinter &p) {
isVarArg(), resTypes);
function_interface_impl::printFunctionAttributes(
p, *this,
{getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
getLinkageAttrName(), getCConvAttrName()});
{getFunctionTypeAttrName(), getLinkageAttrName(), getCConvAttrName()});
// Print the body if this is not an external function.
Region &body = getBody();

View File

@ -153,14 +153,12 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
return function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false,
getFunctionTypeAttrName(result.name), buildFuncType,
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
getFunctionTypeAttrName(result.name), buildFuncType);
}
void FuncOp::print(OpAsmPrinter &p) {
function_interface_impl::printFunctionOp(
p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
getArgAttrsAttrName(), getResAttrsAttrName());
function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
getFunctionTypeAttrName());
}
//===----------------------------------------------------------------------===//
@ -318,14 +316,12 @@ ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) {
return function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false,
getFunctionTypeAttrName(result.name), buildFuncType,
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
getFunctionTypeAttrName(result.name), buildFuncType);
}
void SubgraphOp::print(OpAsmPrinter &p) {
function_interface_impl::printFunctionOp(
p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
getArgAttrsAttrName(), getResAttrsAttrName());
function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
getFunctionTypeAttrName());
}
//===----------------------------------------------------------------------===//

View File

@ -221,14 +221,12 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
return function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false,
getFunctionTypeAttrName(result.name), buildFuncType,
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
getFunctionTypeAttrName(result.name), buildFuncType);
}
void FuncOp::print(OpAsmPrinter &p) {
function_interface_impl::printFunctionOp(
p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
getArgAttrsAttrName(), getResAttrsAttrName());
function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
getFunctionTypeAttrName());
}
//===----------------------------------------------------------------------===//

View File

@ -2396,9 +2396,8 @@ ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
// Add the attributes to the function arguments.
assert(resultAttrs.size() == resultTypes.size());
function_interface_impl::addArgAndResultAttrs(
builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
getResAttrsAttrName(result.name));
function_interface_impl::addArgAndResultAttrs(builder, result, entryArgs,
resultAttrs);
// Parse the optional function body.
auto *body = result.addRegion();
@ -2420,8 +2419,7 @@ void spirv::FuncOp::print(OpAsmPrinter &printer) {
function_interface_impl::printFunctionAttributes(
printer, *this,
{spirv::attributeName<spirv::FunctionControl>(),
getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
getFunctionControlAttrName()});
getFunctionTypeAttrName(), getFunctionControlAttrName()});
// Print the body if this is not an external function.
Region &body = this->getBody();

View File

@ -1300,9 +1300,8 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
if (argAttrs.empty())
return;
assert(type.getNumInputs() == argAttrs.size());
function_interface_impl::addArgAndResultAttrs(
builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs,
/*resultAttrs=*/std::nullopt);
}
ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
@ -1313,14 +1312,12 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
return function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false,
getFunctionTypeAttrName(result.name), buildFuncType,
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
getFunctionTypeAttrName(result.name), buildFuncType);
}
void FuncOp::print(OpAsmPrinter &p) {
function_interface_impl::printFunctionOp(
p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
getArgAttrsAttrName(), getResAttrsAttrName());
function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
getFunctionTypeAttrName());
}
//===----------------------------------------------------------------------===//

View File

@ -113,7 +113,7 @@ parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
return parser.parseRParen();
}
ParseResult function_interface_impl::parseFunctionSignature(
ParseResult mlir::function_interface_impl::parseFunctionSignature(
OpAsmParser &parser, bool allowVariadic,
SmallVectorImpl<OpAsmParser::Argument> &arguments, bool &isVariadic,
SmallVectorImpl<Type> &resultTypes,
@ -125,10 +125,9 @@ ParseResult function_interface_impl::parseFunctionSignature(
return success();
}
void function_interface_impl::addArgAndResultAttrs(
void mlir::function_interface_impl::addArgAndResultAttrs(
Builder &builder, OperationState &result, ArrayRef<DictionaryAttr> argAttrs,
ArrayRef<DictionaryAttr> resultAttrs, StringAttr argAttrsName,
StringAttr resAttrsName) {
ArrayRef<DictionaryAttr> resultAttrs) {
auto nonEmptyAttrsFn = [](DictionaryAttr attrs) {
return attrs && !attrs.empty();
};
@ -143,28 +142,28 @@ void function_interface_impl::addArgAndResultAttrs(
// Add the attributes to the function arguments.
if (llvm::any_of(argAttrs, nonEmptyAttrsFn))
result.addAttribute(argAttrsName, getArrayAttr(argAttrs));
result.addAttribute(function_interface_impl::getArgDictAttrName(),
getArrayAttr(argAttrs));
// Add the attributes to the function results.
if (llvm::any_of(resultAttrs, nonEmptyAttrsFn))
result.addAttribute(resAttrsName, getArrayAttr(resultAttrs));
result.addAttribute(function_interface_impl::getResultDictAttrName(),
getArrayAttr(resultAttrs));
}
void function_interface_impl::addArgAndResultAttrs(
void mlir::function_interface_impl::addArgAndResultAttrs(
Builder &builder, OperationState &result,
ArrayRef<OpAsmParser::Argument> args, ArrayRef<DictionaryAttr> resultAttrs,
StringAttr argAttrsName, StringAttr resAttrsName) {
ArrayRef<OpAsmParser::Argument> args,
ArrayRef<DictionaryAttr> resultAttrs) {
SmallVector<DictionaryAttr> argAttrs;
for (const auto &arg : args)
argAttrs.push_back(arg.attrs);
addArgAndResultAttrs(builder, result, argAttrs, resultAttrs, argAttrsName,
resAttrsName);
addArgAndResultAttrs(builder, result, argAttrs, resultAttrs);
}
ParseResult function_interface_impl::parseFunctionOp(
ParseResult mlir::function_interface_impl::parseFunctionOp(
OpAsmParser &parser, OperationState &result, bool allowVariadic,
StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder,
StringAttr argAttrsName, StringAttr resAttrsName) {
StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder) {
SmallVector<OpAsmParser::Argument> entryArgs;
SmallVector<DictionaryAttr> resultAttrs;
SmallVector<Type> resultTypes;
@ -221,8 +220,7 @@ ParseResult function_interface_impl::parseFunctionOp(
// Add the attributes to the function arguments.
assert(resultAttrs.size() == resultTypes.size());
addArgAndResultAttrs(builder, result, entryArgs, resultAttrs, argAttrsName,
resAttrsName);
addArgAndResultAttrs(builder, result, entryArgs, resultAttrs);
// Parse the optional function body. The printer will not print the body if
// its empty, so disallow parsing of empty body in the parser.
@ -263,14 +261,14 @@ static void printFunctionResultList(OpAsmPrinter &p, ArrayRef<Type> types,
os << ')';
}
void function_interface_impl::printFunctionSignature(
OpAsmPrinter &p, FunctionOpInterface op, ArrayRef<Type> argTypes,
bool isVariadic, ArrayRef<Type> resultTypes) {
void mlir::function_interface_impl::printFunctionSignature(
OpAsmPrinter &p, Operation *op, ArrayRef<Type> argTypes, bool isVariadic,
ArrayRef<Type> resultTypes) {
Region &body = op->getRegion(0);
bool isExternal = body.empty();
p << '(';
ArrayAttr argAttrs = op.getArgAttrsAttr();
ArrayAttr argAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName());
for (unsigned i = 0, e = argTypes.size(); i < e; ++i) {
if (i > 0)
p << ", ";
@ -297,23 +295,26 @@ void function_interface_impl::printFunctionSignature(
if (!resultTypes.empty()) {
p.getStream() << " -> ";
auto resultAttrs = op.getResAttrsAttr();
auto resultAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName());
printFunctionResultList(p, resultTypes, resultAttrs);
}
}
void function_interface_impl::printFunctionAttributes(
void mlir::function_interface_impl::printFunctionAttributes(
OpAsmPrinter &p, Operation *op, ArrayRef<StringRef> elided) {
// Print out function attributes, if present.
SmallVector<StringRef, 8> ignoredAttrs = {SymbolTable::getSymbolAttrName()};
SmallVector<StringRef, 2> ignoredAttrs = {SymbolTable::getSymbolAttrName(),
getArgDictAttrName(),
getResultDictAttrName()};
ignoredAttrs.append(elided.begin(), elided.end());
p.printOptionalAttrDictWithKeyword(op->getAttrs(), ignoredAttrs);
}
void function_interface_impl::printFunctionOp(
OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic,
StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName) {
void mlir::function_interface_impl::printFunctionOp(OpAsmPrinter &p,
FunctionOpInterface op,
bool isVariadic,
StringRef typeAttrName) {
// Print the operation and the function name.
auto funcName =
op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
@ -328,8 +329,7 @@ void function_interface_impl::printFunctionOp(
ArrayRef<Type> argTypes = op.getArgumentTypes();
ArrayRef<Type> resultTypes = op.getResultTypes();
printFunctionSignature(p, op, argTypes, isVariadic, resultTypes);
printFunctionAttributes(
p, op, {visibilityAttrName, typeAttrName, argAttrsName, resAttrsName});
printFunctionAttributes(p, op, {visibilityAttrName, typeAttrName});
// Print the body if this is not an external function.
Region &body = op->getRegion(0);
if (!body.empty()) {

View File

@ -24,104 +24,27 @@ static bool isEmptyAttrDict(Attribute attr) {
return attr.cast<DictionaryAttr>().empty();
}
DictionaryAttr function_interface_impl::getArgAttrDict(FunctionOpInterface op,
unsigned index) {
ArrayAttr attrs = op.getArgAttrsAttr();
DictionaryAttr mlir::function_interface_impl::getArgAttrDict(Operation *op,
unsigned index) {
ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName());
DictionaryAttr argAttrs =
attrs ? attrs[index].cast<DictionaryAttr>() : DictionaryAttr();
return argAttrs;
}
DictionaryAttr
function_interface_impl::getResultAttrDict(FunctionOpInterface op,
unsigned index) {
ArrayAttr attrs = op.getResAttrsAttr();
mlir::function_interface_impl::getResultAttrDict(Operation *op,
unsigned index) {
ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName());
DictionaryAttr resAttrs =
attrs ? attrs[index].cast<DictionaryAttr>() : DictionaryAttr();
return resAttrs;
}
ArrayRef<NamedAttribute>
function_interface_impl::getArgAttrs(FunctionOpInterface op, unsigned index) {
auto argDict = getArgAttrDict(op, index);
return argDict ? argDict.getValue() : std::nullopt;
}
ArrayRef<NamedAttribute>
function_interface_impl::getResultAttrs(FunctionOpInterface op,
unsigned index) {
auto resultDict = getResultAttrDict(op, index);
return resultDict ? resultDict.getValue() : std::nullopt;
}
/// Get either the argument or result attributes array.
template <bool isArg>
static ArrayAttr getArgResAttrs(FunctionOpInterface op) {
if constexpr (isArg)
return op.getArgAttrsAttr();
else
return op.getResAttrsAttr();
}
/// Set either the argument or result attributes array.
template <bool isArg>
static void setArgResAttrs(FunctionOpInterface op, ArrayAttr attrs) {
if constexpr (isArg)
op.setArgAttrsAttr(attrs);
else
op.setResAttrsAttr(attrs);
}
/// Erase either the argument or result attributes array.
template <bool isArg>
static void removeArgResAttrs(FunctionOpInterface op) {
if constexpr (isArg)
op.removeArgAttrsAttr();
else
op.removeResAttrsAttr();
}
/// Set all of the argument or result attribute dictionaries for a function.
template <bool isArg>
static void setAllArgResAttrDicts(FunctionOpInterface op,
ArrayRef<Attribute> attrs) {
if (llvm::all_of(attrs, isEmptyAttrDict))
removeArgResAttrs<isArg>(op);
else
setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), attrs));
}
void function_interface_impl::setAllArgAttrDicts(
FunctionOpInterface op, ArrayRef<DictionaryAttr> attrs) {
setAllArgAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
}
void function_interface_impl::setAllArgAttrDicts(FunctionOpInterface op,
ArrayRef<Attribute> attrs) {
auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
return !attr ? DictionaryAttr::get(op->getContext()) : attr;
});
setAllArgResAttrDicts</*isArg=*/true>(op, llvm::to_vector<8>(wrappedAttrs));
}
void function_interface_impl::setAllResultAttrDicts(
FunctionOpInterface op, ArrayRef<DictionaryAttr> attrs) {
setAllResultAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
}
void function_interface_impl::setAllResultAttrDicts(FunctionOpInterface op,
ArrayRef<Attribute> attrs) {
auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
return !attr ? DictionaryAttr::get(op->getContext()) : attr;
});
setAllArgResAttrDicts</*isArg=*/false>(op, llvm::to_vector<8>(wrappedAttrs));
}
/// Update the given index into an argument or result attribute dictionary.
template <bool isArg>
static void setArgResAttrDict(FunctionOpInterface op, unsigned numTotalIndices,
unsigned index, DictionaryAttr attrs) {
ArrayAttr allAttrs = getArgResAttrs<isArg>(op);
void mlir::function_interface_impl::detail::setArgResAttrDict(
Operation *op, StringRef attrName, unsigned numTotalIndices, unsigned index,
DictionaryAttr attrs) {
ArrayAttr allAttrs = op->getAttrOfType<ArrayAttr>(attrName);
if (!allAttrs) {
if (attrs.empty())
return;
@ -130,7 +53,7 @@ static void setArgResAttrDict(FunctionOpInterface op, unsigned numTotalIndices,
SmallVector<Attribute, 8> newAttrs(numTotalIndices,
DictionaryAttr::get(op->getContext()));
newAttrs[index] = attrs;
setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), newAttrs));
op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs));
return;
}
// Check to see if the attribute is different from what we already have.
@ -142,51 +65,53 @@ static void setArgResAttrDict(FunctionOpInterface op, unsigned numTotalIndices,
ArrayRef<Attribute> rawAttrArray = allAttrs.getValue();
if (attrs.empty() &&
llvm::all_of(rawAttrArray.take_front(index), isEmptyAttrDict) &&
llvm::all_of(rawAttrArray.drop_front(index + 1), isEmptyAttrDict))
return removeArgResAttrs<isArg>(op);
llvm::all_of(rawAttrArray.drop_front(index + 1), isEmptyAttrDict)) {
op->removeAttr(attrName);
return;
}
// Otherwise, create a new attribute array with the updated dictionary.
SmallVector<Attribute, 8> newAttrs(rawAttrArray.begin(), rawAttrArray.end());
newAttrs[index] = attrs;
setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), newAttrs));
op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs));
}
void function_interface_impl::setArgAttrs(FunctionOpInterface op,
unsigned index,
ArrayRef<NamedAttribute> attributes) {
assert(index < op.getNumArguments() && "invalid argument number");
return setArgResAttrDict</*isArg=*/true>(
op, op.getNumArguments(), index,
DictionaryAttr::get(op->getContext(), attributes));
/// Set all of the argument or result attribute dictionaries for a function.
static void setAllArgResAttrDicts(Operation *op, StringRef attrName,
ArrayRef<Attribute> attrs) {
if (llvm::all_of(attrs, isEmptyAttrDict))
op->removeAttr(attrName);
else
op->setAttr(attrName, ArrayAttr::get(op->getContext(), attrs));
}
void function_interface_impl::setArgAttrs(FunctionOpInterface op,
unsigned index,
DictionaryAttr attributes) {
return setArgResAttrDict</*isArg=*/true>(
op, op.getNumArguments(), index,
attributes ? attributes : DictionaryAttr::get(op->getContext()));
void mlir::function_interface_impl::setAllArgAttrDicts(
Operation *op, ArrayRef<DictionaryAttr> attrs) {
setAllArgAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
}
void mlir::function_interface_impl::setAllArgAttrDicts(
Operation *op, ArrayRef<Attribute> attrs) {
auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
return !attr ? DictionaryAttr::get(op->getContext()) : attr;
});
setAllArgResAttrDicts(op, getArgDictAttrName(),
llvm::to_vector<8>(wrappedAttrs));
}
void function_interface_impl::setResultAttrs(
FunctionOpInterface op, unsigned index,
ArrayRef<NamedAttribute> attributes) {
assert(index < op.getNumResults() && "invalid result number");
return setArgResAttrDict</*isArg=*/false>(
op, op.getNumResults(), index,
DictionaryAttr::get(op->getContext(), attributes));
void mlir::function_interface_impl::setAllResultAttrDicts(
Operation *op, ArrayRef<DictionaryAttr> attrs) {
setAllResultAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
}
void mlir::function_interface_impl::setAllResultAttrDicts(
Operation *op, ArrayRef<Attribute> attrs) {
auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
return !attr ? DictionaryAttr::get(op->getContext()) : attr;
});
setAllArgResAttrDicts(op, getResultDictAttrName(),
llvm::to_vector<8>(wrappedAttrs));
}
void function_interface_impl::setResultAttrs(FunctionOpInterface op,
unsigned index,
DictionaryAttr attributes) {
assert(index < op.getNumResults() && "invalid result number");
return setArgResAttrDict</*isArg=*/false>(
op, op.getNumResults(), index,
attributes ? attributes : DictionaryAttr::get(op->getContext()));
}
void function_interface_impl::insertFunctionArguments(
void mlir::function_interface_impl::insertFunctionArguments(
FunctionOpInterface op, ArrayRef<unsigned> argIndices, TypeRange argTypes,
ArrayRef<DictionaryAttr> argAttrs, ArrayRef<Location> argLocs,
unsigned originalNumArgs, Type newType) {
@ -203,7 +128,7 @@ void function_interface_impl::insertFunctionArguments(
Block &entry = op->getRegion(0).front();
// Update the argument attributes of the function.
ArrayAttr oldArgAttrs = op.getArgAttrsAttr();
auto oldArgAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName());
if (oldArgAttrs || !argAttrs.empty()) {
SmallVector<DictionaryAttr, 4> newArgAttrs;
newArgAttrs.reserve(originalNumArgs + argIndices.size());
@ -232,7 +157,7 @@ void function_interface_impl::insertFunctionArguments(
entry.insertArgument(argIndices[i] + i, argTypes[i], argLocs[i]);
}
void function_interface_impl::insertFunctionResults(
void mlir::function_interface_impl::insertFunctionResults(
FunctionOpInterface op, ArrayRef<unsigned> resultIndices,
TypeRange resultTypes, ArrayRef<DictionaryAttr> resultAttrs,
unsigned originalNumResults, Type newType) {
@ -246,7 +171,7 @@ void function_interface_impl::insertFunctionResults(
// - Result attrs.
// Update the result attributes of the function.
ArrayAttr oldResultAttrs = op.getResAttrsAttr();
auto oldResultAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName());
if (oldResultAttrs || !resultAttrs.empty()) {
SmallVector<DictionaryAttr, 4> newResultAttrs;
newResultAttrs.reserve(originalNumResults + resultIndices.size());
@ -274,7 +199,7 @@ void function_interface_impl::insertFunctionResults(
op.setFunctionTypeAttr(TypeAttr::get(newType));
}
void function_interface_impl::eraseFunctionArguments(
void mlir::function_interface_impl::eraseFunctionArguments(
FunctionOpInterface op, const BitVector &argIndices, Type newType) {
// There are 3 things that need to be updated:
// - Function type.
@ -283,7 +208,7 @@ void function_interface_impl::eraseFunctionArguments(
Block &entry = op->getRegion(0).front();
// Update the argument attributes of the function.
if (ArrayAttr argAttrs = op.getArgAttrsAttr()) {
if (auto argAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName())) {
SmallVector<DictionaryAttr, 4> newArgAttrs;
newArgAttrs.reserve(argAttrs.size());
for (unsigned i = 0, e = argIndices.size(); i < e; ++i)
@ -297,14 +222,14 @@ void function_interface_impl::eraseFunctionArguments(
entry.eraseArguments(argIndices);
}
void function_interface_impl::eraseFunctionResults(
void mlir::function_interface_impl::eraseFunctionResults(
FunctionOpInterface op, const BitVector &resultIndices, Type newType) {
// There are 2 things that need to be updated:
// - Function type.
// - Result attrs.
// Update the result attributes of the function.
if (ArrayAttr resAttrs = op.getResAttrsAttr()) {
if (auto resAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName())) {
SmallVector<DictionaryAttr, 4> newResultAttrs;
newResultAttrs.reserve(resAttrs.size());
for (unsigned i = 0, e = resultIndices.size(); i < e; ++i)
@ -317,7 +242,7 @@ void function_interface_impl::eraseFunctionResults(
op.setFunctionTypeAttr(TypeAttr::get(newType));
}
TypeRange function_interface_impl::insertTypesInto(
TypeRange mlir::function_interface_impl::insertTypesInto(
TypeRange oldTypes, ArrayRef<unsigned> indices, TypeRange newTypes,
SmallVectorImpl<Type> &storage) {
assert(indices.size() == newTypes.size() &&
@ -336,7 +261,7 @@ TypeRange function_interface_impl::insertTypesInto(
return storage;
}
TypeRange function_interface_impl::filterTypesOut(
TypeRange mlir::function_interface_impl::filterTypesOut(
TypeRange types, const BitVector &indices, SmallVectorImpl<Type> &storage) {
if (indices.none())
return types;
@ -351,8 +276,8 @@ TypeRange function_interface_impl::filterTypesOut(
// Function type signature.
//===----------------------------------------------------------------------===//
void function_interface_impl::setFunctionType(FunctionOpInterface op,
Type newType) {
void mlir::function_interface_impl::setFunctionType(FunctionOpInterface op,
Type newType) {
unsigned oldNumArgs = op.getNumArguments();
unsigned oldNumResults = op.getNumResults();
op.setFunctionTypeAttr(TypeAttr::get(newType));
@ -360,31 +285,35 @@ void function_interface_impl::setFunctionType(FunctionOpInterface op,
unsigned newNumResults = op.getNumResults();
// Functor used to update the argument and result attributes of the function.
auto updateAttrFn = [&](auto isArg, unsigned oldCount, unsigned newCount) {
constexpr bool isArgVal = std::is_same_v<decltype(isArg), std::true_type>;
auto updateAttrFn = [&](StringRef attrName, unsigned oldCount,
unsigned newCount, auto setAttrFn) {
if (oldCount == newCount)
return;
// The new type has no arguments/results, just drop the attribute.
if (newCount == 0)
return removeArgResAttrs<isArgVal>(op);
ArrayAttr attrs = getArgResAttrs<isArgVal>(op);
if (newCount == 0) {
op->removeAttr(attrName);
return;
}
ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(attrName);
if (!attrs)
return;
// The new type has less arguments/results, take the first N attributes.
if (newCount < oldCount)
return setAllArgResAttrDicts<isArgVal>(
op, attrs.getValue().take_front(newCount));
return setAttrFn(op, attrs.getValue().take_front(newCount));
// Otherwise, the new type has more arguments/results. Initialize the new
// arguments/results with empty attributes.
SmallVector<Attribute> newAttrs(attrs.begin(), attrs.end());
newAttrs.resize(newCount);
setAllArgResAttrDicts<isArgVal>(op, newAttrs);
setAttrFn(op, newAttrs);
};
// Update the argument and result attributes.
updateAttrFn(std::true_type{}, oldNumArgs, newNumArgs);
updateAttrFn(std::false_type{}, oldNumResults, newNumResults);
updateAttrFn(
getArgDictAttrName(), oldNumArgs, newNumArgs,
[&](Operation *op, auto &&attrs) { setAllArgAttrDicts(op, attrs); });
updateAttrFn(
getResultDictAttrName(), oldNumResults, newNumResults,
[&](Operation *op, auto &&attrs) { setAllResultAttrDicts(op, attrs); });
}

View File

@ -96,11 +96,20 @@ func.func private @invalid_symbol_type_attr() attributes { function_type = "x" }
// -----
// expected-error@+1 {{argument attribute array to have the same number of elements as the number of function arguments}}
// expected-error@+1 {{argument attribute array `arg_attrs` to have the same number of elements as the number of function arguments}}
func.func private @invalid_arg_attrs() attributes { arg_attrs = [{}] }
// -----
// expected-error@+1 {{expects argument attribute dictionary to be a DictionaryAttr, but got `10 : i64`}}
func.func private @invalid_arg_attrs(i32) attributes { arg_attrs = [10] }
// expected-error@+1 {{result attribute array to have the same number of elements as the number of function results}}
// -----
// expected-error@+1 {{result attribute array `res_attrs` to have the same number of elements as the number of function results}}
func.func private @invalid_res_attrs() attributes { res_attrs = [{}] }
// -----
// expected-error@+1 {{expects result attribute dictionary to be a DictionaryAttr, but got `10 : i64`}}
func.func private @invalid_res_attrs() -> i32 attributes { res_attrs = [10] }