mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-12-15 12:09:51 +00:00
Add SymbolRefAttr to python bindings
Differential Revision: https://reviews.llvm.org/D154541
This commit is contained in:
parent
e8ed6e35bd
commit
4eee9ef976
@ -283,9 +283,6 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx,
|
||||
MLIR_CAPI_EXPORTED MlirStringRef
|
||||
mlirFlatSymbolRefAttrGetValue(MlirAttribute attr);
|
||||
|
||||
/// Returns the typeID of an FlatSymbolRef attribute.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirFlatSymbolRefAttrGetTypeID(void);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type attribute.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -442,14 +442,59 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class PySymbolRefAttribute : public PyConcreteAttribute<PySymbolRefAttribute> {
|
||||
public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef;
|
||||
static constexpr const char *pyClassName = "SymbolRefAttr";
|
||||
using PyConcreteAttribute::PyConcreteAttribute;
|
||||
|
||||
static MlirAttribute fromList(const std::vector<std::string> &symbols,
|
||||
PyMlirContext &context) {
|
||||
if (symbols.empty())
|
||||
throw std::runtime_error("SymbolRefAttr must be composed of at least "
|
||||
"one symbol.");
|
||||
MlirStringRef rootSymbol = toMlirStringRef(symbols[0]);
|
||||
SmallVector<MlirAttribute, 3> referenceAttrs;
|
||||
for (size_t i = 1; i < symbols.size(); ++i) {
|
||||
referenceAttrs.push_back(
|
||||
mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i])));
|
||||
}
|
||||
return mlirSymbolRefAttrGet(context.get(), rootSymbol,
|
||||
referenceAttrs.size(), referenceAttrs.data());
|
||||
}
|
||||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def_static(
|
||||
"get",
|
||||
[](const std::vector<std::string> &symbols,
|
||||
DefaultingPyMlirContext context) {
|
||||
return PySymbolRefAttribute::fromList(symbols, context.resolve());
|
||||
},
|
||||
py::arg("symbols"), py::arg("context") = py::none(),
|
||||
"Gets a uniqued SymbolRef attribute from a list of symbol names");
|
||||
c.def_property_readonly(
|
||||
"value",
|
||||
[](PySymbolRefAttribute &self) {
|
||||
std::vector<std::string> symbols = {
|
||||
unwrap(mlirSymbolRefAttrGetRootReference(self)).str()};
|
||||
for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self);
|
||||
++i)
|
||||
symbols.push_back(
|
||||
unwrap(mlirSymbolRefAttrGetRootReference(
|
||||
mlirSymbolRefAttrGetNestedReference(self, i)))
|
||||
.str());
|
||||
return symbols;
|
||||
},
|
||||
"Returns the value of the SymbolRef attribute as a list[str]");
|
||||
}
|
||||
};
|
||||
|
||||
class PyFlatSymbolRefAttribute
|
||||
: public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
|
||||
public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
|
||||
static constexpr const char *pyClassName = "FlatSymbolRefAttr";
|
||||
using PyConcreteAttribute::PyConcreteAttribute;
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
||||
mlirFlatSymbolRefAttrGetTypeID;
|
||||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def_static(
|
||||
@ -1167,6 +1212,16 @@ py::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) {
|
||||
throw py::cast_error(msg);
|
||||
}
|
||||
|
||||
py::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) {
|
||||
if (PyFlatSymbolRefAttribute::isaFunction(pyAttribute))
|
||||
return py::cast(PyFlatSymbolRefAttribute(pyAttribute));
|
||||
if (PySymbolRefAttribute::isaFunction(pyAttribute))
|
||||
return py::cast(PySymbolRefAttribute(pyAttribute));
|
||||
std::string msg = std::string("Can't cast unknown SymbolRef attribute (") +
|
||||
std::string(py::repr(py::cast(pyAttribute))) + ")";
|
||||
throw py::cast_error(msg);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::python::populateIRAttributes(py::module &m) {
|
||||
@ -1201,6 +1256,11 @@ void mlir::python::populateIRAttributes(py::module &m) {
|
||||
pybind11::cpp_function(denseIntOrFPElementsAttributeCaster));
|
||||
|
||||
PyDictAttribute::bind(m);
|
||||
PySymbolRefAttribute::bind(m);
|
||||
PyGlobals::get().registerTypeCaster(
|
||||
mlirSymbolRefAttrGetTypeID(),
|
||||
pybind11::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster));
|
||||
|
||||
PyFlatSymbolRefAttribute::bind(m);
|
||||
PyOpaqueAttribute::bind(m);
|
||||
PyFloatAttribute::bind(m);
|
||||
|
@ -3131,13 +3131,13 @@ void mlir::python::populateIRCore(py::module &m) {
|
||||
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
|
||||
.def_static(
|
||||
"parse",
|
||||
[](std::string attrSpec, DefaultingPyMlirContext context) {
|
||||
[](const std::string &attrSpec, DefaultingPyMlirContext context) {
|
||||
PyMlirContext::ErrorCapture errors(context->getRef());
|
||||
MlirAttribute type = mlirAttributeParseGet(
|
||||
MlirAttribute attr = mlirAttributeParseGet(
|
||||
context->get(), toMlirStringRef(attrSpec));
|
||||
if (mlirAttributeIsNull(type))
|
||||
if (mlirAttributeIsNull(attr))
|
||||
throw MLIRError("Unable to parse attribute", errors.take());
|
||||
return type;
|
||||
return attr;
|
||||
},
|
||||
py::arg("asm"), py::arg("context") = py::none(),
|
||||
"Parses an attribute from an assembly form. Raises an MLIRError on "
|
||||
|
@ -305,10 +305,6 @@ MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) {
|
||||
return wrap(llvm::cast<FlatSymbolRefAttr>(unwrap(attr)).getValue());
|
||||
}
|
||||
|
||||
MlirTypeID mlirFlatSymbolRefAttrGetTypeID(void) {
|
||||
return wrap(FlatSymbolRefAttr::getTypeID());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type attribute.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -73,6 +73,14 @@ def _symbolNameAttr(x, context):
|
||||
|
||||
@register_attribute_builder("SymbolRefAttr")
|
||||
def _symbolRefAttr(x, context):
|
||||
if isinstance(x, list):
|
||||
return SymbolRefAttr.get(x, context=context)
|
||||
else:
|
||||
return FlatSymbolRefAttr.get(x, context=context)
|
||||
|
||||
|
||||
@register_attribute_builder("FlatSymbolRefAttr")
|
||||
def _flatSymbolRefAttr(x, context):
|
||||
return FlatSymbolRefAttr.get(x, context=context)
|
||||
|
||||
|
||||
@ -105,6 +113,7 @@ def _f64ArrayAttr(x, context):
|
||||
def _denseI64ArrayAttr(x, context):
|
||||
return DenseI64ArrayAttr.get(x, context=context)
|
||||
|
||||
|
||||
@register_attribute_builder("DenseBoolArrayAttr")
|
||||
def _denseBoolArrayAttr(x, context):
|
||||
return DenseBoolArrayAttr.get(x, context=context)
|
||||
|
@ -2,7 +2,7 @@
|
||||
# This is just a smoke test that the dialect is functional.
|
||||
|
||||
from mlir.ir import *
|
||||
from mlir.dialects import ml_program
|
||||
from mlir.dialects import ml_program, arith, builtin
|
||||
|
||||
|
||||
def constructAndPrintInModule(f):
|
||||
@ -26,3 +26,21 @@ def testFuncOp():
|
||||
with InsertionPoint(block):
|
||||
# CHECK: ml_program.return
|
||||
ml_program.ReturnOp([block.arguments[0]])
|
||||
|
||||
|
||||
# CHECK-LABEL: testGlobalStoreOp
|
||||
@constructAndPrintInModule
|
||||
def testGlobalStoreOp():
|
||||
# CHECK: %cst = arith.constant 4.242000e+01 : f32
|
||||
cst = arith.ConstantOp(value=42.42, result=F32Type.get())
|
||||
|
||||
m = builtin.ModuleOp()
|
||||
m.sym_name = StringAttr.get("symbol1")
|
||||
m.sym_visibility = StringAttr.get("public")
|
||||
# CHECK: module @symbol1 attributes {sym_visibility = "public"} {
|
||||
# CHECK: ml_program.global public mutable @symbol2 : f32
|
||||
# CHECK: }
|
||||
with InsertionPoint(m.body):
|
||||
ml_program.GlobalOp("symbol2", F32Type.get(), is_mutable=True)
|
||||
# CHECK: ml_program.global_store @symbol1::@symbol2 = %cst : f32
|
||||
ml_program.GlobalStoreOp(["symbol1", "symbol2"], cst)
|
||||
|
@ -228,7 +228,7 @@ def testBoolAttr():
|
||||
@run
|
||||
def testFlatSymbolRefAttr():
|
||||
with Context() as ctx:
|
||||
sattr = FlatSymbolRefAttr(Attribute.parse("@symbol"))
|
||||
sattr = Attribute.parse("@symbol")
|
||||
# CHECK: symattr value: symbol
|
||||
print("symattr value:", sattr.value)
|
||||
|
||||
@ -237,6 +237,21 @@ def testFlatSymbolRefAttr():
|
||||
print("default_get:", FlatSymbolRefAttr.get("foobar"))
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testSymbolRefAttr
|
||||
@run
|
||||
def testSymbolRefAttr():
|
||||
with Context() as ctx:
|
||||
sattr = Attribute.parse("@symbol1::@symbol2")
|
||||
# CHECK: symattr value: ['symbol1', 'symbol2']
|
||||
print("symattr value:", sattr.value)
|
||||
|
||||
# CHECK: default_get: @symbol1::@symbol2
|
||||
print("default_get:", SymbolRefAttr.get(["symbol1", "symbol2"]))
|
||||
|
||||
# CHECK: default_get: @"@symbol1"::@"@symbol2"
|
||||
print("default_get:", SymbolRefAttr.get(["@symbol1", "@symbol2"]))
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testOpaqueAttr
|
||||
@run
|
||||
def testOpaqueAttr():
|
||||
|
Loading…
Reference in New Issue
Block a user