diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md index ec9e6fdc80dd..1e1abdc20d2f 100644 --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -964,6 +964,16 @@ Note that `extraClassDeclaration` is a mechanism intended for long-tail cases by power users; for not-yet-implemented widely-applicable cases, improving the infrastructure is preferable. +### Extra definitions + +When defining base op classes in TableGen that are inherited many times by +different ops, users may want to provide common definitions of utility and +interface functions. However, many of these definitions may not be desirable or +possible in `extraClassDeclaration`, which append them to the op's C++ class +declaration. In these cases, users can add an `extraClassDefinition` to define +code that is added to the generated source file inside the op's C++ namespace. +The substitution `$cppClass` is replaced by the op's C++ class name. + ### Generated C++ code [OpDefinitionsGen][OpDefinitionsGen] processes the op definition spec file and diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index f1a5446ad1f9..8e70f4844008 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2445,6 +2445,11 @@ class Op props = []> { // Additional code that will be added to the public part of the generated // C++ code of the op declaration. code extraClassDeclaration = ?; + + // Additional code that will be added to the generated source file. The + // generated code is placed inside the op's C++ namespace. `$cppClass` is + // replaced by the op's C++ class name. + code extraClassDefinition = ?; } // Base class for ops with static/dynamic offset, sizes and strides diff --git a/mlir/include/mlir/TableGen/Class.h b/mlir/include/mlir/TableGen/Class.h index 1f310fe1d082..a8a710ff85fe 100644 --- a/mlir/include/mlir/TableGen/Class.h +++ b/mlir/include/mlir/TableGen/Class.h @@ -532,22 +532,32 @@ private: Visibility visibility; }; -/// Unstructured extra class declarations, from TableGen definitions. The -/// default visibility of extra class declarations is up to the owning class. +/// Unstructured extra class declarations and definitions, from TableGen +/// definitions. The default visibility of extra class declarations is up to the +/// owning class. class ExtraClassDeclaration : public ClassDeclarationBase { public: /// Create an extra class declaration. - ExtraClassDeclaration(StringRef extraClassDeclaration) - : extraClassDeclaration(extraClassDeclaration) {} + ExtraClassDeclaration(StringRef extraClassDeclaration, + StringRef extraClassDefinition = "") + : extraClassDeclaration(extraClassDeclaration), + extraClassDefinition(extraClassDefinition) {} /// Write the extra class declarations. void writeDeclTo(raw_indented_ostream &os) const override; + /// Write the extra class definitions. + void writeDefTo(raw_indented_ostream &os, + StringRef namePrefix) const override; + private: /// The string of the extra class declarations. It is re-indented before /// printed. StringRef extraClassDeclaration; + /// The string of the extra class definitions. It is re-indented before + /// printed. + StringRef extraClassDefinition; }; /// A class used to emit C++ classes from Tablegen. Contains a list of public diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h index 44f10440c1e3..ddfb7dd0178b 100644 --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -235,6 +235,9 @@ public: // Returns this op's extra class declaration code. StringRef getExtraClassDeclaration() const; + // Returns this op's extra class definition code. + StringRef getExtraClassDefinition() const; + // Returns the Tablegen definition this operator was constructed from. // TODO: do not expose the TableGen record, this is a temporary solution to // OpEmitter requiring a Record because Operator does not provide enough diff --git a/mlir/lib/TableGen/Class.cpp b/mlir/lib/TableGen/Class.cpp index 9b7124e2e3a5..a7c02d3ae543 100644 --- a/mlir/lib/TableGen/Class.cpp +++ b/mlir/lib/TableGen/Class.cpp @@ -260,6 +260,11 @@ void ExtraClassDeclaration::writeDeclTo(raw_indented_ostream &os) const { os.printReindented(extraClassDeclaration); } +void ExtraClassDeclaration::writeDefTo(raw_indented_ostream &os, + StringRef namePrefix) const { + os.printReindented(extraClassDefinition); +} + //===----------------------------------------------------------------------===// // Class definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp index f1c1fe534666..cde617dcd30b 100644 --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -128,6 +128,13 @@ StringRef Operator::getExtraClassDeclaration() const { return def.getValueAsString(attr); } +StringRef Operator::getExtraClassDefinition() const { + constexpr auto attr = "extraClassDefinition"; + if (def.isValueUnset(attr)) + return {}; + return def.getValueAsString(attr); +} + const llvm::Record &Operator::getDef() const { return def; } bool Operator::skipDefaultBuilders() const { diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 6fad11b85ad8..28dbc271d72d 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -382,7 +382,10 @@ def ConversionCallOp : TEST_Op<"conversion_call_op", let extraClassDeclaration = [{ /// Return the callee of this operation. - ::mlir::CallInterfaceCallable getCallableForCallee() { + ::mlir::CallInterfaceCallable getCallableForCallee(); + }]; + let extraClassDefinition = [{ + ::mlir::CallInterfaceCallable $cppClass::getCallableForCallee() { return (*this)->getAttrOfType<::mlir::SymbolRefAttr>("callee"); } }]; diff --git a/mlir/tools/mlir-tblgen/OpClass.cpp b/mlir/tools/mlir-tblgen/OpClass.cpp index 9524dc9210b8..3512212272f4 100644 --- a/mlir/tools/mlir-tblgen/OpClass.cpp +++ b/mlir/tools/mlir-tblgen/OpClass.cpp @@ -15,8 +15,10 @@ using namespace mlir::tblgen; // OpClass definitions //===----------------------------------------------------------------------===// -OpClass::OpClass(StringRef name, StringRef extraClassDeclaration) +OpClass::OpClass(StringRef name, StringRef extraClassDeclaration, + std::string extraClassDefinition) : Class(name.str()), extraClassDeclaration(extraClassDeclaration), + extraClassDefinition(std::move(extraClassDefinition)), parent(addParent("::mlir::Op")) { parent.addTemplateParam(getClassName().str()); declare(Visibility::Public); @@ -30,5 +32,5 @@ OpClass::OpClass(StringRef name, StringRef extraClassDeclaration) void OpClass::finalize() { Class::finalize(); declare(Visibility::Public); - declare(extraClassDeclaration); + declare(extraClassDeclaration, extraClassDefinition); } diff --git a/mlir/tools/mlir-tblgen/OpClass.h b/mlir/tools/mlir-tblgen/OpClass.h index b0558a0e5513..6b90dd2c3a3a 100644 --- a/mlir/tools/mlir-tblgen/OpClass.h +++ b/mlir/tools/mlir-tblgen/OpClass.h @@ -25,7 +25,8 @@ public: /// - inheritance of `print` /// - a type alias for the associated adaptor class /// - OpClass(StringRef name, StringRef extraClassDeclaration); + OpClass(StringRef name, StringRef extraClassDeclaration, + std::string extraClassDefinition); /// Add an op trait. void addTrait(Twine trait) { parent.addTemplateParam(trait.str()); } @@ -39,6 +40,8 @@ public: private: /// Hand-written extra class declarations. StringRef extraClassDeclaration; + /// Hand-written extra class definitions. + std::string extraClassDefinition; /// The parent class, which also contains the traits to be inherited. ParentClass &parent; }; diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index f024b90d3340..8511df9c54e6 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -557,10 +557,18 @@ static void genAttributeVerifier( } } +/// Op extra class definitions have a `$cppClass` substitution that is to be +/// replaced by the C++ class name. +static std::string formatExtraDefinitions(const Operator &op) { + FmtContext ctx = FmtContext().addSubst("cppClass", op.getCppClassName()); + return tgfmt(op.getExtraClassDefinition(), &ctx).str(); +} + OpEmitter::OpEmitter(const Operator &op, const StaticVerifierFunctionEmitter &staticVerifierEmitter) : def(op.getDef()), op(op), - opClass(op.getCppClassName(), op.getExtraClassDeclaration()), + opClass(op.getCppClassName(), op.getExtraClassDeclaration(), + formatExtraDefinitions(op)), staticVerifierEmitter(staticVerifierEmitter) { verifyCtx.withOp("(*this->getOperation())"); verifyCtx.addSubst("_ctxt", "this->getOperation()->getContext()");