mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-02-26 05:18:46 +00:00
[mlir] NFC: Move the state for managing aliases out of ModuleState and into a new class AliasState.
Summary: This reduces the complexity of ModuleState and simplifies the code. A future revision will mold ModuleState into something that can be used by users for caching of printer state, as well as for implementing printAsOperand style methods. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D72292
This commit is contained in:
parent
766ce87e9b
commit
659f7d463b
@ -155,98 +155,46 @@ bool OpPrintingFlags::shouldPrintGenericOpForm() const {
|
||||
bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ModuleState
|
||||
// AliasState
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
/// A special index constant used for non-kind attribute aliases.
|
||||
static constexpr int kNonAttrKindAlias = -1;
|
||||
|
||||
class ModuleState {
|
||||
/// This class manages the state for type and attribute aliases.
|
||||
class AliasState {
|
||||
public:
|
||||
explicit ModuleState(MLIRContext *context) : interfaces(context) {}
|
||||
void initialize(Operation *op);
|
||||
// Initialize the internal aliases.
|
||||
void
|
||||
initialize(Operation *op,
|
||||
DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
|
||||
|
||||
Twine getAttributeAlias(Attribute attr) const {
|
||||
auto alias = attrToAlias.find(attr);
|
||||
if (alias == attrToAlias.end())
|
||||
return Twine();
|
||||
/// Return a name used for an attribute alias, or empty if there is no alias.
|
||||
Twine getAttributeAlias(Attribute attr) const;
|
||||
|
||||
// Return the alias for this attribute, along with the index if this was
|
||||
// generated by a kind alias.
|
||||
int kindIndex = alias->second.second;
|
||||
return alias->second.first +
|
||||
(kindIndex == kNonAttrKindAlias ? Twine() : Twine(kindIndex));
|
||||
}
|
||||
/// Print all of the referenced attribute aliases.
|
||||
void printAttributeAliases(raw_ostream &os) const;
|
||||
|
||||
void printAttributeAliases(raw_ostream &os) const {
|
||||
auto printAlias = [&](StringRef alias, Attribute attr, int index) {
|
||||
os << '#' << alias;
|
||||
if (index != kNonAttrKindAlias)
|
||||
os << index;
|
||||
os << " = " << attr << '\n';
|
||||
};
|
||||
/// Return a string to use as an alias for the given type, or empty if there
|
||||
/// is no alias recorded.
|
||||
StringRef getTypeAlias(Type ty) const;
|
||||
|
||||
// Print all of the attribute kind aliases.
|
||||
for (auto &kindAlias : attrKindToAlias) {
|
||||
for (unsigned i = 0, e = kindAlias.second.second.size(); i != e; ++i)
|
||||
printAlias(kindAlias.second.first, kindAlias.second.second[i], i);
|
||||
os << "\n";
|
||||
}
|
||||
|
||||
// In a second pass print all of the remaining attribute aliases that aren't
|
||||
// kind aliases.
|
||||
for (Attribute attr : usedAttributes) {
|
||||
auto alias = attrToAlias.find(attr);
|
||||
if (alias != attrToAlias.end() &&
|
||||
alias->second.second == kNonAttrKindAlias)
|
||||
printAlias(alias->second.first, attr, alias->second.second);
|
||||
}
|
||||
}
|
||||
|
||||
StringRef getTypeAlias(Type ty) const { return typeToAlias.lookup(ty); }
|
||||
|
||||
void printTypeAliases(raw_ostream &os) const {
|
||||
for (Type type : usedTypes) {
|
||||
auto alias = typeToAlias.find(type);
|
||||
if (alias != typeToAlias.end())
|
||||
os << '!' << alias->second << " = type " << type << '\n';
|
||||
}
|
||||
}
|
||||
|
||||
/// Get an instance of the OpAsmDialectInterface for the given dialect, or
|
||||
/// null if one wasn't registered.
|
||||
const OpAsmDialectInterface *getOpAsmInterface(Dialect *dialect) {
|
||||
return interfaces.getInterfaceFor(dialect);
|
||||
}
|
||||
/// Print all of the referenced type aliases.
|
||||
void printTypeAliases(raw_ostream &os) const;
|
||||
|
||||
private:
|
||||
void recordAttributeReference(Attribute attr) {
|
||||
// Don't recheck attributes that have already been seen or those that
|
||||
// already have an alias.
|
||||
if (!usedAttributes.insert(attr) || attrToAlias.count(attr))
|
||||
return;
|
||||
/// A special index constant used for non-kind attribute aliases.
|
||||
enum { NonAttrKindAlias = -1 };
|
||||
|
||||
// If this attribute kind has an alias, then record one for this attribute.
|
||||
auto alias = attrKindToAlias.find(static_cast<unsigned>(attr.getKind()));
|
||||
if (alias == attrKindToAlias.end())
|
||||
return;
|
||||
std::pair<StringRef, int> attrAlias(alias->second.first,
|
||||
alias->second.second.size());
|
||||
attrToAlias.insert({attr, attrAlias});
|
||||
alias->second.second.push_back(attr);
|
||||
}
|
||||
/// Record a reference to the given attribute.
|
||||
void recordAttributeReference(Attribute attr);
|
||||
|
||||
void recordTypeReference(Type ty) { usedTypes.insert(ty); }
|
||||
/// Record a reference to the given type.
|
||||
void recordTypeReference(Type ty);
|
||||
|
||||
// Visit functions.
|
||||
void visitOperation(Operation *op);
|
||||
void visitType(Type type);
|
||||
void visitAttribute(Attribute attr);
|
||||
|
||||
// Initialize symbol aliases.
|
||||
void initializeSymbolAliases();
|
||||
|
||||
/// Set of attributes known to be used within the module.
|
||||
llvm::SetVector<Attribute> usedAttributes;
|
||||
|
||||
@ -265,59 +213,9 @@ private:
|
||||
|
||||
/// A mapping between a type and a given alias.
|
||||
DenseMap<Type, StringRef> typeToAlias;
|
||||
|
||||
/// Collection of OpAsm interfaces implemented in the context.
|
||||
DialectInterfaceCollection<OpAsmDialectInterface> interfaces;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
// TODO Support visiting other types/operations when implemented.
|
||||
void ModuleState::visitType(Type type) {
|
||||
recordTypeReference(type);
|
||||
if (auto funcType = type.dyn_cast<FunctionType>()) {
|
||||
// Visit input and result types for functions.
|
||||
for (auto input : funcType.getInputs())
|
||||
visitType(input);
|
||||
for (auto result : funcType.getResults())
|
||||
visitType(result);
|
||||
return;
|
||||
}
|
||||
if (auto memref = type.dyn_cast<MemRefType>()) {
|
||||
// Visit affine maps in memref type.
|
||||
for (auto map : memref.getAffineMaps())
|
||||
recordAttributeReference(AffineMapAttr::get(map));
|
||||
}
|
||||
if (auto shapedType = type.dyn_cast<ShapedType>()) {
|
||||
visitType(shapedType.getElementType());
|
||||
}
|
||||
}
|
||||
|
||||
void ModuleState::visitAttribute(Attribute attr) {
|
||||
recordAttributeReference(attr);
|
||||
if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
|
||||
for (auto elt : arrayAttr.getValue())
|
||||
visitAttribute(elt);
|
||||
} else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) {
|
||||
visitType(typeAttr.getValue());
|
||||
}
|
||||
}
|
||||
|
||||
void ModuleState::visitOperation(Operation *op) {
|
||||
// Visit all the types used in the operation.
|
||||
for (auto type : op->getOperandTypes())
|
||||
visitType(type);
|
||||
for (auto type : op->getResultTypes())
|
||||
visitType(type);
|
||||
for (auto ®ion : op->getRegions())
|
||||
for (auto &block : region)
|
||||
for (auto arg : block.getArguments())
|
||||
visitType(arg->getType());
|
||||
|
||||
// Visit each of the attributes.
|
||||
for (auto elt : op->getAttrs())
|
||||
visitAttribute(elt.second);
|
||||
}
|
||||
|
||||
// Utility to generate a function to register a symbol alias.
|
||||
static bool canRegisterAlias(StringRef name, llvm::StringSet<> &usedAliases) {
|
||||
assert(!name.empty() && "expected alias name to be non-empty");
|
||||
@ -329,7 +227,9 @@ static bool canRegisterAlias(StringRef name, llvm::StringSet<> &usedAliases) {
|
||||
return !name.contains('.') && usedAliases.insert(name).second;
|
||||
}
|
||||
|
||||
void ModuleState::initializeSymbolAliases() {
|
||||
void AliasState::initialize(
|
||||
Operation *op,
|
||||
DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
|
||||
// Track the identifiers in use for each symbol so that the same identifier
|
||||
// isn't used twice.
|
||||
llvm::StringSet<> usedAliases;
|
||||
@ -374,7 +274,7 @@ void ModuleState::initializeSymbolAliases() {
|
||||
for (auto &attrAliasPair : attributeAliases) {
|
||||
std::tie(attr, alias) = attrAliasPair;
|
||||
if (!reservedAttrNames.match(alias) && canRegisterAlias(alias, usedAliases))
|
||||
attrToAlias.insert({attr, {alias, kNonAttrKindAlias}});
|
||||
attrToAlias.insert({attr, {alias, NonAttrKindAlias}});
|
||||
}
|
||||
|
||||
// Clear the set of used identifiers as types can have the same identifiers as
|
||||
@ -385,16 +285,164 @@ void ModuleState::initializeSymbolAliases() {
|
||||
for (auto &typeAliasPair : typeAliases)
|
||||
if (canRegisterAlias(typeAliasPair.second, usedAliases))
|
||||
typeToAlias.insert(typeAliasPair);
|
||||
}
|
||||
|
||||
void ModuleState::initialize(Operation *op) {
|
||||
// Initialize the symbol aliases.
|
||||
initializeSymbolAliases();
|
||||
|
||||
// Visit each of the nested operations.
|
||||
// Traverse the given IR to generate the set of used attributes/types.
|
||||
op->walk([&](Operation *op) { visitOperation(op); });
|
||||
}
|
||||
|
||||
/// Return a name used for an attribute alias, or empty if there is no alias.
|
||||
Twine AliasState::getAttributeAlias(Attribute attr) const {
|
||||
auto alias = attrToAlias.find(attr);
|
||||
if (alias == attrToAlias.end())
|
||||
return Twine();
|
||||
|
||||
// Return the alias for this attribute, along with the index if this was
|
||||
// generated by a kind alias.
|
||||
int kindIndex = alias->second.second;
|
||||
return alias->second.first +
|
||||
(kindIndex == NonAttrKindAlias ? Twine() : Twine(kindIndex));
|
||||
}
|
||||
|
||||
/// Print all of the referenced attribute aliases.
|
||||
void AliasState::printAttributeAliases(raw_ostream &os) const {
|
||||
auto printAlias = [&](StringRef alias, Attribute attr, int index) {
|
||||
os << '#' << alias;
|
||||
if (index != NonAttrKindAlias)
|
||||
os << index;
|
||||
os << " = " << attr << '\n';
|
||||
};
|
||||
|
||||
// Print all of the attribute kind aliases.
|
||||
for (auto &kindAlias : attrKindToAlias) {
|
||||
auto &aliasAttrsPair = kindAlias.second;
|
||||
for (unsigned i = 0, e = aliasAttrsPair.second.size(); i != e; ++i)
|
||||
printAlias(aliasAttrsPair.first, aliasAttrsPair.second[i], i);
|
||||
os << "\n";
|
||||
}
|
||||
|
||||
// In a second pass print all of the remaining attribute aliases that aren't
|
||||
// kind aliases.
|
||||
for (Attribute attr : usedAttributes) {
|
||||
auto alias = attrToAlias.find(attr);
|
||||
if (alias != attrToAlias.end() && alias->second.second == NonAttrKindAlias)
|
||||
printAlias(alias->second.first, attr, alias->second.second);
|
||||
}
|
||||
}
|
||||
|
||||
/// Return a string to use as an alias for the given type, or empty if there
|
||||
/// is no alias recorded.
|
||||
StringRef AliasState::getTypeAlias(Type ty) const {
|
||||
return typeToAlias.lookup(ty);
|
||||
}
|
||||
|
||||
/// Print all of the referenced type aliases.
|
||||
void AliasState::printTypeAliases(raw_ostream &os) const {
|
||||
for (Type type : usedTypes) {
|
||||
auto alias = typeToAlias.find(type);
|
||||
if (alias != typeToAlias.end())
|
||||
os << '!' << alias->second << " = type " << type << '\n';
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a reference to the given attribute.
|
||||
void AliasState::recordAttributeReference(Attribute attr) {
|
||||
// Don't recheck attributes that have already been seen or those that
|
||||
// already have an alias.
|
||||
if (!usedAttributes.insert(attr) || attrToAlias.count(attr))
|
||||
return;
|
||||
|
||||
// If this attribute kind has an alias, then record one for this attribute.
|
||||
auto alias = attrKindToAlias.find(static_cast<unsigned>(attr.getKind()));
|
||||
if (alias == attrKindToAlias.end())
|
||||
return;
|
||||
std::pair<StringRef, int> attrAlias(alias->second.first,
|
||||
alias->second.second.size());
|
||||
attrToAlias.insert({attr, attrAlias});
|
||||
alias->second.second.push_back(attr);
|
||||
}
|
||||
|
||||
/// Record a reference to the given type.
|
||||
void AliasState::recordTypeReference(Type ty) { usedTypes.insert(ty); }
|
||||
|
||||
// TODO Support visiting other types/operations when implemented.
|
||||
void AliasState::visitType(Type type) {
|
||||
recordTypeReference(type);
|
||||
|
||||
if (auto funcType = type.dyn_cast<FunctionType>()) {
|
||||
// Visit input and result types for functions.
|
||||
for (auto input : funcType.getInputs())
|
||||
visitType(input);
|
||||
for (auto result : funcType.getResults())
|
||||
visitType(result);
|
||||
} else if (auto shapedType = type.dyn_cast<ShapedType>()) {
|
||||
visitType(shapedType.getElementType());
|
||||
|
||||
// Visit affine maps in memref type.
|
||||
if (auto memref = type.dyn_cast<MemRefType>())
|
||||
for (auto map : memref.getAffineMaps())
|
||||
recordAttributeReference(AffineMapAttr::get(map));
|
||||
}
|
||||
}
|
||||
|
||||
void AliasState::visitAttribute(Attribute attr) {
|
||||
recordAttributeReference(attr);
|
||||
|
||||
if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
|
||||
for (auto elt : arrayAttr.getValue())
|
||||
visitAttribute(elt);
|
||||
} else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) {
|
||||
visitType(typeAttr.getValue());
|
||||
}
|
||||
}
|
||||
|
||||
void AliasState::visitOperation(Operation *op) {
|
||||
// Visit all the types used in the operation.
|
||||
for (auto type : op->getOperandTypes())
|
||||
visitType(type);
|
||||
for (auto type : op->getResultTypes())
|
||||
visitType(type);
|
||||
for (auto ®ion : op->getRegions())
|
||||
for (auto &block : region)
|
||||
for (auto arg : block.getArguments())
|
||||
visitType(arg->getType());
|
||||
|
||||
// Visit each of the attributes.
|
||||
for (auto elt : op->getAttrs())
|
||||
visitAttribute(elt.second);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ModuleState
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
class ModuleState {
|
||||
public:
|
||||
explicit ModuleState(MLIRContext *context) : interfaces(context) {}
|
||||
|
||||
/// Initialize the alias state to enable the printing of aliases.
|
||||
void initializeAliases(Operation *op) {
|
||||
aliasState.initialize(op, interfaces);
|
||||
}
|
||||
|
||||
/// Get an instance of the OpAsmDialectInterface for the given dialect, or
|
||||
/// null if one wasn't registered.
|
||||
const OpAsmDialectInterface *getOpAsmInterface(Dialect *dialect) {
|
||||
return interfaces.getInterfaceFor(dialect);
|
||||
}
|
||||
|
||||
/// Get the state used for aliases.
|
||||
AliasState &getAliasState() { return aliasState; }
|
||||
|
||||
private:
|
||||
/// Collection of OpAsm interfaces implemented in the context.
|
||||
DialectInterfaceCollection<OpAsmDialectInterface> interfaces;
|
||||
|
||||
/// The state used for attribute and type aliases.
|
||||
AliasState aliasState;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ModulePrinter
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -745,7 +793,7 @@ void ModulePrinter::printAttribute(Attribute attr, bool mayElideType) {
|
||||
|
||||
// Check for an alias for this attribute.
|
||||
if (state) {
|
||||
Twine alias = state->getAttributeAlias(attr);
|
||||
Twine alias = state->getAliasState().getAttributeAlias(attr);
|
||||
if (!alias.isTriviallyEmpty()) {
|
||||
os << '#' << alias;
|
||||
return;
|
||||
@ -975,7 +1023,7 @@ void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) {
|
||||
void ModulePrinter::printType(Type type) {
|
||||
// Check for an alias for this type.
|
||||
if (state) {
|
||||
StringRef alias = state->getTypeAlias(type);
|
||||
StringRef alias = state->getAliasState().getTypeAlias(type);
|
||||
if (!alias.empty()) {
|
||||
os << '!' << alias;
|
||||
return;
|
||||
@ -1997,8 +2045,8 @@ void OperationPrinter::printSuccessorAndUseList(Operation *term,
|
||||
void ModulePrinter::print(ModuleOp module) {
|
||||
// Output the aliases at the top level.
|
||||
if (state) {
|
||||
state->printAttributeAliases(os);
|
||||
state->printTypeAliases(os);
|
||||
state->getAliasState().printAttributeAliases(os);
|
||||
state->getAliasState().printTypeAliases(os);
|
||||
}
|
||||
|
||||
// Print the module.
|
||||
@ -2136,9 +2184,9 @@ void Block::printAsOperand(raw_ostream &os, bool printType) {
|
||||
|
||||
void ModuleOp::print(raw_ostream &os, OpPrintingFlags flags) {
|
||||
ModuleState state(getContext());
|
||||
// Skip initializing in local scope to avoid populating aliases.
|
||||
// Don't populate aliases when printing at local scope.
|
||||
if (!flags.shouldUseLocalScope())
|
||||
state.initialize(*this);
|
||||
state.initializeAliases(*this);
|
||||
ModulePrinter(os, flags, &state).print(*this);
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user