diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h index b760dd0cdb9a..63198192453e 100644 --- a/mlir/include/mlir-c/BuiltinAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -283,9 +283,6 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, MLIR_CAPI_EXPORTED MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr); -/// Returns the typeID of an FlatSymbolRef attribute. -MLIR_CAPI_EXPORTED MlirTypeID mlirFlatSymbolRefAttrGetTypeID(void); - //===----------------------------------------------------------------------===// // Type attribute. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index 99881b35c96d..4ee06fa7a6d7 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -442,14 +442,59 @@ public: } }; +class PySymbolRefAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef; + static constexpr const char *pyClassName = "SymbolRefAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static MlirAttribute fromList(const std::vector &symbols, + PyMlirContext &context) { + if (symbols.empty()) + throw std::runtime_error("SymbolRefAttr must be composed of at least " + "one symbol."); + MlirStringRef rootSymbol = toMlirStringRef(symbols[0]); + SmallVector referenceAttrs; + for (size_t i = 1; i < symbols.size(); ++i) { + referenceAttrs.push_back( + mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i]))); + } + return mlirSymbolRefAttrGet(context.get(), rootSymbol, + referenceAttrs.size(), referenceAttrs.data()); + } + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](const std::vector &symbols, + DefaultingPyMlirContext context) { + return PySymbolRefAttribute::fromList(symbols, context.resolve()); + }, + py::arg("symbols"), py::arg("context") = py::none(), + "Gets a uniqued SymbolRef attribute from a list of symbol names"); + c.def_property_readonly( + "value", + [](PySymbolRefAttribute &self) { + std::vector symbols = { + unwrap(mlirSymbolRefAttrGetRootReference(self)).str()}; + for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self); + ++i) + symbols.push_back( + unwrap(mlirSymbolRefAttrGetRootReference( + mlirSymbolRefAttrGetNestedReference(self, i))) + .str()); + return symbols; + }, + "Returns the value of the SymbolRef attribute as a list[str]"); + } +}; + class PyFlatSymbolRefAttribute : public PyConcreteAttribute { public: static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef; static constexpr const char *pyClassName = "FlatSymbolRefAttr"; using PyConcreteAttribute::PyConcreteAttribute; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFlatSymbolRefAttrGetTypeID; static void bindDerived(ClassTy &c) { c.def_static( @@ -1167,6 +1212,16 @@ py::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) { throw py::cast_error(msg); } +py::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) { + if (PyFlatSymbolRefAttribute::isaFunction(pyAttribute)) + return py::cast(PyFlatSymbolRefAttribute(pyAttribute)); + if (PySymbolRefAttribute::isaFunction(pyAttribute)) + return py::cast(PySymbolRefAttribute(pyAttribute)); + std::string msg = std::string("Can't cast unknown SymbolRef attribute (") + + std::string(py::repr(py::cast(pyAttribute))) + ")"; + throw py::cast_error(msg); +} + } // namespace void mlir::python::populateIRAttributes(py::module &m) { @@ -1201,6 +1256,11 @@ void mlir::python::populateIRAttributes(py::module &m) { pybind11::cpp_function(denseIntOrFPElementsAttributeCaster)); PyDictAttribute::bind(m); + PySymbolRefAttribute::bind(m); + PyGlobals::get().registerTypeCaster( + mlirSymbolRefAttrGetTypeID(), + pybind11::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster)); + PyFlatSymbolRefAttribute::bind(m); PyOpaqueAttribute::bind(m); PyFloatAttribute::bind(m); diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index da8a58de775a..3ab6d57b4169 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -3131,13 +3131,13 @@ void mlir::python::populateIRCore(py::module &m) { .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) .def_static( "parse", - [](std::string attrSpec, DefaultingPyMlirContext context) { + [](const std::string &attrSpec, DefaultingPyMlirContext context) { PyMlirContext::ErrorCapture errors(context->getRef()); - MlirAttribute type = mlirAttributeParseGet( + MlirAttribute attr = mlirAttributeParseGet( context->get(), toMlirStringRef(attrSpec)); - if (mlirAttributeIsNull(type)) + if (mlirAttributeIsNull(attr)) throw MLIRError("Unable to parse attribute", errors.take()); - return type; + return attr; }, py::arg("asm"), py::arg("context") = py::none(), "Parses an attribute from an assembly form. Raises an MLIRError on " diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index 289913d4f548..de221ddbfa7a 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -305,10 +305,6 @@ MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) { return wrap(llvm::cast(unwrap(attr)).getValue()); } -MlirTypeID mlirFlatSymbolRefAttrGetTypeID(void) { - return wrap(FlatSymbolRefAttr::getTypeID()); -} - //===----------------------------------------------------------------------===// // Type attribute. //===----------------------------------------------------------------------===// diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py index 76077acb6a57..e36736f2974f 100644 --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -73,6 +73,14 @@ def _symbolNameAttr(x, context): @register_attribute_builder("SymbolRefAttr") def _symbolRefAttr(x, context): + if isinstance(x, list): + return SymbolRefAttr.get(x, context=context) + else: + return FlatSymbolRefAttr.get(x, context=context) + + +@register_attribute_builder("FlatSymbolRefAttr") +def _flatSymbolRefAttr(x, context): return FlatSymbolRefAttr.get(x, context=context) @@ -105,6 +113,7 @@ def _f64ArrayAttr(x, context): def _denseI64ArrayAttr(x, context): return DenseI64ArrayAttr.get(x, context=context) + @register_attribute_builder("DenseBoolArrayAttr") def _denseBoolArrayAttr(x, context): return DenseBoolArrayAttr.get(x, context=context) diff --git a/mlir/test/python/dialects/ml_program.py b/mlir/test/python/dialects/ml_program.py index f16de2add379..edffcfbf0138 100644 --- a/mlir/test/python/dialects/ml_program.py +++ b/mlir/test/python/dialects/ml_program.py @@ -2,7 +2,7 @@ # This is just a smoke test that the dialect is functional. from mlir.ir import * -from mlir.dialects import ml_program +from mlir.dialects import ml_program, arith, builtin def constructAndPrintInModule(f): @@ -26,3 +26,21 @@ def testFuncOp(): with InsertionPoint(block): # CHECK: ml_program.return ml_program.ReturnOp([block.arguments[0]]) + + +# CHECK-LABEL: testGlobalStoreOp +@constructAndPrintInModule +def testGlobalStoreOp(): + # CHECK: %cst = arith.constant 4.242000e+01 : f32 + cst = arith.ConstantOp(value=42.42, result=F32Type.get()) + + m = builtin.ModuleOp() + m.sym_name = StringAttr.get("symbol1") + m.sym_visibility = StringAttr.get("public") + # CHECK: module @symbol1 attributes {sym_visibility = "public"} { + # CHECK: ml_program.global public mutable @symbol2 : f32 + # CHECK: } + with InsertionPoint(m.body): + ml_program.GlobalOp("symbol2", F32Type.get(), is_mutable=True) + # CHECK: ml_program.global_store @symbol1::@symbol2 = %cst : f32 + ml_program.GlobalStoreOp(["symbol1", "symbol2"], cst) diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py index 221c186ae7d5..28729e86ccd4 100644 --- a/mlir/test/python/ir/attributes.py +++ b/mlir/test/python/ir/attributes.py @@ -228,7 +228,7 @@ def testBoolAttr(): @run def testFlatSymbolRefAttr(): with Context() as ctx: - sattr = FlatSymbolRefAttr(Attribute.parse("@symbol")) + sattr = Attribute.parse("@symbol") # CHECK: symattr value: symbol print("symattr value:", sattr.value) @@ -237,6 +237,21 @@ def testFlatSymbolRefAttr(): print("default_get:", FlatSymbolRefAttr.get("foobar")) +# CHECK-LABEL: TEST: testSymbolRefAttr +@run +def testSymbolRefAttr(): + with Context() as ctx: + sattr = Attribute.parse("@symbol1::@symbol2") + # CHECK: symattr value: ['symbol1', 'symbol2'] + print("symattr value:", sattr.value) + + # CHECK: default_get: @symbol1::@symbol2 + print("default_get:", SymbolRefAttr.get(["symbol1", "symbol2"])) + + # CHECK: default_get: @"@symbol1"::@"@symbol2" + print("default_get:", SymbolRefAttr.get(["@symbol1", "@symbol2"])) + + # CHECK-LABEL: TEST: testOpaqueAttr @run def testOpaqueAttr():