mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-04-01 12:43:47 +00:00
[linalg][fusion] Disallow fusion when it would create an invalid expand_shape
The input type of a linalg.generic can be less dynamic than its output type. If this is the case moving a reshape across the generic op would create invalid IR, as expand_shape cannot expand arbitrary dynamic dimensions. Check that the reshape is actually valid before creating the expand_shape. This exposes the existing verification logic in reshape utils and removes the incomplete custom implementation in fusion. Differential Revision: https://reviews.llvm.org/D116600
This commit is contained in:
parent
f100bedb03
commit
ff5de8a9e0
@ -166,47 +166,19 @@ static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
|
||||
/// 2) if a dimension in the collaped type is dynamic, one and only one of the
|
||||
/// corresponding dimensions in the expanded type should be dynamic. This
|
||||
/// rule is only needed with reshape operations that are expanding.
|
||||
LogicalResult reshapeLikeShapesAreCompatible(
|
||||
function_ref<LogicalResult(const Twine &)> emitError,
|
||||
ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
|
||||
ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape);
|
||||
|
||||
template <typename OpTy>
|
||||
static LogicalResult verifyReshapeLikeShapes(OpTy op, ShapedType collapsedType,
|
||||
ShapedType expandedType,
|
||||
bool isExpandingReshape) {
|
||||
ArrayRef<int64_t> collapsedShape = collapsedType.getShape();
|
||||
ArrayRef<int64_t> expandedShape = expandedType.getShape();
|
||||
unsigned expandedDimStart = 0;
|
||||
for (auto map : llvm::enumerate(op.getReassociationMaps())) {
|
||||
Optional<int64_t> dynamicShape;
|
||||
int64_t linearizedStaticShape = 1;
|
||||
for (auto dim : llvm::enumerate(expandedShape.slice(
|
||||
expandedDimStart, map.value().getNumResults()))) {
|
||||
if (ShapedType::isDynamic(dim.value())) {
|
||||
if (isExpandingReshape && dynamicShape) {
|
||||
return op->emitOpError("invalid to have a single dimension (")
|
||||
<< map.index() << ") expanded into multiple dynamic dims ("
|
||||
<< expandedDimStart + dynamicShape.getValue() << ","
|
||||
<< expandedDimStart + dim.index() << ")";
|
||||
}
|
||||
dynamicShape = dim.index();
|
||||
} else {
|
||||
linearizedStaticShape *= dim.value();
|
||||
}
|
||||
}
|
||||
if (dynamicShape) {
|
||||
if (!ShapedType::isDynamic(collapsedShape[map.index()])) {
|
||||
return op->emitOpError("expected dimension ")
|
||||
<< map.index()
|
||||
<< " of collapsed type to be dynamic since one or more of the "
|
||||
"corresponding dimensions in the expanded type is dynamic";
|
||||
}
|
||||
} else {
|
||||
if (collapsedShape[map.index()] != linearizedStaticShape) {
|
||||
return op->emitOpError("expected dimension ")
|
||||
<< map.index() << " of collapsed type to be static value of "
|
||||
<< linearizedStaticShape << " ";
|
||||
}
|
||||
}
|
||||
expandedDimStart += map.value().getNumResults();
|
||||
}
|
||||
return success();
|
||||
return reshapeLikeShapesAreCompatible(
|
||||
[&](const Twine &msg) { return op->emitOpError(msg); },
|
||||
collapsedType.getShape(), expandedType.getShape(),
|
||||
op.getReassociationIndices(), isExpandingReshape);
|
||||
}
|
||||
|
||||
/// Pattern to collapse producer/consumer reshape ops that are both collapsing
|
||||
|
@ -608,31 +608,6 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
|
||||
LogicalResult isGenericOpExpandable(GenericOp genericOp,
|
||||
const ExpansionInfo &expansionInfo,
|
||||
PatternRewriter &rewriter) {
|
||||
// Current reshape only supports expansion of a dynamic dim when only one of
|
||||
// the expanded dims are dynamic.
|
||||
for (const auto &originalShape :
|
||||
llvm::enumerate(expansionInfo.getOriginalShape()))
|
||||
if (ShapedType::isDynamic(originalShape.value())) {
|
||||
// All but one of the expanded dims must be static.
|
||||
bool foundDynamicExpandedDim = false;
|
||||
for (auto expandedShape :
|
||||
expansionInfo.getExpandedShapeOfDim(originalShape.index())) {
|
||||
if (ShapedType::isDynamic(expandedShape)) {
|
||||
if (foundDynamicExpandedDim) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
genericOp,
|
||||
"cannot expanded dynamic dims into multiple dynamic dims");
|
||||
}
|
||||
foundDynamicExpandedDim = true;
|
||||
}
|
||||
}
|
||||
if (!foundDynamicExpandedDim) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
genericOp, "dynamic dim expansion needs at least one dynamic dim "
|
||||
"in result shape");
|
||||
}
|
||||
}
|
||||
|
||||
if (!genericOp.hasIndexSemantics())
|
||||
return success();
|
||||
for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
|
||||
@ -793,13 +768,21 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
|
||||
}
|
||||
if (genericOp.isInputTensor(opOperand)) {
|
||||
AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
|
||||
auto opOperandType = opOperand->get().getType().cast<RankedTensorType>();
|
||||
RankedTensorType expandedOperandType =
|
||||
getExpandedType(opOperand->get().getType().cast<RankedTensorType>(),
|
||||
indexingMap, expansionInfo);
|
||||
getExpandedType(opOperandType, indexingMap, expansionInfo);
|
||||
if (expandedOperandType != opOperand->get().getType()) {
|
||||
// Reshape the operand to get the right type.
|
||||
SmallVector<ReassociationIndices> reassociation =
|
||||
getReassociationForExpansion(indexingMap, expansionInfo);
|
||||
if (failed(reshapeLikeShapesAreCompatible(
|
||||
[&](const Twine &msg) {
|
||||
return rewriter.notifyMatchFailure(genericOp, msg);
|
||||
},
|
||||
opOperandType.getShape(), expandedOperandType.getShape(),
|
||||
reassociation,
|
||||
/*isExpandingReshape=*/true)))
|
||||
return llvm::None;
|
||||
expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
|
||||
genericOp.getLoc(), expandedOperandType, opOperand->get(),
|
||||
reassociation));
|
||||
@ -813,12 +796,20 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
|
||||
SmallVector<Value> outputs;
|
||||
for (OpOperand *opOperand : genericOp.getOutputOperands()) {
|
||||
AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
|
||||
auto opOperandType = opOperand->get().getType().cast<RankedTensorType>();
|
||||
RankedTensorType expandedOutputType =
|
||||
getExpandedType(opOperand->get().getType().cast<RankedTensorType>(),
|
||||
indexingMap, expansionInfo);
|
||||
getExpandedType(opOperandType, indexingMap, expansionInfo);
|
||||
if (expandedOutputType != opOperand->get().getType()) {
|
||||
SmallVector<ReassociationIndices> reassociation =
|
||||
getReassociationForExpansion(indexingMap, expansionInfo);
|
||||
if (failed(reshapeLikeShapesAreCompatible(
|
||||
[&](const Twine &msg) {
|
||||
return rewriter.notifyMatchFailure(genericOp, msg);
|
||||
},
|
||||
opOperandType.getShape(), expandedOutputType.getShape(),
|
||||
reassociation,
|
||||
/*isExpandingReshape=*/true)))
|
||||
return llvm::None;
|
||||
outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
|
||||
genericOp.getLoc(), expandedOutputType, opOperand->get(),
|
||||
reassociation));
|
||||
|
@ -276,3 +276,45 @@ bool mlir::isReassociationValid(ArrayRef<AffineMap> reassociation,
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
LogicalResult mlir::reshapeLikeShapesAreCompatible(
|
||||
function_ref<LogicalResult(const Twine &)> emitError,
|
||||
ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
|
||||
ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape) {
|
||||
unsigned expandedDimStart = 0;
|
||||
for (const auto &map : llvm::enumerate(reassociationMaps)) {
|
||||
Optional<int64_t> dynamicShape;
|
||||
int64_t linearizedStaticShape = 1;
|
||||
for (const auto &dim : llvm::enumerate(
|
||||
expandedShape.slice(expandedDimStart, map.value().size()))) {
|
||||
if (ShapedType::isDynamic(dim.value())) {
|
||||
if (isExpandingReshape && dynamicShape) {
|
||||
return emitError("invalid to have a single dimension (" +
|
||||
Twine(map.index()) +
|
||||
") expanded into multiple dynamic dims (" +
|
||||
Twine(expandedDimStart + dynamicShape.getValue()) +
|
||||
"," + Twine(expandedDimStart + dim.index()) + ")");
|
||||
}
|
||||
dynamicShape = dim.index();
|
||||
} else {
|
||||
linearizedStaticShape *= dim.value();
|
||||
}
|
||||
}
|
||||
if (dynamicShape) {
|
||||
if (!ShapedType::isDynamic(collapsedShape[map.index()])) {
|
||||
return emitError(
|
||||
"expected dimension " + Twine(map.index()) +
|
||||
" of collapsed type to be dynamic since one or more of the "
|
||||
"corresponding dimensions in the expanded type is dynamic");
|
||||
}
|
||||
} else {
|
||||
if (collapsedShape[map.index()] != linearizedStaticShape) {
|
||||
return emitError("expected dimension " + Twine(map.index()) +
|
||||
" of collapsed type to be static value of " +
|
||||
Twine(linearizedStaticShape));
|
||||
}
|
||||
}
|
||||
expandedDimStart += map.value().size();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
@ -530,3 +530,30 @@ func @no_fuse_dynamic_dims(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
|
||||
// CHECK: %[[GENERIC:.+]] = linalg.generic
|
||||
// CHECK-SAME: ins(%[[RESHAPE]] : tensor<?xf32>)
|
||||
// CHECK: return %[[GENERIC]]
|
||||
|
||||
// -----
|
||||
|
||||
func @no_fuse_mismatched_dynamism(%arg0: tensor<1x1xi64>, %arg1: tensor<?xi64>) -> tensor<1xi64> {
|
||||
%0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<1x1xi64> into tensor<1xi64>
|
||||
%1 = linalg.init_tensor [1] : tensor<1xi64>
|
||||
%2 = linalg.generic
|
||||
{indexing_maps = [affine_map<(d0) -> (d0)>,
|
||||
affine_map<(d0) -> (d0)>,
|
||||
affine_map<(d0) -> (d0)>],
|
||||
iterator_types = ["parallel"]}
|
||||
ins(%0, %arg1 : tensor<1xi64>, tensor<?xi64>)
|
||||
outs(%1 : tensor<1xi64>) {
|
||||
^bb0(%arg4: i64, %arg5: i64, %arg6: i64): // no predecessors
|
||||
%3 = arith.addi %arg4, %arg5 : i64
|
||||
linalg.yield %3 : i64
|
||||
} -> tensor<1xi64>
|
||||
return %2 : tensor<1xi64>
|
||||
}
|
||||
|
||||
// CHECK: func @no_fuse_mismatched_dynamism
|
||||
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1xi64>
|
||||
// CHECK-SAME: %[[ARG1:.+]]: tensor<?xi64>
|
||||
// CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]]
|
||||
// CHECK: %[[GENERIC:.+]] = linalg.generic
|
||||
// CHECK-SAME: ins(%[[RESHAPE]], %[[ARG1]] : tensor<1xi64>, tensor<?xi64>)
|
||||
// CHECK: return %[[GENERIC]]
|
||||
|
Loading…
x
Reference in New Issue
Block a user