mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-11-29 16:41:27 +00:00
Split ElementwiseMappable
trait into four more precise traits.
Some elementwise operations are not scalarizable, vectorizable, or tensorizable. Split `ElementwiseMappable` trait into the following, more precise traits. - `Elementwise` - `Scalarizable` - `Vectorizable` - `Tensorizable` This allows for reuse of `Elementwise` in dialects like HLO. Differential Revision: https://reviews.llvm.org/D97674
This commit is contained in:
parent
7fce3322a2
commit
bcc9b371e4
@ -20,11 +20,9 @@ class Complex_Op<string mnemonic, list<OpTrait> traits = []>
|
||||
// floating-point element type. These operations take two operands and return
|
||||
// one result, all of which must be complex numbers of the same type.
|
||||
class ComplexArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
Complex_Op<mnemonic,
|
||||
!listconcat(traits, [NoSideEffect,
|
||||
SameOperandsAndResultType,
|
||||
DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
|
||||
ElementwiseMappable])> {
|
||||
Complex_Op<mnemonic, traits # [NoSideEffect, SameOperandsAndResultType,
|
||||
DeclareOpInterfaceMethods<VectorUnrollOpInterface>] #
|
||||
ElementwiseMappable.traits> {
|
||||
let arguments = (ins Complex<AnyFloat>:$lhs, Complex<AnyFloat>:$rhs);
|
||||
let results = (outs Complex<AnyFloat>:$result);
|
||||
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)";
|
||||
|
@ -17,10 +17,9 @@ class MathOp<string mnemonic, list<OpTrait> traits = []>
|
||||
: Op<Math_Dialect, mnemonic, traits # [NoSideEffect]>;
|
||||
|
||||
class FloatUnaryOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
MathOp<mnemonic,
|
||||
traits # [DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
|
||||
ElementwiseMappable,
|
||||
SameOperandsAndResultType]> {
|
||||
MathOp<mnemonic, traits #
|
||||
[DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
|
||||
SameOperandsAndResultType] # ElementwiseMappable.traits> {
|
||||
let arguments = (ins FloatLike:$operand);
|
||||
|
||||
let results = (outs FloatLike:$result);
|
||||
@ -29,10 +28,9 @@ class FloatUnaryOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
}
|
||||
|
||||
class FloatBinaryOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
MathOp<mnemonic,
|
||||
traits # [DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
|
||||
ElementwiseMappable,
|
||||
SameOperandsAndResultType]> {
|
||||
MathOp<mnemonic, traits # [
|
||||
DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
|
||||
SameOperandsAndResultType] # ElementwiseMappable.traits> {
|
||||
let arguments = (ins FloatLike:$lhs, FloatLike:$rhs);
|
||||
let results = (outs FloatLike:$result);
|
||||
|
||||
|
@ -71,9 +71,9 @@ class CastOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
|
||||
// Base class for arithmetic cast operations.
|
||||
class ArithmeticCastOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
CastOp<mnemonic,
|
||||
!listconcat(traits, [ElementwiseMappable,
|
||||
DeclareOpInterfaceMethods<VectorUnrollOpInterface>])> {
|
||||
CastOp<mnemonic, traits # [
|
||||
DeclareOpInterfaceMethods<VectorUnrollOpInterface>] #
|
||||
ElementwiseMappable.traits> {
|
||||
}
|
||||
|
||||
// Base class for unary ops. Requires single operand and result. Individual
|
||||
@ -95,21 +95,18 @@ class UnaryOpSameOperandAndResultType<string mnemonic,
|
||||
}
|
||||
|
||||
class FloatUnaryOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
UnaryOpSameOperandAndResultType<mnemonic,
|
||||
!listconcat(traits,
|
||||
[DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
|
||||
ElementwiseMappable])>,
|
||||
Arguments<(ins FloatLike:$operand)>;
|
||||
UnaryOpSameOperandAndResultType<mnemonic, traits # [
|
||||
DeclareOpInterfaceMethods<VectorUnrollOpInterface>] #
|
||||
ElementwiseMappable.traits>, Arguments<(ins FloatLike:$operand)>;
|
||||
|
||||
// Base class for standard arithmetic operations. Requires operands and
|
||||
// results to be of the same type, but does not constrain them to specific
|
||||
// types.
|
||||
class ArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<StandardOps_Dialect, mnemonic,
|
||||
!listconcat(traits, [NoSideEffect,
|
||||
SameOperandsAndResultType,
|
||||
DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
|
||||
ElementwiseMappable])> {
|
||||
Op<StandardOps_Dialect, mnemonic, traits # [NoSideEffect,
|
||||
SameOperandsAndResultType,
|
||||
DeclareOpInterfaceMethods<VectorUnrollOpInterface>] #
|
||||
ElementwiseMappable.traits> {
|
||||
|
||||
let results = (outs AnyType:$result);
|
||||
|
||||
@ -930,12 +927,10 @@ def CmpFPredicateAttr : I64EnumAttr<
|
||||
let cppNamespace = "::mlir";
|
||||
}
|
||||
|
||||
def CmpFOp : Std_Op<"cmpf",
|
||||
[NoSideEffect, SameTypeOperands, ElementwiseMappable,
|
||||
DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
|
||||
TypesMatchWith<
|
||||
"result type has i1 element type and same shape as operands",
|
||||
"lhs", "result", "getI1SameShape($_self)">]> {
|
||||
def CmpFOp : Std_Op<"cmpf", [NoSideEffect, SameTypeOperands,
|
||||
DeclareOpInterfaceMethods<VectorUnrollOpInterface>, TypesMatchWith<
|
||||
"result type has i1 element type and same shape as operands",
|
||||
"lhs", "result", "getI1SameShape($_self)">] # ElementwiseMappable.traits> {
|
||||
let summary = "floating-point comparison operation";
|
||||
let description = [{
|
||||
The `cmpf` operation compares its two operands according to the float
|
||||
@ -1015,12 +1010,10 @@ def CmpIPredicateAttr : I64EnumAttr<
|
||||
let cppNamespace = "::mlir";
|
||||
}
|
||||
|
||||
def CmpIOp : Std_Op<"cmpi",
|
||||
[NoSideEffect, SameTypeOperands, ElementwiseMappable,
|
||||
DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
|
||||
TypesMatchWith<
|
||||
"result type has i1 element type and same shape as operands",
|
||||
"lhs", "result", "getI1SameShape($_self)">]> {
|
||||
def CmpIOp : Std_Op<"cmpi", [NoSideEffect, SameTypeOperands,
|
||||
DeclareOpInterfaceMethods<VectorUnrollOpInterface>, TypesMatchWith<
|
||||
"result type has i1 element type and same shape as operands",
|
||||
"lhs", "result", "getI1SameShape($_self)">] # ElementwiseMappable.traits> {
|
||||
let summary = "integer comparison operation";
|
||||
let description = [{
|
||||
The `cmpi` operation is a generic comparison for integer-like types. Its two
|
||||
@ -2160,8 +2153,9 @@ def ReturnOp : Std_Op<"return", [NoSideEffect, HasParent<"FuncOp">,
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SelectOp : Std_Op<"select", [NoSideEffect,
|
||||
AllTypesMatch<["true_value", "false_value", "result"]>,
|
||||
ElementwiseMappable, DeclareOpInterfaceMethods<VectorUnrollOpInterface>]> {
|
||||
AllTypesMatch<["true_value", "false_value", "result"]>,
|
||||
DeclareOpInterfaceMethods<VectorUnrollOpInterface>] #
|
||||
ElementwiseMappable.traits> {
|
||||
let summary = "select operation";
|
||||
let description = [{
|
||||
The `select` operation chooses one value based on a binary condition
|
||||
@ -2391,9 +2385,9 @@ def SignedShiftRightOp : IntBinaryOp<"shift_right_signed"> {
|
||||
// SignExtendIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SignExtendIOp : Std_Op<"sexti",
|
||||
[NoSideEffect, ElementwiseMappable,
|
||||
DeclareOpInterfaceMethods<VectorUnrollOpInterface>]> {
|
||||
def SignExtendIOp : Std_Op<"sexti", [NoSideEffect,
|
||||
DeclareOpInterfaceMethods<VectorUnrollOpInterface>] #
|
||||
ElementwiseMappable.traits> {
|
||||
let summary = "integer sign extension operation";
|
||||
let description = [{
|
||||
The integer sign extension operation takes an integer input of
|
||||
@ -3220,9 +3214,9 @@ def TransposeOp : Std_Op<"transpose", [NoSideEffect]>,
|
||||
// TruncateIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def TruncateIOp : Std_Op<"trunci",
|
||||
[NoSideEffect, ElementwiseMappable,
|
||||
DeclareOpInterfaceMethods<VectorUnrollOpInterface>,]> {
|
||||
def TruncateIOp : Std_Op<"trunci", [NoSideEffect,
|
||||
DeclareOpInterfaceMethods<VectorUnrollOpInterface>] #
|
||||
ElementwiseMappable.traits> {
|
||||
let summary = "integer truncation operation";
|
||||
let description = [{
|
||||
The integer truncation operation takes an integer input of
|
||||
@ -3463,9 +3457,9 @@ def XOrOp : IntBinaryOp<"xor", [Commutative]> {
|
||||
// ZeroExtendIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ZeroExtendIOp : Std_Op<"zexti",
|
||||
[NoSideEffect, ElementwiseMappable,
|
||||
DeclareOpInterfaceMethods<VectorUnrollOpInterface>,]> {
|
||||
def ZeroExtendIOp : Std_Op<"zexti", [NoSideEffect,
|
||||
DeclareOpInterfaceMethods<VectorUnrollOpInterface>] #
|
||||
ElementwiseMappable.traits> {
|
||||
let summary = "integer zero extension operation";
|
||||
let description = [{
|
||||
The integer zero extension operation takes an integer input of
|
||||
|
@ -1785,9 +1785,25 @@ def Terminator : NativeOpTrait<"IsTerminator">;
|
||||
// Op can be safely normalized in the presence of MemRefs with
|
||||
// non-identity maps.
|
||||
def MemRefsNormalizable : NativeOpTrait<"MemRefsNormalizable">;
|
||||
// Op can be systematically interconverted between scalar and vector/tensor
|
||||
// form by mapping elementwise based on the type.
|
||||
def ElementwiseMappable : NativeOpTrait<"ElementwiseMappable">;
|
||||
// Op is elementwise on tensor/vector operands and results.
|
||||
def Elementwise : NativeOpTrait<"Elementwise">;
|
||||
// Elementwise op can be applied to scalars instead tensor/vector operands.
|
||||
def Scalarizable : NativeOpTrait<"Scalarizable">;
|
||||
// Elementwise op can be applied to all-vector operands.
|
||||
def Vectorizable : NativeOpTrait<"Vectorizable">;
|
||||
// Elementwise op can be applied to all-tensor operands.
|
||||
def Tensorizable : NativeOpTrait<"Tensorizable">;
|
||||
|
||||
// Group together `Elementwise`, `Scalarizable`, `Vectorizable`, and
|
||||
// `Tensorizable` for convenience.
|
||||
def ElementwiseMappable {
|
||||
list<OpTrait> traits = [
|
||||
Elementwise,
|
||||
Scalarizable,
|
||||
Vectorizable,
|
||||
Tensorizable,
|
||||
];
|
||||
}
|
||||
|
||||
// Op's regions have a single block with the specified terminator.
|
||||
class SingleBlockImplicitTerminator<string op>
|
||||
|
@ -282,7 +282,7 @@ LogicalResult verifyAtLeastNSuccessors(Operation *op, unsigned numSuccessors);
|
||||
LogicalResult verifyOperandSizeAttr(Operation *op, StringRef sizeAttrName);
|
||||
LogicalResult verifyResultSizeAttr(Operation *op, StringRef sizeAttrName);
|
||||
LogicalResult verifyNoRegionArguments(Operation *op);
|
||||
LogicalResult verifyElementwiseMappable(Operation *op);
|
||||
LogicalResult verifyElementwise(Operation *op);
|
||||
} // namespace impl
|
||||
|
||||
/// Helper class for implementing traits. Clients are not expected to interact
|
||||
@ -1213,93 +1213,144 @@ template <typename ConcrentType>
|
||||
struct MemRefsNormalizable
|
||||
: public TraitBase<ConcrentType, MemRefsNormalizable> {};
|
||||
|
||||
/// This trait tags scalar ops that also can be applied to vectors/tensors, with
|
||||
/// their semantics on vectors/tensors being elementwise application.
|
||||
/// This trait tags element-wise ops that operate on scalars, vectors, or
|
||||
/// tensors.
|
||||
///
|
||||
/// NOTE: Not all ops that are "elementwise" in some abstract sense satisfy this
|
||||
/// trait. In particular, broadcasting behavior is not allowed. This trait
|
||||
/// describes a set of invariants that allow systematic
|
||||
/// vectorization/tensorization, and the reverse, scalarization. The properties
|
||||
/// needed for this also can be used to implement a number of
|
||||
/// transformations/analyses/interfaces.
|
||||
/// trait. In particular, broadcasting behavior is not allowed.
|
||||
///
|
||||
/// An `ElementwiseMappable` op must satisfy the following properties:
|
||||
/// An `Elementwise` op must satisfy the following properties:
|
||||
///
|
||||
/// 1. If any result is a vector (resp. tensor), then at least one operand must
|
||||
/// be a vector (resp. tensor).
|
||||
/// 2. If any operand is a vector (resp. tensor), then there must be at least
|
||||
/// one result, and all results must be vectors (resp. tensors).
|
||||
/// 3. The static types of all vector (resp. tensor) operands and results must
|
||||
/// have the same shape.
|
||||
/// 4. In the case of tensor operands, the dynamic shapes of all tensor operands
|
||||
/// must be the same, otherwise the op has undefined behavior.
|
||||
/// 5. ("systematic scalarization" property) If an op has vector/tensor
|
||||
/// operands/results, then the same op, with the operand/result types changed to
|
||||
/// their corresponding element type, shall be a verifier-valid op.
|
||||
/// 6. The semantics of the op on vectors (resp. tensors) shall be the same as
|
||||
/// applying the scalarized version of the op for each corresponding element of
|
||||
/// the vector (resp. tensor) operands in parallel.
|
||||
/// 7. ("systematic vectorization/tensorization" property) If an op has
|
||||
/// scalar operands/results, the op shall remain verifier-valid if all scalar
|
||||
/// operands are replaced with vectors/tensors of the same shape and
|
||||
/// corresponding element types.
|
||||
/// 1. If any result is a vector/tensor then at least one operand must also be a
|
||||
/// vector/tensor.
|
||||
/// 2. If any operand is a vector/tensor then there must be at least one result
|
||||
/// and all results must be vectors/tensors.
|
||||
/// 3. All operand and result vector/tensor types must be of the same shape. The
|
||||
/// shape may be dynamic in which case the op's behaviour is undefined for
|
||||
/// non-matching shapes.
|
||||
/// 4. The operation must be elementwise on its vector/tensor operands and
|
||||
/// results. When applied to single-element vectors/tensors, the result must
|
||||
/// be the same per elememnt.
|
||||
///
|
||||
/// Together, these properties provide an easy way for scalar operations to
|
||||
/// conveniently generalize their behavior to vectors/tensors, and systematize
|
||||
/// conversion between these forms.
|
||||
/// TODO: Avoid hardcoding vector/tensor, and generalize this trait to a new
|
||||
/// interface `ElementwiseTypeInterface` that describes the container types for
|
||||
/// which the operation is elementwise.
|
||||
///
|
||||
/// Rationale:
|
||||
/// - 1. and 2. guarantee a well-defined iteration space and exclude the cases
|
||||
/// of 0 non-scalar operands or 0 non-scalar results, which complicate a
|
||||
/// generic definition of the iteration space.
|
||||
/// - 3. guarantees that folding can be done across scalars/vectors/tensors with
|
||||
/// the same pattern, as otherwise lots of special handling for type
|
||||
/// mismatches would be needed.
|
||||
/// - 4. guarantees that no error handling is needed. Higher-level dialects
|
||||
/// should reify any needed guards or error handling code before lowering to
|
||||
/// an `Elementwise` op.
|
||||
template <typename ConcreteType>
|
||||
struct Elementwise : public TraitBase<ConcreteType, Elementwise> {
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
return ::mlir::OpTrait::impl::verifyElementwise(op);
|
||||
}
|
||||
};
|
||||
|
||||
/// This trait tags `Elementwise` operatons that can be systematically
|
||||
/// scalarized. All vector/tensor operands and results are then replaced by
|
||||
/// scalars of the respective element type. Semantically, this is the operation
|
||||
/// on a single element per vector/tensor.
|
||||
///
|
||||
/// Rationale:
|
||||
/// Allow to define the vector/tensor semantics of elementwise operations based
|
||||
/// on scalars. This provides a constructive procedure for IR transformations
|
||||
/// to, e.g., create scalar loop bodies from tensor ops.
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
/// %tensor_select = "std.select"(%pred_tensor, %true_val, %false_val)
|
||||
/// : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>)
|
||||
/// -> tensor<?xf32>
|
||||
/// ```
|
||||
/// can be scalarized to
|
||||
///
|
||||
/// ```
|
||||
/// %scalar_select = "std.select"(%pred, %true_val_scalar, %false_val_scalar)
|
||||
/// : (i1, f32, f32) -> f32
|
||||
/// ```
|
||||
template <typename ConcreteType>
|
||||
struct Scalarizable : public TraitBase<ConcreteType, Scalarizable> {
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
static_assert(
|
||||
ConcreteType::template hasTrait<Elementwise>(),
|
||||
"`Scalarizable` trait is only applicable to `Elementwise` ops.");
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// This trait tags `Elementwise` operatons that can be systematically
|
||||
/// vectorized. All scalar operands and results are then replaced by vectors
|
||||
/// with the respective element type. Semantically, this is the operation on
|
||||
/// multiple arguments simultaneously.
|
||||
///
|
||||
/// Rationale:
|
||||
/// Provide the reverse to `Scalarizable` which, when chained together, allows
|
||||
/// reasoning about the relationship between the tensor and vector case.
|
||||
/// Additionally, it permits reasoning about promoting scalars to vectors via
|
||||
/// broadcasting in cases like `%select_scalar_pred` above.
|
||||
template <typename ConcreteType>
|
||||
struct Vectorizable : public TraitBase<ConcreteType, Vectorizable> {
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
static_assert(
|
||||
ConcreteType::template hasTrait<Elementwise>(),
|
||||
"`Vectorizable` trait is only applicable to `Elementwise` ops.");
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// This trait tags `Elementwise` operatons that can be systematically
|
||||
/// tensorized. All scalar operands and results are then replaced by tensors
|
||||
/// with the respective element type. Semantically, this is the operation on
|
||||
/// multiple arguments simultaneously.
|
||||
///
|
||||
/// Rationale:
|
||||
/// Provide the reverse to `Scalarizable` which, when chained together, allows
|
||||
/// reasoning about the relationship between the tensor and vector case.
|
||||
/// Additionally, it permits reasoning about promoting scalars to tensors via
|
||||
/// broadcasting in cases like `%select_scalar_pred` above.
|
||||
///
|
||||
/// Examples:
|
||||
/// ```
|
||||
/// %scalar = "std.addf"(%a, %b) : (f32, f32) -> f32
|
||||
/// // Applying the systematic vectorization/tensorization property, this op
|
||||
/// // must also be valid:
|
||||
/// %tensor = "std.addf"(%a_tensor, %b_tensor)
|
||||
/// : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>)
|
||||
///
|
||||
/// // These properties generalize well to the cases of non-scalar operands.
|
||||
/// %select_scalar_pred = "std.select"(%pred, %true_val, %false_val)
|
||||
/// : (i1, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
/// // Applying the systematic vectorization / tensorization property, this
|
||||
/// // op must also be valid:
|
||||
/// %select_tensor_pred = "std.select"(%pred_tensor, %true_val, %false_val)
|
||||
/// : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>)
|
||||
/// -> tensor<?xf32>
|
||||
/// // Applying the systematic scalarization property, this op must also
|
||||
/// // be valid.
|
||||
/// %select_scalar = "std.select"(%pred, %true_val_scalar, %false_val_scalar)
|
||||
/// : (i1, f32, f32) -> f32
|
||||
/// ```
|
||||
/// can be tensorized to
|
||||
/// ```
|
||||
/// %tensor = "std.addf"(%a, %b) : (tensor<?xf32>, tensor<?xf32>)
|
||||
/// -> tensor<?xf32>)
|
||||
/// ```
|
||||
///
|
||||
/// TODO: Avoid hardcoding vector/tensor, and generalize this to any type
|
||||
/// implementing a new "ElementwiseMappableTypeInterface" that describes types
|
||||
/// for which it makes sense to apply a scalar function to each element.
|
||||
///
|
||||
/// Rationale:
|
||||
/// - 1. and 2. guarantee a well-defined iteration space for 6.
|
||||
/// - These also exclude the cases of 0 non-scalar operands or 0 non-scalar
|
||||
/// results, which complicate a generic definition of the iteration space.
|
||||
/// - 3. guarantees that folding can be done across scalars/vectors/tensors
|
||||
/// with the same pattern, as otherwise lots of special handling of type
|
||||
/// mismatches would be needed.
|
||||
/// - 4. guarantees that no error handling cases need to be considered.
|
||||
/// - Higher-level dialects should reify any needed guards / error handling
|
||||
/// code before lowering to an ElementwiseMappable op.
|
||||
/// - 5. and 6. allow defining the semantics on vectors/tensors via the scalar
|
||||
/// semantics and provide a constructive procedure for IR transformations
|
||||
/// to e.g. create scalar loop bodies from tensor ops.
|
||||
/// - 7. provides the reverse of 5., which when chained together allows
|
||||
/// reasoning about the relationship between the tensor and vector case.
|
||||
/// Additionally, it permits reasoning about promoting scalars to
|
||||
/// vectors/tensors via broadcasting in cases like `%select_scalar_pred`
|
||||
/// above.
|
||||
/// ```
|
||||
/// %scalar_pred = "std.select"(%pred, %true_val, %false_val)
|
||||
/// : (i1, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
/// ```
|
||||
/// can be tensorized to
|
||||
/// ```
|
||||
/// %tensor_pred = "std.select"(%pred, %true_val, %false_val)
|
||||
/// : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>)
|
||||
/// -> tensor<?xf32>
|
||||
/// ```
|
||||
template <typename ConcreteType>
|
||||
struct ElementwiseMappable
|
||||
: public TraitBase<ConcreteType, ElementwiseMappable> {
|
||||
struct Tensorizable : public TraitBase<ConcreteType, Tensorizable> {
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
return ::mlir::OpTrait::impl::verifyElementwiseMappable(op);
|
||||
static_assert(
|
||||
ConcreteType::template hasTrait<Elementwise>(),
|
||||
"`Tensorizable` trait is only applicable to `Elementwise` ops.");
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Together, `Elementwise`, `Scalarizable`, `Vectorizable`, and `Tensorizable`
|
||||
/// provide an easy way for scalar operations to conveniently generalize their
|
||||
/// behavior to vectors/tensors, and systematize conversion between these forms.
|
||||
bool hasElementwiseMappableTraits(Operation *op);
|
||||
|
||||
} // end namespace OpTrait
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -18,7 +18,7 @@
|
||||
using namespace mlir;
|
||||
|
||||
static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {
|
||||
if (!op->hasTrait<OpTrait::ElementwiseMappable>())
|
||||
if (!OpTrait::hasElementwiseMappableTraits(op))
|
||||
return false;
|
||||
|
||||
// TODO: The conversion pattern can be made to work for `any_of` here, but
|
||||
|
@ -205,7 +205,7 @@ vectorizeOneOp(OpBuilder &builder, Operation *op,
|
||||
return VectorizationResult{VectorizationStatus::NewOp, builder.clone(*op)};
|
||||
|
||||
// 3. Only ElementwiseMappable are allowed in the generic vectorization.
|
||||
if (!op->hasTrait<OpTrait::ElementwiseMappable>())
|
||||
if (!OpTrait::hasElementwiseMappableTraits(op))
|
||||
return VectorizationResult{VectorizationStatus::Failure, nullptr};
|
||||
|
||||
// 4. Generic vectorization path for ElementwiseMappable ops.
|
||||
@ -323,7 +323,7 @@ static bool hasOnlyScalarElementwiseOp(Region &r) {
|
||||
return false;
|
||||
for (Operation &op : r.front()) {
|
||||
if (!(isa<ConstantOp, linalg::YieldOp>(op) ||
|
||||
op.hasTrait<OpTrait::ElementwiseMappable>()) ||
|
||||
OpTrait::hasElementwiseMappableTraits(&op)) ||
|
||||
llvm::any_of(op.getResultTypes(),
|
||||
[](Type type) { return !type.isIntOrIndexOrFloat(); }))
|
||||
return false;
|
||||
|
@ -1085,7 +1085,7 @@ static bool areSameShapedTypeIgnoringElementType(ShapedType a, ShapedType b) {
|
||||
return a.getShape() == b.getShape();
|
||||
}
|
||||
|
||||
LogicalResult OpTrait::impl::verifyElementwiseMappable(Operation *op) {
|
||||
LogicalResult OpTrait::impl::verifyElementwise(Operation *op) {
|
||||
auto isMappableType = [](Type type) {
|
||||
return type.isa<VectorType, TensorType>();
|
||||
};
|
||||
@ -1127,6 +1127,11 @@ LogicalResult OpTrait::impl::verifyElementwiseMappable(Operation *op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
bool OpTrait::hasElementwiseMappableTraits(Operation *op) {
|
||||
return op->hasTrait<Elementwise>() && op->hasTrait<Scalarizable>() &&
|
||||
op->hasTrait<Vectorizable>() && op->hasTrait<Tensorizable>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BinaryOp implementation
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -356,7 +356,8 @@ def SameOperandElementTypeOp : TEST_Op<"same_operand_element_type",
|
||||
let results = (outs AnyType);
|
||||
}
|
||||
|
||||
def SameOperandAndResultElementTypeOp : TEST_Op<"same_operand_and_result_element_type",
|
||||
def SameOperandAndResultElementTypeOp :
|
||||
TEST_Op<"same_operand_and_result_element_type",
|
||||
[SameOperandsAndResultElementType]> {
|
||||
let arguments = (ins Variadic<AnyType>);
|
||||
let results = (outs Variadic<AnyType>);
|
||||
@ -379,7 +380,7 @@ def SameOperandAndResultTypeOp : TEST_Op<"same_operand_and_result_type",
|
||||
}
|
||||
|
||||
def ElementwiseMappableOp : TEST_Op<"elementwise_mappable",
|
||||
[ElementwiseMappable]> {
|
||||
ElementwiseMappable.traits> {
|
||||
let arguments = (ins Variadic<AnyType>);
|
||||
let results = (outs Variadic<AnyType>);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user