[mlir][sparse][taco] Split the evaluate method into compile and compute.

This is to align with the PyTACO API better.

Modify an existing unit test to test the new routines.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D121083
This commit is contained in:
Bixia Zheng 2022-03-06 19:34:40 -08:00
parent 355ad3a3cd
commit 5b87e0521d
2 changed files with 88 additions and 41 deletions

View File

@ -30,6 +30,7 @@ import os
import threading
# Import MLIR related modules.
from mlir import execution_engine
from mlir import ir
from mlir import runtime
from mlir.dialects import arith
@ -644,6 +645,7 @@ class Tensor:
dtype = dtype or DType(Type.FLOAT32)
self._name = name or self._get_unique_name()
self._assignment = None
self._engine = None
self._sparse_value_location = _SparseValueInfo._UNPACKED
self._dense_storage = None
self._dtype = dtype
@ -978,17 +980,72 @@ class Tensor:
f"len({indices}) != {self.order}.")
self._assignment = _Assignment(indices, value)
self._engine = None
def evaluate(self) -> None:
"""Evaluates the assignment to the tensor."""
result = self._assignment.expression.evaluate(self,
self._assignment.indices)
self._assignment = None
def compile(self, force_recompile: bool = False) -> None:
"""Compiles the tensor assignment to an execution engine.
Calling compile the second time does not do anything unless
force_recompile is True.
Args:
force_recompile: A boolean value to enable recompilation, such as for the
purpose of timing.
Raises:
ValueError: If the assignment is not proper or not supported.
"""
if self._assignment is None or (self._engine is not None and
not force_recompile):
return
self._engine = self._assignment.expression.compile(self,
self._assignment.indices)
def compute(self) -> None:
"""Executes the engine for the tensor assignment.
Raises:
ValueError: If the assignment hasn't been compiled yet.
"""
if self._assignment is None:
return
if self._engine is None:
raise ValueError("Need to invoke compile() before invoking compute().")
input_accesses = self._assignment.expression.get_input_accesses()
# Gather the pointers for the input buffers.
input_pointers = [a.tensor.ctype_pointer() for a in input_accesses]
if self.is_dense():
# The pointer to receive dense output is the first argument to the
# execution engine.
arg_pointers = [self.dense_dst_ctype_pointer()] + input_pointers
else:
# The pointer to receive the sparse tensor output is the last argument
# to the execution engine and is a pointer to pointer of char.
arg_pointers = input_pointers + [
ctypes.pointer(ctypes.pointer(ctypes.c_char(0)))
]
# Invoke the execution engine to run the module.
self._engine.invoke(_ENTRY_NAME, *arg_pointers)
# Retrieve the result.
if self.is_dense():
result = runtime.ranked_memref_to_numpy(arg_pointers[0][0])
assert isinstance(result, np.ndarray)
self._dense_storage = result
else:
self._set_packed_sparse_tensor(result)
self._set_packed_sparse_tensor(arg_pointers[-1][0])
self._assignment = None
self._engine = None
def evaluate(self) -> None:
"""Evaluates the tensor assignment."""
self.compile()
self.compute()
def _sync_value(self) -> None:
"""Updates the tensor value by evaluating the pending assignment."""
@ -1444,29 +1501,31 @@ class IndexExpr(abc.ABC):
linalg_funcop.func_op.attributes[
"llvm.emit_c_interface"] = ir.UnitAttr.get()
def evaluate(
def get_input_accesses(self) -> List["Access"]:
"""Compute the list of input accesses for the expression."""
input_accesses = []
self._visit(_gather_input_accesses_index_vars, (input_accesses,))
return input_accesses
def compile(
self,
dst: Tensor,
dst_indices: Tuple[IndexVar, ...],
) -> Union[np.ndarray, ctypes.c_void_p]:
"""Evaluates tensor assignment dst[dst_indices] = expression.
) -> execution_engine.ExecutionEngine:
"""Compiles the tensor assignment dst[dst_indices] = expression.
Args:
dst: The destination tensor.
dst_indices: The tuple of IndexVar used to access the destination tensor.
Returns:
The result of the dense tensor represented in numpy ndarray or the pointer
to the MLIR sparse tensor.
The execution engine for the tensor assignment.
Raises:
ValueError: If the expression is not proper or not supported.
"""
expr_to_info = self._validate_and_collect_expr_info(dst, dst_indices)
# Compute a list of input accesses.
input_accesses = []
self._visit(_gather_input_accesses_index_vars, (input_accesses,))
input_accesses = self.get_input_accesses()
# Build and compile the module to produce the execution engine.
with ir.Context(), ir.Location.unknown():
@ -1475,29 +1534,7 @@ class IndexExpr(abc.ABC):
input_accesses)
engine = utils.compile_and_build_engine(module)
# Gather the pointers for the input buffers.
input_pointers = [a.tensor.ctype_pointer() for a in input_accesses]
if dst.is_dense():
# The pointer to receive dense output is the first argument to the
# execution engine.
arg_pointers = [dst.dense_dst_ctype_pointer()] + input_pointers
else:
# The pointer to receive sparse output is the last argument to the
# execution engine. The pointer to receive a sparse tensor output is a
# pointer to pointer of char.
arg_pointers = input_pointers + [
ctypes.pointer(ctypes.pointer(ctypes.c_char(0)))
]
# Invoke the execution engine to run the module and return the result.
engine.invoke(_ENTRY_NAME, *arg_pointers)
if dst.is_dense():
return runtime.ranked_memref_to_numpy(arg_pointers[0][0])
# Return the sparse tensor pointer.
return arg_pointers[-1][0]
return engine
@dataclasses.dataclass(frozen=True)
class Access(IndexExpr):

View File

@ -279,13 +279,23 @@ def test_tensor_copy():
A.insert([1, 2], 6.0)
B = mlir_pytaco.Tensor([I, J])
B[i, j] = A[i, j]
passed = (B._assignment is not None)
passed += (B._engine is None)
try:
B.compute()
except ValueError as e:
passed += (str(e).startswith("Need to invoke compile"))
B.compile()
passed += (B._engine is not None)
B.compute()
passed += (B._assignment is None)
passed += (B._engine is None)
indices, values = B.get_coordinates_and_values()
passed = np.array_equal(indices, [[0, 1], [1, 2]])
passed += np.array_equal(indices, [[0, 1], [1, 2]])
passed += np.allclose(values, [5.0, 6.0])
# No temporary tensor is used.
passed += (B._stats.get_total() == 0)
# CHECK: Number of passed: 3
# CHECK: Number of passed: 9
print("Number of passed:", passed)