[mlir][python] Swap shape and element_type order for MemRefType.

* Matches how all of the other shaped types are declared.
* No super principled reason fro this ordering beyond that it makes the one that was different be like the rest.
* Also matches ordering of things like ndarray, et al.

Reviewed By: ftynse, nicolasvasilache

Differential Revision: https://reviews.llvm.org/D94812
This commit is contained in:
Stella Laurenzo 2021-01-19 16:02:02 -08:00
parent 7f36df0fb1
commit b62c7e0474
3 changed files with 8 additions and 10 deletions

View File

@ -31,9 +31,9 @@ def FuncOp(name: str, func_type: Type) -> Tuple[Operation, Block]:
def build_matmul_buffers_func(func_name, m, k, n, dtype):
lhs_type = MemRefType.get(dtype, [m, k])
rhs_type = MemRefType.get(dtype, [k, n])
result_type = MemRefType.get(dtype, [m, n])
lhs_type = MemRefType.get([m, k], dtype)
rhs_type = MemRefType.get([k, n], dtype)
result_type = MemRefType.get([m, n], dtype)
# TODO: There should be a one-liner for this.
func_type = FunctionType.get([lhs_type, rhs_type, result_type], [])
_, entry = FuncOp(func_name, func_type)
@ -49,8 +49,6 @@ def build_matmul_buffers_func(func_name, m, k, n, dtype):
def build_matmul_tensors_func(func_name, m, k, n, dtype):
# TODO: MemRefType and TensorTypes should not have inverted dtype/shapes
# from each other.
lhs_type = RankedTensorType.get([m, k], dtype)
rhs_type = RankedTensorType.get([k, n], dtype)
result_type = RankedTensorType.get([m, n], dtype)

View File

@ -2832,7 +2832,7 @@ public:
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](PyType &elementType, std::vector<int64_t> shape,
[](std::vector<int64_t> shape, PyType &elementType,
std::vector<PyAffineMap> layout, unsigned memorySpace,
DefaultingPyLocation loc) {
SmallVector<MlirAffineMap> maps;
@ -2856,7 +2856,7 @@ public:
}
return PyMemRefType(elementType.getContext(), t);
},
py::arg("element_type"), py::arg("shape"),
py::arg("shape"), py::arg("element_type"),
py::arg("layout") = py::list(), py::arg("memory_space") = 0,
py::arg("loc") = py::none(), "Create a memref type")
.def_property_readonly("layout", &PyMemRefType::getLayout,

View File

@ -326,7 +326,7 @@ def testMemRefType():
f32 = F32Type.get()
shape = [2, 3]
loc = Location.unknown()
memref = MemRefType.get(f32, shape, memory_space=2)
memref = MemRefType.get(shape, f32, memory_space=2)
# CHECK: memref type: memref<2x3xf32, 2>
print("memref type:", memref)
# CHECK: number of affine layout maps: 0
@ -335,7 +335,7 @@ def testMemRefType():
print("memory space:", memref.memory_space)
layout = AffineMap.get_permutation([1, 0])
memref_layout = MemRefType.get(f32, shape, [layout])
memref_layout = MemRefType.get(shape, f32, [layout])
# CHECK: memref type: memref<2x3xf32, affine_map<(d0, d1) -> (d1, d0)>>
print("memref type:", memref_layout)
assert len(memref_layout.layout) == 1
@ -346,7 +346,7 @@ def testMemRefType():
none = NoneType.get()
try:
memref_invalid = MemRefType.get(none, shape)
memref_invalid = MemRefType.get(shape, none)
except ValueError as e:
# CHECK: invalid 'Type(none)' and expected floating point, integer, vector
# CHECK: or complex type.