[mlir] NFC: Move the state for managing aliases out of ModuleState and into a new class AliasState.

Summary: This reduces the complexity of ModuleState and simplifies the code. A future revision will mold ModuleState into something that can be used by users for caching of printer state, as well as for implementing printAsOperand style methods.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D72292
This commit is contained in:
River Riddle 2020-01-08 10:11:56 -08:00
parent 766ce87e9b
commit 659f7d463b

View File

@ -155,98 +155,46 @@ bool OpPrintingFlags::shouldPrintGenericOpForm() const {
bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; }
//===----------------------------------------------------------------------===//
// ModuleState
// AliasState
//===----------------------------------------------------------------------===//
namespace {
/// A special index constant used for non-kind attribute aliases.
static constexpr int kNonAttrKindAlias = -1;
class ModuleState {
/// This class manages the state for type and attribute aliases.
class AliasState {
public:
explicit ModuleState(MLIRContext *context) : interfaces(context) {}
void initialize(Operation *op);
// Initialize the internal aliases.
void
initialize(Operation *op,
DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
Twine getAttributeAlias(Attribute attr) const {
auto alias = attrToAlias.find(attr);
if (alias == attrToAlias.end())
return Twine();
/// Return a name used for an attribute alias, or empty if there is no alias.
Twine getAttributeAlias(Attribute attr) const;
// Return the alias for this attribute, along with the index if this was
// generated by a kind alias.
int kindIndex = alias->second.second;
return alias->second.first +
(kindIndex == kNonAttrKindAlias ? Twine() : Twine(kindIndex));
}
/// Print all of the referenced attribute aliases.
void printAttributeAliases(raw_ostream &os) const;
void printAttributeAliases(raw_ostream &os) const {
auto printAlias = [&](StringRef alias, Attribute attr, int index) {
os << '#' << alias;
if (index != kNonAttrKindAlias)
os << index;
os << " = " << attr << '\n';
};
/// Return a string to use as an alias for the given type, or empty if there
/// is no alias recorded.
StringRef getTypeAlias(Type ty) const;
// Print all of the attribute kind aliases.
for (auto &kindAlias : attrKindToAlias) {
for (unsigned i = 0, e = kindAlias.second.second.size(); i != e; ++i)
printAlias(kindAlias.second.first, kindAlias.second.second[i], i);
os << "\n";
}
// In a second pass print all of the remaining attribute aliases that aren't
// kind aliases.
for (Attribute attr : usedAttributes) {
auto alias = attrToAlias.find(attr);
if (alias != attrToAlias.end() &&
alias->second.second == kNonAttrKindAlias)
printAlias(alias->second.first, attr, alias->second.second);
}
}
StringRef getTypeAlias(Type ty) const { return typeToAlias.lookup(ty); }
void printTypeAliases(raw_ostream &os) const {
for (Type type : usedTypes) {
auto alias = typeToAlias.find(type);
if (alias != typeToAlias.end())
os << '!' << alias->second << " = type " << type << '\n';
}
}
/// Get an instance of the OpAsmDialectInterface for the given dialect, or
/// null if one wasn't registered.
const OpAsmDialectInterface *getOpAsmInterface(Dialect *dialect) {
return interfaces.getInterfaceFor(dialect);
}
/// Print all of the referenced type aliases.
void printTypeAliases(raw_ostream &os) const;
private:
void recordAttributeReference(Attribute attr) {
// Don't recheck attributes that have already been seen or those that
// already have an alias.
if (!usedAttributes.insert(attr) || attrToAlias.count(attr))
return;
/// A special index constant used for non-kind attribute aliases.
enum { NonAttrKindAlias = -1 };
// If this attribute kind has an alias, then record one for this attribute.
auto alias = attrKindToAlias.find(static_cast<unsigned>(attr.getKind()));
if (alias == attrKindToAlias.end())
return;
std::pair<StringRef, int> attrAlias(alias->second.first,
alias->second.second.size());
attrToAlias.insert({attr, attrAlias});
alias->second.second.push_back(attr);
}
/// Record a reference to the given attribute.
void recordAttributeReference(Attribute attr);
void recordTypeReference(Type ty) { usedTypes.insert(ty); }
/// Record a reference to the given type.
void recordTypeReference(Type ty);
// Visit functions.
void visitOperation(Operation *op);
void visitType(Type type);
void visitAttribute(Attribute attr);
// Initialize symbol aliases.
void initializeSymbolAliases();
/// Set of attributes known to be used within the module.
llvm::SetVector<Attribute> usedAttributes;
@ -265,59 +213,9 @@ private:
/// A mapping between a type and a given alias.
DenseMap<Type, StringRef> typeToAlias;
/// Collection of OpAsm interfaces implemented in the context.
DialectInterfaceCollection<OpAsmDialectInterface> interfaces;
};
} // end anonymous namespace
// TODO Support visiting other types/operations when implemented.
void ModuleState::visitType(Type type) {
recordTypeReference(type);
if (auto funcType = type.dyn_cast<FunctionType>()) {
// Visit input and result types for functions.
for (auto input : funcType.getInputs())
visitType(input);
for (auto result : funcType.getResults())
visitType(result);
return;
}
if (auto memref = type.dyn_cast<MemRefType>()) {
// Visit affine maps in memref type.
for (auto map : memref.getAffineMaps())
recordAttributeReference(AffineMapAttr::get(map));
}
if (auto shapedType = type.dyn_cast<ShapedType>()) {
visitType(shapedType.getElementType());
}
}
void ModuleState::visitAttribute(Attribute attr) {
recordAttributeReference(attr);
if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
for (auto elt : arrayAttr.getValue())
visitAttribute(elt);
} else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) {
visitType(typeAttr.getValue());
}
}
void ModuleState::visitOperation(Operation *op) {
// Visit all the types used in the operation.
for (auto type : op->getOperandTypes())
visitType(type);
for (auto type : op->getResultTypes())
visitType(type);
for (auto &region : op->getRegions())
for (auto &block : region)
for (auto arg : block.getArguments())
visitType(arg->getType());
// Visit each of the attributes.
for (auto elt : op->getAttrs())
visitAttribute(elt.second);
}
// Utility to generate a function to register a symbol alias.
static bool canRegisterAlias(StringRef name, llvm::StringSet<> &usedAliases) {
assert(!name.empty() && "expected alias name to be non-empty");
@ -329,7 +227,9 @@ static bool canRegisterAlias(StringRef name, llvm::StringSet<> &usedAliases) {
return !name.contains('.') && usedAliases.insert(name).second;
}
void ModuleState::initializeSymbolAliases() {
void AliasState::initialize(
Operation *op,
DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
// Track the identifiers in use for each symbol so that the same identifier
// isn't used twice.
llvm::StringSet<> usedAliases;
@ -374,7 +274,7 @@ void ModuleState::initializeSymbolAliases() {
for (auto &attrAliasPair : attributeAliases) {
std::tie(attr, alias) = attrAliasPair;
if (!reservedAttrNames.match(alias) && canRegisterAlias(alias, usedAliases))
attrToAlias.insert({attr, {alias, kNonAttrKindAlias}});
attrToAlias.insert({attr, {alias, NonAttrKindAlias}});
}
// Clear the set of used identifiers as types can have the same identifiers as
@ -385,16 +285,164 @@ void ModuleState::initializeSymbolAliases() {
for (auto &typeAliasPair : typeAliases)
if (canRegisterAlias(typeAliasPair.second, usedAliases))
typeToAlias.insert(typeAliasPair);
}
void ModuleState::initialize(Operation *op) {
// Initialize the symbol aliases.
initializeSymbolAliases();
// Visit each of the nested operations.
// Traverse the given IR to generate the set of used attributes/types.
op->walk([&](Operation *op) { visitOperation(op); });
}
/// Return a name used for an attribute alias, or empty if there is no alias.
Twine AliasState::getAttributeAlias(Attribute attr) const {
auto alias = attrToAlias.find(attr);
if (alias == attrToAlias.end())
return Twine();
// Return the alias for this attribute, along with the index if this was
// generated by a kind alias.
int kindIndex = alias->second.second;
return alias->second.first +
(kindIndex == NonAttrKindAlias ? Twine() : Twine(kindIndex));
}
/// Print all of the referenced attribute aliases.
void AliasState::printAttributeAliases(raw_ostream &os) const {
auto printAlias = [&](StringRef alias, Attribute attr, int index) {
os << '#' << alias;
if (index != NonAttrKindAlias)
os << index;
os << " = " << attr << '\n';
};
// Print all of the attribute kind aliases.
for (auto &kindAlias : attrKindToAlias) {
auto &aliasAttrsPair = kindAlias.second;
for (unsigned i = 0, e = aliasAttrsPair.second.size(); i != e; ++i)
printAlias(aliasAttrsPair.first, aliasAttrsPair.second[i], i);
os << "\n";
}
// In a second pass print all of the remaining attribute aliases that aren't
// kind aliases.
for (Attribute attr : usedAttributes) {
auto alias = attrToAlias.find(attr);
if (alias != attrToAlias.end() && alias->second.second == NonAttrKindAlias)
printAlias(alias->second.first, attr, alias->second.second);
}
}
/// Return a string to use as an alias for the given type, or empty if there
/// is no alias recorded.
StringRef AliasState::getTypeAlias(Type ty) const {
return typeToAlias.lookup(ty);
}
/// Print all of the referenced type aliases.
void AliasState::printTypeAliases(raw_ostream &os) const {
for (Type type : usedTypes) {
auto alias = typeToAlias.find(type);
if (alias != typeToAlias.end())
os << '!' << alias->second << " = type " << type << '\n';
}
}
/// Record a reference to the given attribute.
void AliasState::recordAttributeReference(Attribute attr) {
// Don't recheck attributes that have already been seen or those that
// already have an alias.
if (!usedAttributes.insert(attr) || attrToAlias.count(attr))
return;
// If this attribute kind has an alias, then record one for this attribute.
auto alias = attrKindToAlias.find(static_cast<unsigned>(attr.getKind()));
if (alias == attrKindToAlias.end())
return;
std::pair<StringRef, int> attrAlias(alias->second.first,
alias->second.second.size());
attrToAlias.insert({attr, attrAlias});
alias->second.second.push_back(attr);
}
/// Record a reference to the given type.
void AliasState::recordTypeReference(Type ty) { usedTypes.insert(ty); }
// TODO Support visiting other types/operations when implemented.
void AliasState::visitType(Type type) {
recordTypeReference(type);
if (auto funcType = type.dyn_cast<FunctionType>()) {
// Visit input and result types for functions.
for (auto input : funcType.getInputs())
visitType(input);
for (auto result : funcType.getResults())
visitType(result);
} else if (auto shapedType = type.dyn_cast<ShapedType>()) {
visitType(shapedType.getElementType());
// Visit affine maps in memref type.
if (auto memref = type.dyn_cast<MemRefType>())
for (auto map : memref.getAffineMaps())
recordAttributeReference(AffineMapAttr::get(map));
}
}
void AliasState::visitAttribute(Attribute attr) {
recordAttributeReference(attr);
if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
for (auto elt : arrayAttr.getValue())
visitAttribute(elt);
} else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) {
visitType(typeAttr.getValue());
}
}
void AliasState::visitOperation(Operation *op) {
// Visit all the types used in the operation.
for (auto type : op->getOperandTypes())
visitType(type);
for (auto type : op->getResultTypes())
visitType(type);
for (auto &region : op->getRegions())
for (auto &block : region)
for (auto arg : block.getArguments())
visitType(arg->getType());
// Visit each of the attributes.
for (auto elt : op->getAttrs())
visitAttribute(elt.second);
}
//===----------------------------------------------------------------------===//
// ModuleState
//===----------------------------------------------------------------------===//
namespace {
class ModuleState {
public:
explicit ModuleState(MLIRContext *context) : interfaces(context) {}
/// Initialize the alias state to enable the printing of aliases.
void initializeAliases(Operation *op) {
aliasState.initialize(op, interfaces);
}
/// Get an instance of the OpAsmDialectInterface for the given dialect, or
/// null if one wasn't registered.
const OpAsmDialectInterface *getOpAsmInterface(Dialect *dialect) {
return interfaces.getInterfaceFor(dialect);
}
/// Get the state used for aliases.
AliasState &getAliasState() { return aliasState; }
private:
/// Collection of OpAsm interfaces implemented in the context.
DialectInterfaceCollection<OpAsmDialectInterface> interfaces;
/// The state used for attribute and type aliases.
AliasState aliasState;
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// ModulePrinter
//===----------------------------------------------------------------------===//
@ -745,7 +793,7 @@ void ModulePrinter::printAttribute(Attribute attr, bool mayElideType) {
// Check for an alias for this attribute.
if (state) {
Twine alias = state->getAttributeAlias(attr);
Twine alias = state->getAliasState().getAttributeAlias(attr);
if (!alias.isTriviallyEmpty()) {
os << '#' << alias;
return;
@ -975,7 +1023,7 @@ void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) {
void ModulePrinter::printType(Type type) {
// Check for an alias for this type.
if (state) {
StringRef alias = state->getTypeAlias(type);
StringRef alias = state->getAliasState().getTypeAlias(type);
if (!alias.empty()) {
os << '!' << alias;
return;
@ -1997,8 +2045,8 @@ void OperationPrinter::printSuccessorAndUseList(Operation *term,
void ModulePrinter::print(ModuleOp module) {
// Output the aliases at the top level.
if (state) {
state->printAttributeAliases(os);
state->printTypeAliases(os);
state->getAliasState().printAttributeAliases(os);
state->getAliasState().printTypeAliases(os);
}
// Print the module.
@ -2136,9 +2184,9 @@ void Block::printAsOperand(raw_ostream &os, bool printType) {
void ModuleOp::print(raw_ostream &os, OpPrintingFlags flags) {
ModuleState state(getContext());
// Skip initializing in local scope to avoid populating aliases.
// Don't populate aliases when printing at local scope.
if (!flags.shouldUseLocalScope())
state.initialize(*this);
state.initializeAliases(*this);
ModulePrinter(os, flags, &state).print(*this);
}