[mlir][tensor] Add shape inference methods to tensor::PackOp.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D143686
This commit is contained in:
Hanhan Wang 2023-02-09 17:24:26 -08:00
parent 02718433a0
commit f71de259c3
2 changed files with 80 additions and 8 deletions

View File

@ -1772,6 +1772,14 @@ def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
];
let extraClassDeclaration = commonExtraClassDeclaration # [{
// Method to get the shape of the result as `SmallVector<OpFoldResult>`.
// This is a static method to allow getting the shape of the destination
// expected while creating a `pack` op.
static SmallVector<OpFoldResult> getResultShape(OpBuilder &builder,
Location loc, ArrayRef<OpFoldResult> sourceDims,
ArrayRef<OpFoldResult> innerTileDims, ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm = {});
// Method to get the `ShapedType` of the result based on the inner tiles,
// position of the inner tiles (innerDimsPos) and interchange vector of
// outer loops (outerDimsPerm).

View File

@ -3479,14 +3479,29 @@ LogicalResult PackOp::verify() {
return success();
}
/// Get the expected packed type based on source type, tile factors, position of
/// the inner tiles and permutation of the outer tiled loop.
ShapedType PackOp::inferPackedType(ShapedType sourceType,
ArrayRef<int64_t> innerTileSizes,
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm) {
SmallVector<int64_t> resultShape = llvm::to_vector(sourceType.getShape());
for (const auto &tiledDim : llvm::enumerate(innerDimsPos)) {
/// Converts OpFoldResults to int64_t shape entries, unconditionally mapping all
/// Value's to kDynamic, even if they are arith.constant values.
static SmallVector<int64_t>
asShapeWithAnyValueAsDynamic(ArrayRef<OpFoldResult> ofrs) {
SmallVector<int64_t> result;
for (auto o : ofrs) {
// Have to do this first, as getConstantIntValue special-cases constants.
if (o.dyn_cast<Value>())
result.push_back(ShapedType::kDynamic);
else
result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));
}
return result;
}
/// Helper for PackOp::{getResultShape,inferPackedType}. Returns the shape of
/// the packed type. Having a shared helper helps implement these two methods in
/// a way that ensures that they agree on which dimensions are dynamic.
static SmallVector<int64_t> getPackOpResultTypeShape(
ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes,
ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
SmallVector<int64_t> resultShape = llvm::to_vector(sourceShape);
for (auto tiledDim : llvm::enumerate(innerDimsPos)) {
if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
continue;
if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
@ -3497,11 +3512,60 @@ ShapedType PackOp::inferPackedType(ShapedType sourceType,
innerTileSizes[tiledDim.index()]);
}
// Swap tile loops if outer_dims_perm is available.
if (!outerDimsPerm.empty())
applyPermutationToVector(resultShape, outerDimsPerm);
// Append the inner tile dimensions.
resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
return resultShape;
}
SmallVector<OpFoldResult> PackOp::getResultShape(
OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
ArrayRef<OpFoldResult> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm) {
SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims);
AffineExpr s0, s1;
bindSymbols(builder.getContext(), s0, s1);
AffineExpr ceilDivExpr = s0.ceilDiv(s1);
for (auto tiledDim : llvm::enumerate(innerDimsPos)) {
resultDims[tiledDim.value()] = makeComposedFoldedAffineApply(
builder, loc, ceilDivExpr,
{resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
}
if (!outerDimsPerm.empty())
applyPermutationToVector(resultDims, outerDimsPerm);
resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
SmallVector<int64_t> resultTypeShape =
getPackOpResultTypeShape(asShapeWithAnyValueAsDynamic(sourceDims),
asShapeWithAnyValueAsDynamic(innerTileSizes),
innerDimsPos, outerDimsPerm);
// Fix-up `resultDims` to ensure that they are Value's if and only if the
// result type shape says it's a dynamic dim. This is needed as callers may
// use dispatchIndexOpFoldResults on the result, and rely on exact number of
// dynamic dims returned by that.
for (unsigned i = 0; i < resultDims.size(); ++i) {
if (!ShapedType::isDynamic(resultTypeShape[i]))
continue;
resultDims[i] =
getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]);
}
return resultDims;
}
/// Get the expected packed type based on source type, tile factors, position of
/// the inner tiles and permutation of the outer tiled loop.
ShapedType PackOp::inferPackedType(ShapedType sourceType,
ArrayRef<int64_t> innerTileSizes,
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm) {
SmallVector<int64_t> resultShape = getPackOpResultTypeShape(
sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
return RankedTensorType::get(resultShape, sourceType.getElementType());
}