Split ElementwiseMappable trait into four more precise traits.

Some elementwise operations are not scalarizable, vectorizable, or tensorizable.
Split `ElementwiseMappable` trait into the following, more precise traits.
  - `Elementwise`
  - `Scalarizable`
  - `Vectorizable`
  - `Tensorizable`
This allows for reuse of `Elementwise` in dialects like HLO.

Differential Revision: https://reviews.llvm.org/D97674
This commit is contained in:
Frederik Gossen 2021-03-02 15:29:08 +01:00
parent 7fce3322a2
commit bcc9b371e4
9 changed files with 193 additions and 130 deletions

View File

@ -20,11 +20,9 @@ class Complex_Op<string mnemonic, list<OpTrait> traits = []>
// floating-point element type. These operations take two operands and return
// one result, all of which must be complex numbers of the same type.
class ComplexArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
Complex_Op<mnemonic,
!listconcat(traits, [NoSideEffect,
SameOperandsAndResultType,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
ElementwiseMappable])> {
Complex_Op<mnemonic, traits # [NoSideEffect, SameOperandsAndResultType,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>] #
ElementwiseMappable.traits> {
let arguments = (ins Complex<AnyFloat>:$lhs, Complex<AnyFloat>:$rhs);
let results = (outs Complex<AnyFloat>:$result);
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)";

View File

@ -17,10 +17,9 @@ class MathOp<string mnemonic, list<OpTrait> traits = []>
: Op<Math_Dialect, mnemonic, traits # [NoSideEffect]>;
class FloatUnaryOp<string mnemonic, list<OpTrait> traits = []> :
MathOp<mnemonic,
traits # [DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
ElementwiseMappable,
SameOperandsAndResultType]> {
MathOp<mnemonic, traits #
[DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
SameOperandsAndResultType] # ElementwiseMappable.traits> {
let arguments = (ins FloatLike:$operand);
let results = (outs FloatLike:$result);
@ -29,10 +28,9 @@ class FloatUnaryOp<string mnemonic, list<OpTrait> traits = []> :
}
class FloatBinaryOp<string mnemonic, list<OpTrait> traits = []> :
MathOp<mnemonic,
traits # [DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
ElementwiseMappable,
SameOperandsAndResultType]> {
MathOp<mnemonic, traits # [
DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
SameOperandsAndResultType] # ElementwiseMappable.traits> {
let arguments = (ins FloatLike:$lhs, FloatLike:$rhs);
let results = (outs FloatLike:$result);

View File

@ -71,9 +71,9 @@ class CastOp<string mnemonic, list<OpTrait> traits = []> :
// Base class for arithmetic cast operations.
class ArithmeticCastOp<string mnemonic, list<OpTrait> traits = []> :
CastOp<mnemonic,
!listconcat(traits, [ElementwiseMappable,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>])> {
CastOp<mnemonic, traits # [
DeclareOpInterfaceMethods<VectorUnrollOpInterface>] #
ElementwiseMappable.traits> {
}
// Base class for unary ops. Requires single operand and result. Individual
@ -95,21 +95,18 @@ class UnaryOpSameOperandAndResultType<string mnemonic,
}
class FloatUnaryOp<string mnemonic, list<OpTrait> traits = []> :
UnaryOpSameOperandAndResultType<mnemonic,
!listconcat(traits,
[DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
ElementwiseMappable])>,
Arguments<(ins FloatLike:$operand)>;
UnaryOpSameOperandAndResultType<mnemonic, traits # [
DeclareOpInterfaceMethods<VectorUnrollOpInterface>] #
ElementwiseMappable.traits>, Arguments<(ins FloatLike:$operand)>;
// Base class for standard arithmetic operations. Requires operands and
// results to be of the same type, but does not constrain them to specific
// types.
class ArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
Op<StandardOps_Dialect, mnemonic,
!listconcat(traits, [NoSideEffect,
SameOperandsAndResultType,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
ElementwiseMappable])> {
Op<StandardOps_Dialect, mnemonic, traits # [NoSideEffect,
SameOperandsAndResultType,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>] #
ElementwiseMappable.traits> {
let results = (outs AnyType:$result);
@ -930,12 +927,10 @@ def CmpFPredicateAttr : I64EnumAttr<
let cppNamespace = "::mlir";
}
def CmpFOp : Std_Op<"cmpf",
[NoSideEffect, SameTypeOperands, ElementwiseMappable,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
TypesMatchWith<
"result type has i1 element type and same shape as operands",
"lhs", "result", "getI1SameShape($_self)">]> {
def CmpFOp : Std_Op<"cmpf", [NoSideEffect, SameTypeOperands,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>, TypesMatchWith<
"result type has i1 element type and same shape as operands",
"lhs", "result", "getI1SameShape($_self)">] # ElementwiseMappable.traits> {
let summary = "floating-point comparison operation";
let description = [{
The `cmpf` operation compares its two operands according to the float
@ -1015,12 +1010,10 @@ def CmpIPredicateAttr : I64EnumAttr<
let cppNamespace = "::mlir";
}
def CmpIOp : Std_Op<"cmpi",
[NoSideEffect, SameTypeOperands, ElementwiseMappable,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
TypesMatchWith<
"result type has i1 element type and same shape as operands",
"lhs", "result", "getI1SameShape($_self)">]> {
def CmpIOp : Std_Op<"cmpi", [NoSideEffect, SameTypeOperands,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>, TypesMatchWith<
"result type has i1 element type and same shape as operands",
"lhs", "result", "getI1SameShape($_self)">] # ElementwiseMappable.traits> {
let summary = "integer comparison operation";
let description = [{
The `cmpi` operation is a generic comparison for integer-like types. Its two
@ -2160,8 +2153,9 @@ def ReturnOp : Std_Op<"return", [NoSideEffect, HasParent<"FuncOp">,
//===----------------------------------------------------------------------===//
def SelectOp : Std_Op<"select", [NoSideEffect,
AllTypesMatch<["true_value", "false_value", "result"]>,
ElementwiseMappable, DeclareOpInterfaceMethods<VectorUnrollOpInterface>]> {
AllTypesMatch<["true_value", "false_value", "result"]>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>] #
ElementwiseMappable.traits> {
let summary = "select operation";
let description = [{
The `select` operation chooses one value based on a binary condition
@ -2391,9 +2385,9 @@ def SignedShiftRightOp : IntBinaryOp<"shift_right_signed"> {
// SignExtendIOp
//===----------------------------------------------------------------------===//
def SignExtendIOp : Std_Op<"sexti",
[NoSideEffect, ElementwiseMappable,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>]> {
def SignExtendIOp : Std_Op<"sexti", [NoSideEffect,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>] #
ElementwiseMappable.traits> {
let summary = "integer sign extension operation";
let description = [{
The integer sign extension operation takes an integer input of
@ -3220,9 +3214,9 @@ def TransposeOp : Std_Op<"transpose", [NoSideEffect]>,
// TruncateIOp
//===----------------------------------------------------------------------===//
def TruncateIOp : Std_Op<"trunci",
[NoSideEffect, ElementwiseMappable,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>,]> {
def TruncateIOp : Std_Op<"trunci", [NoSideEffect,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>] #
ElementwiseMappable.traits> {
let summary = "integer truncation operation";
let description = [{
The integer truncation operation takes an integer input of
@ -3463,9 +3457,9 @@ def XOrOp : IntBinaryOp<"xor", [Commutative]> {
// ZeroExtendIOp
//===----------------------------------------------------------------------===//
def ZeroExtendIOp : Std_Op<"zexti",
[NoSideEffect, ElementwiseMappable,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>,]> {
def ZeroExtendIOp : Std_Op<"zexti", [NoSideEffect,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>] #
ElementwiseMappable.traits> {
let summary = "integer zero extension operation";
let description = [{
The integer zero extension operation takes an integer input of

View File

@ -1785,9 +1785,25 @@ def Terminator : NativeOpTrait<"IsTerminator">;
// Op can be safely normalized in the presence of MemRefs with
// non-identity maps.
def MemRefsNormalizable : NativeOpTrait<"MemRefsNormalizable">;
// Op can be systematically interconverted between scalar and vector/tensor
// form by mapping elementwise based on the type.
def ElementwiseMappable : NativeOpTrait<"ElementwiseMappable">;
// Op is elementwise on tensor/vector operands and results.
def Elementwise : NativeOpTrait<"Elementwise">;
// Elementwise op can be applied to scalars instead tensor/vector operands.
def Scalarizable : NativeOpTrait<"Scalarizable">;
// Elementwise op can be applied to all-vector operands.
def Vectorizable : NativeOpTrait<"Vectorizable">;
// Elementwise op can be applied to all-tensor operands.
def Tensorizable : NativeOpTrait<"Tensorizable">;
// Group together `Elementwise`, `Scalarizable`, `Vectorizable`, and
// `Tensorizable` for convenience.
def ElementwiseMappable {
list<OpTrait> traits = [
Elementwise,
Scalarizable,
Vectorizable,
Tensorizable,
];
}
// Op's regions have a single block with the specified terminator.
class SingleBlockImplicitTerminator<string op>

View File

@ -282,7 +282,7 @@ LogicalResult verifyAtLeastNSuccessors(Operation *op, unsigned numSuccessors);
LogicalResult verifyOperandSizeAttr(Operation *op, StringRef sizeAttrName);
LogicalResult verifyResultSizeAttr(Operation *op, StringRef sizeAttrName);
LogicalResult verifyNoRegionArguments(Operation *op);
LogicalResult verifyElementwiseMappable(Operation *op);
LogicalResult verifyElementwise(Operation *op);
} // namespace impl
/// Helper class for implementing traits. Clients are not expected to interact
@ -1213,93 +1213,144 @@ template <typename ConcrentType>
struct MemRefsNormalizable
: public TraitBase<ConcrentType, MemRefsNormalizable> {};
/// This trait tags scalar ops that also can be applied to vectors/tensors, with
/// their semantics on vectors/tensors being elementwise application.
/// This trait tags element-wise ops that operate on scalars, vectors, or
/// tensors.
///
/// NOTE: Not all ops that are "elementwise" in some abstract sense satisfy this
/// trait. In particular, broadcasting behavior is not allowed. This trait
/// describes a set of invariants that allow systematic
/// vectorization/tensorization, and the reverse, scalarization. The properties
/// needed for this also can be used to implement a number of
/// transformations/analyses/interfaces.
/// trait. In particular, broadcasting behavior is not allowed.
///
/// An `ElementwiseMappable` op must satisfy the following properties:
/// An `Elementwise` op must satisfy the following properties:
///
/// 1. If any result is a vector (resp. tensor), then at least one operand must
/// be a vector (resp. tensor).
/// 2. If any operand is a vector (resp. tensor), then there must be at least
/// one result, and all results must be vectors (resp. tensors).
/// 3. The static types of all vector (resp. tensor) operands and results must
/// have the same shape.
/// 4. In the case of tensor operands, the dynamic shapes of all tensor operands
/// must be the same, otherwise the op has undefined behavior.
/// 5. ("systematic scalarization" property) If an op has vector/tensor
/// operands/results, then the same op, with the operand/result types changed to
/// their corresponding element type, shall be a verifier-valid op.
/// 6. The semantics of the op on vectors (resp. tensors) shall be the same as
/// applying the scalarized version of the op for each corresponding element of
/// the vector (resp. tensor) operands in parallel.
/// 7. ("systematic vectorization/tensorization" property) If an op has
/// scalar operands/results, the op shall remain verifier-valid if all scalar
/// operands are replaced with vectors/tensors of the same shape and
/// corresponding element types.
/// 1. If any result is a vector/tensor then at least one operand must also be a
/// vector/tensor.
/// 2. If any operand is a vector/tensor then there must be at least one result
/// and all results must be vectors/tensors.
/// 3. All operand and result vector/tensor types must be of the same shape. The
/// shape may be dynamic in which case the op's behaviour is undefined for
/// non-matching shapes.
/// 4. The operation must be elementwise on its vector/tensor operands and
/// results. When applied to single-element vectors/tensors, the result must
/// be the same per elememnt.
///
/// Together, these properties provide an easy way for scalar operations to
/// conveniently generalize their behavior to vectors/tensors, and systematize
/// conversion between these forms.
/// TODO: Avoid hardcoding vector/tensor, and generalize this trait to a new
/// interface `ElementwiseTypeInterface` that describes the container types for
/// which the operation is elementwise.
///
/// Rationale:
/// - 1. and 2. guarantee a well-defined iteration space and exclude the cases
/// of 0 non-scalar operands or 0 non-scalar results, which complicate a
/// generic definition of the iteration space.
/// - 3. guarantees that folding can be done across scalars/vectors/tensors with
/// the same pattern, as otherwise lots of special handling for type
/// mismatches would be needed.
/// - 4. guarantees that no error handling is needed. Higher-level dialects
/// should reify any needed guards or error handling code before lowering to
/// an `Elementwise` op.
template <typename ConcreteType>
struct Elementwise : public TraitBase<ConcreteType, Elementwise> {
static LogicalResult verifyTrait(Operation *op) {
return ::mlir::OpTrait::impl::verifyElementwise(op);
}
};
/// This trait tags `Elementwise` operatons that can be systematically
/// scalarized. All vector/tensor operands and results are then replaced by
/// scalars of the respective element type. Semantically, this is the operation
/// on a single element per vector/tensor.
///
/// Rationale:
/// Allow to define the vector/tensor semantics of elementwise operations based
/// on scalars. This provides a constructive procedure for IR transformations
/// to, e.g., create scalar loop bodies from tensor ops.
///
/// Example:
/// ```
/// %tensor_select = "std.select"(%pred_tensor, %true_val, %false_val)
/// : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>)
/// -> tensor<?xf32>
/// ```
/// can be scalarized to
///
/// ```
/// %scalar_select = "std.select"(%pred, %true_val_scalar, %false_val_scalar)
/// : (i1, f32, f32) -> f32
/// ```
template <typename ConcreteType>
struct Scalarizable : public TraitBase<ConcreteType, Scalarizable> {
static LogicalResult verifyTrait(Operation *op) {
static_assert(
ConcreteType::template hasTrait<Elementwise>(),
"`Scalarizable` trait is only applicable to `Elementwise` ops.");
return success();
}
};
/// This trait tags `Elementwise` operatons that can be systematically
/// vectorized. All scalar operands and results are then replaced by vectors
/// with the respective element type. Semantically, this is the operation on
/// multiple arguments simultaneously.
///
/// Rationale:
/// Provide the reverse to `Scalarizable` which, when chained together, allows
/// reasoning about the relationship between the tensor and vector case.
/// Additionally, it permits reasoning about promoting scalars to vectors via
/// broadcasting in cases like `%select_scalar_pred` above.
template <typename ConcreteType>
struct Vectorizable : public TraitBase<ConcreteType, Vectorizable> {
static LogicalResult verifyTrait(Operation *op) {
static_assert(
ConcreteType::template hasTrait<Elementwise>(),
"`Vectorizable` trait is only applicable to `Elementwise` ops.");
return success();
}
};
/// This trait tags `Elementwise` operatons that can be systematically
/// tensorized. All scalar operands and results are then replaced by tensors
/// with the respective element type. Semantically, this is the operation on
/// multiple arguments simultaneously.
///
/// Rationale:
/// Provide the reverse to `Scalarizable` which, when chained together, allows
/// reasoning about the relationship between the tensor and vector case.
/// Additionally, it permits reasoning about promoting scalars to tensors via
/// broadcasting in cases like `%select_scalar_pred` above.
///
/// Examples:
/// ```
/// %scalar = "std.addf"(%a, %b) : (f32, f32) -> f32
/// // Applying the systematic vectorization/tensorization property, this op
/// // must also be valid:
/// %tensor = "std.addf"(%a_tensor, %b_tensor)
/// : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>)
///
/// // These properties generalize well to the cases of non-scalar operands.
/// %select_scalar_pred = "std.select"(%pred, %true_val, %false_val)
/// : (i1, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
/// // Applying the systematic vectorization / tensorization property, this
/// // op must also be valid:
/// %select_tensor_pred = "std.select"(%pred_tensor, %true_val, %false_val)
/// : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>)
/// -> tensor<?xf32>
/// // Applying the systematic scalarization property, this op must also
/// // be valid.
/// %select_scalar = "std.select"(%pred, %true_val_scalar, %false_val_scalar)
/// : (i1, f32, f32) -> f32
/// ```
/// can be tensorized to
/// ```
/// %tensor = "std.addf"(%a, %b) : (tensor<?xf32>, tensor<?xf32>)
/// -> tensor<?xf32>)
/// ```
///
/// TODO: Avoid hardcoding vector/tensor, and generalize this to any type
/// implementing a new "ElementwiseMappableTypeInterface" that describes types
/// for which it makes sense to apply a scalar function to each element.
///
/// Rationale:
/// - 1. and 2. guarantee a well-defined iteration space for 6.
/// - These also exclude the cases of 0 non-scalar operands or 0 non-scalar
/// results, which complicate a generic definition of the iteration space.
/// - 3. guarantees that folding can be done across scalars/vectors/tensors
/// with the same pattern, as otherwise lots of special handling of type
/// mismatches would be needed.
/// - 4. guarantees that no error handling cases need to be considered.
/// - Higher-level dialects should reify any needed guards / error handling
/// code before lowering to an ElementwiseMappable op.
/// - 5. and 6. allow defining the semantics on vectors/tensors via the scalar
/// semantics and provide a constructive procedure for IR transformations
/// to e.g. create scalar loop bodies from tensor ops.
/// - 7. provides the reverse of 5., which when chained together allows
/// reasoning about the relationship between the tensor and vector case.
/// Additionally, it permits reasoning about promoting scalars to
/// vectors/tensors via broadcasting in cases like `%select_scalar_pred`
/// above.
/// ```
/// %scalar_pred = "std.select"(%pred, %true_val, %false_val)
/// : (i1, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
/// ```
/// can be tensorized to
/// ```
/// %tensor_pred = "std.select"(%pred, %true_val, %false_val)
/// : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>)
/// -> tensor<?xf32>
/// ```
template <typename ConcreteType>
struct ElementwiseMappable
: public TraitBase<ConcreteType, ElementwiseMappable> {
struct Tensorizable : public TraitBase<ConcreteType, Tensorizable> {
static LogicalResult verifyTrait(Operation *op) {
return ::mlir::OpTrait::impl::verifyElementwiseMappable(op);
static_assert(
ConcreteType::template hasTrait<Elementwise>(),
"`Tensorizable` trait is only applicable to `Elementwise` ops.");
return success();
}
};
/// Together, `Elementwise`, `Scalarizable`, `Vectorizable`, and `Tensorizable`
/// provide an easy way for scalar operations to conveniently generalize their
/// behavior to vectors/tensors, and systematize conversion between these forms.
bool hasElementwiseMappableTraits(Operation *op);
} // end namespace OpTrait
//===----------------------------------------------------------------------===//

View File

@ -18,7 +18,7 @@
using namespace mlir;
static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {
if (!op->hasTrait<OpTrait::ElementwiseMappable>())
if (!OpTrait::hasElementwiseMappableTraits(op))
return false;
// TODO: The conversion pattern can be made to work for `any_of` here, but

View File

@ -205,7 +205,7 @@ vectorizeOneOp(OpBuilder &builder, Operation *op,
return VectorizationResult{VectorizationStatus::NewOp, builder.clone(*op)};
// 3. Only ElementwiseMappable are allowed in the generic vectorization.
if (!op->hasTrait<OpTrait::ElementwiseMappable>())
if (!OpTrait::hasElementwiseMappableTraits(op))
return VectorizationResult{VectorizationStatus::Failure, nullptr};
// 4. Generic vectorization path for ElementwiseMappable ops.
@ -323,7 +323,7 @@ static bool hasOnlyScalarElementwiseOp(Region &r) {
return false;
for (Operation &op : r.front()) {
if (!(isa<ConstantOp, linalg::YieldOp>(op) ||
op.hasTrait<OpTrait::ElementwiseMappable>()) ||
OpTrait::hasElementwiseMappableTraits(&op)) ||
llvm::any_of(op.getResultTypes(),
[](Type type) { return !type.isIntOrIndexOrFloat(); }))
return false;

View File

@ -1085,7 +1085,7 @@ static bool areSameShapedTypeIgnoringElementType(ShapedType a, ShapedType b) {
return a.getShape() == b.getShape();
}
LogicalResult OpTrait::impl::verifyElementwiseMappable(Operation *op) {
LogicalResult OpTrait::impl::verifyElementwise(Operation *op) {
auto isMappableType = [](Type type) {
return type.isa<VectorType, TensorType>();
};
@ -1127,6 +1127,11 @@ LogicalResult OpTrait::impl::verifyElementwiseMappable(Operation *op) {
return success();
}
bool OpTrait::hasElementwiseMappableTraits(Operation *op) {
return op->hasTrait<Elementwise>() && op->hasTrait<Scalarizable>() &&
op->hasTrait<Vectorizable>() && op->hasTrait<Tensorizable>();
}
//===----------------------------------------------------------------------===//
// BinaryOp implementation
//===----------------------------------------------------------------------===//

View File

@ -356,7 +356,8 @@ def SameOperandElementTypeOp : TEST_Op<"same_operand_element_type",
let results = (outs AnyType);
}
def SameOperandAndResultElementTypeOp : TEST_Op<"same_operand_and_result_element_type",
def SameOperandAndResultElementTypeOp :
TEST_Op<"same_operand_and_result_element_type",
[SameOperandsAndResultElementType]> {
let arguments = (ins Variadic<AnyType>);
let results = (outs Variadic<AnyType>);
@ -379,7 +380,7 @@ def SameOperandAndResultTypeOp : TEST_Op<"same_operand_and_result_type",
}
def ElementwiseMappableOp : TEST_Op<"elementwise_mappable",
[ElementwiseMappable]> {
ElementwiseMappable.traits> {
let arguments = (ins Variadic<AnyType>);
let results = (outs Variadic<AnyType>);
}