[linalg][fusion] Disallow fusion when it would create an invalid expand_shape

The input type of a linalg.generic can be less dynamic than its output
type. If this is the case moving a reshape across the generic op would
create invalid IR, as expand_shape cannot expand arbitrary dynamic
dimensions.

Check that the reshape is actually valid before creating the
expand_shape. This exposes the existing verification logic in reshape
utils and removes the incomplete custom implementation in fusion.

Differential Revision: https://reviews.llvm.org/D116600
This commit is contained in:
Benjamin Kramer 2022-01-04 16:52:44 +01:00
parent f100bedb03
commit ff5de8a9e0
4 changed files with 98 additions and 66 deletions

View File

@ -166,47 +166,19 @@ static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
/// 2) if a dimension in the collaped type is dynamic, one and only one of the
/// corresponding dimensions in the expanded type should be dynamic. This
/// rule is only needed with reshape operations that are expanding.
LogicalResult reshapeLikeShapesAreCompatible(
function_ref<LogicalResult(const Twine &)> emitError,
ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape);
template <typename OpTy>
static LogicalResult verifyReshapeLikeShapes(OpTy op, ShapedType collapsedType,
ShapedType expandedType,
bool isExpandingReshape) {
ArrayRef<int64_t> collapsedShape = collapsedType.getShape();
ArrayRef<int64_t> expandedShape = expandedType.getShape();
unsigned expandedDimStart = 0;
for (auto map : llvm::enumerate(op.getReassociationMaps())) {
Optional<int64_t> dynamicShape;
int64_t linearizedStaticShape = 1;
for (auto dim : llvm::enumerate(expandedShape.slice(
expandedDimStart, map.value().getNumResults()))) {
if (ShapedType::isDynamic(dim.value())) {
if (isExpandingReshape && dynamicShape) {
return op->emitOpError("invalid to have a single dimension (")
<< map.index() << ") expanded into multiple dynamic dims ("
<< expandedDimStart + dynamicShape.getValue() << ","
<< expandedDimStart + dim.index() << ")";
}
dynamicShape = dim.index();
} else {
linearizedStaticShape *= dim.value();
}
}
if (dynamicShape) {
if (!ShapedType::isDynamic(collapsedShape[map.index()])) {
return op->emitOpError("expected dimension ")
<< map.index()
<< " of collapsed type to be dynamic since one or more of the "
"corresponding dimensions in the expanded type is dynamic";
}
} else {
if (collapsedShape[map.index()] != linearizedStaticShape) {
return op->emitOpError("expected dimension ")
<< map.index() << " of collapsed type to be static value of "
<< linearizedStaticShape << " ";
}
}
expandedDimStart += map.value().getNumResults();
}
return success();
return reshapeLikeShapesAreCompatible(
[&](const Twine &msg) { return op->emitOpError(msg); },
collapsedType.getShape(), expandedType.getShape(),
op.getReassociationIndices(), isExpandingReshape);
}
/// Pattern to collapse producer/consumer reshape ops that are both collapsing

View File

@ -608,31 +608,6 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
LogicalResult isGenericOpExpandable(GenericOp genericOp,
const ExpansionInfo &expansionInfo,
PatternRewriter &rewriter) {
// Current reshape only supports expansion of a dynamic dim when only one of
// the expanded dims are dynamic.
for (const auto &originalShape :
llvm::enumerate(expansionInfo.getOriginalShape()))
if (ShapedType::isDynamic(originalShape.value())) {
// All but one of the expanded dims must be static.
bool foundDynamicExpandedDim = false;
for (auto expandedShape :
expansionInfo.getExpandedShapeOfDim(originalShape.index())) {
if (ShapedType::isDynamic(expandedShape)) {
if (foundDynamicExpandedDim) {
return rewriter.notifyMatchFailure(
genericOp,
"cannot expanded dynamic dims into multiple dynamic dims");
}
foundDynamicExpandedDim = true;
}
}
if (!foundDynamicExpandedDim) {
return rewriter.notifyMatchFailure(
genericOp, "dynamic dim expansion needs at least one dynamic dim "
"in result shape");
}
}
if (!genericOp.hasIndexSemantics())
return success();
for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
@ -793,13 +768,21 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
}
if (genericOp.isInputTensor(opOperand)) {
AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
auto opOperandType = opOperand->get().getType().cast<RankedTensorType>();
RankedTensorType expandedOperandType =
getExpandedType(opOperand->get().getType().cast<RankedTensorType>(),
indexingMap, expansionInfo);
getExpandedType(opOperandType, indexingMap, expansionInfo);
if (expandedOperandType != opOperand->get().getType()) {
// Reshape the operand to get the right type.
SmallVector<ReassociationIndices> reassociation =
getReassociationForExpansion(indexingMap, expansionInfo);
if (failed(reshapeLikeShapesAreCompatible(
[&](const Twine &msg) {
return rewriter.notifyMatchFailure(genericOp, msg);
},
opOperandType.getShape(), expandedOperandType.getShape(),
reassociation,
/*isExpandingReshape=*/true)))
return llvm::None;
expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
genericOp.getLoc(), expandedOperandType, opOperand->get(),
reassociation));
@ -813,12 +796,20 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
SmallVector<Value> outputs;
for (OpOperand *opOperand : genericOp.getOutputOperands()) {
AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
auto opOperandType = opOperand->get().getType().cast<RankedTensorType>();
RankedTensorType expandedOutputType =
getExpandedType(opOperand->get().getType().cast<RankedTensorType>(),
indexingMap, expansionInfo);
getExpandedType(opOperandType, indexingMap, expansionInfo);
if (expandedOutputType != opOperand->get().getType()) {
SmallVector<ReassociationIndices> reassociation =
getReassociationForExpansion(indexingMap, expansionInfo);
if (failed(reshapeLikeShapesAreCompatible(
[&](const Twine &msg) {
return rewriter.notifyMatchFailure(genericOp, msg);
},
opOperandType.getShape(), expandedOutputType.getShape(),
reassociation,
/*isExpandingReshape=*/true)))
return llvm::None;
outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
genericOp.getLoc(), expandedOutputType, opOperand->get(),
reassociation));

View File

@ -276,3 +276,45 @@ bool mlir::isReassociationValid(ArrayRef<AffineMap> reassociation,
}
return true;
}
LogicalResult mlir::reshapeLikeShapesAreCompatible(
function_ref<LogicalResult(const Twine &)> emitError,
ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape) {
unsigned expandedDimStart = 0;
for (const auto &map : llvm::enumerate(reassociationMaps)) {
Optional<int64_t> dynamicShape;
int64_t linearizedStaticShape = 1;
for (const auto &dim : llvm::enumerate(
expandedShape.slice(expandedDimStart, map.value().size()))) {
if (ShapedType::isDynamic(dim.value())) {
if (isExpandingReshape && dynamicShape) {
return emitError("invalid to have a single dimension (" +
Twine(map.index()) +
") expanded into multiple dynamic dims (" +
Twine(expandedDimStart + dynamicShape.getValue()) +
"," + Twine(expandedDimStart + dim.index()) + ")");
}
dynamicShape = dim.index();
} else {
linearizedStaticShape *= dim.value();
}
}
if (dynamicShape) {
if (!ShapedType::isDynamic(collapsedShape[map.index()])) {
return emitError(
"expected dimension " + Twine(map.index()) +
" of collapsed type to be dynamic since one or more of the "
"corresponding dimensions in the expanded type is dynamic");
}
} else {
if (collapsedShape[map.index()] != linearizedStaticShape) {
return emitError("expected dimension " + Twine(map.index()) +
" of collapsed type to be static value of " +
Twine(linearizedStaticShape));
}
}
expandedDimStart += map.value().size();
}
return success();
}

View File

@ -530,3 +530,30 @@ func @no_fuse_dynamic_dims(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[RESHAPE]] : tensor<?xf32>)
// CHECK: return %[[GENERIC]]
// -----
func @no_fuse_mismatched_dynamism(%arg0: tensor<1x1xi64>, %arg1: tensor<?xi64>) -> tensor<1xi64> {
%0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<1x1xi64> into tensor<1xi64>
%1 = linalg.init_tensor [1] : tensor<1xi64>
%2 = linalg.generic
{indexing_maps = [affine_map<(d0) -> (d0)>,
affine_map<(d0) -> (d0)>,
affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]}
ins(%0, %arg1 : tensor<1xi64>, tensor<?xi64>)
outs(%1 : tensor<1xi64>) {
^bb0(%arg4: i64, %arg5: i64, %arg6: i64): // no predecessors
%3 = arith.addi %arg4, %arg5 : i64
linalg.yield %3 : i64
} -> tensor<1xi64>
return %2 : tensor<1xi64>
}
// CHECK: func @no_fuse_mismatched_dynamism
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1xi64>
// CHECK-SAME: %[[ARG1:.+]]: tensor<?xi64>
// CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]]
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[RESHAPE]], %[[ARG1]] : tensor<1xi64>, tensor<?xi64>)
// CHECK: return %[[GENERIC]]