diff --git a/mlir/examples/python/linalg_matmul.py b/mlir/examples/python/linalg_matmul.py index e9be189bfaaf..0bd3c12a0378 100644 --- a/mlir/examples/python/linalg_matmul.py +++ b/mlir/examples/python/linalg_matmul.py @@ -31,9 +31,9 @@ def FuncOp(name: str, func_type: Type) -> Tuple[Operation, Block]: def build_matmul_buffers_func(func_name, m, k, n, dtype): - lhs_type = MemRefType.get(dtype, [m, k]) - rhs_type = MemRefType.get(dtype, [k, n]) - result_type = MemRefType.get(dtype, [m, n]) + lhs_type = MemRefType.get([m, k], dtype) + rhs_type = MemRefType.get([k, n], dtype) + result_type = MemRefType.get([m, n], dtype) # TODO: There should be a one-liner for this. func_type = FunctionType.get([lhs_type, rhs_type, result_type], []) _, entry = FuncOp(func_name, func_type) @@ -49,8 +49,6 @@ def build_matmul_buffers_func(func_name, m, k, n, dtype): def build_matmul_tensors_func(func_name, m, k, n, dtype): - # TODO: MemRefType and TensorTypes should not have inverted dtype/shapes - # from each other. lhs_type = RankedTensorType.get([m, k], dtype) rhs_type = RankedTensorType.get([k, n], dtype) result_type = RankedTensorType.get([m, n], dtype) diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp index 63bdd0c7a184..3c9f79e2a17a 100644 --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -2832,7 +2832,7 @@ public: static void bindDerived(ClassTy &c) { c.def_static( "get", - [](PyType &elementType, std::vector shape, + [](std::vector shape, PyType &elementType, std::vector layout, unsigned memorySpace, DefaultingPyLocation loc) { SmallVector maps; @@ -2856,7 +2856,7 @@ public: } return PyMemRefType(elementType.getContext(), t); }, - py::arg("element_type"), py::arg("shape"), + py::arg("shape"), py::arg("element_type"), py::arg("layout") = py::list(), py::arg("memory_space") = 0, py::arg("loc") = py::none(), "Create a memref type") .def_property_readonly("layout", &PyMemRefType::getLayout, diff --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py index 64b684ee99e9..7402c644a1c1 100644 --- a/mlir/test/Bindings/Python/ir_types.py +++ b/mlir/test/Bindings/Python/ir_types.py @@ -326,7 +326,7 @@ def testMemRefType(): f32 = F32Type.get() shape = [2, 3] loc = Location.unknown() - memref = MemRefType.get(f32, shape, memory_space=2) + memref = MemRefType.get(shape, f32, memory_space=2) # CHECK: memref type: memref<2x3xf32, 2> print("memref type:", memref) # CHECK: number of affine layout maps: 0 @@ -335,7 +335,7 @@ def testMemRefType(): print("memory space:", memref.memory_space) layout = AffineMap.get_permutation([1, 0]) - memref_layout = MemRefType.get(f32, shape, [layout]) + memref_layout = MemRefType.get(shape, f32, [layout]) # CHECK: memref type: memref<2x3xf32, affine_map<(d0, d1) -> (d1, d0)>> print("memref type:", memref_layout) assert len(memref_layout.layout) == 1 @@ -346,7 +346,7 @@ def testMemRefType(): none = NoneType.get() try: - memref_invalid = MemRefType.get(none, shape) + memref_invalid = MemRefType.get(shape, none) except ValueError as e: # CHECK: invalid 'Type(none)' and expected floating point, integer, vector # CHECK: or complex type.