[mlir] Expose printer flags in AsmState

This change exposes printer flags in AsmState and AsmStateImpl. All functions
receiving AsmState as a parameter now use the flags from the AsmState instead of
taking an additional OpPrintingFlags parameter.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D119870
This commit is contained in:
Sergei Grechanik 2022-02-15 17:09:08 -08:00
parent d2a0ef9844
commit 988a3ba0d8
5 changed files with 24 additions and 17 deletions

View File

@ -47,6 +47,9 @@ public:
LocationMap *locationMap = nullptr);
~AsmState();
/// Get the printer flags.
const OpPrintingFlags &getPrinterFlags() const;
/// Return an instance of the internal implementation. Returns nullptr if the
/// state has not been initialized.
detail::AsmStateImpl &getImpl() { return *impl; }

View File

@ -112,9 +112,8 @@ public:
void print(raw_ostream &os, OpPrintingFlags flags = llvm::None) {
state->print(os, flags);
}
void print(raw_ostream &os, AsmState &asmState,
OpPrintingFlags flags = llvm::None) {
state->print(os, asmState, flags);
void print(raw_ostream &os, AsmState &asmState) {
state->print(os, asmState);
}
/// Dump this operation.

View File

@ -192,8 +192,7 @@ public:
bool isBeforeInBlock(Operation *other);
void print(raw_ostream &os, const OpPrintingFlags &flags = llvm::None);
void print(raw_ostream &os, AsmState &state,
const OpPrintingFlags &flags = llvm::None);
void print(raw_ostream &os, AsmState &state);
void dump();
//===--------------------------------------------------------------------===//

View File

@ -1216,6 +1216,9 @@ public:
/// Get the state used for SSA names.
SSANameState &getSSANameState() { return nameState; }
/// Get the printer flags.
const OpPrintingFlags &getPrinterFlags() const { return printerFlags; }
/// Register the location, line and column, within the buffer that the given
/// operation was printed at.
void registerOperationLocation(Operation *op, unsigned line, unsigned col) {
@ -1247,6 +1250,10 @@ AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags,
: impl(std::make_unique<AsmStateImpl>(op, printerFlags, locationMap)) {}
AsmState::~AsmState() = default;
const OpPrintingFlags &AsmState::getPrinterFlags() const {
return impl->getPrinterFlags();
}
//===----------------------------------------------------------------------===//
// AsmPrinter::Impl
//===----------------------------------------------------------------------===//
@ -2405,9 +2412,9 @@ public:
using Impl = AsmPrinter::Impl;
using Impl::printType;
explicit OperationPrinter(raw_ostream &os, OpPrintingFlags flags,
AsmStateImpl &state)
: Impl(os, flags, &state), OpAsmPrinter(static_cast<Impl &>(*this)) {}
explicit OperationPrinter(raw_ostream &os, AsmStateImpl &state)
: Impl(os, state.getPrinterFlags(), &state),
OpAsmPrinter(static_cast<Impl &>(*this)) {}
/// Print the given top-level operation.
void printTopLevelOperation(Operation *op);
@ -2893,7 +2900,7 @@ void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) {
if (!getParent() && !printerFlags.shouldUseLocalScope()) {
AsmState state(this, printerFlags);
state.getImpl().initializeAliases(this);
print(os, state, printerFlags);
print(os, state);
return;
}
@ -2914,12 +2921,11 @@ void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) {
} while (true);
AsmState state(op, printerFlags);
print(os, state, printerFlags);
print(os, state);
}
void Operation::print(raw_ostream &os, AsmState &state,
const OpPrintingFlags &flags) {
OperationPrinter printer(os, flags, state.getImpl());
if (!getParent() && !flags.shouldUseLocalScope())
void Operation::print(raw_ostream &os, AsmState &state) {
OperationPrinter printer(os, state.getImpl());
if (!getParent() && !state.getPrinterFlags().shouldUseLocalScope())
printer.printTopLevelOperation(this);
else
printer.print(this);
@ -2944,7 +2950,7 @@ void Block::print(raw_ostream &os) {
print(os, state);
}
void Block::print(raw_ostream &os, AsmState &state) {
OperationPrinter(os, /*flags=*/llvm::None, state.getImpl()).print(this);
OperationPrinter(os, state.getImpl()).print(this);
}
void Block::dump() { print(llvm::errs()); }
@ -2960,6 +2966,6 @@ void Block::printAsOperand(raw_ostream &os, bool printType) {
printAsOperand(os, state);
}
void Block::printAsOperand(raw_ostream &os, AsmState &state) {
OperationPrinter printer(os, /*flags=*/llvm::None, state.getImpl());
OperationPrinter printer(os, state.getImpl());
printer.printBlockName(this);
}

View File

@ -27,7 +27,7 @@ static void generateLocationsFromIR(raw_ostream &os, StringRef fileName,
// Print the IR to the stream, and collect the raw line+column information.
AsmState::LocationMap opToLineCol;
AsmState state(op, flags, &opToLineCol);
op->print(os, state, flags);
op->print(os, state);
Builder builder(op->getContext());
Optional<StringAttr> tagIdentifier;