mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-11-24 06:10:12 +00:00
Revert "Revert "[mlir][py] Enable building ops with raw inputs""
Fix Python 3.6.9 issue encountered due to type checking here. Will
add back in follow up.
This reverts commit 1f47fee294
.
This commit is contained in:
parent
02f4cfa33d
commit
b57acb9a40
@ -743,6 +743,34 @@ with Context():
|
||||
dictionary = DictAttr.get({"array": array, "unit": UnitAttr.get()})
|
||||
```
|
||||
|
||||
Custom builders for Attributes to be used during Operation creation can be
|
||||
registered by way of the `register_attribute_builder`. In particular the
|
||||
following is how a custom builder is registered for `I32Attr`:
|
||||
|
||||
```python
|
||||
@register_attribute_builder("I32Attr")
|
||||
def _i32Attr(x: int, context: Context):
|
||||
return IntegerAttr.get(
|
||||
IntegerType.get_signless(32, context=context), x)
|
||||
```
|
||||
|
||||
This allows to invoke op creation of an op with a `I32Attr` with
|
||||
|
||||
```python
|
||||
foo.Op(30)
|
||||
```
|
||||
|
||||
The registration is based on the ODS name but registry is via pure python
|
||||
method. Only single custom builder is allowed to be registered per ODS attribute
|
||||
type (e.g., I32Attr can have only one, which can correspond to multiple of the
|
||||
underlying IntegerAttr type).
|
||||
|
||||
instead of
|
||||
|
||||
```python
|
||||
foo.Op(IntegerAttr.get(IndexType.get_signless(32, context=context), 30))
|
||||
```
|
||||
|
||||
## Style
|
||||
|
||||
In general, for the core parts of MLIR, the Python bindings should be largely
|
||||
|
@ -58,6 +58,12 @@ public:
|
||||
/// have a DIALECT_NAMESPACE attribute.
|
||||
pybind11::object registerDialectDecorator(pybind11::object pyClass);
|
||||
|
||||
/// Adds a user-friendly Attribute builder.
|
||||
/// Raises an exception if the mapping already exists.
|
||||
/// This is intended to be called by implementation code.
|
||||
void registerAttributeBuilder(const std::string &attributeKind,
|
||||
pybind11::function pyFunc);
|
||||
|
||||
/// Adds a concrete implementation dialect class.
|
||||
/// Raises an exception if the mapping already exists.
|
||||
/// This is intended to be called by implementation code.
|
||||
@ -71,6 +77,10 @@ public:
|
||||
pybind11::object pyClass,
|
||||
pybind11::object rawOpViewClass);
|
||||
|
||||
/// Returns the custom Attribute builder for Attribute kind.
|
||||
std::optional<pybind11::function>
|
||||
lookupAttributeBuilder(const std::string &attributeKind);
|
||||
|
||||
/// Looks up a registered dialect class by namespace. Note that this may
|
||||
/// trigger loading of the defining module and can arbitrarily re-enter.
|
||||
llvm::Optional<pybind11::object>
|
||||
@ -92,6 +102,8 @@ private:
|
||||
/// Map of operation name to custom subclass that directly initializes
|
||||
/// the OpView base class (bypassing the user class constructor).
|
||||
llvm::StringMap<pybind11::object> rawOpViewClassMap;
|
||||
/// Map of attribute ODS name to custom builder.
|
||||
llvm::StringMap<pybind11::function> attributeBuilderMap;
|
||||
|
||||
/// Set of dialect namespaces that we have attempted to import implementation
|
||||
/// modules for.
|
||||
|
@ -194,6 +194,29 @@ struct PyGlobalDebugFlag {
|
||||
}
|
||||
};
|
||||
|
||||
struct PyAttrBuilderMap {
|
||||
static bool dunderContains(const std::string &attributeKind) {
|
||||
return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
|
||||
}
|
||||
static py::function dundeGetItemNamed(const std::string &attributeKind) {
|
||||
auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
|
||||
if (!builder)
|
||||
throw py::key_error();
|
||||
return *builder;
|
||||
}
|
||||
static void dundeSetItemNamed(const std::string &attributeKind,
|
||||
py::function func) {
|
||||
PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func));
|
||||
}
|
||||
|
||||
static void bind(py::module &m) {
|
||||
py::class_<PyAttrBuilderMap>(m, "AttrBuilder", py::module_local())
|
||||
.def_static("contains", &PyAttrBuilderMap::dunderContains)
|
||||
.def_static("get", &PyAttrBuilderMap::dundeGetItemNamed)
|
||||
.def_static("insert", &PyAttrBuilderMap::dundeSetItemNamed);
|
||||
}
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// Collections.
|
||||
//------------------------------------------------------------------------------
|
||||
@ -3283,4 +3306,7 @@ void mlir::python::populateIRCore(py::module &m) {
|
||||
|
||||
// Debug bindings.
|
||||
PyGlobalDebugFlag::bind(m);
|
||||
|
||||
// Attribute builder getter.
|
||||
PyAttrBuilderMap::bind(m);
|
||||
}
|
||||
|
@ -60,6 +60,17 @@ void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
|
||||
loadedDialectModulesCache.insert(dialectNamespace);
|
||||
}
|
||||
|
||||
void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
|
||||
py::function pyFunc) {
|
||||
py::function &found = attributeBuilderMap[attributeKind];
|
||||
if (found) {
|
||||
throw std::runtime_error((llvm::Twine("Attribute builder for '") +
|
||||
attributeKind + "' is already registered")
|
||||
.str());
|
||||
}
|
||||
found = std::move(pyFunc);
|
||||
}
|
||||
|
||||
void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
|
||||
py::object pyClass) {
|
||||
py::object &found = dialectClassMap[dialectNamespace];
|
||||
@ -84,6 +95,22 @@ void PyGlobals::registerOperationImpl(const std::string &operationName,
|
||||
rawOpViewClassMap[operationName] = std::move(rawOpViewClass);
|
||||
}
|
||||
|
||||
std::optional<py::function>
|
||||
PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
|
||||
// Fast match against the class map first (common case).
|
||||
const auto foundIt = attributeBuilderMap.find(attributeKind);
|
||||
if (foundIt != attributeBuilderMap.end()) {
|
||||
if (foundIt->second.is_none())
|
||||
return std::nullopt;
|
||||
assert(foundIt->second && "py::function is defined");
|
||||
return foundIt->second;
|
||||
}
|
||||
|
||||
// Not found and loading did not yield a registration. Negative cache.
|
||||
attributeBuilderMap[attributeKind] = py::none();
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
llvm::Optional<py::object>
|
||||
PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
|
||||
loadDialectModule(dialectNamespace);
|
||||
|
@ -4,3 +4,44 @@
|
||||
|
||||
from ._mlir_libs._mlir.ir import *
|
||||
from ._mlir_libs._mlir.ir import _GlobalDebug
|
||||
|
||||
|
||||
# Convenience decorator for registering user-friendly Attribute builders.
|
||||
def register_attribute_builder(kind):
|
||||
def decorator_builder(func):
|
||||
AttrBuilder.insert(kind, func)
|
||||
return func
|
||||
return decorator_builder
|
||||
|
||||
|
||||
@register_attribute_builder("BoolAttr")
|
||||
def _boolAttr(x, context):
|
||||
return BoolAttr.get(x, context=context)
|
||||
|
||||
@register_attribute_builder("IndexAttr")
|
||||
def _indexAttr(x, context):
|
||||
return IntegerAttr.get(IndexType.get(context=context), x)
|
||||
|
||||
@register_attribute_builder("I32Attr")
|
||||
def _i32Attr(x, context):
|
||||
return IntegerAttr.get(
|
||||
IntegerType.get_signless(32, context=context), x)
|
||||
|
||||
@register_attribute_builder("I64Attr")
|
||||
def _i64Attr(x, context):
|
||||
return IntegerAttr.get(
|
||||
IntegerType.get_signless(64, context=context), x)
|
||||
|
||||
@register_attribute_builder("SymbolNameAttr")
|
||||
def _symbolNameAttr(x, context):
|
||||
return StringAttr.get(x, context=context)
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
@register_attribute_builder("IndexElementsAttr")
|
||||
def _indexElementsAttr(x, context):
|
||||
return DenseElementsAttr.get(
|
||||
np.array(x, dtype=np.int64), type=IndexType.get(context=context),
|
||||
context=context)
|
||||
except ImportError:
|
||||
pass
|
||||
|
@ -115,11 +115,14 @@ def AttributedOp : TestOp<"attributed_op"> {
|
||||
// CHECK: results = []
|
||||
// CHECK: attributes = {}
|
||||
// CHECK: regions = None
|
||||
// CHECK: attributes["i32attr"] = i32attr
|
||||
// CHECK: if optionalF32Attr is not None: attributes["optionalF32Attr"] = optionalF32Attr
|
||||
// CHECK: attributes["i32attr"] = (i32attr if (
|
||||
// CHECK-NEXT: issubclass(type(i32attr), _ods_ir.Attribute) or
|
||||
// CHECK-NEXT: not _ods_ir.AttrBuilder.contains('I32Attr')
|
||||
// CHECK-NEXT: _ods_ir.AttrBuilder.get('I32Attr')(i32attr, context=_ods_context)
|
||||
// CHECK: if optionalF32Attr is not None: attributes["optionalF32Attr"] = (optionalF32Attr
|
||||
// CHECK: if bool(unitAttr): attributes["unitAttr"] = _ods_ir.UnitAttr.get(
|
||||
// CHECK: _ods_get_default_loc_context(loc))
|
||||
// CHECK: attributes["in"] = in_
|
||||
// CHECK: attributes["in"] = (in_
|
||||
// CHECK: _ods_successors = None
|
||||
// CHECK: super().__init__(self.build_generic(
|
||||
// CHECK: attributes=attributes, results=results, operands=operands,
|
||||
@ -161,7 +164,7 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
|
||||
// CHECK: operands.append(_get_op_result_or_value(_gen_arg_2))
|
||||
// CHECK: if bool(in_): attributes["in"] = _ods_ir.UnitAttr.get(
|
||||
// CHECK: _ods_get_default_loc_context(loc))
|
||||
// CHECK: if is_ is not None: attributes["is"] = is_
|
||||
// CHECK: if is_ is not None: attributes["is"] = (is_
|
||||
// CHECK: _ods_successors = None
|
||||
// CHECK: super().__init__(self.build_generic(
|
||||
// CHECK: attributes=attributes, results=results, operands=operands,
|
||||
@ -188,8 +191,8 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
|
||||
// CHECK: results = []
|
||||
// CHECK: attributes = {}
|
||||
// CHECK: regions = None
|
||||
// CHECK: if arr is not None: attributes["arr"] = arr
|
||||
// CHECK: if unsupported is not None: attributes["unsupported"] = unsupported
|
||||
// CHECK: if arr is not None: attributes["arr"] = (arr
|
||||
// CHECK: if unsupported is not None: attributes["unsupported"] = (unsupported
|
||||
// CHECK: _ods_successors = None
|
||||
// CHECK: super().__init__(self.build_generic(
|
||||
// CHECK: attributes=attributes, results=results, operands=operands,
|
||||
@ -202,7 +205,7 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
|
||||
|
||||
// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op"
|
||||
def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> {
|
||||
// CHECK: def __init__(self, type, *, loc=None, ip=None):
|
||||
// CHECK: def __init__(self, type_, *, loc=None, ip=None):
|
||||
// CHECK: operands = []
|
||||
// CHECK: results = []
|
||||
// CHECK: _ods_result_type_source_attr = attributes["type"]
|
||||
@ -217,7 +220,7 @@ def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResu
|
||||
|
||||
// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op"
|
||||
def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> {
|
||||
// CHECK: def __init__(self, res, _gen_res_1, type, *, loc=None, ip=None):
|
||||
// CHECK: def __init__(self, res, _gen_res_1, type_, *, loc=None, ip=None):
|
||||
let arguments = (ins TypeAttr:$type);
|
||||
let results = (outs AnyType:$res, Variadic<AnyType>);
|
||||
}
|
||||
|
@ -22,9 +22,18 @@ def testConstShape():
|
||||
@func.FuncOp.from_py_func(
|
||||
RankedTensorType.get((12, ShapedType.get_dynamic_size()), f32))
|
||||
def const_shape_tensor(arg):
|
||||
shape.ConstWitnessOp(False)
|
||||
shape.ConstSizeOp(30)
|
||||
shape.ConstSizeOp(IntegerAttr.get(IndexType.get(), 40))
|
||||
shape.ConstShapeOp([1, 2])
|
||||
return shape.ConstShapeOp(
|
||||
DenseElementsAttr.get(np.array([10, 20], dtype=np.int64), type=IndexType.get()))
|
||||
DenseElementsAttr.get(
|
||||
np.array([3, 4], dtype=np.int64), type=IndexType.get()))
|
||||
|
||||
# CHECK-LABEL: func @const_shape_tensor(%arg0: tensor<12x?xf32>)
|
||||
# CHECK: shape.const_shape [10, 20] : tensor<2xindex>
|
||||
# CHECK-DAG: shape.const_witness false
|
||||
# CHECK-DAG: shape.const_size 30
|
||||
# CHECK-DAG: shape.const_size 40
|
||||
# CHECK-DAG: shape.const_shape [1, 2] : tensor<2xindex>
|
||||
# CHECK-DAG: shape.const_shape [3, 4] : tensor<2xindex>
|
||||
print(module)
|
||||
|
@ -280,15 +280,16 @@ static llvm::cl::opt<std::string> clDialectExtensionName(
|
||||
|
||||
using AttributeClasses = DenseMap<StringRef, StringRef>;
|
||||
|
||||
/// Checks whether `str` is a Python keyword.
|
||||
static bool isPythonKeyword(StringRef str) {
|
||||
static llvm::StringSet<> keywords(
|
||||
{"and", "as", "assert", "break", "class", "continue",
|
||||
"def", "del", "elif", "else", "except", "finally",
|
||||
"for", "from", "global", "if", "import", "in",
|
||||
"is", "lambda", "nonlocal", "not", "or", "pass",
|
||||
"raise", "return", "try", "while", "with", "yield"});
|
||||
return keywords.contains(str);
|
||||
/// Checks whether `str` is a Python keyword or would shadow builtin function.
|
||||
static bool isPythonReserved(StringRef str) {
|
||||
static llvm::StringSet<> reserved(
|
||||
{"and", "as", "assert", "break", "callable", "class",
|
||||
"continue", "def", "del", "elif", "else", "except",
|
||||
"finally", "for", "from", "global", "if", "import",
|
||||
"in", "is", "lambda", "nonlocal", "not", "or",
|
||||
"pass", "raise", "return", "issubclass", "try", "type",
|
||||
"while", "with", "yield"});
|
||||
return reserved.contains(str);
|
||||
}
|
||||
|
||||
/// Checks whether `str` would shadow a generated variable or attribute
|
||||
@ -306,7 +307,7 @@ static bool isODSReserved(StringRef str) {
|
||||
/// (does not change the `name` if it already is suitable) and returns the
|
||||
/// modified version.
|
||||
static std::string sanitizeName(StringRef name) {
|
||||
if (isPythonKeyword(name) || isODSReserved(name))
|
||||
if (isPythonReserved(name) || isODSReserved(name))
|
||||
return (name + "_").str();
|
||||
return name.str();
|
||||
}
|
||||
@ -531,16 +532,30 @@ constexpr const char *multiOperandAppendPackTemplate =
|
||||
"operands.append(_get_op_results_or_values({0}))";
|
||||
constexpr const char *multiResultAppendTemplate = "results.extend({0})";
|
||||
|
||||
/// Template for setting an attribute in the operation builder.
|
||||
/// {0} is the attribute name;
|
||||
/// {1} is the builder argument name.
|
||||
constexpr const char *initAttributeTemplate = R"Py(attributes["{0}"] = {1})Py";
|
||||
/// Template for attribute builder from raw input in the operation builder.
|
||||
/// {0} is the builder argument name;
|
||||
/// {1} is the attribute builder from raw;
|
||||
/// {2} is the attribute builder from raw.
|
||||
/// Use the value the user passed in if either it is already an Attribute or
|
||||
/// there is no method registered to make it an Attribute.
|
||||
constexpr const char *initAttributeWithBuilderTemplate =
|
||||
R"Py(attributes["{1}"] = ({0} if (
|
||||
issubclass(type({0}), _ods_ir.Attribute) or
|
||||
not _ods_ir.AttrBuilder.contains('{2}')) else
|
||||
_ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
|
||||
|
||||
/// Template for setting an optional attribute in the operation builder.
|
||||
/// {0} is the attribute name;
|
||||
/// {1} is the builder argument name.
|
||||
constexpr const char *initOptionalAttributeTemplate =
|
||||
R"Py(if {1} is not None: attributes["{0}"] = {1})Py";
|
||||
/// Template for attribute builder from raw input for optional attribute in the
|
||||
/// operation builder.
|
||||
/// {0} is the builder argument name;
|
||||
/// {1} is the attribute builder from raw;
|
||||
/// {2} is the attribute builder from raw.
|
||||
/// Use the value the user passed in if either it is already an Attribute or
|
||||
/// there is no method registered to make it an Attribute.
|
||||
constexpr const char *initOptionalAttributeWithBuilderTemplate =
|
||||
R"Py(if {0} is not None: attributes["{1}"] = ({0} if (
|
||||
issubclass(type({0}), _ods_ir.Attribute) or
|
||||
not _ods_ir.AttrBuilder.contains('{2}')) else
|
||||
_ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
|
||||
|
||||
constexpr const char *initUnitAttributeTemplate =
|
||||
R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
|
||||
@ -656,6 +671,7 @@ static void
|
||||
populateBuilderLinesAttr(const Operator &op,
|
||||
llvm::ArrayRef<std::string> argNames,
|
||||
llvm::SmallVectorImpl<std::string> &builderLines) {
|
||||
builderLines.push_back("_ods_context = _ods_get_default_loc_context(loc)");
|
||||
for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
|
||||
Argument arg = op.getArg(i);
|
||||
auto *attribute = arg.dyn_cast<NamedAttribute *>();
|
||||
@ -670,10 +686,10 @@ populateBuilderLinesAttr(const Operator &op,
|
||||
}
|
||||
|
||||
builderLines.push_back(llvm::formatv(
|
||||
(attribute->attr.isOptional() || attribute->attr.hasDefaultValue())
|
||||
? initOptionalAttributeTemplate
|
||||
: initAttributeTemplate,
|
||||
attribute->name, argNames[i]));
|
||||
attribute->attr.isOptional() || attribute->attr.hasDefaultValue()
|
||||
? initOptionalAttributeWithBuilderTemplate
|
||||
: initAttributeWithBuilderTemplate,
|
||||
argNames[i], attribute->name, attribute->attr.getAttrDefName()));
|
||||
}
|
||||
}
|
||||
|
||||
@ -753,8 +769,7 @@ constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})";
|
||||
/// corresponding interface:
|
||||
/// - {0} is the name of the class for which the types are inferred.
|
||||
constexpr const char *inferTypeInterfaceTemplate =
|
||||
R"PY(_ods_context = _ods_get_default_loc_context(loc)
|
||||
results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes(
|
||||
R"PY(results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes(
|
||||
operands=operands,
|
||||
attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context),
|
||||
context=_ods_context,
|
||||
|
Loading…
Reference in New Issue
Block a user