[mlir][sparse][taco] Add a few unary operations.

Add operations -, abs, ceil and floor to the index notation.

Add test cases.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D121388
This commit is contained in:
Bixia Zheng 2022-03-10 09:40:54 -08:00
parent 058c92f2a4
commit 30c5269d93
3 changed files with 197 additions and 1 deletions

View File

@ -0,0 +1,40 @@
# RUN: SUPPORTLIB=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext %PYTHON %s | FileCheck %s
import numpy as np
import os
import sys
_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
sys.path.append(_SCRIPT_PATH)
from tools import mlir_pytaco_api as pt
i, j = pt.get_index_vars(2)
A = pt.tensor([2, 3])
B = pt.tensor([2, 3])
A.insert([0, 1], 10.3)
A.insert([1, 1], 40.7)
A.insert([0, 2], -11.3)
A.insert([1, 2], -41.7)
B[i, j] = abs(A[i, j])
indices, values = B.get_coordinates_and_values()
passed = np.array_equal(indices, [[0, 1], [0, 2], [1, 1], [1, 2]])
passed += np.allclose(values, [10.3, 11.3, 40.7, 41.7])
B[i, j] = pt.ceil(A[i, j])
indices, values = B.get_coordinates_and_values()
passed += np.array_equal(indices, [[0, 1], [0, 2], [1, 1], [1, 2]])
passed += np.allclose(values, [11, -11, 41, -41])
B[i, j] = pt.floor(A[i, j])
indices, values = B.get_coordinates_and_values()
passed += np.array_equal(indices, [[0, 1], [0, 2], [1, 1], [1, 2]])
passed += np.allclose(values, [10, -12, 40, -42])
B[i, j] = -A[i, j]
indices, values = B.get_coordinates_and_values()
passed += np.array_equal(indices, [[0, 1], [0, 2], [1, 1], [1, 2]])
passed += np.allclose(values, [-10.3, 11.3, -40.7, 41.7])
# CHECK: Number of passed: 8
print("Number of passed:", passed)

View File

@ -53,6 +53,7 @@ _INDEX_BIT_WIDTH = 0
_ENTRY_NAME = "main"
# Type aliases for type annotation.
_UnaryOp = Callable[[Any], Any]
_BinaryOp = Callable[[Any, Any], Any]
_ExprVisitor = Callable[..., None]
_ExprInfoDict = Dict["IndexExpr", "_ExprInfo"]
@ -1223,6 +1224,14 @@ class IndexExpr(abc.ABC):
raise ValueError(f"Expected IndexExpr: {rhs}")
return _BinaryExpr(op, self, rhs)
def _build_unary_expr(self, op: _UnaryOp) -> "_UnaryExpr":
"""Build a unary expression.
Args:
op: A _UnaryOp object representing the unary operation.
"""
return _UnaryExpr(op, self)
def __add__(self, rhs) -> "_BinaryExpr":
"""Defines the operator +.
@ -1253,6 +1262,22 @@ class IndexExpr(abc.ABC):
"""
return self._verify_operand_and_build_expr(rhs, operator.mul)
def __abs__(self) -> "_UnaryExpr":
"""Defines the operator abs.
Returns:
A _UnaryExpr object representing the operation.
"""
return self._build_unary_expr(operator.abs)
def __neg__(self) -> "_UnaryExpr":
"""Defines the operator neg.
Returns:
A _UnaryExpr object representing the operation.
"""
return self._build_unary_expr(operator.neg)
def __sub__(self, rhs) -> "_BinaryExpr":
"""Defines the operator -.
@ -1603,6 +1628,75 @@ def _gather_input_accesses_index_vars(
input_accesses.append(expr)
def _op_ceil(__a: Any) -> Any:
"""A _UnaryOp object for operation ceil."""
pass
def _op_floor(__a: Any) -> Any:
"""A _UnaryOp object for operation floor."""
pass
def _op_unary_to_callable(op: _UnaryOp) -> lang.UnaryFnType:
"""Returns the linalg dialect function object for the given operation."""
op_to_callable = {
operator.abs: lang.UnaryFn.abs,
operator.neg: lang.UnaryFn.negf,
_op_ceil: lang.UnaryFn.ceil,
_op_floor: lang.UnaryFn.floor,
}
return op_to_callable[op]
@dataclasses.dataclass(frozen=True)
class _UnaryExpr(IndexExpr):
"""The representation for a Unary operation.
Attributes:
op: A _UnaryOp representing the operation.
a: An IndexExpr representing the operand for the operation.
"""
op: _BinaryOp
a: IndexExpr
def __post_init__(self) -> None:
"""Verifies that the operand being added is an IndexExpr."""
assert isinstance(self.a, IndexExpr)
def _emit_expression(
self,
expr_to_opnd: Dict[IndexExpr, lang.OperandDef],
expr_to_info: _ExprInfoDict,
) -> lang.ScalarExpression:
"""Emits the expression tree and returns the expression."""
# The current expression node is an internal node of the structured op.
if self not in expr_to_opnd:
a = self.a._emit_expression(expr_to_opnd, expr_to_info)
return _op_unary_to_callable(self.op)(a)
# The current expression is a leaf node of the structured op. That is, it is
# a temporary tensor generated by its child structured op.
op_info = expr_to_info[self].structop_info
assert op_info is not None
dims = _mlir_dimensions_from_index_vars(op_info.dst_indices)
return lang.TensorUse(expr_to_opnd[self], dims)
def _visit(self,
func: _ExprVisitor,
args,
*,
leaf_checker: _SubtreeLeafChecker = None) -> None:
"""A post-order visitor."""
if leaf_checker is None or not leaf_checker(self, *args):
self.a._visit(func, args, leaf_checker=leaf_checker)
func(self, *args)
def dtype(self) -> DType:
"""Returns the data type of the operation."""
return self.a.dtype()
def _op_to_callable(op: _BinaryOp) -> lang.BinaryFnType:
"""Returns the linalg dialect function object for the given operation."""
op_to_callable = {
@ -1612,7 +1706,6 @@ def _op_to_callable(op: _BinaryOp) -> lang.BinaryFnType:
}
return op_to_callable[op]
@dataclasses.dataclass(frozen=True)
class _BinaryExpr(IndexExpr):
"""The representation for a binary operation.
@ -1740,6 +1833,15 @@ def _validate_and_collect_expr_info(
mode_formats = tuple(expr.tensor.format.format_pack.formats)
assert len(src_dims) == len(mode_formats)
dim_infos = tuple([_DimInfo(d, m) for d, m in zip(src_dims, mode_formats)])
elif isinstance(expr, _UnaryExpr):
a_info = expr_to_info[expr.a]
index_to_dim_info = {
i: d for i, d in zip(a_info.src_indices, a_info.dim_infos)
}
# Here we rely on the fact that dictionaries keep the insertion order for
# keys and values.
src_indices = tuple(index_to_dim_info.keys())
dim_infos = tuple(index_to_dim_info.values())
else:
assert isinstance(expr, _BinaryExpr)
a_info = expr_to_info[expr.a]
@ -1826,6 +1928,10 @@ def _accumulate_reduce_indices(
expr_info.acc_reduce_indices = (
a_info.acc_reduce_indices | b_info.acc_reduce_indices
| expr_info.reduce_indices)
elif isinstance(expr, _UnaryExpr):
a_info = expr_to_info[expr.a]
expr_info.acc_reduce_indices = (
a_info.acc_reduce_indices | expr_info.reduce_indices)
else:
assert isinstance(expr, Access)
# Handle simple reduction expression in the format of A[i] = B[i, j].
@ -1965,3 +2071,51 @@ def _emit_structured_op_input(
opnd = lang.OperandDef(lang.OperandKind.INPUT_TENSOR, lang.T, dim_sym)
op_def.add_operand(name, opnd)
return opnd
def _check_and_build_unary(a: Access, op: _UnaryOp) -> "_UnaryExpr":
"""Build a unary operation ceil.
Args:
a: The operand, which could be any Python object from user inputs.
op: An _UnaryOp object representing the operation.
Returns:
A _UnaryExpr object representing the operation.
Raises:
ValueError: If a is not an IndexExpr.
"""
if not isinstance(a, Access):
raise ValueError(f"Expected an Access Operand: {a}")
return a._build_unary_expr(op)
def ceil(a: Access) -> "_UnaryExpr":
"""Defines the operation ceil.
Args:
a: The operand, which could be any Python object from user inputs.
Returns:
A _UnaryExpr object representing the operation.
Raises:
ValueError: If a is not an IndexExpr.
"""
return _check_and_build_unary(a, _op_ceil)
def floor(a: Access) -> "_UnaryExpr":
"""Defines the operation floor.
Args:
a: The operand, which could be any Python object from user inputs.
Returns:
A _UnaryExpr object representing the operation.
Raises:
ValueError: If a is not an IndexExpr.
"""
return _check_and_build_unary(a, _op_floor)

View File

@ -16,6 +16,8 @@ from . import mlir_pytaco
from . import mlir_pytaco_io
# Functions defined by PyTACO API.
ceil = mlir_pytaco.ceil
floor = mlir_pytaco.floor
get_index_vars = mlir_pytaco.get_index_vars
from_array = mlir_pytaco.Tensor.from_array
read = mlir_pytaco_io.read