mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-02-10 11:23:52 +00:00
[mlir][sparse][taco] Add a few unary operations.
Add operations -, abs, ceil and floor to the index notation. Add test cases. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D121388
This commit is contained in:
parent
058c92f2a4
commit
30c5269d93
@ -0,0 +1,40 @@
|
||||
# RUN: SUPPORTLIB=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext %PYTHON %s | FileCheck %s
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
|
||||
_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(_SCRIPT_PATH)
|
||||
from tools import mlir_pytaco_api as pt
|
||||
|
||||
i, j = pt.get_index_vars(2)
|
||||
A = pt.tensor([2, 3])
|
||||
B = pt.tensor([2, 3])
|
||||
A.insert([0, 1], 10.3)
|
||||
A.insert([1, 1], 40.7)
|
||||
A.insert([0, 2], -11.3)
|
||||
A.insert([1, 2], -41.7)
|
||||
|
||||
B[i, j] = abs(A[i, j])
|
||||
indices, values = B.get_coordinates_and_values()
|
||||
passed = np.array_equal(indices, [[0, 1], [0, 2], [1, 1], [1, 2]])
|
||||
passed += np.allclose(values, [10.3, 11.3, 40.7, 41.7])
|
||||
|
||||
B[i, j] = pt.ceil(A[i, j])
|
||||
indices, values = B.get_coordinates_and_values()
|
||||
passed += np.array_equal(indices, [[0, 1], [0, 2], [1, 1], [1, 2]])
|
||||
passed += np.allclose(values, [11, -11, 41, -41])
|
||||
|
||||
B[i, j] = pt.floor(A[i, j])
|
||||
indices, values = B.get_coordinates_and_values()
|
||||
passed += np.array_equal(indices, [[0, 1], [0, 2], [1, 1], [1, 2]])
|
||||
passed += np.allclose(values, [10, -12, 40, -42])
|
||||
|
||||
B[i, j] = -A[i, j]
|
||||
indices, values = B.get_coordinates_and_values()
|
||||
passed += np.array_equal(indices, [[0, 1], [0, 2], [1, 1], [1, 2]])
|
||||
passed += np.allclose(values, [-10.3, 11.3, -40.7, 41.7])
|
||||
|
||||
# CHECK: Number of passed: 8
|
||||
print("Number of passed:", passed)
|
@ -53,6 +53,7 @@ _INDEX_BIT_WIDTH = 0
|
||||
_ENTRY_NAME = "main"
|
||||
|
||||
# Type aliases for type annotation.
|
||||
_UnaryOp = Callable[[Any], Any]
|
||||
_BinaryOp = Callable[[Any, Any], Any]
|
||||
_ExprVisitor = Callable[..., None]
|
||||
_ExprInfoDict = Dict["IndexExpr", "_ExprInfo"]
|
||||
@ -1223,6 +1224,14 @@ class IndexExpr(abc.ABC):
|
||||
raise ValueError(f"Expected IndexExpr: {rhs}")
|
||||
return _BinaryExpr(op, self, rhs)
|
||||
|
||||
def _build_unary_expr(self, op: _UnaryOp) -> "_UnaryExpr":
|
||||
"""Build a unary expression.
|
||||
|
||||
Args:
|
||||
op: A _UnaryOp object representing the unary operation.
|
||||
"""
|
||||
return _UnaryExpr(op, self)
|
||||
|
||||
def __add__(self, rhs) -> "_BinaryExpr":
|
||||
"""Defines the operator +.
|
||||
|
||||
@ -1253,6 +1262,22 @@ class IndexExpr(abc.ABC):
|
||||
"""
|
||||
return self._verify_operand_and_build_expr(rhs, operator.mul)
|
||||
|
||||
def __abs__(self) -> "_UnaryExpr":
|
||||
"""Defines the operator abs.
|
||||
|
||||
Returns:
|
||||
A _UnaryExpr object representing the operation.
|
||||
"""
|
||||
return self._build_unary_expr(operator.abs)
|
||||
|
||||
def __neg__(self) -> "_UnaryExpr":
|
||||
"""Defines the operator neg.
|
||||
|
||||
Returns:
|
||||
A _UnaryExpr object representing the operation.
|
||||
"""
|
||||
return self._build_unary_expr(operator.neg)
|
||||
|
||||
def __sub__(self, rhs) -> "_BinaryExpr":
|
||||
"""Defines the operator -.
|
||||
|
||||
@ -1603,6 +1628,75 @@ def _gather_input_accesses_index_vars(
|
||||
input_accesses.append(expr)
|
||||
|
||||
|
||||
def _op_ceil(__a: Any) -> Any:
|
||||
"""A _UnaryOp object for operation ceil."""
|
||||
pass
|
||||
|
||||
|
||||
def _op_floor(__a: Any) -> Any:
|
||||
"""A _UnaryOp object for operation floor."""
|
||||
pass
|
||||
|
||||
|
||||
def _op_unary_to_callable(op: _UnaryOp) -> lang.UnaryFnType:
|
||||
"""Returns the linalg dialect function object for the given operation."""
|
||||
op_to_callable = {
|
||||
operator.abs: lang.UnaryFn.abs,
|
||||
operator.neg: lang.UnaryFn.negf,
|
||||
_op_ceil: lang.UnaryFn.ceil,
|
||||
_op_floor: lang.UnaryFn.floor,
|
||||
}
|
||||
return op_to_callable[op]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class _UnaryExpr(IndexExpr):
|
||||
"""The representation for a Unary operation.
|
||||
|
||||
Attributes:
|
||||
op: A _UnaryOp representing the operation.
|
||||
a: An IndexExpr representing the operand for the operation.
|
||||
"""
|
||||
op: _BinaryOp
|
||||
a: IndexExpr
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Verifies that the operand being added is an IndexExpr."""
|
||||
assert isinstance(self.a, IndexExpr)
|
||||
|
||||
def _emit_expression(
|
||||
self,
|
||||
expr_to_opnd: Dict[IndexExpr, lang.OperandDef],
|
||||
expr_to_info: _ExprInfoDict,
|
||||
) -> lang.ScalarExpression:
|
||||
"""Emits the expression tree and returns the expression."""
|
||||
# The current expression node is an internal node of the structured op.
|
||||
if self not in expr_to_opnd:
|
||||
a = self.a._emit_expression(expr_to_opnd, expr_to_info)
|
||||
return _op_unary_to_callable(self.op)(a)
|
||||
|
||||
# The current expression is a leaf node of the structured op. That is, it is
|
||||
# a temporary tensor generated by its child structured op.
|
||||
op_info = expr_to_info[self].structop_info
|
||||
assert op_info is not None
|
||||
dims = _mlir_dimensions_from_index_vars(op_info.dst_indices)
|
||||
return lang.TensorUse(expr_to_opnd[self], dims)
|
||||
|
||||
def _visit(self,
|
||||
func: _ExprVisitor,
|
||||
args,
|
||||
*,
|
||||
leaf_checker: _SubtreeLeafChecker = None) -> None:
|
||||
"""A post-order visitor."""
|
||||
if leaf_checker is None or not leaf_checker(self, *args):
|
||||
self.a._visit(func, args, leaf_checker=leaf_checker)
|
||||
func(self, *args)
|
||||
|
||||
def dtype(self) -> DType:
|
||||
"""Returns the data type of the operation."""
|
||||
return self.a.dtype()
|
||||
|
||||
|
||||
def _op_to_callable(op: _BinaryOp) -> lang.BinaryFnType:
|
||||
"""Returns the linalg dialect function object for the given operation."""
|
||||
op_to_callable = {
|
||||
@ -1612,7 +1706,6 @@ def _op_to_callable(op: _BinaryOp) -> lang.BinaryFnType:
|
||||
}
|
||||
return op_to_callable[op]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class _BinaryExpr(IndexExpr):
|
||||
"""The representation for a binary operation.
|
||||
@ -1740,6 +1833,15 @@ def _validate_and_collect_expr_info(
|
||||
mode_formats = tuple(expr.tensor.format.format_pack.formats)
|
||||
assert len(src_dims) == len(mode_formats)
|
||||
dim_infos = tuple([_DimInfo(d, m) for d, m in zip(src_dims, mode_formats)])
|
||||
elif isinstance(expr, _UnaryExpr):
|
||||
a_info = expr_to_info[expr.a]
|
||||
index_to_dim_info = {
|
||||
i: d for i, d in zip(a_info.src_indices, a_info.dim_infos)
|
||||
}
|
||||
# Here we rely on the fact that dictionaries keep the insertion order for
|
||||
# keys and values.
|
||||
src_indices = tuple(index_to_dim_info.keys())
|
||||
dim_infos = tuple(index_to_dim_info.values())
|
||||
else:
|
||||
assert isinstance(expr, _BinaryExpr)
|
||||
a_info = expr_to_info[expr.a]
|
||||
@ -1826,6 +1928,10 @@ def _accumulate_reduce_indices(
|
||||
expr_info.acc_reduce_indices = (
|
||||
a_info.acc_reduce_indices | b_info.acc_reduce_indices
|
||||
| expr_info.reduce_indices)
|
||||
elif isinstance(expr, _UnaryExpr):
|
||||
a_info = expr_to_info[expr.a]
|
||||
expr_info.acc_reduce_indices = (
|
||||
a_info.acc_reduce_indices | expr_info.reduce_indices)
|
||||
else:
|
||||
assert isinstance(expr, Access)
|
||||
# Handle simple reduction expression in the format of A[i] = B[i, j].
|
||||
@ -1965,3 +2071,51 @@ def _emit_structured_op_input(
|
||||
opnd = lang.OperandDef(lang.OperandKind.INPUT_TENSOR, lang.T, dim_sym)
|
||||
op_def.add_operand(name, opnd)
|
||||
return opnd
|
||||
|
||||
|
||||
def _check_and_build_unary(a: Access, op: _UnaryOp) -> "_UnaryExpr":
|
||||
"""Build a unary operation ceil.
|
||||
|
||||
Args:
|
||||
a: The operand, which could be any Python object from user inputs.
|
||||
op: An _UnaryOp object representing the operation.
|
||||
|
||||
Returns:
|
||||
A _UnaryExpr object representing the operation.
|
||||
|
||||
Raises:
|
||||
ValueError: If a is not an IndexExpr.
|
||||
"""
|
||||
if not isinstance(a, Access):
|
||||
raise ValueError(f"Expected an Access Operand: {a}")
|
||||
return a._build_unary_expr(op)
|
||||
|
||||
|
||||
def ceil(a: Access) -> "_UnaryExpr":
|
||||
"""Defines the operation ceil.
|
||||
|
||||
Args:
|
||||
a: The operand, which could be any Python object from user inputs.
|
||||
|
||||
Returns:
|
||||
A _UnaryExpr object representing the operation.
|
||||
|
||||
Raises:
|
||||
ValueError: If a is not an IndexExpr.
|
||||
"""
|
||||
return _check_and_build_unary(a, _op_ceil)
|
||||
|
||||
|
||||
def floor(a: Access) -> "_UnaryExpr":
|
||||
"""Defines the operation floor.
|
||||
|
||||
Args:
|
||||
a: The operand, which could be any Python object from user inputs.
|
||||
|
||||
Returns:
|
||||
A _UnaryExpr object representing the operation.
|
||||
|
||||
Raises:
|
||||
ValueError: If a is not an IndexExpr.
|
||||
"""
|
||||
return _check_and_build_unary(a, _op_floor)
|
||||
|
@ -16,6 +16,8 @@ from . import mlir_pytaco
|
||||
from . import mlir_pytaco_io
|
||||
|
||||
# Functions defined by PyTACO API.
|
||||
ceil = mlir_pytaco.ceil
|
||||
floor = mlir_pytaco.floor
|
||||
get_index_vars = mlir_pytaco.get_index_vars
|
||||
from_array = mlir_pytaco.Tensor.from_array
|
||||
read = mlir_pytaco_io.read
|
||||
|
Loading…
x
Reference in New Issue
Block a user