[mlir][linalg] Allow TC ops taking an unused shaped operand.

If one operand is not used in the formula, it will be considered a
shaped operand. And the result of indexing map of the operand will be the first
reduction dims.

Depends On D97383

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D97384
This commit is contained in:
Hanhan Wang 2021-02-26 06:45:08 -08:00
parent 4941fef9c4
commit 855a119604
3 changed files with 45 additions and 17 deletions

View File

@ -582,8 +582,9 @@ better adapt to Linalg:
resorting to more general MLIR parsing.
1. Reduction dimensions are specified with angle bracket notation on the
operation they apply to (e.g. `std_add<k>` specifies that `k` is a reduction
dimension). In TC, a reduction is specified with `op=` operator and the
reduction dimensions are inferred.
dimension). In TC, the reduction dimensions are inferred. If one of the
operand is not used in any expressions, it will be considered a shape-only
operand, and the result of the indexing_map will be reduction dimensions.
1. The parallel and reduction dimension are ordered by the textual program
order. For instance, in the comprehension `O(i, j) = std_add<k, l>(...)`,
`i` (resp. `j`) is a parallel iterator encoded by affine dimension of

View File

@ -190,3 +190,14 @@ def test8(A: f32(M, K), B: f32(K)) -> (C: f32(M))
{
C(m) = std_subf<k>(std_mulf(A(m, k), B(k)), C(m));
}
// Test shape-only operand.
// IMPL-LABEL: ArrayAttr Test9Op::indexing_maps() {
// IMPL: auto map0 = AffineMap::get(2, 2, {d0, d1}, context);
// IMPL: auto map1 = AffineMap::get(2, 2, {d1}, context);
// IMPL: auto map2 = AffineMap::get(2, 2, {d0}, context);
ods_def<Test9Op>:
def test9(A: f32(M, K), B: f32(K)) -> (C: f32(M))
{
C(m) = std_addf<k>(C(m), A(m, k));
}

View File

@ -1634,7 +1634,26 @@ TCParser::parseOneComprehension(StringRef cppOpName, StringRef linalgOpName,
tensor.indexingMap = use.indexingMap;
state.orderedTensorArgs[use] = tensor.index;
});
state.numArgs = seenDefs.size();
// If more than one definitions are less. They are shaped-only operand, which
// are used to define reduction loops. For now, only accept exactly one
// shaped-only operand.
if (state.numArgs > seenDefs.size() + 1) {
failed = true;
} else if (state.numArgs == seenDefs.size() + 1) {
for (auto &tensorIter : registeredTensors) {
auto &tensor = tensorIter.getValue();
if (tensor.indexingMap)
continue;
if (auto *pTensorExpr =
dyn_cast<TensorExpr>(state.expressions[0].get())) {
SmallVector<AffineExpr, 4> exprs;
for (auto dim : pTensorExpr->reductionDimensions)
exprs.push_back(getAffineDimExpr(dim, parser.context));
tensor.indexingMap = AffineMap::get(state.dims.size(), symbols.size(),
exprs, parser.context);
}
}
}
if (failed)
return failure();
@ -1762,6 +1781,7 @@ LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) {
SmallVector<ComprehensionParsingState, 4> perComprehensionStates;
while (parser.curToken.isNot(Token::Kind::r_brace)) {
perComprehensionStates.push_back(ComprehensionParsingState());
perComprehensionStates.back().numArgs = registeredTensors.size();
if (failed(parseOneComprehension(cppOpName, tcName,
perComprehensionStates.back())))
return failure();
@ -2207,10 +2227,6 @@ void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
std::string mapsStr;
llvm::raw_string_ostream mapsStringStream(mapsStr);
SmallVector<TensorUse, 4> orderedUses(state.numArgs);
for (const auto &it : state.orderedTensorArgs)
orderedUses[it.second] = it.first;
// Create a list of all symbols.
SmallVector<std::string, 4> symbolReplacements;
symbolReplacements.reserve(symbols.size());
@ -2242,10 +2258,11 @@ void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
symbolReplacements[position] = llvm::formatv("cst{0}", attrUse.index());
}
// For each tensor use, construct the affine map, replace symbols by the
// corresponding attribute values, and simplify the affine map.
for (auto tensorUse : llvm::enumerate(orderedUses)) {
auto indexingMap = tensorUse.value().indexingMap;
// For each registered tensor, construct the affine map, replace symbols by
// the corresponding attribute values, and simplify the affine map.
for (auto &tensorIter : registeredTensors) {
auto &tensor = tensorIter.getValue();
auto indexingMap = tensor.indexingMap;
const char *mapFmt =
"\n\tauto map{0} = AffineMap::get({1}, {2}, {3}, context);";
@ -2255,8 +2272,7 @@ void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
llvm::interleaveComma(indexingMap.getResults(), exprsStringStream);
exprsStringStream << "}";
exprsStringStream.flush();
mapsStringStream << llvm::formatv(mapFmt, tensorUse.index(),
state.dims.size(),
mapsStringStream << llvm::formatv(mapFmt, tensor.index, state.dims.size(),
indexingMap.getNumSymbols(), exprsStr);
std::string replaceSymbolList =
@ -2269,17 +2285,17 @@ void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
// need that.
const char *replaceFmt =
"\n\tmap{0} = map{0}.replaceDimsAndSymbols({{}, {1}, {2}, 0);";
mapsStringStream << llvm::formatv(replaceFmt, tensorUse.index(),
mapsStringStream << llvm::formatv(replaceFmt, tensor.index,
replaceSymbolList, state.dims.size());
const char *simplifyFmt = "\n\tmap{0} = simplifyAffineMap(map{0});";
mapsStringStream << llvm::formatv(simplifyFmt, tensorUse.index());
mapsStringStream << llvm::formatv(simplifyFmt, tensor.index);
}
mapsStringStream.flush();
SmallVector<std::string, 4> mapList;
mapList.reserve(orderedUses.size());
for (unsigned i = 0; i < orderedUses.size(); ++i)
mapList.reserve(state.numArgs);
for (auto i : llvm::seq<unsigned>(0, state.numArgs))
mapList.push_back(llvm::formatv("map{0}", i));
// 4. Apply format to 1. using 2. and 3.