Add a new interface method getAsmBlockName() on OpAsmOpInterface to control block names

This allows operations to control the block ids used by the printer in nested regions.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D115849
This commit is contained in:
Mehdi Amini 2022-02-07 21:10:18 +00:00
parent 3571bdb4f3
commit b055e6d313
5 changed files with 136 additions and 27 deletions

View File

@ -62,6 +62,36 @@ def OpAsmOpInterface : OpInterface<"OpAsmOpInterface"> {
),
"", "return;"
>,
InterfaceMethod<[{
Get the name to use for a given block inside a region attached to this
operation.
For example if this operation has multiple blocks:
```mlir
some.op() ({
^bb0:
...
^bb1:
...
})
```
the method will be invoked on each of the blocks allowing the op to
print:
```mlir
some.op() ({
^custom_foo_name:
...
^custom_bar_name:
...
})
```
}],
"void", "getAsmBlockNames",
(ins "::mlir::OpAsmSetBlockNameFn":$setNameFn), "", ";"
>,
StaticInterfaceMethod<[{
Return the default dialect used when printing/parsing operations in
regions nested under this operation. This allows for eliding the dialect

View File

@ -1322,6 +1322,10 @@ private:
/// operation. See 'getAsmResultNames' below for more details.
using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>;
/// A functor used to set the name of blocks in regions directly nested under
/// an operation.
using OpAsmSetBlockNameFn = function_ref<void(Block *, StringRef)>;
class OpAsmDialectInterface
: public DialectInterface::Base<OpAsmDialectInterface> {
public:

View File

@ -791,6 +791,13 @@ void AliasState::printAliases(raw_ostream &os, NewLineCounter &newLine,
//===----------------------------------------------------------------------===//
namespace {
/// Info about block printing: a number which is its position in the visitation
/// order, and a name that is used to print reference to it, e.g. ^bb42.
struct BlockInfo {
int ordering;
StringRef name;
};
/// This class manages the state of SSA value names.
class SSANameState {
public:
@ -808,8 +815,8 @@ public:
/// operation, or empty if none exist.
ArrayRef<int> getOpResultGroups(Operation *op);
/// Get the ID for the given block.
unsigned getBlockID(Block *block);
/// Get the info for the given block.
BlockInfo getBlockInfo(Block *block);
/// Renumber the arguments for the specified region to the same names as the
/// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for
@ -846,8 +853,9 @@ private:
/// value of this map are the result numbers that start a result group.
DenseMap<Operation *, SmallVector<int, 1>> opResultGroups;
/// This is the block ID for each block in the current.
DenseMap<Block *, unsigned> blockIDs;
/// This maps blocks to there visitation number in the current region as well
/// as the string representing their name.
DenseMap<Block *, BlockInfo> blockNames;
/// This keeps track of all of the non-numeric names that are in flight,
/// allowing us to check for duplicates.
@ -967,9 +975,10 @@ ArrayRef<int> SSANameState::getOpResultGroups(Operation *op) {
return it == opResultGroups.end() ? ArrayRef<int>() : it->second;
}
unsigned SSANameState::getBlockID(Block *block) {
auto it = blockIDs.find(block);
return it != blockIDs.end() ? it->second : NameSentinel;
BlockInfo SSANameState::getBlockInfo(Block *block) {
auto it = blockNames.find(block);
BlockInfo invalidBlock{-1, "INVALIDBLOCK"};
return it != blockNames.end() ? it->second : invalidBlock;
}
void SSANameState::shadowRegionArgs(Region &region, ValueRange namesToUse) {
@ -1021,7 +1030,16 @@ void SSANameState::numberValuesInRegion(Region &region) {
for (auto &block : region) {
// Each block gets a unique ID, and all of the operations within it get
// numbered as well.
blockIDs[&block] = nextBlockID++;
auto blockInfoIt = blockNames.insert({&block, {-1, ""}});
if (blockInfoIt.second) {
// This block hasn't been named through `getAsmBlockArgumentNames`, use
// default `^bbNNN` format.
std::string name;
llvm::raw_string_ostream(name) << "^bb" << nextBlockID;
blockInfoIt.first->second.name = StringRef(name).copy(usedNameAllocator);
}
blockInfoIt.first->second.ordering = nextBlockID++;
numberValuesInBlock(block);
}
}
@ -1048,11 +1066,6 @@ void SSANameState::numberValuesInBlock(Block &block) {
}
void SSANameState::numberValuesInOp(Operation &op) {
unsigned numResults = op.getNumResults();
if (numResults == 0)
return;
Value resultBegin = op.getResult(0);
// Function used to set the special result names for the operation.
SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0);
auto setResultNameFn = [&](Value result, StringRef name) {
@ -1064,11 +1077,34 @@ void SSANameState::numberValuesInOp(Operation &op) {
if (int resultNo = result.cast<OpResult>().getResultNumber())
resultGroups.push_back(resultNo);
};
// Operations can customize the printing of block names in OpAsmOpInterface.
auto setBlockNameFn = [&](Block *block, StringRef name) {
assert(block->getParentOp() == &op &&
"getAsmBlockArgumentNames callback invoked on a block not directly "
"nested under the current operation");
assert(!blockNames.count(block) && "block numbered multiple times");
SmallString<16> tmpBuffer{"^"};
name = sanitizeIdentifier(name, tmpBuffer);
if (name.data() != tmpBuffer.data()) {
tmpBuffer.append(name);
name = tmpBuffer.str();
}
name = name.copy(usedNameAllocator);
blockNames[block] = {-1, name};
};
if (!printerFlags.shouldPrintGenericOpForm()) {
if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op))
if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) {
asmInterface.getAsmBlockNames(setBlockNameFn);
asmInterface.getAsmResultNames(setResultNameFn);
}
}
unsigned numResults = op.getNumResults();
if (numResults == 0)
return;
Value resultBegin = op.getResult(0);
// If the first result wasn't numbered, give it a default number.
if (valueIDs.try_emplace(resultBegin, nextValueID).second)
++nextValueID;
@ -2609,11 +2645,7 @@ void OperationPrinter::printGenericOp(Operation *op, bool printOpName) {
}
void OperationPrinter::printBlockName(Block *block) {
auto id = state->getSSANameState().getBlockID(block);
if (id != SSANameState::NameSentinel)
os << "^bb" << id;
else
os << "^INVALIDBLOCK";
os << state->getSSANameState().getBlockInfo(block).name;
}
void OperationPrinter::print(Block *block, bool printBlockArgs,
@ -2647,18 +2679,18 @@ void OperationPrinter::print(Block *block, bool printBlockArgs,
os << " // pred: ";
printBlockName(pred);
} else {
// We want to print the predecessors in increasing numeric order, not in
// We want to print the predecessors in a stable order, not in
// whatever order the use-list is in, so gather and sort them.
SmallVector<std::pair<unsigned, Block *>, 4> predIDs;
SmallVector<BlockInfo, 4> predIDs;
for (auto *pred : block->getPredecessors())
predIDs.push_back({state->getSSANameState().getBlockID(pred), pred});
llvm::array_pod_sort(predIDs.begin(), predIDs.end());
predIDs.push_back(state->getSSANameState().getBlockInfo(pred));
llvm::sort(predIDs, [](BlockInfo lhs, BlockInfo rhs) {
return lhs.ordering < rhs.ordering;
});
os << " // " << predIDs.size() << " preds: ";
interleaveComma(predIDs, [&](std::pair<unsigned, Block *> pred) {
printBlockName(pred.second);
});
interleaveComma(predIDs, [&](BlockInfo pred) { os << pred.name; });
}
os << newLine;
}

View File

@ -36,7 +36,6 @@ func @pretty_printed_region_op(%arg0 : f32, %arg1 : f32) -> (f32) {
// -----
func @pretty_printed_region_op_deferred_loc(%arg0 : f32, %arg1 : f32) -> (f32) {
// CHECK-LOCATION: "test.pretty_printed_region"(%arg1, %arg0)
// CHECK-LOCATION: ^bb0(%arg[[x:[0-9]+]]: f32 loc("foo"), %arg[[y:[0-9]+]]: f32 loc("foo")
@ -47,3 +46,29 @@ func @pretty_printed_region_op_deferred_loc(%arg0 : f32, %arg1 : f32) -> (f32) {
%res = test.pretty_printed_region %arg1, %arg0 start special.op end : (f32, f32) -> (f32) loc("foo")
return %res : f32
}
// -----
// This tests the behavior of custom block names:
// operations like `test.block_names` can define custom names for blocks in
// nested regions.
// CHECK-CUSTOM-LABEL: func @block_names
func @block_names(%bool : i1) {
// CHECK: test.block_names
test.block_names {
// CHECK-CUSTOM: br ^foo1
// CHECK-GENERIC: cf.br{{.*}}^bb1
cf.br ^foo1
// CHECK-CUSTOM: ^foo1:
// CHECK-GENERIC: ^bb1:
^foo1:
// CHECK-CUSTOM: br ^foo2
// CHECK-GENERIC: cf.br{{.*}}^bb2
cf.br ^foo2
// CHECK-CUSTOM: ^foo2:
// CHECK-GENERIC: ^bb2:
^foo2:
"test.return"() : () -> ()
}
return
}

View File

@ -660,6 +660,24 @@ def DefaultDialectOp : TEST_Op<"default_dialect", [OpAsmOpInterface]> {
let assemblyFormat = "regions attr-dict-with-keyword";
}
// This is used to test the OpAsmOpInterface::getAsmBlockName() feature:
// blocks nested in a region under this op will have a name defined by the
// interface.
def AsmBlockNameOp : TEST_Op<"block_names", [OpAsmOpInterface]> {
let regions = (region AnyRegion:$body);
let extraClassDeclaration = [{
void getAsmBlockNames(mlir::OpAsmSetBlockNameFn setNameFn) {
std::string name;
int count = 0;
for (::mlir::Block &block : getRegion().getBlocks()) {
name = "foo" + std::to_string(count++);
setNameFn(&block, name);
}
}
}];
let assemblyFormat = "regions attr-dict-with-keyword";
}
// This operation requires its return type to have the trait 'TestTypeTrait'.
def ResultTypeWithTraitOp : TEST_Op<"result_type_with_trait", []> {
let results = (outs AnyType);