diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index 34e6cd08ea3c..dab6e106f951 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -907,6 +907,10 @@ public: virtual Operation *parseGenericOperation(Block *insertBlock, Block::iterator insertPt) = 0; + /// Parse the name of an operation, in the custom form. On success, return a + /// an object of type 'OperationName'. Otherwise, failure is returned. + virtual FailureOr parseCustomOperationName() = 0; + //===--------------------------------------------------------------------===// // Operand Parsing //===--------------------------------------------------------------------===// @@ -918,6 +922,20 @@ public: unsigned number; // Number, e.g. 12 for an operand like %xyz#12 }; + /// Parse different components, viz., use-info of operand(s), successor(s), + /// region(s), attribute(s) and function-type, of the generic form of an + /// operation instance and populate the input operation-state 'result' with + /// those components. If any of the components is explicitly provided, then + /// skip parsing that component. + virtual ParseResult parseGenericOperationAfterOpName( + OperationState &result, + Optional> parsedOperandType = llvm::None, + Optional> parsedSuccessors = llvm::None, + Optional>> parsedRegions = + llvm::None, + Optional> parsedAttributes = llvm::None, + Optional parsedFnType = llvm::None) = 0; + /// Parse a single operand. virtual ParseResult parseOperand(OperandType &result) = 0; diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index a1ba6a8010db..1818e420edd0 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -310,6 +310,20 @@ public: /// Parse an operation instance that is in the generic form. Operation *parseGenericOperation(); + /// Parse different components, viz., use-info of operand(s), successor(s), + /// region(s), attribute(s) and function-type, of the generic form of an + /// operation instance and populate the input operation-state 'result' with + /// those components. If any of the components is explicitly provided, then + /// skip parsing that component. + ParseResult parseGenericOperationAfterOpName( + OperationState &result, + Optional> parsedOperandUseInfo = llvm::None, + Optional> parsedSuccessors = llvm::None, + Optional>> parsedRegions = + llvm::None, + Optional> parsedAttributes = llvm::None, + Optional parsedFnType = llvm::None); + /// Parse an operation instance that is in the generic form and insert it at /// the provided insertion point. Operation *parseGenericOperation(Block *insertBlock, @@ -335,6 +349,10 @@ public: /// resultInfo specifies information about the "%name =" specifiers. Operation *parseCustomOperation(ArrayRef resultIDs); + /// Parse the name of an operation, in the custom form. On success, return a + /// an object of type 'OperationName'. Otherwise, failure is returned. + FailureOr parseCustomOperationName(); + //===--------------------------------------------------------------------===// // Region Parsing //===--------------------------------------------------------------------===// @@ -972,6 +990,105 @@ struct CleanupOpStateRegions { }; } // namespace +ParseResult OperationParser::parseGenericOperationAfterOpName( + OperationState &result, Optional> parsedOperandUseInfo, + Optional> parsedSuccessors, + Optional>> parsedRegions, + Optional> parsedAttributes, + Optional parsedFnType) { + + // Parse the operand list, if not explicitly provided. + SmallVector opInfo; + if (!parsedOperandUseInfo) { + if (parseToken(Token::l_paren, "expected '(' to start operand list") || + parseOptionalSSAUseList(opInfo) || + parseToken(Token::r_paren, "expected ')' to end operand list")) { + return failure(); + } + parsedOperandUseInfo = opInfo; + } + + // Parse the successor list, if not explicitly provided. + if (!parsedSuccessors) { + if (getToken().is(Token::l_square)) { + // Check if the operation is not a known terminator. + if (!result.name.mightHaveTrait()) + return emitError("successors in non-terminator"); + + SmallVector successors; + if (parseSuccessors(successors)) + return failure(); + result.addSuccessors(successors); + } + } else { + result.addSuccessors(*parsedSuccessors); + } + + // Parse the region list, if not explicitly provided. + if (!parsedRegions) { + if (consumeIf(Token::l_paren)) { + do { + // Create temporary regions with the top level region as parent. + result.regions.emplace_back(new Region(topLevelOp)); + if (parseRegion(*result.regions.back(), /*entryArguments=*/{})) + return failure(); + } while (consumeIf(Token::comma)); + if (parseToken(Token::r_paren, "expected ')' to end region list")) + return failure(); + } + } else { + result.addRegions(*parsedRegions); + } + + // Parse the attributes, if not explicitly provided. + if (!parsedAttributes) { + if (getToken().is(Token::l_brace)) { + if (parseAttributeDict(result.attributes)) + return failure(); + } + } else { + result.addAttributes(*parsedAttributes); + } + + // Parse the operation type, if not explicitly provided. + Location typeLoc = result.location; + if (!parsedFnType) { + if (parseToken(Token::colon, "expected ':' followed by operation type")) + return failure(); + + typeLoc = getEncodedSourceLocation(getToken().getLoc()); + auto type = parseType(); + if (!type) + return failure(); + auto fnType = type.dyn_cast(); + if (!fnType) + return mlir::emitError(typeLoc, "expected function type"); + + parsedFnType = fnType; + } + + result.addTypes(parsedFnType->getResults()); + + // Check that we have the right number of types for the operands. + ArrayRef operandTypes = parsedFnType->getInputs(); + if (operandTypes.size() != parsedOperandUseInfo->size()) { + auto plural = "s"[parsedOperandUseInfo->size() == 1]; + return mlir::emitError(typeLoc, "expected ") + << parsedOperandUseInfo->size() << " operand type" << plural + << " but had " << operandTypes.size(); + } + + // Resolve all of the operands. + for (unsigned i = 0, e = parsedOperandUseInfo->size(); i != e; ++i) { + result.operands.push_back( + resolveSSAUse((*parsedOperandUseInfo)[i], operandTypes[i])); + if (!result.operands.back()) + return failure(); + } + + return success(); +} + Operation *OperationParser::parseGenericOperation() { // Get location information for the operation. auto srcLocation = getEncodedSourceLocation(getToken().getLoc()); @@ -985,6 +1102,7 @@ Operation *OperationParser::parseGenericOperation() { consumeToken(Token::string); OperationState result(srcLocation, name); + CleanupOpStateRegions guard{result}; // Lazy load dialects in the context as needed. if (!result.name.isRegistered()) { @@ -1005,73 +1123,8 @@ Operation *OperationParser::parseGenericOperation() { if (state.asmState) state.asmState->startOperationDefinition(result.name); - // Parse the operand list. - SmallVector operandInfos; - if (parseToken(Token::l_paren, "expected '(' to start operand list") || - parseOptionalSSAUseList(operandInfos) || - parseToken(Token::r_paren, "expected ')' to end operand list")) { + if (parseGenericOperationAfterOpName(result)) return nullptr; - } - - // Parse the successor list. - if (getToken().is(Token::l_square)) { - // Check if the operation is not a known terminator. - if (!result.name.mightHaveTrait()) - return emitError("successors in non-terminator"), nullptr; - - SmallVector successors; - if (parseSuccessors(successors)) - return nullptr; - result.addSuccessors(successors); - } - - // Parse the region list. - CleanupOpStateRegions guard{result}; - if (consumeIf(Token::l_paren)) { - do { - // Create temporary regions with the top level region as parent. - result.regions.emplace_back(new Region(topLevelOp)); - if (parseRegion(*result.regions.back(), /*entryArguments=*/{})) - return nullptr; - } while (consumeIf(Token::comma)); - if (parseToken(Token::r_paren, "expected ')' to end region list")) - return nullptr; - } - - if (getToken().is(Token::l_brace)) { - if (parseAttributeDict(result.attributes)) - return nullptr; - } - - if (parseToken(Token::colon, "expected ':' followed by operation type")) - return nullptr; - - auto typeLoc = getToken().getLoc(); - auto type = parseType(); - if (!type) - return nullptr; - auto fnType = type.dyn_cast(); - if (!fnType) - return (emitError(typeLoc, "expected function type"), nullptr); - - result.addTypes(fnType.getResults()); - - // Check that we have the right number of types for the operands. - auto operandTypes = fnType.getInputs(); - if (operandTypes.size() != operandInfos.size()) { - auto plural = "s"[operandInfos.size() == 1]; - return (emitError(typeLoc, "expected ") - << operandInfos.size() << " operand type" << plural - << " but had " << operandTypes.size(), - nullptr); - } - - // Resolve all of the operands. - for (unsigned i = 0, e = operandInfos.size(); i != e; ++i) { - result.operands.push_back(resolveSSAUse(operandInfos[i], operandTypes[i])); - if (!result.operands.back()) - return nullptr; - } // Create the operation and try to parse a location for it. Operation *op = opBuilder.createOperation(result); @@ -1133,6 +1186,37 @@ public: return parser.parseGenericOperation(insertBlock, insertPt); } + FailureOr parseCustomOperationName() final { + return parser.parseCustomOperationName(); + } + + ParseResult parseGenericOperationAfterOpName( + OperationState &result, + Optional> parsedOperandTypes, + Optional> parsedSuccessors, + Optional>> parsedRegions, + Optional> parsedAttributes, + Optional parsedFnType) final { + + // TODO: The types, OperandType and SSAUseInfo, both share the same members + // but in different order. It would be cleaner to make one alias of the + // other, making the following code redundant. + SmallVector parsedOperandUseInfo; + if (parsedOperandTypes) { + for (const OperandType &parsedOperandType : *parsedOperandTypes) + parsedOperandUseInfo.push_back({ + parsedOperandType.name, + parsedOperandType.number, + parsedOperandType.location, + }); + } + + return parser.parseGenericOperationAfterOpName( + result, + parsedOperandTypes ? llvm::makeArrayRef(parsedOperandUseInfo) + : llvm::None, + parsedSuccessors, parsedRegions, parsedAttributes, parsedFnType); + } //===--------------------------------------------------------------------===// // Utilities //===--------------------------------------------------------------------===// @@ -1506,10 +1590,13 @@ private: }; } // end anonymous namespace. -Operation * -OperationParser::parseCustomOperation(ArrayRef resultIDs) { - llvm::SMLoc opLoc = getToken().getLoc(); +FailureOr OperationParser::parseCustomOperationName() { std::string opName = getTokenSpelling().str(); + if (opName.empty()) + return (emitError("empty operation name is invalid"), failure()); + + consumeToken(); + Optional opInfo = RegisteredOperationName::lookup(opName, getContext()); StringRef defaultDialect = getState().defaultDialectStack.back(); @@ -1543,13 +1630,28 @@ OperationParser::parseCustomOperation(ArrayRef resultIDs) { } } + return OperationName(opName, getContext()); +} + +Operation * +OperationParser::parseCustomOperation(ArrayRef resultIDs) { + llvm::SMLoc opLoc = getToken().getLoc(); + + FailureOr opNameInfo = parseCustomOperationName(); + if (failed(opNameInfo)) + return nullptr; + + StringRef opName = opNameInfo->getStringRef(); + Dialect *dialect = opNameInfo->getDialect(); + Optional opInfo = opNameInfo->getRegisteredInfo(); + // This is the actual hook for the custom op parsing, usually implemented by // the op itself (`Op::parse()`). We retrieve it either from the // RegisteredOperationName or from the Dialect. function_ref parseAssemblyFn; bool isIsolatedFromAbove = false; - defaultDialect = ""; + StringRef defaultDialect = ""; if (opInfo) { parseAssemblyFn = opInfo->getParseAssemblyFn(); isIsolatedFromAbove = opInfo->hasTrait(); @@ -1570,16 +1672,14 @@ OperationParser::parseCustomOperation(ArrayRef resultIDs) { auto restoreDefaultDialect = llvm::make_scope_exit( [&]() { getState().defaultDialectStack.pop_back(); }); - consumeToken(); - // If the custom op parser crashes, produce some indication to help // debugging. llvm::PrettyStackTraceFormat fmt("MLIR Parser: custom op parser '%s'", - opName.c_str()); + opNameInfo->getIdentifier().data()); // Get location information for the operation. auto srcLocation = getEncodedSourceLocation(opLoc); - OperationState opState(srcLocation, opName); + OperationState opState(srcLocation, *opNameInfo); // If we are populating the parser state, start a new operation definition. if (state.asmState) diff --git a/mlir/test/IR/pretty_printed_region_op.mlir b/mlir/test/IR/pretty_printed_region_op.mlir new file mode 100644 index 000000000000..c12b26de4cc6 --- /dev/null +++ b/mlir/test/IR/pretty_printed_region_op.mlir @@ -0,0 +1,35 @@ +// RUN: mlir-opt -allow-unregistered-dialect -split-input-file %s | FileCheck %s --check-prefixes=CHECK-CUSTOM,CHECK +// RUN: mlir-opt -allow-unregistered-dialect -mlir-print-op-generic -split-input-file %s | FileCheck %s --check-prefixes=CHECK,CHECK-GENERIC + +// ----- + +func @pretty_printed_region_op(%arg0 : f32, %arg1 : f32) -> (f32) { +// CHECK-CUSTOM: test.pretty_printed_region %arg1, %arg0 start special.op end : (f32, f32) -> f32 +// CHECK-GENERIC: "test.pretty_printed_region"(%arg1, %arg0) +// CHECK-GENERIC: ^bb0(%arg[[x:[0-9]+]]: f32, %arg[[y:[0-9]+]]: f32 +// CHECK-GENERIC: %[[RES:.*]] = "special.op"(%arg[[x]], %arg[[y]]) : (f32, f32) -> f32 +// CHECK-GENERIC: "test.return"(%[[RES]]) : (f32) -> () +// CHECK-GENERIC: : (f32, f32) -> f32 + + %res = test.pretty_printed_region %arg1, %arg0 start special.op end : (f32, f32) -> (f32) loc("some_NameLoc") + return %res : f32 +} + +// ----- + +func @pretty_printed_region_op(%arg0 : f32, %arg1 : f32) -> (f32) { +// CHECK-CUSTOM: test.pretty_printed_region %arg1, %arg0 +// CHECK-GENERIC: "test.pretty_printed_region"(%arg1, %arg0) +// CHECK: ^bb0(%arg[[x:[0-9]+]]: f32, %arg[[y:[0-9]+]]: f32): +// CHECK: %[[RES:.*]] = "non.special.op"(%arg[[x]], %arg[[y]]) : (f32, f32) -> f32 +// CHECK: "test.return"(%[[RES]]) : (f32) -> () +// CHECK: : (f32, f32) -> f32 + + %0 = test.pretty_printed_region %arg1, %arg0 ( { + ^bb0(%arg2: f32, %arg3: f32): + %1 = "non.special.op"(%arg2, %arg3) : (f32, f32) -> f32 + "test.return"(%1) : (f32) -> () + }) : (f32, f32) -> f32 + return %0 : f32 +} + diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp index e045bb91ab18..cb09ea45795b 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -720,6 +720,107 @@ static void print(OpAsmPrinter &p, WrappingRegionOp op) { p.printGenericOp(&op.getRegion().front().front()); } +//===----------------------------------------------------------------------===// +// Test PrettyPrintedRegionOp - exercising the following parser APIs +// parseGenericOperationAfterOpName +// parseCustomOperationName +//===----------------------------------------------------------------------===// + +static ParseResult parsePrettyPrintedRegionOp(OpAsmParser &parser, + OperationState &result) { + + llvm::SMLoc loc = parser.getCurrentLocation(); + Location currLocation = parser.getEncodedSourceLoc(loc); + + // Parse the operands. + SmallVector operands; + if (parser.parseOperandList(operands)) + return failure(); + + // Check if we are parsing the pretty-printed version + // test.pretty_printed_region start end : + // Else fallback to parsing the "non pretty-printed" version. + if (!succeeded(parser.parseOptionalKeyword("start"))) + return parser.parseGenericOperationAfterOpName( + result, llvm::makeArrayRef(operands)); + + FailureOr parseOpNameInfo = parser.parseCustomOperationName(); + if (failed(parseOpNameInfo)) + return failure(); + + StringRef innerOpName = parseOpNameInfo->getStringRef(); + + FunctionType opFntype; + Optional explicitLoc; + if (parser.parseKeyword("end") || parser.parseColon() || + parser.parseType(opFntype) || + parser.parseOptionalLocationSpecifier(explicitLoc)) + return failure(); + + // If location of the op is explicitly provided, then use it; Else use + // the parser's current location. + Location opLoc = explicitLoc.getValueOr(currLocation); + + // Derive the SSA-values for op's operands. + if (parser.resolveOperands(operands, opFntype.getInputs(), loc, + result.operands)) + return failure(); + + // Add a region for op. + Region ®ion = *result.addRegion(); + + // Create a basic-block inside op's region. + Block &block = region.emplaceBlock(); + + // Create and insert an "inner-op" operation in the block. + // Just for testing purposes, we can assume that inner op is a binary op with + // result and operand types all same as the test-op's first operand. + Type innerOpType = opFntype.getInput(0); + Value lhs = block.addArgument(innerOpType, opLoc); + Value rhs = block.addArgument(innerOpType, opLoc); + + OpBuilder builder(parser.getBuilder().getContext()); + builder.setInsertionPointToStart(&block); + + OperationState innerOpState(opLoc, innerOpName); + innerOpState.operands.push_back(lhs); + innerOpState.operands.push_back(rhs); + innerOpState.addTypes(innerOpType); + + Operation *innerOp = builder.createOperation(innerOpState); + + // Insert a return statement in the block returning the inner-op's result. + builder.create(innerOp->getLoc(), innerOp->getResults()); + + // Populate the op operation-state with result-type and location. + result.addTypes(opFntype.getResults()); + result.location = innerOp->getLoc(); + + return success(); +} + +static void print(OpAsmPrinter &p, PrettyPrintedRegionOp op) { + p << ' '; + p.printOperands(op.getOperands()); + + Operation &innerOp = op.getRegion().front().front(); + // Assuming that region has a single non-terminator inner-op, if the inner-op + // meets some criteria (which in this case is a simple one based on the name + // of inner-op), then we can print the entire region in a succinct way. + // Here we assume that the prototype of "special.op" can be trivially derived + // while parsing it back. + if (innerOp.getName().getStringRef().equals("special.op")) { + p << " start special.op end"; + } else { + p << " ("; + p.printRegion(op.getRegion()); + p << ")"; + } + + p << " : "; + p.printFunctionalType(op); +} + //===----------------------------------------------------------------------===// // Test PolyForOp - parse list of region arguments. //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 6a06596838db..8ac9049ee486 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1630,6 +1630,25 @@ def WrappingRegionOp : TEST_Op<"wrapping_region", let printer = [{ return ::print(p, *this); }]; } +def PrettyPrintedRegionOp : TEST_Op<"pretty_printed_region", + [SingleBlockImplicitTerminator<"TestReturnOp">]> { + let summary = "pretty_printed_region operation"; + let description = [{ + Test-op can be printed either in a "pretty" or "non-pretty" way based on + some criteria. The custom parser parsers both the versions while testing + APIs: parseCustomOperationName & parseGenericOperationAfterOpName. + }]; + let arguments = (ins + AnyType:$input1, + AnyType:$input2 + ); + + let results = (outs AnyType); + let regions = (region SizedRegion<1>:$region); + let parser = [{ return ::parse$cppClass(parser, result); }]; + let printer = [{ return ::print(p, *this); }]; +} + def PolyForOp : TEST_Op<"polyfor"> { let summary = "polyfor operation";