[mlir:ODS] Support using attributes in AllTypesMatch to automatically add InferTypeOpInterface

This allows for using attribute types in result type inference for use with
InferTypeOpInterface. This was a TODO before, but it isn't much
additional work to properly support this. After this commit,
arith::ConstantOp can now have its InferTypeOpInterface implementation automatically
generated.

Differential Revision: https://reviews.llvm.org/D124580
This commit is contained in:
River Riddle 2022-04-26 11:00:35 -07:00
parent 53f775bbc0
commit 1bd1edaf40
11 changed files with 127 additions and 82 deletions

View File

@ -124,9 +124,7 @@ class Arith_CompareOpOfAnyRank<string mnemonic, list<Trait> traits = []> :
def Arith_ConstantOp : Op<Arithmetic_Dialect, "constant",
[ConstantLike, NoSideEffect,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
TypesMatchWith<
"result and attribute have the same type",
"value", "result", "$_self">]> {
AllTypesMatch<["value", "result"]>]> {
let summary = "integer or floating point constant";
let description = [{
The `constant` operation produces an SSA value equal to some integer or
@ -154,8 +152,6 @@ def Arith_ConstantOp : Op<Arithmetic_Dialect, "constant",
let results = (outs /*SignlessIntegerOrFloatLike*/AnyType:$result);
let builders = [
OpBuilder<(ins "Attribute":$value),
[{ build($_builder, $_state, value.getType(), value); }]>,
OpBuilder<(ins "Attribute":$value, "Type":$type),
[{ build($_builder, $_state, type, value); }]>,
];

View File

@ -187,19 +187,9 @@ private:
/// ensure that the static functions have a unique name.
std::string uniqueOutputLabel;
/// Unique constraints by their predicate and summary. Constraints that share
/// the same predicate may have different descriptions; ensure that the
/// correct error message is reported when verification fails.
struct ConstraintUniquer {
static Constraint getEmptyKey();
static Constraint getTombstoneKey();
static unsigned getHashValue(Constraint constraint);
static bool isEqual(Constraint lhs, Constraint rhs);
};
/// Use a MapVector to ensure that functions are generated deterministically.
using ConstraintMap =
llvm::MapVector<Constraint, std::string,
llvm::DenseMap<Constraint, unsigned, ConstraintUniquer>>;
using ConstraintMap = llvm::MapVector<Constraint, std::string,
llvm::DenseMap<Constraint, unsigned>>;
/// A generic function to emit constraints
void emitConstraints(const ConstraintMap &constraints, StringRef selfName,

View File

@ -94,4 +94,20 @@ struct AppliedConstraint {
} // namespace tblgen
} // namespace mlir
namespace llvm {
/// Unique constraints by their predicate and summary. Constraints that share
/// the same predicate may have different descriptions; ensure that the
/// correct error message is reported when verification fails.
template <>
struct DenseMapInfo<mlir::tblgen::Constraint> {
using RecordDenseMapInfo = llvm::DenseMapInfo<const llvm::Record *>;
static mlir::tblgen::Constraint getEmptyKey();
static mlir::tblgen::Constraint getTombstoneKey();
static unsigned getHashValue(mlir::tblgen::Constraint constraint);
static bool isEqual(mlir::tblgen::Constraint lhs,
mlir::tblgen::Constraint rhs);
};
} // namespace llvm
#endif // MLIR_TABLEGEN_CONSTRAINT_H_

View File

@ -108,3 +108,34 @@ AppliedConstraint::AppliedConstraint(Constraint &&constraint,
std::vector<std::string> &&entities)
: constraint(constraint), self(std::string(self)),
entities(std::move(entities)) {}
Constraint DenseMapInfo<Constraint>::getEmptyKey() {
return Constraint(RecordDenseMapInfo::getEmptyKey(),
Constraint::CK_Uncategorized);
}
Constraint DenseMapInfo<Constraint>::getTombstoneKey() {
return Constraint(RecordDenseMapInfo::getTombstoneKey(),
Constraint::CK_Uncategorized);
}
unsigned DenseMapInfo<Constraint>::getHashValue(Constraint constraint) {
if (constraint == getEmptyKey())
return RecordDenseMapInfo::getHashValue(RecordDenseMapInfo::getEmptyKey());
if (constraint == getTombstoneKey()) {
return RecordDenseMapInfo::getHashValue(
RecordDenseMapInfo::getTombstoneKey());
}
return llvm::hash_combine(constraint.getPredicate(), constraint.getSummary());
}
bool DenseMapInfo<Constraint>::isEqual(Constraint lhs, Constraint rhs) {
if (lhs == rhs)
return true;
if (lhs == getEmptyKey() || lhs == getTombstoneKey())
return false;
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
return false;
return lhs.getPredicate() == rhs.getPredicate() &&
lhs.getSummary() == rhs.getSummary();
}

View File

@ -357,10 +357,6 @@ void Operator::populateTypeInferenceInfo(
continue;
}
if (getArg(*mi).is<NamedAttribute *>()) {
// TODO: Handle attributes.
continue;
}
resultTypeMapping[i].emplace_back(*mi);
found = true;
}

View File

@ -41,11 +41,11 @@ class ConstantOp:
loc=None,
ip=None):
if isinstance(value, int):
super().__init__(result, IntegerAttr.get(result, value), loc=loc, ip=ip)
super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
elif isinstance(value, float):
super().__init__(result, FloatAttr.get(result, value), loc=loc, ip=ip)
super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
else:
super().__init__(result, value, loc=loc, ip=ip)
super().__init__(value, loc=loc, ip=ip)
@classmethod
def create_index(cls, value: int, *, loc=None, ip=None):

View File

@ -25,7 +25,7 @@ func.func @non_signless_constant() {
// -----
func.func @complex_constant_wrong_attribute_type() {
// expected-error @+1 {{'arith.constant' op failed to verify that result and attribute have the same type}}
// expected-error @+1 {{'arith.constant' op failed to verify that all of {value, result} have same type}}
%0 = "arith.constant" () {value = 1.0 : f32} : () -> complex<f32>
return
}
@ -50,7 +50,7 @@ func.func @bitcast_different_bit_widths(%arg : f16) -> f32 {
func.func @constant() {
^bb:
%x = "arith.constant"(){value = "xyz"} : () -> i32 // expected-error {{'arith.constant' op failed to verify that result and attribute have the same type}}
%x = "arith.constant"(){value = "xyz"} : () -> i32 // expected-error {{'arith.constant' op failed to verify that all of {value, result} have same type}}
return
}
@ -58,7 +58,7 @@ func.func @constant() {
func.func @constant_out_of_range() {
^bb:
%x = "arith.constant"(){value = 100} : () -> i1 // expected-error {{'arith.constant' op failed to verify that result and attribute have the same type}}
%x = "arith.constant"(){value = 100} : () -> i1 // expected-error {{'arith.constant' op failed to verify that all of {value, result} have same type}}
return
}
@ -66,7 +66,7 @@ func.func @constant_out_of_range() {
func.func @constant_wrong_type() {
^bb:
%x = "arith.constant"(){value = 10.} : () -> f32 // expected-error {{'arith.constant' op failed to verify that result and attribute have the same type}}
%x = "arith.constant"(){value = 10.} : () -> f32 // expected-error {{'arith.constant' op failed to verify that all of {value, result} have same type}}
return
}

View File

@ -5,7 +5,7 @@
// Emit the first available call stack in the fused location.
func.func @constant_out_of_range() {
// CHECK: mysource1:0:0: error: 'arith.constant' op failed to verify that result and attribute have the same type
// CHECK: mysource1:0:0: error: 'arith.constant' op failed to verify that all of {value, result} have same type
// CHECK-NEXT: mysource2:1:0: note: called from
// CHECK-NEXT: mysource3:2:0: note: called from
%x = "arith.constant"() {value = 100} : () -> i1 loc(fused["bar", callsite("foo"("mysource1":0:0) at callsite("mysource2":1:0 at "mysource3":2:0))])

View File

@ -123,7 +123,8 @@ def OpL1 : NS_Op<"op_with_all_types_constraint",
// CHECK-LABEL: LogicalResult OpL1::inferReturnTypes
// CHECK-NOT: }
// CHECK: inferredReturnTypes[0] = operands[0].getType();
// CHECK: ::mlir::Type odsInferredType0 = operands[0].getType();
// CHECK: inferredReturnTypes[0] = odsInferredType0;
def OpL2 : NS_Op<"op_with_all_types_constraint",
[AllTypesMatch<["c", "b"]>, AllTypesMatch<["a", "d"]>]> {
@ -133,5 +134,18 @@ def OpL2 : NS_Op<"op_with_all_types_constraint",
// CHECK-LABEL: LogicalResult OpL2::inferReturnTypes
// CHECK-NOT: }
// CHECK: inferredReturnTypes[0] = operands[2].getType();
// CHECK: inferredReturnTypes[1] = operands[0].getType();
// CHECK: ::mlir::Type odsInferredType0 = operands[2].getType();
// CHECK: ::mlir::Type odsInferredType1 = operands[0].getType();
// CHECK: inferredReturnTypes[0] = odsInferredType0;
// CHECK: inferredReturnTypes[1] = odsInferredType1;
def OpL3 : NS_Op<"op_with_all_types_constraint",
[AllTypesMatch<["a", "b"]>]> {
let arguments = (ins I32Attr:$a);
let results = (outs AnyType:$b);
}
// CHECK-LABEL: LogicalResult OpL3::inferReturnTypes
// CHECK-NOT: }
// CHECK: ::mlir::Type odsInferredType0 = attributes.get("a").getType();
// CHECK: inferredReturnTypes[0] = odsInferredType0;

View File

@ -234,41 +234,6 @@ void StaticVerifierFunctionEmitter::emitPatternConstraints() {
//===----------------------------------------------------------------------===//
// Constraint Uniquing
using RecordDenseMapInfo = llvm::DenseMapInfo<const llvm::Record *>;
Constraint StaticVerifierFunctionEmitter::ConstraintUniquer::getEmptyKey() {
return Constraint(RecordDenseMapInfo::getEmptyKey(),
Constraint::CK_Uncategorized);
}
Constraint StaticVerifierFunctionEmitter::ConstraintUniquer::getTombstoneKey() {
return Constraint(RecordDenseMapInfo::getTombstoneKey(),
Constraint::CK_Uncategorized);
}
unsigned StaticVerifierFunctionEmitter::ConstraintUniquer::getHashValue(
Constraint constraint) {
if (constraint == getEmptyKey())
return RecordDenseMapInfo::getHashValue(RecordDenseMapInfo::getEmptyKey());
if (constraint == getTombstoneKey()) {
return RecordDenseMapInfo::getHashValue(
RecordDenseMapInfo::getTombstoneKey());
}
return llvm::hash_combine(constraint.getPredicate(), constraint.getSummary());
}
bool StaticVerifierFunctionEmitter::ConstraintUniquer::isEqual(Constraint lhs,
Constraint rhs) {
if (lhs == rhs)
return true;
if (lhs == getEmptyKey() || lhs == getTombstoneKey())
return false;
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
return false;
return lhs.getPredicate() == rhs.getPredicate() &&
lhs.getSummary() == rhs.getSummary();
}
/// An attribute constraint that references anything other than itself and the
/// current op cannot be generically extracted into a function. Most
/// prohibitive are operands and results, which require calls to

View File

@ -2336,23 +2336,60 @@ void OpEmitter::genTypeInterfaceMethods() {
fctx.withBuilder("odsBuilder");
body << " ::mlir::Builder odsBuilder(context);\n";
auto emitType = [&](const tblgen::Operator::ArgOrType &type) -> MethodBody & {
if (!type.isArg())
return body << tgfmt(*type.getType().getBuilderCall(), &fctx);
auto argIndex = type.getArg();
assert(!op.getArg(argIndex).is<NamedAttribute *>());
auto arg = op.getArgToOperandOrAttribute(argIndex);
if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand)
return body << "operands[" << arg.operandOrAttributeIndex()
<< "].getType()";
return body << "attributes[" << arg.operandOrAttributeIndex()
<< "].getType()";
};
// Preprocess the result types and build all of the types used during
// inferrence. This limits the amount of duplicated work when a type is used
// to infer multiple others.
llvm::DenseMap<Constraint, int> constraintsTypes;
llvm::DenseMap<int, int> argumentsTypes;
int inferredTypeIdx = 0;
for (int i = 0, e = op.getNumResults(); i != e; ++i) {
auto type = op.getSameTypeAsResult(i).front();
// If the type isn't an argument, it refers to a buildable type.
if (!type.isArg()) {
auto it = constraintsTypes.try_emplace(type.getType(), inferredTypeIdx);
if (!it.second)
continue;
// If we haven't seen this constraint, generate a variable for it.
body << " ::mlir::Type odsInferredType" << inferredTypeIdx++ << " = "
<< tgfmt(*type.getType().getBuilderCall(), &fctx) << ";\n";
continue;
}
// Otherwise, this is an argument.
int argIndex = type.getArg();
auto it = argumentsTypes.try_emplace(argIndex, inferredTypeIdx);
if (!it.second)
continue;
body << " ::mlir::Type odsInferredType" << inferredTypeIdx++ << " = ";
// If this is an operand, just index into operand list to access the type.
auto arg = op.getArgToOperandOrAttribute(argIndex);
if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) {
body << "operands[" << arg.operandOrAttributeIndex() << "].getType()";
// If this is an attribute, index into the attribute dictionary.
} else {
auto *attr =
op.getArg(arg.operandOrAttributeIndex()).get<NamedAttribute *>();
body << "attributes.get(\"" << attr->name << "\").getType()";
}
body << ";\n";
}
// Perform a second pass that handles assigning the inferred types to the
// results.
for (int i = 0, e = op.getNumResults(); i != e; ++i) {
body << " inferredReturnTypes[" << i << "] = ";
auto types = op.getSameTypeAsResult(i);
emitType(types[0]) << ";\n";
// Append the inferred type.
auto type = types.front();
body << " inferredReturnTypes[" << i << "] = odsInferredType"
<< (type.isArg() ? argumentsTypes[type.getArg()]
: constraintsTypes[type.getType()])
<< ";\n";
if (types.size() == 1)
continue;
// TODO: We could verify equality here, but skipping that for verification.