Revert "Revert "[mlir][py] Enable building ops with raw inputs""

Fix Python 3.6.9 issue encountered due to type checking here. Will
add back in follow up.

This reverts commit 1f47fee294.
This commit is contained in:
Jacques Pienaar 2022-12-21 16:22:39 -08:00
parent 02f4cfa33d
commit b57acb9a40
8 changed files with 196 additions and 35 deletions

View File

@ -743,6 +743,34 @@ with Context():
dictionary = DictAttr.get({"array": array, "unit": UnitAttr.get()})
```
Custom builders for Attributes to be used during Operation creation can be
registered by way of the `register_attribute_builder`. In particular the
following is how a custom builder is registered for `I32Attr`:
```python
@register_attribute_builder("I32Attr")
def _i32Attr(x: int, context: Context):
return IntegerAttr.get(
IntegerType.get_signless(32, context=context), x)
```
This allows to invoke op creation of an op with a `I32Attr` with
```python
foo.Op(30)
```
The registration is based on the ODS name but registry is via pure python
method. Only single custom builder is allowed to be registered per ODS attribute
type (e.g., I32Attr can have only one, which can correspond to multiple of the
underlying IntegerAttr type).
instead of
```python
foo.Op(IntegerAttr.get(IndexType.get_signless(32, context=context), 30))
```
## Style
In general, for the core parts of MLIR, the Python bindings should be largely

View File

@ -58,6 +58,12 @@ public:
/// have a DIALECT_NAMESPACE attribute.
pybind11::object registerDialectDecorator(pybind11::object pyClass);
/// Adds a user-friendly Attribute builder.
/// Raises an exception if the mapping already exists.
/// This is intended to be called by implementation code.
void registerAttributeBuilder(const std::string &attributeKind,
pybind11::function pyFunc);
/// Adds a concrete implementation dialect class.
/// Raises an exception if the mapping already exists.
/// This is intended to be called by implementation code.
@ -71,6 +77,10 @@ public:
pybind11::object pyClass,
pybind11::object rawOpViewClass);
/// Returns the custom Attribute builder for Attribute kind.
std::optional<pybind11::function>
lookupAttributeBuilder(const std::string &attributeKind);
/// Looks up a registered dialect class by namespace. Note that this may
/// trigger loading of the defining module and can arbitrarily re-enter.
llvm::Optional<pybind11::object>
@ -92,6 +102,8 @@ private:
/// Map of operation name to custom subclass that directly initializes
/// the OpView base class (bypassing the user class constructor).
llvm::StringMap<pybind11::object> rawOpViewClassMap;
/// Map of attribute ODS name to custom builder.
llvm::StringMap<pybind11::function> attributeBuilderMap;
/// Set of dialect namespaces that we have attempted to import implementation
/// modules for.

View File

@ -194,6 +194,29 @@ struct PyGlobalDebugFlag {
}
};
struct PyAttrBuilderMap {
static bool dunderContains(const std::string &attributeKind) {
return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
}
static py::function dundeGetItemNamed(const std::string &attributeKind) {
auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
if (!builder)
throw py::key_error();
return *builder;
}
static void dundeSetItemNamed(const std::string &attributeKind,
py::function func) {
PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func));
}
static void bind(py::module &m) {
py::class_<PyAttrBuilderMap>(m, "AttrBuilder", py::module_local())
.def_static("contains", &PyAttrBuilderMap::dunderContains)
.def_static("get", &PyAttrBuilderMap::dundeGetItemNamed)
.def_static("insert", &PyAttrBuilderMap::dundeSetItemNamed);
}
};
//------------------------------------------------------------------------------
// Collections.
//------------------------------------------------------------------------------
@ -3283,4 +3306,7 @@ void mlir::python::populateIRCore(py::module &m) {
// Debug bindings.
PyGlobalDebugFlag::bind(m);
// Attribute builder getter.
PyAttrBuilderMap::bind(m);
}

View File

@ -60,6 +60,17 @@ void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
loadedDialectModulesCache.insert(dialectNamespace);
}
void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
py::function pyFunc) {
py::function &found = attributeBuilderMap[attributeKind];
if (found) {
throw std::runtime_error((llvm::Twine("Attribute builder for '") +
attributeKind + "' is already registered")
.str());
}
found = std::move(pyFunc);
}
void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
py::object pyClass) {
py::object &found = dialectClassMap[dialectNamespace];
@ -84,6 +95,22 @@ void PyGlobals::registerOperationImpl(const std::string &operationName,
rawOpViewClassMap[operationName] = std::move(rawOpViewClass);
}
std::optional<py::function>
PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
// Fast match against the class map first (common case).
const auto foundIt = attributeBuilderMap.find(attributeKind);
if (foundIt != attributeBuilderMap.end()) {
if (foundIt->second.is_none())
return std::nullopt;
assert(foundIt->second && "py::function is defined");
return foundIt->second;
}
// Not found and loading did not yield a registration. Negative cache.
attributeBuilderMap[attributeKind] = py::none();
return std::nullopt;
}
llvm::Optional<py::object>
PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
loadDialectModule(dialectNamespace);

View File

@ -4,3 +4,44 @@
from ._mlir_libs._mlir.ir import *
from ._mlir_libs._mlir.ir import _GlobalDebug
# Convenience decorator for registering user-friendly Attribute builders.
def register_attribute_builder(kind):
def decorator_builder(func):
AttrBuilder.insert(kind, func)
return func
return decorator_builder
@register_attribute_builder("BoolAttr")
def _boolAttr(x, context):
return BoolAttr.get(x, context=context)
@register_attribute_builder("IndexAttr")
def _indexAttr(x, context):
return IntegerAttr.get(IndexType.get(context=context), x)
@register_attribute_builder("I32Attr")
def _i32Attr(x, context):
return IntegerAttr.get(
IntegerType.get_signless(32, context=context), x)
@register_attribute_builder("I64Attr")
def _i64Attr(x, context):
return IntegerAttr.get(
IntegerType.get_signless(64, context=context), x)
@register_attribute_builder("SymbolNameAttr")
def _symbolNameAttr(x, context):
return StringAttr.get(x, context=context)
try:
import numpy as np
@register_attribute_builder("IndexElementsAttr")
def _indexElementsAttr(x, context):
return DenseElementsAttr.get(
np.array(x, dtype=np.int64), type=IndexType.get(context=context),
context=context)
except ImportError:
pass

View File

@ -115,11 +115,14 @@ def AttributedOp : TestOp<"attributed_op"> {
// CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
// CHECK: attributes["i32attr"] = i32attr
// CHECK: if optionalF32Attr is not None: attributes["optionalF32Attr"] = optionalF32Attr
// CHECK: attributes["i32attr"] = (i32attr if (
// CHECK-NEXT: issubclass(type(i32attr), _ods_ir.Attribute) or
// CHECK-NEXT: not _ods_ir.AttrBuilder.contains('I32Attr')
// CHECK-NEXT: _ods_ir.AttrBuilder.get('I32Attr')(i32attr, context=_ods_context)
// CHECK: if optionalF32Attr is not None: attributes["optionalF32Attr"] = (optionalF32Attr
// CHECK: if bool(unitAttr): attributes["unitAttr"] = _ods_ir.UnitAttr.get(
// CHECK: _ods_get_default_loc_context(loc))
// CHECK: attributes["in"] = in_
// CHECK: attributes["in"] = (in_
// CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
@ -161,7 +164,7 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
// CHECK: operands.append(_get_op_result_or_value(_gen_arg_2))
// CHECK: if bool(in_): attributes["in"] = _ods_ir.UnitAttr.get(
// CHECK: _ods_get_default_loc_context(loc))
// CHECK: if is_ is not None: attributes["is"] = is_
// CHECK: if is_ is not None: attributes["is"] = (is_
// CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
@ -188,8 +191,8 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
// CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
// CHECK: if arr is not None: attributes["arr"] = arr
// CHECK: if unsupported is not None: attributes["unsupported"] = unsupported
// CHECK: if arr is not None: attributes["arr"] = (arr
// CHECK: if unsupported is not None: attributes["unsupported"] = (unsupported
// CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
@ -202,7 +205,7 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op"
def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> {
// CHECK: def __init__(self, type, *, loc=None, ip=None):
// CHECK: def __init__(self, type_, *, loc=None, ip=None):
// CHECK: operands = []
// CHECK: results = []
// CHECK: _ods_result_type_source_attr = attributes["type"]
@ -217,7 +220,7 @@ def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResu
// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op"
def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> {
// CHECK: def __init__(self, res, _gen_res_1, type, *, loc=None, ip=None):
// CHECK: def __init__(self, res, _gen_res_1, type_, *, loc=None, ip=None):
let arguments = (ins TypeAttr:$type);
let results = (outs AnyType:$res, Variadic<AnyType>);
}

View File

@ -22,9 +22,18 @@ def testConstShape():
@func.FuncOp.from_py_func(
RankedTensorType.get((12, ShapedType.get_dynamic_size()), f32))
def const_shape_tensor(arg):
shape.ConstWitnessOp(False)
shape.ConstSizeOp(30)
shape.ConstSizeOp(IntegerAttr.get(IndexType.get(), 40))
shape.ConstShapeOp([1, 2])
return shape.ConstShapeOp(
DenseElementsAttr.get(np.array([10, 20], dtype=np.int64), type=IndexType.get()))
DenseElementsAttr.get(
np.array([3, 4], dtype=np.int64), type=IndexType.get()))
# CHECK-LABEL: func @const_shape_tensor(%arg0: tensor<12x?xf32>)
# CHECK: shape.const_shape [10, 20] : tensor<2xindex>
# CHECK-DAG: shape.const_witness false
# CHECK-DAG: shape.const_size 30
# CHECK-DAG: shape.const_size 40
# CHECK-DAG: shape.const_shape [1, 2] : tensor<2xindex>
# CHECK-DAG: shape.const_shape [3, 4] : tensor<2xindex>
print(module)

View File

@ -280,15 +280,16 @@ static llvm::cl::opt<std::string> clDialectExtensionName(
using AttributeClasses = DenseMap<StringRef, StringRef>;
/// Checks whether `str` is a Python keyword.
static bool isPythonKeyword(StringRef str) {
static llvm::StringSet<> keywords(
{"and", "as", "assert", "break", "class", "continue",
"def", "del", "elif", "else", "except", "finally",
"for", "from", "global", "if", "import", "in",
"is", "lambda", "nonlocal", "not", "or", "pass",
"raise", "return", "try", "while", "with", "yield"});
return keywords.contains(str);
/// Checks whether `str` is a Python keyword or would shadow builtin function.
static bool isPythonReserved(StringRef str) {
static llvm::StringSet<> reserved(
{"and", "as", "assert", "break", "callable", "class",
"continue", "def", "del", "elif", "else", "except",
"finally", "for", "from", "global", "if", "import",
"in", "is", "lambda", "nonlocal", "not", "or",
"pass", "raise", "return", "issubclass", "try", "type",
"while", "with", "yield"});
return reserved.contains(str);
}
/// Checks whether `str` would shadow a generated variable or attribute
@ -306,7 +307,7 @@ static bool isODSReserved(StringRef str) {
/// (does not change the `name` if it already is suitable) and returns the
/// modified version.
static std::string sanitizeName(StringRef name) {
if (isPythonKeyword(name) || isODSReserved(name))
if (isPythonReserved(name) || isODSReserved(name))
return (name + "_").str();
return name.str();
}
@ -531,16 +532,30 @@ constexpr const char *multiOperandAppendPackTemplate =
"operands.append(_get_op_results_or_values({0}))";
constexpr const char *multiResultAppendTemplate = "results.extend({0})";
/// Template for setting an attribute in the operation builder.
/// {0} is the attribute name;
/// {1} is the builder argument name.
constexpr const char *initAttributeTemplate = R"Py(attributes["{0}"] = {1})Py";
/// Template for attribute builder from raw input in the operation builder.
/// {0} is the builder argument name;
/// {1} is the attribute builder from raw;
/// {2} is the attribute builder from raw.
/// Use the value the user passed in if either it is already an Attribute or
/// there is no method registered to make it an Attribute.
constexpr const char *initAttributeWithBuilderTemplate =
R"Py(attributes["{1}"] = ({0} if (
issubclass(type({0}), _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('{2}')) else
_ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
/// Template for setting an optional attribute in the operation builder.
/// {0} is the attribute name;
/// {1} is the builder argument name.
constexpr const char *initOptionalAttributeTemplate =
R"Py(if {1} is not None: attributes["{0}"] = {1})Py";
/// Template for attribute builder from raw input for optional attribute in the
/// operation builder.
/// {0} is the builder argument name;
/// {1} is the attribute builder from raw;
/// {2} is the attribute builder from raw.
/// Use the value the user passed in if either it is already an Attribute or
/// there is no method registered to make it an Attribute.
constexpr const char *initOptionalAttributeWithBuilderTemplate =
R"Py(if {0} is not None: attributes["{1}"] = ({0} if (
issubclass(type({0}), _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('{2}')) else
_ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
constexpr const char *initUnitAttributeTemplate =
R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
@ -656,6 +671,7 @@ static void
populateBuilderLinesAttr(const Operator &op,
llvm::ArrayRef<std::string> argNames,
llvm::SmallVectorImpl<std::string> &builderLines) {
builderLines.push_back("_ods_context = _ods_get_default_loc_context(loc)");
for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
Argument arg = op.getArg(i);
auto *attribute = arg.dyn_cast<NamedAttribute *>();
@ -670,10 +686,10 @@ populateBuilderLinesAttr(const Operator &op,
}
builderLines.push_back(llvm::formatv(
(attribute->attr.isOptional() || attribute->attr.hasDefaultValue())
? initOptionalAttributeTemplate
: initAttributeTemplate,
attribute->name, argNames[i]));
attribute->attr.isOptional() || attribute->attr.hasDefaultValue()
? initOptionalAttributeWithBuilderTemplate
: initAttributeWithBuilderTemplate,
argNames[i], attribute->name, attribute->attr.getAttrDefName()));
}
}
@ -753,8 +769,7 @@ constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})";
/// corresponding interface:
/// - {0} is the name of the class for which the types are inferred.
constexpr const char *inferTypeInterfaceTemplate =
R"PY(_ods_context = _ods_get_default_loc_context(loc)
results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes(
R"PY(results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes(
operands=operands,
attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context),
context=_ods_context,