diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py index 6d21da3b4179..eb7f035fec7c 100644 --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -263,6 +263,11 @@ def _typeArrayAttr(x, context): return _arrayAttr([TypeAttr.get(t, context=context) for t in x], context) +@register_attribute_builder("MemRefTypeAttr") +def _memref_type_attr(x, context): + return _typeAttr(x, context) + + try: import numpy as np diff --git a/mlir/test/python/dialects/memref.py b/mlir/test/python/dialects/memref.py index 2e3cae671a9f..0c8a7ee282fe 100644 --- a/mlir/test/python/dialects/memref.py +++ b/mlir/test/python/dialects/memref.py @@ -3,6 +3,7 @@ from mlir.ir import * import mlir.dialects.func as func import mlir.dialects.memref as memref +import mlir.extras.types as T def run(f): @@ -76,3 +77,14 @@ def testCustomBuidlers(): # CHECK: memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] print(module) assert module.operation.verify() + + +# CHECK-LABEL: TEST: testMemRefAttr +@run +def testMemRefAttr(): + with Context() as ctx, Location.unknown(ctx): + module = Module.create() + with InsertionPoint(module.body): + memref.global_("objFifo_in0", T.memref(16, T.i32())) + # CHECK: memref.global @objFifo_in0 : memref<16xi32> + print(module)