[mlir] Add a new fold API using Generic Adaptors

This is part of the RFC for a better fold API: https://discourse.llvm.org/t/rfc-a-better-fold-api-using-more-generic-adaptors/67374

This patch implements the required foldHook changes and the TableGen machinery for generating `fold` method signatures using `FoldAdaptor` for ops, based on the value of `useFoldAPI` of the dialect. It may be one of 2 values, with convenient named constants to create a quasi enum. The new `fold` method will then be generated if `kEmitFoldAdaptorFolder` is used.

Since the new `FoldAdaptor` approach is strictly better than the old signature, part of this patch updates the documentation and all example to encourage use of the new `fold` signature.
Included are also tests exercising the new API, ensuring proper construction of the `FoldAdaptor` and proper generation by TableGen.

Differential Revision: https://reviews.llvm.org/D140886
This commit is contained in:
Markus Böck 2022-12-25 19:29:31 +01:00
parent cf6f217516
commit bbfa7ef16d
17 changed files with 235 additions and 43 deletions

View File

@ -156,7 +156,7 @@ If the operation has a single result the following will be generated:
/// of the operation. The caller will remove the operation and use that
/// result instead.
///
OpFoldResult MyOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult MyOp::fold(FoldAdaptor adaptor) {
...
}
```
@ -178,19 +178,19 @@ Otherwise, the following is generated:
/// the operation and use those results instead.
///
/// Note that this mechanism cannot be used to remove 0-result operations.
LogicalResult MyOp::fold(ArrayRef<Attribute> operands,
LogicalResult MyOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
...
}
```
In the above, for each method an `ArrayRef<Attribute>` is provided that
corresponds to the constant attribute value of each of the operands. These
In the above, for each method a `FoldAdaptor` is provided with getters for
each of the operands, returning the corresponding constant attribute. These
operands are those that implement the `ConstantLike` trait. If any of the
operands are non-constant, a null `Attribute` value is provided instead. For
example, if MyOp provides three operands [`a`, `b`, `c`], but only `b` is
constant then `operands` will be of the form [Attribute(), b-value,
Attribute()].
constant then `adaptor` will return Attribute() for `getA()` and `getC()`,
and b-value for `getB()`.
Also above, is the use of `OpFoldResult`. This class represents the possible
result of folding an operation result: either an SSA `Value`, or an

View File

@ -255,6 +255,31 @@ LogicalResult MyDialect::verifyRegionResultAttribute(Operation *op, unsigned reg
unsigned argIndex, NamedAttribute attribute);
```
#### `useFoldAPI`
There are currently two possible values that are allowed to be assigned to this
field:
* `kEmitFoldAdaptorFolder` generates a `fold` method making use of the op's
`FoldAdaptor` to allow access of operands via convenient getter.
Generated code example:
```cpp
OpFoldResult fold(FoldAdaptor adaptor);
// or
LogicalResult fold(FoldAdaptor adaptor, SmallVectorImpl<OpFoldResult>& results);
```
* `kEmitRawAttributesFolder` generates the deprecated legacy `fold`
method, containing `ArrayRef<Attribute>` in the parameter list instead of
the op's `FoldAdaptor`. This API is scheduled for removal and should not be
used by new dialects.
Generated code example:
```cpp
OpFoldResult fold(ArrayRef<Attribute> operands);
// or
LogicalResult fold(ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult>& results);
```
### Operation Interface Fallback
Some dialects have an open ecosystem and don't register all of the possible operations. In such

View File

@ -458,16 +458,16 @@ method.
```c++
/// Fold constants.
OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) { return value(); }
OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return value(); }
/// Fold struct constants.
OpFoldResult StructConstantOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult StructConstantOp::fold(FoldAdaptor adaptor) {
return value();
}
/// Fold simple struct access operations that access into a constant.
OpFoldResult StructAccessOp::fold(ArrayRef<Attribute> operands) {
auto structAttr = operands.front().dyn_cast_or_null<mlir::ArrayAttr>();
OpFoldResult StructAccessOp::fold(FoldAdaptor adaptor) {
auto structAttr = adaptor.getInput().dyn_cast_or_null<mlir::ArrayAttr>();
if (!structAttr)
return nullptr;

View File

@ -33,6 +33,8 @@ def Toy_Dialect : Dialect {
// We set this bit to generate the declarations for the dialect's type parsing
// and printing hooks.
let useDefaultTypePrinterParser = 1;
let useFoldAPI = kEmitFoldAdaptorFolder;
}
// Base class for toy dialect operations. This operation inherits from the base

View File

@ -24,18 +24,14 @@ namespace {
} // namespace
/// Fold constants.
OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
return getValue();
}
OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
/// Fold struct constants.
OpFoldResult StructConstantOp::fold(ArrayRef<Attribute> operands) {
return getValue();
}
OpFoldResult StructConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
/// Fold simple struct access operations that access into a constant.
OpFoldResult StructAccessOp::fold(ArrayRef<Attribute> operands) {
auto structAttr = operands.front().dyn_cast_or_null<mlir::ArrayAttr>();
OpFoldResult StructAccessOp::fold(FoldAdaptor adaptor) {
auto structAttr = adaptor.getInput().dyn_cast_or_null<mlir::ArrayAttr>();
if (!structAttr)
return nullptr;

View File

@ -17,6 +17,14 @@
// Dialect definitions
//===----------------------------------------------------------------------===//
// Generate 'fold' method with 'ArrayRef<Attribute>' parameter.
// New code should prefer using 'kEmitFoldAdaptorFolder' and
// consider 'kEmitRawAttributesFolder' deprecated and to be
// removed in the future.
defvar kEmitRawAttributesFolder = 0;
// Generate 'fold' method with 'FoldAdaptor' parameter.
defvar kEmitFoldAdaptorFolder = 1;
class Dialect {
// The name of the dialect.
string name = ?;
@ -85,6 +93,9 @@ class Dialect {
// If this dialect can be extended at runtime with new operations or types.
bit isExtensible = 0;
// Fold API to use for operations in this dialect.
int useFoldAPI = kEmitRawAttributesFolder;
}
#endif // DIALECTBASE_TD

View File

@ -1686,18 +1686,35 @@ public:
private:
/// Trait to check if T provides a 'fold' method for a single result op.
template <typename T, typename... Args>
using has_single_result_fold =
using has_single_result_fold_t =
decltype(std::declval<T>().fold(std::declval<ArrayRef<Attribute>>()));
template <typename T>
using detect_has_single_result_fold =
llvm::is_detected<has_single_result_fold, T>;
constexpr static bool has_single_result_fold_v =
llvm::is_detected<has_single_result_fold_t, T>::value;
/// Trait to check if T provides a general 'fold' method.
template <typename T, typename... Args>
using has_fold = decltype(std::declval<T>().fold(
using has_fold_t = decltype(std::declval<T>().fold(
std::declval<ArrayRef<Attribute>>(),
std::declval<SmallVectorImpl<OpFoldResult> &>()));
template <typename T>
using detect_has_fold = llvm::is_detected<has_fold, T>;
constexpr static bool has_fold_v = llvm::is_detected<has_fold_t, T>::value;
/// Trait to check if T provides a 'fold' method with a FoldAdaptor for a
/// single result op.
template <typename T, typename... Args>
using has_fold_adaptor_single_result_fold_t =
decltype(std::declval<T>().fold(std::declval<typename T::FoldAdaptor>()));
template <class T>
constexpr static bool has_fold_adaptor_single_result_v =
llvm::is_detected<has_fold_adaptor_single_result_fold_t, T>::value;
/// Trait to check if T provides a general 'fold' method with a FoldAdaptor.
template <typename T, typename... Args>
using has_fold_adaptor_fold_t = decltype(std::declval<T>().fold(
std::declval<typename T::FoldAdaptor>(),
std::declval<SmallVectorImpl<OpFoldResult> &>()));
template <class T>
constexpr static bool has_fold_adaptor_v =
llvm::is_detected<has_fold_adaptor_fold_t, T>::value;
/// Trait to check if T provides a 'print' method.
template <typename T, typename... Args>
using has_print =
@ -1746,13 +1763,14 @@ private:
// If the operation is single result and defines a `fold` method.
if constexpr (llvm::is_one_of<OpTrait::OneResult<ConcreteType>,
Traits<ConcreteType>...>::value &&
detect_has_single_result_fold<ConcreteType>::value)
(has_single_result_fold_v<ConcreteType> ||
has_fold_adaptor_single_result_v<ConcreteType>))
return [](Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
return foldSingleResultHook<ConcreteType>(op, operands, results);
};
// The operation is not single result and defines a `fold` method.
if constexpr (detect_has_fold<ConcreteType>::value)
if constexpr (has_fold_v<ConcreteType> || has_fold_adaptor_v<ConcreteType>)
return [](Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
return foldHook<ConcreteType>(op, operands, results);
@ -1771,7 +1789,12 @@ private:
static LogicalResult
foldSingleResultHook(Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
OpFoldResult result = cast<ConcreteOpT>(op).fold(operands);
OpFoldResult result;
if constexpr (has_fold_adaptor_single_result_v<ConcreteOpT>)
result = cast<ConcreteOpT>(op).fold(typename ConcreteOpT::FoldAdaptor(
operands, op->getAttrDictionary(), op->getRegions()));
else
result = cast<ConcreteOpT>(op).fold(operands);
// If the fold failed or was in-place, try to fold the traits of the
// operation.
@ -1788,7 +1811,15 @@ private:
template <typename ConcreteOpT>
static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
LogicalResult result = cast<ConcreteOpT>(op).fold(operands, results);
auto result = LogicalResult::failure();
if constexpr (has_fold_adaptor_v<ConcreteOpT>) {
result = cast<ConcreteOpT>(op).fold(
typename ConcreteOpT::FoldAdaptor(operands, op->getAttrDictionary(),
op->getRegions()),
results);
} else {
result = cast<ConcreteOpT>(op).fold(operands, results);
}
// If the fold failed or was in-place, try to fold the traits of the
// operation.

View File

@ -86,6 +86,15 @@ public:
/// operations or types.
bool isExtensible() const;
enum class FolderAPI {
RawAttributes = 0, /// fold method with ArrayRef<Attribute>.
FolderAdaptor = 1, /// fold method with the operation's FoldAdaptor.
};
/// Returns the folder API that should be emitted for operations in this
/// dialect.
FolderAPI getFolderAPI() const;
// Returns whether two dialects are equal by checking the equality of the
// underlying record.
bool operator==(const Dialect &other) const;

View File

@ -314,6 +314,8 @@ public:
/// Returns the remove name for the accessor of `name`.
std::string getRemoverName(StringRef name) const;
bool hasFolder() const;
private:
/// Populates the vectors containing operands, attributes, results and traits.
void populateOpStructure();

View File

@ -102,6 +102,16 @@ bool Dialect::isExtensible() const {
return def->getValueAsBit("isExtensible");
}
Dialect::FolderAPI Dialect::getFolderAPI() const {
int64_t value = def->getValueAsInt("useFoldAPI");
if (value < static_cast<int64_t>(FolderAPI::RawAttributes) ||
value > static_cast<int64_t>(FolderAPI::FolderAdaptor))
llvm::PrintFatalError(def->getLoc(),
"Invalid value for dialect field `useFoldAPI`");
return static_cast<FolderAPI>(value);
}
bool Dialect::operator==(const Dialect &other) const {
return def == other.def;
}

View File

@ -745,3 +745,5 @@ std::string Operator::getSetterName(StringRef name) const {
std::string Operator::getRemoverName(StringRef name) const {
return "remove" + convertToCamelFromSnakeCase(name, /*capitalizeFirst=*/true);
}
bool Operator::hasFolder() const { return def.getValueAsBit("hasFolder"); }

View File

@ -0,0 +1,16 @@
// RUN: mlir-opt %s -canonicalize -split-input-file | FileCheck %s
func.func @test() -> i32 {
%c5 = "test.constant"() {value = 5 : i32} : () -> i32
%c1 = "test.constant"() {value = 1 : i32} : () -> i32
%c2 = "test.constant"() {value = 2 : i32} : () -> i32
%c3 = "test.constant"() {value = 3 : i32} : () -> i32
%res = test.fold_with_fold_adaptor %c5, [ %c1, %c2], { (%c3), (%c3) } {
%c0 = "test.constant"() {value = 0 : i32} : () -> i32
}
return %res : i32
}
// CHECK-LABEL: func.func @test
// CHECK-NEXT: %[[C:.*]] = "test.constant"() {value = 33 : i32}
// CHECK-NEXT: return %[[C]]

View File

@ -33,6 +33,8 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include <numeric>
// Include this before the using namespace lines below to
// test that we don't have namespace dependencies.
#include "TestOpsDialect.cpp.inc"
@ -1126,6 +1128,25 @@ OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
return getOperand();
}
OpFoldResult TestOpFoldWithFoldAdaptor::fold(FoldAdaptor adaptor) {
int64_t sum = 0;
if (auto value = dyn_cast_or_null<IntegerAttr>(adaptor.getOp()))
sum += value.getValue().getSExtValue();
for (Attribute attr : adaptor.getVariadic())
if (auto value = dyn_cast_or_null<IntegerAttr>(attr))
sum += 2 * value.getValue().getSExtValue();
for (ArrayRef<Attribute> attrs : adaptor.getVarOfVar())
for (Attribute attr : attrs)
if (auto value = dyn_cast_or_null<IntegerAttr>(attr))
sum += 3 * value.getValue().getSExtValue();
sum += 4 * std::distance(adaptor.getBody().begin(), adaptor.getBody().end());
return IntegerAttr::get(getType(), sum);
}
LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
MLIRContext *, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,

View File

@ -1297,6 +1297,31 @@ def TestOpInPlaceFoldSuccess : TEST_Op<"op_in_place_fold_success"> {
}];
}
def TestOpFoldWithFoldAdaptor
: TEST_Op<"fold_with_fold_adaptor",
[AttrSizedOperandSegments, NoTerminator]> {
let arguments = (ins
I32:$op,
DenseI32ArrayAttr:$attr,
Variadic<I32>:$variadic,
VariadicOfVariadic<I32, "attr">:$var_of_var
);
let results = (outs I32:$res);
let regions = (region AnyRegion:$body);
let assemblyFormat = [{
$op `,` `[` $variadic `]` `,` `{` $var_of_var `}` $body attr-dict-with-keyword
}];
let hasFolder = 0;
let extraClassDeclaration = [{
::mlir::OpFoldResult fold(FoldAdaptor adaptor);
}];
}
// An op that always fold itself.
def TestPassthroughFold : TEST_Op<"passthrough_fold"> {
let arguments = (ins AnyType:$op);

View File

@ -0,0 +1,15 @@
// RUN: not mlir-tblgen -gen-op-decls -I %S/../../include %s 2>&1 | FileCheck %s
include "mlir/IR/OpBase.td"
def Test_Dialect : Dialect {
let name = "test";
let cppNamespace = "NS";
let useFoldAPI = 3;
}
def InvalidValue_Op : Op<Test_Dialect, "invalid_op"> {
let hasFolder = 1;
}
// CHECK: Invalid value for dialect field `useFoldAPI`

View File

@ -317,6 +317,29 @@ def NS_LOp : NS_Op<"op_with_same_operands_and_result_types_unwrapped_attr", [Sam
// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
def TestWithNewFold_Dialect : Dialect {
let name = "test";
let cppNamespace = "::mlir::testWithFold";
let useFoldAPI = kEmitFoldAdaptorFolder;
}
def NS_MOp : Op<TestWithNewFold_Dialect, "op_with_single_result_and_fold_adaptor_fold", []> {
let results = (outs AnyType:$res);
let hasFolder = 1;
}
// CHECK-LABEL: class MOp :
// CHECK: ::mlir::OpFoldResult fold(FoldAdaptor adaptor);
def NS_NOp : Op<TestWithNewFold_Dialect, "op_with_multiple_results_and_fold_adaptor_fold", []> {
let results = (outs AnyType:$res1, AnyType:$res2);
let hasFolder = 1;
}
// CHECK-LABEL: class NOp :
// CHECK: ::mlir::LogicalResult fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult> &results);
// Test that type defs have the proper namespaces when used as a constraint.
// ---

View File

@ -2326,25 +2326,29 @@ void OpEmitter::genCanonicalizerDecls() {
}
void OpEmitter::genFolderDecls() {
if (!op.hasFolder())
return;
Dialect::FolderAPI folderApi = op.getDialect().getFolderAPI();
SmallVector<MethodParameter> paramList;
if (folderApi == Dialect::FolderAPI::RawAttributes)
paramList.emplace_back("::llvm::ArrayRef<::mlir::Attribute>", "operands");
else
paramList.emplace_back("FoldAdaptor", "adaptor");
StringRef retType;
bool hasSingleResult =
op.getNumResults() == 1 && op.getNumVariableLengthResults() == 0;
if (def.getValueAsBit("hasFolder")) {
if (hasSingleResult) {
auto *m = opClass.declareMethod(
"::mlir::OpFoldResult", "fold",
MethodParameter("::llvm::ArrayRef<::mlir::Attribute>", "operands"));
ERROR_IF_PRUNED(m, "operands", op);
} else {
SmallVector<MethodParameter> paramList;
paramList.emplace_back("::llvm::ArrayRef<::mlir::Attribute>", "operands");
paramList.emplace_back("::llvm::SmallVectorImpl<::mlir::OpFoldResult> &",
"results");
auto *m = opClass.declareMethod("::mlir::LogicalResult", "fold",
std::move(paramList));
ERROR_IF_PRUNED(m, "fold", op);
}
if (hasSingleResult) {
retType = "::mlir::OpFoldResult";
} else {
paramList.emplace_back("::llvm::SmallVectorImpl<::mlir::OpFoldResult> &",
"results");
retType = "::mlir::LogicalResult";
}
auto *m = opClass.declareMethod(retType, "fold", std::move(paramList));
ERROR_IF_PRUNED(m, "fold", op);
}
void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceTrait *opTrait) {