mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-11-27 15:41:46 +00:00
[mlir][sparse] rename DimLevelType to LevelType (#73561)
The "Dim" prefix is a legacy left-over that no longer makes sense, since we have a very strict "Dimension" vs. "Level" definition for sparse tensor types and their storage.
This commit is contained in:
parent
e1f69b863d
commit
1944c4f76b
@ -22,24 +22,24 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor);
|
||||
/// Dimension level types (and properties) that define sparse tensors.
|
||||
/// See the documentation in SparseTensorAttrDefs.td for their meaning.
|
||||
///
|
||||
/// These correspond to SparseTensorEncodingAttr::DimLevelType in the C++ API.
|
||||
/// These correspond to SparseTensorEncodingAttr::LevelType in the C++ API.
|
||||
/// If updating, keep them in sync and update the static_assert in the impl
|
||||
/// file.
|
||||
enum MlirSparseTensorDimLevelType {
|
||||
MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE = 4, // 0b00001_00
|
||||
MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED = 8, // 0b00010_00
|
||||
MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU = 9, // 0b00010_01
|
||||
MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO = 10, // 0b00010_10
|
||||
MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO = 11, // 0b00010_11
|
||||
MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON = 16, // 0b00100_00
|
||||
MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU = 17, // 0b00100_01
|
||||
MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO = 18, // 0b00100_10
|
||||
MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO = 19, // 0b00100_11
|
||||
MLIR_SPARSE_TENSOR_DIM_LEVEL_LOOSE_COMPRESSED = 32, // 0b01000_00
|
||||
MLIR_SPARSE_TENSOR_DIM_LEVEL_LOOSE_COMPRESSED_NU = 33, // 0b01000_01
|
||||
MLIR_SPARSE_TENSOR_DIM_LEVEL_LOOSE_COMPRESSED_NO = 34, // 0b01000_10
|
||||
MLIR_SPARSE_TENSOR_DIM_LEVEL_LOOSE_COMPRESSED_NU_NO = 35, // 0b01000_11
|
||||
MLIR_SPARSE_TENSOR_DIM_LEVEL_TWO_OUT_OF_FOUR = 64, // 0b10000_00
|
||||
enum MlirSparseTensorLevelType {
|
||||
MLIR_SPARSE_TENSOR_LEVEL_DENSE = 4, // 0b00001_00
|
||||
MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 8, // 0b00010_00
|
||||
MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU = 9, // 0b00010_01
|
||||
MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO = 10, // 0b00010_10
|
||||
MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO = 11, // 0b00010_11
|
||||
MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 16, // 0b00100_00
|
||||
MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU = 17, // 0b00100_01
|
||||
MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO = 18, // 0b00100_10
|
||||
MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO = 19, // 0b00100_11
|
||||
MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 32, // 0b01000_00
|
||||
MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU = 33, // 0b01000_01
|
||||
MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO = 34, // 0b01000_10
|
||||
MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO = 35, // 0b01000_11
|
||||
MLIR_SPARSE_TENSOR_LEVEL_TWO_OUT_OF_FOUR = 64, // 0b10000_00
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -53,7 +53,7 @@ mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr);
|
||||
/// Creates a `sparse_tensor.encoding` attribute with the given parameters.
|
||||
MLIR_CAPI_EXPORTED MlirAttribute mlirSparseTensorEncodingAttrGet(
|
||||
MlirContext ctx, intptr_t lvlRank,
|
||||
enum MlirSparseTensorDimLevelType const *lvlTypes, MlirAffineMap dimToLvl,
|
||||
enum MlirSparseTensorLevelType const *lvlTypes, MlirAffineMap dimToLvl,
|
||||
MlirAffineMap lvlTodim, int posWidth, int crdWidth);
|
||||
|
||||
/// Returns the level-rank of the `sparse_tensor.encoding` attribute.
|
||||
@ -61,7 +61,7 @@ MLIR_CAPI_EXPORTED intptr_t
|
||||
mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr);
|
||||
|
||||
/// Returns a specified level-type of the `sparse_tensor.encoding` attribute.
|
||||
MLIR_CAPI_EXPORTED enum MlirSparseTensorDimLevelType
|
||||
MLIR_CAPI_EXPORTED enum MlirSparseTensorLevelType
|
||||
mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl);
|
||||
|
||||
/// Returns the dimension-to-level mapping of the `sparse_tensor.encoding`
|
||||
|
@ -10,7 +10,7 @@
|
||||
// IR, and the lightweight runtime support library for sparse tensor
|
||||
// manipulations. That is, all the enums are used to define the API
|
||||
// of the runtime library and hence are also needed when generating
|
||||
// calls into the runtime library. Moveover, the `DimLevelType` enum
|
||||
// calls into the runtime library. Moveover, the `LevelType` enum
|
||||
// is also used as the internal IR encoding of dimension level types,
|
||||
// to avoid code duplication (e.g., for the predicates).
|
||||
//
|
||||
@ -162,10 +162,10 @@ enum class Action : uint32_t {
|
||||
/// about the particular binary encoding.
|
||||
///
|
||||
/// The `Undef` "format" is a special value used internally for cases
|
||||
/// where we need to store an undefined or indeterminate `DimLevelType`.
|
||||
/// where we need to store an undefined or indeterminate `LevelType`.
|
||||
/// It should not be used externally, since it does not indicate an
|
||||
/// actual/representable format.
|
||||
enum class DimLevelType : uint8_t {
|
||||
enum class LevelType : uint8_t {
|
||||
Undef = 0, // 0b00000_00
|
||||
Dense = 4, // 0b00001_00
|
||||
Compressed = 8, // 0b00010_00
|
||||
@ -199,44 +199,44 @@ enum class LevelPropertyNondefault : uint8_t {
|
||||
};
|
||||
|
||||
/// Returns string representation of the given dimension level type.
|
||||
constexpr const char *toMLIRString(DimLevelType lt) {
|
||||
constexpr const char *toMLIRString(LevelType lt) {
|
||||
switch (lt) {
|
||||
case DimLevelType::Undef:
|
||||
case LevelType::Undef:
|
||||
return "undef";
|
||||
case DimLevelType::Dense:
|
||||
case LevelType::Dense:
|
||||
return "dense";
|
||||
case DimLevelType::Compressed:
|
||||
case LevelType::Compressed:
|
||||
return "compressed";
|
||||
case DimLevelType::CompressedNu:
|
||||
case LevelType::CompressedNu:
|
||||
return "compressed(nonunique)";
|
||||
case DimLevelType::CompressedNo:
|
||||
case LevelType::CompressedNo:
|
||||
return "compressed(nonordered)";
|
||||
case DimLevelType::CompressedNuNo:
|
||||
case LevelType::CompressedNuNo:
|
||||
return "compressed(nonunique, nonordered)";
|
||||
case DimLevelType::Singleton:
|
||||
case LevelType::Singleton:
|
||||
return "singleton";
|
||||
case DimLevelType::SingletonNu:
|
||||
case LevelType::SingletonNu:
|
||||
return "singleton(nonunique)";
|
||||
case DimLevelType::SingletonNo:
|
||||
case LevelType::SingletonNo:
|
||||
return "singleton(nonordered)";
|
||||
case DimLevelType::SingletonNuNo:
|
||||
case LevelType::SingletonNuNo:
|
||||
return "singleton(nonunique, nonordered)";
|
||||
case DimLevelType::LooseCompressed:
|
||||
case LevelType::LooseCompressed:
|
||||
return "loose_compressed";
|
||||
case DimLevelType::LooseCompressedNu:
|
||||
case LevelType::LooseCompressedNu:
|
||||
return "loose_compressed(nonunique)";
|
||||
case DimLevelType::LooseCompressedNo:
|
||||
case LevelType::LooseCompressedNo:
|
||||
return "loose_compressed(nonordered)";
|
||||
case DimLevelType::LooseCompressedNuNo:
|
||||
case LevelType::LooseCompressedNuNo:
|
||||
return "loose_compressed(nonunique, nonordered)";
|
||||
case DimLevelType::TwoOutOfFour:
|
||||
case LevelType::TwoOutOfFour:
|
||||
return "block2_4";
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
/// Check that the `DimLevelType` contains a valid (possibly undefined) value.
|
||||
constexpr bool isValidLT(DimLevelType lt) {
|
||||
/// Check that the `LevelType` contains a valid (possibly undefined) value.
|
||||
constexpr bool isValidLT(LevelType lt) {
|
||||
const uint8_t formatBits = static_cast<uint8_t>(lt) >> 2;
|
||||
const uint8_t propertyBits = static_cast<uint8_t>(lt) & 3;
|
||||
// If undefined or dense, then must be unique and ordered.
|
||||
@ -246,75 +246,75 @@ constexpr bool isValidLT(DimLevelType lt) {
|
||||
: (formatBits == 2 || formatBits == 4 || formatBits == 8);
|
||||
}
|
||||
|
||||
/// Check if the `DimLevelType` is the special undefined value.
|
||||
constexpr bool isUndefLT(DimLevelType lt) { return lt == DimLevelType::Undef; }
|
||||
/// Check if the `LevelType` is the special undefined value.
|
||||
constexpr bool isUndefLT(LevelType lt) { return lt == LevelType::Undef; }
|
||||
|
||||
/// Check if the `DimLevelType` is dense (regardless of properties).
|
||||
constexpr bool isDenseLT(DimLevelType lt) {
|
||||
/// Check if the `LevelType` is dense (regardless of properties).
|
||||
constexpr bool isDenseLT(LevelType lt) {
|
||||
return (static_cast<uint8_t>(lt) & ~3) ==
|
||||
static_cast<uint8_t>(DimLevelType::Dense);
|
||||
static_cast<uint8_t>(LevelType::Dense);
|
||||
}
|
||||
|
||||
/// Check if the `DimLevelType` is compressed (regardless of properties).
|
||||
constexpr bool isCompressedLT(DimLevelType lt) {
|
||||
/// Check if the `LevelType` is compressed (regardless of properties).
|
||||
constexpr bool isCompressedLT(LevelType lt) {
|
||||
return (static_cast<uint8_t>(lt) & ~3) ==
|
||||
static_cast<uint8_t>(DimLevelType::Compressed);
|
||||
static_cast<uint8_t>(LevelType::Compressed);
|
||||
}
|
||||
|
||||
/// Check if the `DimLevelType` is singleton (regardless of properties).
|
||||
constexpr bool isSingletonLT(DimLevelType lt) {
|
||||
/// Check if the `LevelType` is singleton (regardless of properties).
|
||||
constexpr bool isSingletonLT(LevelType lt) {
|
||||
return (static_cast<uint8_t>(lt) & ~3) ==
|
||||
static_cast<uint8_t>(DimLevelType::Singleton);
|
||||
static_cast<uint8_t>(LevelType::Singleton);
|
||||
}
|
||||
|
||||
/// Check if the `DimLevelType` is loose compressed (regardless of properties).
|
||||
constexpr bool isLooseCompressedLT(DimLevelType lt) {
|
||||
/// Check if the `LevelType` is loose compressed (regardless of properties).
|
||||
constexpr bool isLooseCompressedLT(LevelType lt) {
|
||||
return (static_cast<uint8_t>(lt) & ~3) ==
|
||||
static_cast<uint8_t>(DimLevelType::LooseCompressed);
|
||||
static_cast<uint8_t>(LevelType::LooseCompressed);
|
||||
}
|
||||
|
||||
/// Check if the `DimLevelType` is 2OutOf4 (regardless of properties).
|
||||
constexpr bool is2OutOf4LT(DimLevelType lt) {
|
||||
/// Check if the `LevelType` is 2OutOf4 (regardless of properties).
|
||||
constexpr bool is2OutOf4LT(LevelType lt) {
|
||||
return (static_cast<uint8_t>(lt) & ~3) ==
|
||||
static_cast<uint8_t>(DimLevelType::TwoOutOfFour);
|
||||
static_cast<uint8_t>(LevelType::TwoOutOfFour);
|
||||
}
|
||||
|
||||
/// Check if the `DimLevelType` needs positions array.
|
||||
constexpr bool isWithPosLT(DimLevelType lt) {
|
||||
/// Check if the `LevelType` needs positions array.
|
||||
constexpr bool isWithPosLT(LevelType lt) {
|
||||
return isCompressedLT(lt) || isLooseCompressedLT(lt);
|
||||
}
|
||||
|
||||
/// Check if the `DimLevelType` needs coordinates array.
|
||||
constexpr bool isWithCrdLT(DimLevelType lt) {
|
||||
/// Check if the `LevelType` needs coordinates array.
|
||||
constexpr bool isWithCrdLT(LevelType lt) {
|
||||
return isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) ||
|
||||
is2OutOf4LT(lt);
|
||||
}
|
||||
|
||||
/// Check if the `DimLevelType` is ordered (regardless of storage format).
|
||||
constexpr bool isOrderedLT(DimLevelType lt) {
|
||||
/// Check if the `LevelType` is ordered (regardless of storage format).
|
||||
constexpr bool isOrderedLT(LevelType lt) {
|
||||
return !(static_cast<uint8_t>(lt) & 2);
|
||||
}
|
||||
|
||||
/// Check if the `DimLevelType` is unique (regardless of storage format).
|
||||
constexpr bool isUniqueLT(DimLevelType lt) {
|
||||
/// Check if the `LevelType` is unique (regardless of storage format).
|
||||
constexpr bool isUniqueLT(LevelType lt) {
|
||||
return !(static_cast<uint8_t>(lt) & 1);
|
||||
}
|
||||
|
||||
/// Convert a DimLevelType to its corresponding LevelFormat.
|
||||
/// Convert a LevelType to its corresponding LevelFormat.
|
||||
/// Returns std::nullopt when input lt is Undef.
|
||||
constexpr std::optional<LevelFormat> getLevelFormat(DimLevelType lt) {
|
||||
if (lt == DimLevelType::Undef)
|
||||
constexpr std::optional<LevelFormat> getLevelFormat(LevelType lt) {
|
||||
if (lt == LevelType::Undef)
|
||||
return std::nullopt;
|
||||
return static_cast<LevelFormat>(static_cast<uint8_t>(lt) & ~3);
|
||||
}
|
||||
|
||||
/// Convert a LevelFormat to its corresponding DimLevelType with the given
|
||||
/// Convert a LevelFormat to its corresponding LevelType with the given
|
||||
/// properties. Returns std::nullopt when the properties are not applicable
|
||||
/// for the input level format.
|
||||
constexpr std::optional<DimLevelType>
|
||||
buildLevelType(LevelFormat lf, bool ordered, bool unique) {
|
||||
auto lt = static_cast<DimLevelType>(static_cast<uint8_t>(lf) |
|
||||
(ordered ? 0 : 2) | (unique ? 0 : 1));
|
||||
constexpr std::optional<LevelType> buildLevelType(LevelFormat lf, bool ordered,
|
||||
bool unique) {
|
||||
auto lt = static_cast<LevelType>(static_cast<uint8_t>(lf) |
|
||||
(ordered ? 0 : 2) | (unique ? 0 : 1));
|
||||
return isValidLT(lt) ? std::optional(lt) : std::nullopt;
|
||||
}
|
||||
|
||||
@ -323,190 +323,187 @@ buildLevelType(LevelFormat lf, bool ordered, bool unique) {
|
||||
//
|
||||
|
||||
static_assert(
|
||||
(getLevelFormat(DimLevelType::Undef) == std::nullopt &&
|
||||
*getLevelFormat(DimLevelType::Dense) == LevelFormat::Dense &&
|
||||
*getLevelFormat(DimLevelType::Compressed) == LevelFormat::Compressed &&
|
||||
*getLevelFormat(DimLevelType::CompressedNu) == LevelFormat::Compressed &&
|
||||
*getLevelFormat(DimLevelType::CompressedNo) == LevelFormat::Compressed &&
|
||||
*getLevelFormat(DimLevelType::CompressedNuNo) == LevelFormat::Compressed &&
|
||||
*getLevelFormat(DimLevelType::Singleton) == LevelFormat::Singleton &&
|
||||
*getLevelFormat(DimLevelType::SingletonNu) == LevelFormat::Singleton &&
|
||||
*getLevelFormat(DimLevelType::SingletonNo) == LevelFormat::Singleton &&
|
||||
*getLevelFormat(DimLevelType::SingletonNuNo) == LevelFormat::Singleton &&
|
||||
*getLevelFormat(DimLevelType::LooseCompressed) ==
|
||||
(getLevelFormat(LevelType::Undef) == std::nullopt &&
|
||||
*getLevelFormat(LevelType::Dense) == LevelFormat::Dense &&
|
||||
*getLevelFormat(LevelType::Compressed) == LevelFormat::Compressed &&
|
||||
*getLevelFormat(LevelType::CompressedNu) == LevelFormat::Compressed &&
|
||||
*getLevelFormat(LevelType::CompressedNo) == LevelFormat::Compressed &&
|
||||
*getLevelFormat(LevelType::CompressedNuNo) == LevelFormat::Compressed &&
|
||||
*getLevelFormat(LevelType::Singleton) == LevelFormat::Singleton &&
|
||||
*getLevelFormat(LevelType::SingletonNu) == LevelFormat::Singleton &&
|
||||
*getLevelFormat(LevelType::SingletonNo) == LevelFormat::Singleton &&
|
||||
*getLevelFormat(LevelType::SingletonNuNo) == LevelFormat::Singleton &&
|
||||
*getLevelFormat(LevelType::LooseCompressed) ==
|
||||
LevelFormat::LooseCompressed &&
|
||||
*getLevelFormat(DimLevelType::LooseCompressedNu) ==
|
||||
*getLevelFormat(LevelType::LooseCompressedNu) ==
|
||||
LevelFormat::LooseCompressed &&
|
||||
*getLevelFormat(DimLevelType::LooseCompressedNo) ==
|
||||
*getLevelFormat(LevelType::LooseCompressedNo) ==
|
||||
LevelFormat::LooseCompressed &&
|
||||
*getLevelFormat(DimLevelType::LooseCompressedNuNo) ==
|
||||
*getLevelFormat(LevelType::LooseCompressedNuNo) ==
|
||||
LevelFormat::LooseCompressed &&
|
||||
*getLevelFormat(DimLevelType::TwoOutOfFour) == LevelFormat::TwoOutOfFour),
|
||||
*getLevelFormat(LevelType::TwoOutOfFour) == LevelFormat::TwoOutOfFour),
|
||||
"getLevelFormat conversion is broken");
|
||||
|
||||
static_assert(
|
||||
(buildLevelType(LevelFormat::Dense, false, true) == std::nullopt &&
|
||||
buildLevelType(LevelFormat::Dense, true, false) == std::nullopt &&
|
||||
buildLevelType(LevelFormat::Dense, false, false) == std::nullopt &&
|
||||
*buildLevelType(LevelFormat::Dense, true, true) == DimLevelType::Dense &&
|
||||
*buildLevelType(LevelFormat::Dense, true, true) == LevelType::Dense &&
|
||||
*buildLevelType(LevelFormat::Compressed, true, true) ==
|
||||
DimLevelType::Compressed &&
|
||||
LevelType::Compressed &&
|
||||
*buildLevelType(LevelFormat::Compressed, true, false) ==
|
||||
DimLevelType::CompressedNu &&
|
||||
LevelType::CompressedNu &&
|
||||
*buildLevelType(LevelFormat::Compressed, false, true) ==
|
||||
DimLevelType::CompressedNo &&
|
||||
LevelType::CompressedNo &&
|
||||
*buildLevelType(LevelFormat::Compressed, false, false) ==
|
||||
DimLevelType::CompressedNuNo &&
|
||||
LevelType::CompressedNuNo &&
|
||||
*buildLevelType(LevelFormat::Singleton, true, true) ==
|
||||
DimLevelType::Singleton &&
|
||||
LevelType::Singleton &&
|
||||
*buildLevelType(LevelFormat::Singleton, true, false) ==
|
||||
DimLevelType::SingletonNu &&
|
||||
LevelType::SingletonNu &&
|
||||
*buildLevelType(LevelFormat::Singleton, false, true) ==
|
||||
DimLevelType::SingletonNo &&
|
||||
LevelType::SingletonNo &&
|
||||
*buildLevelType(LevelFormat::Singleton, false, false) ==
|
||||
DimLevelType::SingletonNuNo &&
|
||||
LevelType::SingletonNuNo &&
|
||||
*buildLevelType(LevelFormat::LooseCompressed, true, true) ==
|
||||
DimLevelType::LooseCompressed &&
|
||||
LevelType::LooseCompressed &&
|
||||
*buildLevelType(LevelFormat::LooseCompressed, true, false) ==
|
||||
DimLevelType::LooseCompressedNu &&
|
||||
LevelType::LooseCompressedNu &&
|
||||
*buildLevelType(LevelFormat::LooseCompressed, false, true) ==
|
||||
DimLevelType::LooseCompressedNo &&
|
||||
LevelType::LooseCompressedNo &&
|
||||
*buildLevelType(LevelFormat::LooseCompressed, false, false) ==
|
||||
DimLevelType::LooseCompressedNuNo &&
|
||||
LevelType::LooseCompressedNuNo &&
|
||||
buildLevelType(LevelFormat::TwoOutOfFour, false, true) == std::nullopt &&
|
||||
buildLevelType(LevelFormat::TwoOutOfFour, true, false) == std::nullopt &&
|
||||
buildLevelType(LevelFormat::TwoOutOfFour, false, false) == std::nullopt &&
|
||||
*buildLevelType(LevelFormat::TwoOutOfFour, true, true) ==
|
||||
DimLevelType::TwoOutOfFour),
|
||||
LevelType::TwoOutOfFour),
|
||||
"buildLevelType conversion is broken");
|
||||
|
||||
static_assert((isValidLT(DimLevelType::Undef) &&
|
||||
isValidLT(DimLevelType::Dense) &&
|
||||
isValidLT(DimLevelType::Compressed) &&
|
||||
isValidLT(DimLevelType::CompressedNu) &&
|
||||
isValidLT(DimLevelType::CompressedNo) &&
|
||||
isValidLT(DimLevelType::CompressedNuNo) &&
|
||||
isValidLT(DimLevelType::Singleton) &&
|
||||
isValidLT(DimLevelType::SingletonNu) &&
|
||||
isValidLT(DimLevelType::SingletonNo) &&
|
||||
isValidLT(DimLevelType::SingletonNuNo) &&
|
||||
isValidLT(DimLevelType::LooseCompressed) &&
|
||||
isValidLT(DimLevelType::LooseCompressedNu) &&
|
||||
isValidLT(DimLevelType::LooseCompressedNo) &&
|
||||
isValidLT(DimLevelType::LooseCompressedNuNo) &&
|
||||
isValidLT(DimLevelType::TwoOutOfFour)),
|
||||
"isValidLT definition is broken");
|
||||
static_assert(
|
||||
(isValidLT(LevelType::Undef) && isValidLT(LevelType::Dense) &&
|
||||
isValidLT(LevelType::Compressed) && isValidLT(LevelType::CompressedNu) &&
|
||||
isValidLT(LevelType::CompressedNo) &&
|
||||
isValidLT(LevelType::CompressedNuNo) && isValidLT(LevelType::Singleton) &&
|
||||
isValidLT(LevelType::SingletonNu) && isValidLT(LevelType::SingletonNo) &&
|
||||
isValidLT(LevelType::SingletonNuNo) &&
|
||||
isValidLT(LevelType::LooseCompressed) &&
|
||||
isValidLT(LevelType::LooseCompressedNu) &&
|
||||
isValidLT(LevelType::LooseCompressedNo) &&
|
||||
isValidLT(LevelType::LooseCompressedNuNo) &&
|
||||
isValidLT(LevelType::TwoOutOfFour)),
|
||||
"isValidLT definition is broken");
|
||||
|
||||
static_assert((isDenseLT(DimLevelType::Dense) &&
|
||||
!isDenseLT(DimLevelType::Compressed) &&
|
||||
!isDenseLT(DimLevelType::CompressedNu) &&
|
||||
!isDenseLT(DimLevelType::CompressedNo) &&
|
||||
!isDenseLT(DimLevelType::CompressedNuNo) &&
|
||||
!isDenseLT(DimLevelType::Singleton) &&
|
||||
!isDenseLT(DimLevelType::SingletonNu) &&
|
||||
!isDenseLT(DimLevelType::SingletonNo) &&
|
||||
!isDenseLT(DimLevelType::SingletonNuNo) &&
|
||||
!isDenseLT(DimLevelType::LooseCompressed) &&
|
||||
!isDenseLT(DimLevelType::LooseCompressedNu) &&
|
||||
!isDenseLT(DimLevelType::LooseCompressedNo) &&
|
||||
!isDenseLT(DimLevelType::LooseCompressedNuNo) &&
|
||||
!isDenseLT(DimLevelType::TwoOutOfFour)),
|
||||
static_assert((isDenseLT(LevelType::Dense) &&
|
||||
!isDenseLT(LevelType::Compressed) &&
|
||||
!isDenseLT(LevelType::CompressedNu) &&
|
||||
!isDenseLT(LevelType::CompressedNo) &&
|
||||
!isDenseLT(LevelType::CompressedNuNo) &&
|
||||
!isDenseLT(LevelType::Singleton) &&
|
||||
!isDenseLT(LevelType::SingletonNu) &&
|
||||
!isDenseLT(LevelType::SingletonNo) &&
|
||||
!isDenseLT(LevelType::SingletonNuNo) &&
|
||||
!isDenseLT(LevelType::LooseCompressed) &&
|
||||
!isDenseLT(LevelType::LooseCompressedNu) &&
|
||||
!isDenseLT(LevelType::LooseCompressedNo) &&
|
||||
!isDenseLT(LevelType::LooseCompressedNuNo) &&
|
||||
!isDenseLT(LevelType::TwoOutOfFour)),
|
||||
"isDenseLT definition is broken");
|
||||
|
||||
static_assert((!isCompressedLT(DimLevelType::Dense) &&
|
||||
isCompressedLT(DimLevelType::Compressed) &&
|
||||
isCompressedLT(DimLevelType::CompressedNu) &&
|
||||
isCompressedLT(DimLevelType::CompressedNo) &&
|
||||
isCompressedLT(DimLevelType::CompressedNuNo) &&
|
||||
!isCompressedLT(DimLevelType::Singleton) &&
|
||||
!isCompressedLT(DimLevelType::SingletonNu) &&
|
||||
!isCompressedLT(DimLevelType::SingletonNo) &&
|
||||
!isCompressedLT(DimLevelType::SingletonNuNo) &&
|
||||
!isCompressedLT(DimLevelType::LooseCompressed) &&
|
||||
!isCompressedLT(DimLevelType::LooseCompressedNu) &&
|
||||
!isCompressedLT(DimLevelType::LooseCompressedNo) &&
|
||||
!isCompressedLT(DimLevelType::LooseCompressedNuNo) &&
|
||||
!isCompressedLT(DimLevelType::TwoOutOfFour)),
|
||||
static_assert((!isCompressedLT(LevelType::Dense) &&
|
||||
isCompressedLT(LevelType::Compressed) &&
|
||||
isCompressedLT(LevelType::CompressedNu) &&
|
||||
isCompressedLT(LevelType::CompressedNo) &&
|
||||
isCompressedLT(LevelType::CompressedNuNo) &&
|
||||
!isCompressedLT(LevelType::Singleton) &&
|
||||
!isCompressedLT(LevelType::SingletonNu) &&
|
||||
!isCompressedLT(LevelType::SingletonNo) &&
|
||||
!isCompressedLT(LevelType::SingletonNuNo) &&
|
||||
!isCompressedLT(LevelType::LooseCompressed) &&
|
||||
!isCompressedLT(LevelType::LooseCompressedNu) &&
|
||||
!isCompressedLT(LevelType::LooseCompressedNo) &&
|
||||
!isCompressedLT(LevelType::LooseCompressedNuNo) &&
|
||||
!isCompressedLT(LevelType::TwoOutOfFour)),
|
||||
"isCompressedLT definition is broken");
|
||||
|
||||
static_assert((!isSingletonLT(DimLevelType::Dense) &&
|
||||
!isSingletonLT(DimLevelType::Compressed) &&
|
||||
!isSingletonLT(DimLevelType::CompressedNu) &&
|
||||
!isSingletonLT(DimLevelType::CompressedNo) &&
|
||||
!isSingletonLT(DimLevelType::CompressedNuNo) &&
|
||||
isSingletonLT(DimLevelType::Singleton) &&
|
||||
isSingletonLT(DimLevelType::SingletonNu) &&
|
||||
isSingletonLT(DimLevelType::SingletonNo) &&
|
||||
isSingletonLT(DimLevelType::SingletonNuNo) &&
|
||||
!isSingletonLT(DimLevelType::LooseCompressed) &&
|
||||
!isSingletonLT(DimLevelType::LooseCompressedNu) &&
|
||||
!isSingletonLT(DimLevelType::LooseCompressedNo) &&
|
||||
!isSingletonLT(DimLevelType::LooseCompressedNuNo) &&
|
||||
!isSingletonLT(DimLevelType::TwoOutOfFour)),
|
||||
static_assert((!isSingletonLT(LevelType::Dense) &&
|
||||
!isSingletonLT(LevelType::Compressed) &&
|
||||
!isSingletonLT(LevelType::CompressedNu) &&
|
||||
!isSingletonLT(LevelType::CompressedNo) &&
|
||||
!isSingletonLT(LevelType::CompressedNuNo) &&
|
||||
isSingletonLT(LevelType::Singleton) &&
|
||||
isSingletonLT(LevelType::SingletonNu) &&
|
||||
isSingletonLT(LevelType::SingletonNo) &&
|
||||
isSingletonLT(LevelType::SingletonNuNo) &&
|
||||
!isSingletonLT(LevelType::LooseCompressed) &&
|
||||
!isSingletonLT(LevelType::LooseCompressedNu) &&
|
||||
!isSingletonLT(LevelType::LooseCompressedNo) &&
|
||||
!isSingletonLT(LevelType::LooseCompressedNuNo) &&
|
||||
!isSingletonLT(LevelType::TwoOutOfFour)),
|
||||
"isSingletonLT definition is broken");
|
||||
|
||||
static_assert((!isLooseCompressedLT(DimLevelType::Dense) &&
|
||||
!isLooseCompressedLT(DimLevelType::Compressed) &&
|
||||
!isLooseCompressedLT(DimLevelType::CompressedNu) &&
|
||||
!isLooseCompressedLT(DimLevelType::CompressedNo) &&
|
||||
!isLooseCompressedLT(DimLevelType::CompressedNuNo) &&
|
||||
!isLooseCompressedLT(DimLevelType::Singleton) &&
|
||||
!isLooseCompressedLT(DimLevelType::SingletonNu) &&
|
||||
!isLooseCompressedLT(DimLevelType::SingletonNo) &&
|
||||
!isLooseCompressedLT(DimLevelType::SingletonNuNo) &&
|
||||
isLooseCompressedLT(DimLevelType::LooseCompressed) &&
|
||||
isLooseCompressedLT(DimLevelType::LooseCompressedNu) &&
|
||||
isLooseCompressedLT(DimLevelType::LooseCompressedNo) &&
|
||||
isLooseCompressedLT(DimLevelType::LooseCompressedNuNo) &&
|
||||
!isLooseCompressedLT(DimLevelType::TwoOutOfFour)),
|
||||
static_assert((!isLooseCompressedLT(LevelType::Dense) &&
|
||||
!isLooseCompressedLT(LevelType::Compressed) &&
|
||||
!isLooseCompressedLT(LevelType::CompressedNu) &&
|
||||
!isLooseCompressedLT(LevelType::CompressedNo) &&
|
||||
!isLooseCompressedLT(LevelType::CompressedNuNo) &&
|
||||
!isLooseCompressedLT(LevelType::Singleton) &&
|
||||
!isLooseCompressedLT(LevelType::SingletonNu) &&
|
||||
!isLooseCompressedLT(LevelType::SingletonNo) &&
|
||||
!isLooseCompressedLT(LevelType::SingletonNuNo) &&
|
||||
isLooseCompressedLT(LevelType::LooseCompressed) &&
|
||||
isLooseCompressedLT(LevelType::LooseCompressedNu) &&
|
||||
isLooseCompressedLT(LevelType::LooseCompressedNo) &&
|
||||
isLooseCompressedLT(LevelType::LooseCompressedNuNo) &&
|
||||
!isLooseCompressedLT(LevelType::TwoOutOfFour)),
|
||||
"isLooseCompressedLT definition is broken");
|
||||
|
||||
static_assert((!is2OutOf4LT(DimLevelType::Dense) &&
|
||||
!is2OutOf4LT(DimLevelType::Compressed) &&
|
||||
!is2OutOf4LT(DimLevelType::CompressedNu) &&
|
||||
!is2OutOf4LT(DimLevelType::CompressedNo) &&
|
||||
!is2OutOf4LT(DimLevelType::CompressedNuNo) &&
|
||||
!is2OutOf4LT(DimLevelType::Singleton) &&
|
||||
!is2OutOf4LT(DimLevelType::SingletonNu) &&
|
||||
!is2OutOf4LT(DimLevelType::SingletonNo) &&
|
||||
!is2OutOf4LT(DimLevelType::SingletonNuNo) &&
|
||||
!is2OutOf4LT(DimLevelType::LooseCompressed) &&
|
||||
!is2OutOf4LT(DimLevelType::LooseCompressedNu) &&
|
||||
!is2OutOf4LT(DimLevelType::LooseCompressedNo) &&
|
||||
!is2OutOf4LT(DimLevelType::LooseCompressedNuNo) &&
|
||||
is2OutOf4LT(DimLevelType::TwoOutOfFour)),
|
||||
static_assert((!is2OutOf4LT(LevelType::Dense) &&
|
||||
!is2OutOf4LT(LevelType::Compressed) &&
|
||||
!is2OutOf4LT(LevelType::CompressedNu) &&
|
||||
!is2OutOf4LT(LevelType::CompressedNo) &&
|
||||
!is2OutOf4LT(LevelType::CompressedNuNo) &&
|
||||
!is2OutOf4LT(LevelType::Singleton) &&
|
||||
!is2OutOf4LT(LevelType::SingletonNu) &&
|
||||
!is2OutOf4LT(LevelType::SingletonNo) &&
|
||||
!is2OutOf4LT(LevelType::SingletonNuNo) &&
|
||||
!is2OutOf4LT(LevelType::LooseCompressed) &&
|
||||
!is2OutOf4LT(LevelType::LooseCompressedNu) &&
|
||||
!is2OutOf4LT(LevelType::LooseCompressedNo) &&
|
||||
!is2OutOf4LT(LevelType::LooseCompressedNuNo) &&
|
||||
is2OutOf4LT(LevelType::TwoOutOfFour)),
|
||||
"is2OutOf4LT definition is broken");
|
||||
|
||||
static_assert((isOrderedLT(DimLevelType::Dense) &&
|
||||
isOrderedLT(DimLevelType::Compressed) &&
|
||||
isOrderedLT(DimLevelType::CompressedNu) &&
|
||||
!isOrderedLT(DimLevelType::CompressedNo) &&
|
||||
!isOrderedLT(DimLevelType::CompressedNuNo) &&
|
||||
isOrderedLT(DimLevelType::Singleton) &&
|
||||
isOrderedLT(DimLevelType::SingletonNu) &&
|
||||
!isOrderedLT(DimLevelType::SingletonNo) &&
|
||||
!isOrderedLT(DimLevelType::SingletonNuNo) &&
|
||||
isOrderedLT(DimLevelType::LooseCompressed) &&
|
||||
isOrderedLT(DimLevelType::LooseCompressedNu) &&
|
||||
!isOrderedLT(DimLevelType::LooseCompressedNo) &&
|
||||
!isOrderedLT(DimLevelType::LooseCompressedNuNo) &&
|
||||
isOrderedLT(DimLevelType::TwoOutOfFour)),
|
||||
static_assert((isOrderedLT(LevelType::Dense) &&
|
||||
isOrderedLT(LevelType::Compressed) &&
|
||||
isOrderedLT(LevelType::CompressedNu) &&
|
||||
!isOrderedLT(LevelType::CompressedNo) &&
|
||||
!isOrderedLT(LevelType::CompressedNuNo) &&
|
||||
isOrderedLT(LevelType::Singleton) &&
|
||||
isOrderedLT(LevelType::SingletonNu) &&
|
||||
!isOrderedLT(LevelType::SingletonNo) &&
|
||||
!isOrderedLT(LevelType::SingletonNuNo) &&
|
||||
isOrderedLT(LevelType::LooseCompressed) &&
|
||||
isOrderedLT(LevelType::LooseCompressedNu) &&
|
||||
!isOrderedLT(LevelType::LooseCompressedNo) &&
|
||||
!isOrderedLT(LevelType::LooseCompressedNuNo) &&
|
||||
isOrderedLT(LevelType::TwoOutOfFour)),
|
||||
"isOrderedLT definition is broken");
|
||||
|
||||
static_assert((isUniqueLT(DimLevelType::Dense) &&
|
||||
isUniqueLT(DimLevelType::Compressed) &&
|
||||
!isUniqueLT(DimLevelType::CompressedNu) &&
|
||||
isUniqueLT(DimLevelType::CompressedNo) &&
|
||||
!isUniqueLT(DimLevelType::CompressedNuNo) &&
|
||||
isUniqueLT(DimLevelType::Singleton) &&
|
||||
!isUniqueLT(DimLevelType::SingletonNu) &&
|
||||
isUniqueLT(DimLevelType::SingletonNo) &&
|
||||
!isUniqueLT(DimLevelType::SingletonNuNo) &&
|
||||
isUniqueLT(DimLevelType::LooseCompressed) &&
|
||||
!isUniqueLT(DimLevelType::LooseCompressedNu) &&
|
||||
isUniqueLT(DimLevelType::LooseCompressedNo) &&
|
||||
!isUniqueLT(DimLevelType::LooseCompressedNuNo) &&
|
||||
isUniqueLT(DimLevelType::TwoOutOfFour)),
|
||||
static_assert((isUniqueLT(LevelType::Dense) &&
|
||||
isUniqueLT(LevelType::Compressed) &&
|
||||
!isUniqueLT(LevelType::CompressedNu) &&
|
||||
isUniqueLT(LevelType::CompressedNo) &&
|
||||
!isUniqueLT(LevelType::CompressedNuNo) &&
|
||||
isUniqueLT(LevelType::Singleton) &&
|
||||
!isUniqueLT(LevelType::SingletonNu) &&
|
||||
isUniqueLT(LevelType::SingletonNo) &&
|
||||
!isUniqueLT(LevelType::SingletonNuNo) &&
|
||||
isUniqueLT(LevelType::LooseCompressed) &&
|
||||
!isUniqueLT(LevelType::LooseCompressedNu) &&
|
||||
isUniqueLT(LevelType::LooseCompressedNo) &&
|
||||
!isUniqueLT(LevelType::LooseCompressedNuNo) &&
|
||||
isUniqueLT(LevelType::TwoOutOfFour)),
|
||||
"isUniqueLT definition is broken");
|
||||
|
||||
/// Bit manipulations for affine encoding.
|
||||
|
@ -278,7 +278,7 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
|
||||
// A level-type for each level of the sparse storage
|
||||
// (consists of a level-format combined with level-properties).
|
||||
ArrayRefParameter<
|
||||
"::mlir::sparse_tensor::DimLevelType",
|
||||
"::mlir::sparse_tensor::LevelType",
|
||||
"level-types"
|
||||
>: $lvlTypes,
|
||||
|
||||
@ -302,7 +302,7 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
|
||||
);
|
||||
|
||||
let builders = [
|
||||
AttrBuilder<(ins "ArrayRef<::mlir::sparse_tensor::DimLevelType>":$lvlTypes,
|
||||
AttrBuilder<(ins "ArrayRef<::mlir::sparse_tensor::LevelType>":$lvlTypes,
|
||||
CArg<"AffineMap", "{}">:$dimToLvl,
|
||||
CArg<"AffineMap", "{}">:$lvlToDim,
|
||||
CArg<"unsigned", "0">:$posWidth,
|
||||
@ -366,9 +366,9 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
|
||||
//
|
||||
|
||||
/// Safely looks up the level-type for the requested level. (Returns
|
||||
/// `DimLevelType::Dense` for the null encoding, since dense-tensors
|
||||
/// `LevelType::Dense` for the null encoding, since dense-tensors
|
||||
/// are always all-dense.)
|
||||
::mlir::sparse_tensor::DimLevelType getLvlType(::mlir::sparse_tensor::Level l) const;
|
||||
::mlir::sparse_tensor::LevelType getLvlType(::mlir::sparse_tensor::Level l) const;
|
||||
|
||||
bool isDenseLvl(::mlir::sparse_tensor::Level l) const { return isDenseLT(getLvlType(l)); }
|
||||
bool isCompressedLvl(::mlir::sparse_tensor::Level l) const { return isCompressedLT(getLvlType(l)); }
|
||||
@ -428,7 +428,7 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
|
||||
|
||||
void printSymbols(AffineMap &map, AsmPrinter &printer) const;
|
||||
void printDimensions(AffineMap &map, AsmPrinter &printer, ArrayRef<::mlir::sparse_tensor::SparseTensorDimSliceAttr> dimSlices) const;
|
||||
void printLevels(AffineMap &map, AsmPrinter &printer, ArrayRef<::mlir::sparse_tensor::DimLevelType> lvlTypes) const;
|
||||
void printLevels(AffineMap &map, AsmPrinter &printer, ArrayRef<::mlir::sparse_tensor::LevelType> lvlTypes) const;
|
||||
}];
|
||||
|
||||
let genVerifyDecl = 1;
|
||||
|
@ -126,7 +126,7 @@ public:
|
||||
void foreachField(
|
||||
llvm::function_ref<bool(
|
||||
FieldIndex /*fieldIdx*/, SparseTensorFieldKind /*fieldKind*/,
|
||||
Level /*lvl (if applicable)*/, DimLevelType /*LT (if applicable)*/)>)
|
||||
Level /*lvl (if applicable)*/, LevelType /*LT (if applicable)*/)>)
|
||||
const;
|
||||
|
||||
/// Gets the field index for required field.
|
||||
@ -165,7 +165,7 @@ inline unsigned getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc) {
|
||||
inline void foreachFieldInSparseTensor(
|
||||
SparseTensorEncodingAttr enc,
|
||||
llvm::function_ref<bool(FieldIndex, SparseTensorFieldKind, Level,
|
||||
DimLevelType)>
|
||||
LevelType)>
|
||||
callback) {
|
||||
return StorageLayout(enc).foreachField(callback);
|
||||
}
|
||||
@ -173,7 +173,7 @@ inline void foreachFieldInSparseTensor(
|
||||
void foreachFieldAndTypeInSparseTensor(
|
||||
SparseTensorType,
|
||||
llvm::function_ref<bool(Type, FieldIndex, SparseTensorFieldKind, Level,
|
||||
DimLevelType)>);
|
||||
LevelType)>);
|
||||
|
||||
} // namespace sparse_tensor
|
||||
} // namespace mlir
|
||||
|
@ -282,8 +282,8 @@ public:
|
||||
/// `ShapedType::Trait<T>::getNumDynamicDims`.
|
||||
int64_t getNumDynamicDims() const { return rtp.getNumDynamicDims(); }
|
||||
|
||||
ArrayRef<DimLevelType> getLvlTypes() const { return enc.getLvlTypes(); }
|
||||
DimLevelType getLvlType(Level l) const {
|
||||
ArrayRef<LevelType> getLvlTypes() const { return enc.getLvlTypes(); }
|
||||
LevelType getLvlType(Level l) const {
|
||||
// This OOB check is for dense-tensors, since this class knows
|
||||
// their lvlRank (whereas STEA::getLvlType will/can only check
|
||||
// OOB for sparse-tensors).
|
||||
|
@ -56,8 +56,8 @@ using LatPointId = unsigned;
|
||||
/// for the corresponding `SmallVector<LatPointId>` object.
|
||||
using LatSetId = unsigned;
|
||||
|
||||
/// A pair of level and its corresponding DimLevelType of a tensor.
|
||||
using LvlLTPair = std::pair<Level, DimLevelType>;
|
||||
/// A pair of level and its corresponding LevelType of a tensor.
|
||||
using LvlLTPair = std::pair<Level, LevelType>;
|
||||
|
||||
/// A pair of loop id and its coefficients. E.g., for affine expression in the
|
||||
/// affine map `2 * d0`, loop id = 0, coefficient = 2.
|
||||
@ -395,13 +395,13 @@ public:
|
||||
bool hasSparseIdxReduction(const BitVector &bits) const;
|
||||
|
||||
/// Gets the level-type of the `t`th tensor on `i`th loop.
|
||||
DimLevelType getLvlType(TensorId t, LoopId i) const {
|
||||
LevelType getLvlType(TensorId t, LoopId i) const {
|
||||
assert(isValidTensorId(t) && isValidLoopId(i));
|
||||
return lvlTypes[t][i];
|
||||
}
|
||||
|
||||
/// Gets the level-type of the TensorLoopId.
|
||||
DimLevelType getLvlType(TensorLoopId b) const {
|
||||
LevelType getLvlType(TensorLoopId b) const {
|
||||
return getLvlType(tensor(b), loop(b));
|
||||
}
|
||||
|
||||
@ -422,7 +422,7 @@ public:
|
||||
|
||||
/// Sets the level number and level-type of the `t`th tensor on
|
||||
/// `i`th loop.
|
||||
void setLevelAndType(TensorId t, LoopId i, Level lvl, DimLevelType lt) {
|
||||
void setLevelAndType(TensorId t, LoopId i, Level lvl, LevelType lt) {
|
||||
assert(isValidLevel(t, lvl) && isValidLoopId(i) && isValidLT(lt));
|
||||
lvlTypes[t][i] = lt;
|
||||
loopToLvl[t][i] = lvl;
|
||||
@ -432,7 +432,7 @@ public:
|
||||
}
|
||||
|
||||
using ForeachTensorLoopIdCallback = function_ref<void(
|
||||
TensorLoopId, TensorId, std::optional<Level>, DimLevelType, bool)>;
|
||||
TensorLoopId, TensorId, std::optional<Level>, LevelType, bool)>;
|
||||
|
||||
/// Iterates over a set of `TensorLoopId`s, invoking the callback
|
||||
/// for each `TensorLoopId` and passing it the corresponding tensor
|
||||
@ -469,7 +469,7 @@ public:
|
||||
|
||||
/// Establishes the two-way map that i <-> <t, lvl, lt>.
|
||||
void setLoopDependentTensorLevel(LoopId i, TensorId t, Level lvl,
|
||||
DimLevelType lt, unsigned coefficient) {
|
||||
LevelType lt, unsigned coefficient) {
|
||||
assert(isValidLoopId(i) && isValidLevel(t, lvl));
|
||||
assert(!loopToUnresolvedLvls[i][t].has_value()); // must be the first def
|
||||
loopToUnresolvedLvls[i][t] = std::make_pair(lvl, lt);
|
||||
@ -520,7 +520,7 @@ public:
|
||||
return loopToUnresolvedLvls[loop(b)][tensor(b)]->first;
|
||||
}
|
||||
|
||||
DimLevelType getLoopDependentLevelType(TensorLoopId b) const {
|
||||
LevelType getLoopDependentLevelType(TensorLoopId b) const {
|
||||
assert(isLvlWithNonTrivialIdxExp(b));
|
||||
return loopToUnresolvedLvls[loop(b)][tensor(b)]->second;
|
||||
}
|
||||
@ -636,7 +636,7 @@ private:
|
||||
// does not.
|
||||
|
||||
/// Map that converts pair<TensorId, LoopId> to the corresponding lvl-type.
|
||||
std::vector<std::vector<DimLevelType>> lvlTypes;
|
||||
std::vector<std::vector<LevelType>> lvlTypes;
|
||||
|
||||
/// Map that converts pair<TensorId, LoopId> to the corresponding lvl.
|
||||
std::vector<std::vector<std::optional<Level>>> loopToLvl;
|
||||
|
@ -197,7 +197,7 @@ public:
|
||||
template <typename P, typename I, typename V>
|
||||
SparseTensorStorage<P, I, V> *
|
||||
readSparseTensor(uint64_t lvlRank, const uint64_t *lvlSizes,
|
||||
const DimLevelType *lvlTypes, const uint64_t *dim2lvl,
|
||||
const LevelType *lvlTypes, const uint64_t *dim2lvl,
|
||||
const uint64_t *lvl2dim) {
|
||||
const uint64_t dimRank = getRank();
|
||||
MapRef map(dimRank, lvlRank, dim2lvl, lvl2dim);
|
||||
|
@ -70,7 +70,7 @@ public:
|
||||
/// Constructs a new sparse-tensor storage object with the given encoding.
|
||||
SparseTensorStorageBase(uint64_t dimRank, const uint64_t *dimSizes,
|
||||
uint64_t lvlRank, const uint64_t *lvlSizes,
|
||||
const DimLevelType *lvlTypes, const uint64_t *dim2lvl,
|
||||
const LevelType *lvlTypes, const uint64_t *dim2lvl,
|
||||
const uint64_t *lvl2dim);
|
||||
virtual ~SparseTensorStorageBase() = default;
|
||||
|
||||
@ -99,10 +99,10 @@ public:
|
||||
}
|
||||
|
||||
/// Gets the level-types array.
|
||||
const std::vector<DimLevelType> &getLvlTypes() const { return lvlTypes; }
|
||||
const std::vector<LevelType> &getLvlTypes() const { return lvlTypes; }
|
||||
|
||||
/// Safely looks up the type of the given level.
|
||||
DimLevelType getLvlType(uint64_t l) const {
|
||||
LevelType getLvlType(uint64_t l) const {
|
||||
assert(l < getLvlRank());
|
||||
return lvlTypes[l];
|
||||
}
|
||||
@ -180,7 +180,7 @@ public:
|
||||
private:
|
||||
const std::vector<uint64_t> dimSizes;
|
||||
const std::vector<uint64_t> lvlSizes;
|
||||
const std::vector<DimLevelType> lvlTypes;
|
||||
const std::vector<LevelType> lvlTypes;
|
||||
const std::vector<uint64_t> dim2lvlVec;
|
||||
const std::vector<uint64_t> lvl2dimVec;
|
||||
|
||||
@ -203,7 +203,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
|
||||
/// doesn't entail `!(positions[l].empty())`.
|
||||
SparseTensorStorage(uint64_t dimRank, const uint64_t *dimSizes,
|
||||
uint64_t lvlRank, const uint64_t *lvlSizes,
|
||||
const DimLevelType *lvlTypes, const uint64_t *dim2lvl,
|
||||
const LevelType *lvlTypes, const uint64_t *dim2lvl,
|
||||
const uint64_t *lvl2dim)
|
||||
: SparseTensorStorageBase(dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes,
|
||||
dim2lvl, lvl2dim),
|
||||
@ -219,7 +219,7 @@ public:
|
||||
/// some other form of initialization.
|
||||
SparseTensorStorage(uint64_t dimRank, const uint64_t *dimSizes,
|
||||
uint64_t lvlRank, const uint64_t *lvlSizes,
|
||||
const DimLevelType *lvlTypes, const uint64_t *dim2lvl,
|
||||
const LevelType *lvlTypes, const uint64_t *dim2lvl,
|
||||
const uint64_t *lvl2dim, SparseTensorCOO<V> *lvlCOO,
|
||||
bool initializeValuesIfAllDense);
|
||||
|
||||
@ -228,7 +228,7 @@ public:
|
||||
/// overhead-storage allocation as the ctor above.
|
||||
SparseTensorStorage(uint64_t dimRank, const uint64_t *dimSizes,
|
||||
uint64_t lvlRank, const uint64_t *lvlSizes,
|
||||
const DimLevelType *lvlTypes, const uint64_t *dim2lvl,
|
||||
const LevelType *lvlTypes, const uint64_t *dim2lvl,
|
||||
const uint64_t *lvl2dim, SparseTensorCOO<V> &lvlCOO);
|
||||
|
||||
/// Constructs a sparse tensor with the given encoding, and initializes
|
||||
@ -240,19 +240,19 @@ public:
|
||||
/// passed in as a single AoS memory.
|
||||
SparseTensorStorage(uint64_t dimRank, const uint64_t *dimSizes,
|
||||
uint64_t lvlRank, const uint64_t *lvlSizes,
|
||||
const DimLevelType *lvlTypes, const uint64_t *dim2lvl,
|
||||
const LevelType *lvlTypes, const uint64_t *dim2lvl,
|
||||
const uint64_t *lvl2dim, const intptr_t *lvlBufs);
|
||||
|
||||
/// Allocates a new empty sparse tensor.
|
||||
static SparseTensorStorage<P, C, V> *
|
||||
newEmpty(uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
|
||||
const uint64_t *lvlSizes, const DimLevelType *lvlTypes,
|
||||
const uint64_t *lvlSizes, const LevelType *lvlTypes,
|
||||
const uint64_t *dim2lvl, const uint64_t *lvl2dim, bool forwarding);
|
||||
|
||||
/// Allocates a new sparse tensor and initializes it from the given COO.
|
||||
static SparseTensorStorage<P, C, V> *
|
||||
newFromCOO(uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
|
||||
const uint64_t *lvlSizes, const DimLevelType *lvlTypes,
|
||||
const uint64_t *lvlSizes, const LevelType *lvlTypes,
|
||||
const uint64_t *dim2lvl, const uint64_t *lvl2dim,
|
||||
SparseTensorCOO<V> &lvlCOO);
|
||||
|
||||
@ -261,7 +261,7 @@ public:
|
||||
static SparseTensorStorage<P, C, V> *
|
||||
packFromLvlBuffers(uint64_t dimRank, const uint64_t *dimSizes,
|
||||
uint64_t lvlRank, const uint64_t *lvlSizes,
|
||||
const DimLevelType *lvlTypes, const uint64_t *dim2lvl,
|
||||
const LevelType *lvlTypes, const uint64_t *dim2lvl,
|
||||
const uint64_t *lvl2dim, uint64_t srcRank,
|
||||
const intptr_t *buffers);
|
||||
|
||||
@ -294,7 +294,7 @@ public:
|
||||
void lexInsert(const uint64_t *lvlCoords, V val) final {
|
||||
assert(lvlCoords);
|
||||
bool allDense = std::all_of(getLvlTypes().begin(), getLvlTypes().end(),
|
||||
[](DimLevelType lt) { return isDenseLT(lt); });
|
||||
[](LevelType lt) { return isDenseLT(lt); });
|
||||
if (allDense) {
|
||||
uint64_t lvlRank = getLvlRank();
|
||||
uint64_t valIdx = 0;
|
||||
@ -654,7 +654,7 @@ private:
|
||||
template <typename P, typename C, typename V>
|
||||
SparseTensorStorage<P, C, V> *SparseTensorStorage<P, C, V>::newEmpty(
|
||||
uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
|
||||
const uint64_t *lvlSizes, const DimLevelType *lvlTypes,
|
||||
const uint64_t *lvlSizes, const LevelType *lvlTypes,
|
||||
const uint64_t *dim2lvl, const uint64_t *lvl2dim, bool forwarding) {
|
||||
SparseTensorCOO<V> *lvlCOO = nullptr;
|
||||
if (forwarding)
|
||||
@ -667,7 +667,7 @@ SparseTensorStorage<P, C, V> *SparseTensorStorage<P, C, V>::newEmpty(
|
||||
template <typename P, typename C, typename V>
|
||||
SparseTensorStorage<P, C, V> *SparseTensorStorage<P, C, V>::newFromCOO(
|
||||
uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
|
||||
const uint64_t *lvlSizes, const DimLevelType *lvlTypes,
|
||||
const uint64_t *lvlSizes, const LevelType *lvlTypes,
|
||||
const uint64_t *dim2lvl, const uint64_t *lvl2dim,
|
||||
SparseTensorCOO<V> &lvlCOO) {
|
||||
return new SparseTensorStorage<P, C, V>(dimRank, dimSizes, lvlRank, lvlSizes,
|
||||
@ -677,7 +677,7 @@ SparseTensorStorage<P, C, V> *SparseTensorStorage<P, C, V>::newFromCOO(
|
||||
template <typename P, typename C, typename V>
|
||||
SparseTensorStorage<P, C, V> *SparseTensorStorage<P, C, V>::packFromLvlBuffers(
|
||||
uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
|
||||
const uint64_t *lvlSizes, const DimLevelType *lvlTypes,
|
||||
const uint64_t *lvlSizes, const LevelType *lvlTypes,
|
||||
const uint64_t *dim2lvl, const uint64_t *lvl2dim, uint64_t srcRank,
|
||||
const intptr_t *buffers) {
|
||||
return new SparseTensorStorage<P, C, V>(dimRank, dimSizes, lvlRank, lvlSizes,
|
||||
@ -693,7 +693,7 @@ SparseTensorStorage<P, C, V> *SparseTensorStorage<P, C, V>::packFromLvlBuffers(
|
||||
template <typename P, typename C, typename V>
|
||||
SparseTensorStorage<P, C, V>::SparseTensorStorage(
|
||||
uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
|
||||
const uint64_t *lvlSizes, const DimLevelType *lvlTypes,
|
||||
const uint64_t *lvlSizes, const LevelType *lvlTypes,
|
||||
const uint64_t *dim2lvl, const uint64_t *lvl2dim,
|
||||
SparseTensorCOO<V> *lvlCOO, bool initializeValuesIfAllDense)
|
||||
: SparseTensorStorage(dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes,
|
||||
@ -742,7 +742,7 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
|
||||
template <typename P, typename C, typename V>
|
||||
SparseTensorStorage<P, C, V>::SparseTensorStorage( // NOLINT
|
||||
uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
|
||||
const uint64_t *lvlSizes, const DimLevelType *lvlTypes,
|
||||
const uint64_t *lvlSizes, const LevelType *lvlTypes,
|
||||
const uint64_t *dim2lvl, const uint64_t *lvl2dim,
|
||||
SparseTensorCOO<V> &lvlCOO)
|
||||
: SparseTensorStorage(dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes,
|
||||
@ -761,7 +761,7 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage( // NOLINT
|
||||
template <typename P, typename C, typename V>
|
||||
SparseTensorStorage<P, C, V>::SparseTensorStorage(
|
||||
uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
|
||||
const uint64_t *lvlSizes, const DimLevelType *lvlTypes,
|
||||
const uint64_t *lvlSizes, const LevelType *lvlTypes,
|
||||
const uint64_t *dim2lvl, const uint64_t *lvl2dim, const intptr_t *lvlBufs)
|
||||
: SparseTensorStorage(dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes,
|
||||
dim2lvl, lvl2dim) {
|
||||
|
@ -52,7 +52,7 @@ extern "C" {
|
||||
MLIR_CRUNNERUTILS_EXPORT void *_mlir_ciface_newSparseTensor( // NOLINT
|
||||
StridedMemRefType<index_type, 1> *dimSizesRef,
|
||||
StridedMemRefType<index_type, 1> *lvlSizesRef,
|
||||
StridedMemRefType<DimLevelType, 1> *lvlTypesRef,
|
||||
StridedMemRefType<LevelType, 1> *lvlTypesRef,
|
||||
StridedMemRefType<index_type, 1> *dim2lvlRef,
|
||||
StridedMemRefType<index_type, 1> *lvl2dimRef, OverheadType posTp,
|
||||
OverheadType crdTp, PrimaryType valTp, Action action, void *ptr);
|
||||
|
@ -23,30 +23,30 @@ using namespace mlir;
|
||||
using namespace mlir::python::adaptors;
|
||||
|
||||
static void populateDialectSparseTensorSubmodule(const py::module &m) {
|
||||
py::enum_<MlirSparseTensorDimLevelType>(m, "DimLevelType", py::module_local())
|
||||
.value("dense", MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE)
|
||||
.value("compressed24", MLIR_SPARSE_TENSOR_DIM_LEVEL_TWO_OUT_OF_FOUR)
|
||||
.value("compressed", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED)
|
||||
.value("compressed_nu", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU)
|
||||
.value("compressed_no", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO)
|
||||
.value("compressed_nu_no", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO)
|
||||
.value("singleton", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON)
|
||||
.value("singleton_nu", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU)
|
||||
.value("singleton_no", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO)
|
||||
.value("singleton_nu_no", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO)
|
||||
.value("loose_compressed", MLIR_SPARSE_TENSOR_DIM_LEVEL_LOOSE_COMPRESSED)
|
||||
py::enum_<MlirSparseTensorLevelType>(m, "LevelType", py::module_local())
|
||||
.value("dense", MLIR_SPARSE_TENSOR_LEVEL_DENSE)
|
||||
.value("compressed24", MLIR_SPARSE_TENSOR_LEVEL_TWO_OUT_OF_FOUR)
|
||||
.value("compressed", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED)
|
||||
.value("compressed_nu", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU)
|
||||
.value("compressed_no", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO)
|
||||
.value("compressed_nu_no", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO)
|
||||
.value("singleton", MLIR_SPARSE_TENSOR_LEVEL_SINGLETON)
|
||||
.value("singleton_nu", MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU)
|
||||
.value("singleton_no", MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO)
|
||||
.value("singleton_nu_no", MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO)
|
||||
.value("loose_compressed", MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED)
|
||||
.value("loose_compressed_nu",
|
||||
MLIR_SPARSE_TENSOR_DIM_LEVEL_LOOSE_COMPRESSED_NU)
|
||||
MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU)
|
||||
.value("loose_compressed_no",
|
||||
MLIR_SPARSE_TENSOR_DIM_LEVEL_LOOSE_COMPRESSED_NO)
|
||||
MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO)
|
||||
.value("loose_compressed_nu_no",
|
||||
MLIR_SPARSE_TENSOR_DIM_LEVEL_LOOSE_COMPRESSED_NU_NO);
|
||||
MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO);
|
||||
|
||||
mlir_attribute_subclass(m, "EncodingAttr",
|
||||
mlirAttributeIsASparseTensorEncodingAttr)
|
||||
.def_classmethod(
|
||||
"get",
|
||||
[](py::object cls, std::vector<MlirSparseTensorDimLevelType> lvlTypes,
|
||||
[](py::object cls, std::vector<MlirSparseTensorLevelType> lvlTypes,
|
||||
std::optional<MlirAffineMap> dimToLvl,
|
||||
std::optional<MlirAffineMap> lvlToDim, int posWidth, int crdWidth,
|
||||
MlirContext context) {
|
||||
@ -64,7 +64,7 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
|
||||
"lvl_types",
|
||||
[](MlirAttribute self) {
|
||||
const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
|
||||
std::vector<MlirSparseTensorDimLevelType> ret;
|
||||
std::vector<MlirSparseTensorLevelType> ret;
|
||||
ret.reserve(lvlRank);
|
||||
for (int l = 0; l < lvlRank; ++l)
|
||||
ret.push_back(mlirSparseTensorEncodingAttrGetLvlType(self, l));
|
||||
|
@ -20,26 +20,25 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor,
|
||||
mlir::sparse_tensor::SparseTensorDialect)
|
||||
|
||||
// Ensure the C-API enums are int-castable to C++ equivalents.
|
||||
static_assert(
|
||||
static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE) ==
|
||||
static_cast<int>(DimLevelType::Dense) &&
|
||||
static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED) ==
|
||||
static_cast<int>(DimLevelType::Compressed) &&
|
||||
static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU) ==
|
||||
static_cast<int>(DimLevelType::CompressedNu) &&
|
||||
static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO) ==
|
||||
static_cast<int>(DimLevelType::CompressedNo) &&
|
||||
static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO) ==
|
||||
static_cast<int>(DimLevelType::CompressedNuNo) &&
|
||||
static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON) ==
|
||||
static_cast<int>(DimLevelType::Singleton) &&
|
||||
static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU) ==
|
||||
static_cast<int>(DimLevelType::SingletonNu) &&
|
||||
static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO) ==
|
||||
static_cast<int>(DimLevelType::SingletonNo) &&
|
||||
static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO) ==
|
||||
static_cast<int>(DimLevelType::SingletonNuNo),
|
||||
"MlirSparseTensorDimLevelType (C-API) and DimLevelType (C++) mismatch");
|
||||
static_assert(static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_DENSE) ==
|
||||
static_cast<int>(LevelType::Dense) &&
|
||||
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED) ==
|
||||
static_cast<int>(LevelType::Compressed) &&
|
||||
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU) ==
|
||||
static_cast<int>(LevelType::CompressedNu) &&
|
||||
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO) ==
|
||||
static_cast<int>(LevelType::CompressedNo) &&
|
||||
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO) ==
|
||||
static_cast<int>(LevelType::CompressedNuNo) &&
|
||||
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON) ==
|
||||
static_cast<int>(LevelType::Singleton) &&
|
||||
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU) ==
|
||||
static_cast<int>(LevelType::SingletonNu) &&
|
||||
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO) ==
|
||||
static_cast<int>(LevelType::SingletonNo) &&
|
||||
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO) ==
|
||||
static_cast<int>(LevelType::SingletonNuNo),
|
||||
"MlirSparseTensorLevelType (C-API) and LevelType (C++) mismatch");
|
||||
|
||||
bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) {
|
||||
return isa<SparseTensorEncodingAttr>(unwrap(attr));
|
||||
@ -47,13 +46,13 @@ bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) {
|
||||
|
||||
MlirAttribute
|
||||
mlirSparseTensorEncodingAttrGet(MlirContext ctx, intptr_t lvlRank,
|
||||
MlirSparseTensorDimLevelType const *lvlTypes,
|
||||
MlirSparseTensorLevelType const *lvlTypes,
|
||||
MlirAffineMap dimToLvl, MlirAffineMap lvlToDim,
|
||||
int posWidth, int crdWidth) {
|
||||
SmallVector<DimLevelType> cppLvlTypes;
|
||||
SmallVector<LevelType> cppLvlTypes;
|
||||
cppLvlTypes.reserve(lvlRank);
|
||||
for (intptr_t l = 0; l < lvlRank; ++l)
|
||||
cppLvlTypes.push_back(static_cast<DimLevelType>(lvlTypes[l]));
|
||||
cppLvlTypes.push_back(static_cast<LevelType>(lvlTypes[l]));
|
||||
return wrap(SparseTensorEncodingAttr::get(unwrap(ctx), cppLvlTypes,
|
||||
unwrap(dimToLvl), unwrap(lvlToDim),
|
||||
posWidth, crdWidth));
|
||||
@ -71,9 +70,9 @@ intptr_t mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr) {
|
||||
return cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlRank();
|
||||
}
|
||||
|
||||
MlirSparseTensorDimLevelType
|
||||
MlirSparseTensorLevelType
|
||||
mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl) {
|
||||
return static_cast<MlirSparseTensorDimLevelType>(
|
||||
return static_cast<MlirSparseTensorLevelType>(
|
||||
cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlType(lvl));
|
||||
}
|
||||
|
||||
|
@ -62,7 +62,7 @@ bool DimSpec::isValid(Ranks const &ranks) const {
|
||||
// `LvlSpec` implementation.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LvlSpec::LvlSpec(LvlVar var, LvlExpr expr, DimLevelType type)
|
||||
LvlSpec::LvlSpec(LvlVar var, LvlExpr expr, LevelType type)
|
||||
: var(var), expr(expr), type(type) {
|
||||
assert(expr);
|
||||
assert(isValidLT(type) && !isUndefLT(type));
|
||||
|
@ -202,10 +202,10 @@ class LvlSpec final {
|
||||
/// The level-expression.
|
||||
LvlExpr expr;
|
||||
/// The level-type (== level-format + lvl-properties).
|
||||
DimLevelType type;
|
||||
LevelType type;
|
||||
|
||||
public:
|
||||
LvlSpec(LvlVar var, LvlExpr expr, DimLevelType type);
|
||||
LvlSpec(LvlVar var, LvlExpr expr, LevelType type);
|
||||
|
||||
MLIRContext *getContext() const {
|
||||
MLIRContext *ctx = expr.tryGetContext();
|
||||
@ -217,7 +217,7 @@ public:
|
||||
constexpr bool canElideVar() const { return elideVar; }
|
||||
void setElideVar(bool b) { elideVar = b; }
|
||||
constexpr LvlExpr getExpr() const { return expr; }
|
||||
constexpr DimLevelType getType() const { return type; }
|
||||
constexpr LevelType getType() const { return type; }
|
||||
|
||||
/// Checks whether the variables bound/used by this spec are valid
|
||||
/// with respect to the given ranks.
|
||||
@ -246,7 +246,7 @@ public:
|
||||
|
||||
ArrayRef<LvlSpec> getLvls() const { return lvlSpecs; }
|
||||
const LvlSpec &getLvl(Level lvl) const { return lvlSpecs[lvl]; }
|
||||
DimLevelType getLvlType(Level lvl) const { return getLvl(lvl).getType(); }
|
||||
LevelType getLvlType(Level lvl) const { return getLvl(lvl).getType(); }
|
||||
|
||||
AffineMap getDimToLvlMap(MLIRContext *context) const;
|
||||
AffineMap getLvlToDimMap(MLIRContext *context) const;
|
||||
|
@ -298,7 +298,7 @@ ParseResult DimLvlMapParser::parseLvlSpec(bool requireLvlVarBinding) {
|
||||
const auto type = lvlTypeParser.parseLvlType(parser);
|
||||
FAILURE_IF_FAILED(type)
|
||||
|
||||
lvlSpecs.emplace_back(var, expr, static_cast<DimLevelType>(*type));
|
||||
lvlSpecs.emplace_back(var, expr, static_cast<LevelType>(*type));
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
//===- LvlTypeParser.h - `DimLevelType` parser ----------------------------===//
|
||||
//===- LvlTypeParser.h - `LevelType` parser ----------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
@ -58,7 +58,7 @@ FailureOr<uint8_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
|
||||
return failure();
|
||||
}
|
||||
|
||||
ERROR_IF(!isValidLT(static_cast<DimLevelType>(properties)),
|
||||
ERROR_IF(!isValidLT(static_cast<LevelType>(properties)),
|
||||
"invalid level type: level format doesn't support the properties");
|
||||
return properties;
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
//===- LvlTypeParser.h - `DimLevelType` parser ------------------*- C++ -*-===//
|
||||
//===- LvlTypeParser.h - `LevelType` parser ------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
|
@ -62,7 +62,7 @@ static constexpr FieldIndex kDataFieldStartingIdx = 0;
|
||||
|
||||
void StorageLayout::foreachField(
|
||||
llvm::function_ref<bool(FieldIndex, SparseTensorFieldKind, Level,
|
||||
DimLevelType)>
|
||||
LevelType)>
|
||||
callback) const {
|
||||
const auto lvlTypes = enc.getLvlTypes();
|
||||
const Level lvlRank = enc.getLvlRank();
|
||||
@ -83,18 +83,18 @@ void StorageLayout::foreachField(
|
||||
}
|
||||
// The values array.
|
||||
if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
|
||||
DimLevelType::Undef)))
|
||||
LevelType::Undef)))
|
||||
return;
|
||||
// Put metadata at the end.
|
||||
if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel,
|
||||
DimLevelType::Undef)))
|
||||
LevelType::Undef)))
|
||||
return;
|
||||
}
|
||||
|
||||
void sparse_tensor::foreachFieldAndTypeInSparseTensor(
|
||||
SparseTensorType stt,
|
||||
llvm::function_ref<bool(Type, FieldIndex, SparseTensorFieldKind, Level,
|
||||
DimLevelType)>
|
||||
LevelType)>
|
||||
callback) {
|
||||
assert(stt.hasEncoding());
|
||||
// Construct the basic types.
|
||||
@ -110,28 +110,28 @@ void sparse_tensor::foreachFieldAndTypeInSparseTensor(
|
||||
// memref<? x eltType> values
|
||||
const Type valMemType = MemRefType::get({ShapedType::kDynamic}, eltType);
|
||||
|
||||
StorageLayout(stt).foreachField(
|
||||
[specType, posMemType, crdMemType, valMemType,
|
||||
callback](FieldIndex fieldIdx, SparseTensorFieldKind fieldKind,
|
||||
Level lvl, DimLevelType lt) -> bool {
|
||||
switch (fieldKind) {
|
||||
case SparseTensorFieldKind::StorageSpec:
|
||||
return callback(specType, fieldIdx, fieldKind, lvl, lt);
|
||||
case SparseTensorFieldKind::PosMemRef:
|
||||
return callback(posMemType, fieldIdx, fieldKind, lvl, lt);
|
||||
case SparseTensorFieldKind::CrdMemRef:
|
||||
return callback(crdMemType, fieldIdx, fieldKind, lvl, lt);
|
||||
case SparseTensorFieldKind::ValMemRef:
|
||||
return callback(valMemType, fieldIdx, fieldKind, lvl, lt);
|
||||
};
|
||||
llvm_unreachable("unrecognized field kind");
|
||||
});
|
||||
StorageLayout(stt).foreachField([specType, posMemType, crdMemType, valMemType,
|
||||
callback](FieldIndex fieldIdx,
|
||||
SparseTensorFieldKind fieldKind,
|
||||
Level lvl, LevelType lt) -> bool {
|
||||
switch (fieldKind) {
|
||||
case SparseTensorFieldKind::StorageSpec:
|
||||
return callback(specType, fieldIdx, fieldKind, lvl, lt);
|
||||
case SparseTensorFieldKind::PosMemRef:
|
||||
return callback(posMemType, fieldIdx, fieldKind, lvl, lt);
|
||||
case SparseTensorFieldKind::CrdMemRef:
|
||||
return callback(crdMemType, fieldIdx, fieldKind, lvl, lt);
|
||||
case SparseTensorFieldKind::ValMemRef:
|
||||
return callback(valMemType, fieldIdx, fieldKind, lvl, lt);
|
||||
};
|
||||
llvm_unreachable("unrecognized field kind");
|
||||
});
|
||||
}
|
||||
|
||||
unsigned StorageLayout::getNumFields() const {
|
||||
unsigned numFields = 0;
|
||||
foreachField([&numFields](FieldIndex, SparseTensorFieldKind, Level,
|
||||
DimLevelType) -> bool {
|
||||
LevelType) -> bool {
|
||||
numFields++;
|
||||
return true;
|
||||
});
|
||||
@ -141,7 +141,7 @@ unsigned StorageLayout::getNumFields() const {
|
||||
unsigned StorageLayout::getNumDataFields() const {
|
||||
unsigned numFields = 0; // one value memref
|
||||
foreachField([&numFields](FieldIndex fidx, SparseTensorFieldKind, Level,
|
||||
DimLevelType) -> bool {
|
||||
LevelType) -> bool {
|
||||
if (fidx >= kDataFieldStartingIdx)
|
||||
numFields++;
|
||||
return true;
|
||||
@ -167,7 +167,7 @@ StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind,
|
||||
}
|
||||
foreachField([lvl, kind, &fieldIdx](FieldIndex fIdx,
|
||||
SparseTensorFieldKind fKind, Level fLvl,
|
||||
DimLevelType lt) -> bool {
|
||||
LevelType lt) -> bool {
|
||||
if ((lvl && fLvl == lvl.value() && kind == fKind) ||
|
||||
(kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) {
|
||||
fieldIdx = fIdx;
|
||||
@ -343,9 +343,9 @@ Level SparseTensorEncodingAttr::getLvlRank() const {
|
||||
return getLvlTypes().size();
|
||||
}
|
||||
|
||||
DimLevelType SparseTensorEncodingAttr::getLvlType(Level l) const {
|
||||
LevelType SparseTensorEncodingAttr::getLvlType(Level l) const {
|
||||
if (!getImpl())
|
||||
return DimLevelType::Dense;
|
||||
return LevelType::Dense;
|
||||
assert(l < getLvlRank() && "Level is out of bounds");
|
||||
return getLvlTypes()[l];
|
||||
}
|
||||
@ -469,7 +469,7 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
return {};
|
||||
|
||||
// Process the data from the parsed dictionary value into struct-like data.
|
||||
SmallVector<DimLevelType> lvlTypes;
|
||||
SmallVector<LevelType> lvlTypes;
|
||||
SmallVector<SparseTensorDimSliceAttr> dimSlices;
|
||||
AffineMap dimToLvl = {};
|
||||
AffineMap lvlToDim = {};
|
||||
@ -621,9 +621,8 @@ void SparseTensorEncodingAttr::printDimensions(
|
||||
}
|
||||
}
|
||||
|
||||
void SparseTensorEncodingAttr::printLevels(
|
||||
AffineMap &map, AsmPrinter &printer,
|
||||
ArrayRef<DimLevelType> lvlTypes) const {
|
||||
void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer,
|
||||
ArrayRef<LevelType> lvlTypes) const {
|
||||
for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) {
|
||||
map.getResult(i).print(printer.getStream());
|
||||
printer << " : " << toMLIRString(lvlTypes[i]) << ", ";
|
||||
@ -635,12 +634,10 @@ void SparseTensorEncodingAttr::printLevels(
|
||||
}
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
SparseTensorEncodingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
|
||||
ArrayRef<DimLevelType> lvlTypes,
|
||||
AffineMap dimToLvl, AffineMap lvlToDim,
|
||||
unsigned posWidth, unsigned crdWidth,
|
||||
ArrayRef<SparseTensorDimSliceAttr> dimSlices) {
|
||||
LogicalResult SparseTensorEncodingAttr::verify(
|
||||
function_ref<InFlightDiagnostic()> emitError, ArrayRef<LevelType> lvlTypes,
|
||||
AffineMap dimToLvl, AffineMap lvlToDim, unsigned posWidth,
|
||||
unsigned crdWidth, ArrayRef<SparseTensorDimSliceAttr> dimSlices) {
|
||||
if (!acceptBitWidth(posWidth))
|
||||
return emitError() << "unexpected position bitwidth: " << posWidth;
|
||||
if (!acceptBitWidth(crdWidth))
|
||||
@ -652,7 +649,7 @@ SparseTensorEncodingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
|
||||
return emitError() << "expected compressed or loose_compressed level "
|
||||
"before singleton level";
|
||||
if (!std::all_of(it, lvlTypes.end(),
|
||||
[](DimLevelType i) { return isSingletonLT(i); }))
|
||||
[](LevelType i) { return isSingletonLT(i); }))
|
||||
return emitError() << "expected all singleton lvlTypes "
|
||||
"following a singleton level";
|
||||
}
|
||||
@ -891,7 +888,7 @@ RankedTensorType sparse_tensor::getCOOFromTypeWithOrdering(RankedTensorType rtt,
|
||||
bool ordered) {
|
||||
const SparseTensorType src(rtt);
|
||||
const Level lvlRank = src.getLvlRank();
|
||||
SmallVector<DimLevelType> lvlTypes;
|
||||
SmallVector<LevelType> lvlTypes;
|
||||
lvlTypes.reserve(lvlRank);
|
||||
|
||||
// An unordered and non-unique compressed level at beginning.
|
||||
@ -960,7 +957,7 @@ Level mlir::sparse_tensor::toStoredDim(SparseTensorEncodingAttr enc,
|
||||
/// irrelevant fields that do not alter the sparse tensor memory layout.
|
||||
static SparseTensorEncodingAttr
|
||||
getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
|
||||
SmallVector<DimLevelType> lts;
|
||||
SmallVector<LevelType> lts;
|
||||
for (auto lt : enc.getLvlTypes())
|
||||
lts.push_back(*buildLevelType(*getLevelFormat(lt), true, true));
|
||||
|
||||
@ -1070,7 +1067,7 @@ static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
|
||||
bool misMatch = false;
|
||||
layout.foreachField([&idx, &misMatch, stt, valTp,
|
||||
lvlTps](FieldIndex fid, SparseTensorFieldKind fKind,
|
||||
Level lvl, DimLevelType lt) -> bool {
|
||||
Level lvl, LevelType lt) -> bool {
|
||||
if (fKind == SparseTensorFieldKind::StorageSpec)
|
||||
return true;
|
||||
|
||||
@ -1301,8 +1298,8 @@ void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState,
|
||||
LogicalResult ReinterpretMapOp::verify() {
|
||||
auto srcStt = getSparseTensorType(getSource());
|
||||
auto dstStt = getSparseTensorType(getDest());
|
||||
ArrayRef<DimLevelType> srcLvlTps = srcStt.getLvlTypes();
|
||||
ArrayRef<DimLevelType> dstLvlTps = dstStt.getLvlTypes();
|
||||
ArrayRef<LevelType> srcLvlTps = srcStt.getLvlTypes();
|
||||
ArrayRef<LevelType> dstLvlTps = dstStt.getLvlTypes();
|
||||
|
||||
if (srcLvlTps.size() != dstLvlTps.size())
|
||||
return emitError("Level rank mismatch between source/dest tensors");
|
||||
|
@ -77,10 +77,10 @@ public:
|
||||
const TensorExp &exp(ExprId e) const { return latticeMerger.exp(e); }
|
||||
const LatPoint &lat(LatPointId l) const { return latticeMerger.lat(l); }
|
||||
ArrayRef<LatPointId> set(LatSetId s) const { return latticeMerger.set(s); }
|
||||
DimLevelType lt(TensorId t, LoopId i) const {
|
||||
LevelType lt(TensorId t, LoopId i) const {
|
||||
return latticeMerger.getLvlType(t, i);
|
||||
}
|
||||
DimLevelType lt(TensorLoopId b) const { return latticeMerger.getLvlType(b); }
|
||||
LevelType lt(TensorLoopId b) const { return latticeMerger.getLvlType(b); }
|
||||
|
||||
unsigned getLoopNum() const { return latticeMerger.getNumLoops(); }
|
||||
|
||||
|
@ -428,8 +428,8 @@ inline Value constantPrimaryTypeEncoding(OpBuilder &builder, Location loc,
|
||||
}
|
||||
|
||||
/// Generates a constant of the internal dimension level type encoding.
|
||||
inline Value constantDimLevelTypeEncoding(OpBuilder &builder, Location loc,
|
||||
DimLevelType lt) {
|
||||
inline Value constantLevelTypeEncoding(OpBuilder &builder, Location loc,
|
||||
LevelType lt) {
|
||||
return constantI8(builder, loc, static_cast<uint8_t>(lt));
|
||||
}
|
||||
|
||||
|
@ -295,7 +295,7 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
|
||||
// tensors array (len == numManifestTensor).
|
||||
this->tensors.assign(ts.begin(), ts.end());
|
||||
// Arrays with len == numTensor.
|
||||
this->lvlTypes.assign(numTensors, std::vector<DimLevelType>());
|
||||
this->lvlTypes.assign(numTensors, std::vector<LevelType>());
|
||||
this->lvlSizes.assign(numTensors, std::vector<Value>());
|
||||
this->highs.assign(numTensors, std::vector<Value>());
|
||||
this->segHi.assign(numTensors, std::vector<Value>());
|
||||
@ -330,7 +330,7 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
|
||||
// to the total number of loops (each level can potentially be mapped to
|
||||
// one of the loop being generated).
|
||||
lvlRank = numLoops;
|
||||
lvlTypes[tid].assign(lvlRank, DimLevelType::Dense);
|
||||
lvlTypes[tid].assign(lvlRank, LevelType::Dense);
|
||||
} else {
|
||||
const Value t = tensors[tid];
|
||||
// a scalar or 0-dimension tensors
|
||||
@ -349,7 +349,7 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
|
||||
for (auto lvlTp : enc.getLvlTypes())
|
||||
lvlTypes[tid].push_back(lvlTp);
|
||||
} else {
|
||||
lvlTypes[tid].assign(lvlRank, DimLevelType::Dense);
|
||||
lvlTypes[tid].assign(lvlRank, LevelType::Dense);
|
||||
}
|
||||
}
|
||||
|
||||
@ -2072,7 +2072,7 @@ bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
|
||||
|
||||
// Only when the level is sorted, the next-non-empty slice can be computed
|
||||
// efficiently.
|
||||
const DimLevelType lvlType = lvlTypes[tid][lvl];
|
||||
const LevelType lvlType = lvlTypes[tid][lvl];
|
||||
assert(isOrderedLT(lvlType));
|
||||
if (isSingletonLT(lvlType)) {
|
||||
llvm_unreachable("TODO: dense level should be easy to support, while "
|
||||
|
@ -676,7 +676,7 @@ private:
|
||||
/// Input and (optional) output tensors.
|
||||
std::vector<Value> tensors;
|
||||
/// Level-types for each `(TensorId, Level)` pair.
|
||||
std::vector<std::vector<DimLevelType>> lvlTypes;
|
||||
std::vector<std::vector<LevelType>> lvlTypes;
|
||||
// Sparse iteration information for each `(TensorId, Level)` pair.
|
||||
// These arrays are updated to remain current within the current loop.
|
||||
// TODO: Clarify which of these are indexed by dstLvl vs srcLvl.
|
||||
|
@ -216,7 +216,7 @@ static void createAllocFields(OpBuilder &builder, Location loc,
|
||||
stt,
|
||||
[&builder, &fields, stt, loc, posHeuristic, crdHeuristic, valHeuristic,
|
||||
enableInit](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind,
|
||||
Level /*lvl*/, DimLevelType /*lt*/) -> bool {
|
||||
Level /*lvl*/, LevelType /*lt*/) -> bool {
|
||||
assert(fields.size() == fIdx);
|
||||
Value field;
|
||||
switch (fKind) {
|
||||
@ -1155,7 +1155,7 @@ public:
|
||||
SparseTensorType(cast<RankedTensorType>(op.getResult().getType())),
|
||||
[&rewriter, &fields, srcDesc,
|
||||
loc](Type fTp, FieldIndex fIdx, SparseTensorFieldKind fKind, Level lvl,
|
||||
DimLevelType /*lt*/) -> bool {
|
||||
LevelType /*lt*/) -> bool {
|
||||
// Simply reuses the storage specifier as it is an SSA value.
|
||||
if (fKind == SparseTensorFieldKind::StorageSpec) {
|
||||
fields.push_back(srcDesc.getSpecifier());
|
||||
@ -1284,7 +1284,7 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
|
||||
stt,
|
||||
[&rewriter, &fields, &op, &stt,
|
||||
loc](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind,
|
||||
Level /*lvl*/, DimLevelType lt) -> bool {
|
||||
Level /*lvl*/, LevelType lt) -> bool {
|
||||
assert(fields.size() == fIdx);
|
||||
if (fKind == SparseTensorFieldKind::StorageSpec) {
|
||||
fields.push_back(
|
||||
@ -1333,7 +1333,7 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
|
||||
continue;
|
||||
|
||||
// Sets up the memory size by reading the last value in position array.
|
||||
DimLevelType lt = stt.getLvlType(lvl);
|
||||
LevelType lt = stt.getLvlType(lvl);
|
||||
// Simply forwards the position index when this is a dense level.
|
||||
if (isDenseLT(lt)) {
|
||||
memSize = rewriter.create<arith::MulIOp>(loc, lvlSize, memSize);
|
||||
@ -1387,10 +1387,10 @@ struct SparseDisassembleOpConverter
|
||||
Location loc = op.getLoc();
|
||||
SmallVector<Value> retMem;
|
||||
SmallVector<Value> retLen;
|
||||
desc.getLayout().foreachField([desc, loc, &rewriter, &op, &retMem, &retLen](
|
||||
FieldIndex fid,
|
||||
SparseTensorFieldKind fKind, Level lvl,
|
||||
DimLevelType lt) -> bool {
|
||||
desc.getLayout().foreachField([desc, loc, &rewriter, &op, &retMem,
|
||||
&retLen](FieldIndex fid,
|
||||
SparseTensorFieldKind fKind,
|
||||
Level lvl, LevelType lt) -> bool {
|
||||
if (fKind == SparseTensorFieldKind::StorageSpec)
|
||||
return true;
|
||||
SparseTensorType stt(desc.getRankedTensorType());
|
||||
|
@ -146,7 +146,7 @@ static Value genLvlTypesBuffer(OpBuilder &builder, Location loc,
|
||||
SmallVector<Value> lvlTypes;
|
||||
lvlTypes.reserve(stt.getLvlRank());
|
||||
for (const auto lt : stt.getEncoding().getLvlTypes())
|
||||
lvlTypes.push_back(constantDimLevelTypeEncoding(builder, loc, lt));
|
||||
lvlTypes.push_back(constantLevelTypeEncoding(builder, loc, lt));
|
||||
return allocaBuffer(builder, loc, lvlTypes);
|
||||
}
|
||||
|
||||
|
@ -42,7 +42,7 @@ convertSparseTensorType(RankedTensorType rtp, SmallVectorImpl<Type> &fields) {
|
||||
stt,
|
||||
[&fields](Type fieldType, FieldIndex fieldIdx,
|
||||
SparseTensorFieldKind /*fieldKind*/, Level /*lvl*/,
|
||||
DimLevelType /*lt*/) -> bool {
|
||||
LevelType /*lt*/) -> bool {
|
||||
assert(fieldIdx == fields.size());
|
||||
fields.push_back(fieldType);
|
||||
return true;
|
||||
|
@ -45,9 +45,8 @@ static bool isZeroValue(Value val) {
|
||||
// Helper to detect a sparse tensor type operand.
|
||||
static bool isSparseTensor(Value v) {
|
||||
auto enc = getSparseTensorEncoding(v.getType());
|
||||
return enc && !llvm::all_of(enc.getLvlTypes(), [](auto lt) {
|
||||
return lt == DimLevelType::Dense;
|
||||
});
|
||||
return enc && !llvm::all_of(enc.getLvlTypes(),
|
||||
[](auto lt) { return lt == LevelType::Dense; });
|
||||
}
|
||||
static bool isSparseTensor(OpOperand *op) { return isSparseTensor(op->get()); }
|
||||
|
||||
|
@ -79,7 +79,7 @@ static bool isInvariantAffine(AffineExpr a, unsigned loopDepth, LoopId ldx,
|
||||
/// same index is used more than once. Also rejects compound affine
|
||||
/// expressions in sparse dimensions.
|
||||
static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
|
||||
DimLevelType lt, bool setLvlFormat = true) {
|
||||
LevelType lt, bool setLvlFormat = true) {
|
||||
switch (a.getKind()) {
|
||||
case AffineExprKind::DimId: {
|
||||
const LoopId idx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition());
|
||||
@ -125,7 +125,7 @@ static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
|
||||
///
|
||||
/// TODO: constant should be easy to handle.
|
||||
static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
|
||||
AffineExpr a, DimLevelType lt, bool isSubExp = false,
|
||||
AffineExpr a, LevelType lt, bool isSubExp = false,
|
||||
int64_t coefficient = 1) {
|
||||
switch (a.getKind()) {
|
||||
case AffineExprKind::DimId: {
|
||||
@ -275,7 +275,7 @@ static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) {
|
||||
// to be sliced.
|
||||
for (Level l = 0; l < lvlRank; l++) {
|
||||
const AffineExpr a = map.getResult(l);
|
||||
const DimLevelType lt = enc.getLvlType(l);
|
||||
const LevelType lt = enc.getLvlType(l);
|
||||
if (idxReducBased && needIdxReduc) {
|
||||
if (!findDepIdxSet(env.merger(), tid, l, a, lt))
|
||||
return false; // inadmissible affine expression
|
||||
@ -883,8 +883,8 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId ldx,
|
||||
Value cond;
|
||||
env.merger().foreachTensorLoopId(
|
||||
p, /*simple=*/true,
|
||||
[&](TensorLoopId b, TensorId tid, std::optional<Level> lvl,
|
||||
DimLevelType lt, bool isIdxRed) {
|
||||
[&](TensorLoopId b, TensorId tid, std::optional<Level> lvl, LevelType lt,
|
||||
bool isIdxRed) {
|
||||
if (isIdxRed) {
|
||||
// Since there is no 1:1 mapping from loop to level (multiple loops
|
||||
// are required to resolve one level with non-trivial index
|
||||
@ -970,7 +970,7 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
|
||||
SmallVector<TensorLevel> tidLvls;
|
||||
env.merger().foreachTensorLoopId(l0, [&](TensorLoopId b, TensorId tid,
|
||||
std::optional<Level> lvl,
|
||||
DimLevelType lt, bool isIdxReduc) {
|
||||
LevelType lt, bool isIdxReduc) {
|
||||
assert(env.merger().loop(b) == idx);
|
||||
if (isDenseLT(lt) || isUndefLT(lt)) {
|
||||
if (tid == env.merger().getSynTensorID()) {
|
||||
@ -1048,89 +1048,89 @@ static bool translateBitsToTidLvlPairs(
|
||||
|
||||
unsigned numloopCond = 0;
|
||||
bool hasNonUnique = false;
|
||||
env.merger().foreachTensorLoopId(
|
||||
li, [&, ldx](TensorLoopId b, TensorId tid, std::optional<Level> lvl,
|
||||
DimLevelType lt, bool isIdxReduc) {
|
||||
if (simple[b]) {
|
||||
if (isIdxReduc) {
|
||||
tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
|
||||
numloopCond++;
|
||||
return;
|
||||
}
|
||||
if (isUndefLT(lt)) {
|
||||
// An undefined lt in the lattices, we probably mean to
|
||||
// iterate based on the level of output tensor. E.g., this
|
||||
// could be a synthetic tensor (for invariants and sparse
|
||||
// output tensor).
|
||||
auto itType = env.op().getIteratorTypesArray()[ldx];
|
||||
if (linalg::isReductionIterator(itType) &&
|
||||
env.merger().getSynTensorID() == tid) {
|
||||
// Coiterating with an invariant, and this is a reduction loop
|
||||
// e.g., out = prod(in[i][j] op invariant);
|
||||
// In this case, we can not infer the loop bound from output
|
||||
// (whose level is reduced). Instead we use the synthetic tensor
|
||||
// to infer the bound.
|
||||
// The level of the synthetic tensor is the current loop depth;
|
||||
// the rank of the synthetic tensor equals to number of loops.
|
||||
lvl = env.emitter().getCurrentDepth();
|
||||
} else {
|
||||
// or a broadcast
|
||||
// out[i][j] = in[i] (j is undef for input)
|
||||
tid = outTid;
|
||||
lvl = outLvl;
|
||||
// Skips invalid lvl (e.g., when this is a zero ranked tensor).
|
||||
if (!lvl)
|
||||
return;
|
||||
}
|
||||
}
|
||||
hasNonUnique = !isUniqueLT(lt) || hasNonUnique;
|
||||
tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
|
||||
numloopCond++;
|
||||
} else if (isDenseLT(lt) || isIdxReduc) {
|
||||
tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
|
||||
env.merger().foreachTensorLoopId(li, [&, ldx](TensorLoopId b, TensorId tid,
|
||||
std::optional<Level> lvl,
|
||||
LevelType lt, bool isIdxReduc) {
|
||||
if (simple[b]) {
|
||||
if (isIdxReduc) {
|
||||
tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
|
||||
numloopCond++;
|
||||
return;
|
||||
}
|
||||
if (isUndefLT(lt)) {
|
||||
// An undefined lt in the lattices, we probably mean to
|
||||
// iterate based on the level of output tensor. E.g., this
|
||||
// could be a synthetic tensor (for invariants and sparse
|
||||
// output tensor).
|
||||
auto itType = env.op().getIteratorTypesArray()[ldx];
|
||||
if (linalg::isReductionIterator(itType) &&
|
||||
env.merger().getSynTensorID() == tid) {
|
||||
// Coiterating with an invariant, and this is a reduction loop
|
||||
// e.g., out = prod(in[i][j] op invariant);
|
||||
// In this case, we can not infer the loop bound from output
|
||||
// (whose level is reduced). Instead we use the synthetic tensor
|
||||
// to infer the bound.
|
||||
// The level of the synthetic tensor is the current loop depth;
|
||||
// the rank of the synthetic tensor equals to number of loops.
|
||||
lvl = env.emitter().getCurrentDepth();
|
||||
} else {
|
||||
assert(isUndefLT(lt));
|
||||
linalg::GenericOp op = env.op();
|
||||
if (tid >= op.getNumDpsInputs())
|
||||
// We only handle affine expression on input tensors (for now).
|
||||
return;
|
||||
OpOperand *operand = &op->getOpOperand(tid);
|
||||
const auto stt = getSparseTensorType(operand->get());
|
||||
// Non-annotated dense tensors requires no special handling.
|
||||
if (!stt.hasEncoding())
|
||||
// or a broadcast
|
||||
// out[i][j] = in[i] (j is undef for input)
|
||||
tid = outTid;
|
||||
lvl = outLvl;
|
||||
// Skips invalid lvl (e.g., when this is a zero ranked tensor).
|
||||
if (!lvl)
|
||||
return;
|
||||
}
|
||||
}
|
||||
hasNonUnique = !isUniqueLT(lt) || hasNonUnique;
|
||||
tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
|
||||
numloopCond++;
|
||||
} else if (isDenseLT(lt) || isIdxReduc) {
|
||||
tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
|
||||
} else {
|
||||
assert(isUndefLT(lt));
|
||||
linalg::GenericOp op = env.op();
|
||||
if (tid >= op.getNumDpsInputs())
|
||||
// We only handle affine expression on input tensors (for now).
|
||||
return;
|
||||
OpOperand *operand = &op->getOpOperand(tid);
|
||||
const auto stt = getSparseTensorType(operand->get());
|
||||
// Non-annotated dense tensors requires no special handling.
|
||||
if (!stt.hasEncoding())
|
||||
return;
|
||||
|
||||
ArrayRef<AffineExpr> affines =
|
||||
op.getMatchingIndexingMap(operand).getResults();
|
||||
const Level lvlRank = stt.getLvlRank();
|
||||
assert(affines.size() == static_cast<size_t>(lvlRank));
|
||||
for (Level l = 0; l < lvlRank; l++) {
|
||||
AffineExpr exp = affines[l];
|
||||
// Skip simple affine expression and non-dense levels (which
|
||||
// have their own filter loop).
|
||||
if (isa<AffineDimExpr>(exp) || !stt.isDenseLvl(l))
|
||||
continue;
|
||||
ArrayRef<AffineExpr> affines =
|
||||
op.getMatchingIndexingMap(operand).getResults();
|
||||
const Level lvlRank = stt.getLvlRank();
|
||||
assert(affines.size() == static_cast<size_t>(lvlRank));
|
||||
for (Level l = 0; l < lvlRank; l++) {
|
||||
AffineExpr exp = affines[l];
|
||||
// Skip simple affine expression and non-dense levels (which
|
||||
// have their own filter loop).
|
||||
if (isa<AffineDimExpr>(exp) || !stt.isDenseLvl(l))
|
||||
continue;
|
||||
|
||||
// Constant affine expression are handled in genLoop
|
||||
if (!isa<AffineConstantExpr>(exp)) {
|
||||
bool isAtLoop = false;
|
||||
if (isInvariantAffine(exp, env.getLoopDepth(), ldx, isAtLoop) &&
|
||||
isAtLoop) {
|
||||
// If the compound affine is invariant and we are right at the
|
||||
// level. We need to generate the address according to the
|
||||
// affine expression. This is also the best place we can do it
|
||||
// to avoid putting it inside inner loops.
|
||||
// NOTE: It assumes that the levels of the input tensor are
|
||||
// initialized in order (and it is also currently guaranteed by
|
||||
// computeIterationGraph), another more admissible approach
|
||||
// might be accepting out-of-order access between consecutive
|
||||
// dense levels.
|
||||
affineTidLvls.emplace_back(env.makeTensorLevel(tid, l), exp);
|
||||
}
|
||||
}
|
||||
// Constant affine expression are handled in genLoop
|
||||
if (!isa<AffineConstantExpr>(exp)) {
|
||||
bool isAtLoop = false;
|
||||
if (isInvariantAffine(exp, env.getLoopDepth(), ldx, isAtLoop) &&
|
||||
isAtLoop) {
|
||||
// If the compound affine is invariant and we are right at the
|
||||
// level. We need to generate the address according to the
|
||||
// affine expression. This is also the best place we can do it
|
||||
// to avoid putting it inside inner loops.
|
||||
// NOTE: It assumes that the levels of the input tensor are
|
||||
// initialized in order (and it is also currently guaranteed by
|
||||
// computeIterationGraph), another more admissible approach
|
||||
// might be accepting out-of-order access between consecutive
|
||||
// dense levels.
|
||||
affineTidLvls.emplace_back(env.makeTensorLevel(tid, l), exp);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if (isDenseLT(env.lt(outTid, ldx))) {
|
||||
// Note that we generate dense indices of the output tensor
|
||||
|
@ -226,8 +226,7 @@ Merger::Merger(unsigned numInputOutputTensors, unsigned numLoops,
|
||||
syntheticTensor(numInputOutputTensors),
|
||||
numTensors(numInputOutputTensors + 1), numLoops(numLoops),
|
||||
hasSparseOut(false),
|
||||
lvlTypes(numTensors,
|
||||
std::vector<DimLevelType>(numLoops, DimLevelType::Undef)),
|
||||
lvlTypes(numTensors, std::vector<LevelType>(numLoops, LevelType::Undef)),
|
||||
loopToLvl(numTensors,
|
||||
std::vector<std::optional<Level>>(numLoops, std::nullopt)),
|
||||
lvlToLoop(numTensors,
|
||||
|
@ -19,7 +19,7 @@ using namespace mlir::sparse_tensor;
|
||||
|
||||
SparseTensorStorageBase::SparseTensorStorageBase( // NOLINT
|
||||
uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
|
||||
const uint64_t *lvlSizes, const DimLevelType *lvlTypes,
|
||||
const uint64_t *lvlSizes, const LevelType *lvlTypes,
|
||||
const uint64_t *dim2lvl, const uint64_t *lvl2dim)
|
||||
: dimSizes(dimSizes, dimSizes + dimRank),
|
||||
lvlSizes(lvlSizes, lvlSizes + lvlRank),
|
||||
|
@ -173,7 +173,7 @@ static_assert(std::is_same<index_type, uint64_t>::value,
|
||||
void *_mlir_ciface_newSparseTensor( // NOLINT
|
||||
StridedMemRefType<index_type, 1> *dimSizesRef,
|
||||
StridedMemRefType<index_type, 1> *lvlSizesRef,
|
||||
StridedMemRefType<DimLevelType, 1> *lvlTypesRef,
|
||||
StridedMemRefType<LevelType, 1> *lvlTypesRef,
|
||||
StridedMemRefType<index_type, 1> *dim2lvlRef,
|
||||
StridedMemRefType<index_type, 1> *lvl2dimRef, OverheadType posTp,
|
||||
OverheadType crdTp, PrimaryType valTp, Action action, void *ptr) {
|
||||
@ -189,7 +189,7 @@ void *_mlir_ciface_newSparseTensor( // NOLINT
|
||||
ASSERT_USIZE_EQ(lvl2dimRef, dimRank);
|
||||
const index_type *dimSizes = MEMREF_GET_PAYLOAD(dimSizesRef);
|
||||
const index_type *lvlSizes = MEMREF_GET_PAYLOAD(lvlSizesRef);
|
||||
const DimLevelType *lvlTypes = MEMREF_GET_PAYLOAD(lvlTypesRef);
|
||||
const LevelType *lvlTypes = MEMREF_GET_PAYLOAD(lvlTypesRef);
|
||||
const index_type *dim2lvl = MEMREF_GET_PAYLOAD(dim2lvlRef);
|
||||
const index_type *lvl2dim = MEMREF_GET_PAYLOAD(lvl2dimRef);
|
||||
|
||||
|
@ -43,8 +43,8 @@ static int testRoundtripEncoding(MlirContext ctx) {
|
||||
MlirAffineMap lvlToDim =
|
||||
mlirSparseTensorEncodingAttrGetLvlToDim(originalAttr);
|
||||
int lvlRank = mlirSparseTensorEncodingGetLvlRank(originalAttr);
|
||||
enum MlirSparseTensorDimLevelType *lvlTypes =
|
||||
malloc(sizeof(enum MlirSparseTensorDimLevelType) * lvlRank);
|
||||
enum MlirSparseTensorLevelType *lvlTypes =
|
||||
malloc(sizeof(enum MlirSparseTensorLevelType) * lvlRank);
|
||||
for (int l = 0; l < lvlRank; ++l) {
|
||||
lvlTypes[l] = mlirSparseTensorEncodingAttrGetLvlType(originalAttr, l);
|
||||
fprintf(stderr, "level_type: %d\n", lvlTypes[l]);
|
||||
|
@ -53,7 +53,7 @@ module {
|
||||
}
|
||||
|
||||
//
|
||||
// The first test suite (for non-singleton DimLevelTypes).
|
||||
// The first test suite (for non-singleton LevelTypes).
|
||||
//
|
||||
func.func @entry() {
|
||||
//
|
||||
|
@ -72,7 +72,7 @@ module {
|
||||
}
|
||||
|
||||
//
|
||||
// The first test suite (for non-singleton DimLevelTypes).
|
||||
// The first test suite (for non-singleton LevelTypes).
|
||||
//
|
||||
func.func @testNonSingleton() {
|
||||
//
|
||||
@ -125,7 +125,7 @@ module {
|
||||
}
|
||||
|
||||
//
|
||||
// The second test suite (for singleton DimLevelTypes).
|
||||
// The second test suite (for singleton LevelTypes).
|
||||
//
|
||||
func.func @testSingleton() {
|
||||
//
|
||||
|
@ -140,11 +140,11 @@ def main():
|
||||
# straightforward to adapt the code below to explore more combinations.
|
||||
# For these simple orderings, dim2lvl and lvl2dim are the same.
|
||||
levels = [
|
||||
[st.DimLevelType.compressed_nu, st.DimLevelType.singleton],
|
||||
[st.DimLevelType.dense, st.DimLevelType.dense],
|
||||
[st.DimLevelType.dense, st.DimLevelType.compressed],
|
||||
[st.DimLevelType.compressed, st.DimLevelType.dense],
|
||||
[st.DimLevelType.compressed, st.DimLevelType.compressed],
|
||||
[st.LevelType.compressed_nu, st.LevelType.singleton],
|
||||
[st.LevelType.dense, st.LevelType.dense],
|
||||
[st.LevelType.dense, st.LevelType.compressed],
|
||||
[st.LevelType.compressed, st.LevelType.dense],
|
||||
[st.LevelType.compressed, st.LevelType.compressed],
|
||||
]
|
||||
orderings = [
|
||||
ir.AffineMap.get_permutation([0, 1]),
|
||||
|
@ -126,11 +126,11 @@ def main():
|
||||
e = False
|
||||
opt = f"parallelization-strategy=none"
|
||||
levels = [
|
||||
[st.DimLevelType.compressed_nu, st.DimLevelType.singleton],
|
||||
[st.DimLevelType.dense, st.DimLevelType.dense],
|
||||
[st.DimLevelType.dense, st.DimLevelType.compressed],
|
||||
[st.DimLevelType.compressed, st.DimLevelType.dense],
|
||||
[st.DimLevelType.compressed, st.DimLevelType.compressed],
|
||||
[st.LevelType.compressed_nu, st.LevelType.singleton],
|
||||
[st.LevelType.dense, st.LevelType.dense],
|
||||
[st.LevelType.dense, st.LevelType.compressed],
|
||||
[st.LevelType.compressed, st.LevelType.dense],
|
||||
[st.LevelType.compressed, st.LevelType.compressed],
|
||||
]
|
||||
orderings = [
|
||||
ir.AffineMap.get_permutation([0, 1]),
|
||||
|
@ -125,10 +125,10 @@ def main():
|
||||
# regular and loose compression and various metadata bitwidths.
|
||||
# For these simple orderings, dim2lvl and lvl2dim are the same.
|
||||
levels = [
|
||||
[st.DimLevelType.compressed_nu, st.DimLevelType.singleton],
|
||||
[st.DimLevelType.dense, st.DimLevelType.compressed],
|
||||
[st.DimLevelType.dense, st.DimLevelType.loose_compressed],
|
||||
[st.DimLevelType.compressed, st.DimLevelType.compressed],
|
||||
[st.LevelType.compressed_nu, st.LevelType.singleton],
|
||||
[st.LevelType.dense, st.LevelType.compressed],
|
||||
[st.LevelType.dense, st.LevelType.loose_compressed],
|
||||
[st.LevelType.compressed, st.LevelType.compressed],
|
||||
]
|
||||
orderings = [
|
||||
(ir.AffineMap.get_permutation([0, 1]), 0),
|
||||
@ -149,10 +149,10 @@ def main():
|
||||
|
||||
# Now do the same for BSR.
|
||||
level = [
|
||||
st.DimLevelType.dense,
|
||||
st.DimLevelType.compressed,
|
||||
st.DimLevelType.dense,
|
||||
st.DimLevelType.dense,
|
||||
st.LevelType.dense,
|
||||
st.LevelType.compressed,
|
||||
st.LevelType.dense,
|
||||
st.LevelType.dense,
|
||||
]
|
||||
d0 = ir.AffineDimExpr.get(0)
|
||||
d1 = ir.AffineDimExpr.get(1)
|
||||
|
@ -25,6 +25,7 @@ from tools import sparsifier
|
||||
|
||||
# ===----------------------------------------------------------------------=== #
|
||||
|
||||
|
||||
class TypeConverter:
|
||||
"""Converter between NumPy types and MLIR types."""
|
||||
|
||||
@ -204,9 +205,7 @@ def main():
|
||||
# All combinations.
|
||||
levels = list(
|
||||
itertools.product(
|
||||
*itertools.repeat(
|
||||
[st.DimLevelType.dense, st.DimLevelType.compressed], rank
|
||||
)
|
||||
*itertools.repeat([st.LevelType.dense, st.LevelType.compressed], rank)
|
||||
)
|
||||
)
|
||||
# All permutations.
|
||||
|
@ -28,7 +28,7 @@ def testEncodingAttr1D():
|
||||
# CHECK: equal: True
|
||||
print(f"equal: {casted == parsed}")
|
||||
|
||||
# CHECK: lvl_types: [<DimLevelType.compressed: 8>]
|
||||
# CHECK: lvl_types: [<LevelType.compressed: 8>]
|
||||
print(f"lvl_types: {casted.lvl_types}")
|
||||
# CHECK: dim_to_lvl: (d0) -> (d0)
|
||||
print(f"dim_to_lvl: {casted.dim_to_lvl}")
|
||||
@ -70,7 +70,7 @@ def testEncodingAttr2D():
|
||||
# CHECK: equal: True
|
||||
print(f"equal: {casted == parsed}")
|
||||
|
||||
# CHECK: lvl_types: [<DimLevelType.dense: 4>, <DimLevelType.compressed: 8>]
|
||||
# CHECK: lvl_types: [<LevelType.dense: 4>, <LevelType.compressed: 8>]
|
||||
print(f"lvl_types: {casted.lvl_types}")
|
||||
# CHECK: dim_to_lvl: (d0, d1) -> (d1, d0)
|
||||
print(f"dim_to_lvl: {casted.dim_to_lvl}")
|
||||
|
@ -313,11 +313,11 @@ protected:
|
||||
MergerTest3T1L() : MergerTestBase(3, 1) {
|
||||
EXPECT_TRUE(merger.getOutTensorID() == tid(2));
|
||||
// Tensor 0: sparse input vector.
|
||||
merger.setLevelAndType(tid(0), lid(0), 0, DimLevelType::Compressed);
|
||||
merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Compressed);
|
||||
// Tensor 1: sparse input vector.
|
||||
merger.setLevelAndType(tid(1), lid(0), 0, DimLevelType::Compressed);
|
||||
merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Compressed);
|
||||
// Tensor 2: dense output vector.
|
||||
merger.setLevelAndType(tid(2), lid(0), 0, DimLevelType::Dense);
|
||||
merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Dense);
|
||||
}
|
||||
};
|
||||
|
||||
@ -327,13 +327,13 @@ protected:
|
||||
MergerTest4T1L() : MergerTestBase(4, 1) {
|
||||
EXPECT_TRUE(merger.getOutTensorID() == tid(3));
|
||||
// Tensor 0: sparse input vector.
|
||||
merger.setLevelAndType(tid(0), lid(0), 0, DimLevelType::Compressed);
|
||||
merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Compressed);
|
||||
// Tensor 1: sparse input vector.
|
||||
merger.setLevelAndType(tid(1), lid(0), 0, DimLevelType::Compressed);
|
||||
merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Compressed);
|
||||
// Tensor 2: sparse input vector
|
||||
merger.setLevelAndType(tid(2), lid(0), 0, DimLevelType::Compressed);
|
||||
merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Compressed);
|
||||
// Tensor 3: dense output vector
|
||||
merger.setLevelAndType(tid(3), lid(0), 0, DimLevelType::Dense);
|
||||
merger.setLevelAndType(tid(3), lid(0), 0, LevelType::Dense);
|
||||
}
|
||||
};
|
||||
|
||||
@ -347,11 +347,11 @@ protected:
|
||||
MergerTest3T1LD() : MergerTestBase(3, 1) {
|
||||
EXPECT_TRUE(merger.getOutTensorID() == tid(2));
|
||||
// Tensor 0: sparse input vector.
|
||||
merger.setLevelAndType(tid(0), lid(0), 0, DimLevelType::Compressed);
|
||||
merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Compressed);
|
||||
// Tensor 1: dense input vector.
|
||||
merger.setLevelAndType(tid(1), lid(0), 0, DimLevelType::Dense);
|
||||
merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Dense);
|
||||
// Tensor 2: dense output vector.
|
||||
merger.setLevelAndType(tid(2), lid(0), 0, DimLevelType::Dense);
|
||||
merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Dense);
|
||||
}
|
||||
};
|
||||
|
||||
@ -365,13 +365,13 @@ protected:
|
||||
MergerTest4T1LU() : MergerTestBase(4, 1) {
|
||||
EXPECT_TRUE(merger.getOutTensorID() == tid(3));
|
||||
// Tensor 0: undef input vector.
|
||||
merger.setLevelAndType(tid(0), lid(0), 0, DimLevelType::Undef);
|
||||
merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Undef);
|
||||
// Tensor 1: dense input vector.
|
||||
merger.setLevelAndType(tid(1), lid(0), 0, DimLevelType::Dense);
|
||||
merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Dense);
|
||||
// Tensor 2: undef input vector.
|
||||
merger.setLevelAndType(tid(2), lid(0), 0, DimLevelType::Undef);
|
||||
merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Undef);
|
||||
// Tensor 3: dense output vector.
|
||||
merger.setLevelAndType(tid(3), lid(0), 0, DimLevelType::Dense);
|
||||
merger.setLevelAndType(tid(3), lid(0), 0, LevelType::Dense);
|
||||
}
|
||||
};
|
||||
|
||||
@ -387,11 +387,11 @@ protected:
|
||||
EXPECT_TRUE(merger.getSynTensorID() == tid(3));
|
||||
merger.setHasSparseOut(true);
|
||||
// Tensor 0: undef input vector.
|
||||
merger.setLevelAndType(tid(0), lid(0), 0, DimLevelType::Undef);
|
||||
merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Undef);
|
||||
// Tensor 1: undef input vector.
|
||||
merger.setLevelAndType(tid(1), lid(0), 0, DimLevelType::Undef);
|
||||
merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Undef);
|
||||
// Tensor 2: sparse output vector.
|
||||
merger.setLevelAndType(tid(2), lid(0), 0, DimLevelType::Compressed);
|
||||
merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Compressed);
|
||||
}
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user