[mlir] Refactoring a few Parser APIs

Refactored two new parser APIs parseGenericOperationAfterOperands and
 parseCustomOperationName out of parseGenericOperation and parseCustomOperation.

Motivation: Sometimes an op can be printed in a special way if certain criteria
is met. While parsing, we need to handle all the versions.
`parseGenericOperationAfterOperands` is handy in situation where we already
parsed the operands and decide to fall back to default parsing.

`parseCustomOperationName` is useful when we need to know details (dialect,
operation name etc.) about a parsed token meant to be an mlir operation.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D113719
This commit is contained in:
Sandeep Dasgupta 2021-11-23 06:05:41 +00:00 committed by Mehdi Amini
parent d5b73a70a0
commit e5a8c8c883
5 changed files with 347 additions and 74 deletions

View File

@ -907,6 +907,10 @@ public:
virtual Operation *parseGenericOperation(Block *insertBlock,
Block::iterator insertPt) = 0;
/// Parse the name of an operation, in the custom form. On success, return a
/// an object of type 'OperationName'. Otherwise, failure is returned.
virtual FailureOr<OperationName> parseCustomOperationName() = 0;
//===--------------------------------------------------------------------===//
// Operand Parsing
//===--------------------------------------------------------------------===//
@ -918,6 +922,20 @@ public:
unsigned number; // Number, e.g. 12 for an operand like %xyz#12
};
/// Parse different components, viz., use-info of operand(s), successor(s),
/// region(s), attribute(s) and function-type, of the generic form of an
/// operation instance and populate the input operation-state 'result' with
/// those components. If any of the components is explicitly provided, then
/// skip parsing that component.
virtual ParseResult parseGenericOperationAfterOpName(
OperationState &result,
Optional<ArrayRef<OperandType>> parsedOperandType = llvm::None,
Optional<ArrayRef<Block *>> parsedSuccessors = llvm::None,
Optional<MutableArrayRef<std::unique_ptr<Region>>> parsedRegions =
llvm::None,
Optional<ArrayRef<NamedAttribute>> parsedAttributes = llvm::None,
Optional<FunctionType> parsedFnType = llvm::None) = 0;
/// Parse a single operand.
virtual ParseResult parseOperand(OperandType &result) = 0;

View File

@ -310,6 +310,20 @@ public:
/// Parse an operation instance that is in the generic form.
Operation *parseGenericOperation();
/// Parse different components, viz., use-info of operand(s), successor(s),
/// region(s), attribute(s) and function-type, of the generic form of an
/// operation instance and populate the input operation-state 'result' with
/// those components. If any of the components is explicitly provided, then
/// skip parsing that component.
ParseResult parseGenericOperationAfterOpName(
OperationState &result,
Optional<ArrayRef<SSAUseInfo>> parsedOperandUseInfo = llvm::None,
Optional<ArrayRef<Block *>> parsedSuccessors = llvm::None,
Optional<MutableArrayRef<std::unique_ptr<Region>>> parsedRegions =
llvm::None,
Optional<ArrayRef<NamedAttribute>> parsedAttributes = llvm::None,
Optional<FunctionType> parsedFnType = llvm::None);
/// Parse an operation instance that is in the generic form and insert it at
/// the provided insertion point.
Operation *parseGenericOperation(Block *insertBlock,
@ -335,6 +349,10 @@ public:
/// resultInfo specifies information about the "%name =" specifiers.
Operation *parseCustomOperation(ArrayRef<ResultRecord> resultIDs);
/// Parse the name of an operation, in the custom form. On success, return a
/// an object of type 'OperationName'. Otherwise, failure is returned.
FailureOr<OperationName> parseCustomOperationName();
//===--------------------------------------------------------------------===//
// Region Parsing
//===--------------------------------------------------------------------===//
@ -972,6 +990,105 @@ struct CleanupOpStateRegions {
};
} // namespace
ParseResult OperationParser::parseGenericOperationAfterOpName(
OperationState &result, Optional<ArrayRef<SSAUseInfo>> parsedOperandUseInfo,
Optional<ArrayRef<Block *>> parsedSuccessors,
Optional<MutableArrayRef<std::unique_ptr<Region>>> parsedRegions,
Optional<ArrayRef<NamedAttribute>> parsedAttributes,
Optional<FunctionType> parsedFnType) {
// Parse the operand list, if not explicitly provided.
SmallVector<SSAUseInfo, 8> opInfo;
if (!parsedOperandUseInfo) {
if (parseToken(Token::l_paren, "expected '(' to start operand list") ||
parseOptionalSSAUseList(opInfo) ||
parseToken(Token::r_paren, "expected ')' to end operand list")) {
return failure();
}
parsedOperandUseInfo = opInfo;
}
// Parse the successor list, if not explicitly provided.
if (!parsedSuccessors) {
if (getToken().is(Token::l_square)) {
// Check if the operation is not a known terminator.
if (!result.name.mightHaveTrait<OpTrait::IsTerminator>())
return emitError("successors in non-terminator");
SmallVector<Block *, 2> successors;
if (parseSuccessors(successors))
return failure();
result.addSuccessors(successors);
}
} else {
result.addSuccessors(*parsedSuccessors);
}
// Parse the region list, if not explicitly provided.
if (!parsedRegions) {
if (consumeIf(Token::l_paren)) {
do {
// Create temporary regions with the top level region as parent.
result.regions.emplace_back(new Region(topLevelOp));
if (parseRegion(*result.regions.back(), /*entryArguments=*/{}))
return failure();
} while (consumeIf(Token::comma));
if (parseToken(Token::r_paren, "expected ')' to end region list"))
return failure();
}
} else {
result.addRegions(*parsedRegions);
}
// Parse the attributes, if not explicitly provided.
if (!parsedAttributes) {
if (getToken().is(Token::l_brace)) {
if (parseAttributeDict(result.attributes))
return failure();
}
} else {
result.addAttributes(*parsedAttributes);
}
// Parse the operation type, if not explicitly provided.
Location typeLoc = result.location;
if (!parsedFnType) {
if (parseToken(Token::colon, "expected ':' followed by operation type"))
return failure();
typeLoc = getEncodedSourceLocation(getToken().getLoc());
auto type = parseType();
if (!type)
return failure();
auto fnType = type.dyn_cast<FunctionType>();
if (!fnType)
return mlir::emitError(typeLoc, "expected function type");
parsedFnType = fnType;
}
result.addTypes(parsedFnType->getResults());
// Check that we have the right number of types for the operands.
ArrayRef<Type> operandTypes = parsedFnType->getInputs();
if (operandTypes.size() != parsedOperandUseInfo->size()) {
auto plural = "s"[parsedOperandUseInfo->size() == 1];
return mlir::emitError(typeLoc, "expected ")
<< parsedOperandUseInfo->size() << " operand type" << plural
<< " but had " << operandTypes.size();
}
// Resolve all of the operands.
for (unsigned i = 0, e = parsedOperandUseInfo->size(); i != e; ++i) {
result.operands.push_back(
resolveSSAUse((*parsedOperandUseInfo)[i], operandTypes[i]));
if (!result.operands.back())
return failure();
}
return success();
}
Operation *OperationParser::parseGenericOperation() {
// Get location information for the operation.
auto srcLocation = getEncodedSourceLocation(getToken().getLoc());
@ -985,6 +1102,7 @@ Operation *OperationParser::parseGenericOperation() {
consumeToken(Token::string);
OperationState result(srcLocation, name);
CleanupOpStateRegions guard{result};
// Lazy load dialects in the context as needed.
if (!result.name.isRegistered()) {
@ -1005,73 +1123,8 @@ Operation *OperationParser::parseGenericOperation() {
if (state.asmState)
state.asmState->startOperationDefinition(result.name);
// Parse the operand list.
SmallVector<SSAUseInfo, 8> operandInfos;
if (parseToken(Token::l_paren, "expected '(' to start operand list") ||
parseOptionalSSAUseList(operandInfos) ||
parseToken(Token::r_paren, "expected ')' to end operand list")) {
if (parseGenericOperationAfterOpName(result))
return nullptr;
}
// Parse the successor list.
if (getToken().is(Token::l_square)) {
// Check if the operation is not a known terminator.
if (!result.name.mightHaveTrait<OpTrait::IsTerminator>())
return emitError("successors in non-terminator"), nullptr;
SmallVector<Block *, 2> successors;
if (parseSuccessors(successors))
return nullptr;
result.addSuccessors(successors);
}
// Parse the region list.
CleanupOpStateRegions guard{result};
if (consumeIf(Token::l_paren)) {
do {
// Create temporary regions with the top level region as parent.
result.regions.emplace_back(new Region(topLevelOp));
if (parseRegion(*result.regions.back(), /*entryArguments=*/{}))
return nullptr;
} while (consumeIf(Token::comma));
if (parseToken(Token::r_paren, "expected ')' to end region list"))
return nullptr;
}
if (getToken().is(Token::l_brace)) {
if (parseAttributeDict(result.attributes))
return nullptr;
}
if (parseToken(Token::colon, "expected ':' followed by operation type"))
return nullptr;
auto typeLoc = getToken().getLoc();
auto type = parseType();
if (!type)
return nullptr;
auto fnType = type.dyn_cast<FunctionType>();
if (!fnType)
return (emitError(typeLoc, "expected function type"), nullptr);
result.addTypes(fnType.getResults());
// Check that we have the right number of types for the operands.
auto operandTypes = fnType.getInputs();
if (operandTypes.size() != operandInfos.size()) {
auto plural = "s"[operandInfos.size() == 1];
return (emitError(typeLoc, "expected ")
<< operandInfos.size() << " operand type" << plural
<< " but had " << operandTypes.size(),
nullptr);
}
// Resolve all of the operands.
for (unsigned i = 0, e = operandInfos.size(); i != e; ++i) {
result.operands.push_back(resolveSSAUse(operandInfos[i], operandTypes[i]));
if (!result.operands.back())
return nullptr;
}
// Create the operation and try to parse a location for it.
Operation *op = opBuilder.createOperation(result);
@ -1133,6 +1186,37 @@ public:
return parser.parseGenericOperation(insertBlock, insertPt);
}
FailureOr<OperationName> parseCustomOperationName() final {
return parser.parseCustomOperationName();
}
ParseResult parseGenericOperationAfterOpName(
OperationState &result,
Optional<ArrayRef<OperandType>> parsedOperandTypes,
Optional<ArrayRef<Block *>> parsedSuccessors,
Optional<MutableArrayRef<std::unique_ptr<Region>>> parsedRegions,
Optional<ArrayRef<NamedAttribute>> parsedAttributes,
Optional<FunctionType> parsedFnType) final {
// TODO: The types, OperandType and SSAUseInfo, both share the same members
// but in different order. It would be cleaner to make one alias of the
// other, making the following code redundant.
SmallVector<OperationParser::SSAUseInfo> parsedOperandUseInfo;
if (parsedOperandTypes) {
for (const OperandType &parsedOperandType : *parsedOperandTypes)
parsedOperandUseInfo.push_back({
parsedOperandType.name,
parsedOperandType.number,
parsedOperandType.location,
});
}
return parser.parseGenericOperationAfterOpName(
result,
parsedOperandTypes ? llvm::makeArrayRef(parsedOperandUseInfo)
: llvm::None,
parsedSuccessors, parsedRegions, parsedAttributes, parsedFnType);
}
//===--------------------------------------------------------------------===//
// Utilities
//===--------------------------------------------------------------------===//
@ -1506,10 +1590,13 @@ private:
};
} // end anonymous namespace.
Operation *
OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
llvm::SMLoc opLoc = getToken().getLoc();
FailureOr<OperationName> OperationParser::parseCustomOperationName() {
std::string opName = getTokenSpelling().str();
if (opName.empty())
return (emitError("empty operation name is invalid"), failure());
consumeToken();
Optional<RegisteredOperationName> opInfo =
RegisteredOperationName::lookup(opName, getContext());
StringRef defaultDialect = getState().defaultDialectStack.back();
@ -1543,13 +1630,28 @@ OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
}
}
return OperationName(opName, getContext());
}
Operation *
OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
llvm::SMLoc opLoc = getToken().getLoc();
FailureOr<OperationName> opNameInfo = parseCustomOperationName();
if (failed(opNameInfo))
return nullptr;
StringRef opName = opNameInfo->getStringRef();
Dialect *dialect = opNameInfo->getDialect();
Optional<RegisteredOperationName> opInfo = opNameInfo->getRegisteredInfo();
// This is the actual hook for the custom op parsing, usually implemented by
// the op itself (`Op::parse()`). We retrieve it either from the
// RegisteredOperationName or from the Dialect.
function_ref<ParseResult(OpAsmParser &, OperationState &)> parseAssemblyFn;
bool isIsolatedFromAbove = false;
defaultDialect = "";
StringRef defaultDialect = "";
if (opInfo) {
parseAssemblyFn = opInfo->getParseAssemblyFn();
isIsolatedFromAbove = opInfo->hasTrait<OpTrait::IsIsolatedFromAbove>();
@ -1570,16 +1672,14 @@ OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
auto restoreDefaultDialect = llvm::make_scope_exit(
[&]() { getState().defaultDialectStack.pop_back(); });
consumeToken();
// If the custom op parser crashes, produce some indication to help
// debugging.
llvm::PrettyStackTraceFormat fmt("MLIR Parser: custom op parser '%s'",
opName.c_str());
opNameInfo->getIdentifier().data());
// Get location information for the operation.
auto srcLocation = getEncodedSourceLocation(opLoc);
OperationState opState(srcLocation, opName);
OperationState opState(srcLocation, *opNameInfo);
// If we are populating the parser state, start a new operation definition.
if (state.asmState)

View File

@ -0,0 +1,35 @@
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file %s | FileCheck %s --check-prefixes=CHECK-CUSTOM,CHECK
// RUN: mlir-opt -allow-unregistered-dialect -mlir-print-op-generic -split-input-file %s | FileCheck %s --check-prefixes=CHECK,CHECK-GENERIC
// -----
func @pretty_printed_region_op(%arg0 : f32, %arg1 : f32) -> (f32) {
// CHECK-CUSTOM: test.pretty_printed_region %arg1, %arg0 start special.op end : (f32, f32) -> f32
// CHECK-GENERIC: "test.pretty_printed_region"(%arg1, %arg0)
// CHECK-GENERIC: ^bb0(%arg[[x:[0-9]+]]: f32, %arg[[y:[0-9]+]]: f32
// CHECK-GENERIC: %[[RES:.*]] = "special.op"(%arg[[x]], %arg[[y]]) : (f32, f32) -> f32
// CHECK-GENERIC: "test.return"(%[[RES]]) : (f32) -> ()
// CHECK-GENERIC: : (f32, f32) -> f32
%res = test.pretty_printed_region %arg1, %arg0 start special.op end : (f32, f32) -> (f32) loc("some_NameLoc")
return %res : f32
}
// -----
func @pretty_printed_region_op(%arg0 : f32, %arg1 : f32) -> (f32) {
// CHECK-CUSTOM: test.pretty_printed_region %arg1, %arg0
// CHECK-GENERIC: "test.pretty_printed_region"(%arg1, %arg0)
// CHECK: ^bb0(%arg[[x:[0-9]+]]: f32, %arg[[y:[0-9]+]]: f32):
// CHECK: %[[RES:.*]] = "non.special.op"(%arg[[x]], %arg[[y]]) : (f32, f32) -> f32
// CHECK: "test.return"(%[[RES]]) : (f32) -> ()
// CHECK: : (f32, f32) -> f32
%0 = test.pretty_printed_region %arg1, %arg0 ( {
^bb0(%arg2: f32, %arg3: f32):
%1 = "non.special.op"(%arg2, %arg3) : (f32, f32) -> f32
"test.return"(%1) : (f32) -> ()
}) : (f32, f32) -> f32
return %0 : f32
}

View File

@ -720,6 +720,107 @@ static void print(OpAsmPrinter &p, WrappingRegionOp op) {
p.printGenericOp(&op.getRegion().front().front());
}
//===----------------------------------------------------------------------===//
// Test PrettyPrintedRegionOp - exercising the following parser APIs
// parseGenericOperationAfterOpName
// parseCustomOperationName
//===----------------------------------------------------------------------===//
static ParseResult parsePrettyPrintedRegionOp(OpAsmParser &parser,
OperationState &result) {
llvm::SMLoc loc = parser.getCurrentLocation();
Location currLocation = parser.getEncodedSourceLoc(loc);
// Parse the operands.
SmallVector<OpAsmParser::OperandType, 2> operands;
if (parser.parseOperandList(operands))
return failure();
// Check if we are parsing the pretty-printed version
// test.pretty_printed_region start <inner-op> end : <functional-type>
// Else fallback to parsing the "non pretty-printed" version.
if (!succeeded(parser.parseOptionalKeyword("start")))
return parser.parseGenericOperationAfterOpName(
result, llvm::makeArrayRef(operands));
FailureOr<OperationName> parseOpNameInfo = parser.parseCustomOperationName();
if (failed(parseOpNameInfo))
return failure();
StringRef innerOpName = parseOpNameInfo->getStringRef();
FunctionType opFntype;
Optional<Location> explicitLoc;
if (parser.parseKeyword("end") || parser.parseColon() ||
parser.parseType(opFntype) ||
parser.parseOptionalLocationSpecifier(explicitLoc))
return failure();
// If location of the op is explicitly provided, then use it; Else use
// the parser's current location.
Location opLoc = explicitLoc.getValueOr(currLocation);
// Derive the SSA-values for op's operands.
if (parser.resolveOperands(operands, opFntype.getInputs(), loc,
result.operands))
return failure();
// Add a region for op.
Region &region = *result.addRegion();
// Create a basic-block inside op's region.
Block &block = region.emplaceBlock();
// Create and insert an "inner-op" operation in the block.
// Just for testing purposes, we can assume that inner op is a binary op with
// result and operand types all same as the test-op's first operand.
Type innerOpType = opFntype.getInput(0);
Value lhs = block.addArgument(innerOpType, opLoc);
Value rhs = block.addArgument(innerOpType, opLoc);
OpBuilder builder(parser.getBuilder().getContext());
builder.setInsertionPointToStart(&block);
OperationState innerOpState(opLoc, innerOpName);
innerOpState.operands.push_back(lhs);
innerOpState.operands.push_back(rhs);
innerOpState.addTypes(innerOpType);
Operation *innerOp = builder.createOperation(innerOpState);
// Insert a return statement in the block returning the inner-op's result.
builder.create<TestReturnOp>(innerOp->getLoc(), innerOp->getResults());
// Populate the op operation-state with result-type and location.
result.addTypes(opFntype.getResults());
result.location = innerOp->getLoc();
return success();
}
static void print(OpAsmPrinter &p, PrettyPrintedRegionOp op) {
p << ' ';
p.printOperands(op.getOperands());
Operation &innerOp = op.getRegion().front().front();
// Assuming that region has a single non-terminator inner-op, if the inner-op
// meets some criteria (which in this case is a simple one based on the name
// of inner-op), then we can print the entire region in a succinct way.
// Here we assume that the prototype of "special.op" can be trivially derived
// while parsing it back.
if (innerOp.getName().getStringRef().equals("special.op")) {
p << " start special.op end";
} else {
p << " (";
p.printRegion(op.getRegion());
p << ")";
}
p << " : ";
p.printFunctionalType(op);
}
//===----------------------------------------------------------------------===//
// Test PolyForOp - parse list of region arguments.
//===----------------------------------------------------------------------===//

View File

@ -1630,6 +1630,25 @@ def WrappingRegionOp : TEST_Op<"wrapping_region",
let printer = [{ return ::print(p, *this); }];
}
def PrettyPrintedRegionOp : TEST_Op<"pretty_printed_region",
[SingleBlockImplicitTerminator<"TestReturnOp">]> {
let summary = "pretty_printed_region operation";
let description = [{
Test-op can be printed either in a "pretty" or "non-pretty" way based on
some criteria. The custom parser parsers both the versions while testing
APIs: parseCustomOperationName & parseGenericOperationAfterOpName.
}];
let arguments = (ins
AnyType:$input1,
AnyType:$input2
);
let results = (outs AnyType);
let regions = (region SizedRegion<1>:$region);
let parser = [{ return ::parse$cppClass(parser, result); }];
let printer = [{ return ::print(p, *this); }];
}
def PolyForOp : TEST_Op<"polyfor">
{
let summary = "polyfor operation";