From 0fbf4ff232cdad9f6131527f2a23019bfb331b9e Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Mon, 28 Jan 2019 07:13:40 -0800 Subject: [PATCH] Define mAttr in terms of AttrConstraint. * Matching an attribute and specifying a attribute constraint is the same thing executionally, so represent it such. * Extract AttrConstraint helper to match TypeConstraint and use that where mAttr was previously used in RewriterGen. PiperOrigin-RevId: 231213580 --- mlir/include/mlir/IR/op_base.td | 8 ++--- mlir/include/mlir/TableGen/Attribute.h | 43 ++++++++++++++--------- mlir/lib/TableGen/Attribute.cpp | 48 +++++++++++++++----------- mlir/lib/TableGen/Type.cpp | 2 +- mlir/tools/mlir-tblgen/RewriterGen.cpp | 19 ++++------ 5 files changed, 65 insertions(+), 55 deletions(-) diff --git a/mlir/include/mlir/IR/op_base.td b/mlir/include/mlir/IR/op_base.td index 85d41a0de3c5..63a055f14695 100644 --- a/mlir/include/mlir/IR/op_base.td +++ b/mlir/include/mlir/IR/op_base.td @@ -221,7 +221,7 @@ def FloatLike : TypeConstraint { - // The predicates that this type satisfies. + // The predicates that this attribute satisfies. // Format: {0} will be expanded to the attribute. Pred predicate = condition; } @@ -443,11 +443,7 @@ class Pat : Pattern; // Attribute matcher. This is the base class to specify a predicate // that has to match. Used on the input attributes of a rewrite rule. -class mAttr { - // Code to match the attribute. - // Format: {0} represents the attribute. - CPred predicate = pred; -} +class mAttr : AttrConstraint; // Attribute transforms. This is the base class to specify a // transformation of a matched attribute. Used on the output of a rewrite diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h index 64c2db02cb9e..eca480410e62 100644 --- a/mlir/include/mlir/TableGen/Attribute.h +++ b/mlir/include/mlir/TableGen/Attribute.h @@ -35,13 +35,37 @@ class Record; namespace mlir { namespace tblgen { +// Wrapper class with helper methods for accessing Attribute constraints defined +// in TableGen. +class AttrConstraint { +public: + explicit AttrConstraint(const llvm::Record *record); + explicit AttrConstraint(const llvm::DefInit *init); + + // Returns the predicate that can be used to check if a attribute satisfies + // this attribute constraint. + Pred getPredicate() const; + + // Returns the condition template that can be used to check if a attribute + // satisfies this attribute constraint. The template may contain "{0}" that + // must be substituted with an expression returning an mlir::Attribute. + std::string getConditionTemplate() const; + + // Returns the user-readable description of the constraint. If the + // description is not provided, returns an empty string. + StringRef getDescription() const; + +protected: + // The TableGen definition of this attribute. + const llvm::Record *def; +}; + // Wrapper class providing helper methods for accessing MLIR Attribute defined // in TableGen. This class should closely reflect what is defined as class // `Attr` in TableGen. -class Attribute { +class Attribute : public AttrConstraint { public: - explicit Attribute(const llvm::Record &def); - explicit Attribute(const llvm::Record *def); + explicit Attribute(const llvm::Record *record); explicit Attribute(const llvm::DefInit *init); // Returns true if this attribute is a derived attribute (i.e., a subclass @@ -85,19 +109,6 @@ public: // Returns the code body for derived attribute. Aborts if this is not a // derived attribute. StringRef getDerivedCodeBody() const; - - // Returns the predicate that can be used to check if a attribute satisfies - // this attribute's constraint. - Pred getPredicate() const; - - // Returns the template that can be used to verify that an attribute satisfies - // the constraints for its declared attribute type. - // Syntax: {0} should be replaced with the attribute. - std::string getConditionTemplate() const; - -private: - // The TableGen definition of this attribute. - const llvm::Record *def; }; } // end namespace tblgen diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp index 42dd333b3f74..f4c040ce179b 100644 --- a/mlir/lib/TableGen/Attribute.cpp +++ b/mlir/lib/TableGen/Attribute.cpp @@ -36,15 +36,38 @@ static StringRef getValueAsString(const llvm::Init *init) { return {}; } -tblgen::Attribute::Attribute(const llvm::Record *def) : def(def) { - assert(def->isSubClassOf("Attr") && +tblgen::AttrConstraint::AttrConstraint(const llvm::Record *record) + : def(record) { + assert(def->isSubClassOf("AttrConstraint") && + "must be subclass of TableGen 'AttrConstraint' class"); +} + +tblgen::AttrConstraint::AttrConstraint(const llvm::DefInit *init) + : AttrConstraint(init->getDef()) {} + +tblgen::Pred tblgen::AttrConstraint::getPredicate() const { + auto *val = def->getValue("predicate"); + // If no predicate is specified, then return the null predicate (which + // corresponds to true). + if (!val) + return Pred(); + + const auto *pred = dyn_cast(val->getValue()); + return Pred(pred); +} + +std::string tblgen::AttrConstraint::getConditionTemplate() const { + return getPredicate().getCondition(); +} + +tblgen::Attribute::Attribute(const llvm::Record *record) + : AttrConstraint(record) { + assert(record->isSubClassOf("Attr") && "must be subclass of TableGen 'Attr' class"); } -tblgen::Attribute::Attribute(const llvm::Record &def) : Attribute(&def) {} - tblgen::Attribute::Attribute(const llvm::DefInit *init) - : Attribute(*init->getDef()) {} + : AttrConstraint(init->getDef()) {} bool tblgen::Attribute::isDerivedAttr() const { return def->isSubClassOf("DerivedAttr"); @@ -103,18 +126,3 @@ StringRef tblgen::Attribute::getDerivedCodeBody() const { assert(isDerivedAttr() && "only derived attribute has 'body' field"); return def->getValueAsString("body"); } - -tblgen::Pred tblgen::Attribute::getPredicate() const { - auto *val = def->getValue("predicate"); - // If no predicate is specified, then return the null predicate (which - // corresponds to true). - if (!val) - return Pred(); - - const auto *pred = dyn_cast(val->getValue()); - return Pred(pred); -} - -std::string tblgen::Attribute::getConditionTemplate() const { - return getPredicate().getCondition(); -} diff --git a/mlir/lib/TableGen/Type.cpp b/mlir/lib/TableGen/Type.cpp index d7f2d7118f64..0c5def6d6219 100644 --- a/mlir/lib/TableGen/Type.cpp +++ b/mlir/lib/TableGen/Type.cpp @@ -53,7 +53,7 @@ llvm::StringRef tblgen::TypeConstraint::getDescription() const { } tblgen::TypeConstraint::TypeConstraint(const llvm::DefInit &init) - : def(*init.getDef()) {} + : TypeConstraint(*init.getDef()) {} tblgen::Type::Type(const llvm::Record &record) : TypeConstraint(record) { assert(def.isSubClassOf("Type") && diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index 5c83528a8f80..7af3134bcf14 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -208,18 +208,13 @@ void Pattern::matchOp(DagInit *tree, int depth) { // TODO(jpienaar): Verify attributes. if (auto *namedAttr = opArg.dyn_cast()) { - // TODO(jpienaar): move to helper class. - if (defInit->getDef()->isSubClassOf("mAttr")) { - auto pred = - tblgen::Pred(defInit->getDef()->getValueInit("predicate")); - os.indent(indent) - << "if (!(" - << formatv(pred.getCondition().c_str(), - formatv("op{0}->getAttrOfType<{1}>(\"{2}\")", depth, - namedAttr->attr.getStorageType(), - namedAttr->getName())) - << ")) return matchFailure();\n"; - } + auto constraint = tblgen::AttrConstraint(defInit); + std::string condition = formatv( + constraint.getConditionTemplate().c_str(), + formatv("op{0}->getAttrOfType<{1}>(\"{2}\")", depth, + namedAttr->attr.getStorageType(), namedAttr->getName())); + os.indent(indent) << "if (!(" << condition + << ")) return matchFailure();\n"; } }