Add SymbolRefAttr to python bindings

Differential Revision: https://reviews.llvm.org/D154541
This commit is contained in:
max 2023-07-05 15:02:59 -05:00
parent e8ed6e35bd
commit 4eee9ef976
7 changed files with 110 additions and 15 deletions

View File

@ -283,9 +283,6 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx,
MLIR_CAPI_EXPORTED MlirStringRef
mlirFlatSymbolRefAttrGetValue(MlirAttribute attr);
/// Returns the typeID of an FlatSymbolRef attribute.
MLIR_CAPI_EXPORTED MlirTypeID mlirFlatSymbolRefAttrGetTypeID(void);
//===----------------------------------------------------------------------===//
// Type attribute.
//===----------------------------------------------------------------------===//

View File

@ -442,14 +442,59 @@ public:
}
};
class PySymbolRefAttribute : public PyConcreteAttribute<PySymbolRefAttribute> {
public:
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef;
static constexpr const char *pyClassName = "SymbolRefAttr";
using PyConcreteAttribute::PyConcreteAttribute;
static MlirAttribute fromList(const std::vector<std::string> &symbols,
PyMlirContext &context) {
if (symbols.empty())
throw std::runtime_error("SymbolRefAttr must be composed of at least "
"one symbol.");
MlirStringRef rootSymbol = toMlirStringRef(symbols[0]);
SmallVector<MlirAttribute, 3> referenceAttrs;
for (size_t i = 1; i < symbols.size(); ++i) {
referenceAttrs.push_back(
mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i])));
}
return mlirSymbolRefAttrGet(context.get(), rootSymbol,
referenceAttrs.size(), referenceAttrs.data());
}
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](const std::vector<std::string> &symbols,
DefaultingPyMlirContext context) {
return PySymbolRefAttribute::fromList(symbols, context.resolve());
},
py::arg("symbols"), py::arg("context") = py::none(),
"Gets a uniqued SymbolRef attribute from a list of symbol names");
c.def_property_readonly(
"value",
[](PySymbolRefAttribute &self) {
std::vector<std::string> symbols = {
unwrap(mlirSymbolRefAttrGetRootReference(self)).str()};
for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self);
++i)
symbols.push_back(
unwrap(mlirSymbolRefAttrGetRootReference(
mlirSymbolRefAttrGetNestedReference(self, i)))
.str());
return symbols;
},
"Returns the value of the SymbolRef attribute as a list[str]");
}
};
class PyFlatSymbolRefAttribute
: public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
public:
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
static constexpr const char *pyClassName = "FlatSymbolRefAttr";
using PyConcreteAttribute::PyConcreteAttribute;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirFlatSymbolRefAttrGetTypeID;
static void bindDerived(ClassTy &c) {
c.def_static(
@ -1167,6 +1212,16 @@ py::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) {
throw py::cast_error(msg);
}
py::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) {
if (PyFlatSymbolRefAttribute::isaFunction(pyAttribute))
return py::cast(PyFlatSymbolRefAttribute(pyAttribute));
if (PySymbolRefAttribute::isaFunction(pyAttribute))
return py::cast(PySymbolRefAttribute(pyAttribute));
std::string msg = std::string("Can't cast unknown SymbolRef attribute (") +
std::string(py::repr(py::cast(pyAttribute))) + ")";
throw py::cast_error(msg);
}
} // namespace
void mlir::python::populateIRAttributes(py::module &m) {
@ -1201,6 +1256,11 @@ void mlir::python::populateIRAttributes(py::module &m) {
pybind11::cpp_function(denseIntOrFPElementsAttributeCaster));
PyDictAttribute::bind(m);
PySymbolRefAttribute::bind(m);
PyGlobals::get().registerTypeCaster(
mlirSymbolRefAttrGetTypeID(),
pybind11::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster));
PyFlatSymbolRefAttribute::bind(m);
PyOpaqueAttribute::bind(m);
PyFloatAttribute::bind(m);

View File

@ -3131,13 +3131,13 @@ void mlir::python::populateIRCore(py::module &m) {
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
.def_static(
"parse",
[](std::string attrSpec, DefaultingPyMlirContext context) {
[](const std::string &attrSpec, DefaultingPyMlirContext context) {
PyMlirContext::ErrorCapture errors(context->getRef());
MlirAttribute type = mlirAttributeParseGet(
MlirAttribute attr = mlirAttributeParseGet(
context->get(), toMlirStringRef(attrSpec));
if (mlirAttributeIsNull(type))
if (mlirAttributeIsNull(attr))
throw MLIRError("Unable to parse attribute", errors.take());
return type;
return attr;
},
py::arg("asm"), py::arg("context") = py::none(),
"Parses an attribute from an assembly form. Raises an MLIRError on "

View File

@ -305,10 +305,6 @@ MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) {
return wrap(llvm::cast<FlatSymbolRefAttr>(unwrap(attr)).getValue());
}
MlirTypeID mlirFlatSymbolRefAttrGetTypeID(void) {
return wrap(FlatSymbolRefAttr::getTypeID());
}
//===----------------------------------------------------------------------===//
// Type attribute.
//===----------------------------------------------------------------------===//

View File

@ -73,6 +73,14 @@ def _symbolNameAttr(x, context):
@register_attribute_builder("SymbolRefAttr")
def _symbolRefAttr(x, context):
if isinstance(x, list):
return SymbolRefAttr.get(x, context=context)
else:
return FlatSymbolRefAttr.get(x, context=context)
@register_attribute_builder("FlatSymbolRefAttr")
def _flatSymbolRefAttr(x, context):
return FlatSymbolRefAttr.get(x, context=context)
@ -105,6 +113,7 @@ def _f64ArrayAttr(x, context):
def _denseI64ArrayAttr(x, context):
return DenseI64ArrayAttr.get(x, context=context)
@register_attribute_builder("DenseBoolArrayAttr")
def _denseBoolArrayAttr(x, context):
return DenseBoolArrayAttr.get(x, context=context)

View File

@ -2,7 +2,7 @@
# This is just a smoke test that the dialect is functional.
from mlir.ir import *
from mlir.dialects import ml_program
from mlir.dialects import ml_program, arith, builtin
def constructAndPrintInModule(f):
@ -26,3 +26,21 @@ def testFuncOp():
with InsertionPoint(block):
# CHECK: ml_program.return
ml_program.ReturnOp([block.arguments[0]])
# CHECK-LABEL: testGlobalStoreOp
@constructAndPrintInModule
def testGlobalStoreOp():
# CHECK: %cst = arith.constant 4.242000e+01 : f32
cst = arith.ConstantOp(value=42.42, result=F32Type.get())
m = builtin.ModuleOp()
m.sym_name = StringAttr.get("symbol1")
m.sym_visibility = StringAttr.get("public")
# CHECK: module @symbol1 attributes {sym_visibility = "public"} {
# CHECK: ml_program.global public mutable @symbol2 : f32
# CHECK: }
with InsertionPoint(m.body):
ml_program.GlobalOp("symbol2", F32Type.get(), is_mutable=True)
# CHECK: ml_program.global_store @symbol1::@symbol2 = %cst : f32
ml_program.GlobalStoreOp(["symbol1", "symbol2"], cst)

View File

@ -228,7 +228,7 @@ def testBoolAttr():
@run
def testFlatSymbolRefAttr():
with Context() as ctx:
sattr = FlatSymbolRefAttr(Attribute.parse("@symbol"))
sattr = Attribute.parse("@symbol")
# CHECK: symattr value: symbol
print("symattr value:", sattr.value)
@ -237,6 +237,21 @@ def testFlatSymbolRefAttr():
print("default_get:", FlatSymbolRefAttr.get("foobar"))
# CHECK-LABEL: TEST: testSymbolRefAttr
@run
def testSymbolRefAttr():
with Context() as ctx:
sattr = Attribute.parse("@symbol1::@symbol2")
# CHECK: symattr value: ['symbol1', 'symbol2']
print("symattr value:", sattr.value)
# CHECK: default_get: @symbol1::@symbol2
print("default_get:", SymbolRefAttr.get(["symbol1", "symbol2"]))
# CHECK: default_get: @"@symbol1"::@"@symbol2"
print("default_get:", SymbolRefAttr.get(["@symbol1", "@symbol2"]))
# CHECK-LABEL: TEST: testOpaqueAttr
@run
def testOpaqueAttr():