[mlir][Linalg] Fold/erase self-copy linalg.copy on buffers

Differential Revision: https://reviews.llvm.org/D155203
This commit is contained in:
Nicolas Vasilache 2023-07-13 14:06:59 +02:00
parent 1377179396
commit 39427a4fbb
4 changed files with 46 additions and 7 deletions

View File

@ -9,6 +9,8 @@ metadata: !LinalgOpMetadata
Numeric casting is performed on the input operand, promoting it to the same
data type as the accumulator/output.
defines:
- hasCanonicalizer
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig

View File

@ -224,11 +224,11 @@ parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
if (addOperandSegmentSizes) {
// This is a bit complex because we're trying to be backward compatible with
// operation syntax that mix the inherent attributes and the discardable ones
// in the same dictionary.
// If the properties are used, we append the operand_segment_sizes there directly.
// Otherwise we append it to the discardable attributes dictionary where it is
// handled by the generic Operation::create(...) method.
// operation syntax that mix the inherent attributes and the discardable
// ones in the same dictionary. If the properties are used, we append the
// operand_segment_sizes there directly. Otherwise we append it to the
// discardable attributes dictionary where it is handled by the generic
// Operation::create(...) method.
if (result.propertiesAttr) {
NamedAttrList attrs = llvm::cast<DictionaryAttr>(result.propertiesAttr);
attrs.append("operand_segment_sizes",
@ -539,6 +539,33 @@ private:
} // namespace
//===----------------------------------------------------------------------===//
// CopyOp
//===----------------------------------------------------------------------===//
namespace {
struct EraseSelfCopyOnBuffers : OpRewritePattern<CopyOp> {
using OpRewritePattern<CopyOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CopyOp copyOp,
PatternRewriter &rewriter) const override {
if (!copyOp.hasBufferSemantics())
return rewriter.notifyMatchFailure(copyOp,
"does not have buffer semantics");
if (copyOp.getInputs().front() != copyOp.getOutputs().front())
return rewriter.notifyMatchFailure(copyOp, "not a self copy");
rewriter.eraseOp(copyOp);
return success();
}
};
} // namespace
void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<EraseSelfCopyOnBuffers>(context);
}
//===----------------------------------------------------------------------===//
// FillOp
//===----------------------------------------------------------------------===//
@ -2114,8 +2141,7 @@ static void createNewOperandWithStaticSizes(
for (unsigned i = 0; i < sourceShape.size(); i++) {
int64_t dimShape = sourceShape[i];
AffineExpr dimExpr = sourceMap.getResult(i);
if (!affineExprToSize.contains(dimExpr) ||
!sourceType.isDynamicDim(i)) {
if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
newShape.push_back(dimShape);
continue;
}

View File

@ -17,6 +17,7 @@ def copy(
Numeric casting is performed on the input operand, promoting it to the same
data type as the accumulator/output.
"""
defines(Canonicalizer)
O[None] = cast(U, I[None])

View File

@ -47,6 +47,16 @@ func.func @dce_zero_memref(%arg0 : memref<0xf32>, %arg1: tensor<0xf32>) -> tenso
// -----
func.func @dce_self_linalg_copy(%arg0 : memref<?xf32>) {
linalg.copy ins(%arg0: memref<?xf32>) outs(%arg0: memref<?xf32>)
return
}
// CHECK-LABEL: @dce_self_linalg_copy
// CHECK-NOT: copy
// -----
// CHECK-LABEL: func @tensor.cast(
func.func @tensor.cast(%a : tensor<3x4xf32>, %b : tensor<4x?xf32>, %c : tensor<3x?xf32>)
-> tensor<3x?xf32>