[mlir] Implement replacement of SymbolRefAttrs in Dialect attributes using SubElementAttr interface

This patch extends the SubElementAttr interface to allow replacing a contained sub attribute. The attribute that should be replaced is identified by an index which denotes the n-th element returned by the accompanying walkImmediateSubElements method.

Using this addition the patch implements replacing SymbolRefAttrs contained within any dialect attributes.

Differential Revision: https://reviews.llvm.org/D111357
This commit is contained in:
Markus Böck 2021-10-28 19:08:10 +02:00
parent b437aaa672
commit 10a80c4413
7 changed files with 180 additions and 47 deletions

View File

@ -71,7 +71,8 @@ def Builtin_AffineMapAttr : Builtin_Attr<"AffineMap", [
//===----------------------------------------------------------------------===//
def Builtin_ArrayAttr : Builtin_Attr<"Array", [
DeclareAttrInterfaceMethods<SubElementAttrInterface>
DeclareAttrInterfaceMethods<SubElementAttrInterface,
["replaceImmediateSubAttribute"]>
]> {
let summary = "A collection of other Attribute values";
let description = [{
@ -345,7 +346,8 @@ def Builtin_DenseStringElementsAttr : Builtin_Attr<
//===----------------------------------------------------------------------===//
def Builtin_DictionaryAttr : Builtin_Attr<"Dictionary", [
DeclareAttrInterfaceMethods<SubElementAttrInterface>
DeclareAttrInterfaceMethods<SubElementAttrInterface,
["replaceImmediateSubAttribute"]>
]> {
let summary = "An dictionary of named Attribute values";
let description = [{
@ -954,10 +956,11 @@ def Builtin_SymbolRefAttr : Builtin_Attr<"SymbolRef"> {
symbol nested within a different symbol table.
This attribute can only be held internally by
[array attributes](#array-attribute) and
[array attributes](#array-attribute),
[dictionary attributes](#dictionary-attribute)(including the top-level
operation attribute dictionary), i.e. no other attribute kinds such as
Locations or extended attribute kinds.
operation attribute dictionary) as well as attributes exposing it via
the `SubElementAttrInterface` interface. Symbol reference attributes
nested in types are currently not supported.
**Rationale:** Identifying accesses to global data is critical to
enabling efficient multi-threaded compilation. Restricting global

View File

@ -33,6 +33,20 @@ class SubElementInterfaceBase<string interfaceName, string derivedValue> {
(ins "llvm::function_ref<void(mlir::Attribute)>":$walkAttrsFn,
"llvm::function_ref<void(mlir::Type)>":$walkTypesFn)
>,
InterfaceMethod<
/*desc=*/[{
Replace the attributes identified by the indices with the corresponding
value. The index is derived from the order of the attributes returned by
the attribute callback of `walkImmediateSubElements`. An index of 0 would
replace the very first attribute given by `walkImmediateSubElements`.
The new instance with the values replaced is returned.
}], cppNamespace # "::" # interfaceName, "replaceImmediateSubAttribute",
(ins "::llvm::ArrayRef<std::pair<size_t, ::mlir::Attribute>>":$replacements),
[{}],
/*defaultImplementation=*/[{
llvm_unreachable("Attribute or Type does not support replacing attributes");
}]
>,
];
code extraClassDeclaration = [{

View File

@ -53,6 +53,15 @@ void ArrayAttr::walkImmediateSubElements(
walkAttrsFn(attr);
}
SubElementAttrInterface ArrayAttr::replaceImmediateSubAttribute(
ArrayRef<std::pair<size_t, Attribute>> replacements) const {
std::vector<Attribute> vector = getValue().vec();
for (auto &it : replacements) {
vector[it.first] = it.second;
}
return get(getContext(), vector);
}
//===----------------------------------------------------------------------===//
// DictionaryAttr
//===----------------------------------------------------------------------===//
@ -217,6 +226,17 @@ void DictionaryAttr::walkImmediateSubElements(
walkAttrsFn(attr);
}
SubElementAttrInterface DictionaryAttr::replaceImmediateSubAttribute(
ArrayRef<std::pair<size_t, Attribute>> replacements) const {
std::vector<NamedAttribute> vec = getValue().vec();
for (auto &it : replacements) {
vec[it.first].second = it.second;
}
// The above only modifies the mapped value, but not the key, and therefore
// not the order of the elements. It remains sorted
return getWithSorted(getContext(), vec);
}
//===----------------------------------------------------------------------===//
// StringAttr
//===----------------------------------------------------------------------===//

View File

@ -485,16 +485,30 @@ static WalkResult walkSymbolRefs(
// A worklist of a container attribute and the current index into the held
// attribute list.
SmallVector<Attribute, 1> attrWorklist(1, attrDict);
struct WorklistItem {
SubElementAttrInterface container;
SmallVector<Attribute> immediateSubElements;
explicit WorklistItem(SubElementAttrInterface container) {
SmallVector<Attribute> subElements;
container.walkImmediateSubElements(
[&](Attribute attr) { subElements.push_back(attr); }, [](Type) {});
immediateSubElements = std::move(subElements);
}
};
SmallVector<WorklistItem, 1> attrWorklist(1, WorklistItem(attrDict));
SmallVector<int, 1> curAccessChain(1, /*Value=*/-1);
// Process the symbol references within the given nested attribute range.
auto processAttrs = [&](int &index, auto attrRange) -> WalkResult {
for (Attribute attr : llvm::drop_begin(attrRange, index)) {
auto processAttrs = [&](int &index,
WorklistItem &worklistItem) -> WalkResult {
for (Attribute attr :
llvm::drop_begin(worklistItem.immediateSubElements, index)) {
/// Check for a nested container attribute, these will also need to be
/// walked.
if (attr.isa<ArrayAttr, DictionaryAttr>()) {
attrWorklist.push_back(attr);
if (auto interface = attr.dyn_cast<SubElementAttrInterface>()) {
attrWorklist.emplace_back(interface);
curAccessChain.push_back(-1);
return WalkResult::advance();
}
@ -517,15 +531,12 @@ static WalkResult walkSymbolRefs(
WalkResult result = WalkResult::advance();
do {
Attribute attr = attrWorklist.back();
WorklistItem &item = attrWorklist.back();
int &index = curAccessChain.back();
++index;
// Process the given attribute, which is guaranteed to be a container.
if (auto dict = attr.dyn_cast<DictionaryAttr>())
result = processAttrs(index, make_second_range(dict.getValue()));
else
result = processAttrs(index, attr.cast<ArrayAttr>().getValue());
result = processAttrs(index, item);
} while (!attrWorklist.empty() && !result.wasInterrupted());
return result;
}
@ -811,48 +822,46 @@ bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Region *from) {
/// Rebuild the given attribute container after replacing all references to a
/// symbol with the updated attribute in 'accesses'.
static Attribute rebuildAttrAfterRAUW(
Attribute container,
static SubElementAttrInterface rebuildAttrAfterRAUW(
SubElementAttrInterface container,
ArrayRef<std::pair<SmallVector<int, 1>, SymbolRefAttr>> accesses,
unsigned depth) {
// Given a range of Attributes, update the ones referred to by the given
// access chains to point to the new symbol attribute.
auto updateAttrs = [&](auto &&attrRange) {
auto attrBegin = std::begin(attrRange);
for (unsigned i = 0, e = accesses.size(); i != e;) {
ArrayRef<int> access = accesses[i].first;
Attribute &attr = *std::next(attrBegin, access[depth]);
// Check to see if this is a leaf access, i.e. a SymbolRef.
if (access.size() == depth + 1) {
attr = accesses[i].second;
++i;
continue;
}
SmallVector<std::pair<size_t, Attribute>> replacements;
// Otherwise, this is a container. Collect all of the accesses for this
// index and recurse. The recursion here is bounded by the size of the
// largest access array.
auto nestedAccesses = accesses.drop_front(i).take_while([&](auto &it) {
ArrayRef<int> nextAccess = it.first;
return nextAccess.size() > depth + 1 &&
nextAccess[depth] == access[depth];
});
attr = rebuildAttrAfterRAUW(attr, nestedAccesses, depth + 1);
SmallVector<Attribute> subElements;
container.walkImmediateSubElements(
[&](Attribute attribute) { subElements.push_back(attribute); },
[](Type) {});
for (unsigned i = 0, e = accesses.size(); i != e;) {
ArrayRef<int> access = accesses[i].first;
// Skip over all of the accesses that refer to the nested container.
i += nestedAccesses.size();
// Check to see if this is a leaf access, i.e. a SymbolRef.
if (access.size() == depth + 1) {
replacements.emplace_back(access.back(), accesses[i].second);
++i;
continue;
}
};
if (auto dictAttr = container.dyn_cast<DictionaryAttr>()) {
auto newAttrs = llvm::to_vector<4>(dictAttr.getValue());
updateAttrs(make_second_range(newAttrs));
return DictionaryAttr::get(dictAttr.getContext(), newAttrs);
// Otherwise, this is a container. Collect all of the accesses for this
// index and recurse. The recursion here is bounded by the size of the
// largest access array.
auto nestedAccesses = accesses.drop_front(i).take_while([&](auto &it) {
ArrayRef<int> nextAccess = it.first;
return nextAccess.size() > depth + 1 &&
nextAccess[depth] == access[depth];
});
auto result = rebuildAttrAfterRAUW(subElements[access[depth]],
nestedAccesses, depth + 1);
replacements.emplace_back(access[depth], result);
// Skip over all of the accesses that refer to the nested container.
i += nestedAccesses.size();
}
auto newAttrs = llvm::to_vector<4>(container.cast<ArrayAttr>().getValue());
updateAttrs(newAttrs);
return ArrayAttr::get(container.getContext(), newAttrs);
return container.replaceImmediateSubAttribute(replacements);
}
/// Generates a new symbol reference attribute with a new leaf reference.

View File

@ -73,3 +73,24 @@ module {
"foo.possibly_unknown_symbol_table"() ({
}) : () -> ()
}
// -----
// Check that replacement works in any implementations of SubElementsAttrInterface
module {
// CHECK: func private @replaced_foo
func private @symbol_foo() attributes {sym.new_name = "replaced_foo" }
// CHECK: func @symbol_bar
func @symbol_bar() {
// CHECK: foo.op
// CHECK-SAME: non_symbol_attr,
// CHECK-SAME: use = [#test.sub_elements_access<[@replaced_foo], @symbol_bar, @replaced_foo>],
// CHECK-SAME: z_non_symbol_attr_3
"foo.op"() {
non_symbol_attr,
use = [#test.sub_elements_access<[@symbol_foo],@symbol_bar,@symbol_foo>],
z_non_symbol_attr_3
} : () -> ()
}
}

View File

@ -16,6 +16,7 @@
// To get the test dialect definition.
include "TestOps.td"
include "mlir/IR/BuiltinAttributeInterfaces.td"
include "mlir/IR/SubElementInterfaces.td"
// All of the attributes will extend this class.
class Test_Attr<string name, list<Trait> traits = []>
@ -101,4 +102,18 @@ def TestI64ElementsAttr : Test_Attr<"TestI64Elements", [
let genVerifyDecl = 1;
}
def TestSubElementsAccessAttr : Test_Attr<"TestSubElementsAccess", [
DeclareAttrInterfaceMethods<SubElementAttrInterface,
["replaceImmediateSubAttribute"]>
]> {
let mnemonic = "sub_elements_access";
let parameters = (ins
"::mlir::Attribute":$first,
"::mlir::Attribute":$second,
"::mlir::Attribute":$third
);
}
#endif // TEST_ATTRDEFS

View File

@ -127,6 +127,57 @@ TestI64ElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}
//===----------------------------------------------------------------------===//
// TestSubElementsAccessAttr
//===----------------------------------------------------------------------===//
Attribute TestSubElementsAccessAttr::parse(::mlir::DialectAsmParser &parser,
::mlir::Type type) {
Attribute first, second, third;
if (parser.parseLess() || parser.parseAttribute(first) ||
parser.parseComma() || parser.parseAttribute(second) ||
parser.parseComma() || parser.parseAttribute(third) ||
parser.parseGreater()) {
return {};
}
return get(parser.getContext(), first, second, third);
}
void TestSubElementsAccessAttr::print(
::mlir::DialectAsmPrinter &printer) const {
printer << getMnemonic() << "<" << getFirst() << ", " << getSecond() << ", "
<< getThird() << ">";
}
void TestSubElementsAccessAttr::walkImmediateSubElements(
llvm::function_ref<void(mlir::Attribute)> walkAttrsFn,
llvm::function_ref<void(mlir::Type)> walkTypesFn) const {
walkAttrsFn(getFirst());
walkAttrsFn(getSecond());
walkAttrsFn(getThird());
}
SubElementAttrInterface TestSubElementsAccessAttr::replaceImmediateSubAttribute(
ArrayRef<std::pair<size_t, Attribute>> replacements) const {
Attribute first = getFirst();
Attribute second = getSecond();
Attribute third = getThird();
for (auto &it : replacements) {
switch (it.first) {
case 0:
first = it.second;
break;
case 1:
second = it.second;
break;
case 2:
third = it.second;
break;
}
}
return get(getContext(), first, second, third);
}
//===----------------------------------------------------------------------===//
// Tablegen Generated Definitions
//===----------------------------------------------------------------------===//