mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-12-04 20:20:54 +00:00
Support NativeCodeCall binding in rewrite pattern.
We are able to bind the result from native function while rewriting pattern. In matching pattern, if we want to get some values back, we can do that by passing parameter as return value placeholder. Besides, add the semantic of '$_self' in NativeCodeCall while matching, it'll be the operation that defines certain operand. Differential Revision: https://reviews.llvm.org/D100746
This commit is contained in:
parent
75f74f2673
commit
34b5482b33
@ -392,26 +392,31 @@ placeholder_.
|
||||
* `$_builder` will be replaced by the current `mlir::PatternRewriter`.
|
||||
* `$_loc` will be replaced by the fused location or custom location (as
|
||||
determined by location directive).
|
||||
* `$_self` will be replaced with the entity `NativeCodeCall` is attached to.
|
||||
* `$_self` will be replaced by the defining operation in a source pattern.
|
||||
|
||||
We have seen how `$_builder` can be used in the above; it allows us to pass a
|
||||
`mlir::Builder` (`mlir::PatternRewriter` is a subclass of `mlir::OpBuilder`,
|
||||
which is a subclass of `mlir::Builder`) to the C++ helper function to use the
|
||||
handy methods on `mlir::Builder`.
|
||||
|
||||
`$_self` is useful when we want to write something in the form of
|
||||
`NativeCodeCall<"...">:$symbol`. For example, if we want to reverse the previous
|
||||
example and decompose the array attribute into two attributes:
|
||||
Here's an example how we should use `$_self` in source pattern,
|
||||
|
||||
```tablegen
|
||||
class getNthAttr<int n> : NativeCodeCall<"$_self[" # n # "]">;
|
||||
|
||||
def : Pat<(OneAttrOp $attr),
|
||||
(TwoAttrOp (getNthAttr<0>:$attr), (getNthAttr<1>:$attr)>;
|
||||
def : Pat<(OneAttrOp (NativeCodeCall<"Foo($_self, &$0)"> I32Attr:$val)),
|
||||
(TwoAttrOp $val, $val)>;
|
||||
```
|
||||
|
||||
In the above, `$_self` is substituted by the attribute bound by `$attr`, which
|
||||
is `OneAttrOp`'s array attribute.
|
||||
In the above, `$_self` is substituted by the defining operation of the first
|
||||
operand of OneAttrOp. Note that we don't support binding name to NativeCodeCall
|
||||
in the source pattern. To carry some return values from helper function, put the
|
||||
names (constraint is optional) in the parameter list and they will be bound to
|
||||
the variables with correspoding type. Then these named must be either passed by
|
||||
reference or a pointer to variable used as argument so that the matched value
|
||||
can be returned. In the same example, `$val` will be bound to a variable with
|
||||
`Attribute` type(as `I32Attr`) and the type of the second argument in Foo()
|
||||
could be `Attribute&` or `Attribute*`. Names with attribute constraints will be
|
||||
captured as Attributes while everything else will be treated as Value.
|
||||
|
||||
Positional placeholders will be substituted by the `dag` object parameters at
|
||||
the `NativeCodeCall` use site. For example, if we define `SomeCall :
|
||||
|
@ -2530,9 +2530,9 @@ class Pat<dag pattern, dag result, list<dag> preds = [],
|
||||
// the wrapped expression can take special placeholders listed below:
|
||||
//
|
||||
// * `$_builder` will be replaced by the current `mlir::PatternRewriter`.
|
||||
// * `$_self` will be replaced with the entity this transformer is attached to.
|
||||
// E.g., with the definition `def transform : NativeCodeCall<"$_self...">`,
|
||||
// `$_self` in `transform:$attr` will be replaced by the value for `$attr`.
|
||||
// * `$_self` will be replaced by the defining operation in a source pattern.
|
||||
// E.g., `NativeCodeCall<"Foo($_self, &$0)> I32Attr:$attr)>`, `$_self` will be
|
||||
// replaced with the defining operation of the first operand of OneArgOp.
|
||||
//
|
||||
// If used as a DAG node, i.e., `(NativeCodeCall<"..."> <arg0>, ..., <argN>)`,
|
||||
// then positional placeholders are also supported; placeholder `$N` in the
|
||||
@ -2542,7 +2542,7 @@ class NativeCodeCall<string expr> {
|
||||
string expression = expr;
|
||||
}
|
||||
|
||||
def ConstantLikeMatcher : NativeCodeCall<"success(matchPattern($0->getResult(0), m_Constant(&$1)))">;
|
||||
def ConstantLikeMatcher : NativeCodeCall<"success(matchPattern($_self->getResult(0), m_Constant(&$0)))">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Rewrite directives
|
||||
|
@ -232,7 +232,7 @@ std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
|
||||
getVarName(name)));
|
||||
}
|
||||
case Kind::Value: {
|
||||
return std::string(formatv("::llvm::ArrayRef<::mlir::Value> {0};\n", name));
|
||||
return std::string(formatv("::mlir::Value {0};\n", name));
|
||||
}
|
||||
case Kind::Result: {
|
||||
// Use the op itself for captured results.
|
||||
@ -626,11 +626,16 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
|
||||
|
||||
if (tree.isNativeCodeCall()) {
|
||||
if (!treeName.empty()) {
|
||||
PrintFatalError(
|
||||
&def,
|
||||
formatv(
|
||||
"binding symbol '{0}' to native code call unsupported right now",
|
||||
treeName));
|
||||
if (!isSrcPattern) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "found symbol bound to NativeCodeCall: "
|
||||
<< treeName << '\n');
|
||||
verifyBind(infoMap.bindValue(treeName), treeName);
|
||||
} else {
|
||||
PrintFatalError(&def,
|
||||
formatv("binding symbol '{0}' to NativecodeCall in "
|
||||
"MatchPattern is not supported",
|
||||
treeName));
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i != numTreeArgs; ++i) {
|
||||
@ -649,24 +654,27 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
|
||||
|
||||
// `$_` is a special symbol meaning ignore the current argument.
|
||||
if (!treeArgName.empty() && treeArgName != "_") {
|
||||
if (tree.isNestedDagArg(i)) {
|
||||
auto err = formatv("cannot bind '{0}' for nested native call arg",
|
||||
treeArgName);
|
||||
PrintFatalError(&def, err);
|
||||
}
|
||||
|
||||
DagLeaf leaf = tree.getArgAsLeaf(i);
|
||||
auto constraint = leaf.getAsConstraint();
|
||||
bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() ||
|
||||
leaf.isConstantAttr() ||
|
||||
constraint.getKind() == Constraint::Kind::CK_Attr;
|
||||
|
||||
if (isAttr) {
|
||||
verifyBind(infoMap.bindAttr(treeArgName), treeArgName);
|
||||
continue;
|
||||
// In (NativeCodeCall<"Foo($_self, $0, $1, $2)"> I8Attr:$a, I8:$b, $c),
|
||||
if (leaf.isUnspecified()) {
|
||||
// This is case of $c, a Value without any constraints.
|
||||
verifyBind(infoMap.bindValue(treeArgName), treeArgName);
|
||||
} else {
|
||||
auto constraint = leaf.getAsConstraint();
|
||||
bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() ||
|
||||
leaf.isConstantAttr() ||
|
||||
constraint.getKind() == Constraint::Kind::CK_Attr;
|
||||
|
||||
if (isAttr) {
|
||||
// This is case of $a, a binding to a certain attribute.
|
||||
verifyBind(infoMap.bindAttr(treeArgName), treeArgName);
|
||||
continue;
|
||||
}
|
||||
|
||||
// This is case of $b, a binding to a certain type.
|
||||
verifyBind(infoMap.bindValue(treeArgName), treeArgName);
|
||||
}
|
||||
|
||||
verifyBind(infoMap.bindValue(treeArgName), treeArgName);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -837,6 +837,20 @@ def : Pattern<(OpNativeCodeCall3 $input),
|
||||
[(NativeCodeCall<"createOpI($_builder, $_loc, $0)"> $input),
|
||||
(OpK)]>;
|
||||
|
||||
def OpNativeCodeCall4 : TEST_Op<"native_code_call4"> {
|
||||
let arguments = (ins AnyType:$input1);
|
||||
let results = (outs I32:$output1, I32:$output2);
|
||||
}
|
||||
def OpNativeCodeCall5 : TEST_Op<"native_code_call5"> {
|
||||
let arguments = (ins I32:$input1, I32:$input2);
|
||||
let results = (outs I32:$output1, I32:$output2);
|
||||
}
|
||||
|
||||
def GetFirstI32Result : NativeCodeCall<"success(getFirstI32Result($_self, $0))">;
|
||||
def BindNativeCodeCallResult : NativeCodeCall<"bindNativeCodeCallResult($0)">;
|
||||
def : Pat<(OpNativeCodeCall4 (GetFirstI32Result $ret)),
|
||||
(OpNativeCodeCall5 (BindNativeCodeCallResult:$native $ret), $native)>;
|
||||
|
||||
// Test AllAttrConstraintsOf.
|
||||
def OpAllAttrConstraint1 : TEST_Op<"all_attr_constraint_of1"> {
|
||||
let arguments = (ins I64ArrayAttr:$attr);
|
||||
|
@ -35,6 +35,15 @@ static void handleNoResultOp(PatternRewriter &rewriter,
|
||||
op.operand());
|
||||
}
|
||||
|
||||
static bool getFirstI32Result(Operation *op, Value &value) {
|
||||
if (!Type(op->getResult(0).getType()).isSignlessInteger(32))
|
||||
return false;
|
||||
value = op->getResult(0);
|
||||
return true;
|
||||
}
|
||||
|
||||
static Value bindNativeCodeCallResult(Value value) { return value; }
|
||||
|
||||
// Test that natives calls are only called once during rewrites.
|
||||
// OpM_Test will return Pi, increased by 1 for each subsequent calls.
|
||||
// This let us check the number of times OpM_Test was called by inspecting
|
||||
|
@ -88,6 +88,20 @@ func @verifyAuxiliaryNativeCodeCall(%arg0: i32) -> (i32) {
|
||||
return %0 : i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: verifyNativeCodeCallBinding
|
||||
func @verifyNativeCodeCallBinding(%arg0 : i32) -> (i32) {
|
||||
%0 = "test.op_k"() : () -> (i32)
|
||||
// CHECK: %[[A:.*]], %[[B:.*]] = "test.native_code_call5"(%1, %1) : (i32, i32) -> (i32, i32)
|
||||
%1, %2 = "test.native_code_call4"(%0) : (i32) -> (i32, i32)
|
||||
%3 = "test.constant"() {value = 1 : i8} : () -> i8
|
||||
// %3 is i8 so it'll fail at GetFirstI32Result match. The operation should
|
||||
// keep the same form.
|
||||
// CHECK: %{{.*}}, %{{.*}} = "test.native_code_call4"({{%.*}}) : (i8) -> (i32, i32)
|
||||
%4, %5 = "test.native_code_call4"(%3) : (i8) -> (i32, i32)
|
||||
// CHECK: return %[[A]]
|
||||
return %1 : i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: verifyAllAttrConstraintOf
|
||||
func @verifyAllAttrConstraintOf() -> (i32, i32, i32) {
|
||||
// CHECK: "test.all_attr_constraint_of2"
|
||||
|
@ -1,5 +1,6 @@
|
||||
// RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR1 %s 2>&1 | FileCheck --check-prefix=ERROR1 %s
|
||||
// RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR2 %s 2>&1 | FileCheck --check-prefix=ERROR2 %s
|
||||
// RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR3 %s 2>&1 | FileCheck --check-prefix=ERROR3 %s
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
@ -16,14 +17,21 @@ def OpB : A_Op<"op_b">, Arguments<(ins AnyInteger, AnyAttr:$value)>, Results<(ou
|
||||
|
||||
#ifdef ERROR1
|
||||
def NativeMatcher : NativeCodeCall<"success(nativeCall($0, $1))">;
|
||||
// ERROR1: [[@LINE+1]]:1: error: binding symbol 'error' to native code call unsupported right now
|
||||
def : Pat<(OpA (NativeMatcher:$error $val), AnyI32Attr:$arg),
|
||||
// ERROR1: [[@LINE+1]]:1: error: NativeCodeCall must have $_self as argument for passing the defining Operation
|
||||
def : Pat<(OpA (NativeMatcher $val), AnyI32Attr:$arg),
|
||||
(OpB $val, $arg)>;
|
||||
#endif
|
||||
|
||||
#ifdef ERROR2
|
||||
def NativeMatcher : NativeCodeCall<"success(nativeCall($0, $1))">;
|
||||
// ERROR2: [[@LINE+1]]:1: error: Matching nested tree in NativeCodecall not support for
|
||||
def NativeMatcher : NativeCodeCall<"success(nativeCall($_self, &$0))">;
|
||||
// ERROR2: [[@LINE+1]]:1: error: binding symbol 'error' to NativecodeCall in MatchPattern is not supported
|
||||
def : Pat<(OpA (NativeMatcher:$error $val), AnyI32Attr:$arg),
|
||||
(OpB $val, $arg)>;
|
||||
#endif
|
||||
|
||||
#ifdef ERROR3
|
||||
def NativeMatcher : NativeCodeCall<"success(nativeCall($_self, $0, $1))">;
|
||||
// ERROR3: [[@LINE+1]]:1: error: Matching nested tree in NativeCodecall not support for
|
||||
def : Pat<(OpA (NativeMatcher (OpB $val, $unused)), AnyI32Attr:$arg),
|
||||
(OpB $val, $arg)>;
|
||||
#endif
|
||||
|
@ -252,7 +252,6 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
|
||||
// TODO(suderman): iterate through arguments, determine their types, output
|
||||
// names.
|
||||
SmallVector<std::string, 8> capture;
|
||||
capture.push_back(opName.str());
|
||||
|
||||
raw_indented_ostream::DelimitedScope scope(os);
|
||||
|
||||
@ -265,8 +264,8 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
|
||||
auto leaf = tree.getArgAsLeaf(i);
|
||||
if (leaf.isAttrMatcher() || leaf.isConstantAttr()) {
|
||||
os << "Attribute " << argName << ";\n";
|
||||
} else if (leaf.isOperandMatcher()) {
|
||||
os << "Operation " << argName << ";\n";
|
||||
} else {
|
||||
os << "Value " << argName << ";\n";
|
||||
}
|
||||
}
|
||||
|
||||
@ -278,20 +277,25 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
|
||||
std::tie(hasLocationDirective, locToUse) = getLocation(tree);
|
||||
|
||||
auto fmt = tree.getNativeCodeTemplate();
|
||||
auto nativeCodeCall =
|
||||
std::string(tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), capture));
|
||||
if (fmt.count("$_self") != 1) {
|
||||
PrintFatalError(loc, "NativeCodeCall must have $_self as argument for "
|
||||
"passing the defining Operation");
|
||||
}
|
||||
|
||||
auto nativeCodeCall = std::string(tgfmt(
|
||||
fmt, &fmtCtx.addSubst("_loc", locToUse).withSelf(opName.str()), capture));
|
||||
|
||||
os << "if (failed(" << nativeCodeCall << ")) return ::mlir::failure();\n";
|
||||
|
||||
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
|
||||
auto name = tree.getArgName(i);
|
||||
if (!name.empty() && name != "_") {
|
||||
os << formatv("{0} = {1};\n", name, capture[i + 1]);
|
||||
os << formatv("{0} = {1};\n", name, capture[i]);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
|
||||
std::string argName = capture[i + 1];
|
||||
std::string argName = capture[i];
|
||||
|
||||
// Handle nested DAG construct first
|
||||
if (DagNode argTree = tree.getArgAsNestedDag(i)) {
|
||||
@ -302,9 +306,18 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
|
||||
}
|
||||
|
||||
DagLeaf leaf = tree.getArgAsLeaf(i);
|
||||
|
||||
// The parameter for native function doesn't bind any constraints.
|
||||
if (leaf.isUnspecified())
|
||||
continue;
|
||||
|
||||
auto constraint = leaf.getAsConstraint();
|
||||
|
||||
auto self = formatv("{0}", argName);
|
||||
std::string self;
|
||||
if (leaf.isAttrMatcher() || leaf.isConstantAttr())
|
||||
self = argName;
|
||||
else
|
||||
self = formatv("{0}.getType()", argName);
|
||||
emitMatchCheck(
|
||||
opName,
|
||||
tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)),
|
||||
@ -362,6 +375,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
|
||||
os << "{\n";
|
||||
|
||||
// Attributes don't count for getODSOperands.
|
||||
// TODO: Operand is a Value, check if we should remove `getDefiningOp()`.
|
||||
os.indent() << formatv(
|
||||
"auto *{0} = "
|
||||
"(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
|
||||
@ -929,7 +943,13 @@ std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree,
|
||||
<< " replacement: " << attrs[i] << "\n");
|
||||
}
|
||||
|
||||
return std::string(tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), attrs));
|
||||
std::string symbol = tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), attrs);
|
||||
if (!tree.getSymbol().empty()) {
|
||||
os << formatv("auto {0} = {1};\n", tree.getSymbol(), symbol);
|
||||
symbol = tree.getSymbol().str();
|
||||
}
|
||||
|
||||
return symbol;
|
||||
}
|
||||
|
||||
int PatternEmitter::getNodeValueCount(DagNode node) {
|
||||
|
Loading…
Reference in New Issue
Block a user