[mlir][RegionBranchOpInterface] explicitly check for existance of block terminator (#76831)

This commit is contained in:
Maksim Levental 2024-01-04 14:43:52 -06:00 committed by GitHub
parent 58f1640635
commit a0c19bd455
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 80 additions and 4 deletions

View File

@ -177,9 +177,10 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
SmallVector<RegionBranchTerminatorOpInterface> regionReturnOps;
for (Block &block : region)
if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(
block.getTerminator()))
regionReturnOps.push_back(terminator);
if (!block.empty())
if (auto terminator =
dyn_cast<RegionBranchTerminatorOpInterface>(block.back()))
regionReturnOps.push_back(terminator);
// If there is no return-like terminator, the op itself should verify
// type consistency.

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s
// RUN: mlir-opt %s -split-input-file
func.func @test_ops_verify(%arg: i32) -> f32 {
%0 = "test.constant"() { value = 5.3 : f32 } : () -> f32
@ -8,3 +8,16 @@ func.func @test_ops_verify(%arg: i32) -> f32 {
}
return %1 : f32
}
// -----
func.func @test_no_terminator(%arg: index) {
test.switch_with_no_break %arg
case 0 {
^bb:
}
case 1 {
^bb:
}
return
}

View File

@ -53,6 +53,7 @@ using namespace test;
Attribute MyPropStruct::asAttribute(MLIRContext *ctx) const {
return StringAttr::get(ctx, content);
}
LogicalResult
MyPropStruct::setFromAttr(MyPropStruct &prop, Attribute attr,
function_ref<InFlightDiagnostic()> emitError) {
@ -64,6 +65,7 @@ MyPropStruct::setFromAttr(MyPropStruct &prop, Attribute attr,
prop.content = strAttr.getValue();
return success();
}
llvm::hash_code MyPropStruct::hash() const {
return hash_value(StringRef(content));
}
@ -127,6 +129,12 @@ static void customPrintProperties(OpAsmPrinter &p,
const VersionedProperties &prop);
static ParseResult customParseProperties(OpAsmParser &parser,
VersionedProperties &prop);
static ParseResult
parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases,
SmallVectorImpl<std::unique_ptr<Region>> &caseRegions);
static void printSwitchCases(OpAsmPrinter &p, Operation *op,
DenseI64ArrayAttr cases, RegionRange caseRegions);
void test::registerTestDialect(DialectRegistry &registry) {
registry.insert<TestDialect>();
@ -230,6 +238,7 @@ void TestDialect::initialize() {
// unregistered op.
fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback;
}
TestDialect::~TestDialect() {
delete static_cast<TestOpEffectInterfaceFallback *>(
fallbackEffectOpInterfaces);
@ -1013,6 +1022,13 @@ LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point) {
return getNextIterArgMutable();
}
//===----------------------------------------------------------------------===//
// SwitchWithNoBreakOp
//===----------------------------------------------------------------------===//
void TestNoTerminatorOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {}
//===----------------------------------------------------------------------===//
// SingleNoTerminatorCustomAsmOp
//===----------------------------------------------------------------------===//
@ -1160,6 +1176,7 @@ setPropertiesFromAttribute(PropertiesWithCustomPrint &prop, Attribute attr,
prop.value = valueAttr.getValue().getSExtValue();
return success();
}
static DictionaryAttr
getPropertiesAsAttribute(MLIRContext *ctx,
const PropertiesWithCustomPrint &prop) {
@ -1169,14 +1186,17 @@ getPropertiesAsAttribute(MLIRContext *ctx,
attrs.push_back(b.getNamedAttr("value", b.getI32IntegerAttr(prop.value)));
return b.getDictionaryAttr(attrs);
}
static llvm::hash_code computeHash(const PropertiesWithCustomPrint &prop) {
return llvm::hash_combine(prop.value, StringRef(*prop.label));
}
static void customPrintProperties(OpAsmPrinter &p,
const PropertiesWithCustomPrint &prop) {
p.printKeywordOrString(*prop.label);
p << " is " << prop.value;
}
static ParseResult customParseProperties(OpAsmParser &parser,
PropertiesWithCustomPrint &prop) {
std::string label;
@ -1186,6 +1206,31 @@ static ParseResult customParseProperties(OpAsmParser &parser,
prop.label = std::make_shared<std::string>(std::move(label));
return success();
}
static ParseResult
parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases,
SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
SmallVector<int64_t> caseValues;
while (succeeded(p.parseOptionalKeyword("case"))) {
int64_t value;
Region &region = *caseRegions.emplace_back(std::make_unique<Region>());
if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{}))
return failure();
caseValues.push_back(value);
}
cases = p.getBuilder().getDenseI64ArrayAttr(caseValues);
return success();
}
static void printSwitchCases(OpAsmPrinter &p, Operation *op,
DenseI64ArrayAttr cases, RegionRange caseRegions) {
for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
p.printNewline();
p << "case " << value << ' ';
p.printRegion(*region, /*printEntryBlockArgs=*/false);
}
}
static LogicalResult
setPropertiesFromAttribute(VersionedProperties &prop, Attribute attr,
function_ref<InFlightDiagnostic()> emitError) {
@ -1209,6 +1254,7 @@ setPropertiesFromAttribute(VersionedProperties &prop, Attribute attr,
prop.value2 = value2Attr.getValue().getSExtValue();
return success();
}
static DictionaryAttr
getPropertiesAsAttribute(MLIRContext *ctx, const VersionedProperties &prop) {
SmallVector<NamedAttribute> attrs;
@ -1217,13 +1263,16 @@ getPropertiesAsAttribute(MLIRContext *ctx, const VersionedProperties &prop) {
attrs.push_back(b.getNamedAttr("value2", b.getI32IntegerAttr(prop.value2)));
return b.getDictionaryAttr(attrs);
}
static llvm::hash_code computeHash(const VersionedProperties &prop) {
return llvm::hash_combine(prop.value1, prop.value2);
}
static void customPrintProperties(OpAsmPrinter &p,
const VersionedProperties &prop) {
p << prop.value1 << " | " << prop.value2;
}
static ParseResult customParseProperties(OpAsmParser &parser,
VersionedProperties &prop) {
if (parser.parseInteger(prop.value1) || parser.parseVerticalBar() ||
@ -1393,6 +1442,7 @@ void TestVersionedOpA::writeProperties(::mlir::DialectBytecodeWriter &writer) {
prop.value2 = value2;
return success();
}
void TestOpWithVersionedProperties::writeToMlirBytecode(
::mlir::DialectBytecodeWriter &writer,
const test::VersionedProperties &prop) {

View File

@ -2213,6 +2213,18 @@ def LoopBlockTerminatorOp : TEST_Op<"loop_block_term",
}];
}
def TestNoTerminatorOp : TEST_Op<"switch_with_no_break", [
NoTerminator,
DeclareOpInterfaceMethods<RegionBranchOpInterface, ["getSuccessorRegions"]>
]> {
let arguments = (ins Index:$arg, DenseI64ArrayAttr:$cases);
let regions = (region VariadicRegion<SizedRegion<1>>:$caseRegions);
let assemblyFormat = [{
$arg attr-dict custom<SwitchCases>($cases, $caseRegions)
}];
}
//===----------------------------------------------------------------------===//
// Test TableGen generated build() methods
//===----------------------------------------------------------------------===//