[mlir][sparse] Add lowering for unary and binary ops

Adding lowering for Unary and Binary required several changes due to
their unique nature of containing custom code for different "regions"
of the sparse structure being operated on. Along with a Kind, a pointer
to the Operation is passed along to be merged once the lattice
structure is figured out.

The original operation is maintained, as it is required for subsequent
lattice decisions. However, sparse_tensor.binary has some branches
are considered as fully handled and therefore are marked with as
kBinaryBranch to distinguish them.

A unique aspect of the custom code is that sometimes the desired result
is no result at all -- i.e. a user wants overlapping sparse entries to
become empty in the output. The solution to this is to return an
uninitialized Value(), which is checked and handled elsewhere in the
code and results in nothing being written to the output tensor for that
case.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D123057
This commit is contained in:
Jim Kitchen 2022-05-03 15:50:26 -05:00
parent c4546091ed
commit 2c33266084
8 changed files with 736 additions and 41 deletions

View File

@ -415,7 +415,7 @@ def SparseTensor_BinaryOp : SparseTensor_Op<"binary", [NoSideEffect]>,
would be equivalent to a union operation where non-overlapping values
in the inputs are copied to the output unchanged.
Example of isEqual applied to intersecting elements only:
Example of isEqual applied to intersecting elements only.
```mlir
%C = sparse_tensor.init...
%0 = linalg.generic #trait
@ -435,7 +435,8 @@ def SparseTensor_BinaryOp : SparseTensor_Op<"binary", [NoSideEffect]>,
} -> tensor<?xi8, #SparseVec>
```
Example of A+B in upper triangle, A-B in lower triangle:
Example of A+B in upper triangle, A-B in lower triangle
(not working yet, but construct will be available soon).
```mlir
%C = sparse_tensor.init...
%1 = linalg.generic #trait

View File

@ -46,6 +46,8 @@ enum Kind {
kCastIdx,
kTruncI,
kBitCast,
kBinaryBranch, // semiring unary branch created from a binary op
kUnary, // semiring unary op
// Binary operations.
kMulF,
kMulI,
@ -62,6 +64,7 @@ enum Kind {
kShrS, // signed
kShrU, // unsigned
kShlI,
kBinary, // semiring binary op
};
/// Children subexpressions of tensor operations.
@ -72,7 +75,7 @@ struct Children {
/// Tensor expression. Represents a MLIR expression in tensor index notation.
struct TensorExp {
TensorExp(Kind k, unsigned x, unsigned y, Value v);
TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *operation);
/// Tensor expression kind.
Kind kind;
@ -92,6 +95,12 @@ struct TensorExp {
/// infer destination type) of a cast operation During code generation,
/// this field may be used to cache "hoisted" loop invariant tensor loads.
Value val;
/// Code blocks used by semirings. For the case of kUnary and
/// kBinary, this holds the original operation with all regions. For
/// kBinaryBranch, this holds the YieldOp for the left or right half
/// to be merged into a nested scf loop.
Operation *op;
};
/// Lattice point. Each lattice point consists of a conjunction of tensor
@ -110,7 +119,7 @@ struct LatPoint {
/// must execute. Pre-computed during codegen to avoid repeated eval.
BitVector simple;
/// Index of the tensor expresssion.
/// Index of the tensor expression.
unsigned exp;
};
@ -130,9 +139,14 @@ public:
hasSparseOut(false), dims(t + 1, std::vector<Dim>(l, Dim::kUndef)) {}
/// Adds a tensor expression. Returns its index.
unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value());
unsigned addExp(Kind k, unsigned e, Value v) { return addExp(k, e, -1u, v); }
unsigned addExp(Kind k, Value v) { return addExp(k, -1u, -1u, v); }
unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value(),
Operation *op = nullptr);
unsigned addExp(Kind k, unsigned e, Value v, Operation *op = nullptr) {
return addExp(k, e, -1u, v, op);
}
unsigned addExp(Kind k, Value v, Operation *op = nullptr) {
return addExp(k, -1u, -1u, v, op);
}
/// Adds an iteration lattice point. Returns its index.
unsigned addLat(unsigned t, unsigned i, unsigned e);
@ -144,20 +158,31 @@ public:
/// of loop indices (effectively constructing a larger "intersection" of those
/// indices) with a newly constructed tensor (sub)expression of given kind.
/// Returns the index of the new lattice point.
unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1);
unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1,
Operation *op = nullptr);
/// Conjunctive merge of two lattice sets L0 and L1 is conjunction of
/// cartesian product. Returns the index of the new set.
unsigned takeConj(Kind kind, unsigned s0, unsigned s1);
unsigned takeConj(Kind kind, unsigned s0, unsigned s1,
Operation *op = nullptr);
/// Disjunctive merge of two lattice sets L0 and L1 is (L0 /\_op L1, L0, L1).
/// Returns the index of the new set.
unsigned takeDisj(Kind kind, unsigned s0, unsigned s1);
unsigned takeDisj(Kind kind, unsigned s0, unsigned s1,
Operation *op = nullptr);
/// Disjunctive merge of two lattice sets L0 and L1 with custom handling of
/// the overlap, left, and right regions. Any region may be left missing in
/// the output. Returns the index of the new set.
unsigned takeCombi(Kind kind, unsigned s0, unsigned s1, Operation *orig,
bool includeLeft, Kind ltrans, Operation *opleft,
bool includeRight, Kind rtrans, Operation *opright);
/// Maps the unary operator over the lattice set of the operand, i.e. each
/// lattice point on an expression E is simply copied over, but with OP E
/// as new expression. Returns the index of the new set.
unsigned mapSet(Kind kind, unsigned s0, Value v = Value());
unsigned mapSet(Kind kind, unsigned s0, Value v = Value(),
Operation *op = nullptr);
/// Optimizes the iteration lattice points in the given set. This
/// method should be called right before code generation to avoid

View File

@ -782,7 +782,7 @@ static Value genTensorLoad(Merger &merger, CodeGen &codegen,
/// Generates a store on a dense or sparse tensor.
static void genTensorStore(Merger &merger, CodeGen &codegen,
PatternRewriter &rewriter, linalg::GenericOp op,
Value rhs) {
unsigned exp, Value rhs) {
Location loc = op.getLoc();
// Test if this is a scalarized reduction.
if (codegen.redVal) {
@ -795,7 +795,14 @@ static void genTensorStore(Merger &merger, CodeGen &codegen,
// Store during insertion.
OpOperand *t = op.getOutputOperand(0);
if (t == codegen.sparseOut) {
genInsertionStore(codegen, rewriter, op, t, rhs);
if (!rhs) {
// Only unary and binary are allowed to return uninitialized rhs
// to indicate missing output.
Kind kind = merger.exp(exp).kind;
assert(kind == kUnary || kind == kBinary);
} else {
genInsertionStore(codegen, rewriter, op, t, rhs);
}
return;
}
// Actual store.
@ -982,7 +989,7 @@ static void genInvariants(Merger &merger, CodeGen &codegen,
updateReduc(merger, codegen, Value());
codegen.redExp = -1u;
codegen.redKind = kNoReduc;
genTensorStore(merger, codegen, rewriter, op, redVal);
genTensorStore(merger, codegen, rewriter, op, exp, redVal);
}
} else {
// Start or end loop invariant hoisting of a tensor load.
@ -1225,8 +1232,7 @@ static Operation *genFor(Merger &merger, CodeGen &codegen,
/// Emit a while-loop for co-iteration over multiple indices.
static Operation *genWhile(Merger &merger, CodeGen &codegen,
PatternRewriter &rewriter, linalg::GenericOp op,
unsigned idx, bool needsUniv,
BitVector &indices) {
unsigned idx, bool needsUniv, BitVector &indices) {
SmallVector<Type, 4> types;
SmallVector<Value, 4> operands;
// Construct the while-loop with a parameter for each index.
@ -1373,8 +1379,7 @@ static void genLocals(Merger &merger, CodeGen &codegen,
static void genWhileInduction(Merger &merger, CodeGen &codegen,
PatternRewriter &rewriter, linalg::GenericOp op,
unsigned idx, bool needsUniv,
BitVector &induction,
scf::WhileOp whileOp) {
BitVector &induction, scf::WhileOp whileOp) {
Location loc = op.getLoc();
// Finalize each else branch of all if statements.
if (codegen.redVal || codegen.expValues) {
@ -1599,7 +1604,7 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
if (at == topSort.size()) {
unsigned ldx = topSort[at - 1];
Value rhs = genExp(merger, codegen, rewriter, op, exp, ldx);
genTensorStore(merger, codegen, rewriter, op, rhs);
genTensorStore(merger, codegen, rewriter, op, exp, rhs);
return;
}

View File

@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SparseTensor/Utils/Merger.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Math/IR/Math.h"
@ -20,18 +21,18 @@ namespace sparse_tensor {
// Constructors.
//===----------------------------------------------------------------------===//
TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v)
: kind(k), val(v) {
TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
: kind(k), val(v), op(o) {
switch (kind) {
case kTensor:
assert(x != -1u && y == -1u && !v);
assert(x != -1u && y == -1u && !v && !o);
tensor = x;
break;
case kInvariant:
assert(x == -1u && y == -1u && v);
assert(x == -1u && y == -1u && v && !o);
break;
case kIndex:
assert(x != -1u && y == -1u && !v);
assert(x != -1u && y == -1u && !v && !o);
index = x;
break;
case kAbsF:
@ -39,7 +40,7 @@ TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v)
case kFloorF:
case kNegF:
case kNegI:
assert(x != -1u && y == -1u && !v);
assert(x != -1u && y == -1u && !v && !o);
children.e0 = x;
children.e1 = y;
break;
@ -54,12 +55,29 @@ TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v)
case kCastIdx:
case kTruncI:
case kBitCast:
assert(x != -1u && y == -1u && v);
assert(x != -1u && y == -1u && v && !o);
children.e0 = x;
children.e1 = y;
break;
case kBinaryBranch:
assert(x != -1u && y == -1u && !v && o);
children.e0 = x;
children.e1 = y;
break;
case kUnary:
// No assertion on y can be made, as the branching paths involve both
// a unary (mapSet) and binary (takeDisj) pathway.
assert(x != -1u && !v && o);
children.e0 = x;
children.e1 = y;
break;
case kBinary:
assert(x != -1u && y != -1u && !v && o);
children.e0 = x;
children.e1 = y;
break;
default:
assert(x != -1u && y != -1u && !v);
assert(x != -1u && y != -1u && !v && !o);
children.e0 = x;
children.e1 = y;
break;
@ -78,9 +96,10 @@ LatPoint::LatPoint(const BitVector &b, unsigned e)
// Lattice methods.
//===----------------------------------------------------------------------===//
unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v) {
unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v,
Operation *op) {
unsigned e = tensorExps.size();
tensorExps.push_back(TensorExp(k, e0, e1, v));
tensorExps.push_back(TensorExp(k, e0, e1, v, op));
return e;
}
@ -97,29 +116,31 @@ unsigned Merger::addSet() {
return s;
}
unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1) {
unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1,
Operation *op) {
unsigned p = latPoints.size();
BitVector nb = BitVector(latPoints[p0].bits);
nb |= latPoints[p1].bits;
unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp);
unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp, Value(), op);
latPoints.push_back(LatPoint(nb, e));
return p;
}
unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1) {
unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1, Operation *op) {
unsigned s = addSet();
for (unsigned p0 : latSets[s0])
for (unsigned p1 : latSets[s1])
latSets[s].push_back(conjLatPoint(kind, p0, p1));
latSets[s].push_back(conjLatPoint(kind, p0, p1, op));
return s;
}
unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) {
unsigned s = takeConj(kind, s0, s1);
unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1, Operation *op) {
unsigned s = takeConj(kind, s0, s1, op);
// Followed by all in s0.
for (unsigned p : latSets[s0])
latSets[s].push_back(p);
// Map binary 0-y to unary -y.
// TODO: move this if-else logic into buildLattices
if (kind == kSubF)
s1 = mapSet(kNegF, s1);
else if (kind == kSubI)
@ -130,11 +151,32 @@ unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) {
return s;
}
unsigned Merger::mapSet(Kind kind, unsigned s0, Value v) {
assert(kAbsF <= kind && kind <= kBitCast);
unsigned Merger::takeCombi(Kind kind, unsigned s0, unsigned s1, Operation *orig,
bool includeLeft, Kind ltrans, Operation *opleft,
bool includeRight, Kind rtrans, Operation *opright) {
unsigned s = takeConj(kind, s0, s1, orig);
// Left Region.
if (includeLeft) {
if (opleft)
s0 = mapSet(ltrans, s0, Value(), opleft);
for (unsigned p : latSets[s0])
latSets[s].push_back(p);
}
// Right Region.
if (includeRight) {
if (opright)
s1 = mapSet(rtrans, s1, Value(), opright);
for (unsigned p : latSets[s1])
latSets[s].push_back(p);
}
return s;
}
unsigned Merger::mapSet(Kind kind, unsigned s0, Value v, Operation *op) {
assert(kAbsF <= kind && kind <= kUnary);
unsigned s = addSet();
for (unsigned p : latSets[s0]) {
unsigned e = addExp(kind, latPoints[p].exp, v);
unsigned e = addExp(kind, latPoints[p].exp, v, op);
latPoints.push_back(LatPoint(latPoints[p].bits, e));
latSets[s].push_back(latPoints.size() - 1);
}
@ -304,6 +346,10 @@ static const char *kindToOpSymbol(Kind kind) {
case kTruncI:
case kBitCast:
return "cast";
case kBinaryBranch:
return "binary_branch";
case kUnary:
return "unary";
case kMulF:
return "*";
case kMulI:
@ -334,6 +380,8 @@ static const char *kindToOpSymbol(Kind kind) {
return ">>";
case kShlI:
return "<<";
case kBinary:
return "binary";
}
llvm_unreachable("unexpected kind for symbol");
}
@ -475,6 +523,35 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
// | 0 |-y |
return mapSet(kind, buildLattices(tensorExps[e].children.e0, i),
tensorExps[e].val);
case kBinaryBranch:
// The left or right half of a binary operation which has already
// been split into separate operations for each region.
return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), Value(),
tensorExps[e].op);
case kUnary:
// A custom unary operation.
//
// op y| !y | y |
// ----+----------+------------+
// | absent() | present(y) |
{
unsigned child0 = buildLattices(tensorExps[e].children.e0, i);
UnaryOp unop = cast<UnaryOp>(tensorExps[e].op);
Region &absentRegion = unop.absentRegion();
if (absentRegion.empty()) {
// Simple mapping over existing values.
return mapSet(kind, child0, Value(), unop);
} else {
// Use a disjunction with `unop` on the left and the absent value as an
// invariant on the right.
Block &absentBlock = absentRegion.front();
YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator());
Value absentVal = absentYield.result();
unsigned rhs = addExp(kInvariant, absentVal);
return takeDisj(kind, child0, buildLattices(rhs, i), unop);
}
}
case kMulF:
case kMulI:
case kAndI:
@ -534,6 +611,37 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
return takeConj(kind, // take binary conjunction
buildLattices(tensorExps[e].children.e0, i),
buildLattices(tensorExps[e].children.e1, i));
case kBinary:
// A custom binary operation.
//
// x op y| !y | y |
// ------+---------+--------------+
// !x | empty | right(y) |
// x | left(x) | overlap(x,y) |
{
unsigned child0 = buildLattices(tensorExps[e].children.e0, i);
unsigned child1 = buildLattices(tensorExps[e].children.e1, i);
BinaryOp binop = cast<BinaryOp>(tensorExps[e].op);
Region &leftRegion = binop.leftRegion();
Region &rightRegion = binop.rightRegion();
// Left Region.
Operation *leftYield = nullptr;
if (!leftRegion.empty()) {
Block &leftBlock = leftRegion.front();
leftYield = leftBlock.getTerminator();
}
// Right Region.
Operation *rightYield = nullptr;
if (!rightRegion.empty()) {
Block &rightBlock = rightRegion.front();
rightYield = rightBlock.getTerminator();
}
bool includeLeft = binop.left_identity() || !leftRegion.empty();
bool includeRight = binop.right_identity() || !rightRegion.empty();
return takeCombi(kBinary, child0, child1, binop, includeLeft,
kBinaryBranch, leftYield, includeRight, kBinaryBranch,
rightYield);
}
}
llvm_unreachable("unexpected expression kind");
}
@ -628,6 +736,8 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
return addExp(kTruncI, e, v);
if (isa<arith::BitcastOp>(def))
return addExp(kBitCast, e, v);
if (isa<sparse_tensor::UnaryOp>(def))
return addExp(kUnary, e, Value(), def);
}
}
// Construct binary operations if subexpressions can be built.
@ -669,12 +779,59 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
return addExp(kShrU, e0, e1);
if (isa<arith::ShLIOp>(def) && isInvariant(e1))
return addExp(kShlI, e0, e1);
if (isa<sparse_tensor::BinaryOp>(def))
return addExp(kBinary, e0, e1, Value(), def);
}
}
// Cannot build.
return None;
}
static Value insertYieldOp(PatternRewriter &rewriter, Location loc,
Region &region, ValueRange vals) {
// Make a clone of overlap region.
Region tmpRegion;
BlockAndValueMapping mapper;
region.cloneInto(&tmpRegion, tmpRegion.begin(), mapper);
Block &clonedBlock = tmpRegion.front();
YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator());
// Merge cloned block and return yield value.
Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0);
rewriter.mergeBlockBefore(&tmpRegion.front(), placeholder, vals);
Value val = clonedYield.result();
rewriter.eraseOp(clonedYield);
rewriter.eraseOp(placeholder);
return val;
}
static Value buildUnaryPresent(PatternRewriter &rewriter, Location loc,
Operation *op, Value v0) {
if (!v0)
// Empty input value must be propagated.
return Value();
UnaryOp unop = cast<UnaryOp>(op);
Region &presentRegion = unop.presentRegion();
if (presentRegion.empty())
// Uninitialized Value() will be interpreted as missing data in the
// output.
return Value();
return insertYieldOp(rewriter, loc, presentRegion, {v0});
}
static Value buildBinaryOverlap(PatternRewriter &rewriter, Location loc,
Operation *op, Value v0, Value v1) {
if (!v0 || !v1)
// Empty input values must be propagated.
return Value();
BinaryOp binop = cast<BinaryOp>(op);
Region &overlapRegion = binop.overlapRegion();
if (overlapRegion.empty())
// Uninitialized Value() will be interpreted as missing data in the
// output.
return Value();
return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1});
}
Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e,
Value v0, Value v1) {
switch (tensorExps[e].kind) {
@ -750,6 +907,14 @@ Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e,
return rewriter.create<arith::ShRUIOp>(loc, v0, v1);
case kShlI:
return rewriter.create<arith::ShLIOp>(loc, v0, v1);
// Semiring ops with custom logic.
case kBinaryBranch:
return insertYieldOp(rewriter, loc,
*tensorExps[e].op->getBlock()->getParent(), {v0});
case kUnary:
return buildUnaryPresent(rewriter, loc, tensorExps[e].op, v0);
case kBinary:
return buildBinaryOverlap(rewriter, loc, tensorExps[e].op, v0, v1);
}
llvm_unreachable("unexpected expression kind in build");
}

View File

@ -0,0 +1,294 @@
// RUN: mlir-opt %s --sparse-compiler | \
// RUN: mlir-cpu-runner \
// RUN: -e entry -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
// RUN: FileCheck %s
#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
//
// Traits for tensor operations.
//
#trait_vec_scale = {
indexing_maps = [
affine_map<(i) -> (i)>, // a (in)
affine_map<(i) -> (i)> // x (out)
],
iterator_types = ["parallel"]
}
#trait_vec_op = {
indexing_maps = [
affine_map<(i) -> (i)>, // a (in)
affine_map<(i) -> (i)>, // b (in)
affine_map<(i) -> (i)> // x (out)
],
iterator_types = ["parallel"]
}
#trait_mat_op = {
indexing_maps = [
affine_map<(i,j) -> (i,j)>, // A (in)
affine_map<(i,j) -> (i,j)>, // B (in)
affine_map<(i,j) -> (i,j)> // X (out)
],
iterator_types = ["parallel", "parallel"],
doc = "X(i,j) = A(i,j) OP B(i,j)"
}
module {
// Creates a new sparse vector using the minimum values from two input sparse vectors.
// When there is no overlap, include the present value in the output.
func @vector_min(%arga: tensor<?xf64, #SparseVector>,
%argb: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector> {
%c = arith.constant 0 : index
%d = tensor.dim %arga, %c : tensor<?xf64, #SparseVector>
%xv = sparse_tensor.init [%d] : tensor<?xf64, #SparseVector>
%0 = linalg.generic #trait_vec_op
ins(%arga, %argb: tensor<?xf64, #SparseVector>, tensor<?xf64, #SparseVector>)
outs(%xv: tensor<?xf64, #SparseVector>) {
^bb(%a: f64, %b: f64, %x: f64):
%1 = sparse_tensor.binary %a, %b : f64, f64 to f64
overlap={
^bb0(%a0: f64, %b0: f64):
%cmp = arith.cmpf "olt", %a0, %b0 : f64
%2 = arith.select %cmp, %a0, %b0: f64
sparse_tensor.yield %2 : f64
}
left=identity
right=identity
linalg.yield %1 : f64
} -> tensor<?xf64, #SparseVector>
return %0 : tensor<?xf64, #SparseVector>
}
// Creates a new sparse vector by multiplying a sparse vector with a dense vector.
// When there is no overlap, leave the result empty.
func @vector_mul(%arga: tensor<?xf64, #SparseVector>,
%argb: tensor<?xf64>) -> tensor<?xf64, #SparseVector> {
%c = arith.constant 0 : index
%d = tensor.dim %arga, %c : tensor<?xf64, #SparseVector>
%xv = sparse_tensor.init [%d] : tensor<?xf64, #SparseVector>
%0 = linalg.generic #trait_vec_op
ins(%arga, %argb: tensor<?xf64, #SparseVector>, tensor<?xf64>)
outs(%xv: tensor<?xf64, #SparseVector>) {
^bb(%a: f64, %b: f64, %x: f64):
%1 = sparse_tensor.binary %a, %b : f64, f64 to f64
overlap={
^bb0(%a0: f64, %b0: f64):
%ret = arith.mulf %a0, %b0 : f64
sparse_tensor.yield %ret : f64
}
left={}
right={}
linalg.yield %1 : f64
} -> tensor<?xf64, #SparseVector>
return %0 : tensor<?xf64, #SparseVector>
}
// Take a set difference of two sparse vectors. The result will include only those
// sparse elements present in the first, but not the second vector.
func @vector_setdiff(%arga: tensor<?xf64, #SparseVector>,
%argb: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector> {
%c = arith.constant 0 : index
%d = tensor.dim %arga, %c : tensor<?xf64, #SparseVector>
%xv = sparse_tensor.init [%d] : tensor<?xf64, #SparseVector>
%0 = linalg.generic #trait_vec_op
ins(%arga, %argb: tensor<?xf64, #SparseVector>, tensor<?xf64, #SparseVector>)
outs(%xv: tensor<?xf64, #SparseVector>) {
^bb(%a: f64, %b: f64, %x: f64):
%1 = sparse_tensor.binary %a, %b : f64, f64 to f64
overlap={}
left=identity
right={}
linalg.yield %1 : f64
} -> tensor<?xf64, #SparseVector>
return %0 : tensor<?xf64, #SparseVector>
}
// Return the index of each entry
func @vector_index(%arga: tensor<?xf64, #SparseVector>) -> tensor<?xi32, #SparseVector> {
%c = arith.constant 0 : index
%d = tensor.dim %arga, %c : tensor<?xf64, #SparseVector>
%xv = sparse_tensor.init [%d] : tensor<?xi32, #SparseVector>
%0 = linalg.generic #trait_vec_scale
ins(%arga: tensor<?xf64, #SparseVector>)
outs(%xv: tensor<?xi32, #SparseVector>) {
^bb(%a: f64, %x: i32):
%idx = linalg.index 0 : index
%1 = sparse_tensor.binary %a, %idx : f64, index to i32
overlap={
^bb0(%x0: f64, %i: index):
%ret = arith.index_cast %i : index to i32
sparse_tensor.yield %ret : i32
}
left={}
right={}
linalg.yield %1 : i32
} -> tensor<?xi32, #SparseVector>
return %0 : tensor<?xi32, #SparseVector>
}
// Adds two sparse matrices when they intersect. Where they don't intersect,
// negate the 2nd argument's values; ignore 1st argument-only values.
func @matrix_intersect(%arga: tensor<?x?xf64, #DCSR>,
%argb: tensor<?x?xf64, #DCSR>) -> tensor<?x?xf64, #DCSR> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%d0 = tensor.dim %arga, %c0 : tensor<?x?xf64, #DCSR>
%d1 = tensor.dim %arga, %c1 : tensor<?x?xf64, #DCSR>
%xv = sparse_tensor.init [%d0, %d1] : tensor<?x?xf64, #DCSR>
%0 = linalg.generic #trait_mat_op
ins(%arga, %argb: tensor<?x?xf64, #DCSR>, tensor<?x?xf64, #DCSR>)
outs(%xv: tensor<?x?xf64, #DCSR>) {
^bb(%a: f64, %b: f64, %x: f64):
%1 = sparse_tensor.binary %a, %b: f64, f64 to f64
overlap={
^bb0(%x0: f64, %y0: f64):
%ret = arith.addf %x0, %y0 : f64
sparse_tensor.yield %ret : f64
}
left={}
right={
^bb0(%x1: f64):
%lret = arith.negf %x1 : f64
sparse_tensor.yield %lret : f64
}
linalg.yield %1 : f64
} -> tensor<?x?xf64, #DCSR>
return %0 : tensor<?x?xf64, #DCSR>
}
// Dumps a sparse vector of type f64.
func @dump_vec(%arg0: tensor<?xf64, #SparseVector>) {
// Dump the values array to verify only sparse contents are stored.
%c0 = arith.constant 0 : index
%d0 = arith.constant -1.0 : f64
%0 = sparse_tensor.values %arg0 : tensor<?xf64, #SparseVector> to memref<?xf64>
%1 = vector.transfer_read %0[%c0], %d0: memref<?xf64>, vector<16xf64>
vector.print %1 : vector<16xf64>
// Dump the dense vector to verify structure is correct.
%dv = sparse_tensor.convert %arg0 : tensor<?xf64, #SparseVector> to tensor<?xf64>
%2 = bufferization.to_memref %dv : memref<?xf64>
%3 = vector.transfer_read %2[%c0], %d0: memref<?xf64>, vector<32xf64>
vector.print %3 : vector<32xf64>
memref.dealloc %2 : memref<?xf64>
return
}
// Dumps a sparse vector of type i32.
func @dump_vec_i32(%arg0: tensor<?xi32, #SparseVector>) {
// Dump the values array to verify only sparse contents are stored.
%c0 = arith.constant 0 : index
%d0 = arith.constant -1 : i32
%0 = sparse_tensor.values %arg0 : tensor<?xi32, #SparseVector> to memref<?xi32>
%1 = vector.transfer_read %0[%c0], %d0: memref<?xi32>, vector<24xi32>
vector.print %1 : vector<24xi32>
// Dump the dense vector to verify structure is correct.
%dv = sparse_tensor.convert %arg0 : tensor<?xi32, #SparseVector> to tensor<?xi32>
%2 = bufferization.to_memref %dv : memref<?xi32>
%3 = vector.transfer_read %2[%c0], %d0: memref<?xi32>, vector<32xi32>
vector.print %3 : vector<32xi32>
memref.dealloc %2 : memref<?xi32>
return
}
// Dump a sparse matrix.
func @dump_mat(%arg0: tensor<?x?xf64, #DCSR>) {
%d0 = arith.constant 0.0 : f64
%c0 = arith.constant 0 : index
%dm = sparse_tensor.convert %arg0 : tensor<?x?xf64, #DCSR> to tensor<?x?xf64>
%0 = bufferization.to_memref %dm : memref<?x?xf64>
%1 = vector.transfer_read %0[%c0, %c0], %d0: memref<?x?xf64>, vector<4x8xf64>
vector.print %1 : vector<4x8xf64>
memref.dealloc %0 : memref<?x?xf64>
return
}
// Driver method to call and verify vector kernels.
func @entry() {
%c0 = arith.constant 0 : index
// Setup sparse vectors.
%v1 = arith.constant sparse<
[ [0], [3], [11], [17], [20], [21], [28], [29], [31] ],
[ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0 ]
> : tensor<32xf64>
%v2 = arith.constant sparse<
[ [1], [3], [4], [10], [16], [18], [21], [28], [29], [31] ],
[11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0 ]
> : tensor<32xf64>
%v3 = arith.constant dense<
[0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 0., 1., 2., 3., 4., 5., 6., 7., 8., 9.,
0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 0., 1.]
> : tensor<32xf64>
%sv1 = sparse_tensor.convert %v1 : tensor<32xf64> to tensor<?xf64, #SparseVector>
%sv2 = sparse_tensor.convert %v2 : tensor<32xf64> to tensor<?xf64, #SparseVector>
%dv3 = tensor.cast %v3 : tensor<32xf64> to tensor<?xf64>
// Setup sparse matrices.
%m1 = arith.constant sparse<
[ [0,0], [0,1], [1,7], [2,2], [2,4], [2,7], [3,0], [3,2], [3,3] ],
[ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0 ]
> : tensor<4x8xf64>
%m2 = arith.constant sparse<
[ [0,0], [0,7], [1,0], [1,6], [2,1], [2,7] ],
[6.0, 5.0, 4.0, 3.0, 2.0, 1.0 ]
> : tensor<4x8xf64>
%sm1 = sparse_tensor.convert %m1 : tensor<4x8xf64> to tensor<?x?xf64, #DCSR>
%sm2 = sparse_tensor.convert %m2 : tensor<4x8xf64> to tensor<?x?xf64, #DCSR>
// Call sparse vector kernels.
%0 = call @vector_min(%sv1, %sv2)
: (tensor<?xf64, #SparseVector>,
tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector>
%1 = call @vector_mul(%sv1, %dv3)
: (tensor<?xf64, #SparseVector>,
tensor<?xf64>) -> tensor<?xf64, #SparseVector>
%2 = call @vector_setdiff(%sv1, %sv2)
: (tensor<?xf64, #SparseVector>,
tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector>
%3 = call @vector_index(%sv1)
: (tensor<?xf64, #SparseVector>) -> tensor<?xi32, #SparseVector>
// Call sparse matrix kernels.
%5 = call @matrix_intersect(%sm1, %sm2)
: (tensor<?x?xf64, #DCSR>, tensor<?x?xf64, #DCSR>) -> tensor<?x?xf64, #DCSR>
//
// Verify the results.
//
// CHECK: ( 1, 2, 3, 4, 5, 6, 7, 8, 9, -1, -1, -1, -1, -1, -1, -1 )
// CHECK-NEXT: ( 1, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 4, 0, 0, 5, 6, 0, 0, 0, 0, 0, 0, 7, 8, 0, 9 )
// CHECK-NEXT: ( 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, -1, -1, -1, -1, -1, -1 )
// CHECK-NEXT: ( 0, 11, 0, 12, 13, 0, 0, 0, 0, 0, 14, 0, 0, 0, 0, 0, 15, 0, 16, 0, 0, 17, 0, 0, 0, 0, 0, 0, 18, 19, 0, 20 )
// CHECK-NEXT: ( 1, 11, 2, 13, 14, 3, 15, 4, 16, 5, 6, 7, 8, 9, -1, -1 )
// CHECK-NEXT: ( 1, 11, 0, 2, 13, 0, 0, 0, 0, 0, 14, 3, 0, 0, 0, 0, 15, 4, 16, 0, 5, 6, 0, 0, 0, 0, 0, 0, 7, 8, 0, 9 )
// CHECK-NEXT: ( 0, 6, 3, 28, 0, 6, 56, 72, 9, -1, -1, -1, -1, -1, -1, -1 )
// CHECK-NEXT: ( 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 28, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 56, 72, 0, 9 )
// CHECK-NEXT: ( 1, 3, 4, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 )
// CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 4, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
// CHECK-NEXT: ( 0, 3, 11, 17, 20, 21, 28, 29, 31, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 )
// CHECK-NEXT: ( 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 11, 0, 0, 0, 0, 0, 17, 0, 0, 20, 21, 0, 0, 0, 0, 0, 0, 28, 29, 0, 31 )
// CHECK-NEXT: ( ( 7, 0, 0, 0, 0, 0, 0, -5 ), ( -4, 0, 0, 0, 0, 0, -3, 0 ), ( 0, -2, 0, 0, 0, 0, 0, 7 ), ( 0, 0, 0, 0, 0, 0, 0, 0 ) )
//
call @dump_vec(%sv1) : (tensor<?xf64, #SparseVector>) -> ()
call @dump_vec(%sv2) : (tensor<?xf64, #SparseVector>) -> ()
call @dump_vec(%0) : (tensor<?xf64, #SparseVector>) -> ()
call @dump_vec(%1) : (tensor<?xf64, #SparseVector>) -> ()
call @dump_vec(%2) : (tensor<?xf64, #SparseVector>) -> ()
call @dump_vec_i32(%3) : (tensor<?xi32, #SparseVector>) -> ()
call @dump_mat(%5) : (tensor<?x?xf64, #DCSR>) -> ()
// Release the resources.
sparse_tensor.release %sv1 : tensor<?xf64, #SparseVector>
sparse_tensor.release %sv2 : tensor<?xf64, #SparseVector>
sparse_tensor.release %sm1 : tensor<?x?xf64, #DCSR>
sparse_tensor.release %sm2 : tensor<?x?xf64, #DCSR>
sparse_tensor.release %0 : tensor<?xf64, #SparseVector>
sparse_tensor.release %1 : tensor<?xf64, #SparseVector>
sparse_tensor.release %2 : tensor<?xf64, #SparseVector>
sparse_tensor.release %3 : tensor<?xi32, #SparseVector>
sparse_tensor.release %5 : tensor<?x?xf64, #DCSR>
return
}
}

View File

@ -131,7 +131,7 @@ module {
%sm1 = sparse_tensor.convert %m1 : tensor<4x8xf64> to tensor<?x?xf64, #DCSR>
%sm2 = sparse_tensor.convert %m2 : tensor<4x8xf64> to tensor<?x?xf64, #DCSR>
// Call sparse vector kernels.
// Call sparse matrix kernels.
%0 = call @matrix_scale(%sm1)
: (tensor<?x?xf64, #DCSR>) -> tensor<?x?xf64, #DCSR>
%1 = call @matrix_scale_inplace(%sm1)

View File

@ -0,0 +1,205 @@
// RUN: mlir-opt %s --sparse-compiler | \
// RUN: mlir-cpu-runner \
// RUN: -e entry -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
// RUN: FileCheck %s
#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
//
// Traits for tensor operations.
//
#trait_vec_scale = {
indexing_maps = [
affine_map<(i) -> (i)>, // a (in)
affine_map<(i) -> (i)> // x (out)
],
iterator_types = ["parallel"]
}
#trait_mat_scale = {
indexing_maps = [
affine_map<(i,j) -> (i,j)>, // A (in)
affine_map<(i,j) -> (i,j)> // X (out)
],
iterator_types = ["parallel", "parallel"]
}
module {
// Invert the structure of a sparse vector. Present values become missing.
// Missing values are filled with 1 (i32).
func @vector_complement(%arga: tensor<?xf64, #SparseVector>) -> tensor<?xi32, #SparseVector> {
%c = arith.constant 0 : index
%ci1 = arith.constant 1 : i32
%d = tensor.dim %arga, %c : tensor<?xf64, #SparseVector>
%xv = sparse_tensor.init [%d] : tensor<?xi32, #SparseVector>
%0 = linalg.generic #trait_vec_scale
ins(%arga: tensor<?xf64, #SparseVector>)
outs(%xv: tensor<?xi32, #SparseVector>) {
^bb(%a: f64, %x: i32):
%1 = sparse_tensor.unary %a : f64 to i32
present={}
absent={
sparse_tensor.yield %ci1 : i32
}
linalg.yield %1 : i32
} -> tensor<?xi32, #SparseVector>
return %0 : tensor<?xi32, #SparseVector>
}
// Negate existing values. Fill missing ones with +1.
func @vector_negation(%arga: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector> {
%c = arith.constant 0 : index
%cf1 = arith.constant 1.0 : f64
%d = tensor.dim %arga, %c : tensor<?xf64, #SparseVector>
%xv = sparse_tensor.init [%d] : tensor<?xf64, #SparseVector>
%0 = linalg.generic #trait_vec_scale
ins(%arga: tensor<?xf64, #SparseVector>)
outs(%xv: tensor<?xf64, #SparseVector>) {
^bb(%a: f64, %x: f64):
%1 = sparse_tensor.unary %a : f64 to f64
present={
^bb0(%x0: f64):
%ret = arith.negf %x0 : f64
sparse_tensor.yield %ret : f64
}
absent={
sparse_tensor.yield %cf1 : f64
}
linalg.yield %1 : f64
} -> tensor<?xf64, #SparseVector>
return %0 : tensor<?xf64, #SparseVector>
}
// Clips values to the range [3, 7].
func @matrix_clip(%argx: tensor<?x?xf64, #DCSR>) -> tensor<?x?xf64, #DCSR> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%cfmin = arith.constant 3.0 : f64
%cfmax = arith.constant 7.0 : f64
%d0 = tensor.dim %argx, %c0 : tensor<?x?xf64, #DCSR>
%d1 = tensor.dim %argx, %c1 : tensor<?x?xf64, #DCSR>
%xv = sparse_tensor.init [%d0, %d1] : tensor<?x?xf64, #DCSR>
%0 = linalg.generic #trait_mat_scale
ins(%argx: tensor<?x?xf64, #DCSR>)
outs(%xv: tensor<?x?xf64, #DCSR>) {
^bb(%a: f64, %x: f64):
%1 = sparse_tensor.unary %a: f64 to f64
present={
^bb0(%x0: f64):
%mincmp = arith.cmpf "ogt", %x0, %cfmin : f64
%x1 = arith.select %mincmp, %x0, %cfmin : f64
%maxcmp = arith.cmpf "olt", %x1, %cfmax : f64
%x2 = arith.select %maxcmp, %x1, %cfmax : f64
sparse_tensor.yield %x2 : f64
}
absent={}
linalg.yield %1 : f64
} -> tensor<?x?xf64, #DCSR>
return %0 : tensor<?x?xf64, #DCSR>
}
// Dumps a sparse vector of type f64.
func @dump_vec_f64(%arg0: tensor<?xf64, #SparseVector>) {
// Dump the values array to verify only sparse contents are stored.
%c0 = arith.constant 0 : index
%d0 = arith.constant -1.0 : f64
%0 = sparse_tensor.values %arg0 : tensor<?xf64, #SparseVector> to memref<?xf64>
%1 = vector.transfer_read %0[%c0], %d0: memref<?xf64>, vector<32xf64>
vector.print %1 : vector<32xf64>
// Dump the dense vector to verify structure is correct.
%dv = sparse_tensor.convert %arg0 : tensor<?xf64, #SparseVector> to tensor<?xf64>
%2 = bufferization.to_memref %dv : memref<?xf64>
%3 = vector.transfer_read %2[%c0], %d0: memref<?xf64>, vector<32xf64>
vector.print %3 : vector<32xf64>
memref.dealloc %2 : memref<?xf64>
return
}
// Dumps a sparse vector of type i32.
func @dump_vec_i32(%arg0: tensor<?xi32, #SparseVector>) {
// Dump the values array to verify only sparse contents are stored.
%c0 = arith.constant 0 : index
%d0 = arith.constant -1 : i32
%0 = sparse_tensor.values %arg0 : tensor<?xi32, #SparseVector> to memref<?xi32>
%1 = vector.transfer_read %0[%c0], %d0: memref<?xi32>, vector<24xi32>
vector.print %1 : vector<24xi32>
// Dump the dense vector to verify structure is correct.
%dv = sparse_tensor.convert %arg0 : tensor<?xi32, #SparseVector> to tensor<?xi32>
%2 = bufferization.to_memref %dv : memref<?xi32>
%3 = vector.transfer_read %2[%c0], %d0: memref<?xi32>, vector<32xi32>
vector.print %3 : vector<32xi32>
memref.dealloc %2 : memref<?xi32>
return
}
// Dump a sparse matrix.
func @dump_mat(%arg0: tensor<?x?xf64, #DCSR>) {
%c0 = arith.constant 0 : index
%d0 = arith.constant -1.0 : f64
%0 = sparse_tensor.values %arg0 : tensor<?x?xf64, #DCSR> to memref<?xf64>
%1 = vector.transfer_read %0[%c0], %d0: memref<?xf64>, vector<16xf64>
vector.print %1 : vector<16xf64>
%dm = sparse_tensor.convert %arg0 : tensor<?x?xf64, #DCSR> to tensor<?x?xf64>
%2 = bufferization.to_memref %dm : memref<?x?xf64>
%3 = vector.transfer_read %2[%c0, %c0], %d0: memref<?x?xf64>, vector<4x8xf64>
vector.print %3 : vector<4x8xf64>
memref.dealloc %2 : memref<?x?xf64>
return
}
// Driver method to call and verify vector kernels.
func @entry() {
%c0 = arith.constant 0 : index
// Setup sparse vectors.
%v1 = arith.constant sparse<
[ [0], [3], [11], [17], [20], [21], [28], [29], [31] ],
[ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0 ]
> : tensor<32xf64>
%sv1 = sparse_tensor.convert %v1 : tensor<32xf64> to tensor<?xf64, #SparseVector>
// Setup sparse matrices.
%m1 = arith.constant sparse<
[ [0,0], [0,1], [1,7], [2,2], [2,4], [2,7], [3,0], [3,2], [3,3] ],
[ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0 ]
> : tensor<4x8xf64>
%sm1 = sparse_tensor.convert %m1 : tensor<4x8xf64> to tensor<?x?xf64, #DCSR>
// Call sparse vector kernels.
%0 = call @vector_complement(%sv1)
: (tensor<?xf64, #SparseVector>) -> tensor<?xi32, #SparseVector>
%1 = call @vector_negation(%sv1)
: (tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector>
// Call sparse matrix kernels.
%2 = call @matrix_clip(%sm1)
: (tensor<?x?xf64, #DCSR>) -> tensor<?x?xf64, #DCSR>
//
// Verify the results.
//
// CHECK: ( 1, 2, 3, 4, 5, 6, 7, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 )
// CHECK-NEXT: ( 1, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 4, 0, 0, 5, 6, 0, 0, 0, 0, 0, 0, 7, 8, 0, 9 )
// CHECK-NEXT: ( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1 )
// CHECK-NEXT: ( 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0 )
// CHECK-NEXT: ( -1, 1, 1, -2, 1, 1, 1, 1, 1, 1, 1, -3, 1, 1, 1, 1, 1, -4, 1, 1, -5, -6, 1, 1, 1, 1, 1, 1, -7, -8, 1, -9 )
// CHECK-NEXT: ( -1, 1, 1, -2, 1, 1, 1, 1, 1, 1, 1, -3, 1, 1, 1, 1, 1, -4, 1, 1, -5, -6, 1, 1, 1, 1, 1, 1, -7, -8, 1, -9 )
// CHECK-NEXT: ( 3, 3, 3, 4, 5, 6, 7, 7, 7, -1, -1, -1, -1, -1, -1, -1 )
// CHECK-NEXT: ( ( 3, 3, 0, 0, 0, 0, 0, 0 ), ( 0, 0, 0, 0, 0, 0, 0, 3 ), ( 0, 0, 4, 0, 5, 0, 0, 6 ), ( 7, 0, 7, 7, 0, 0, 0, 0 ) )
//
call @dump_vec_f64(%sv1) : (tensor<?xf64, #SparseVector>) -> ()
call @dump_vec_i32(%0) : (tensor<?xi32, #SparseVector>) -> ()
call @dump_vec_f64(%1) : (tensor<?xf64, #SparseVector>) -> ()
call @dump_mat(%2) : (tensor<?x?xf64, #DCSR>) -> ()
// Release the resources.
sparse_tensor.release %sv1 : tensor<?xf64, #SparseVector>
sparse_tensor.release %sm1 : tensor<?x?xf64, #DCSR>
sparse_tensor.release %0 : tensor<?xi32, #SparseVector>
sparse_tensor.release %1 : tensor<?xf64, #SparseVector>
sparse_tensor.release %2 : tensor<?x?xf64, #DCSR>
return
}
}

View File

@ -125,13 +125,13 @@ module {
// Sum reduces dot product of two sparse vectors.
func.func @vector_dotprod(%arga: tensor<?xf64, #SparseVector>,
%argb: tensor<?xf64, #SparseVector>,
%argx: tensor<f64> {linalg.inplaceable = true}) -> tensor<f64> {
%argx: tensor<f64> {linalg.inplaceable = true}) -> tensor<f64> {
%0 = linalg.generic #trait_dot
ins(%arga, %argb: tensor<?xf64, #SparseVector>, tensor<?xf64, #SparseVector>)
outs(%argx: tensor<f64>) {
^bb(%a: f64, %b: f64, %x: f64):
%1 = arith.mulf %a, %b : f64
%2 = arith.addf %x, %1 : f64
%2 = arith.addf %x, %1 : f64
linalg.yield %2 : f64
} -> tensor<f64>
return %0 : tensor<f64>