[mlir][sparse] add full dimension ordering support

This revision completes the "dimension ordering" feature
of sparse tensor types that enables the programmer to
define a preferred order on dimension access (other than
the default left-to-right order). This enables e.g. selection
of column-major over row-major storage for sparse matrices,
but generalized to any rank, as in:

dimOrdering = affine_map<(i,j,k,l,m,n,o,p) -> (p,o,j,k,i,l,m,n)>

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D102856
This commit is contained in:
Aart Bik 2021-05-21 11:52:34 -07:00
parent bbdabb044d
commit c194b49c9c
9 changed files with 449 additions and 141 deletions

View File

@ -351,8 +351,8 @@ extern "C" MLIR_CRUNNERUTILS_EXPORT double rtclock();
//===----------------------------------------------------------------------===//
// Small runtime support library for sparse tensors.
//===----------------------------------------------------------------------===//
extern "C" MLIR_CRUNNERUTILS_EXPORT void *openTensorC(char *filename,
uint64_t *idata);
extern "C" MLIR_CRUNNERUTILS_EXPORT void *
openTensorC(char *filename, uint64_t *idata, uint64_t *perm);
extern "C" MLIR_CRUNNERUTILS_EXPORT void
readTensorItemC(void *tensor, uint64_t *idata, double *ddata);
extern "C" MLIR_CRUNNERUTILS_EXPORT void closeTensor(void *tensor);

View File

@ -54,6 +54,20 @@ getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) {
}
}
/// Returns integers of given width and values as a constant tensor.
/// We cast the static shape into a dynamic shape to ensure that the
/// method signature remains uniform accross different tensor dimensions.
static Value getTensor(ConversionPatternRewriter &rewriter, unsigned width,
Location loc, ArrayRef<APInt> values) {
Type etp = rewriter.getIntegerType(width);
unsigned sz = values.size();
RankedTensorType tt1 = RankedTensorType::get({sz}, etp);
RankedTensorType tt2 = RankedTensorType::get({ShapedType::kDynamicSize}, etp);
auto elts =
rewriter.create<ConstantOp>(loc, DenseElementsAttr::get(tt1, values));
return rewriter.create<tensor::CastOp>(loc, tt2, elts);
}
/// Returns function reference (first hit also inserts into module).
static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type result,
ValueRange operands) {
@ -117,22 +131,29 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
return failure();
// User pointer.
params.push_back(operands[0]);
// Sparsity annotations in tensor constant form. Note that we cast
// the static shape into a dynamic shape to ensure that the method
// signature remains uniform accross different tensor dimensions.
// Sparsity annotations in tensor constant form.
SmallVector<APInt, 4> attrs;
unsigned sz = enc.getDimLevelType().size();
for (unsigned i = 0; i < sz; i++)
attrs.push_back(
APInt(8, getDimLevelTypeEncoding(enc.getDimLevelType()[i])));
Type etp = rewriter.getIntegerType(8);
RankedTensorType tt1 = RankedTensorType::get({sz}, etp);
RankedTensorType tt2 =
RankedTensorType::get({ShapedType::kDynamicSize}, etp);
auto elts =
rewriter.create<ConstantOp>(loc, DenseElementsAttr::get(tt1, attrs));
params.push_back(rewriter.create<tensor::CastOp>(loc, tt2, elts));
// Seconary and primary types encoding.
params.push_back(getTensor(rewriter, 8, loc, attrs));
// Dimension order permutation array. This is the "identity"
// permutation by default, or otherwise the "reverse" permutation
// of a given ordering, so that indices can be mapped quickly
// to the right position.
SmallVector<APInt, 4> perm(sz);
AffineMap p = enc.getDimOrdering();
if (p) {
assert(p.isPermutation() && p.getNumResults() == sz);
for (unsigned i = 0; i < sz; i++)
perm[p.getDimPosition(i)] = APInt(64, i);
} else {
for (unsigned i = 0; i < sz; i++)
perm[i] = APInt(64, i);
}
params.push_back(getTensor(rewriter, 64, loc, perm));
// Secondary and primary types encoding.
unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth());
unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
unsigned primary;

View File

@ -333,6 +333,18 @@ struct CodeGen {
} // namespace
// Helper method to apply dimension ordering permutation.
static unsigned perm(SparseTensorEncodingAttr &enc, unsigned d) {
if (enc) {
auto order = enc.getDimOrdering();
if (order) {
assert(order.isPermutation());
return order.getDimPosition(d);
}
}
return d;
}
// Helper method to translate dim level type to internal representation.
static Dim toDim(SparseTensorEncodingAttr &enc, unsigned d) {
if (enc) {
@ -353,17 +365,17 @@ static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
unsigned lhs = numTensors - 1;
for (unsigned t = 0; t < numTensors; t++) {
auto map = op.getIndexingMap(t);
unsigned rank = op.getShapedType(t).getRank();
if (!map.isProjectedPermutation())
return false;
auto enc = getSparseTensorEncoding(op.getShapedType(t));
if (enc) {
annotated = true;
if (enc.getDimOrdering() && !enc.getDimOrdering().isIdentity())
return false; // TODO: handle permutations
if (t == lhs)
return false; // TODO: handle sparse outputs
}
for (unsigned d = 0; d < rank; d++) {
unsigned idx = map.getDimPosition(d);
assert(map.getNumResults() == op.getShapedType(t).getRank());
for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
unsigned idx = map.getDimPosition(perm(enc, d));
merger.setDim(t, idx, toDim(enc, d));
}
}
@ -405,18 +417,18 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
unsigned numTensors = op.getNumShapedOperands();
for (unsigned t = 0; t < numTensors; t++) {
auto map = op.getIndexingMap(t);
auto enc = getSparseTensorEncoding(op.getShapedType(t));
assert(map.getNumDims() == n);
// Skip dense tensor constraints when sparse only is requested.
if (sparseOnly && !getSparseTensorEncoding(op.getShapedType(t)))
if (sparseOnly && !enc)
continue;
// At the moment, we take the index variables in the tensor access
// expression in the order in which they appear (conceptually a
// "row-major" layout of every tensor). So, a tensor access A_ijk
// forces the ordering i < j < k on the loop indices.
// TODO: support affine map to define alternative dimension orders.
for (unsigned d = 1, e = map.getNumResults(); d < e; d++) {
unsigned f = map.getDimPosition(d - 1);
unsigned t = map.getDimPosition(d);
// Each tensor expression and optional dimension ordering (row-major
// by default) puts an ordering constraint on the loop indices. For
// example, the tensor expresion A_ijk forces the ordering i < j < k
// on the loop indices if no explicit dimension ordering is given.
for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) {
unsigned f = map.getDimPosition(perm(enc, d - 1));
unsigned t = map.getDimPosition(perm(enc, d));
adjM[f][t] = true;
}
}
@ -441,15 +453,10 @@ static Optional<unsigned> buildTensorExp(Merger &merger, linalg::GenericOp op,
Value val) {
if (auto arg = val.dyn_cast<BlockArgument>()) {
unsigned argN = arg.getArgNumber();
if (arg.getOwner()->getParentOp() == op) {
// Any parameter of the generic op is considered a tensor,
// indexed by the implicit loop bounds.
auto map = op.getIndexingMap(argN);
if (map.isProjectedPermutation())
return merger.addExp(Kind::kTensor, argN);
// Cannot handle (yet).
return None;
}
// Any parameter of the generic op is considered a tensor,
// indexed by the implicit loop bounds.
if (arg.getOwner()->getParentOp() == op)
return merger.addExp(Kind::kTensor, argN);
// Any parameter of a higher op is invariant.
return merger.addExp(Kind::kInvariant, val);
}
@ -568,10 +575,10 @@ static void genBuffers(Merger &merger, CodeGen &codegen,
auto enc = getSparseTensorEncoding(tensorType);
// Scan all dimensions of current tensor.
args.clear();
for (unsigned d = 0, rank = shape.size(); d < rank; d++) {
unsigned i = map.getDimPosition(d);
for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
unsigned idx = map.getDimPosition(perm(enc, d));
// Handle sparse storage schemes.
if (merger.isDim(t, i, Dim::kSparse)) {
if (merger.isDim(t, idx, Dim::kSparse)) {
auto dynShape = {ShapedType::kDynamicSize};
auto ptrTp = MemRefType::get(
dynShape, genIntType(rewriter, enc.getPointerBitWidth()));
@ -579,9 +586,9 @@ static void genBuffers(Merger &merger, CodeGen &codegen,
dynShape, genIntType(rewriter, enc.getIndexBitWidth()));
Value dim = rewriter.create<ConstantIndexOp>(loc, d);
// Generate sparse primitives to obtains pointer and indices.
codegen.pointers[t][i] =
codegen.pointers[t][idx] =
rewriter.create<ToPointersOp>(loc, ptrTp, tensor, dim);
codegen.indices[t][i] =
codegen.indices[t][idx] =
rewriter.create<ToIndicesOp>(loc, indTp, tensor, dim);
}
// Find lower and upper bound in current dimension.
@ -592,7 +599,7 @@ static void genBuffers(Merger &merger, CodeGen &codegen,
} else {
up = rewriter.create<ConstantIndexOp>(loc, shape[d]);
}
codegen.sizes[i] = codegen.highs[t][i] = up;
codegen.sizes[idx] = codegen.highs[t][idx] = up;
}
// Perform the required bufferization. All dense inputs materialize
// from the input tensor. The dense output tensor needs special
@ -705,8 +712,8 @@ static Value genTensorLoad(Merger &merger, CodeGen &codegen,
unsigned tensor = merger.exp(exp).e0;
auto map = op.getIndexingMap(tensor);
auto enc = getSparseTensorEncoding(op.getShapedType(tensor));
for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) {
unsigned idx = map.getDimPosition(i);
for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
unsigned idx = map.getDimPosition(perm(enc, d));
args.push_back(codegen.loops[idx]); // universal dense index
if (enc) {
args.clear();
@ -737,8 +744,9 @@ static void genTensorStore(Merger &merger, CodeGen &codegen,
// Actual store.
SmallVector<Value, 4> args;
auto map = op.getIndexingMap(tensor);
for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) {
unsigned idx = map.getDimPosition(i);
assert(!getSparseTensorEncoding(op.getShapedType(tensor)));
for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
unsigned idx = map.getDimPosition(d);
args.push_back(codegen.loops[idx]); // universal dense index
}
Value ptr = codegen.buffers[tensor];
@ -888,8 +896,9 @@ static void genInvariants(Merger &merger, CodeGen &codegen,
bool atLevel = ldx == -1u;
unsigned tensor = merger.exp(exp).e0;
auto map = op.getIndexingMap(tensor);
for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) {
unsigned idx = map.getDimPosition(i);
auto enc = getSparseTensorEncoding(op.getShapedType(tensor));
for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
unsigned idx = map.getDimPosition(perm(enc, d));
if (!codegen.loops[idx])
return; // still in play
else if (idx == ldx)
@ -1001,9 +1010,8 @@ static bool denseUnitStrides(Merger &merger, linalg::GenericOp op,
for (unsigned t = 0; t < numTensors; t++) {
if (!getSparseTensorEncoding(op.getShapedType(t))) {
auto map = op.getIndexingMap(t);
unsigned r = map.getNumResults();
for (unsigned i = 0; i < r; i++) {
if (map.getDimPosition(i) == idx && i != r - 1)
for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
if (map.getDimPosition(d) == idx && d != rank - 1)
return false;
}
}

View File

@ -243,9 +243,11 @@ private:
/// Templated reader.
template <typename P, typename I, typename V>
void *newSparseTensor(char *filename, uint8_t *sparsity, uint64_t size) {
void *newSparseTensor(char *filename, uint8_t *sparsity, uint64_t *perm,
uint64_t size) {
uint64_t idata[64];
SparseTensor *t = static_cast<SparseTensor *>(openTensorC(filename, idata));
SparseTensor *t =
static_cast<SparseTensor *>(openTensorC(filename, idata, perm));
assert(size == t->getRank()); // sparsity array must match rank
SparseTensorStorageBase *tensor =
new SparseTensorStorage<P, I, V>(t, sparsity);
@ -371,7 +373,7 @@ extern "C" {
/// understood by other methods in the sparse runtime support library. An
/// array parameter is used to pass the rank, the number of nonzero elements,
/// and the dimension sizes (one per rank).
void *openTensorC(char *filename, uint64_t *idata) {
void *openTensorC(char *filename, uint64_t *idata, uint64_t *perm) {
// Open the file.
FILE *file = fopen(filename, "r");
if (!file) {
@ -393,16 +395,24 @@ void *openTensorC(char *filename, uint64_t *idata) {
uint64_t nnz = idata[1];
std::vector<uint64_t> indices(rank);
for (uint64_t r = 0; r < rank; r++)
indices[r] = idata[2 + r];
if (perm)
indices[perm[r]] = idata[2 + r];
else
indices[r] = idata[2 + r];
SparseTensor *tensor = new SparseTensor(indices, nnz);
// Read all nonzero elements.
for (uint64_t k = 0; k < nnz; k++) {
uint64_t idx = -1;
for (uint64_t r = 0; r < rank; r++) {
if (fscanf(file, "%" PRIu64, &indices[r]) != 1) {
if (fscanf(file, "%" PRIu64, &idx) != 1) {
fprintf(stderr, "Cannot find next index in %s\n", filename);
exit(1);
}
indices[r]--; // 0-based index
// Add 0-based index.
if (perm)
indices[perm[r]] = idx - 1;
else
indices[r] = idx - 1;
}
double value;
if (fscanf(file, "%lg\n", &value) != 1) {
@ -421,7 +431,7 @@ void *openTensorC(char *filename, uint64_t *idata) {
void *openTensor(char *filename, uint64_t *ibase, uint64_t *idata,
uint64_t ioff, uint64_t isize, uint64_t istride) {
assert(istride == 1);
return openTensorC(filename, idata + ioff);
return openTensorC(filename, idata + ioff, nullptr);
}
/// Yields the next element from the given opaque sparse tensor object.
@ -477,7 +487,7 @@ char *getTensorFilename(uint64_t id) {
#define CASE(p, i, v, P, I, V) \
if (ptrTp == (p) && indTp == (i) && valTp == (v)) \
return newSparseTensor<P, I, V>(filename, sparsity, asize)
return newSparseTensor<P, I, V>(filename, sparsity, perm, asize)
#define IMPL1(RET, NAME, TYPE, LIB) \
RET NAME(void *tensor) { \
@ -515,9 +525,12 @@ enum PrimaryTypeEnum : uint64_t {
void *newSparseTensor(char *filename, uint8_t *abase, uint8_t *adata,
uint64_t aoff, uint64_t asize, uint64_t astride,
uint64_t ptrTp, uint64_t indTp, uint64_t valTp) {
assert(astride == 1);
uint64_t *pbase, uint64_t *pdata, uint64_t poff,
uint64_t psize, uint64_t pstride, uint64_t ptrTp,
uint64_t indTp, uint64_t valTp) {
assert(astride == 1 && pstride == 1);
uint8_t *sparsity = adata + aoff;
uint64_t *perm = pdata + poff;
// The most common cases: 64-bit or 32-bit overhead, double/float values.
CASE(kU64, kU64, kF64, uint64_t, uint64_t, double);

View File

@ -20,6 +20,11 @@
dimLevelType = ["dense", "compressed"]
}>
#SparseTensor = #sparse_tensor.encoding<{
dimLevelType = ["dense", "compressed", "compressed"],
dimOrdering = affine_map<(i,j,k) -> (k,i,j)>
}>
// CHECK-LABEL: func @sparse_dim(
// CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
// CHECK: %[[C:.*]] = constant 0 : index
@ -35,7 +40,9 @@ func @sparse_dim(%arg0: tensor<?xf64, #SparseVector>) -> index {
// CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8>
// CHECK: %[[D:.*]] = constant dense<1> : tensor<1xi8>
// CHECK: %[[C:.*]] = tensor.cast %[[D]] : tensor<1xi8> to tensor<?xi8>
// CHECK: %[[T:.*]] = call @newSparseTensor(%[[A]], %[[C]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, tensor<?xi8>, i64, i64, i64) -> !llvm.ptr<i8>
// CHECK: %[[P:.*]] = constant dense<0> : tensor<1xi64>
// CHECK: %[[Q:.*]] = tensor.cast %[[P]] : tensor<1xi64> to tensor<?xi64>
// CHECK: %[[T:.*]] = call @newSparseTensor(%[[A]], %[[C]], %[[Q]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, tensor<?xi8>, tensor<?xi64>, i64, i64, i64) -> !llvm.ptr<i8>
// CHECK: return %[[T]] : !llvm.ptr<i8>
func @sparse_new1d(%arg0: !llvm.ptr<i8>) -> tensor<128xf64, #SparseVector> {
%0 = sparse_tensor.new %arg0 : !llvm.ptr<i8> to tensor<128xf64, #SparseVector>
@ -46,13 +53,28 @@ func @sparse_new1d(%arg0: !llvm.ptr<i8>) -> tensor<128xf64, #SparseVector> {
// CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8>
// CHECK: %[[D:.*]] = constant dense<[0, 1]> : tensor<2xi8>
// CHECK: %[[C:.*]] = tensor.cast %[[D]] : tensor<2xi8> to tensor<?xi8>
// CHECK: %[[T:.*]] = call @newSparseTensor(%[[A]], %[[C]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, tensor<?xi8>, i64, i64, i64) -> !llvm.ptr<i8>
// CHECK: %[[P:.*]] = constant dense<[0, 1]> : tensor<2xi64>
// CHECK: %[[Q:.*]] = tensor.cast %[[P]] : tensor<2xi64> to tensor<?xi64>
// CHECK: %[[T:.*]] = call @newSparseTensor(%[[A]], %[[C]], %[[Q]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, tensor<?xi8>, tensor<?xi64>, i64, i64, i64) -> !llvm.ptr<i8>
// CHECK: return %[[T]] : !llvm.ptr<i8>
func @sparse_new2d(%arg0: !llvm.ptr<i8>) -> tensor<?x?xf32, #SparseMatrix> {
%0 = sparse_tensor.new %arg0 : !llvm.ptr<i8> to tensor<?x?xf32, #SparseMatrix>
return %0 : tensor<?x?xf32, #SparseMatrix>
}
// CHECK-LABEL: func @sparse_new3d(
// CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8>
// CHECK: %[[D:.*]] = constant dense<[0, 1, 1]> : tensor<3xi8>
// CHECK: %[[C:.*]] = tensor.cast %[[D]] : tensor<3xi8> to tensor<?xi8>
// CHECK: %[[P:.*]] = constant dense<[1, 2, 0]> : tensor<3xi64>
// CHECK: %[[Q:.*]] = tensor.cast %[[P]] : tensor<3xi64> to tensor<?xi64>
// CHECK: %[[T:.*]] = call @newSparseTensor(%[[A]], %[[C]], %[[Q]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, tensor<?xi8>, tensor<?xi64>, i64, i64, i64) -> !llvm.ptr<i8>
// CHECK: return %[[T]] : !llvm.ptr<i8>
func @sparse_new3d(%arg0: !llvm.ptr<i8>) -> tensor<?x?x?xf32, #SparseTensor> {
%0 = sparse_tensor.new %arg0 : !llvm.ptr<i8> to tensor<?x?x?xf32, #SparseTensor>
return %0 : tensor<?x?x?xf32, #SparseTensor>
}
// CHECK-LABEL: func @sparse_pointers(
// CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
// CHECK: %[[C:.*]] = constant 0 : index

View File

@ -21,60 +21,60 @@
}
// CHECK-HIR-LABEL: func @matvec(
// CHECK-HIR-SAME: %[[VAL_0:.*]]: tensor<64x64xf64, #sparse_tensor.encoding<{{.*}}>>,
// CHECK-HIR-SAME: %[[VAL_0:.*]]: tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
// CHECK-HIR-SAME: %[[VAL_1:.*]]: tensor<64xf64>,
// CHECK-HIR-SAME: %[[VAL_2:.*]]: tensor<64xf64>) -> tensor<64xf64> {
// CHECK-HIR: %[[VAL_3:.*]] = constant 64 : index
// CHECK-HIR-SAME: %[[VAL_2:.*]]: tensor<32xf64>) -> tensor<32xf64> {
// CHECK-HIR: %[[VAL_3:.*]] = constant 32 : index
// CHECK-HIR: %[[VAL_4:.*]] = constant 0 : index
// CHECK-HIR: %[[VAL_5:.*]] = constant 1 : index
// CHECK-HIR: %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<64x64xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
// CHECK-HIR: %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<64x64xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
// CHECK-HIR: %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<64x64xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xf64>
// CHECK-HIR: %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_1]] : memref<64xf64>
// CHECK-HIR: %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_2]] : memref<64xf64>
// CHECK-HIR: %[[VAL_12:.*]] = memref.alloc() : memref<64xf64>
// CHECK-HIR: linalg.copy(%[[VAL_11]], %[[VAL_12]]) : memref<64xf64>, memref<64xf64>
// CHECK-HIR: scf.for %[[VAL_13:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
// CHECK-HIR: %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_13]]] : memref<?xindex>
// CHECK-HIR: %[[VAL_15:.*]] = addi %[[VAL_13]], %[[VAL_5]] : index
// CHECK-HIR: %[[VAL_16:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_15]]] : memref<?xindex>
// CHECK-HIR: %[[VAL_17:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_13]]] : memref<64xf64>
// CHECK-HIR: %[[VAL_18:.*]] = scf.for %[[VAL_19:.*]] = %[[VAL_14]] to %[[VAL_16]] step %[[VAL_5]] iter_args(%[[VAL_20:.*]] = %[[VAL_17]]) -> (f64) {
// CHECK-HIR: %[[VAL_21:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_19]]] : memref<?xindex>
// CHECK-HIR: %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_19]]] : memref<?xf64>
// CHECK-HIR: %[[VAL_23:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_21]]] : memref<64xf64>
// CHECK-HIR: %[[VAL_24:.*]] = mulf %[[VAL_22]], %[[VAL_23]] : f64
// CHECK-HIR: %[[VAL_25:.*]] = addf %[[VAL_20]], %[[VAL_24]] : f64
// CHECK-HIR: scf.yield %[[VAL_25]] : f64
// CHECK-HIR: %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
// CHECK-HIR: %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
// CHECK-HIR: %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf64>
// CHECK-HIR: %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_1]] : memref<64xf64>
// CHECK-HIR: %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf64>
// CHECK-HIR: %[[VAL_11:.*]] = memref.alloc() : memref<32xf64>
// CHECK-HIR: linalg.copy(%[[VAL_10]], %[[VAL_11]]) : memref<32xf64>, memref<32xf64>
// CHECK-HIR: scf.for %[[VAL_12:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
// CHECK-HIR: %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref<?xindex>
// CHECK-HIR: %[[VAL_14:.*]] = addi %[[VAL_12]], %[[VAL_5]] : index
// CHECK-HIR: %[[VAL_15:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_14]]] : memref<?xindex>
// CHECK-HIR: %[[VAL_16:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_12]]] : memref<32xf64>
// CHECK-HIR: %[[VAL_17:.*]] = scf.for %[[VAL_18:.*]] = %[[VAL_13]] to %[[VAL_15]] step %[[VAL_5]] iter_args(%[[VAL_19:.*]] = %[[VAL_16]]) -> (f64) {
// CHECK-HIR: %[[VAL_20:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_18]]] : memref<?xindex>
// CHECK-HIR: %[[VAL_21:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref<?xf64>
// CHECK-HIR: %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_20]]] : memref<64xf64>
// CHECK-HIR: %[[VAL_23:.*]] = mulf %[[VAL_21]], %[[VAL_22]] : f64
// CHECK-HIR: %[[VAL_24:.*]] = addf %[[VAL_19]], %[[VAL_23]] : f64
// CHECK-HIR: scf.yield %[[VAL_24]] : f64
// CHECK-HIR: }
// CHECK-HIR: store %[[VAL_26:.*]], %[[VAL_12]]{{\[}}%[[VAL_13]]] : memref<64xf64>
// CHECK-HIR: memref.store %[[VAL_25:.*]], %[[VAL_11]]{{\[}}%[[VAL_12]]] : memref<32xf64>
// CHECK-HIR: }
// CHECK-HIR: %[[VAL_27:.*]] = memref.tensor_load %[[VAL_12]] : memref<64xf64>
// CHECK-HIR: return %[[VAL_27]] : tensor<64xf64>
// CHECK-HIR: %[[VAL_26:.*]] = memref.tensor_load %[[VAL_11]] : memref<32xf64>
// CHECK-HIR: return %[[VAL_26]] : tensor<32xf64>
// CHECK-HIR: }
// CHECK-MIR-LABEL: func @matvec(
// CHECK-MIR-SAME: %[[VAL_0:.*]]: !llvm.ptr<i8>,
// CHECK-MIR-SAME: %[[VAL_1:.*]]: tensor<64xf64>,
// CHECK-MIR-SAME: %[[VAL_2:.*]]: tensor<64xf64>) -> tensor<64xf64> {
// CHECK-MIR: %[[VAL_3:.*]] = constant 64 : index
// CHECK-MIR-SAME: %[[VAL_2:.*]]: tensor<32xf64>) -> tensor<32xf64> {
// CHECK-MIR: %[[VAL_3:.*]] = constant 32 : index
// CHECK-MIR: %[[VAL_4:.*]] = constant 0 : index
// CHECK-MIR: %[[VAL_5:.*]] = constant 1 : index
// CHECK-MIR: %[[VAL_6:.*]] = call @sparsePointers(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
// CHECK-MIR: %[[VAL_7:.*]] = call @sparseIndices(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
// CHECK-MIR: %[[VAL_8:.*]] = call @sparseValuesF64(%[[VAL_0]]) : (!llvm.ptr<i8>) -> memref<?xf64>
// CHECK-MIR: %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_1]] : memref<64xf64>
// CHECK-MIR: %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_2]] : memref<64xf64>
// CHECK-MIR: %[[VAL_11:.*]] = memref.alloc() : memref<64xf64>
// CHECK-MIR: %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf64>
// CHECK-MIR: %[[VAL_11:.*]] = memref.alloc() : memref<32xf64>
// CHECK-MIR: scf.for %[[VAL_12:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
// CHECK-MIR: %[[VAL_13:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_12]]] : memref<64xf64>
// CHECK-MIR: store %[[VAL_13]], %[[VAL_11]]{{\[}}%[[VAL_12]]] : memref<64xf64>
// CHECK-MIR: %[[VAL_13:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_12]]] : memref<32xf64>
// CHECK-MIR: memref.store %[[VAL_13]], %[[VAL_11]]{{\[}}%[[VAL_12]]] : memref<32xf64>
// CHECK-MIR: }
// CHECK-MIR: scf.for %[[VAL_14:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
// CHECK-MIR: %[[VAL_15:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_14]]] : memref<?xindex>
// CHECK-MIR: %[[VAL_16:.*]] = addi %[[VAL_14]], %[[VAL_5]] : index
// CHECK-MIR: %[[VAL_17:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_16]]] : memref<?xindex>
// CHECK-MIR: %[[VAL_18:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_14]]] : memref<64xf64>
// CHECK-MIR: %[[VAL_18:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_14]]] : memref<32xf64>
// CHECK-MIR: %[[VAL_19:.*]] = scf.for %[[VAL_20:.*]] = %[[VAL_15]] to %[[VAL_17]] step %[[VAL_5]] iter_args(%[[VAL_21:.*]] = %[[VAL_18]]) -> (f64) {
// CHECK-MIR: %[[VAL_22:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_20]]] : memref<?xindex>
// CHECK-MIR: %[[VAL_23:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_20]]] : memref<?xf64>
@ -83,32 +83,32 @@
// CHECK-MIR: %[[VAL_26:.*]] = addf %[[VAL_21]], %[[VAL_25]] : f64
// CHECK-MIR: scf.yield %[[VAL_26]] : f64
// CHECK-MIR: }
// CHECK-MIR: store %[[VAL_27:.*]], %[[VAL_11]]{{\[}}%[[VAL_14]]] : memref<64xf64>
// CHECK-MIR: memref.store %[[VAL_27:.*]], %[[VAL_11]]{{\[}}%[[VAL_14]]] : memref<32xf64>
// CHECK-MIR: }
// CHECK-MIR: %[[VAL_28:.*]] = memref.tensor_load %[[VAL_11]] : memref<64xf64>
// CHECK-MIR: return %[[VAL_28]] : tensor<64xf64>
// CHECK-MIR: %[[VAL_28:.*]] = memref.tensor_load %[[VAL_11]] : memref<32xf64>
// CHECK-MIR: return %[[VAL_28]] : tensor<32xf64>
// CHECK-MIR: }
// CHECK-LIR-LABEL: func @matvec(
// CHECK-LIR-SAME: %[[VAL_0:.*]]: !llvm.ptr<i8>,
// CHECK-LIR-SAME: %[[VAL_1:.*]]: memref<64xf64>,
// CHECK-LIR-SAME: %[[VAL_2:.*]]: memref<64xf64>) -> memref<64xf64> {
// CHECK-LIR: %[[VAL_3:.*]] = constant 64 : index
// CHECK-LIR-SAME: %[[VAL_2:.*]]: memref<32xf64>) -> memref<32xf64> {
// CHECK-LIR: %[[VAL_3:.*]] = constant 32 : index
// CHECK-LIR: %[[VAL_4:.*]] = constant 0 : index
// CHECK-LIR: %[[VAL_5:.*]] = constant 1 : index
// CHECK-LIR: %[[VAL_6:.*]] = call @sparsePointers(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
// CHECK-LIR: %[[VAL_7:.*]] = call @sparseIndices(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
// CHECK-LIR: %[[VAL_8:.*]] = call @sparseValuesF64(%[[VAL_0]]) : (!llvm.ptr<i8>) -> memref<?xf64>
// CHECK-LIR: %[[VAL_9:.*]] = memref.alloc() : memref<64xf64>
// CHECK-LIR: %[[VAL_9:.*]] = memref.alloc() : memref<32xf64>
// CHECK-LIR: scf.for %[[VAL_10:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
// CHECK-LIR: %[[VAL_11:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_10]]] : memref<64xf64>
// CHECK-LIR: store %[[VAL_11]], %[[VAL_9]]{{\[}}%[[VAL_10]]] : memref<64xf64>
// CHECK-LIR: %[[VAL_11:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_10]]] : memref<32xf64>
// CHECK-LIR: memref.store %[[VAL_11]], %[[VAL_9]]{{\[}}%[[VAL_10]]] : memref<32xf64>
// CHECK-LIR: }
// CHECK-LIR: scf.for %[[VAL_12:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
// CHECK-LIR: %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref<?xindex>
// CHECK-LIR: %[[VAL_14:.*]] = addi %[[VAL_12]], %[[VAL_5]] : index
// CHECK-LIR: %[[VAL_15:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_14]]] : memref<?xindex>
// CHECK-LIR: %[[VAL_16:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_12]]] : memref<64xf64>
// CHECK-LIR: %[[VAL_16:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_12]]] : memref<32xf64>
// CHECK-LIR: %[[VAL_17:.*]] = scf.for %[[VAL_18:.*]] = %[[VAL_13]] to %[[VAL_15]] step %[[VAL_5]] iter_args(%[[VAL_19:.*]] = %[[VAL_16]]) -> (f64) {
// CHECK-LIR: %[[VAL_20:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_18]]] : memref<?xindex>
// CHECK-LIR: %[[VAL_21:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref<?xf64>
@ -117,21 +117,21 @@
// CHECK-LIR: %[[VAL_24:.*]] = addf %[[VAL_19]], %[[VAL_23]] : f64
// CHECK-LIR: scf.yield %[[VAL_24]] : f64
// CHECK-LIR: }
// CHECK-LIR: store %[[VAL_25:.*]], %[[VAL_9]]{{\[}}%[[VAL_12]]] : memref<64xf64>
// CHECK-LIR: memref.store %[[VAL_25:.*]], %[[VAL_9]]{{\[}}%[[VAL_12]]] : memref<32xf64>
// CHECK-LIR: }
// CHECK-LIR: return %[[VAL_9]] : memref<64xf64>
// CHECK-LIR: return %[[VAL_9]] : memref<32xf64>
// CHECK-LIR: }
func @matvec(%arga: tensor<64x64xf64, #CSR>,
func @matvec(%arga: tensor<32x64xf64, #CSR>,
%argb: tensor<64xf64>,
%argx: tensor<64xf64>) -> tensor<64xf64> {
%argx: tensor<32xf64>) -> tensor<32xf64> {
%0 = linalg.generic #trait_matvec
ins(%arga, %argb : tensor<64x64xf64, #CSR>, tensor<64xf64>)
outs(%argx: tensor<64xf64>) {
ins(%arga, %argb : tensor<32x64xf64, #CSR>, tensor<64xf64>)
outs(%argx: tensor<32xf64>) {
^bb(%A: f64, %b: f64, %x: f64):
%0 = mulf %A, %b : f64
%1 = addf %x, %0 : f64
linalg.yield %1 : f64
} -> tensor<64xf64>
return %0 : tensor<64xf64>
} -> tensor<32xf64>
return %0 : tensor<32xf64>
}

View File

@ -0,0 +1,139 @@
// RUN: mlir-opt %s -sparsification | FileCheck %s --check-prefix=CHECK-HIR
//
// RUN: mlir-opt %s -sparsification --sparse-tensor-conversion \
// RUN: --convert-linalg-to-loops | FileCheck %s --check-prefix=CHECK-MIR
//
// RUN: mlir-opt %s -sparsification --sparse-tensor-conversion \
// RUN: --convert-linalg-to-loops --func-bufferize --tensor-constant-bufferize \
// RUN: --tensor-bufferize --finalizing-bufferize | \
// RUN: FileCheck %s --check-prefix=CHECK-LIR
#CSC = #sparse_tensor.encoding<{
dimLevelType = [ "dense", "compressed" ],
dimOrdering = affine_map<(i,j) -> (j,i)>
}>
#trait_matvec = {
indexing_maps = [
affine_map<(i,j) -> (i,j)>, // A
affine_map<(i,j) -> (j)>, // b
affine_map<(i,j) -> (i)> // x (out)
],
iterator_types = ["parallel","reduction"],
doc = "x(i) += A(i,j) * b(j)"
}
// CHECK-HIR-LABEL: func @matvec(
// CHECK-HIR-SAME: %[[VAL_0:.*]]: tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)>, pointerBitWidth = 0, indexBitWidth = 0 }>>,
// CHECK-HIR-SAME: %[[VAL_1:.*]]: tensor<64xf64>,
// CHECK-HIR-SAME: %[[VAL_2:.*]]: tensor<32xf64>) -> tensor<32xf64> {
// CHECK-HIR: %[[VAL_3:.*]] = constant 64 : index
// CHECK-HIR: %[[VAL_4:.*]] = constant 0 : index
// CHECK-HIR: %[[VAL_5:.*]] = constant 1 : index
// CHECK-HIR: %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)>, pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
// CHECK-HIR: %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)>, pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
// CHECK-HIR: %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)>, pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf64>
// CHECK-HIR: %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_1]] : memref<64xf64>
// CHECK-HIR: %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf64>
// CHECK-HIR: %[[VAL_11:.*]] = memref.alloc() : memref<32xf64>
// CHECK-HIR: linalg.copy(%[[VAL_10]], %[[VAL_11]]) : memref<32xf64>, memref<32xf64>
// CHECK-HIR: scf.for %[[VAL_12:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
// CHECK-HIR: %[[VAL_13:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_12]]] : memref<64xf64>
// CHECK-HIR: %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref<?xindex>
// CHECK-HIR: %[[VAL_15:.*]] = addi %[[VAL_12]], %[[VAL_5]] : index
// CHECK-HIR: %[[VAL_16:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_15]]] : memref<?xindex>
// CHECK-HIR: scf.for %[[VAL_17:.*]] = %[[VAL_14]] to %[[VAL_16]] step %[[VAL_5]] {
// CHECK-HIR: %[[VAL_18:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_17]]] : memref<?xindex>
// CHECK-HIR: %[[VAL_19:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_18]]] : memref<32xf64>
// CHECK-HIR: %[[VAL_20:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_17]]] : memref<?xf64>
// CHECK-HIR: %[[VAL_21:.*]] = mulf %[[VAL_20]], %[[VAL_13]] : f64
// CHECK-HIR: %[[VAL_22:.*]] = addf %[[VAL_19]], %[[VAL_21]] : f64
// CHECK-HIR: memref.store %[[VAL_22]], %[[VAL_11]]{{\[}}%[[VAL_18]]] : memref<32xf64>
// CHECK-HIR: }
// CHECK-HIR: }
// CHECK-HIR: %[[VAL_23:.*]] = memref.tensor_load %[[VAL_11]] : memref<32xf64>
// CHECK-HIR: return %[[VAL_23]] : tensor<32xf64>
// CHECK-HIR: }
// CHECK-MIR-LABEL: func @matvec(
// CHECK-MIR-SAME: %[[VAL_0:.*]]: !llvm.ptr<i8>,
// CHECK-MIR-SAME: %[[VAL_1:.*]]: tensor<64xf64>,
// CHECK-MIR-SAME: %[[VAL_2:.*]]: tensor<32xf64>) -> tensor<32xf64> {
// CHECK-MIR: %[[VAL_3:.*]] = constant 64 : index
// CHECK-MIR: %[[VAL_4:.*]] = constant 32 : index
// CHECK-MIR: %[[VAL_5:.*]] = constant 0 : index
// CHECK-MIR: %[[VAL_6:.*]] = constant 1 : index
// CHECK-MIR: %[[VAL_7:.*]] = call @sparsePointers(%[[VAL_0]], %[[VAL_6]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
// CHECK-MIR: %[[VAL_8:.*]] = call @sparseIndices(%[[VAL_0]], %[[VAL_6]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
// CHECK-MIR: %[[VAL_9:.*]] = call @sparseValuesF64(%[[VAL_0]]) : (!llvm.ptr<i8>) -> memref<?xf64>
// CHECK-MIR: %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_1]] : memref<64xf64>
// CHECK-MIR: %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf64>
// CHECK-MIR: %[[VAL_12:.*]] = memref.alloc() : memref<32xf64>
// CHECK-MIR: scf.for %[[VAL_13:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] {
// CHECK-MIR: %[[VAL_14:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_13]]] : memref<32xf64>
// CHECK-MIR: memref.store %[[VAL_14]], %[[VAL_12]]{{\[}}%[[VAL_13]]] : memref<32xf64>
// CHECK-MIR: }
// CHECK-MIR: scf.for %[[VAL_15:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] {
// CHECK-MIR: %[[VAL_16:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_15]]] : memref<64xf64>
// CHECK-MIR: %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_15]]] : memref<?xindex>
// CHECK-MIR: %[[VAL_18:.*]] = addi %[[VAL_15]], %[[VAL_6]] : index
// CHECK-MIR: %[[VAL_19:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_18]]] : memref<?xindex>
// CHECK-MIR: scf.for %[[VAL_20:.*]] = %[[VAL_17]] to %[[VAL_19]] step %[[VAL_6]] {
// CHECK-MIR: %[[VAL_21:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_20]]] : memref<?xindex>
// CHECK-MIR: %[[VAL_22:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_21]]] : memref<32xf64>
// CHECK-MIR: %[[VAL_23:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_20]]] : memref<?xf64>
// CHECK-MIR: %[[VAL_24:.*]] = mulf %[[VAL_23]], %[[VAL_16]] : f64
// CHECK-MIR: %[[VAL_25:.*]] = addf %[[VAL_22]], %[[VAL_24]] : f64
// CHECK-MIR: memref.store %[[VAL_25]], %[[VAL_12]]{{\[}}%[[VAL_21]]] : memref<32xf64>
// CHECK-MIR: }
// CHECK-MIR: }
// CHECK-MIR: %[[VAL_26:.*]] = memref.tensor_load %[[VAL_12]] : memref<32xf64>
// CHECK-MIR: return %[[VAL_26]] : tensor<32xf64>
// CHECK-MIR: }
// CHECK-LIR-LABEL: func @matvec(
// CHECK-LIR-SAME: %[[VAL_0:.*]]: !llvm.ptr<i8>,
// CHECK-LIR-SAME: %[[VAL_1:.*]]: memref<64xf64>,
// CHECK-LIR-SAME: %[[VAL_2:.*]]: memref<32xf64>) -> memref<32xf64> {
// CHECK-LIR: %[[VAL_3:.*]] = constant 64 : index
// CHECK-LIR: %[[VAL_4:.*]] = constant 32 : index
// CHECK-LIR: %[[VAL_5:.*]] = constant 0 : index
// CHECK-LIR: %[[VAL_6:.*]] = constant 1 : index
// CHECK-LIR: %[[VAL_7:.*]] = call @sparsePointers(%[[VAL_0]], %[[VAL_6]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
// CHECK-LIR: %[[VAL_8:.*]] = call @sparseIndices(%[[VAL_0]], %[[VAL_6]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
// CHECK-LIR: %[[VAL_9:.*]] = call @sparseValuesF64(%[[VAL_0]]) : (!llvm.ptr<i8>) -> memref<?xf64>
// CHECK-LIR: %[[VAL_10:.*]] = memref.alloc() : memref<32xf64>
// CHECK-LIR: scf.for %[[VAL_11:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] {
// CHECK-LIR: %[[VAL_12:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_11]]] : memref<32xf64>
// CHECK-LIR: memref.store %[[VAL_12]], %[[VAL_10]]{{\[}}%[[VAL_11]]] : memref<32xf64>
// CHECK-LIR: }
// CHECK-LIR: scf.for %[[VAL_13:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] {
// CHECK-LIR: %[[VAL_14:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_13]]] : memref<64xf64>
// CHECK-LIR: %[[VAL_15:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_13]]] : memref<?xindex>
// CHECK-LIR: %[[VAL_16:.*]] = addi %[[VAL_13]], %[[VAL_6]] : index
// CHECK-LIR: %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_16]]] : memref<?xindex>
// CHECK-LIR: scf.for %[[VAL_18:.*]] = %[[VAL_15]] to %[[VAL_17]] step %[[VAL_6]] {
// CHECK-LIR: %[[VAL_19:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref<?xindex>
// CHECK-LIR: %[[VAL_20:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<32xf64>
// CHECK-LIR: %[[VAL_21:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref<?xf64>
// CHECK-LIR: %[[VAL_22:.*]] = mulf %[[VAL_21]], %[[VAL_14]] : f64
// CHECK-LIR: %[[VAL_23:.*]] = addf %[[VAL_20]], %[[VAL_22]] : f64
// CHECK-LIR: memref.store %[[VAL_23]], %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<32xf64>
// CHECK-LIR: }
// CHECK-LIR: }
// CHECK-LIR: return %[[VAL_10]] : memref<32xf64>
// CHECK-LIR: }
func @matvec(%arga: tensor<32x64xf64, #CSC>,
%argb: tensor<64xf64>,
%argx: tensor<32xf64>) -> tensor<32xf64> {
%0 = linalg.generic #trait_matvec
ins(%arga, %argb : tensor<32x64xf64, #CSC>, tensor<64xf64>)
outs(%argx: tensor<32xf64>) {
^bb(%A: f64, %b: f64, %x: f64):
%0 = mulf %A, %b : f64
%1 = addf %x, %0 : f64
linalg.yield %1 : f64
} -> tensor<32xf64>
return %0 : tensor<32xf64>
}

View File

@ -21,22 +21,22 @@
}
// CHECK-HIR-LABEL: func @matvec(
// CHECK-HIR-SAME: %[[VAL_0:.*]]: tensor<64x64xf64, #sparse_tensor.encoding<{{.*}}>>,
// CHECK-HIR-SAME: %[[VAL_0:.*]]: tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
// CHECK-HIR-SAME: %[[VAL_1:.*]]: tensor<64xf64>,
// CHECK-HIR-SAME: %[[VAL_2:.*]]: tensor<64xf64> {linalg.inplaceable = true}) -> tensor<64xf64> {
// CHECK-HIR: %[[VAL_3:.*]] = constant 64 : index
// CHECK-HIR-SAME: %[[VAL_2:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
// CHECK-HIR: %[[VAL_3:.*]] = constant 32 : index
// CHECK-HIR: %[[VAL_4:.*]] = constant 0 : index
// CHECK-HIR: %[[VAL_5:.*]] = constant 1 : index
// CHECK-HIR: %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<64x64xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
// CHECK-HIR: %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<64x64xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
// CHECK-HIR: %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<64x64xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xf64>
// CHECK-HIR: %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
// CHECK-HIR: %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
// CHECK-HIR: %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf64>
// CHECK-HIR: %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_1]] : memref<64xf64>
// CHECK-HIR: %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_2]] : memref<64xf64>
// CHECK-HIR: %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf64>
// CHECK-HIR: scf.for %[[VAL_11:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
// CHECK-HIR: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xindex>
// CHECK-HIR: %[[VAL_13:.*]] = addi %[[VAL_11]], %[[VAL_5]] : index
// CHECK-HIR: %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_13]]] : memref<?xindex>
// CHECK-HIR: %[[VAL_15:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_11]]] : memref<64xf64>
// CHECK-HIR: %[[VAL_15:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_11]]] : memref<32xf64>
// CHECK-HIR: %[[VAL_16:.*]] = scf.for %[[VAL_17:.*]] = %[[VAL_12]] to %[[VAL_14]] step %[[VAL_5]] iter_args(%[[VAL_18:.*]] = %[[VAL_15]]) -> (f64) {
// CHECK-HIR: %[[VAL_19:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_17]]] : memref<?xindex>
// CHECK-HIR: %[[VAL_20:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_17]]] : memref<?xf64>
@ -45,29 +45,29 @@
// CHECK-HIR: %[[VAL_23:.*]] = addf %[[VAL_18]], %[[VAL_22]] : f64
// CHECK-HIR: scf.yield %[[VAL_23]] : f64
// CHECK-HIR: }
// CHECK-HIR: memref.store %[[VAL_24:.*]], %[[VAL_10]]{{\[}}%[[VAL_11]]] : memref<64xf64>
// CHECK-HIR: memref.store %[[VAL_24:.*]], %[[VAL_10]]{{\[}}%[[VAL_11]]] : memref<32xf64>
// CHECK-HIR: }
// CHECK-HIR: %[[VAL_25:.*]] = memref.tensor_load %[[VAL_10]] : memref<64xf64>
// CHECK-HIR: return %[[VAL_25]] : tensor<64xf64>
// CHECK-HIR: %[[VAL_25:.*]] = memref.tensor_load %[[VAL_10]] : memref<32xf64>
// CHECK-HIR: return %[[VAL_25]] : tensor<32xf64>
// CHECK-HIR: }
// CHECK-MIR-LABEL: func @matvec(
// CHECK-MIR-SAME: %[[VAL_0:.*]]: !llvm.ptr<i8>,
// CHECK-MIR-SAME: %[[VAL_1:.*]]: tensor<64xf64>,
// CHECK-MIR-SAME: %[[VAL_2:.*]]: tensor<64xf64> {linalg.inplaceable = true}) -> tensor<64xf64> {
// CHECK-MIR: %[[VAL_3:.*]] = constant 64 : index
// CHECK-MIR-SAME: %[[VAL_2:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
// CHECK-MIR: %[[VAL_3:.*]] = constant 32 : index
// CHECK-MIR: %[[VAL_4:.*]] = constant 0 : index
// CHECK-MIR: %[[VAL_5:.*]] = constant 1 : index
// CHECK-MIR: %[[VAL_6:.*]] = call @sparsePointers(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
// CHECK-MIR: %[[VAL_7:.*]] = call @sparseIndices(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
// CHECK-MIR: %[[VAL_8:.*]] = call @sparseValuesF64(%[[VAL_0]]) : (!llvm.ptr<i8>) -> memref<?xf64>
// CHECK-MIR: %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_1]] : memref<64xf64>
// CHECK-MIR: %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_2]] : memref<64xf64>
// CHECK-MIR: %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf64>
// CHECK-MIR: scf.for %[[VAL_11:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
// CHECK-MIR: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xindex>
// CHECK-MIR: %[[VAL_13:.*]] = addi %[[VAL_11]], %[[VAL_5]] : index
// CHECK-MIR: %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_13]]] : memref<?xindex>
// CHECK-MIR: %[[VAL_15:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_11]]] : memref<64xf64>
// CHECK-MIR: %[[VAL_15:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_11]]] : memref<32xf64>
// CHECK-MIR: %[[VAL_16:.*]] = scf.for %[[VAL_17:.*]] = %[[VAL_12]] to %[[VAL_14]] step %[[VAL_5]] iter_args(%[[VAL_18:.*]] = %[[VAL_15]]) -> (f64) {
// CHECK-MIR: %[[VAL_19:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_17]]] : memref<?xindex>
// CHECK-MIR: %[[VAL_20:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_17]]] : memref<?xf64>
@ -76,17 +76,17 @@
// CHECK-MIR: %[[VAL_23:.*]] = addf %[[VAL_18]], %[[VAL_22]] : f64
// CHECK-MIR: scf.yield %[[VAL_23]] : f64
// CHECK-MIR: }
// CHECK-MIR: memref.store %[[VAL_24:.*]], %[[VAL_10]]{{\[}}%[[VAL_11]]] : memref<64xf64>
// CHECK-MIR: memref.store %[[VAL_24:.*]], %[[VAL_10]]{{\[}}%[[VAL_11]]] : memref<32xf64>
// CHECK-MIR: }
// CHECK-MIR: %[[VAL_25:.*]] = memref.tensor_load %[[VAL_10]] : memref<64xf64>
// CHECK-MIR: return %[[VAL_25]] : tensor<64xf64>
// CHECK-MIR: %[[VAL_25:.*]] = memref.tensor_load %[[VAL_10]] : memref<32xf64>
// CHECK-MIR: return %[[VAL_25]] : tensor<32xf64>
// CHECK-MIR: }
// CHECK-LIR-LABEL: func @matvec(
// CHECK-LIR-SAME: %[[VAL_0:.*]]: !llvm.ptr<i8>,
// CHECK-LIR-SAME: %[[VAL_1:.*]]: memref<64xf64>,
// CHECK-LIR-SAME: %[[VAL_2:.*]]: memref<64xf64> {linalg.inplaceable = true}) -> memref<64xf64> {
// CHECK-LIR: %[[VAL_3:.*]] = constant 64 : index
// CHECK-LIR-SAME: %[[VAL_2:.*]]: memref<32xf64> {linalg.inplaceable = true}) -> memref<32xf64> {
// CHECK-LIR: %[[VAL_3:.*]] = constant 32 : index
// CHECK-LIR: %[[VAL_4:.*]] = constant 0 : index
// CHECK-LIR: %[[VAL_5:.*]] = constant 1 : index
// CHECK-LIR: %[[VAL_6:.*]] = call @sparsePointers(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
@ -96,7 +96,7 @@
// CHECK-LIR: %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_9]]] : memref<?xindex>
// CHECK-LIR: %[[VAL_11:.*]] = addi %[[VAL_9]], %[[VAL_5]] : index
// CHECK-LIR: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xindex>
// CHECK-LIR: %[[VAL_13:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]] : memref<64xf64>
// CHECK-LIR: %[[VAL_13:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]] : memref<32xf64>
// CHECK-LIR: %[[VAL_14:.*]] = scf.for %[[VAL_15:.*]] = %[[VAL_10]] to %[[VAL_12]] step %[[VAL_5]] iter_args(%[[VAL_16:.*]] = %[[VAL_13]]) -> (f64) {
// CHECK-LIR: %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_15]]] : memref<?xindex>
// CHECK-LIR: %[[VAL_18:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_15]]] : memref<?xf64>
@ -105,21 +105,21 @@
// CHECK-LIR: %[[VAL_21:.*]] = addf %[[VAL_16]], %[[VAL_20]] : f64
// CHECK-LIR: scf.yield %[[VAL_21]] : f64
// CHECK-LIR: }
// CHECK-LIR: memref.store %[[VAL_22:.*]], %[[VAL_2]]{{\[}}%[[VAL_9]]] : memref<64xf64>
// CHECK-LIR: memref.store %[[VAL_22:.*]], %[[VAL_2]]{{\[}}%[[VAL_9]]] : memref<32xf64>
// CHECK-LIR: }
// CHECK-LIR: return %[[VAL_2]] : memref<64xf64>
// CHECK-LIR: return %[[VAL_2]] : memref<32xf64>
// CHECK-LIR: }
func @matvec(%arga: tensor<64x64xf64, #CSR>,
func @matvec(%arga: tensor<32x64xf64, #CSR>,
%argb: tensor<64xf64>,
%argx: tensor<64xf64> {linalg.inplaceable = true}) -> tensor<64xf64> {
%argx: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
%0 = linalg.generic #trait_matvec
ins(%arga, %argb : tensor<64x64xf64, #CSR>, tensor<64xf64>)
outs(%argx: tensor<64xf64>) {
ins(%arga, %argb : tensor<32x64xf64, #CSR>, tensor<64xf64>)
outs(%argx: tensor<32xf64>) {
^bb(%A: f64, %b: f64, %x: f64):
%0 = mulf %A, %b : f64
%1 = addf %x, %0 : f64
linalg.yield %1 : f64
} -> tensor<64xf64>
return %0 : tensor<64xf64>
} -> tensor<32xf64>
return %0 : tensor<32xf64>
}

View File

@ -0,0 +1,105 @@
// RUN: mlir-opt %s \
// RUN: --sparsification --sparse-tensor-conversion \
// RUN: --convert-linalg-to-loops --convert-vector-to-scf --convert-scf-to-std \
// RUN: --func-bufferize --tensor-constant-bufferize --tensor-bufferize \
// RUN: --std-bufferize --finalizing-bufferize \
// RUN: --convert-vector-to-llvm --convert-std-to-llvm | \
// RUN: TENSOR0="%mlir_integration_test_dir/data/test.tns" \
// RUN: mlir-cpu-runner \
// RUN: -e entry -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
// RUN: FileCheck %s
!Filename = type !llvm.ptr<i8>
#SparseTensor = #sparse_tensor.encoding<{
dimLevelType = [ "compressed", "compressed", "compressed", "compressed",
"compressed", "compressed", "compressed", "compressed" ],
// Note that any dimOrdering permutation should give the same results
// since, even though it impacts the sparse storage scheme layout,
// it should not change the semantics.
dimOrdering = affine_map<(i,j,k,l,m,n,o,p) -> (p,o,j,k,i,l,m,n)>
}>
#trait_flatten = {
indexing_maps = [
affine_map<(i,j,k,l,m,n,o,p) -> (i,j,k,l,m,n,o,p)>, // A
affine_map<(i,j,k,l,m,n,o,p) -> (i,j)> // X (out)
],
iterator_types = [ "parallel", "parallel", "reduction", "reduction",
"reduction", "reduction", "reduction", "reduction" ],
doc = "X(i,j) += A(i,j,k,l,m,n,o,p)"
}
//
// Integration test that lowers a kernel annotated as sparse to
// actual sparse code, initializes a matching sparse storage scheme
// from file, and runs the resulting code with the JIT compiler.
//
module {
//
// A kernel that flattens a rank 8 tensor into a dense matrix.
//
func @kernel_flatten(%arga: tensor<7x3x3x3x3x3x5x3xf64, #SparseTensor>,
%argx: tensor<7x3xf64>) -> tensor<7x3xf64> {
%0 = linalg.generic #trait_flatten
ins(%arga: tensor<7x3x3x3x3x3x5x3xf64, #SparseTensor>)
outs(%argx: tensor<7x3xf64>) {
^bb(%a: f64, %x: f64):
%0 = addf %x, %a : f64
linalg.yield %0 : f64
} -> tensor<7x3xf64>
return %0 : tensor<7x3xf64>
}
func private @getTensorFilename(index) -> (!Filename)
//
// Main driver that reads tensor from file and calls the sparse kernel.
//
func @entry() {
%d0 = constant 0.0 : f64
%c0 = constant 0 : index
%c1 = constant 1 : index
%c3 = constant 3 : index
%c7 = constant 7 : index
// Setup matrix memory that is initialized to zero.
%xdata = memref.alloc() : memref<7x3xf64>
scf.for %i = %c0 to %c7 step %c1 {
scf.for %j = %c0 to %c3 step %c1 {
memref.store %d0, %xdata[%i, %j] : memref<7x3xf64>
}
}
%x = memref.tensor_load %xdata : memref<7x3xf64>
// Read the sparse tensor from file, construct sparse storage.
%fileName = call @getTensorFilename(%c0) : (index) -> (!Filename)
%a = sparse_tensor.new %fileName : !llvm.ptr<i8> to tensor<7x3x3x3x3x3x5x3xf64, #SparseTensor>
// Call the kernel.
%0 = call @kernel_flatten(%a, %x)
: (tensor<7x3x3x3x3x3x5x3xf64, #SparseTensor>, tensor<7x3xf64>) -> tensor<7x3xf64>
// Print the result for verification.
//
// CHECK: ( 6.25, 0, 0 )
// CHECK: ( 4.224, 6.21, 0 )
// CHECK: ( 0, 0, 15.455 )
// CHECK: ( 0, 0, 0 )
// CHECK: ( 0, 0, 0 )
// CHECK: ( 0, 0, 0 )
// CHECK: ( 7, 0, 0 )
//
%r = memref.buffer_cast %0 : memref<7x3xf64>
scf.for %i = %c0 to %c7 step %c1 {
%v = vector.transfer_read %r[%i, %c0], %d0: memref<7x3xf64>, vector<3xf64>
vector.print %v : vector<3xf64>
}
// Release the resources.
memref.dealloc %xdata : memref<7x3xf64>
return
}
}