Add canonicalization to remove AllocOps if there are no uses. AllocOp has side effects on the heap, but can still be deleted if it has zero uses.

PiperOrigin-RevId: 229596556
This commit is contained in:
River Riddle 2019-01-16 11:40:37 -08:00 committed by jpienaar
parent a5827fc91d
commit ada685f352
2 changed files with 31 additions and 5 deletions

View File

@ -302,11 +302,30 @@ struct SimplifyAllocConst : public RewritePattern {
rewriter.replaceOp(op, {resultCast}, droppedOperands);
}
};
/// Fold alloc instructions with no uses. Alloc has side effects on the heap,
/// but can still be deleted if it has zero uses.
struct SimplifyDeadAlloc : public RewritePattern {
SimplifyDeadAlloc(MLIRContext *context)
: RewritePattern(AllocOp::getOperationName(), 1, context) {}
PatternMatchResult match(OperationInst *op) const override {
auto alloc = op->cast<AllocOp>();
// Check if the alloc'ed value has no uses.
return alloc->use_empty() ? matchSuccess() : matchFailure();
}
void rewrite(OperationInst *op, PatternRewriter &rewriter) const override {
// Erase the alloc operation.
op->erase();
}
};
} // end anonymous namespace.
void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.push_back(std::make_unique<SimplifyAllocConst>(context));
results.push_back(std::make_unique<SimplifyDeadAlloc>(context));
}
//===----------------------------------------------------------------------===//

View File

@ -166,9 +166,16 @@ func @alloc_const_fold() -> memref<?xf32> {
return %a : memref<?xf32>
}
// CHECK-LABEL: func @dead_alloc_fold
func @dead_alloc_fold() {
// CHECK-NEXT: return
%c4 = constant 4 : index
%a = alloc(%c4) : memref<?xf32>
return
}
// CHECK-LABEL: func @dyn_shape_fold(%arg0: index, %arg1: index)
func @dyn_shape_fold(%L : index, %M : index) -> memref<? x ? x f32> {
func @dyn_shape_fold(%L : index, %M : index) -> (memref<? x ? x i32>, memref<? x ? x f32>) {
// CHECK: %c0 = constant 0 : index
%zero = constant 0 : index
// The constants below disappear after they propagate into shapes.
@ -189,17 +196,17 @@ func @dyn_shape_fold(%L : index, %M : index) -> memref<? x ? x f32> {
for %i = 0 to %L {
// CHECK-NEXT: for %i1 =
for %j = 0 to 10 {
// CHECK-NEXT: %3 = load %0[%i0, %i1] : memref<?x1024xf32>
// CHECK-NEXT: store %3, %1[%c0, %c0, %i0, %i1, %c0] : memref<4x1024x8x512x?xf32>
// CHECK-NEXT: %4 = load %0[%i0, %i1] : memref<?x1024xf32>
// CHECK-NEXT: store %4, %1[%c0, %c0, %i0, %i1, %c0] : memref<4x1024x8x512x?xf32>
%v = load %a[%i, %j] : memref<?x?xf32>
store %v, %b[%zero, %zero, %i, %j, %zero] : memref<4x?x8x?x?xf32>
}
}
// CHECK: %4 = alloc() : memref<9x9xf32>
// CHECK: %5 = alloc() : memref<9x9xf32>
%d = alloc(%nine, %nine) : memref<? x ? x f32>
return %d : memref<? x ? x f32>
return %c, %d : memref<? x ? x i32>, memref<? x ? x f32>
}
// CHECK-LABEL: func @merge_constants