Mogball b0774e5f50 [mlir][ods] ODS ops get an extraClassDefinition
Extra definitions are placed in the generated source file for each op class. The substitution `$cppClass` is replaced by the op's C++ class name.

This is useful when declaring but not defining methods in TableGen base classes:

```
class BaseOp<string mnemonic>
    : Op<MyDialect, mnemonic, [DeclareOpInterfaceMethods<SomeInterface>] {
  let extraClassDeclaration = [{
    // ZOp is declared at at the bottom of the file and is incomplete here
    ZOp getParent();
  }];
  let extraClassDefinition = [{
    int $cppClass::someInterfaceMethod() {
      return someUtilityFunction(*this);
    }
    ZOp $cppClass::getParent() {
      return dyn_cast<ZOp>(this->getParentOp());
    }
  }];
}
```

Certain things may prevent defining these functions inline, in the declaration. In this example, `ZOp` in the same dialect is incomplete at the function declaration because ops classes are declared in alphabetical order. Alternatively, functions may be too big to be desired as inlined, or they may require dependencies that create cyclic includes, or they may be calling a templated utility function that one may not want to expose in a header. If the functions are not inlined, then inheriting from the base class N times means that each function will need to be defined N times. With `extraClassDefinitions`, they only need to be defined once.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D115783
2022-01-06 01:43:26 +00:00

370 lines
12 KiB
C++

//===- Class.cpp - Helper classes for Op C++ code emission --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/Class.h"
#include "mlir/TableGen/Format.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Debug.h"
using namespace mlir;
using namespace mlir::tblgen;
/// Returns space to be emitted after the given C++ `type`. return "" if the
/// ends with '&' or '*', or is empty, else returns " ".
static StringRef getSpaceAfterType(StringRef type) {
return (type.empty() || type.endswith("&") || type.endswith("*")) ? "" : " ";
}
//===----------------------------------------------------------------------===//
// MethodParameter definitions
//===----------------------------------------------------------------------===//
void MethodParameter::writeDeclTo(raw_indented_ostream &os) const {
if (optional)
os << "/*optional*/";
os << type << getSpaceAfterType(type) << name;
if (hasDefaultValue())
os << " = " << defaultValue;
}
void MethodParameter::writeDefTo(raw_indented_ostream &os) const {
if (optional)
os << "/*optional*/";
os << type << getSpaceAfterType(type) << name;
}
//===----------------------------------------------------------------------===//
// MethodParameters definitions
//===----------------------------------------------------------------------===//
void MethodParameters::writeDeclTo(raw_indented_ostream &os) const {
llvm::interleaveComma(parameters, os,
[&os](auto &param) { param.writeDeclTo(os); });
}
void MethodParameters::writeDefTo(raw_indented_ostream &os) const {
llvm::interleaveComma(parameters, os,
[&os](auto &param) { param.writeDefTo(os); });
}
bool MethodParameters::subsumes(const MethodParameters &other) const {
// These parameters do not subsume the others if there are fewer parameters
// or their types do not match.
if (parameters.size() < other.parameters.size())
return false;
if (!std::equal(
other.parameters.begin(), other.parameters.end(), parameters.begin(),
[](auto &lhs, auto &rhs) { return lhs.getType() == rhs.getType(); }))
return false;
// If all the common parameters have the same type, we can elide the other
// method if this method has the same number of parameters as other or if the
// first paramater after the common parameters has a default value (and, as
// required by C++, subsequent parameters will have default values too).
return parameters.size() == other.parameters.size() ||
parameters[other.parameters.size()].hasDefaultValue();
}
//===----------------------------------------------------------------------===//
// MethodSignature definitions
//===----------------------------------------------------------------------===//
bool MethodSignature::makesRedundant(const MethodSignature &other) const {
return methodName == other.methodName &&
parameters.subsumes(other.parameters);
}
void MethodSignature::writeDeclTo(raw_indented_ostream &os) const {
os << returnType << getSpaceAfterType(returnType) << methodName << "(";
parameters.writeDeclTo(os);
os << ")";
}
void MethodSignature::writeDefTo(raw_indented_ostream &os,
StringRef namePrefix) const {
os << returnType << getSpaceAfterType(returnType) << namePrefix
<< (namePrefix.empty() ? "" : "::") << methodName << "(";
parameters.writeDefTo(os);
os << ")";
}
//===----------------------------------------------------------------------===//
// MethodBody definitions
//===----------------------------------------------------------------------===//
MethodBody::MethodBody(bool declOnly)
: declOnly(declOnly), stringOs(body), os(stringOs) {}
void MethodBody::writeTo(raw_indented_ostream &os) const {
auto bodyRef = StringRef(body).drop_while([](char c) { return c == '\n'; });
os << bodyRef;
if (bodyRef.empty())
return;
if (bodyRef.back() != '\n')
os << "\n";
}
//===----------------------------------------------------------------------===//
// Method definitions
//===----------------------------------------------------------------------===//
void Method::writeDeclTo(raw_indented_ostream &os) const {
if (isStatic())
os << "static ";
if (properties & ConstexprValue)
os << "constexpr ";
methodSignature.writeDeclTo(os);
if (isConst())
os << " const";
if (!isInline()) {
os << ";\n";
return;
}
os << " {\n";
methodBody.writeTo(os);
os << "}\n\n";
}
void Method::writeDefTo(raw_indented_ostream &os, StringRef namePrefix) const {
// The method has no definition to write if it is declaration only or inline.
if (properties & Declaration || isInline())
return;
methodSignature.writeDefTo(os, namePrefix);
if (isConst())
os << " const";
os << " {\n";
methodBody.writeTo(os);
os << "}\n\n";
}
//===----------------------------------------------------------------------===//
// Constructor definitions
//===----------------------------------------------------------------------===//
void Constructor::writeDeclTo(raw_indented_ostream &os) const {
if (properties & ConstexprValue)
os << "constexpr ";
methodSignature.writeDeclTo(os);
if (!isInline()) {
os << ";\n\n";
return;
}
os << ' ';
if (!initializers.empty())
os << ": ";
llvm::interleaveComma(initializers, os,
[&](auto &initializer) { initializer.writeTo(os); });
if (!initializers.empty())
os << ' ';
os << "{";
methodBody.writeTo(os);
os << "}\n\n";
}
void Constructor::writeDefTo(raw_indented_ostream &os,
StringRef namePrefix) const {
// The method has no definition to write if it is declaration only or inline.
if (properties & Declaration || isInline())
return;
methodSignature.writeDefTo(os, namePrefix);
os << ' ';
if (!initializers.empty())
os << ": ";
llvm::interleaveComma(initializers, os,
[&](auto &initializer) { initializer.writeTo(os); });
if (!initializers.empty())
os << ' ';
os << "{";
methodBody.writeTo(os);
os << "}\n\n";
}
void Constructor::MemberInitializer::writeTo(raw_indented_ostream &os) const {
os << name << '(' << value << ')';
}
//===----------------------------------------------------------------------===//
// Visibility definitions
//===----------------------------------------------------------------------===//
namespace mlir {
namespace tblgen {
raw_ostream &operator<<(raw_ostream &os, Visibility visibility) {
switch (visibility) {
case Visibility::Public:
return os << "public";
case Visibility::Protected:
return os << "protected";
case Visibility::Private:
return os << "private";
}
return os;
}
} // namespace tblgen
} // namespace mlir
//===----------------------------------------------------------------------===//
// ParentClass definitions
//===----------------------------------------------------------------------===//
void ParentClass::writeTo(raw_indented_ostream &os) const {
os << visibility << ' ' << name;
if (!templateParams.empty()) {
auto scope = os.scope("<", ">", /*indent=*/false);
llvm::interleaveComma(templateParams, os,
[&](auto &param) { os << param; });
}
}
//===----------------------------------------------------------------------===//
// UsingDeclaration definitions
//===----------------------------------------------------------------------===//
void UsingDeclaration::writeDeclTo(raw_indented_ostream &os) const {
os << "using " << name;
if (!value.empty())
os << " = " << value;
os << ";\n";
}
//===----------------------------------------------------------------------===//
// Field definitions
//===----------------------------------------------------------------------===//
void Field::writeDeclTo(raw_indented_ostream &os) const {
os << type << ' ' << name << ";\n";
}
//===----------------------------------------------------------------------===//
// VisibilityDeclaration definitions
//===----------------------------------------------------------------------===//
void VisibilityDeclaration::writeDeclTo(raw_indented_ostream &os) const {
os.unindent();
os << visibility << ":\n";
os.indent();
}
//===----------------------------------------------------------------------===//
// ExtraClassDeclaration definitions
//===----------------------------------------------------------------------===//
void ExtraClassDeclaration::writeDeclTo(raw_indented_ostream &os) const {
os.printReindented(extraClassDeclaration);
}
void ExtraClassDeclaration::writeDefTo(raw_indented_ostream &os,
StringRef namePrefix) const {
os.printReindented(extraClassDefinition);
}
//===----------------------------------------------------------------------===//
// Class definitions
//===----------------------------------------------------------------------===//
ParentClass &Class::addParent(ParentClass parent) {
parents.push_back(std::move(parent));
return parents.back();
}
void Class::writeDeclTo(raw_indented_ostream &os) const {
// Declare the class.
os << (isStruct ? "struct" : "class") << ' ' << className << ' ';
// Declare the parent classes, if any.
if (!parents.empty()) {
os << ": ";
llvm::interleaveComma(parents, os,
[&](auto &parent) { parent.writeTo(os); });
os << ' ';
}
auto classScope = os.scope("{\n", "};\n", /*indent=*/true);
// Print all the class declarations.
for (auto &decl : declarations)
decl->writeDeclTo(os);
}
void Class::writeDefTo(raw_indented_ostream &os) const {
// Print all the definitions.
for (auto &decl : declarations)
decl->writeDefTo(os, className);
}
void Class::finalize() {
// Sort the methods by public and private. Remove them from the pending list
// of methods.
SmallVector<std::unique_ptr<Method>> publicMethods, privateMethods;
for (auto &method : methods) {
if (method->isPrivate())
privateMethods.push_back(std::move(method));
else
publicMethods.push_back(std::move(method));
}
methods.clear();
// If the last visibility declaration wasn't `public`, add one that is. Then,
// declare the public methods.
if (!publicMethods.empty() && getLastVisibilityDecl() != Visibility::Public)
declare<VisibilityDeclaration>(Visibility::Public);
for (auto &method : publicMethods)
declarations.push_back(std::move(method));
// If the last visibility declaration wasn't `private`, add one that is. Then,
// declare the private methods.
if (!privateMethods.empty() && getLastVisibilityDecl() != Visibility::Private)
declare<VisibilityDeclaration>(Visibility::Private);
for (auto &method : privateMethods)
declarations.push_back(std::move(method));
// All fields added to the pending list are private and declared at the bottom
// of the class. If the last visibility declaration wasn't `private`, add one
// that is, then declare the fields.
if (!fields.empty() && getLastVisibilityDecl() != Visibility::Private)
declare<VisibilityDeclaration>(Visibility::Private);
for (auto &field : fields)
declare<Field>(std::move(field));
fields.clear();
}
Visibility Class::getLastVisibilityDecl() const {
auto reverseDecls = llvm::reverse(declarations);
auto it = llvm::find_if(reverseDecls, [](auto &decl) {
return isa<VisibilityDeclaration>(decl);
});
return it == reverseDecls.end()
? (isStruct ? Visibility::Public : Visibility::Private)
: cast<VisibilityDeclaration>(*it).getVisibility();
}
Method *insertAndPruneMethods(std::vector<std::unique_ptr<Method>> &methods,
std::unique_ptr<Method> newMethod) {
if (llvm::any_of(methods, [&](auto &method) {
return method->makesRedundant(*newMethod);
}))
return nullptr;
llvm::erase_if(methods, [&](auto &method) {
return newMethod->makesRedundant(*method);
});
methods.push_back(std::move(newMethod));
return methods.back().get();
}
Method *Class::addMethodAndPrune(Method &&newMethod) {
return insertAndPruneMethods(methods,
std::make_unique<Method>(std::move(newMethod)));
}
Constructor *Class::addConstructorAndPrune(Constructor &&newCtor) {
return dyn_cast_or_null<Constructor>(insertAndPruneMethods(
methods, std::make_unique<Constructor>(std::move(newCtor))));
}