[mlir][sparse] test for linalg tensor semantics (#70254)

This test used to be here, but somehow got lost while linalg rewrote
their interfaces. It is essential to test this on entry of
sparsification, however, since all subsequent analysis simply assumes
tensor types.

Fixes:
https://github.com/llvm/llvm-project/issues/64325
This commit is contained in:
Aart Bik 2023-10-25 14:07:38 -07:00 committed by GitHub
parent 3dbcd733ad
commit 740582fa4c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1939,9 +1939,12 @@ public:
LogicalResult matchAndRewrite(linalg::GenericOp op,
PatternRewriter &rewriter) const override {
// Only accept single output operations without affine index on sparse
// output.
if (op.getNumDpsInits() != 1 || hasNonTrivialAffineOnSparseOut(op))
// Only accept single output operations with pure tensor semantics.
if (op.getNumDpsInits() != 1 || !op.hasTensorSemantics())
return failure();
// Only accept trivial affine indices.
if (hasNonTrivialAffineOnSparseOut(op))
return failure();
// Sets up a code generation environment.
@ -1951,7 +1954,6 @@ public:
// TODO: we should probably always use slice-based codegen whenever
// possible, we can even intermix slice-based and filter-loop based codegen.
bool idxReducBased = options.enableIndexReduction && numFilterLoops != 0;
// If we have indexing map like (d0) -> (0, d0), there might be more
// levels then loops because of the constant index, that means we can not
// use numLoops as the upper bound for ranks of all tensors.
@ -1964,9 +1966,7 @@ public:
maxLvlRank = std::max(maxLvlRank, SparseTensorType(rtp).getLvlRank());
}
}
// If we uses slice based algorithm for affine index, we do not need filter
// loop.
// A slice based algorithm for affine indices does not need filter loops.
CodegenEnv env(op, options, numTensors, numLoops,
/*numFilterLoops=*/idxReducBased ? 0 : numFilterLoops,
maxLvlRank);
@ -2006,7 +2006,6 @@ public:
// to resolve cycles by inserting a conversion.
bool isAdmissible = false;
bool hasCycle = true;
// A const list of all masks that we used for iteration graph
// computation. Must be ordered from more strict to less strict.
// Ideally (though might not be guaranteed), the earlier a constraint mask
@ -2030,7 +2029,6 @@ public:
? failure() // TODO: should cycle be resolved differently?
: resolveCycle(env, rewriter); // one last shot
}
if (!isAdmissible)
return failure(); // inadmissible expression, reject