mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-12-15 20:12:50 +00:00
[mlir][python] Allow specifying block arg locations
Currently blocks are always created with UnknownLoc's for their arguments. This adds an `arg_locs` argument to all block creation APIs, which takes an optional sequence of locations to use, one per block argument. If no locations are supplied, the current Location context is used. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D150084
This commit is contained in:
parent
07edc1c16f
commit
514dddbeba
@ -194,6 +194,31 @@ static MlirStringRef toMlirStringRef(const std::string &s) {
|
|||||||
return mlirStringRefCreate(s.data(), s.size());
|
return mlirStringRefCreate(s.data(), s.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create a block, using the current location context if no locations are
|
||||||
|
/// specified.
|
||||||
|
static MlirBlock createBlock(const py::sequence &pyArgTypes,
|
||||||
|
const std::optional<py::sequence> &pyArgLocs) {
|
||||||
|
SmallVector<MlirType> argTypes;
|
||||||
|
argTypes.reserve(pyArgTypes.size());
|
||||||
|
for (const auto &pyType : pyArgTypes)
|
||||||
|
argTypes.push_back(pyType.cast<PyType &>());
|
||||||
|
|
||||||
|
SmallVector<MlirLocation> argLocs;
|
||||||
|
if (pyArgLocs) {
|
||||||
|
argLocs.reserve(pyArgLocs->size());
|
||||||
|
for (const auto &pyLoc : *pyArgLocs)
|
||||||
|
argLocs.push_back(pyLoc.cast<PyLocation &>());
|
||||||
|
} else if (!argTypes.empty()) {
|
||||||
|
argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (argTypes.size() != argLocs.size())
|
||||||
|
throw py::value_error(("Expected " + Twine(argTypes.size()) +
|
||||||
|
" locations, got: " + Twine(argLocs.size()))
|
||||||
|
.str());
|
||||||
|
return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
|
||||||
|
}
|
||||||
|
|
||||||
/// Wrapper for the global LLVM debugging flag.
|
/// Wrapper for the global LLVM debugging flag.
|
||||||
struct PyGlobalDebugFlag {
|
struct PyGlobalDebugFlag {
|
||||||
static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
|
static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
|
||||||
@ -363,21 +388,10 @@ public:
|
|||||||
throw py::index_error("attempt to access out of bounds block");
|
throw py::index_error("attempt to access out of bounds block");
|
||||||
}
|
}
|
||||||
|
|
||||||
PyBlock appendBlock(const py::args &pyArgTypes) {
|
PyBlock appendBlock(const py::args &pyArgTypes,
|
||||||
|
const std::optional<py::sequence> &pyArgLocs) {
|
||||||
operation->checkValid();
|
operation->checkValid();
|
||||||
llvm::SmallVector<MlirType, 4> argTypes;
|
MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
|
||||||
llvm::SmallVector<MlirLocation, 4> argLocs;
|
|
||||||
argTypes.reserve(pyArgTypes.size());
|
|
||||||
argLocs.reserve(pyArgTypes.size());
|
|
||||||
for (auto &pyArg : pyArgTypes) {
|
|
||||||
argTypes.push_back(pyArg.cast<PyType &>());
|
|
||||||
// TODO: Pass in a proper location here.
|
|
||||||
argLocs.push_back(
|
|
||||||
mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
|
|
||||||
}
|
|
||||||
|
|
||||||
MlirBlock block =
|
|
||||||
mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
|
|
||||||
mlirRegionAppendOwnedBlock(region, block);
|
mlirRegionAppendOwnedBlock(region, block);
|
||||||
return PyBlock(operation, block);
|
return PyBlock(operation, block);
|
||||||
}
|
}
|
||||||
@ -387,7 +401,8 @@ public:
|
|||||||
.def("__getitem__", &PyBlockList::dunderGetItem)
|
.def("__getitem__", &PyBlockList::dunderGetItem)
|
||||||
.def("__iter__", &PyBlockList::dunderIter)
|
.def("__iter__", &PyBlockList::dunderIter)
|
||||||
.def("__len__", &PyBlockList::dunderLen)
|
.def("__len__", &PyBlockList::dunderLen)
|
||||||
.def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
|
.def("append", &PyBlockList::appendBlock, kAppendBlockDocstring,
|
||||||
|
py::arg("arg_locs") = std::nullopt);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -2978,27 +2993,17 @@ void mlir::python::populateIRCore(py::module &m) {
|
|||||||
"Returns a forward-optimized sequence of operations.")
|
"Returns a forward-optimized sequence of operations.")
|
||||||
.def_static(
|
.def_static(
|
||||||
"create_at_start",
|
"create_at_start",
|
||||||
[](PyRegion &parent, py::list pyArgTypes) {
|
[](PyRegion &parent, const py::list &pyArgTypes,
|
||||||
|
const std::optional<py::sequence> &pyArgLocs) {
|
||||||
parent.checkValid();
|
parent.checkValid();
|
||||||
llvm::SmallVector<MlirType, 4> argTypes;
|
MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
|
||||||
llvm::SmallVector<MlirLocation, 4> argLocs;
|
|
||||||
argTypes.reserve(pyArgTypes.size());
|
|
||||||
argLocs.reserve(pyArgTypes.size());
|
|
||||||
for (auto &pyArg : pyArgTypes) {
|
|
||||||
argTypes.push_back(pyArg.cast<PyType &>());
|
|
||||||
// TODO: Pass in a proper location here.
|
|
||||||
argLocs.push_back(
|
|
||||||
mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
|
|
||||||
}
|
|
||||||
|
|
||||||
MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(),
|
|
||||||
argLocs.data());
|
|
||||||
mlirRegionInsertOwnedBlock(parent, 0, block);
|
mlirRegionInsertOwnedBlock(parent, 0, block);
|
||||||
return PyBlock(parent.getParentOperation(), block);
|
return PyBlock(parent.getParentOperation(), block);
|
||||||
},
|
},
|
||||||
py::arg("parent"), py::arg("arg_types") = py::list(),
|
py::arg("parent"), py::arg("arg_types") = py::list(),
|
||||||
|
py::arg("arg_locs") = std::nullopt,
|
||||||
"Creates and returns a new Block at the beginning of the given "
|
"Creates and returns a new Block at the beginning of the given "
|
||||||
"region (with given argument types).")
|
"region (with given argument types and locations).")
|
||||||
.def(
|
.def(
|
||||||
"append_to",
|
"append_to",
|
||||||
[](PyBlock &self, PyRegion ®ion) {
|
[](PyBlock &self, PyRegion ®ion) {
|
||||||
@ -3010,50 +3015,30 @@ void mlir::python::populateIRCore(py::module &m) {
|
|||||||
"Append this block to a region, transferring ownership if necessary")
|
"Append this block to a region, transferring ownership if necessary")
|
||||||
.def(
|
.def(
|
||||||
"create_before",
|
"create_before",
|
||||||
[](PyBlock &self, py::args pyArgTypes) {
|
[](PyBlock &self, const py::args &pyArgTypes,
|
||||||
|
const std::optional<py::sequence> &pyArgLocs) {
|
||||||
self.checkValid();
|
self.checkValid();
|
||||||
llvm::SmallVector<MlirType, 4> argTypes;
|
MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
|
||||||
llvm::SmallVector<MlirLocation, 4> argLocs;
|
|
||||||
argTypes.reserve(pyArgTypes.size());
|
|
||||||
argLocs.reserve(pyArgTypes.size());
|
|
||||||
for (auto &pyArg : pyArgTypes) {
|
|
||||||
argTypes.push_back(pyArg.cast<PyType &>());
|
|
||||||
// TODO: Pass in a proper location here.
|
|
||||||
argLocs.push_back(
|
|
||||||
mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
|
|
||||||
}
|
|
||||||
|
|
||||||
MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(),
|
|
||||||
argLocs.data());
|
|
||||||
MlirRegion region = mlirBlockGetParentRegion(self.get());
|
MlirRegion region = mlirBlockGetParentRegion(self.get());
|
||||||
mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
|
mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
|
||||||
return PyBlock(self.getParentOperation(), block);
|
return PyBlock(self.getParentOperation(), block);
|
||||||
},
|
},
|
||||||
|
py::arg("arg_locs") = std::nullopt,
|
||||||
"Creates and returns a new Block before this block "
|
"Creates and returns a new Block before this block "
|
||||||
"(with given argument types).")
|
"(with given argument types and locations).")
|
||||||
.def(
|
.def(
|
||||||
"create_after",
|
"create_after",
|
||||||
[](PyBlock &self, py::args pyArgTypes) {
|
[](PyBlock &self, const py::args &pyArgTypes,
|
||||||
|
const std::optional<py::sequence> &pyArgLocs) {
|
||||||
self.checkValid();
|
self.checkValid();
|
||||||
llvm::SmallVector<MlirType, 4> argTypes;
|
MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
|
||||||
llvm::SmallVector<MlirLocation, 4> argLocs;
|
|
||||||
argTypes.reserve(pyArgTypes.size());
|
|
||||||
argLocs.reserve(pyArgTypes.size());
|
|
||||||
for (auto &pyArg : pyArgTypes) {
|
|
||||||
argTypes.push_back(pyArg.cast<PyType &>());
|
|
||||||
|
|
||||||
// TODO: Pass in a proper location here.
|
|
||||||
argLocs.push_back(
|
|
||||||
mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
|
|
||||||
}
|
|
||||||
MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(),
|
|
||||||
argLocs.data());
|
|
||||||
MlirRegion region = mlirBlockGetParentRegion(self.get());
|
MlirRegion region = mlirBlockGetParentRegion(self.get());
|
||||||
mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
|
mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
|
||||||
return PyBlock(self.getParentOperation(), block);
|
return PyBlock(self.getParentOperation(), block);
|
||||||
},
|
},
|
||||||
|
py::arg("arg_locs") = std::nullopt,
|
||||||
"Creates and returns a new Block after this block "
|
"Creates and returns a new Block after this block "
|
||||||
"(with given argument types).")
|
"(with given argument types and locations).")
|
||||||
.def(
|
.def(
|
||||||
"__iter__",
|
"__iter__",
|
||||||
[](PyBlock &self) {
|
[](PyBlock &self) {
|
||||||
|
@ -90,7 +90,7 @@ class FuncOp:
|
|||||||
raise IndexError('External function does not have a body')
|
raise IndexError('External function does not have a body')
|
||||||
return self.regions[0].blocks[0]
|
return self.regions[0].blocks[0]
|
||||||
|
|
||||||
def add_entry_block(self):
|
def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None):
|
||||||
"""
|
"""
|
||||||
Add an entry block to the function body using the function signature to
|
Add an entry block to the function body using the function signature to
|
||||||
infer block arguments.
|
infer block arguments.
|
||||||
@ -98,7 +98,7 @@ class FuncOp:
|
|||||||
"""
|
"""
|
||||||
if not self.is_external:
|
if not self.is_external:
|
||||||
raise IndexError('The function already has an entry block!')
|
raise IndexError('The function already has an entry block!')
|
||||||
self.body.blocks.append(*self.type.inputs)
|
self.body.blocks.append(*self.type.inputs, arg_locs=arg_locs)
|
||||||
return self.body.blocks[0]
|
return self.body.blocks[0]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -18,28 +18,28 @@ def run(f):
|
|||||||
|
|
||||||
|
|
||||||
# CHECK-LABEL: TEST: testBlockCreation
|
# CHECK-LABEL: TEST: testBlockCreation
|
||||||
# CHECK: func @test(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i16)
|
# CHECK: func @test(%[[ARG0:.*]]: i32 loc("arg0"), %[[ARG1:.*]]: i16 loc("arg1"))
|
||||||
# CHECK: cf.br ^bb1(%[[ARG1]] : i16)
|
# CHECK: cf.br ^bb1(%[[ARG1]] : i16)
|
||||||
# CHECK: ^bb1(%[[PHI0:.*]]: i16):
|
# CHECK: ^bb1(%[[PHI0:.*]]: i16 loc("middle")):
|
||||||
# CHECK: cf.br ^bb2(%[[ARG0]] : i32)
|
# CHECK: cf.br ^bb2(%[[ARG0]] : i32)
|
||||||
# CHECK: ^bb2(%[[PHI1:.*]]: i32):
|
# CHECK: ^bb2(%[[PHI1:.*]]: i32 loc("successor")):
|
||||||
# CHECK: return
|
# CHECK: return
|
||||||
@run
|
@run
|
||||||
def testBlockCreation():
|
def testBlockCreation():
|
||||||
with Context() as ctx, Location.unknown():
|
with Context() as ctx, Location.unknown():
|
||||||
module = Module.create()
|
module = builtin.ModuleOp()
|
||||||
with InsertionPoint(module.body):
|
with InsertionPoint(module.body):
|
||||||
f_type = FunctionType.get(
|
f_type = FunctionType.get(
|
||||||
[IntegerType.get_signless(32),
|
[IntegerType.get_signless(32),
|
||||||
IntegerType.get_signless(16)], [])
|
IntegerType.get_signless(16)], [])
|
||||||
f_op = func.FuncOp("test", f_type)
|
f_op = func.FuncOp("test", f_type)
|
||||||
entry_block = f_op.add_entry_block()
|
entry_block = f_op.add_entry_block([Location.name("arg0"), Location.name("arg1")])
|
||||||
i32_arg, i16_arg = entry_block.arguments
|
i32_arg, i16_arg = entry_block.arguments
|
||||||
successor_block = entry_block.create_after(i32_arg.type)
|
successor_block = entry_block.create_after(i32_arg.type, arg_locs=[Location.name("successor")])
|
||||||
with InsertionPoint(successor_block) as successor_ip:
|
with InsertionPoint(successor_block) as successor_ip:
|
||||||
assert successor_ip.block == successor_block
|
assert successor_ip.block == successor_block
|
||||||
func.ReturnOp([])
|
func.ReturnOp([])
|
||||||
middle_block = successor_block.create_before(i16_arg.type)
|
middle_block = successor_block.create_before(i16_arg.type, arg_locs=[Location.name("middle")])
|
||||||
|
|
||||||
with InsertionPoint(entry_block) as entry_ip:
|
with InsertionPoint(entry_block) as entry_ip:
|
||||||
assert entry_ip.block == entry_block
|
assert entry_ip.block == entry_block
|
||||||
@ -48,27 +48,57 @@ def testBlockCreation():
|
|||||||
with InsertionPoint(middle_block) as middle_ip:
|
with InsertionPoint(middle_block) as middle_ip:
|
||||||
assert middle_ip.block == middle_block
|
assert middle_ip.block == middle_block
|
||||||
cf.BranchOp([i32_arg], dest=successor_block)
|
cf.BranchOp([i32_arg], dest=successor_block)
|
||||||
print(module.operation)
|
module.print(enable_debug_info=True)
|
||||||
# Ensure region back references are coherent.
|
# Ensure region back references are coherent.
|
||||||
assert entry_block.region == middle_block.region == successor_block.region
|
assert entry_block.region == middle_block.region == successor_block.region
|
||||||
|
|
||||||
|
|
||||||
|
# CHECK-LABEL: TEST: testBlockCreationArgLocs
|
||||||
|
@run
|
||||||
|
def testBlockCreationArgLocs():
|
||||||
|
with Context() as ctx:
|
||||||
|
ctx.allow_unregistered_dialects = True
|
||||||
|
f32 = F32Type.get()
|
||||||
|
op = Operation.create("test", regions=1, loc=Location.unknown())
|
||||||
|
blocks = op.regions[0].blocks
|
||||||
|
|
||||||
|
with Location.name("default_loc"):
|
||||||
|
blocks.append(f32)
|
||||||
|
blocks.append()
|
||||||
|
# CHECK: ^bb0(%{{.+}}: f32 loc("default_loc")):
|
||||||
|
# CHECK-NEXT: ^bb1:
|
||||||
|
op.print(enable_debug_info=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
blocks.append(f32)
|
||||||
|
except RuntimeError as err:
|
||||||
|
# CHECK: Missing loc: An MLIR function requires a Location but none was provided
|
||||||
|
print("Missing loc:", err)
|
||||||
|
|
||||||
|
try:
|
||||||
|
blocks.append(f32, f32, arg_locs=[Location.unknown()])
|
||||||
|
except ValueError as err:
|
||||||
|
# CHECK: Wrong loc count: Expected 2 locations, got: 1
|
||||||
|
print("Wrong loc count:", err)
|
||||||
|
|
||||||
|
|
||||||
# CHECK-LABEL: TEST: testFirstBlockCreation
|
# CHECK-LABEL: TEST: testFirstBlockCreation
|
||||||
# CHECK: func @test(%{{.*}}: f32)
|
# CHECK: func @test(%{{.*}}: f32 loc("arg_loc"))
|
||||||
# CHECK: return
|
# CHECK: return
|
||||||
@run
|
@run
|
||||||
def testFirstBlockCreation():
|
def testFirstBlockCreation():
|
||||||
with Context() as ctx, Location.unknown():
|
with Context() as ctx, Location.unknown():
|
||||||
module = Module.create()
|
module = builtin.ModuleOp()
|
||||||
f32 = F32Type.get()
|
f32 = F32Type.get()
|
||||||
with InsertionPoint(module.body):
|
with InsertionPoint(module.body):
|
||||||
f = func.FuncOp("test", ([f32], []))
|
f = func.FuncOp("test", ([f32], []))
|
||||||
entry_block = Block.create_at_start(f.operation.regions[0], [f32])
|
entry_block = Block.create_at_start(f.operation.regions[0],
|
||||||
|
[f32], [Location.name("arg_loc")])
|
||||||
with InsertionPoint(entry_block):
|
with InsertionPoint(entry_block):
|
||||||
func.ReturnOp([])
|
func.ReturnOp([])
|
||||||
|
|
||||||
print(module)
|
module.print(enable_debug_info=True)
|
||||||
assert module.operation.verify()
|
assert module.verify()
|
||||||
assert f.body.blocks[0] == entry_block
|
assert f.body.blocks[0] == entry_block
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user