mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-12-14 19:49:36 +00:00
[mlir:ODS] Support using attributes in AllTypesMatch to automatically add InferTypeOpInterface
This allows for using attribute types in result type inference for use with InferTypeOpInterface. This was a TODO before, but it isn't much additional work to properly support this. After this commit, arith::ConstantOp can now have its InferTypeOpInterface implementation automatically generated. Differential Revision: https://reviews.llvm.org/D124580
This commit is contained in:
parent
53f775bbc0
commit
1bd1edaf40
@ -124,9 +124,7 @@ class Arith_CompareOpOfAnyRank<string mnemonic, list<Trait> traits = []> :
|
||||
def Arith_ConstantOp : Op<Arithmetic_Dialect, "constant",
|
||||
[ConstantLike, NoSideEffect,
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
|
||||
TypesMatchWith<
|
||||
"result and attribute have the same type",
|
||||
"value", "result", "$_self">]> {
|
||||
AllTypesMatch<["value", "result"]>]> {
|
||||
let summary = "integer or floating point constant";
|
||||
let description = [{
|
||||
The `constant` operation produces an SSA value equal to some integer or
|
||||
@ -154,8 +152,6 @@ def Arith_ConstantOp : Op<Arithmetic_Dialect, "constant",
|
||||
let results = (outs /*SignlessIntegerOrFloatLike*/AnyType:$result);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Attribute":$value),
|
||||
[{ build($_builder, $_state, value.getType(), value); }]>,
|
||||
OpBuilder<(ins "Attribute":$value, "Type":$type),
|
||||
[{ build($_builder, $_state, type, value); }]>,
|
||||
];
|
||||
|
@ -187,19 +187,9 @@ private:
|
||||
/// ensure that the static functions have a unique name.
|
||||
std::string uniqueOutputLabel;
|
||||
|
||||
/// Unique constraints by their predicate and summary. Constraints that share
|
||||
/// the same predicate may have different descriptions; ensure that the
|
||||
/// correct error message is reported when verification fails.
|
||||
struct ConstraintUniquer {
|
||||
static Constraint getEmptyKey();
|
||||
static Constraint getTombstoneKey();
|
||||
static unsigned getHashValue(Constraint constraint);
|
||||
static bool isEqual(Constraint lhs, Constraint rhs);
|
||||
};
|
||||
/// Use a MapVector to ensure that functions are generated deterministically.
|
||||
using ConstraintMap =
|
||||
llvm::MapVector<Constraint, std::string,
|
||||
llvm::DenseMap<Constraint, unsigned, ConstraintUniquer>>;
|
||||
using ConstraintMap = llvm::MapVector<Constraint, std::string,
|
||||
llvm::DenseMap<Constraint, unsigned>>;
|
||||
|
||||
/// A generic function to emit constraints
|
||||
void emitConstraints(const ConstraintMap &constraints, StringRef selfName,
|
||||
|
@ -94,4 +94,20 @@ struct AppliedConstraint {
|
||||
} // namespace tblgen
|
||||
} // namespace mlir
|
||||
|
||||
namespace llvm {
|
||||
/// Unique constraints by their predicate and summary. Constraints that share
|
||||
/// the same predicate may have different descriptions; ensure that the
|
||||
/// correct error message is reported when verification fails.
|
||||
template <>
|
||||
struct DenseMapInfo<mlir::tblgen::Constraint> {
|
||||
using RecordDenseMapInfo = llvm::DenseMapInfo<const llvm::Record *>;
|
||||
|
||||
static mlir::tblgen::Constraint getEmptyKey();
|
||||
static mlir::tblgen::Constraint getTombstoneKey();
|
||||
static unsigned getHashValue(mlir::tblgen::Constraint constraint);
|
||||
static bool isEqual(mlir::tblgen::Constraint lhs,
|
||||
mlir::tblgen::Constraint rhs);
|
||||
};
|
||||
} // namespace llvm
|
||||
|
||||
#endif // MLIR_TABLEGEN_CONSTRAINT_H_
|
||||
|
@ -108,3 +108,34 @@ AppliedConstraint::AppliedConstraint(Constraint &&constraint,
|
||||
std::vector<std::string> &&entities)
|
||||
: constraint(constraint), self(std::string(self)),
|
||||
entities(std::move(entities)) {}
|
||||
|
||||
Constraint DenseMapInfo<Constraint>::getEmptyKey() {
|
||||
return Constraint(RecordDenseMapInfo::getEmptyKey(),
|
||||
Constraint::CK_Uncategorized);
|
||||
}
|
||||
|
||||
Constraint DenseMapInfo<Constraint>::getTombstoneKey() {
|
||||
return Constraint(RecordDenseMapInfo::getTombstoneKey(),
|
||||
Constraint::CK_Uncategorized);
|
||||
}
|
||||
|
||||
unsigned DenseMapInfo<Constraint>::getHashValue(Constraint constraint) {
|
||||
if (constraint == getEmptyKey())
|
||||
return RecordDenseMapInfo::getHashValue(RecordDenseMapInfo::getEmptyKey());
|
||||
if (constraint == getTombstoneKey()) {
|
||||
return RecordDenseMapInfo::getHashValue(
|
||||
RecordDenseMapInfo::getTombstoneKey());
|
||||
}
|
||||
return llvm::hash_combine(constraint.getPredicate(), constraint.getSummary());
|
||||
}
|
||||
|
||||
bool DenseMapInfo<Constraint>::isEqual(Constraint lhs, Constraint rhs) {
|
||||
if (lhs == rhs)
|
||||
return true;
|
||||
if (lhs == getEmptyKey() || lhs == getTombstoneKey())
|
||||
return false;
|
||||
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
|
||||
return false;
|
||||
return lhs.getPredicate() == rhs.getPredicate() &&
|
||||
lhs.getSummary() == rhs.getSummary();
|
||||
}
|
||||
|
@ -357,10 +357,6 @@ void Operator::populateTypeInferenceInfo(
|
||||
continue;
|
||||
}
|
||||
|
||||
if (getArg(*mi).is<NamedAttribute *>()) {
|
||||
// TODO: Handle attributes.
|
||||
continue;
|
||||
}
|
||||
resultTypeMapping[i].emplace_back(*mi);
|
||||
found = true;
|
||||
}
|
||||
|
@ -41,11 +41,11 @@ class ConstantOp:
|
||||
loc=None,
|
||||
ip=None):
|
||||
if isinstance(value, int):
|
||||
super().__init__(result, IntegerAttr.get(result, value), loc=loc, ip=ip)
|
||||
super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
|
||||
elif isinstance(value, float):
|
||||
super().__init__(result, FloatAttr.get(result, value), loc=loc, ip=ip)
|
||||
super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
|
||||
else:
|
||||
super().__init__(result, value, loc=loc, ip=ip)
|
||||
super().__init__(value, loc=loc, ip=ip)
|
||||
|
||||
@classmethod
|
||||
def create_index(cls, value: int, *, loc=None, ip=None):
|
||||
|
@ -25,7 +25,7 @@ func.func @non_signless_constant() {
|
||||
// -----
|
||||
|
||||
func.func @complex_constant_wrong_attribute_type() {
|
||||
// expected-error @+1 {{'arith.constant' op failed to verify that result and attribute have the same type}}
|
||||
// expected-error @+1 {{'arith.constant' op failed to verify that all of {value, result} have same type}}
|
||||
%0 = "arith.constant" () {value = 1.0 : f32} : () -> complex<f32>
|
||||
return
|
||||
}
|
||||
@ -50,7 +50,7 @@ func.func @bitcast_different_bit_widths(%arg : f16) -> f32 {
|
||||
|
||||
func.func @constant() {
|
||||
^bb:
|
||||
%x = "arith.constant"(){value = "xyz"} : () -> i32 // expected-error {{'arith.constant' op failed to verify that result and attribute have the same type}}
|
||||
%x = "arith.constant"(){value = "xyz"} : () -> i32 // expected-error {{'arith.constant' op failed to verify that all of {value, result} have same type}}
|
||||
return
|
||||
}
|
||||
|
||||
@ -58,7 +58,7 @@ func.func @constant() {
|
||||
|
||||
func.func @constant_out_of_range() {
|
||||
^bb:
|
||||
%x = "arith.constant"(){value = 100} : () -> i1 // expected-error {{'arith.constant' op failed to verify that result and attribute have the same type}}
|
||||
%x = "arith.constant"(){value = 100} : () -> i1 // expected-error {{'arith.constant' op failed to verify that all of {value, result} have same type}}
|
||||
return
|
||||
}
|
||||
|
||||
@ -66,7 +66,7 @@ func.func @constant_out_of_range() {
|
||||
|
||||
func.func @constant_wrong_type() {
|
||||
^bb:
|
||||
%x = "arith.constant"(){value = 10.} : () -> f32 // expected-error {{'arith.constant' op failed to verify that result and attribute have the same type}}
|
||||
%x = "arith.constant"(){value = 10.} : () -> f32 // expected-error {{'arith.constant' op failed to verify that all of {value, result} have same type}}
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -5,7 +5,7 @@
|
||||
|
||||
// Emit the first available call stack in the fused location.
|
||||
func.func @constant_out_of_range() {
|
||||
// CHECK: mysource1:0:0: error: 'arith.constant' op failed to verify that result and attribute have the same type
|
||||
// CHECK: mysource1:0:0: error: 'arith.constant' op failed to verify that all of {value, result} have same type
|
||||
// CHECK-NEXT: mysource2:1:0: note: called from
|
||||
// CHECK-NEXT: mysource3:2:0: note: called from
|
||||
%x = "arith.constant"() {value = 100} : () -> i1 loc(fused["bar", callsite("foo"("mysource1":0:0) at callsite("mysource2":1:0 at "mysource3":2:0))])
|
||||
|
@ -123,7 +123,8 @@ def OpL1 : NS_Op<"op_with_all_types_constraint",
|
||||
|
||||
// CHECK-LABEL: LogicalResult OpL1::inferReturnTypes
|
||||
// CHECK-NOT: }
|
||||
// CHECK: inferredReturnTypes[0] = operands[0].getType();
|
||||
// CHECK: ::mlir::Type odsInferredType0 = operands[0].getType();
|
||||
// CHECK: inferredReturnTypes[0] = odsInferredType0;
|
||||
|
||||
def OpL2 : NS_Op<"op_with_all_types_constraint",
|
||||
[AllTypesMatch<["c", "b"]>, AllTypesMatch<["a", "d"]>]> {
|
||||
@ -133,5 +134,18 @@ def OpL2 : NS_Op<"op_with_all_types_constraint",
|
||||
|
||||
// CHECK-LABEL: LogicalResult OpL2::inferReturnTypes
|
||||
// CHECK-NOT: }
|
||||
// CHECK: inferredReturnTypes[0] = operands[2].getType();
|
||||
// CHECK: inferredReturnTypes[1] = operands[0].getType();
|
||||
// CHECK: ::mlir::Type odsInferredType0 = operands[2].getType();
|
||||
// CHECK: ::mlir::Type odsInferredType1 = operands[0].getType();
|
||||
// CHECK: inferredReturnTypes[0] = odsInferredType0;
|
||||
// CHECK: inferredReturnTypes[1] = odsInferredType1;
|
||||
|
||||
def OpL3 : NS_Op<"op_with_all_types_constraint",
|
||||
[AllTypesMatch<["a", "b"]>]> {
|
||||
let arguments = (ins I32Attr:$a);
|
||||
let results = (outs AnyType:$b);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: LogicalResult OpL3::inferReturnTypes
|
||||
// CHECK-NOT: }
|
||||
// CHECK: ::mlir::Type odsInferredType0 = attributes.get("a").getType();
|
||||
// CHECK: inferredReturnTypes[0] = odsInferredType0;
|
||||
|
@ -234,41 +234,6 @@ void StaticVerifierFunctionEmitter::emitPatternConstraints() {
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Constraint Uniquing
|
||||
|
||||
using RecordDenseMapInfo = llvm::DenseMapInfo<const llvm::Record *>;
|
||||
|
||||
Constraint StaticVerifierFunctionEmitter::ConstraintUniquer::getEmptyKey() {
|
||||
return Constraint(RecordDenseMapInfo::getEmptyKey(),
|
||||
Constraint::CK_Uncategorized);
|
||||
}
|
||||
|
||||
Constraint StaticVerifierFunctionEmitter::ConstraintUniquer::getTombstoneKey() {
|
||||
return Constraint(RecordDenseMapInfo::getTombstoneKey(),
|
||||
Constraint::CK_Uncategorized);
|
||||
}
|
||||
|
||||
unsigned StaticVerifierFunctionEmitter::ConstraintUniquer::getHashValue(
|
||||
Constraint constraint) {
|
||||
if (constraint == getEmptyKey())
|
||||
return RecordDenseMapInfo::getHashValue(RecordDenseMapInfo::getEmptyKey());
|
||||
if (constraint == getTombstoneKey()) {
|
||||
return RecordDenseMapInfo::getHashValue(
|
||||
RecordDenseMapInfo::getTombstoneKey());
|
||||
}
|
||||
return llvm::hash_combine(constraint.getPredicate(), constraint.getSummary());
|
||||
}
|
||||
|
||||
bool StaticVerifierFunctionEmitter::ConstraintUniquer::isEqual(Constraint lhs,
|
||||
Constraint rhs) {
|
||||
if (lhs == rhs)
|
||||
return true;
|
||||
if (lhs == getEmptyKey() || lhs == getTombstoneKey())
|
||||
return false;
|
||||
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
|
||||
return false;
|
||||
return lhs.getPredicate() == rhs.getPredicate() &&
|
||||
lhs.getSummary() == rhs.getSummary();
|
||||
}
|
||||
|
||||
/// An attribute constraint that references anything other than itself and the
|
||||
/// current op cannot be generically extracted into a function. Most
|
||||
/// prohibitive are operands and results, which require calls to
|
||||
|
@ -2336,23 +2336,60 @@ void OpEmitter::genTypeInterfaceMethods() {
|
||||
fctx.withBuilder("odsBuilder");
|
||||
body << " ::mlir::Builder odsBuilder(context);\n";
|
||||
|
||||
auto emitType = [&](const tblgen::Operator::ArgOrType &type) -> MethodBody & {
|
||||
if (!type.isArg())
|
||||
return body << tgfmt(*type.getType().getBuilderCall(), &fctx);
|
||||
auto argIndex = type.getArg();
|
||||
assert(!op.getArg(argIndex).is<NamedAttribute *>());
|
||||
auto arg = op.getArgToOperandOrAttribute(argIndex);
|
||||
if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand)
|
||||
return body << "operands[" << arg.operandOrAttributeIndex()
|
||||
<< "].getType()";
|
||||
return body << "attributes[" << arg.operandOrAttributeIndex()
|
||||
<< "].getType()";
|
||||
};
|
||||
|
||||
// Preprocess the result types and build all of the types used during
|
||||
// inferrence. This limits the amount of duplicated work when a type is used
|
||||
// to infer multiple others.
|
||||
llvm::DenseMap<Constraint, int> constraintsTypes;
|
||||
llvm::DenseMap<int, int> argumentsTypes;
|
||||
int inferredTypeIdx = 0;
|
||||
for (int i = 0, e = op.getNumResults(); i != e; ++i) {
|
||||
auto type = op.getSameTypeAsResult(i).front();
|
||||
|
||||
// If the type isn't an argument, it refers to a buildable type.
|
||||
if (!type.isArg()) {
|
||||
auto it = constraintsTypes.try_emplace(type.getType(), inferredTypeIdx);
|
||||
if (!it.second)
|
||||
continue;
|
||||
|
||||
// If we haven't seen this constraint, generate a variable for it.
|
||||
body << " ::mlir::Type odsInferredType" << inferredTypeIdx++ << " = "
|
||||
<< tgfmt(*type.getType().getBuilderCall(), &fctx) << ";\n";
|
||||
continue;
|
||||
}
|
||||
|
||||
// Otherwise, this is an argument.
|
||||
int argIndex = type.getArg();
|
||||
auto it = argumentsTypes.try_emplace(argIndex, inferredTypeIdx);
|
||||
if (!it.second)
|
||||
continue;
|
||||
body << " ::mlir::Type odsInferredType" << inferredTypeIdx++ << " = ";
|
||||
|
||||
// If this is an operand, just index into operand list to access the type.
|
||||
auto arg = op.getArgToOperandOrAttribute(argIndex);
|
||||
if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) {
|
||||
body << "operands[" << arg.operandOrAttributeIndex() << "].getType()";
|
||||
|
||||
// If this is an attribute, index into the attribute dictionary.
|
||||
} else {
|
||||
auto *attr =
|
||||
op.getArg(arg.operandOrAttributeIndex()).get<NamedAttribute *>();
|
||||
body << "attributes.get(\"" << attr->name << "\").getType()";
|
||||
}
|
||||
body << ";\n";
|
||||
}
|
||||
|
||||
// Perform a second pass that handles assigning the inferred types to the
|
||||
// results.
|
||||
for (int i = 0, e = op.getNumResults(); i != e; ++i) {
|
||||
body << " inferredReturnTypes[" << i << "] = ";
|
||||
auto types = op.getSameTypeAsResult(i);
|
||||
emitType(types[0]) << ";\n";
|
||||
|
||||
// Append the inferred type.
|
||||
auto type = types.front();
|
||||
body << " inferredReturnTypes[" << i << "] = odsInferredType"
|
||||
<< (type.isArg() ? argumentsTypes[type.getArg()]
|
||||
: constraintsTypes[type.getType()])
|
||||
<< ";\n";
|
||||
|
||||
if (types.size() == 1)
|
||||
continue;
|
||||
// TODO: We could verify equality here, but skipping that for verification.
|
||||
|
Loading…
Reference in New Issue
Block a user