[mlir] Add a conversion pass between PDL and the PDL Interpreter Dialect

The conversion between PDL and the interpreter is split into several different parts.
** The Matcher:

The matching section of all incoming pdl.pattern operations is converted into a predicate tree and merged. Each pattern is first converted into an ordered list of predicates starting from the root operation. A predicate is composed of three distinct parts:
* Position
  - A position refers to a specific location on the input DAG, i.e. an
    existing MLIR entity being matched. These can be attributes, operands,
    operations, results, and types. Each position also defines a relation to
    its parent. For example, the operand `[0] -> 1` has a parent operation
    position `[0]` (the root).
* Question
  - A question refers to a query on a specific positional value. For
  example, an operation name question checks the name of an operation
  position.
* Answer
  - An answer is the expected result of a question. For example, when
  matching an operation with the name "foo.op". The question would be an
  operation name question, with an expected answer of "foo.op".

After the predicate lists have been created and ordered(based on occurrence of common predicates and other factors), they are formed into a tree of nodes that represent the branching flow of a pattern match. This structure allows for efficient construction and merging of the input patterns. There are currently only 4 simple nodes in the tree:
* ExitNode: Represents the termination of a match
* SuccessNode: Represents a successful match of a specific pattern
* BoolNode/SwitchNode: Branch to a specific child node based on the expected answer to a predicate question.

Once the matcher tree has been generated, this tree is walked to generate the corresponding interpreter operations.

 ** The Rewriter:
The rewriter portion of a pattern is generated in a very straightforward manor, similarly to lowerings in other dialects. Each PDL operation that may exist within a rewrite has a mapping into the interpreter dialect. The code for the rewriter is generated within a FuncOp, that is invoked by the interpreter on a successful pattern match. Referenced values defined in the matcher become inputs the generated rewriter function.

An example lowering is shown below:

```mlir
// The following high level PDL pattern:
pdl.pattern : benefit(1) {
  %resultType = pdl.type
  %inputOperand = pdl.input
  %root, %results = pdl.operation "foo.op"(%inputOperand) -> %resultType
  pdl.rewrite %root {
    pdl.replace %root with (%inputOperand)
  }
}

// is lowered to the following:
module {
  // The matcher function takes the root operation as an input.
  func @matcher(%arg0: !pdl.operation) {
    pdl_interp.check_operation_name of %arg0 is "foo.op" -> ^bb2, ^bb1
  ^bb1:
    pdl_interp.return
  ^bb2:
    pdl_interp.check_operand_count of %arg0 is 1 -> ^bb3, ^bb1
  ^bb3:
    pdl_interp.check_result_count of %arg0 is 1 -> ^bb4, ^bb1
  ^bb4:
    %0 = pdl_interp.get_operand 0 of %arg0
    pdl_interp.is_not_null %0 : !pdl.value -> ^bb5, ^bb1
  ^bb5:
    %1 = pdl_interp.get_result 0 of %arg0
    pdl_interp.is_not_null %1 : !pdl.value -> ^bb6, ^bb1
  ^bb6:
    // This operation corresponds to a successful pattern match.
    pdl_interp.record_match @rewriters::@rewriter(%0, %arg0 : !pdl.value, !pdl.operation) : benefit(1), loc([%arg0]), root("foo.op") -> ^bb1
  }
  module @rewriters {
    // The inputs to the rewriter from the matcher are passed as arguments.
    func @rewriter(%arg0: !pdl.value, %arg1: !pdl.operation) {
      pdl_interp.replace %arg1 with(%arg0)
      pdl_interp.return
    }
  }
}
```

Differential Revision: https://reviews.llvm.org/D84580
This commit is contained in:
River Riddle 2020-10-26 17:23:16 -07:00
parent aab50af8c1
commit 8a1ca2cd34
15 changed files with 2355 additions and 2 deletions

View File

@ -0,0 +1,28 @@
//===- PDLToPDLInterp.h - PDL to PDL Interpreter conversion -----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file provides a pass for PDL to PDL Interpreter dialect conversion.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_PDLTOPDLINTERP_PDLTOPDLINTERP_H
#define MLIR_CONVERSION_PDLTOPDLINTERP_PDLTOPDLINTERP_H
#include <memory>
namespace mlir {
class ModuleOp;
template <typename OpT>
class OperationPass;
/// Creates and returns a pass to convert PDL ops to PDL interpreter ops.
std::unique_ptr<OperationPass<ModuleOp>> createPDLToPDLInterpPass();
} // namespace mlir
#endif // MLIR_CONVERSION_PDLTOPDLINTERP_PDLTOPDLINTERP_H

View File

@ -21,6 +21,7 @@
#include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.h"
#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
#include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
#include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h"

View File

@ -220,6 +220,16 @@ def ConvertOpenMPToLLVM : Pass<"convert-openmp-to-llvm", "ModuleOp"> {
let dependentDialects = ["LLVM::LLVMDialect"];
}
//===----------------------------------------------------------------------===//
// PDLToPDLInterp
//===----------------------------------------------------------------------===//
def ConvertPDLToPDLInterp : Pass<"convert-pdl-to-pdl-interp", "ModuleOp"> {
let summary = "Convert PDL ops to PDL interpreter ops";
let constructor = "mlir::createPDLToPDLInterpPass()";
let dependentDialects = ["pdl_interp::PDLInterpDialect"];
}
//===----------------------------------------------------------------------===//
// SCFToStandard
//===----------------------------------------------------------------------===//

View File

@ -36,6 +36,16 @@ def PDLInterp_Dialect : Dialect {
let name = "pdl_interp";
let cppNamespace = "::mlir::pdl_interp";
let dependentDialects = ["pdl::PDLDialect"];
let extraClassDeclaration = [{
/// Returns the name of the function containing the matcher code. This
/// function is called by the interpreter when matching an operation.
static StringRef getMatcherFunctionName() { return "matcher"; }
/// Returns the name of the module containing the rewrite functions. These
/// functions are invoked by distinct patterns within the matcher function
/// to rewrite the IR after a successful match.
static StringRef getRewriterModuleName() { return "rewriters"; }
}];
}
//===----------------------------------------------------------------------===//

View File

@ -157,8 +157,7 @@ public:
}
/// Utility override when the storage type represents the type id.
template <typename Storage>
void registerSingletonStorageType(
function_ref<void(Storage *)> initFn = llvm::None) {
void registerSingletonStorageType(function_ref<void(Storage *)> initFn = {}) {
registerSingletonStorageType<Storage>(TypeID::get<Storage>(), initFn);
}

View File

@ -10,6 +10,7 @@ add_subdirectory(LinalgToLLVM)
add_subdirectory(LinalgToSPIRV)
add_subdirectory(LinalgToStandard)
add_subdirectory(OpenMPToLLVM)
add_subdirectory(PDLToPDLInterp)
add_subdirectory(SCFToGPU)
add_subdirectory(SCFToSPIRV)
add_subdirectory(SCFToStandard)

View File

@ -0,0 +1,18 @@
add_mlir_conversion_library(MLIRPDLToPDLInterp
PDLToPDLInterp.cpp
Predicate.cpp
PredicateTree.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/PDLToPDLInterp
DEPENDS
MLIRConversionPassIncGen
LINK_LIBS PUBLIC
MLIRIR
MLIRInferTypeOpInterface
MLIRPDL
MLIRPDLInterp
MLIRPass
)

View File

@ -0,0 +1,694 @@
//===- PDLToPDLInterp.cpp - Lower a PDL module to the interpreter ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
#include "../PassDetail.h"
#include "PredicateTree.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/ScopedHashTable.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
using namespace mlir::pdl_to_pdl_interp;
//===----------------------------------------------------------------------===//
// PatternLowering
//===----------------------------------------------------------------------===//
namespace {
/// This class generators operations within the PDL Interpreter dialect from a
/// given module containing PDL pattern operations.
struct PatternLowering {
public:
PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule);
/// Generate code for matching and rewriting based on the pattern operations
/// within the module.
void lower(ModuleOp module);
private:
using ValueMap = llvm::ScopedHashTable<Position *, Value>;
using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>;
/// Generate interpreter operations for the tree rooted at the given matcher
/// node.
Block *generateMatcher(MatcherNode &node);
/// Get or create an access to the provided positional value within the
/// current block.
Value getValueAt(Block *cur, Position *pos);
/// Create an interpreter predicate operation, branching to the provided true
/// and false destinations.
void generatePredicate(Block *currentBlock, Qualifier *question,
Qualifier *answer, Value val, Block *trueDest,
Block *falseDest);
/// Create an interpreter switch predicate operation, with a provided default
/// and several case destinations.
void generateSwitch(Block *currentBlock, Qualifier *question, Value val,
Block *defaultDest,
ArrayRef<std::pair<Qualifier *, Block *>> dests);
/// Create the interpreter operations to record a successful pattern match.
void generateRecordMatch(Block *currentBlock, Block *nextBlock,
pdl::PatternOp pattern);
/// Generate a rewriter function for the given pattern operation, and returns
/// a reference to that function.
SymbolRefAttr generateRewriter(pdl::PatternOp pattern,
SmallVectorImpl<Position *> &usedMatchValues);
/// Generate the rewriter code for the given operation.
void generateRewriter(pdl::AttributeOp attrOp,
DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue);
void generateRewriter(pdl::EraseOp eraseOp,
DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue);
void generateRewriter(pdl::OperationOp operationOp,
DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue);
void generateRewriter(pdl::CreateNativeOp createNativeOp,
DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue);
void generateRewriter(pdl::ReplaceOp replaceOp,
DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue);
void generateRewriter(pdl::TypeOp typeOp,
DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue);
/// Generate the values used for resolving the result types of an operation
/// created within a dag rewriter region.
void generateOperationResultTypeRewriter(
pdl::OperationOp op, SmallVectorImpl<Value> &types,
DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue);
/// A builder to use when generating interpreter operations.
OpBuilder builder;
/// The matcher function used for all match related logic within PDL patterns.
FuncOp matcherFunc;
/// The rewriter module containing the all rewrite related logic within PDL
/// patterns.
ModuleOp rewriterModule;
/// The symbol table of the rewriter module used for insertion.
SymbolTable rewriterSymbolTable;
/// A scoped map connecting a position with the corresponding interpreter
/// value.
ValueMap values;
/// A stack of blocks used as the failure destination for matcher nodes that
/// don't have an explicit failure path.
SmallVector<Block *, 8> failureBlockStack;
/// A mapping between values defined in a pattern match, and the corresponding
/// positional value.
DenseMap<Value, Position *> valueToPosition;
/// The set of operation values whose whose location will be used for newly
/// generated operations.
llvm::SetVector<Value> locOps;
};
} // end anonymous namespace
PatternLowering::PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule)
: builder(matcherFunc.getContext()), matcherFunc(matcherFunc),
rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule) {}
void PatternLowering::lower(ModuleOp module) {
PredicateUniquer predicateUniquer;
PredicateBuilder predicateBuilder(predicateUniquer, module.getContext());
// Define top-level scope for the arguments to the matcher function.
ValueMapScope topLevelValueScope(values);
// Insert the root operation, i.e. argument to the matcher, at the root
// position.
Block *matcherEntryBlock = matcherFunc.addEntryBlock();
values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0));
// Generate a root matcher node from the provided PDL module.
std::unique_ptr<MatcherNode> root = MatcherNode::generateMatcherTree(
module, predicateBuilder, valueToPosition);
Block *firstMatcherBlock = generateMatcher(*root);
// After generation, merged the first matched block into the entry.
matcherEntryBlock->getOperations().splice(matcherEntryBlock->end(),
firstMatcherBlock->getOperations());
firstMatcherBlock->erase();
}
Block *PatternLowering::generateMatcher(MatcherNode &node) {
// Push a new scope for the values used by this matcher.
Block *block = matcherFunc.addBlock();
ValueMapScope scope(values);
// If this is the return node, simply insert the corresponding interpreter
// finalize.
if (isa<ExitNode>(node)) {
builder.setInsertionPointToEnd(block);
builder.create<pdl_interp::FinalizeOp>(matcherFunc.getLoc());
return block;
}
// If this node contains a position, get the corresponding value for this
// block.
Position *position = node.getPosition();
Value val = position ? getValueAt(block, position) : Value();
// Get the next block in the match sequence.
std::unique_ptr<MatcherNode> &failureNode = node.getFailureNode();
Block *nextBlock;
if (failureNode) {
nextBlock = generateMatcher(*failureNode);
failureBlockStack.push_back(nextBlock);
} else {
assert(!failureBlockStack.empty() && "expected valid failure block");
nextBlock = failureBlockStack.back();
}
// If this value corresponds to an operation, record that we are going to use
// its location as part of a fused location.
bool isOperationValue = val && val.getType().isa<pdl::OperationType>();
if (isOperationValue)
locOps.insert(val);
// Generate code for a boolean predicate node.
if (auto *boolNode = dyn_cast<BoolNode>(&node)) {
auto *child = generateMatcher(*boolNode->getSuccessNode());
generatePredicate(block, node.getQuestion(), boolNode->getAnswer(), val,
child, nextBlock);
// Generate code for a switch node.
} else if (auto *switchNode = dyn_cast<SwitchNode>(&node)) {
// Collect the next blocks for all of the children and generate a switch.
llvm::MapVector<Qualifier *, Block *> children;
for (auto &it : switchNode->getChildren())
children.insert({it.first, generateMatcher(*it.second)});
generateSwitch(block, node.getQuestion(), val, nextBlock,
children.takeVector());
// Generate code for a success node.
} else if (auto *successNode = dyn_cast<SuccessNode>(&node)) {
generateRecordMatch(block, nextBlock, successNode->getPattern());
}
if (failureNode)
failureBlockStack.pop_back();
if (isOperationValue)
locOps.remove(val);
return block;
}
Value PatternLowering::getValueAt(Block *cur, Position *pos) {
if (Value val = values.lookup(pos))
return val;
// Get the value for the parent position.
Value parentVal = getValueAt(cur, pos->getParent());
// TODO: Use a location from the position.
Location loc = parentVal.getLoc();
builder.setInsertionPointToEnd(cur);
Value value;
switch (pos->getKind()) {
case Predicates::OperationPos:
value = builder.create<pdl_interp::GetDefiningOpOp>(
loc, builder.getType<pdl::OperationType>(), parentVal);
break;
case Predicates::OperandPos: {
auto *operandPos = cast<OperandPosition>(pos);
value = builder.create<pdl_interp::GetOperandOp>(
loc, builder.getType<pdl::ValueType>(), parentVal,
operandPos->getOperandNumber());
break;
}
case Predicates::AttributePos: {
auto *attrPos = cast<AttributePosition>(pos);
value = builder.create<pdl_interp::GetAttributeOp>(
loc, builder.getType<pdl::AttributeType>(), parentVal,
attrPos->getName().strref());
break;
}
case Predicates::TypePos: {
if (parentVal.getType().isa<pdl::ValueType>())
value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal);
else
value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal);
break;
}
case Predicates::ResultPos: {
auto *resPos = cast<ResultPosition>(pos);
value = builder.create<pdl_interp::GetResultOp>(
loc, builder.getType<pdl::ValueType>(), parentVal,
resPos->getResultNumber());
break;
}
default:
llvm_unreachable("Generating unknown Position getter");
break;
}
values.insert(pos, value);
return value;
}
void PatternLowering::generatePredicate(Block *currentBlock,
Qualifier *question, Qualifier *answer,
Value val, Block *trueDest,
Block *falseDest) {
builder.setInsertionPointToEnd(currentBlock);
Location loc = val.getLoc();
switch (question->getKind()) {
case Predicates::IsNotNullQuestion:
builder.create<pdl_interp::IsNotNullOp>(loc, val, trueDest, falseDest);
break;
case Predicates::OperationNameQuestion: {
auto *opNameAnswer = cast<OperationNameAnswer>(answer);
builder.create<pdl_interp::CheckOperationNameOp>(
loc, val, opNameAnswer->getValue().getStringRef(), trueDest, falseDest);
break;
}
case Predicates::TypeQuestion: {
auto *ans = cast<TypeAnswer>(answer);
builder.create<pdl_interp::CheckTypeOp>(
loc, val, TypeAttr::get(ans->getValue()), trueDest, falseDest);
break;
}
case Predicates::AttributeQuestion: {
auto *ans = cast<AttributeAnswer>(answer);
builder.create<pdl_interp::CheckAttributeOp>(loc, val, ans->getValue(),
trueDest, falseDest);
break;
}
case Predicates::OperandCountQuestion: {
auto *unsignedAnswer = cast<UnsignedAnswer>(answer);
builder.create<pdl_interp::CheckOperandCountOp>(
loc, val, unsignedAnswer->getValue(), trueDest, falseDest);
break;
}
case Predicates::ResultCountQuestion: {
auto *unsignedAnswer = cast<UnsignedAnswer>(answer);
builder.create<pdl_interp::CheckResultCountOp>(
loc, val, unsignedAnswer->getValue(), trueDest, falseDest);
break;
}
case Predicates::EqualToQuestion: {
auto *equalToQuestion = cast<EqualToQuestion>(question);
builder.create<pdl_interp::AreEqualOp>(
loc, val, getValueAt(currentBlock, equalToQuestion->getValue()),
trueDest, falseDest);
break;
}
case Predicates::ConstraintQuestion: {
auto *cstQuestion = cast<ConstraintQuestion>(question);
SmallVector<Value, 2> args;
for (Position *position : std::get<1>(cstQuestion->getValue()))
args.push_back(getValueAt(currentBlock, position));
builder.create<pdl_interp::ApplyConstraintOp>(
loc, std::get<0>(cstQuestion->getValue()), args,
std::get<2>(cstQuestion->getValue()).cast<ArrayAttr>(), trueDest,
falseDest);
break;
}
default:
llvm_unreachable("Generating unknown Predicate operation");
}
}
template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy>
static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder,
ArrayRef<std::pair<Qualifier *, Block *>> dests) {
std::vector<ValT> values;
std::vector<Block *> blocks;
values.reserve(dests.size());
blocks.reserve(dests.size());
for (const auto &it : dests) {
blocks.push_back(it.second);
values.push_back(cast<PredT>(it.first)->getValue());
}
builder.create<OpT>(val.getLoc(), val, values, defaultDest, blocks);
}
void PatternLowering::generateSwitch(
Block *currentBlock, Qualifier *question, Value val, Block *defaultDest,
ArrayRef<std::pair<Qualifier *, Block *>> dests) {
builder.setInsertionPointToEnd(currentBlock);
switch (question->getKind()) {
case Predicates::OperandCountQuestion:
return createSwitchOp<pdl_interp::SwitchOperandCountOp, UnsignedAnswer,
int32_t>(val, defaultDest, builder, dests);
case Predicates::ResultCountQuestion:
return createSwitchOp<pdl_interp::SwitchResultCountOp, UnsignedAnswer,
int32_t>(val, defaultDest, builder, dests);
case Predicates::OperationNameQuestion:
return createSwitchOp<pdl_interp::SwitchOperationNameOp,
OperationNameAnswer>(val, defaultDest, builder,
dests);
case Predicates::TypeQuestion:
return createSwitchOp<pdl_interp::SwitchTypeOp, TypeAnswer>(
val, defaultDest, builder, dests);
case Predicates::AttributeQuestion:
return createSwitchOp<pdl_interp::SwitchAttributeOp, AttributeAnswer>(
val, defaultDest, builder, dests);
default:
llvm_unreachable("Generating unknown switch predicate.");
}
}
void PatternLowering::generateRecordMatch(Block *currentBlock, Block *nextBlock,
pdl::PatternOp pattern) {
// Generate a rewriter for the pattern this success node represents, and track
// any values used from the match region.
SmallVector<Position *, 8> usedMatchValues;
SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues);
// Process any values used in the rewrite that are defined in the match.
std::vector<Value> mappedMatchValues;
mappedMatchValues.reserve(usedMatchValues.size());
for (Position *position : usedMatchValues)
mappedMatchValues.push_back(getValueAt(currentBlock, position));
// Collect the set of operations generated by the rewriter.
SmallVector<StringRef, 4> generatedOps;
for (auto op : pattern.getRewriter().body().getOps<pdl::OperationOp>())
generatedOps.push_back(*op.name());
ArrayAttr generatedOpsAttr;
if (!generatedOps.empty())
generatedOpsAttr = builder.getStrArrayAttr(generatedOps);
// Grab the root kind if present.
StringAttr rootKindAttr;
if (Optional<StringRef> rootKind = pattern.getRootKind())
rootKindAttr = builder.getStringAttr(*rootKind);
builder.setInsertionPointToEnd(currentBlock);
builder.create<pdl_interp::RecordMatchOp>(
pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(),
rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.benefitAttr(),
nextBlock);
}
SymbolRefAttr PatternLowering::generateRewriter(
pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) {
FuncOp rewriterFunc =
FuncOp::create(pattern.getLoc(), "pdl_generated_rewriter",
builder.getFunctionType(llvm::None, llvm::None));
rewriterSymbolTable.insert(rewriterFunc);
// Generate the rewriter function body.
builder.setInsertionPointToEnd(rewriterFunc.addEntryBlock());
// Map an input operand of the pattern to a generated interpreter value.
DenseMap<Value, Value> rewriteValues;
auto mapRewriteValue = [&](Value oldValue) {
Value &newValue = rewriteValues[oldValue];
if (newValue)
return newValue;
// Prefer materializing constants directly when possible.
Operation *oldOp = oldValue.getDefiningOp();
if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) {
if (Attribute value = attrOp.valueAttr()) {
return newValue = builder.create<pdl_interp::CreateAttributeOp>(
attrOp.getLoc(), value);
}
} else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) {
if (TypeAttr type = typeOp.typeAttr()) {
return newValue = builder.create<pdl_interp::CreateTypeOp>(
typeOp.getLoc(), type);
}
}
// Otherwise, add this as an input to the rewriter.
Position *inputPos = valueToPosition.lookup(oldValue);
assert(inputPos && "expected value to be a pattern input");
usedMatchValues.push_back(inputPos);
return newValue = rewriterFunc.front().addArgument(oldValue.getType());
};
// If this is a custom rewriter, simply dispatch to the registered rewrite
// method.
pdl::RewriteOp rewriter = pattern.getRewriter();
if (StringAttr rewriteName = rewriter.nameAttr()) {
Value root = mapRewriteValue(rewriter.root());
SmallVector<Value, 4> args = llvm::to_vector<4>(
llvm::map_range(rewriter.externalArgs(), mapRewriteValue));
builder.create<pdl_interp::ApplyRewriteOp>(
rewriter.getLoc(), rewriteName, root, args,
rewriter.externalConstParamsAttr());
} else {
// Otherwise this is a dag rewriter defined using PDL operations.
for (Operation &rewriteOp : *rewriter.getBody()) {
llvm::TypeSwitch<Operation *>(&rewriteOp)
.Case<pdl::AttributeOp, pdl::CreateNativeOp, pdl::EraseOp,
pdl::OperationOp, pdl::ReplaceOp, pdl::TypeOp>([&](auto op) {
this->generateRewriter(op, rewriteValues, mapRewriteValue);
});
}
}
// Update the signature of the rewrite function.
rewriterFunc.setType(builder.getFunctionType(
llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()),
/*results=*/llvm::None));
builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc());
return builder.getSymbolRefAttr(
pdl_interp::PDLInterpDialect::getRewriterModuleName(),
builder.getSymbolRefAttr(rewriterFunc));
}
void PatternLowering::generateRewriter(
pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
Value newAttr = builder.create<pdl_interp::CreateAttributeOp>(
attrOp.getLoc(), attrOp.valueAttr());
rewriteValues[attrOp] = newAttr;
}
void PatternLowering::generateRewriter(
pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(),
mapRewriteValue(eraseOp.operation()));
}
void PatternLowering::generateRewriter(
pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
SmallVector<Value, 4> operands;
for (Value operand : operationOp.operands())
operands.push_back(mapRewriteValue(operand));
SmallVector<Value, 4> attributes;
for (Value attr : operationOp.attributes())
attributes.push_back(mapRewriteValue(attr));
SmallVector<Value, 2> types;
generateOperationResultTypeRewriter(operationOp, types, rewriteValues,
mapRewriteValue);
// Create the new operation.
Location loc = operationOp.getLoc();
Value createdOp = builder.create<pdl_interp::CreateOperationOp>(
loc, *operationOp.name(), types, operands, attributes,
operationOp.attributeNames());
rewriteValues[operationOp.op()] = createdOp;
// Make all of the new operation results available.
OperandRange resultTypes = operationOp.types();
for (auto it : llvm::enumerate(operationOp.results())) {
Value getResultVal = builder.create<pdl_interp::GetResultOp>(
loc, builder.getType<pdl::ValueType>(), createdOp, it.index());
rewriteValues[it.value()] = getResultVal;
// If any of the types have not been resolved, make those available as well.
Value &type = rewriteValues[resultTypes[it.index()]];
if (!type)
type = builder.create<pdl_interp::GetValueTypeOp>(loc, getResultVal);
}
}
void PatternLowering::generateRewriter(
pdl::CreateNativeOp createNativeOp, DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
SmallVector<Value, 2> arguments;
for (Value argument : createNativeOp.args())
arguments.push_back(mapRewriteValue(argument));
Value result = builder.create<pdl_interp::CreateNativeOp>(
createNativeOp.getLoc(), createNativeOp.result().getType(),
createNativeOp.nameAttr(), arguments, createNativeOp.constParamsAttr());
rewriteValues[createNativeOp] = result;
}
void PatternLowering::generateRewriter(
pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
// If the replacement was another operation, get its results. `pdl` allows
// for using an operation for simplicitly, but the interpreter isn't as
// user facing.
ValueRange origOperands;
if (Value replOp = replaceOp.replOperation())
origOperands = cast<pdl::OperationOp>(replOp.getDefiningOp()).results();
else
origOperands = replaceOp.replValues();
// If there are no replacement values, just create an erase instead.
if (origOperands.empty()) {
builder.create<pdl_interp::EraseOp>(replaceOp.getLoc(),
mapRewriteValue(replaceOp.operation()));
return;
}
SmallVector<Value, 4> replOperands;
for (Value operand : origOperands)
replOperands.push_back(mapRewriteValue(operand));
builder.create<pdl_interp::ReplaceOp>(
replaceOp.getLoc(), mapRewriteValue(replaceOp.operation()), replOperands);
}
void PatternLowering::generateRewriter(
pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
// If the type isn't constant, the users (e.g. OperationOp) will resolve this
// type.
if (TypeAttr typeAttr = typeOp.typeAttr()) {
Value newType =
builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr);
rewriteValues[typeOp] = newType;
}
}
void PatternLowering::generateOperationResultTypeRewriter(
pdl::OperationOp op, SmallVectorImpl<Value> &types,
DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
// Functor that returns if the given use can be used to infer a type.
Block *rewriterBlock = op.getOperation()->getBlock();
auto getReplacedOperationFrom = [&](OpOperand &use) -> Operation * {
// Check that the use corresponds to a ReplaceOp and that it is the
// replacement value, not the operation being replaced.
pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner());
if (!replOpUser || use.getOperandNumber() == 0)
return nullptr;
// Make sure the replaced operation was defined before this one.
Operation *replacedOp = replOpUser.operation().getDefiningOp();
if (replacedOp->getBlock() != rewriterBlock ||
replacedOp->isBeforeInBlock(op))
return replacedOp;
return nullptr;
};
// If non-None/non-Null, this is an operation that is replaced by `op`.
// If Null, there is no full replacement operation for `op`.
// If None, a replacement operation hasn't been searched for.
Optional<Operation *> fullReplacedOperation;
bool hasTypeInference = op.hasTypeInference();
auto resultTypeValues = op.types();
types.reserve(resultTypeValues.size());
for (auto it : llvm::enumerate(op.results())) {
Value result = it.value(), resultType = resultTypeValues[it.index()];
// Check for an already translated value.
if (Value existingRewriteValue = rewriteValues.lookup(resultType)) {
types.push_back(existingRewriteValue);
continue;
}
// Check for an input from the matcher.
if (resultType.getDefiningOp()->getBlock() != rewriterBlock) {
types.push_back(mapRewriteValue(resultType));
continue;
}
// Check if the operation has type inference support.
if (hasTypeInference) {
types.push_back(builder.create<pdl_interp::InferredTypeOp>(op.getLoc()));
continue;
}
// Look for an operation that was replaced by `op`. The result type will be
// inferred from the result that was replaced. There is guaranteed to be a
// replacement for either the op, or this specific result. Note that this is
// guaranteed by the verifier of `pdl::OperationOp`.
Operation *replacedOp = nullptr;
if (!fullReplacedOperation.hasValue()) {
for (OpOperand &use : op.op().getUses())
if ((replacedOp = getReplacedOperationFrom(use)))
break;
fullReplacedOperation = replacedOp;
} else {
replacedOp = fullReplacedOperation.getValue();
}
// Infer from the result, as there was no fully replaced op.
if (!replacedOp) {
for (OpOperand &use : result.getUses())
if ((replacedOp = getReplacedOperationFrom(use)))
break;
assert(replacedOp && "expected replaced op to infer a result type from");
}
auto replOpOp = cast<pdl::OperationOp>(replacedOp);
types.push_back(mapRewriteValue(replOpOp.types()[it.index()]));
}
}
//===----------------------------------------------------------------------===//
// Conversion Pass
//===----------------------------------------------------------------------===//
namespace {
struct PDLToPDLInterpPass
: public ConvertPDLToPDLInterpBase<PDLToPDLInterpPass> {
void runOnOperation() final;
};
} // namespace
/// Convert the given module containing PDL pattern operations into a PDL
/// Interpreter operations.
void PDLToPDLInterpPass::runOnOperation() {
ModuleOp module = getOperation();
// Create the main matcher function This function contains all of the match
// related functionality from patterns in the module.
OpBuilder builder = OpBuilder::atBlockBegin(module.getBody());
FuncOp matcherFunc = builder.create<FuncOp>(
module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(),
builder.getFunctionType(builder.getType<pdl::OperationType>(),
/*results=*/llvm::None),
/*attrs=*/llvm::None);
// Create a nested module to hold the functions invoked for rewriting the IR
// after a successful match.
ModuleOp rewriterModule = builder.create<ModuleOp>(
module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName());
// Generate the code for the patterns within the module.
PatternLowering generator(matcherFunc, rewriterModule);
generator.lower(module);
// After generation, delete all of the pattern operations.
for (pdl::PatternOp pattern :
llvm::make_early_inc_range(module.getOps<pdl::PatternOp>()))
pattern.erase();
}
std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass() {
return std::make_unique<PDLToPDLInterpPass>();
}

View File

@ -0,0 +1,49 @@
//===- Predicate.cpp - Pattern predicates ---------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "Predicate.h"
using namespace mlir;
using namespace mlir::pdl_to_pdl_interp;
//===----------------------------------------------------------------------===//
// Positions
//===----------------------------------------------------------------------===//
Position::~Position() {}
//===----------------------------------------------------------------------===//
// AttributePosition
AttributePosition::AttributePosition(const KeyTy &key) : Base(key) {
parent = key.first;
}
//===----------------------------------------------------------------------===//
// OperandPosition
OperandPosition::OperandPosition(const KeyTy &key) : Base(key) {
parent = key.first;
}
//===----------------------------------------------------------------------===//
// OperationPosition
OperationPosition *OperationPosition::get(StorageUniquer &uniquer,
ArrayRef<unsigned> index) {
assert(!index.empty() && "expected at least two indices");
// Set the parent position if this isn't the root.
Position *parent = nullptr;
if (index.size() > 1) {
auto *node = OperationPosition::get(uniquer, index.drop_back());
parent = OperandPosition::get(uniquer, std::make_pair(node, index.back()));
}
return uniquer.get<OperationPosition>(
[parent](OperationPosition *node) { node->parent = parent; }, index);
}

View File

@ -0,0 +1,530 @@
//===- Predicate.h - Pattern predicates -------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains definitions for "predicates" used when converting PDL into
// a matcher tree. Predicates are composed of three different parts:
//
// * Positions
// - A position refers to a specific location on the input DAG, i.e. an
// existing MLIR entity being matched. These can be attributes, operands,
// operations, results, and types. Each position also defines a relation to
// its parent. For example, the operand `[0] -> 1` has a parent operation
// position `[0]`. The attribute `[0, 1] -> "myAttr"` has parent operation
// position of `[0, 1]`. The operation `[0, 1]` has a parent operand edge
// `[0] -> 1` (i.e. it is the defining op of operand 1). The only position
// without a parent is `[0]`, which refers to the root operation.
// * Questions
// - A question refers to a query on a specific positional value. For
// example, an operation name question checks the name of an operation
// position.
// * Answers
// - An answer is the expected result of a question. For example, when
// matching an operation with the name "foo.op". The question would be an
// operation name question, with an expected answer of "foo.op".
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_
#define MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Types.h"
namespace mlir {
namespace pdl_to_pdl_interp {
namespace Predicates {
/// An enumeration of the kinds of predicates.
enum Kind : unsigned {
/// Positions, ordered by decreasing priority.
OperationPos,
OperandPos,
AttributePos,
ResultPos,
TypePos,
// Questions, ordered by dependency and decreasing priority.
IsNotNullQuestion,
OperationNameQuestion,
TypeQuestion,
AttributeQuestion,
OperandCountQuestion,
ResultCountQuestion,
EqualToQuestion,
ConstraintQuestion,
// Answers.
AttributeAnswer,
TrueAnswer,
OperationNameAnswer,
TypeAnswer,
UnsignedAnswer,
};
} // end namespace Predicates
/// Base class for all predicates, used to allow efficient pointer comparison.
template <typename ConcreteT, typename BaseT, typename Key,
Predicates::Kind Kind>
class PredicateBase : public BaseT {
public:
using KeyTy = Key;
using Base = PredicateBase<ConcreteT, BaseT, Key, Kind>;
template <typename KeyT>
explicit PredicateBase(KeyT &&key)
: BaseT(Kind), key(std::forward<KeyT>(key)) {}
/// Get an instance of this position.
template <typename... Args>
static ConcreteT *get(StorageUniquer &uniquer, Args &&...args) {
return uniquer.get<ConcreteT>(/*initFn=*/{}, std::forward<Args>(args)...);
}
/// Construct an instance with the given storage allocator.
template <typename KeyT>
static ConcreteT *construct(StorageUniquer::StorageAllocator &alloc,
KeyT &&key) {
return new (alloc.allocate<ConcreteT>()) ConcreteT(std::forward<KeyT>(key));
}
/// Utility methods required by the storage allocator.
bool operator==(const KeyTy &key) const { return this->key == key; }
static bool classof(const BaseT *pred) { return pred->getKind() == Kind; }
/// Return the key value of this predicate.
const KeyTy &getValue() const { return key; }
protected:
KeyTy key;
};
/// Base storage for simple predicates that only unique with the kind.
template <typename ConcreteT, typename BaseT, Predicates::Kind Kind>
class PredicateBase<ConcreteT, BaseT, void, Kind> : public BaseT {
public:
using Base = PredicateBase<ConcreteT, BaseT, void, Kind>;
explicit PredicateBase() : BaseT(Kind) {}
static ConcreteT *get(StorageUniquer &uniquer) {
return uniquer.get<ConcreteT>();
}
static bool classof(const BaseT *pred) { return pred->getKind() == Kind; }
};
//===----------------------------------------------------------------------===//
// Positions
//===----------------------------------------------------------------------===//
struct OperationPosition;
/// A position describes a value on the input IR on which a predicate may be
/// applied, such as an operation or attribute. This enables re-use between
/// predicates, and assists generating bytecode and memory management.
///
/// Operation positions form the base of other positions, which are formed
/// relative to a parent operation, e.g. OperandPosition<[0] -> 1>. Operations
/// are indexed by child index: [0, 1, 2] refers to the 3rd child of the 2nd
/// child of the root operation.
///
/// Positions are linked to their parent position, which describes how to obtain
/// a positional value. As a concrete example, getting OperationPosition<[0, 1]>
/// would be `root->getOperand(1)->getDefiningOp()`, so its parent is
/// OperandPosition<[0] -> 1>, whose parent is OperationPosition<[0]>.
class Position : public StorageUniquer::BaseStorage {
public:
explicit Position(Predicates::Kind kind) : kind(kind) {}
virtual ~Position();
/// Returns the base node position. This is an array of indices.
virtual ArrayRef<unsigned> getIndex() const = 0;
/// Returns the parent position. The root operation position has no parent.
Position *getParent() const { return parent; }
/// Returns the kind of this position.
Predicates::Kind getKind() const { return kind; }
protected:
/// Link to the parent position.
Position *parent = nullptr;
private:
/// The kind of this position.
Predicates::Kind kind;
};
//===----------------------------------------------------------------------===//
// AttributePosition
/// A position describing an attribute of an operation.
struct AttributePosition
: public PredicateBase<AttributePosition, Position,
std::pair<OperationPosition *, Identifier>,
Predicates::AttributePos> {
explicit AttributePosition(const KeyTy &key);
/// Returns the index of this position.
ArrayRef<unsigned> getIndex() const final { return parent->getIndex(); }
/// Returns the attribute name of this position.
Identifier getName() const { return key.second; }
};
//===----------------------------------------------------------------------===//
// OperandPosition
/// A position describing an operand of an operation.
struct OperandPosition
: public PredicateBase<OperandPosition, Position,
std::pair<OperationPosition *, unsigned>,
Predicates::OperandPos> {
explicit OperandPosition(const KeyTy &key);
/// Returns the index of this position.
ArrayRef<unsigned> getIndex() const final { return parent->getIndex(); }
/// Returns the operand number of this position.
unsigned getOperandNumber() const { return key.second; }
};
//===----------------------------------------------------------------------===//
// OperationPosition
/// An operation position describes an operation node in the IR. Other position
/// kinds are formed with respect to an operation position.
struct OperationPosition
: public PredicateBase<OperationPosition, Position, ArrayRef<unsigned>,
Predicates::OperationPos> {
using Base::Base;
/// Gets the root position, which is always [0].
static OperationPosition *getRoot(StorageUniquer &uniquer) {
return get(uniquer, ArrayRef<unsigned>(0));
}
/// Gets a node position for the given index.
static OperationPosition *get(StorageUniquer &uniquer,
ArrayRef<unsigned> index);
/// Constructs an instance with the given storage allocator.
static OperationPosition *construct(StorageUniquer::StorageAllocator &alloc,
ArrayRef<unsigned> key) {
return Base::construct(alloc, alloc.copyInto(key));
}
/// Returns the index of this position.
ArrayRef<unsigned> getIndex() const final { return key; }
/// Returns if this operation position corresponds to the root.
bool isRoot() const { return key.size() == 1 && key[0] == 0; }
};
//===----------------------------------------------------------------------===//
// ResultPosition
/// A position describing a result of an operation.
struct ResultPosition
: public PredicateBase<ResultPosition, Position,
std::pair<OperationPosition *, unsigned>,
Predicates::ResultPos> {
explicit ResultPosition(const KeyTy &key) : Base(key) { parent = key.first; }
/// Returns the index of this position.
ArrayRef<unsigned> getIndex() const final { return key.first->getIndex(); }
/// Returns the result number of this position.
unsigned getResultNumber() const { return key.second; }
};
//===----------------------------------------------------------------------===//
// TypePosition
/// A position describing the result type of an entity, i.e. an Attribute,
/// Operand, Result, etc.
struct TypePosition : public PredicateBase<TypePosition, Position, Position *,
Predicates::TypePos> {
explicit TypePosition(const KeyTy &key) : Base(key) {
assert((isa<AttributePosition>(key) || isa<OperandPosition>(key) ||
isa<ResultPosition>(key)) &&
"expected parent to be an attribute, operand, or result");
parent = key;
}
/// Returns the index of this position.
ArrayRef<unsigned> getIndex() const final { return key->getIndex(); }
};
//===----------------------------------------------------------------------===//
// Qualifiers
//===----------------------------------------------------------------------===//
/// An ordinal predicate consists of a "Question" and a set of acceptable
/// "Answers" (later converted to ordinal values). A predicate will query some
/// property of a positional value and decide what to do based on the result.
///
/// This makes top-level predicate representations ordinal (SwitchOp). Later,
/// predicates that end up with only one acceptable answer (including all
/// boolean kinds) will be converted to boolean predicates (PredicateOp) in the
/// matcher.
///
/// For simplicity, both are represented as "qualifiers", with a base kind and
/// perhaps additional properties. For example, all OperationName predicates ask
/// the same question, but GenericConstraint predicates may ask different ones.
class Qualifier : public StorageUniquer::BaseStorage {
public:
explicit Qualifier(Predicates::Kind kind) : kind(kind) {}
/// Returns the kind of this qualifier.
Predicates::Kind getKind() const { return kind; }
private:
/// The kind of this position.
Predicates::Kind kind;
};
//===----------------------------------------------------------------------===//
// Answers
/// An Answer representing an `Attribute` value.
struct AttributeAnswer
: public PredicateBase<AttributeAnswer, Qualifier, Attribute,
Predicates::AttributeAnswer> {
using Base::Base;
};
/// An Answer representing an `OperationName` value.
struct OperationNameAnswer
: public PredicateBase<OperationNameAnswer, Qualifier, OperationName,
Predicates::OperationNameAnswer> {
using Base::Base;
};
/// An Answer representing a boolean `true` value.
struct TrueAnswer
: PredicateBase<TrueAnswer, Qualifier, void, Predicates::TrueAnswer> {
using Base::Base;
};
/// An Answer representing a `Type` value.
struct TypeAnswer : public PredicateBase<TypeAnswer, Qualifier, Type,
Predicates::TypeAnswer> {
using Base::Base;
};
/// An Answer representing an unsigned value.
struct UnsignedAnswer
: public PredicateBase<UnsignedAnswer, Qualifier, unsigned,
Predicates::UnsignedAnswer> {
using Base::Base;
};
//===----------------------------------------------------------------------===//
// Questions
/// Compare an `Attribute` to a constant value.
struct AttributeQuestion
: public PredicateBase<AttributeQuestion, Qualifier, void,
Predicates::AttributeQuestion> {};
/// Apply a parameterized constraint to multiple position values.
struct ConstraintQuestion
: public PredicateBase<
ConstraintQuestion, Qualifier,
std::tuple<StringRef, ArrayRef<Position *>, Attribute>,
Predicates::ConstraintQuestion> {
using Base::Base;
/// Construct an instance with the given storage allocator.
static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc,
KeyTy key) {
return Base::construct(alloc, KeyTy{alloc.copyInto(std::get<0>(key)),
alloc.copyInto(std::get<1>(key)),
std::get<2>(key)});
}
};
/// Compare the equality of two values.
struct EqualToQuestion
: public PredicateBase<EqualToQuestion, Qualifier, Position *,
Predicates::EqualToQuestion> {
using Base::Base;
};
/// Compare a positional value with null, i.e. check if it exists.
struct IsNotNullQuestion
: public PredicateBase<IsNotNullQuestion, Qualifier, void,
Predicates::IsNotNullQuestion> {};
/// Compare the number of operands of an operation with a known value.
struct OperandCountQuestion
: public PredicateBase<OperandCountQuestion, Qualifier, void,
Predicates::OperandCountQuestion> {};
/// Compare the name of an operation with a known value.
struct OperationNameQuestion
: public PredicateBase<OperationNameQuestion, Qualifier, void,
Predicates::OperationNameQuestion> {};
/// Compare the number of results of an operation with a known value.
struct ResultCountQuestion
: public PredicateBase<ResultCountQuestion, Qualifier, void,
Predicates::ResultCountQuestion> {};
/// Compare the type of an attribute or value with a known type.
struct TypeQuestion : public PredicateBase<TypeQuestion, Qualifier, void,
Predicates::TypeQuestion> {};
//===----------------------------------------------------------------------===//
// PredicateUniquer
//===----------------------------------------------------------------------===//
/// This class provides a storage uniquer that is used to allocate predicate
/// instances.
class PredicateUniquer : public StorageUniquer {
public:
PredicateUniquer() {
// Register the types of Positions with the uniquer.
registerParametricStorageType<AttributePosition>();
registerParametricStorageType<OperandPosition>();
registerParametricStorageType<OperationPosition>();
registerParametricStorageType<ResultPosition>();
registerParametricStorageType<TypePosition>();
// Register the types of Questions with the uniquer.
registerParametricStorageType<AttributeAnswer>();
registerParametricStorageType<OperationNameAnswer>();
registerParametricStorageType<TypeAnswer>();
registerParametricStorageType<UnsignedAnswer>();
registerSingletonStorageType<TrueAnswer>();
// Register the types of Answers with the uniquer.
registerParametricStorageType<ConstraintQuestion>();
registerParametricStorageType<EqualToQuestion>();
registerSingletonStorageType<AttributeQuestion>();
registerSingletonStorageType<IsNotNullQuestion>();
registerSingletonStorageType<OperandCountQuestion>();
registerSingletonStorageType<OperationNameQuestion>();
registerSingletonStorageType<ResultCountQuestion>();
registerSingletonStorageType<TypeQuestion>();
}
};
//===----------------------------------------------------------------------===//
// PredicateBuilder
//===----------------------------------------------------------------------===//
/// This class provides utilties for constructing predicates.
class PredicateBuilder {
public:
PredicateBuilder(PredicateUniquer &uniquer, MLIRContext *ctx)
: uniquer(uniquer), ctx(ctx) {}
//===--------------------------------------------------------------------===//
// Positions
//===--------------------------------------------------------------------===//
/// Returns the root operation position.
Position *getRoot() { return OperationPosition::getRoot(uniquer); }
/// Returns the parent position defining the value held by the given operand.
Position *getParent(OperandPosition *p) {
std::vector<unsigned> index = p->getIndex();
index.push_back(p->getOperandNumber());
return OperationPosition::get(uniquer, index);
}
/// Returns an attribute position for an attribute of the given operation.
Position *getAttribute(OperationPosition *p, StringRef name) {
return AttributePosition::get(uniquer, p, Identifier::get(name, ctx));
}
/// Returns an operand position for an operand of the given operation.
Position *getOperand(OperationPosition *p, unsigned operand) {
return OperandPosition::get(uniquer, p, operand);
}
/// Returns a result position for a result of the given operation.
Position *getResult(OperationPosition *p, unsigned result) {
return ResultPosition::get(uniquer, p, result);
}
/// Returns a type position for the given entity.
Position *getType(Position *p) { return TypePosition::get(uniquer, p); }
//===--------------------------------------------------------------------===//
// Qualifiers
//===--------------------------------------------------------------------===//
/// An ordinal predicate consists of a "Question" and a set of acceptable
/// "Answers" (later converted to ordinal values). A predicate will query some
/// property of a positional value and decide what to do based on the result.
using Predicate = std::pair<Qualifier *, Qualifier *>;
/// Create a predicate comparing an attribute to a known value.
Predicate getAttributeConstraint(Attribute attr) {
return {AttributeQuestion::get(uniquer),
AttributeAnswer::get(uniquer, attr)};
}
/// Create a predicate comparing two values.
Predicate getEqualTo(Position *pos) {
return {EqualToQuestion::get(uniquer, pos), TrueAnswer::get(uniquer)};
}
/// Create a predicate that applies a generic constraint.
Predicate getConstraint(StringRef name, ArrayRef<Position *> pos,
Attribute params) {
return {
ConstraintQuestion::get(uniquer, std::make_tuple(name, pos, params)),
TrueAnswer::get(uniquer)};
}
/// Create a predicate comparing a value with null.
Predicate getIsNotNull() {
return {IsNotNullQuestion::get(uniquer), TrueAnswer::get(uniquer)};
}
/// Create a predicate comparing the number of operands of an operation to a
/// known value.
Predicate getOperandCount(unsigned count) {
return {OperandCountQuestion::get(uniquer),
UnsignedAnswer::get(uniquer, count)};
}
/// Create a predicate comparing the name of an operation to a known value.
Predicate getOperationName(StringRef name) {
return {OperationNameQuestion::get(uniquer),
OperationNameAnswer::get(uniquer, OperationName(name, ctx))};
}
/// Create a predicate comparing the number of results of an operation to a
/// known value.
Predicate getResultCount(unsigned count) {
return {ResultCountQuestion::get(uniquer),
UnsignedAnswer::get(uniquer, count)};
}
/// Create a predicate comparing the type of an attribute or value to a known
/// type.
Predicate getTypeConstraint(Type type) {
return {TypeQuestion::get(uniquer), TypeAnswer::get(uniquer, type)};
}
private:
/// The uniquer used when allocating predicate nodes.
PredicateUniquer &uniquer;
/// The current MLIR context.
MLIRContext *ctx;
};
} // end namespace pdl_to_pdl_interp
} // end namespace mlir
#endif // MLIR_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_

View File

@ -0,0 +1,462 @@
//===- PredicateTree.cpp - Predicate tree merging -------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "PredicateTree.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/IR/Module.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
using namespace mlir;
using namespace mlir::pdl_to_pdl_interp;
//===----------------------------------------------------------------------===//
// Predicate List Building
//===----------------------------------------------------------------------===//
/// Compares the depths of two positions.
static bool comparePosDepth(Position *lhs, Position *rhs) {
return lhs->getIndex().size() < rhs->getIndex().size();
}
/// Collect the tree predicates anchored at the given value.
static void getTreePredicates(std::vector<PositionalPredicate> &predList,
Value val, PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs,
Position *pos) {
// Make sure this input value is accessible to the rewrite.
auto it = inputs.try_emplace(val, pos);
// If this is an input value that has been visited in the tree, add a
// constraint to ensure that both instances refer to the same value.
if (!it.second &&
isa<pdl::AttributeOp, pdl::InputOp, pdl::TypeOp>(val.getDefiningOp())) {
auto minMaxPositions = std::minmax(pos, it.first->second, comparePosDepth);
predList.emplace_back(minMaxPositions.second,
builder.getEqualTo(minMaxPositions.first));
return;
}
// Check for a per-position predicate to apply.
switch (pos->getKind()) {
case Predicates::AttributePos: {
assert(val.getType().isa<pdl::AttributeType>() &&
"expected attribute type");
pdl::AttributeOp attr = cast<pdl::AttributeOp>(val.getDefiningOp());
predList.emplace_back(pos, builder.getIsNotNull());
// If the attribute has a type, add a type constraint.
if (Value type = attr.type()) {
getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
// Check for a constant value of the attribute.
} else if (Optional<Attribute> value = attr.value()) {
predList.emplace_back(pos, builder.getAttributeConstraint(*value));
}
break;
}
case Predicates::OperandPos: {
assert(val.getType().isa<pdl::ValueType>() && "expected value type");
// Prevent traversal into a null value.
predList.emplace_back(pos, builder.getIsNotNull());
// If this is a typed input, add a type constraint.
if (auto in = val.getDefiningOp<pdl::InputOp>()) {
if (Value type = in.type()) {
getTreePredicates(predList, type, builder, inputs,
builder.getType(pos));
}
// Otherwise, recurse into the parent node.
} else if (auto parentOp = val.getDefiningOp<pdl::OperationOp>()) {
getTreePredicates(predList, parentOp.op(), builder, inputs,
builder.getParent(cast<OperandPosition>(pos)));
}
break;
}
case Predicates::OperationPos: {
assert(val.getType().isa<pdl::OperationType>() && "expected operation");
pdl::OperationOp op = cast<pdl::OperationOp>(val.getDefiningOp());
OperationPosition *opPos = cast<OperationPosition>(pos);
// Ensure getDefiningOp returns a non-null operation.
if (!opPos->isRoot())
predList.emplace_back(pos, builder.getIsNotNull());
// Check that this is the correct root operation.
if (Optional<StringRef> opName = op.name())
predList.emplace_back(pos, builder.getOperationName(*opName));
// Check that the operation has the proper number of operands and results.
OperandRange operands = op.operands();
ResultRange results = op.results();
predList.emplace_back(pos, builder.getOperandCount(operands.size()));
predList.emplace_back(pos, builder.getResultCount(results.size()));
// Recurse into any attributes, operands, or results.
for (auto it : llvm::zip(op.attributeNames(), op.attributes())) {
getTreePredicates(
predList, std::get<1>(it), builder, inputs,
builder.getAttribute(opPos,
std::get<0>(it).cast<StringAttr>().getValue()));
}
for (auto operandIt : llvm::enumerate(operands))
getTreePredicates(predList, operandIt.value(), builder, inputs,
builder.getOperand(opPos, operandIt.index()));
// Only recurse into results that are not referenced in the source tree.
for (auto resultIt : llvm::enumerate(results)) {
getTreePredicates(predList, resultIt.value(), builder, inputs,
builder.getResult(opPos, resultIt.index()));
}
break;
}
case Predicates::ResultPos: {
assert(val.getType().isa<pdl::ValueType>() && "expected value type");
pdl::OperationOp parentOp = cast<pdl::OperationOp>(val.getDefiningOp());
// Prevent traversing a null value.
predList.emplace_back(pos, builder.getIsNotNull());
// Traverse the type constraint.
unsigned resultNo = cast<ResultPosition>(pos)->getResultNumber();
getTreePredicates(predList, parentOp.types()[resultNo], builder, inputs,
builder.getType(pos));
break;
}
case Predicates::TypePos: {
assert(val.getType().isa<pdl::TypeType>() && "expected value type");
pdl::TypeOp typeOp = cast<pdl::TypeOp>(val.getDefiningOp());
// Check for a constraint on a constant type.
if (Optional<Type> type = typeOp.type())
predList.emplace_back(pos, builder.getTypeConstraint(*type));
break;
}
default:
llvm_unreachable("unknown position kind");
}
}
/// Collect all of the predicates related to constraints within the given
/// pattern operation.
static void collectConstraintPredicates(
pdl::PatternOp pattern, std::vector<PositionalPredicate> &predList,
PredicateBuilder &builder, DenseMap<Value, Position *> &inputs) {
for (auto op : pattern.body().getOps<pdl::ApplyConstraintOp>()) {
OperandRange arguments = op.args();
ArrayAttr parameters = op.constParamsAttr();
std::vector<Position *> allPositions;
allPositions.reserve(arguments.size());
for (Value arg : arguments)
allPositions.push_back(inputs.lookup(arg));
// Push the constraint to the furthest position.
Position *pos = *std::max_element(allPositions.begin(), allPositions.end(),
comparePosDepth);
PredicateBuilder::Predicate pred =
builder.getConstraint(op.name(), std::move(allPositions), parameters);
predList.emplace_back(pos, pred);
}
}
/// Given a pattern operation, build the set of matcher predicates necessary to
/// match this pattern.
static void buildPredicateList(pdl::PatternOp pattern,
PredicateBuilder &builder,
std::vector<PositionalPredicate> &predList,
DenseMap<Value, Position *> &valueToPosition) {
getTreePredicates(predList, pattern.getRewriter().root(), builder,
valueToPosition, builder.getRoot());
collectConstraintPredicates(pattern, predList, builder, valueToPosition);
}
//===----------------------------------------------------------------------===//
// Pattern Predicate Tree Merging
//===----------------------------------------------------------------------===//
namespace {
/// This class represents a specific predicate applied to a position, and
/// provides hashing and ordering operators. This class allows for computing a
/// frequence sum and ordering predicates based on a cost model.
struct OrderedPredicate {
OrderedPredicate(const std::pair<Position *, Qualifier *> &ip)
: position(ip.first), question(ip.second) {}
OrderedPredicate(const PositionalPredicate &ip)
: position(ip.position), question(ip.question) {}
/// The position this predicate is applied to.
Position *position;
/// The question that is applied by this predicate onto the position.
Qualifier *question;
/// The first and second order benefit sums.
/// The primary sum is the number of occurrences of this predicate among all
/// of the patterns.
unsigned primary = 0;
/// The secondary sum is a squared summation of the primary sum of all of the
/// predicates within each pattern that contains this predicate. This allows
/// for favoring predicates that are more commonly shared within a pattern, as
/// opposed to those shared across patterns.
unsigned secondary = 0;
/// A map between a pattern operation and the answer to the predicate question
/// within that pattern.
DenseMap<Operation *, Qualifier *> patternToAnswer;
/// Returns true if this predicate is ordered before `other`, based on the
/// cost model.
bool operator<(const OrderedPredicate &other) const {
// Sort by:
// * first and secondary order sums
// * lower depth
// * position dependency
// * predicate dependency.
auto *otherPos = other.position;
return std::make_tuple(other.primary, other.secondary,
otherPos->getIndex().size(), otherPos->getKind(),
other.question->getKind()) >
std::make_tuple(primary, secondary, position->getIndex().size(),
position->getKind(), question->getKind());
}
};
/// A DenseMapInfo for OrderedPredicate based solely on the position and
/// question.
struct OrderedPredicateDenseInfo {
using Base = DenseMapInfo<std::pair<Position *, Qualifier *>>;
static OrderedPredicate getEmptyKey() { return Base::getEmptyKey(); }
static OrderedPredicate getTombstoneKey() { return Base::getTombstoneKey(); }
static bool isEqual(const OrderedPredicate &lhs,
const OrderedPredicate &rhs) {
return lhs.position == rhs.position && lhs.question == rhs.question;
}
static unsigned getHashValue(const OrderedPredicate &p) {
return llvm::hash_combine(p.position, p.question);
}
};
/// This class wraps a set of ordered predicates that are used within a specific
/// pattern operation.
struct OrderedPredicateList {
OrderedPredicateList(pdl::PatternOp pattern) : pattern(pattern) {}
pdl::PatternOp pattern;
DenseSet<OrderedPredicate *> predicates;
};
} // end anonymous namespace
/// Returns true if the given matcher refers to the same predicate as the given
/// ordered predicate. This means that the position and questions of the two
/// match.
static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate) {
return node->getPosition() == predicate->position &&
node->getQuestion() == predicate->question;
}
/// Get or insert a child matcher for the given parent switch node, given a
/// predicate and parent pattern.
std::unique_ptr<MatcherNode> &getOrCreateChild(SwitchNode *node,
OrderedPredicate *predicate,
pdl::PatternOp pattern) {
assert(isSamePredicate(node, predicate) &&
"expected matcher to equal the given predicate");
auto it = predicate->patternToAnswer.find(pattern);
assert(it != predicate->patternToAnswer.end() &&
"expected pattern to exist in predicate");
return node->getChildren().insert({it->second, nullptr}).first->second;
}
/// Build the matcher CFG by "pushing" patterns through by sorted predicate
/// order. A pattern will traverse as far as possible using common predicates
/// and then either diverge from the CFG or reach the end of a branch and start
/// creating new nodes.
static void propagatePattern(std::unique_ptr<MatcherNode> &node,
OrderedPredicateList &list,
std::vector<OrderedPredicate *>::iterator current,
std::vector<OrderedPredicate *>::iterator end) {
if (current == end) {
// We've hit the end of a pattern, so create a successful result node.
node = std::make_unique<SuccessNode>(list.pattern, std::move(node));
// If the pattern doesn't contain this predicate, ignore it.
} else if (list.predicates.find(*current) == list.predicates.end()) {
propagatePattern(node, list, std::next(current), end);
// If the current matcher node is invalid, create a new one for this
// position and continue propagation.
} else if (!node) {
// Create a new node at this position and continue
node = std::make_unique<SwitchNode>((*current)->position,
(*current)->question);
propagatePattern(
getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
list, std::next(current), end);
// If the matcher has already been created, and it is for this predicate we
// continue propagation to the child.
} else if (isSamePredicate(node.get(), *current)) {
propagatePattern(
getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
list, std::next(current), end);
// If the matcher doesn't match the current predicate, insert a branch as
// the common set of matchers has diverged.
} else {
propagatePattern(node->getFailureNode(), list, current, end);
}
}
/// Fold any switch nodes nested under `node` to boolean nodes when possible.
/// `node` is updated in-place if it is a switch.
static void foldSwitchToBool(std::unique_ptr<MatcherNode> &node) {
if (!node)
return;
if (SwitchNode *switchNode = dyn_cast<SwitchNode>(&*node)) {
SwitchNode::ChildMapT &children = switchNode->getChildren();
for (auto &it : children)
foldSwitchToBool(it.second);
// If the node only contains one child, collapse it into a boolean predicate
// node.
if (children.size() == 1) {
auto childIt = children.begin();
node = std::make_unique<BoolNode>(
node->getPosition(), node->getQuestion(), childIt->first,
std::move(childIt->second), std::move(node->getFailureNode()));
}
} else if (BoolNode *boolNode = dyn_cast<BoolNode>(&*node)) {
foldSwitchToBool(boolNode->getSuccessNode());
}
foldSwitchToBool(node->getFailureNode());
}
/// Insert an exit node at the end of the failure path of the `root`.
static void insertExitNode(std::unique_ptr<MatcherNode> *root) {
while (*root)
root = &(*root)->getFailureNode();
*root = std::make_unique<ExitNode>();
}
/// Given a module containing PDL pattern operations, generate a matcher tree
/// using the patterns within the given module and return the root matcher node.
std::unique_ptr<MatcherNode>
MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
DenseMap<Value, Position *> &valueToPosition) {
// Collect the set of predicates contained within the pattern operations of
// the module.
SmallVector<std::pair<pdl::PatternOp, std::vector<PositionalPredicate>>, 16>
patternsAndPredicates;
for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
std::vector<PositionalPredicate> predicateList;
buildPredicateList(pattern, builder, predicateList, valueToPosition);
patternsAndPredicates.emplace_back(pattern, std::move(predicateList));
}
// Associate a pattern result with each unique predicate.
DenseSet<OrderedPredicate, OrderedPredicateDenseInfo> uniqued;
for (auto &patternAndPredList : patternsAndPredicates) {
for (auto &predicate : patternAndPredList.second) {
auto it = uniqued.insert(predicate);
it.first->patternToAnswer.try_emplace(patternAndPredList.first,
predicate.answer);
}
}
// Associate each pattern to a set of its ordered predicates for later lookup.
std::vector<OrderedPredicateList> lists;
lists.reserve(patternsAndPredicates.size());
for (auto &patternAndPredList : patternsAndPredicates) {
OrderedPredicateList list(patternAndPredList.first);
for (auto &predicate : patternAndPredList.second) {
OrderedPredicate *orderedPredicate = &*uniqued.find(predicate);
list.predicates.insert(orderedPredicate);
// Increment the primary sum for each reference to a particular predicate.
++orderedPredicate->primary;
}
lists.push_back(std::move(list));
}
// For a particular pattern, get the total primary sum and add it to the
// secondary sum of each predicate. Square the primary sums to emphasize
// shared predicates within rather than across patterns.
for (auto &list : lists) {
unsigned total = 0;
for (auto *predicate : list.predicates)
total += predicate->primary * predicate->primary;
for (auto *predicate : list.predicates)
predicate->secondary += total;
}
// Sort the set of predicates now that the cost primary and secondary sums
// have been computed.
std::vector<OrderedPredicate *> ordered;
ordered.reserve(uniqued.size());
for (auto &ip : uniqued)
ordered.push_back(&ip);
std::stable_sort(
ordered.begin(), ordered.end(),
[](OrderedPredicate *lhs, OrderedPredicate *rhs) { return *lhs < *rhs; });
// Build the matchers for each of the pattern predicate lists.
std::unique_ptr<MatcherNode> root;
for (OrderedPredicateList &list : lists)
propagatePattern(root, list, ordered.begin(), ordered.end());
// Collapse the graph and insert the exit node.
foldSwitchToBool(root);
insertExitNode(&root);
return root;
}
//===----------------------------------------------------------------------===//
// MatcherNode
//===----------------------------------------------------------------------===//
MatcherNode::MatcherNode(TypeID matcherTypeID, Position *p, Qualifier *q,
std::unique_ptr<MatcherNode> failureNode)
: position(p), question(q), failureNode(std::move(failureNode)),
matcherTypeID(matcherTypeID) {}
//===----------------------------------------------------------------------===//
// BoolNode
//===----------------------------------------------------------------------===//
BoolNode::BoolNode(Position *position, Qualifier *question, Qualifier *answer,
std::unique_ptr<MatcherNode> successNode,
std::unique_ptr<MatcherNode> failureNode)
: MatcherNode(TypeID::get<BoolNode>(), position, question,
std::move(failureNode)),
answer(answer), successNode(std::move(successNode)) {}
//===----------------------------------------------------------------------===//
// SuccessNode
//===----------------------------------------------------------------------===//
SuccessNode::SuccessNode(pdl::PatternOp pattern,
std::unique_ptr<MatcherNode> failureNode)
: MatcherNode(TypeID::get<SuccessNode>(), /*position=*/nullptr,
/*question=*/nullptr, std::move(failureNode)),
pattern(pattern) {}
//===----------------------------------------------------------------------===//
// SwitchNode
//===----------------------------------------------------------------------===//
SwitchNode::SwitchNode(Position *position, Qualifier *question)
: MatcherNode(TypeID::get<SwitchNode>(), position, question) {}

View File

@ -0,0 +1,200 @@
//===- PredicateTree.h - Predicate tree node definitions --------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains definitions for nodes of a tree structure for representing
// the general control flow within a pattern match.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATETREE_H_
#define MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATETREE_H_
#include "Predicate.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "llvm/ADT/MapVector.h"
namespace mlir {
namespace pdl_to_pdl_interp {
class MatcherNode;
/// A PositionalPredicate is a predicate that is associated with a specific
/// positional value.
struct PositionalPredicate {
PositionalPredicate(Position *pos,
const PredicateBuilder::Predicate &predicate)
: position(pos), question(predicate.first), answer(predicate.second) {}
/// The position the predicate is applied to.
Position *position;
/// The question that the predicate applies.
Qualifier *question;
/// The expected answer of the predicate.
Qualifier *answer;
};
//===----------------------------------------------------------------------===//
// MatcherNode
//===----------------------------------------------------------------------===//
/// This class represents the base of a predicate matcher node.
class MatcherNode {
public:
virtual ~MatcherNode() = default;
/// Given a module containing PDL pattern operations, generate a matcher tree
/// using the patterns within the given module and return the root matcher
/// node. `valueToPosition` is a map that is populated with the original
/// pdl values and their corresponding positions in the matcher tree.
static std::unique_ptr<MatcherNode>
generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
DenseMap<Value, Position *> &valueToPosition);
/// Returns the position on which the question predicate should be checked.
Position *getPosition() const { return position; }
/// Returns the predicate checked on this node.
Qualifier *getQuestion() const { return question; }
/// Returns the node that should be visited if this, or a subsequent node
/// fails.
std::unique_ptr<MatcherNode> &getFailureNode() { return failureNode; }
/// Sets the node that should be visited if this, or a subsequent node fails.
void setFailureNode(std::unique_ptr<MatcherNode> node) {
failureNode = std::move(node);
}
/// Returns the unique type ID of this matcher instance. This should not be
/// used directly, and is provided to support type casting.
TypeID getMatcherTypeID() const { return matcherTypeID; }
protected:
MatcherNode(TypeID matcherTypeID, Position *position = nullptr,
Qualifier *question = nullptr,
std::unique_ptr<MatcherNode> failureNode = nullptr);
private:
/// The position on which the predicate should be checked.
Position *position;
/// The predicate that is checked on the given position.
Qualifier *question;
/// The node to visit if this node fails.
std::unique_ptr<MatcherNode> failureNode;
/// An owning store for the failure node if it is owned by this node.
std::unique_ptr<MatcherNode> failureNodeStorage;
/// A unique identifier for the derived matcher node, used for type casting.
TypeID matcherTypeID;
};
//===----------------------------------------------------------------------===//
// BoolNode
/// A BoolNode denotes a question with a boolean-like result. These nodes branch
/// to a single node on a successful result, otherwise defaulting to the failure
/// node.
struct BoolNode : public MatcherNode {
BoolNode(Position *position, Qualifier *question, Qualifier *answer,
std::unique_ptr<MatcherNode> successNode,
std::unique_ptr<MatcherNode> failureNode = nullptr);
/// Returns if the given matcher node is an instance of this class, used to
/// support type casting.
static bool classof(const MatcherNode *node) {
return node->getMatcherTypeID() == TypeID::get<BoolNode>();
}
/// Returns the expected answer of this boolean node.
Qualifier *getAnswer() const { return answer; }
/// Returns the node that should be visited on success.
std::unique_ptr<MatcherNode> &getSuccessNode() { return successNode; }
private:
/// The expected answer of this boolean node.
Qualifier *answer;
/// The next node if this node succeeds. Otherwise, go to the failure node.
std::unique_ptr<MatcherNode> successNode;
};
//===----------------------------------------------------------------------===//
// ExitNode
/// An ExitNode is a special sentinel node that denotes the end of matcher.
struct ExitNode : public MatcherNode {
ExitNode() : MatcherNode(TypeID::get<ExitNode>()) {}
/// Returns if the given matcher node is an instance of this class, used to
/// support type casting.
static bool classof(const MatcherNode *node) {
return node->getMatcherTypeID() == TypeID::get<ExitNode>();
}
};
//===----------------------------------------------------------------------===//
// SuccessNode
/// A SuccessNode denotes that a given high level pattern has successfully been
/// matched. This does not terminate the matcher, as there may be multiple
/// successful matches.
struct SuccessNode : public MatcherNode {
explicit SuccessNode(pdl::PatternOp pattern,
std::unique_ptr<MatcherNode> failureNode);
/// Returns if the given matcher node is an instance of this class, used to
/// support type casting.
static bool classof(const MatcherNode *node) {
return node->getMatcherTypeID() == TypeID::get<SuccessNode>();
}
/// Return the high level pattern operation that is matched with this node.
pdl::PatternOp getPattern() const { return pattern; }
private:
/// The high level pattern operation that was successfully matched with this
/// node.
pdl::PatternOp pattern;
};
//===----------------------------------------------------------------------===//
// SwitchNode
/// A SwitchNode denotes a question with multiple potential results. These nodes
/// branch to a specific node based on the result of the question.
struct SwitchNode : public MatcherNode {
SwitchNode(Position *position, Qualifier *question);
/// Returns if the given matcher node is an instance of this class, used to
/// support type casting.
static bool classof(const MatcherNode *node) {
return node->getMatcherTypeID() == TypeID::get<SwitchNode>();
}
/// Returns the children of this switch node. The children are contained
/// within a mapping between the various case answers to destination matcher
/// nodes.
using ChildMapT = llvm::MapVector<Qualifier *, std::unique_ptr<MatcherNode>>;
ChildMapT &getChildren() { return children; }
private:
/// Switch predicate "answers" select the child. Answers that are not found
/// default to the failure node.
ChildMapT children;
};
} // end namespace pdl_to_pdl_interp
} // end namespace mlir
#endif // MLIR_CONVERSION_PDLTOPDLINTERP_PREDICATETREE_H_

View File

@ -33,6 +33,10 @@ namespace NVVM {
class NVVMDialect;
} // end namespace NVVM
namespace pdl_interp {
class PDLInterpDialect;
} // end namespace pdl_interp
namespace ROCDL {
class ROCDLDialect;
} // end namespace ROCDL

View File

@ -0,0 +1,145 @@
// RUN: mlir-opt -split-input-file -convert-pdl-to-pdl-interp %s | FileCheck %s
// CHECK-LABEL: module @empty_module
module @empty_module {
// CHECK: func @matcher(%{{.*}}: !pdl.operation)
// CHECK-NEXT: pdl_interp.finalize
}
// -----
// CHECK-LABEL: module @simple
module @simple {
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
// CHECK: pdl_interp.check_operation_name of %[[ROOT]] is "foo.op" -> ^bb2, ^bb1
// CHECK: ^bb1:
// CHECK: pdl_interp.finalize
// CHECK: ^bb2:
// CHECK: pdl_interp.check_operand_count of %[[ROOT]] is 0 -> ^bb3, ^bb1
// CHECK: ^bb3:
// CHECK: pdl_interp.check_result_count of %[[ROOT]] is 0 -> ^bb4, ^bb1
// CHECK: ^bb4:
// CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter
// CHECK-SAME: benefit(1), loc([%[[ROOT]]]), root("foo.op") -> ^bb1
// CHECK: module @rewriters
// CHECK: func @pdl_generated_rewriter(%[[REWRITE_ROOT:.*]]: !pdl.operation)
// CHECK: pdl_interp.apply_rewrite "rewriter" on %[[REWRITE_ROOT]]
// CHECK: pdl_interp.finalize
pdl.pattern : benefit(1) {
%root = pdl.operation "foo.op"()
pdl.rewrite %root with "rewriter"
}
}
// -----
// CHECK-LABEL: module @attributes
module @attributes {
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
// Check the value of "attr".
// CHECK-DAG: %[[ATTR:.*]] = pdl_interp.get_attribute "attr" of %[[ROOT]]
// CHECK-DAG: pdl_interp.is_not_null %[[ATTR]] : !pdl.attribute
// CHECK-DAG: pdl_interp.check_attribute %[[ATTR]] is 10 : i64
// Check the type of "attr1".
// CHECK-DAG: %[[ATTR1:.*]] = pdl_interp.get_attribute "attr1" of %[[ROOT]]
// CHECK-DAG: pdl_interp.is_not_null %[[ATTR1]] : !pdl.attribute
// CHECK-DAG: %[[ATTR1_TYPE:.*]] = pdl_interp.get_attribute_type of %[[ATTR1]]
// CHECK-DAG: pdl_interp.check_type %[[ATTR1_TYPE]] is i64
pdl.pattern : benefit(1) {
%type = pdl.type : i64
%attr = pdl.attribute 10 : i64
%attr1 = pdl.attribute : %type
%root = pdl.operation {"attr" = %attr, "attr1" = %attr1}
pdl.rewrite %root with "rewriter"
}
}
// -----
// CHECK-LABEL: module @constraints
module @constraints {
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
// CHECK-DAG: %[[INPUT:.*]] = pdl_interp.get_operand 0 of %[[ROOT]]
// CHECK-DAG: %[[INPUT1:.*]] = pdl_interp.get_operand 1 of %[[ROOT]]
// CHECK: pdl_interp.apply_constraint "multi_constraint" [true](%[[INPUT]], %[[INPUT1]] : !pdl.value, !pdl.value)
pdl.pattern : benefit(1) {
%input0 = pdl.input
%input1 = pdl.input
pdl.apply_constraint "multi_constraint"[true](%input0, %input1 : !pdl.value, !pdl.value)
%root = pdl.operation(%input0, %input1)
pdl.rewrite %root with "rewriter"
}
}
// -----
// CHECK-LABEL: module @inputs
module @inputs {
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
// CHECK-DAG: pdl_interp.check_operand_count of %[[ROOT]] is 2
// Get the input and check the type.
// CHECK-DAG: %[[INPUT:.*]] = pdl_interp.get_operand 0 of %[[ROOT]]
// CHECK-DAG: pdl_interp.is_not_null %[[INPUT]] : !pdl.value
// CHECK-DAG: %[[INPUT_TYPE:.*]] = pdl_interp.get_value_type of %[[INPUT]]
// CHECK-DAG: pdl_interp.check_type %[[INPUT_TYPE]] is i64
// Get the second operand and check that it is equal to the first.
// CHECK-DAG: %[[INPUT1:.*]] = pdl_interp.get_operand 1 of %[[ROOT]]
// CHECK-DAG: pdl_interp.are_equal %[[INPUT]], %[[INPUT1]] : !pdl.value
pdl.pattern : benefit(1) {
%type = pdl.type : i64
%input = pdl.input : %type
%root = pdl.operation(%input, %input)
pdl.rewrite %root with "rewriter"
}
}
// -----
// CHECK-LABEL: module @results
module @results {
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
// CHECK: pdl_interp.check_result_count of %[[ROOT]] is 2
// Get the input and check the type.
// CHECK-DAG: %[[RESULT:.*]] = pdl_interp.get_result 0 of %[[ROOT]]
// CHECK-DAG: pdl_interp.is_not_null %[[RESULT]] : !pdl.value
// CHECK-DAG: %[[RESULT_TYPE:.*]] = pdl_interp.get_value_type of %[[RESULT]]
// CHECK-DAG: pdl_interp.check_type %[[RESULT_TYPE]] is i32
// Get the second operand and check that it is equal to the first.
// CHECK-DAG: %[[RESULT1:.*]] = pdl_interp.get_result 1 of %[[ROOT]]
// CHECK-NOT: pdl_interp.get_value_type of %[[RESULT1]]
pdl.pattern : benefit(1) {
%type1 = pdl.type : i32
%type2 = pdl.type
%root, %results:2 = pdl.operation -> %type1, %type2
pdl.rewrite %root with "rewriter"
}
}
// -----
// CHECK-LABEL: module @switch_result_types
module @switch_result_types {
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
// CHECK: %[[RESULT:.*]] = pdl_interp.get_result 0 of %[[ROOT]]
// CHECK: %[[RESULT_TYPE:.*]] = pdl_interp.get_value_type of %[[RESULT]]
// CHECK: pdl_interp.switch_type %[[RESULT_TYPE]] to [i32, i64]
pdl.pattern : benefit(1) {
%type = pdl.type : i32
%root, %result = pdl.operation -> %type
pdl.rewrite %root with "rewriter"
}
pdl.pattern : benefit(1) {
%type = pdl.type : i64
%root, %result = pdl.operation -> %type
pdl.rewrite %root with "rewriter"
}
}

View File

@ -0,0 +1,202 @@
// RUN: mlir-opt -split-input-file -convert-pdl-to-pdl-interp %s | FileCheck %s
// -----
// CHECK-LABEL: module @external
module @external {
// CHECK: module @rewriters
// CHECK: func @pdl_generated_rewriter(%[[ROOT:.*]]: !pdl.operation, %[[INPUT:.*]]: !pdl.value)
// CHECK: pdl_interp.apply_rewrite "rewriter" [true](%[[INPUT]] : !pdl.value) on %[[ROOT]]
pdl.pattern : benefit(1) {
%input = pdl.input
%root = pdl.operation "foo.op"(%input)
pdl.rewrite %root with "rewriter"[true](%input : !pdl.value)
}
}
// -----
// CHECK-LABEL: module @erase
module @erase {
// CHECK: module @rewriters
// CHECK: func @pdl_generated_rewriter(%[[ROOT:.*]]: !pdl.operation)
// CHECK: pdl_interp.erase %[[ROOT]]
// CHECK: pdl_interp.finalize
pdl.pattern : benefit(1) {
%root = pdl.operation "foo.op"
pdl.rewrite %root {
pdl.erase %root
}
}
}
// -----
// CHECK-LABEL: module @operation_attributes
module @operation_attributes {
// CHECK: module @rewriters
// CHECK: func @pdl_generated_rewriter(%[[ATTR:.*]]: !pdl.attribute, %[[ROOT:.*]]: !pdl.operation)
// CHECK: %[[ATTR1:.*]] = pdl_interp.create_attribute true
// CHECK: pdl_interp.create_operation "foo.op"() {"attr" = %[[ATTR]], "attr1" = %[[ATTR1]]}
pdl.pattern : benefit(1) {
%attr = pdl.attribute
%root = pdl.operation "foo.op" {"attr" = %attr}
pdl.rewrite %root {
%attr1 = pdl.attribute true
%newOp = pdl.operation "foo.op" {"attr" = %attr, "attr1" = %attr1}
pdl.erase %root
}
}
}
// -----
// CHECK-LABEL: module @operation_operands
module @operation_operands {
// CHECK: module @rewriters
// CHECK: func @pdl_generated_rewriter(%[[OPERAND:.*]]: !pdl.value, %[[ROOT:.*]]: !pdl.operation)
// CHECK: %[[NEWOP:.*]] = pdl_interp.create_operation "foo.op"(%[[OPERAND]])
// CHECK: %[[OPERAND1:.*]] = pdl_interp.get_result 0 of %[[NEWOP]]
// CHECK: pdl_interp.create_operation "foo.op2"(%[[OPERAND1]])
pdl.pattern : benefit(1) {
%operand = pdl.input
%root = pdl.operation "foo.op"(%operand)
pdl.rewrite %root {
%type = pdl.type : i32
%newOp, %result = pdl.operation "foo.op"(%operand) -> %type
%newOp1 = pdl.operation "foo.op2"(%result)
pdl.erase %root
}
}
}
// -----
// CHECK-LABEL: module @operation_operands
module @operation_operands {
// CHECK: module @rewriters
// CHECK: func @pdl_generated_rewriter(%[[OPERAND:.*]]: !pdl.value, %[[ROOT:.*]]: !pdl.operation)
// CHECK: %[[NEWOP:.*]] = pdl_interp.create_operation "foo.op"(%[[OPERAND]])
// CHECK: %[[OPERAND1:.*]] = pdl_interp.get_result 0 of %[[NEWOP]]
// CHECK: pdl_interp.create_operation "foo.op2"(%[[OPERAND1]])
pdl.pattern : benefit(1) {
%operand = pdl.input
%root = pdl.operation "foo.op"(%operand)
pdl.rewrite %root {
%type = pdl.type : i32
%newOp, %result = pdl.operation "foo.op"(%operand) -> %type
%newOp1 = pdl.operation "foo.op2"(%result)
pdl.erase %root
}
}
}
// -----
// CHECK-LABEL: module @operation_result_types
module @operation_result_types {
// CHECK: module @rewriters
// CHECK: func @pdl_generated_rewriter(%[[TYPE:.*]]: !pdl.type, %[[TYPE1:.*]]: !pdl.type
// CHECK: pdl_interp.create_operation "foo.op"() -> %[[TYPE]], %[[TYPE1]]
pdl.pattern : benefit(1) {
%rootType = pdl.type
%rootType1 = pdl.type
%root, %results:2 = pdl.operation "foo.op" -> %rootType, %rootType1
pdl.rewrite %root {
%newType1 = pdl.type
%newOp, %newResults:2 = pdl.operation "foo.op" -> %rootType, %newType1
pdl.replace %root with %newOp
}
}
}
// -----
// CHECK-LABEL: module @operation_result_types_infer_from_value_replacement
module @operation_result_types_infer_from_value_replacement {
// CHECK: module @rewriters
// CHECK: func @pdl_generated_rewriter(%[[TYPE:.*]]: !pdl.type
// CHECK: pdl_interp.create_operation "foo.op"() -> %[[TYPE]]
pdl.pattern : benefit(1) {
%rootType = pdl.type
%root, %result = pdl.operation "foo.op" -> %rootType
pdl.rewrite %root {
%newType = pdl.type
%newOp, %newResult = pdl.operation "foo.op" -> %newType
pdl.replace %root with (%newResult)
}
}
}
// -----
// CHECK-LABEL: module @replace_with_op
module @replace_with_op {
// CHECK: module @rewriters
// CHECK: func @pdl_generated_rewriter(%[[ROOT:.*]]: !pdl.operation)
// CHECK: %[[NEWOP:.*]] = pdl_interp.create_operation
// CHECK: %[[OP_RESULT:.*]] = pdl_interp.get_result 0 of %[[NEWOP]]
// CHECK: pdl_interp.replace %[[ROOT]] with(%[[OP_RESULT]])
pdl.pattern : benefit(1) {
%type = pdl.type : i32
%root, %result = pdl.operation "foo.op" -> %type
pdl.rewrite %root {
%newOp, %newResult = pdl.operation "foo.op" -> %type
pdl.replace %root with %newOp
}
}
}
// -----
// CHECK-LABEL: module @replace_with_values
module @replace_with_values {
// CHECK: module @rewriters
// CHECK: func @pdl_generated_rewriter(%[[ROOT:.*]]: !pdl.operation)
// CHECK: %[[NEWOP:.*]] = pdl_interp.create_operation
// CHECK: %[[OP_RESULT:.*]] = pdl_interp.get_result 0 of %[[NEWOP]]
// CHECK: pdl_interp.replace %[[ROOT]] with(%[[OP_RESULT]])
pdl.pattern : benefit(1) {
%type = pdl.type : i32
%root, %result = pdl.operation "foo.op" -> %type
pdl.rewrite %root {
%newOp, %newResult = pdl.operation "foo.op" -> %type
pdl.replace %root with (%newResult)
}
}
}
// -----
// CHECK-LABEL: module @replace_with_no_results
module @replace_with_no_results {
// CHECK: module @rewriters
// CHECK: func @pdl_generated_rewriter(%[[ROOT:.*]]: !pdl.operation)
// CHECK: pdl_interp.create_operation "foo.op"
// CHECK: pdl_interp.erase %[[ROOT]]
pdl.pattern : benefit(1) {
%root = pdl.operation "foo.op"
pdl.rewrite %root {
%newOp = pdl.operation "foo.op"
pdl.replace %root with %newOp
}
}
}
// -----
// CHECK-LABEL: module @create_native
module @create_native {
// CHECK: module @rewriters
// CHECK: func @pdl_generated_rewriter(%[[ROOT:.*]]: !pdl.operation)
// CHECK: %[[TYPE:.*]] = pdl_interp.create_native "functor" [true](%[[ROOT]] : !pdl.operation) : !pdl.type
// CHECK: pdl_interp.create_operation "foo.op"() -> %[[TYPE]]
pdl.pattern : benefit(1) {
%type = pdl.type
%root, %result = pdl.operation "foo.op" -> %type
pdl.rewrite %root {
%newType = pdl.create_native "functor"[true](%root : !pdl.operation) : !pdl.type
%newOp, %newResult = pdl.operation "foo.op" -> %newType
pdl.replace %root with %newOp
}
}
}