[mlir] Fix Regions takeBody method if the region is not empty

The current implementation of takeBody first clears the Region, before then taking ownership of the blocks of the other regions. The issue here however, is that when clearing the region, it does not take into account references of operations to each other. In particular, blocks are deleted from front to back, and operations within a block are very likely to be deleted despite still having uses, causing an assertion to trigger [0].

This patch fixes that issue by simply calling dropAllReferences()before clearing the blocks.

[0] 9a8bb4bc63/mlir/lib/IR/Operation.cpp (L154)

Differential Revision: https://reviews.llvm.org/D123913
This commit is contained in:
Markus Böck 2022-04-21 15:32:21 +02:00
parent 95d77383f2
commit 850b2c6b3c
5 changed files with 72 additions and 0 deletions

View File

@ -240,6 +240,7 @@ public:
/// Takes body of another region (that region will have no body after this /// Takes body of another region (that region will have no body after this
/// operation completes). The current body of this region is cleared. /// operation completes). The current body of this region is cleared.
void takeBody(Region &other) { void takeBody(Region &other) {
dropAllReferences();
blocks.clear(); blocks.clear();
blocks.splice(blocks.end(), other.getBlocks()); blocks.splice(blocks.end(), other.getBlocks());
} }

View File

@ -0,0 +1,23 @@
// RUN: mlir-opt -allow-unregistered-dialect %s --test-take-body -split-input-file
func @foo() {
%0 = "test.foo"() : () -> i32
cf.br ^header
^header:
cf.br ^body
^body:
"test.use"(%0) : (i32) -> ()
cf.br ^header
}
func private @bar() {
return
}
// CHECK-LABEL: func @foo
// CHECK-NEXT: return
// CHECK-LABEL: func private @bar()
// CHECK-NOT: {

View File

@ -15,6 +15,7 @@ add_mlir_library(MLIRTestIR
TestSideEffects.cpp TestSideEffects.cpp
TestSlicing.cpp TestSlicing.cpp
TestSymbolUses.cpp TestSymbolUses.cpp
TestRegions.cpp
TestTypes.cpp TestTypes.cpp
TestVisitors.cpp TestVisitors.cpp
TestVisitorsGeneric.cpp TestVisitorsGeneric.cpp

View File

@ -0,0 +1,45 @@
//===- TestRegions.cpp - Pass to test Region's methods --------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
namespace {
/// This is a test pass that tests Region's takeBody method by making the first
/// function take the body of the second.
struct TakeBodyPass
: public PassWrapper<TakeBodyPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TakeBodyPass)
StringRef getArgument() const final { return "test-take-body"; }
StringRef getDescription() const final { return "Test Region's takeBody"; }
void runOnOperation() override {
auto module = getOperation();
SmallVector<func::FuncOp> functions =
llvm::to_vector(module.getOps<func::FuncOp>());
if (functions.size() != 2) {
module.emitError("Expected only two functions in test");
signalPassFailure();
return;
}
functions[0].getBody().takeBody(functions[1].getBody());
}
};
} // namespace
namespace mlir {
void registerRegionTestPasses() { PassRegistration<TakeBodyPass>(); }
} // namespace mlir

View File

@ -37,6 +37,7 @@ void registerShapeFunctionTestPasses();
void registerSideEffectTestPasses(); void registerSideEffectTestPasses();
void registerSliceAnalysisTestPass(); void registerSliceAnalysisTestPass();
void registerSymbolTestPasses(); void registerSymbolTestPasses();
void registerRegionTestPasses();
void registerTestAffineDataCopyPass(); void registerTestAffineDataCopyPass();
void registerTestAffineLoopUnswitchingPass(); void registerTestAffineLoopUnswitchingPass();
void registerTestAllReduceLoweringPass(); void registerTestAllReduceLoweringPass();
@ -128,6 +129,7 @@ void registerTestPasses() {
registerSideEffectTestPasses(); registerSideEffectTestPasses();
registerSliceAnalysisTestPass(); registerSliceAnalysisTestPass();
registerSymbolTestPasses(); registerSymbolTestPasses();
registerRegionTestPasses();
registerTestAffineDataCopyPass(); registerTestAffineDataCopyPass();
registerTestAffineLoopUnswitchingPass(); registerTestAffineLoopUnswitchingPass();
registerTestAllReduceLoweringPass(); registerTestAllReduceLoweringPass();