mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-12-14 03:29:57 +00:00
Add a new interface method getAsmBlockName()
on OpAsmOpInterface to control block names
This allows operations to control the block ids used by the printer in nested regions. Reviewed By: Mogball Differential Revision: https://reviews.llvm.org/D115849
This commit is contained in:
parent
3571bdb4f3
commit
b055e6d313
@ -62,6 +62,36 @@ def OpAsmOpInterface : OpInterface<"OpAsmOpInterface"> {
|
||||
),
|
||||
"", "return;"
|
||||
>,
|
||||
InterfaceMethod<[{
|
||||
Get the name to use for a given block inside a region attached to this
|
||||
operation.
|
||||
|
||||
For example if this operation has multiple blocks:
|
||||
|
||||
```mlir
|
||||
some.op() ({
|
||||
^bb0:
|
||||
...
|
||||
^bb1:
|
||||
...
|
||||
})
|
||||
```
|
||||
|
||||
the method will be invoked on each of the blocks allowing the op to
|
||||
print:
|
||||
|
||||
```mlir
|
||||
some.op() ({
|
||||
^custom_foo_name:
|
||||
...
|
||||
^custom_bar_name:
|
||||
...
|
||||
})
|
||||
```
|
||||
}],
|
||||
"void", "getAsmBlockNames",
|
||||
(ins "::mlir::OpAsmSetBlockNameFn":$setNameFn), "", ";"
|
||||
>,
|
||||
StaticInterfaceMethod<[{
|
||||
Return the default dialect used when printing/parsing operations in
|
||||
regions nested under this operation. This allows for eliding the dialect
|
||||
|
@ -1322,6 +1322,10 @@ private:
|
||||
/// operation. See 'getAsmResultNames' below for more details.
|
||||
using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>;
|
||||
|
||||
/// A functor used to set the name of blocks in regions directly nested under
|
||||
/// an operation.
|
||||
using OpAsmSetBlockNameFn = function_ref<void(Block *, StringRef)>;
|
||||
|
||||
class OpAsmDialectInterface
|
||||
: public DialectInterface::Base<OpAsmDialectInterface> {
|
||||
public:
|
||||
|
@ -791,6 +791,13 @@ void AliasState::printAliases(raw_ostream &os, NewLineCounter &newLine,
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
/// Info about block printing: a number which is its position in the visitation
|
||||
/// order, and a name that is used to print reference to it, e.g. ^bb42.
|
||||
struct BlockInfo {
|
||||
int ordering;
|
||||
StringRef name;
|
||||
};
|
||||
|
||||
/// This class manages the state of SSA value names.
|
||||
class SSANameState {
|
||||
public:
|
||||
@ -808,8 +815,8 @@ public:
|
||||
/// operation, or empty if none exist.
|
||||
ArrayRef<int> getOpResultGroups(Operation *op);
|
||||
|
||||
/// Get the ID for the given block.
|
||||
unsigned getBlockID(Block *block);
|
||||
/// Get the info for the given block.
|
||||
BlockInfo getBlockInfo(Block *block);
|
||||
|
||||
/// Renumber the arguments for the specified region to the same names as the
|
||||
/// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for
|
||||
@ -846,8 +853,9 @@ private:
|
||||
/// value of this map are the result numbers that start a result group.
|
||||
DenseMap<Operation *, SmallVector<int, 1>> opResultGroups;
|
||||
|
||||
/// This is the block ID for each block in the current.
|
||||
DenseMap<Block *, unsigned> blockIDs;
|
||||
/// This maps blocks to there visitation number in the current region as well
|
||||
/// as the string representing their name.
|
||||
DenseMap<Block *, BlockInfo> blockNames;
|
||||
|
||||
/// This keeps track of all of the non-numeric names that are in flight,
|
||||
/// allowing us to check for duplicates.
|
||||
@ -967,9 +975,10 @@ ArrayRef<int> SSANameState::getOpResultGroups(Operation *op) {
|
||||
return it == opResultGroups.end() ? ArrayRef<int>() : it->second;
|
||||
}
|
||||
|
||||
unsigned SSANameState::getBlockID(Block *block) {
|
||||
auto it = blockIDs.find(block);
|
||||
return it != blockIDs.end() ? it->second : NameSentinel;
|
||||
BlockInfo SSANameState::getBlockInfo(Block *block) {
|
||||
auto it = blockNames.find(block);
|
||||
BlockInfo invalidBlock{-1, "INVALIDBLOCK"};
|
||||
return it != blockNames.end() ? it->second : invalidBlock;
|
||||
}
|
||||
|
||||
void SSANameState::shadowRegionArgs(Region ®ion, ValueRange namesToUse) {
|
||||
@ -1021,7 +1030,16 @@ void SSANameState::numberValuesInRegion(Region ®ion) {
|
||||
for (auto &block : region) {
|
||||
// Each block gets a unique ID, and all of the operations within it get
|
||||
// numbered as well.
|
||||
blockIDs[&block] = nextBlockID++;
|
||||
auto blockInfoIt = blockNames.insert({&block, {-1, ""}});
|
||||
if (blockInfoIt.second) {
|
||||
// This block hasn't been named through `getAsmBlockArgumentNames`, use
|
||||
// default `^bbNNN` format.
|
||||
std::string name;
|
||||
llvm::raw_string_ostream(name) << "^bb" << nextBlockID;
|
||||
blockInfoIt.first->second.name = StringRef(name).copy(usedNameAllocator);
|
||||
}
|
||||
blockInfoIt.first->second.ordering = nextBlockID++;
|
||||
|
||||
numberValuesInBlock(block);
|
||||
}
|
||||
}
|
||||
@ -1048,11 +1066,6 @@ void SSANameState::numberValuesInBlock(Block &block) {
|
||||
}
|
||||
|
||||
void SSANameState::numberValuesInOp(Operation &op) {
|
||||
unsigned numResults = op.getNumResults();
|
||||
if (numResults == 0)
|
||||
return;
|
||||
Value resultBegin = op.getResult(0);
|
||||
|
||||
// Function used to set the special result names for the operation.
|
||||
SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0);
|
||||
auto setResultNameFn = [&](Value result, StringRef name) {
|
||||
@ -1064,11 +1077,34 @@ void SSANameState::numberValuesInOp(Operation &op) {
|
||||
if (int resultNo = result.cast<OpResult>().getResultNumber())
|
||||
resultGroups.push_back(resultNo);
|
||||
};
|
||||
// Operations can customize the printing of block names in OpAsmOpInterface.
|
||||
auto setBlockNameFn = [&](Block *block, StringRef name) {
|
||||
assert(block->getParentOp() == &op &&
|
||||
"getAsmBlockArgumentNames callback invoked on a block not directly "
|
||||
"nested under the current operation");
|
||||
assert(!blockNames.count(block) && "block numbered multiple times");
|
||||
SmallString<16> tmpBuffer{"^"};
|
||||
name = sanitizeIdentifier(name, tmpBuffer);
|
||||
if (name.data() != tmpBuffer.data()) {
|
||||
tmpBuffer.append(name);
|
||||
name = tmpBuffer.str();
|
||||
}
|
||||
name = name.copy(usedNameAllocator);
|
||||
blockNames[block] = {-1, name};
|
||||
};
|
||||
|
||||
if (!printerFlags.shouldPrintGenericOpForm()) {
|
||||
if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op))
|
||||
if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) {
|
||||
asmInterface.getAsmBlockNames(setBlockNameFn);
|
||||
asmInterface.getAsmResultNames(setResultNameFn);
|
||||
}
|
||||
}
|
||||
|
||||
unsigned numResults = op.getNumResults();
|
||||
if (numResults == 0)
|
||||
return;
|
||||
Value resultBegin = op.getResult(0);
|
||||
|
||||
// If the first result wasn't numbered, give it a default number.
|
||||
if (valueIDs.try_emplace(resultBegin, nextValueID).second)
|
||||
++nextValueID;
|
||||
@ -2609,11 +2645,7 @@ void OperationPrinter::printGenericOp(Operation *op, bool printOpName) {
|
||||
}
|
||||
|
||||
void OperationPrinter::printBlockName(Block *block) {
|
||||
auto id = state->getSSANameState().getBlockID(block);
|
||||
if (id != SSANameState::NameSentinel)
|
||||
os << "^bb" << id;
|
||||
else
|
||||
os << "^INVALIDBLOCK";
|
||||
os << state->getSSANameState().getBlockInfo(block).name;
|
||||
}
|
||||
|
||||
void OperationPrinter::print(Block *block, bool printBlockArgs,
|
||||
@ -2647,18 +2679,18 @@ void OperationPrinter::print(Block *block, bool printBlockArgs,
|
||||
os << " // pred: ";
|
||||
printBlockName(pred);
|
||||
} else {
|
||||
// We want to print the predecessors in increasing numeric order, not in
|
||||
// We want to print the predecessors in a stable order, not in
|
||||
// whatever order the use-list is in, so gather and sort them.
|
||||
SmallVector<std::pair<unsigned, Block *>, 4> predIDs;
|
||||
SmallVector<BlockInfo, 4> predIDs;
|
||||
for (auto *pred : block->getPredecessors())
|
||||
predIDs.push_back({state->getSSANameState().getBlockID(pred), pred});
|
||||
llvm::array_pod_sort(predIDs.begin(), predIDs.end());
|
||||
predIDs.push_back(state->getSSANameState().getBlockInfo(pred));
|
||||
llvm::sort(predIDs, [](BlockInfo lhs, BlockInfo rhs) {
|
||||
return lhs.ordering < rhs.ordering;
|
||||
});
|
||||
|
||||
os << " // " << predIDs.size() << " preds: ";
|
||||
|
||||
interleaveComma(predIDs, [&](std::pair<unsigned, Block *> pred) {
|
||||
printBlockName(pred.second);
|
||||
});
|
||||
interleaveComma(predIDs, [&](BlockInfo pred) { os << pred.name; });
|
||||
}
|
||||
os << newLine;
|
||||
}
|
||||
|
@ -36,7 +36,6 @@ func @pretty_printed_region_op(%arg0 : f32, %arg1 : f32) -> (f32) {
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
func @pretty_printed_region_op_deferred_loc(%arg0 : f32, %arg1 : f32) -> (f32) {
|
||||
// CHECK-LOCATION: "test.pretty_printed_region"(%arg1, %arg0)
|
||||
// CHECK-LOCATION: ^bb0(%arg[[x:[0-9]+]]: f32 loc("foo"), %arg[[y:[0-9]+]]: f32 loc("foo")
|
||||
@ -47,3 +46,29 @@ func @pretty_printed_region_op_deferred_loc(%arg0 : f32, %arg1 : f32) -> (f32) {
|
||||
%res = test.pretty_printed_region %arg1, %arg0 start special.op end : (f32, f32) -> (f32) loc("foo")
|
||||
return %res : f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// This tests the behavior of custom block names:
|
||||
// operations like `test.block_names` can define custom names for blocks in
|
||||
// nested regions.
|
||||
// CHECK-CUSTOM-LABEL: func @block_names
|
||||
func @block_names(%bool : i1) {
|
||||
// CHECK: test.block_names
|
||||
test.block_names {
|
||||
// CHECK-CUSTOM: br ^foo1
|
||||
// CHECK-GENERIC: cf.br{{.*}}^bb1
|
||||
cf.br ^foo1
|
||||
// CHECK-CUSTOM: ^foo1:
|
||||
// CHECK-GENERIC: ^bb1:
|
||||
^foo1:
|
||||
// CHECK-CUSTOM: br ^foo2
|
||||
// CHECK-GENERIC: cf.br{{.*}}^bb2
|
||||
cf.br ^foo2
|
||||
// CHECK-CUSTOM: ^foo2:
|
||||
// CHECK-GENERIC: ^bb2:
|
||||
^foo2:
|
||||
"test.return"() : () -> ()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -660,6 +660,24 @@ def DefaultDialectOp : TEST_Op<"default_dialect", [OpAsmOpInterface]> {
|
||||
let assemblyFormat = "regions attr-dict-with-keyword";
|
||||
}
|
||||
|
||||
// This is used to test the OpAsmOpInterface::getAsmBlockName() feature:
|
||||
// blocks nested in a region under this op will have a name defined by the
|
||||
// interface.
|
||||
def AsmBlockNameOp : TEST_Op<"block_names", [OpAsmOpInterface]> {
|
||||
let regions = (region AnyRegion:$body);
|
||||
let extraClassDeclaration = [{
|
||||
void getAsmBlockNames(mlir::OpAsmSetBlockNameFn setNameFn) {
|
||||
std::string name;
|
||||
int count = 0;
|
||||
for (::mlir::Block &block : getRegion().getBlocks()) {
|
||||
name = "foo" + std::to_string(count++);
|
||||
setNameFn(&block, name);
|
||||
}
|
||||
}
|
||||
}];
|
||||
let assemblyFormat = "regions attr-dict-with-keyword";
|
||||
}
|
||||
|
||||
// This operation requires its return type to have the trait 'TestTypeTrait'.
|
||||
def ResultTypeWithTraitOp : TEST_Op<"result_type_with_trait", []> {
|
||||
let results = (outs AnyType);
|
||||
|
Loading…
Reference in New Issue
Block a user