mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-11-30 17:21:10 +00:00
[mlir][CallOpInterface] Add setCalleeFromCallable
method
Currently `CallOpInterface` has a method `getCallableForCallee` to have a consistent way to get the callee from an operation with `CallOpInterface`, but missing a consistent way to set a callee for an operation with `CallOpInterface`. A set callee method is useful for transformations that operate on `CallOpInterface`, and change the callee, e.g., a pass that specialize function, which clone the callee, and change the `CallOpInterface`'s callee to the cloned version. Without such method, transformation would need to understand the implementation for every operations with `CallOpInterface`, and have a type switch to handle them. This review adds a method to set callee for operation with `CallOpInterface`. Reviewed By: gysit, zero9178o Differential Revision: https://reviews.llvm.org/D149763
This commit is contained in:
parent
9fca0313f8
commit
a2ab6a5e2b
@ -2357,6 +2357,14 @@ def fir_CallOp : fir_Op<"call",
|
||||
return calling;
|
||||
return getOperand(0);
|
||||
}
|
||||
|
||||
/// Set the callee for this operation.
|
||||
void setCalleeFromCallable(mlir::CallInterfaceCallable callee) {
|
||||
if (auto calling =
|
||||
(*this)->getAttrOfType<mlir::SymbolRefAttr>(getCalleeAttrName()))
|
||||
(*this)->setAttr(getCalleeAttrName(), callee.get<mlir::SymbolRefAttr>());
|
||||
setOperand(0, callee.get<mlir::Value>());
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -728,6 +728,7 @@ interface section goes as follows:
|
||||
|
||||
* `CallOpInterface` - Used to represent operations like 'call'
|
||||
- `CallInterfaceCallable getCallableForCallee()`
|
||||
- `void setCalleeFromCallable(CallInterfaceCallable)`
|
||||
* `CallableOpInterface` - Used to represent the target callee of call.
|
||||
- `Region * getCallableRegion()`
|
||||
- `ArrayRef<Type> getCallableResults()`
|
||||
|
@ -189,6 +189,12 @@ CallInterfaceCallable GenericCallOp::getCallableForCallee() {
|
||||
return getAttrOfType<SymbolRefAttr>("callee");
|
||||
}
|
||||
|
||||
/// Set the callee for the generic call operation, this is required by the call
|
||||
/// interface.
|
||||
void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
|
||||
(*this)->setAttr("callee", callee.get<SymbolRefAttr>());
|
||||
}
|
||||
|
||||
/// Get the argument operands to the called function, this is required by the
|
||||
/// call interface.
|
||||
Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); }
|
||||
|
@ -338,6 +338,12 @@ CallInterfaceCallable GenericCallOp::getCallableForCallee() {
|
||||
return (*this)->getAttrOfType<SymbolRefAttr>("callee");
|
||||
}
|
||||
|
||||
/// Set the callee for the generic call operation, this is required by the call
|
||||
/// interface.
|
||||
void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
|
||||
(*this)->setAttr("callee", callee.get<SymbolRefAttr>());
|
||||
}
|
||||
|
||||
/// Get the argument operands to the called function, this is required by the
|
||||
/// call interface.
|
||||
Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); }
|
||||
|
@ -338,6 +338,12 @@ CallInterfaceCallable GenericCallOp::getCallableForCallee() {
|
||||
return (*this)->getAttrOfType<SymbolRefAttr>("callee");
|
||||
}
|
||||
|
||||
/// Set the callee for the generic call operation, this is required by the call
|
||||
/// interface.
|
||||
void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
|
||||
(*this)->setAttr("callee", callee.get<SymbolRefAttr>());
|
||||
}
|
||||
|
||||
/// Get the argument operands to the called function, this is required by the
|
||||
/// call interface.
|
||||
Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); }
|
||||
|
@ -338,6 +338,12 @@ CallInterfaceCallable GenericCallOp::getCallableForCallee() {
|
||||
return (*this)->getAttrOfType<SymbolRefAttr>("callee");
|
||||
}
|
||||
|
||||
/// Set the callee for the generic call operation, this is required by the call
|
||||
/// interface.
|
||||
void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
|
||||
(*this)->setAttr("callee", callee.get<SymbolRefAttr>());
|
||||
}
|
||||
|
||||
/// Get the argument operands to the called function, this is required by the
|
||||
/// call interface.
|
||||
Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); }
|
||||
|
@ -367,6 +367,12 @@ CallInterfaceCallable GenericCallOp::getCallableForCallee() {
|
||||
return (*this)->getAttrOfType<SymbolRefAttr>("callee");
|
||||
}
|
||||
|
||||
/// Set the callee for the generic call operation, this is required by the call
|
||||
/// interface.
|
||||
void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
|
||||
(*this)->setAttr("callee", callee.get<SymbolRefAttr>());
|
||||
}
|
||||
|
||||
/// Get the argument operands to the called function, this is required by the
|
||||
/// call interface.
|
||||
Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); }
|
||||
|
@ -271,6 +271,11 @@ def Async_CallOp : Async_Op<"call",
|
||||
CallInterfaceCallable getCallableForCallee() {
|
||||
return (*this)->getAttrOfType<SymbolRefAttr>("callee");
|
||||
}
|
||||
|
||||
/// Set the callee for this operation.
|
||||
void setCalleeFromCallable(CallInterfaceCallable callee) {
|
||||
(*this)->setAttr("callee", callee.get<SymbolRefAttr>());
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
|
@ -91,6 +91,11 @@ def CallOp : Func_Op<"call",
|
||||
CallInterfaceCallable getCallableForCallee() {
|
||||
return (*this)->getAttrOfType<SymbolRefAttr>("callee");
|
||||
}
|
||||
|
||||
/// Set the callee for this operation.
|
||||
void setCalleeFromCallable(CallInterfaceCallable callee) {
|
||||
(*this)->setAttr("callee", callee.get<SymbolRefAttr>());
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
@ -153,6 +158,11 @@ def CallIndirectOp : Func_Op<"call_indirect", [
|
||||
|
||||
/// Return the callee of this operation.
|
||||
CallInterfaceCallable getCallableForCallee() { return getCallee(); }
|
||||
|
||||
/// Set the callee for this operation.
|
||||
void setCalleeFromCallable(CallInterfaceCallable callee) {
|
||||
setOperand(0, callee.get<Value>());
|
||||
}
|
||||
}];
|
||||
|
||||
let hasCanonicalizeMethod = 1;
|
||||
|
@ -372,6 +372,10 @@ def IncludeOp : TransformDialectOp<"include",
|
||||
return getTarget();
|
||||
}
|
||||
|
||||
void setCalleeFromCallable(::mlir::CallInterfaceCallable callee) {
|
||||
setTargetAttr(callee.get<SymbolRefAttr>());
|
||||
}
|
||||
|
||||
::mlir::Operation::operand_range getArgOperands() {
|
||||
return getOperands();
|
||||
}
|
||||
|
@ -40,6 +40,15 @@ def CallOpInterface : OpInterface<"CallOpInterface"> {
|
||||
}],
|
||||
"::mlir::CallInterfaceCallable", "getCallableForCallee"
|
||||
>,
|
||||
InterfaceMethod<[{
|
||||
Sets the callee of this call-like operation. A `callee` is either a
|
||||
reference to a symbol, via SymbolRefAttr, or a reference to a defined
|
||||
SSA value. The type of the `callee` is expected to be the same as the
|
||||
return type of `getCallableForCallee`, e.g., `callee` should be
|
||||
SymbolRefAttr for `func.call`.
|
||||
}],
|
||||
"void", "setCalleeFromCallable", (ins "::mlir::CallInterfaceCallable":$callee)
|
||||
>,
|
||||
InterfaceMethod<[{
|
||||
Returns the operands within this call that are used as arguments to the
|
||||
callee.
|
||||
|
@ -933,6 +933,16 @@ CallInterfaceCallable CallOp::getCallableForCallee() {
|
||||
return getOperand(0);
|
||||
}
|
||||
|
||||
void CallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
|
||||
// Direct call.
|
||||
if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) {
|
||||
auto symRef = callee.get<SymbolRefAttr>();
|
||||
return setCalleeAttr(cast<FlatSymbolRefAttr>(symRef));
|
||||
}
|
||||
// Indirect call, callee Value is the first operand.
|
||||
return setOperand(0, callee.get<Value>());
|
||||
}
|
||||
|
||||
Operation::operand_range CallOp::getArgOperands() {
|
||||
return getOperands().drop_front(getCallee().has_value() ? 0 : 1);
|
||||
}
|
||||
@ -1157,6 +1167,16 @@ CallInterfaceCallable InvokeOp::getCallableForCallee() {
|
||||
return getOperand(0);
|
||||
}
|
||||
|
||||
void InvokeOp::setCalleeFromCallable(CallInterfaceCallable callee) {
|
||||
// Direct call.
|
||||
if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) {
|
||||
auto symRef = callee.get<SymbolRefAttr>();
|
||||
return setCalleeAttr(cast<FlatSymbolRefAttr>(symRef));
|
||||
}
|
||||
// Indirect call, callee Value is the first operand.
|
||||
return setOperand(0, callee.get<Value>());
|
||||
}
|
||||
|
||||
Operation::operand_range InvokeOp::getArgOperands() {
|
||||
return getOperands().drop_front(getCallee().has_value() ? 0 : 1);
|
||||
}
|
||||
|
@ -2576,6 +2576,11 @@ CallInterfaceCallable spirv::FunctionCallOp::getCallableForCallee() {
|
||||
return (*this)->getAttrOfType<SymbolRefAttr>(kCallee);
|
||||
}
|
||||
|
||||
void spirv::FunctionCallOp::setCalleeFromCallable(
|
||||
CallInterfaceCallable callee) {
|
||||
(*this)->setAttr(kCallee, callee.get<SymbolRefAttr>());
|
||||
}
|
||||
|
||||
Operation::operand_range spirv::FunctionCallOp::getArgOperands() {
|
||||
return getArguments();
|
||||
}
|
||||
|
@ -495,11 +495,18 @@ def ConversionCallOp : TEST_Op<"conversion_call_op",
|
||||
let extraClassDeclaration = [{
|
||||
/// Return the callee of this operation.
|
||||
::mlir::CallInterfaceCallable getCallableForCallee();
|
||||
|
||||
/// Set the callee for this operation.
|
||||
void setCalleeFromCallable(::mlir::CallInterfaceCallable);
|
||||
}];
|
||||
let extraClassDefinition = [{
|
||||
::mlir::CallInterfaceCallable $cppClass::getCallableForCallee() {
|
||||
return (*this)->getAttrOfType<::mlir::SymbolRefAttr>("callee");
|
||||
}
|
||||
|
||||
void $cppClass::setCalleeFromCallable(::mlir::CallInterfaceCallable callee) {
|
||||
(*this)->setAttr("callee", callee.get<SymbolRefAttr>());
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user