[MLIR] LLVM dialect: modernize and cleanups

Summary:
Modernize some of the existing custom parsing code in the LLVM dialect.
While this reduces some boilerplate code, it also reduces the precision
of the diagnostic error messges.

Reviewers: ftynse, nicolasvasilache, rriddle

Reviewed By: rriddle

Subscribers: merge_guards_bot, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, liufengdb, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D72967
This commit is contained in:
Frank Laub 2020-01-17 17:11:04 -08:00
parent df7900e218
commit ee2de95507
2 changed files with 39 additions and 64 deletions

View File

@ -55,47 +55,42 @@ template <typename CmpPredicateType>
static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
Builder &builder = parser.getBuilder();
Attribute predicate;
SmallVector<NamedAttribute, 4> attrs;
StringAttr predicateAttr;
OpAsmParser::OperandType lhs, rhs;
Type type;
llvm::SMLoc predicateLoc, trailingTypeLoc;
if (parser.getCurrentLocation(&predicateLoc) ||
parser.parseAttribute(predicate, "predicate", attrs) ||
parser.parseAttribute(predicateAttr, "predicate", result.attributes) ||
parser.parseOperand(lhs) || parser.parseComma() ||
parser.parseOperand(rhs) || parser.parseOptionalAttrDict(attrs) ||
parser.parseColon() || parser.getCurrentLocation(&trailingTypeLoc) ||
parser.parseType(type) ||
parser.parseOperand(rhs) ||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) ||
parser.resolveOperand(lhs, type, result.operands) ||
parser.resolveOperand(rhs, type, result.operands))
return failure();
// Replace the string attribute `predicate` with an integer attribute.
auto predicateStr = predicate.dyn_cast<StringAttr>();
if (!predicateStr)
return parser.emitError(predicateLoc,
"expected 'predicate' attribute of string type");
int64_t predicateValue = 0;
if (std::is_same<CmpPredicateType, ICmpPredicate>()) {
Optional<ICmpPredicate> predicate =
symbolizeICmpPredicate(predicateStr.getValue());
symbolizeICmpPredicate(predicateAttr.getValue());
if (!predicate)
return parser.emitError(predicateLoc)
<< "'" << predicateStr.getValue()
<< "'" << predicateAttr.getValue()
<< "' is an incorrect value of the 'predicate' attribute";
predicateValue = static_cast<int64_t>(predicate.getValue());
} else {
Optional<FCmpPredicate> predicate =
symbolizeFCmpPredicate(predicateStr.getValue());
symbolizeFCmpPredicate(predicateAttr.getValue());
if (!predicate)
return parser.emitError(predicateLoc)
<< "'" << predicateStr.getValue()
<< "'" << predicateAttr.getValue()
<< "' is an incorrect value of the 'predicate' attribute";
predicateValue = static_cast<int64_t>(predicate.getValue());
}
attrs[0].second = parser.getBuilder().getI64IntegerAttr(predicateValue);
result.attributes[0].second =
parser.getBuilder().getI64IntegerAttr(predicateValue);
// The result type is either i1 or a vector type <? x i1> if the inputs are
// vectors.
@ -108,7 +103,6 @@ static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
resultType = LLVMType::getVectorTy(
resultType, argType.getUnderlyingType()->getVectorNumElements());
result.attributes = attrs;
result.addTypes({resultType});
return success();
}
@ -134,14 +128,13 @@ static void printAllocaOp(OpAsmPrinter &p, AllocaOp &op) {
// <operation> ::= `llvm.alloca` ssa-use `x` type attribute-dict?
// `:` type `,` type
static ParseResult parseAllocaOp(OpAsmParser &parser, OperationState &result) {
SmallVector<NamedAttribute, 4> attrs;
OpAsmParser::OperandType arraySize;
Type type, elemType;
llvm::SMLoc trailingTypeLoc;
if (parser.parseOperand(arraySize) || parser.parseKeyword("x") ||
parser.parseType(elemType) || parser.parseOptionalAttrDict(attrs) ||
parser.parseColon() || parser.getCurrentLocation(&trailingTypeLoc) ||
parser.parseType(type))
parser.parseType(elemType) ||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
return failure();
// Extract the result type from the trailing function type.
@ -155,7 +148,6 @@ static ParseResult parseAllocaOp(OpAsmParser &parser, OperationState &result) {
if (parser.resolveOperand(arraySize, funcType.getInput(0), result.operands))
return failure();
result.attributes = attrs;
result.addTypes({funcType.getResult(0)});
return success();
}
@ -177,14 +169,13 @@ static void printGEPOp(OpAsmPrinter &p, GEPOp &op) {
// <operation> ::= `llvm.getelementptr` ssa-use `[` ssa-use-list `]`
// attribute-dict? `:` type
static ParseResult parseGEPOp(OpAsmParser &parser, OperationState &result) {
SmallVector<NamedAttribute, 4> attrs;
OpAsmParser::OperandType base;
SmallVector<OpAsmParser::OperandType, 8> indices;
Type type;
llvm::SMLoc trailingTypeLoc;
if (parser.parseOperand(base) ||
parser.parseOperandList(indices, OpAsmParser::Delimiter::Square) ||
parser.parseOptionalAttrDict(attrs) || parser.parseColon() ||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
return failure();
@ -202,7 +193,6 @@ static ParseResult parseGEPOp(OpAsmParser &parser, OperationState &result) {
parser.getNameLoc(), result.operands))
return failure();
result.attributes = attrs;
result.addTypes(funcType.getResults());
return success();
}
@ -233,20 +223,18 @@ static Type getLoadStoreElementType(OpAsmParser &parser, Type type,
// <operation> ::= `llvm.load` ssa-use attribute-dict? `:` type
static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
SmallVector<NamedAttribute, 4> attrs;
OpAsmParser::OperandType addr;
Type type;
llvm::SMLoc trailingTypeLoc;
if (parser.parseOperand(addr) || parser.parseOptionalAttrDict(attrs) ||
parser.parseColon() || parser.getCurrentLocation(&trailingTypeLoc) ||
parser.parseType(type) ||
if (parser.parseOperand(addr) ||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) ||
parser.resolveOperand(addr, type, result.operands))
return failure();
Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc);
result.attributes = attrs;
result.addTypes(elemTy);
return success();
}
@ -263,15 +251,14 @@ static void printStoreOp(OpAsmPrinter &p, StoreOp &op) {
// <operation> ::= `llvm.store` ssa-use `,` ssa-use attribute-dict? `:` type
static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
SmallVector<NamedAttribute, 4> attrs;
OpAsmParser::OperandType addr, value;
Type type;
llvm::SMLoc trailingTypeLoc;
if (parser.parseOperand(value) || parser.parseComma() ||
parser.parseOperand(addr) || parser.parseOptionalAttrDict(attrs) ||
parser.parseColon() || parser.getCurrentLocation(&trailingTypeLoc) ||
parser.parseType(type))
parser.parseOperand(addr) ||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
return failure();
Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc);
@ -282,7 +269,6 @@ static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
parser.resolveOperand(addr, type, result.operands))
return failure();
result.attributes = attrs;
return success();
}
@ -316,7 +302,6 @@ static void printCallOp(OpAsmPrinter &p, CallOp &op) {
// <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)`
// attribute-dict? `:` function-type
static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
SmallVector<NamedAttribute, 4> attrs;
SmallVector<OpAsmParser::OperandType, 8> operands;
Type type;
SymbolRefAttr funcAttr;
@ -332,11 +317,11 @@ static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
// Optionally parse a function identifier.
if (isDirect)
if (parser.parseAttribute(funcAttr, "callee", attrs))
if (parser.parseAttribute(funcAttr, "callee", result.attributes))
return failure();
if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
parser.parseOptionalAttrDict(attrs) || parser.parseColon() ||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
return failure();
@ -396,7 +381,6 @@ static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
result.addTypes(llvmResultType);
}
result.attributes = attrs;
return success();
}
@ -461,23 +445,18 @@ static void printExtractValueOp(OpAsmPrinter &p, ExtractValueOp &op) {
// resulting type wrapped in MLIR, or nullptr on error.
static LLVM::LLVMType getInsertExtractValueElementType(OpAsmParser &parser,
Type containerType,
Attribute positionAttr,
ArrayAttr positionAttr,
llvm::SMLoc attributeLoc,
llvm::SMLoc typeLoc) {
auto wrappedContainerType = containerType.dyn_cast<LLVM::LLVMType>();
if (!wrappedContainerType)
return parser.emitError(typeLoc, "expected LLVM IR Dialect type"), nullptr;
auto positionArrayAttr = positionAttr.dyn_cast<ArrayAttr>();
if (!positionArrayAttr)
return parser.emitError(attributeLoc, "expected an array attribute"),
nullptr;
// Infer the element type from the structure type: iteratively step inside the
// type by taking the element type, indexed by the position attribute for
// structures. Check the position index before accessing, it is supposed to
// be in bounds.
for (Attribute subAttr : positionArrayAttr) {
for (Attribute subAttr : positionAttr) {
auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>();
if (!positionElementAttr)
return parser.emitError(attributeLoc,
@ -512,16 +491,15 @@ static LLVM::LLVMType getInsertExtractValueElementType(OpAsmParser &parser,
// attribute-dict? `:` type
static ParseResult parseExtractValueOp(OpAsmParser &parser,
OperationState &result) {
SmallVector<NamedAttribute, 4> attrs;
OpAsmParser::OperandType container;
Type containerType;
Attribute positionAttr;
ArrayAttr positionAttr;
llvm::SMLoc attributeLoc, trailingTypeLoc;
if (parser.parseOperand(container) ||
parser.getCurrentLocation(&attributeLoc) ||
parser.parseAttribute(positionAttr, "position", attrs) ||
parser.parseOptionalAttrDict(attrs) || parser.parseColon() ||
parser.parseAttribute(positionAttr, "position", result.attributes) ||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
parser.getCurrentLocation(&trailingTypeLoc) ||
parser.parseType(containerType) ||
parser.resolveOperand(container, containerType, result.operands))
@ -532,7 +510,6 @@ static ParseResult parseExtractValueOp(OpAsmParser &parser,
if (!elementType)
return failure();
result.attributes = attrs;
result.addTypes(elementType);
return success();
}
@ -599,7 +576,7 @@ static ParseResult parseInsertValueOp(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::OperandType container, value;
Type containerType;
Attribute positionAttr;
ArrayAttr positionAttr;
llvm::SMLoc attributeLoc, trailingTypeLoc;
if (parser.parseOperand(value) || parser.parseComma() ||
@ -1080,15 +1057,15 @@ static void printShuffleVectorOp(OpAsmPrinter &p, ShuffleVectorOp &op) {
static ParseResult parseShuffleVectorOp(OpAsmParser &parser,
OperationState &result) {
llvm::SMLoc loc;
SmallVector<NamedAttribute, 4> attrs;
OpAsmParser::OperandType v1, v2;
Attribute maskAttr;
ArrayAttr maskAttr;
Type typeV1, typeV2;
if (parser.getCurrentLocation(&loc) || parser.parseOperand(v1) ||
parser.parseComma() || parser.parseOperand(v2) ||
parser.parseAttribute(maskAttr, "mask", attrs) ||
parser.parseOptionalAttrDict(attrs) || parser.parseColonType(typeV1) ||
parser.parseComma() || parser.parseType(typeV2) ||
parser.parseAttribute(maskAttr, "mask", result.attributes) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(typeV1) || parser.parseComma() ||
parser.parseType(typeV2) ||
parser.resolveOperand(v1, typeV1, result.operands) ||
parser.resolveOperand(v2, typeV2, result.operands))
return failure();
@ -1097,10 +1074,8 @@ static ParseResult parseShuffleVectorOp(OpAsmParser &parser,
!wrappedContainerType1.getUnderlyingType()->isVectorTy())
return parser.emitError(
loc, "expected LLVM IR dialect vector type for operand #1");
auto vType =
LLVMType::getVectorTy(wrappedContainerType1.getVectorElementType(),
maskAttr.cast<ArrayAttr>().size());
result.attributes = attrs;
auto vType = LLVMType::getVectorTy(
wrappedContainerType1.getVectorElementType(), maskAttr.size());
result.addTypes(vType);
return success();
}

View File

@ -12,7 +12,7 @@ func @invalid_noalias(%arg0: !llvm.i32 {llvm.noalias = 3}) {
// -----
func @icmp_non_string(%arg0 : !llvm.i32, %arg1 : !llvm<"i16">) {
// expected-error@+1 {{expected 'predicate' attribute of string type}}
// expected-error@+1 {{invalid kind of attribute specified}}
llvm.icmp 42 %arg0, %arg0 : !llvm.i32
return
}
@ -156,7 +156,7 @@ func @insertvalue_non_llvm_type(%a : i32, %b : i32) {
func @insertvalue_non_array_position() {
// Note the double-type, otherwise attribute parsing consumes the trailing
// type of the op as the (wrong) attribute type.
// expected-error@+1 {{expected an array attribute}}
// expected-error@+1 {{invalid kind of attribute specified}}
llvm.insertvalue %a, %b 0 : i32 : !llvm<"{i32}">
}
@ -200,7 +200,7 @@ func @extractvalue_non_llvm_type(%a : i32, %b : i32) {
func @extractvalue_non_array_position() {
// Note the double-type, otherwise attribute parsing consumes the trailing
// type of the op as the (wrong) attribute type.
// expected-error@+1 {{expected an array attribute}}
// expected-error@+1 {{invalid kind of attribute specified}}
llvm.extractvalue %b 0 : i32 : !llvm<"{i32}">
}