mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-11-27 07:31:28 +00:00
[mlir][transform] Add an op for replacing values with function calls (#78398)
Adds `transform.func.cast_and_call` that takes a set of inputs and outputs and replaces the uses of those outputs with a call to a function at a specified insertion point. The idea with this operation is to allow users to author independent IR outside of a to-be-compiled module, and then match and replace a slice of the program with a call to the external function. Additionally adds a mechanism for populating a type converter with a set of conversion materialization functions that allow insertion of casts on the inputs/outputs to and from the types of the function signature.
This commit is contained in:
parent
0784b1eefa
commit
42b160356f
@ -12,6 +12,8 @@
|
||||
include "mlir/Dialect/Transform/IR/TransformDialect.td"
|
||||
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
|
||||
include "mlir/Dialect/Transform/IR/TransformTypes.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/IR/RegionKindInterface.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def ApplyFuncToLLVMConversionPatternsOp : Op<Transform_Dialect,
|
||||
@ -26,4 +28,74 @@ def ApplyFuncToLLVMConversionPatternsOp : Op<Transform_Dialect,
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
def CastAndCallOp : Op<Transform_Dialect,
|
||||
"func.cast_and_call",
|
||||
[DeclareOpInterfaceMethods<TransformOpInterface>,
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
||||
AttrSizedOperandSegments,
|
||||
ReportTrackingListenerFailuresOpTrait]
|
||||
# GraphRegionNoTerminator.traits> {
|
||||
let summary = "Casts values to the signature of a function and replaces them "
|
||||
"with a call";
|
||||
let description = [{
|
||||
This transform takes value handles to a set of `inputs` and `outputs` and
|
||||
attempts to cast them to the function signature of the attached function
|
||||
op, then builds a call to the function and replaces the users of the
|
||||
outputs. It is the responsibility of the user to ensure that the slice of
|
||||
the program replaced by this operation makes sense, i.e. there is no
|
||||
verification that the inputs to this operation have any relation to the
|
||||
outputs outside of basic dominance requirements needed for the call.
|
||||
|
||||
The casting materialization functions are specified in the graph region of
|
||||
this op. They must implement the `TypeConverterBuilderOpInterface`. The
|
||||
order of ops within the region is irrelevant.
|
||||
|
||||
The target function can be specified by a symbol name or by a handle to the
|
||||
operation.
|
||||
|
||||
This transform only reads the operand handles and only replaces the users of
|
||||
the outputs with the results of the call. No handles are consumed and no
|
||||
operations are removed. Users are expected to run cleanup separately if
|
||||
desired.
|
||||
|
||||
Warning: The replacement of the uses of the outputs could invalidate certain
|
||||
restricted value handle types (e.g. `transform.block_arg` if it existed, by
|
||||
replacing the use with something not coming from a block argument). The
|
||||
value will still exist in such cases but wouldn't verify against the type.
|
||||
See the discussion here for more information:
|
||||
https://github.com/llvm/llvm-project/pull/78398#discussion_r1455070087
|
||||
|
||||
This transform will emit a silenceable failure if:
|
||||
- The set of outputs isn't unique
|
||||
- The handle for the insertion point does not include exactly one operation
|
||||
- The insertion point op does not dominate any of the output users
|
||||
- The insertion point op is not dominated by any of the inputs
|
||||
- The function signature does not match the number of inputs/outputs
|
||||
|
||||
This transform will emit a definite failure if it fails to resolve the
|
||||
target function, or if it fails to materialize the conversion casts of
|
||||
either the inputs to the function argument types, or the call results to
|
||||
the output types.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TransformHandleTypeInterface:$insertion_point,
|
||||
UnitAttr:$insert_after,
|
||||
Optional<TransformValueHandleTypeInterface>:$inputs,
|
||||
Optional<TransformValueHandleTypeInterface>:$outputs,
|
||||
OptionalAttr<SymbolRefAttr>:$function_name,
|
||||
Optional<TransformHandleTypeInterface>:$function);
|
||||
let results = (outs TransformHandleTypeInterface:$result);
|
||||
let regions = (region MaxSizedRegion<1>:$conversions);
|
||||
|
||||
let assemblyFormat = [{
|
||||
($function_name^)? ($function^)?
|
||||
( `(` $inputs^ `)` )?
|
||||
( `->` $outputs^ )?
|
||||
(`after` $insert_after^):(`before`)? $insertion_point
|
||||
($conversions^)? attr-dict `:` functional-type(operands, results)
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
#endif // FUNC_TRANSFORM_OPS
|
||||
|
@ -18,7 +18,8 @@ include "mlir/IR/OpBase.td"
|
||||
def MemrefToLLVMTypeConverterOp : Op<Transform_Dialect,
|
||||
"apply_conversion_patterns.memref.memref_to_llvm_type_converter",
|
||||
[DeclareOpInterfaceMethods<TypeConverterBuilderOpInterface,
|
||||
["getTypeConverterType"]>]> {
|
||||
["getTypeConverter",
|
||||
"getTypeConverterType"]>]> {
|
||||
let description = [{
|
||||
This operation provides an "LLVMTypeConverter" that lowers memref types to
|
||||
LLVM types.
|
||||
|
@ -169,4 +169,22 @@ def MakeLoopIndependentOp
|
||||
}];
|
||||
}
|
||||
|
||||
def TypeConversionCastShapeDynamicDimsOp : Op<Transform_Dialect,
|
||||
"type_conversion.tensor.cast_shape_dynamic_dims",
|
||||
[DeclareOpInterfaceMethods<TypeConverterBuilderOpInterface,
|
||||
["populateTypeMaterializations"]>]> {
|
||||
let description = [{
|
||||
Populates a type converter with conversion materialization functions that
|
||||
cast a tensor value between two cast-compatible tensors. See `tensor.cast`
|
||||
for more information on cast compatibility between tensors.
|
||||
|
||||
If `ignore_dynamic_info` is not set, this will set an additional constraint
|
||||
that source materializations do not cast dynamic dimensions to static ones.
|
||||
}];
|
||||
let arguments = (ins UnitAttr:$ignore_dynamic_info);
|
||||
|
||||
let assemblyFormat =
|
||||
"(`ignore_dynamic_info` $ignore_dynamic_info^)? attr-dict";
|
||||
}
|
||||
|
||||
#endif // TENSOR_TRANSFORM_OPS
|
||||
|
@ -284,8 +284,14 @@ def TypeConverterBuilderOpInterface
|
||||
: OpInterface<"TypeConverterBuilderOpInterface"> {
|
||||
let description = [{
|
||||
This interface should be implemented by ops that specify a type converter
|
||||
for a dialect conversion. Such ops can be used with
|
||||
"apply_conversion_patterns".
|
||||
for a dialect conversion, or to populate a type converter with
|
||||
conversions.
|
||||
|
||||
When such ops are intended to be used with "apply_conversion_patterns" or
|
||||
other operations that expect a type converter, a non-default implementation
|
||||
of `getTypeConverter` should be implemented. For use with "cast_and_call"
|
||||
like ops that construct a type converter iteratively, non-default
|
||||
`populateTypeMaterializations` should be implemented.
|
||||
}];
|
||||
|
||||
let cppNamespace = "::mlir::transform";
|
||||
@ -297,7 +303,11 @@ def TypeConverterBuilderOpInterface
|
||||
}],
|
||||
/*returnType=*/"std::unique_ptr<::mlir::TypeConverter>",
|
||||
/*name=*/"getTypeConverter",
|
||||
/*arguments=*/(ins)
|
||||
/*arguments=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return std::make_unique<::mlir::TypeConverter>();
|
||||
}]
|
||||
>,
|
||||
StaticInterfaceMethod<
|
||||
/*desc=*/[{
|
||||
@ -310,6 +320,17 @@ def TypeConverterBuilderOpInterface
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{ return "TypeConverter"; }]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Populate the given type converter with source/target materialization
|
||||
functions.
|
||||
}],
|
||||
/*returnType=*/"void",
|
||||
/*name=*/"populateTypeMaterializations",
|
||||
/*arguments=*/(ins "::mlir::TypeConverter &":$converter),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{ return; }]
|
||||
>,
|
||||
];
|
||||
}
|
||||
|
||||
|
@ -15,6 +15,7 @@
|
||||
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformOps.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@ -36,6 +37,196 @@ transform::ApplyFuncToLLVMConversionPatternsOp::verifyTypeConverter(
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CastAndCallOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
DiagnosedSilenceableFailure
|
||||
transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
|
||||
transform::TransformResults &results,
|
||||
transform::TransformState &state) {
|
||||
SmallVector<Value> inputs;
|
||||
if (getInputs())
|
||||
llvm::append_range(inputs, state.getPayloadValues(getInputs()));
|
||||
|
||||
SetVector<Value> outputs;
|
||||
if (getOutputs()) {
|
||||
for (auto output : state.getPayloadValues(getOutputs()))
|
||||
outputs.insert(output);
|
||||
|
||||
// Verify that the set of output values to be replaced is unique.
|
||||
if (outputs.size() !=
|
||||
llvm::range_size(state.getPayloadValues(getOutputs()))) {
|
||||
return emitSilenceableFailure(getLoc())
|
||||
<< "cast and call output values must be unique";
|
||||
}
|
||||
}
|
||||
|
||||
// Get the insertion point for the call.
|
||||
auto insertionOps = state.getPayloadOps(getInsertionPoint());
|
||||
if (!llvm::hasSingleElement(insertionOps)) {
|
||||
return emitSilenceableFailure(getLoc())
|
||||
<< "Only one op can be specified as an insertion point";
|
||||
}
|
||||
bool insertAfter = getInsertAfter();
|
||||
Operation *insertionPoint = *insertionOps.begin();
|
||||
|
||||
// Check that all inputs dominate the insertion point, and the insertion
|
||||
// point dominates all users of the outputs.
|
||||
DominanceInfo dom(insertionPoint);
|
||||
for (Value output : outputs) {
|
||||
for (Operation *user : output.getUsers()) {
|
||||
// If we are inserting after the insertion point operation, the
|
||||
// insertion point operation must properly dominate the user. Otherwise
|
||||
// basic dominance is enough.
|
||||
bool doesDominate = insertAfter
|
||||
? dom.properlyDominates(insertionPoint, user)
|
||||
: dom.dominates(insertionPoint, user);
|
||||
if (!doesDominate) {
|
||||
return emitDefiniteFailure()
|
||||
<< "User " << user << " is not dominated by insertion point "
|
||||
<< insertionPoint;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (Value input : inputs) {
|
||||
// If we are inserting before the insertion point operation, the
|
||||
// input must properly dominate the insertion point operation. Otherwise
|
||||
// basic dominance is enough.
|
||||
bool doesDominate = insertAfter
|
||||
? dom.dominates(input, insertionPoint)
|
||||
: dom.properlyDominates(input, insertionPoint);
|
||||
if (!doesDominate) {
|
||||
return emitDefiniteFailure()
|
||||
<< "input " << input << " does not dominate insertion point "
|
||||
<< insertionPoint;
|
||||
}
|
||||
}
|
||||
|
||||
// Get the function to call. This can either be specified by symbol or as a
|
||||
// transform handle.
|
||||
func::FuncOp targetFunction = nullptr;
|
||||
if (getFunctionName()) {
|
||||
targetFunction = SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
|
||||
insertionPoint, *getFunctionName());
|
||||
if (!targetFunction) {
|
||||
return emitDefiniteFailure()
|
||||
<< "unresolved symbol " << *getFunctionName();
|
||||
}
|
||||
} else if (getFunction()) {
|
||||
auto payloadOps = state.getPayloadOps(getFunction());
|
||||
if (!llvm::hasSingleElement(payloadOps)) {
|
||||
return emitDefiniteFailure() << "requires a single function to call";
|
||||
}
|
||||
targetFunction = dyn_cast<func::FuncOp>(*payloadOps.begin());
|
||||
if (!targetFunction) {
|
||||
return emitDefiniteFailure() << "invalid non-function callee";
|
||||
}
|
||||
} else {
|
||||
llvm_unreachable("Invalid CastAndCall op without a function to call");
|
||||
return emitDefiniteFailure();
|
||||
}
|
||||
|
||||
// Verify that the function argument and result lengths match the inputs and
|
||||
// outputs given to this op.
|
||||
if (targetFunction.getNumArguments() != inputs.size()) {
|
||||
return emitSilenceableFailure(targetFunction.getLoc())
|
||||
<< "mismatch between number of function arguments "
|
||||
<< targetFunction.getNumArguments() << " and number of inputs "
|
||||
<< inputs.size();
|
||||
}
|
||||
if (targetFunction.getNumResults() != outputs.size()) {
|
||||
return emitSilenceableFailure(targetFunction.getLoc())
|
||||
<< "mismatch between number of function results "
|
||||
<< targetFunction->getNumResults() << " and number of outputs "
|
||||
<< outputs.size();
|
||||
}
|
||||
|
||||
// Gather all specified converters.
|
||||
mlir::TypeConverter converter;
|
||||
if (!getRegion().empty()) {
|
||||
for (Operation &op : getRegion().front()) {
|
||||
cast<transform::TypeConverterBuilderOpInterface>(&op)
|
||||
.populateTypeMaterializations(converter);
|
||||
}
|
||||
}
|
||||
|
||||
if (insertAfter)
|
||||
rewriter.setInsertionPointAfter(insertionPoint);
|
||||
else
|
||||
rewriter.setInsertionPoint(insertionPoint);
|
||||
|
||||
for (auto [input, type] :
|
||||
llvm::zip_equal(inputs, targetFunction.getArgumentTypes())) {
|
||||
if (input.getType() != type) {
|
||||
Value newInput = converter.materializeSourceConversion(
|
||||
rewriter, input.getLoc(), type, input);
|
||||
if (!newInput) {
|
||||
return emitDefiniteFailure() << "Failed to materialize conversion of "
|
||||
<< input << " to type " << type;
|
||||
}
|
||||
input = newInput;
|
||||
}
|
||||
}
|
||||
|
||||
auto callOp = rewriter.create<func::CallOp>(insertionPoint->getLoc(),
|
||||
targetFunction, inputs);
|
||||
|
||||
// Cast the call results back to the expected types. If any conversions fail
|
||||
// this is a definite failure as the call has been constructed at this point.
|
||||
for (auto [output, newOutput] :
|
||||
llvm::zip_equal(outputs, callOp.getResults())) {
|
||||
Value convertedOutput = newOutput;
|
||||
if (output.getType() != newOutput.getType()) {
|
||||
convertedOutput = converter.materializeTargetConversion(
|
||||
rewriter, output.getLoc(), output.getType(), newOutput);
|
||||
if (!convertedOutput) {
|
||||
return emitDefiniteFailure()
|
||||
<< "Failed to materialize conversion of " << newOutput
|
||||
<< " to type " << output.getType();
|
||||
}
|
||||
}
|
||||
rewriter.replaceAllUsesExcept(output, convertedOutput, callOp);
|
||||
}
|
||||
results.set(cast<OpResult>(getResult()), {callOp});
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
LogicalResult transform::CastAndCallOp::verify() {
|
||||
if (!getRegion().empty()) {
|
||||
for (Operation &op : getRegion().front()) {
|
||||
if (!isa<transform::TypeConverterBuilderOpInterface>(&op)) {
|
||||
InFlightDiagnostic diag = emitOpError()
|
||||
<< "expected children ops to implement "
|
||||
"TypeConverterBuilderOpInterface";
|
||||
diag.attachNote(op.getLoc()) << "op without interface";
|
||||
return diag;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!getFunction() && !getFunctionName()) {
|
||||
return emitOpError() << "expected a function handle or name to call";
|
||||
}
|
||||
if (getFunction() && getFunctionName()) {
|
||||
return emitOpError() << "function handle and name are mutually exclusive";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
void transform::CastAndCallOp::getEffects(
|
||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
||||
transform::onlyReadsHandle(getInsertionPoint(), effects);
|
||||
if (getInputs())
|
||||
transform::onlyReadsHandle(getInputs(), effects);
|
||||
if (getOutputs())
|
||||
transform::onlyReadsHandle(getOutputs(), effects);
|
||||
if (getFunction())
|
||||
transform::onlyReadsHandle(getFunction(), effects);
|
||||
transform::producesHandle(getResult(), effects);
|
||||
transform::modifiesPayload(effects);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Transform op registration
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -15,6 +15,8 @@
|
||||
#include "mlir/Dialect/Tensor/Utils/Utils.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace tensor;
|
||||
@ -128,6 +130,44 @@ void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns(
|
||||
tensor::populateRewriteAsConstantPatterns(patterns);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TypeConversionCastTensorShapeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void transform::TypeConversionCastShapeDynamicDimsOp::
|
||||
populateTypeMaterializations(TypeConverter &converter) {
|
||||
bool ignoreDynamicInfo = getIgnoreDynamicInfo();
|
||||
converter.addSourceMaterialization([ignoreDynamicInfo](
|
||||
OpBuilder &builder, Type resultType,
|
||||
ValueRange inputs,
|
||||
Location loc) -> std::optional<Value> {
|
||||
if (inputs.size() != 1) {
|
||||
return std::nullopt;
|
||||
}
|
||||
Value input = inputs[0];
|
||||
if (!ignoreDynamicInfo &&
|
||||
!tensor::preservesStaticInformation(resultType, input.getType())) {
|
||||
return std::nullopt;
|
||||
}
|
||||
if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
return builder.create<tensor::CastOp>(loc, resultType, input).getResult();
|
||||
});
|
||||
converter.addTargetMaterialization([](OpBuilder &builder, Type resultType,
|
||||
ValueRange inputs,
|
||||
Location loc) -> std::optional<Value> {
|
||||
if (inputs.size() != 1) {
|
||||
return std::nullopt;
|
||||
}
|
||||
Value input = inputs[0];
|
||||
if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
return builder.create<tensor::CastOp>(loc, resultType, input).getResult();
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MakeLoopIndependentOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -16,10 +16,12 @@
|
||||
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/Dominance.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Verifier.h"
|
||||
#include "mlir/Interfaces/CallInterfaces.h"
|
||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||
#include "mlir/Interfaces/FunctionImplementation.h"
|
||||
#include "mlir/Interfaces/FunctionInterfaces.h"
|
||||
@ -30,11 +32,13 @@
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/ScopeExit.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include <optional>
|
||||
|
||||
#define DEBUG_TYPE "transform-dialect"
|
||||
|
120
mlir/test/Dialect/Func/func-transform.mlir
Normal file
120
mlir/test/Dialect/Func/func-transform.mlir
Normal file
@ -0,0 +1,120 @@
|
||||
// RUN: mlir-opt %s --transform-interpreter -allow-unregistered-dialect --split-input-file | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @basic_cast_and_call
|
||||
func.func @basic_cast_and_call() {
|
||||
// CHECK-NEXT: call @second()
|
||||
"test.foo"() : () -> ()
|
||||
// CHECK-NEXT: test.foo
|
||||
// CHECK-NEXT: call @third()
|
||||
func.return
|
||||
}
|
||||
|
||||
func.func @second() {
|
||||
"test.bar"() : () -> ()
|
||||
func.return
|
||||
}
|
||||
|
||||
func.func private @third()
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
|
||||
%funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||
%f:3 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
|
||||
%foo = transform.structured.match ops{["test.foo"]} in %f#0 : (!transform.any_op) -> !transform.any_op
|
||||
transform.func.cast_and_call @second before %foo : (!transform.any_op) -> !transform.any_op
|
||||
transform.func.cast_and_call %f#2 after %foo : (!transform.any_op, !transform.any_op) -> !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @non_empty_arg_and_out
|
||||
func.func @non_empty_arg_and_out(%arg0 : index) -> i32 {
|
||||
// CHECK-NEXT: %[[FOO:.+]] = "test.foo"
|
||||
%0 = "test.foo"(%arg0) : (index) -> (index)
|
||||
// CHECK-NEXT: %[[CALL:.+]] = call @second(%[[FOO]]) : (index) -> i32
|
||||
%1 = "test.bar"(%0) : (index) -> (i32)
|
||||
// CHECK: return %[[CALL]] : i32
|
||||
func.return %1 : i32
|
||||
}
|
||||
|
||||
func.func private @second(%arg1 : index) -> i32
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
|
||||
%funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||
%f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||
%foo = transform.structured.match ops{["test.foo"]} in %f#0 : (!transform.any_op) -> !transform.any_op
|
||||
%bar = transform.structured.match ops{["test.bar"]} in %f#0 : (!transform.any_op) -> !transform.any_op
|
||||
%in = transform.get_result %foo[0] : (!transform.any_op) -> !transform.any_value
|
||||
%out = transform.get_result %bar[0] : (!transform.any_op) -> !transform.any_value
|
||||
transform.func.cast_and_call %f#1(%in) -> %out before %bar
|
||||
: (!transform.any_op, !transform.any_value,
|
||||
!transform.any_value, !transform.any_op) -> !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @multi_arg_and_result
|
||||
func.func @multi_arg_and_result(%arg0 : index) -> (index, index) {
|
||||
// CHECK-NEXT: %[[FOO:.+]] = "test.foo"
|
||||
%0 = "test.foo"(%arg0) : (index) -> (index)
|
||||
%1 = "test.bar"(%0) : (index) -> (index)
|
||||
%2 = "test.bar"(%0) : (index) -> (index)
|
||||
// CHECK: %[[CALL:.+]]:2 = call @second(%[[FOO]], %[[FOO]]) : (index, index) -> (index, index)
|
||||
// CHECK: return %[[CALL]]#0, %[[CALL]]#1 : index, index
|
||||
func.return %1, %2 : index, index
|
||||
}
|
||||
|
||||
func.func private @second(%arg1: index, %arg2: index) -> (index, index)
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
|
||||
%funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||
%f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||
%foo = transform.structured.match ops{["test.foo"]} in %f#0 : (!transform.any_op) -> !transform.any_op
|
||||
%bars = transform.structured.match ops{["test.bar"]} in %f#0 : (!transform.any_op) -> !transform.any_op
|
||||
%in0 = transform.get_result %foo[0] : (!transform.any_op) -> !transform.any_value
|
||||
%in1 = transform.get_result %foo[0] : (!transform.any_op) -> !transform.any_value
|
||||
%ins = transform.merge_handles %in0, %in1 : !transform.any_value
|
||||
|
||||
%outs = transform.get_result %bars[0] : (!transform.any_op) -> !transform.any_value
|
||||
|
||||
transform.func.cast_and_call %f#1(%ins) -> %outs after %foo
|
||||
: (!transform.any_op, !transform.any_value,
|
||||
!transform.any_value, !transform.any_op) -> !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @nested_call
|
||||
func.func @nested_call() {
|
||||
// CHECK-NEXT: %[[ARG:.+]] = "test.arg"
|
||||
// CHECK-NEXT: test.foo
|
||||
%0 = "test.arg"() : () -> (index)
|
||||
"test.foo"() ({
|
||||
// CHECK-NEXT: call @second(%[[ARG]]) : (index) -> ()
|
||||
"test.bar"(%0) : (index) -> ()
|
||||
}) : () -> ()
|
||||
}
|
||||
|
||||
func.func private @second(%arg1: index) -> ()
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
|
||||
%funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||
%f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||
%arg = transform.structured.match ops{["test.arg"]} in %f#0 : (!transform.any_op) -> !transform.any_op
|
||||
%bar = transform.structured.match ops{["test.bar"]} in %f#0 : (!transform.any_op) -> !transform.any_op
|
||||
%in = transform.get_result %arg[0] : (!transform.any_op) -> !transform.any_value
|
||||
|
||||
transform.func.cast_and_call %f#1(%in) before %bar
|
||||
: (!transform.any_op, !transform.any_value, !transform.any_op) -> !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
}
|
65
mlir/test/Dialect/Tensor/transform-op-casting.mlir
Normal file
65
mlir/test/Dialect/Tensor/transform-op-casting.mlir
Normal file
@ -0,0 +1,65 @@
|
||||
// RUN: mlir-opt %s --transform-interpreter -allow-unregistered-dialect --split-input-file | FileCheck %s
|
||||
|
||||
func.func @cast_to_dynamic(%arg0: tensor<10x13xf32>, %arg1: tensor<3x13xf32>) -> tensor<13x13xf32> {
|
||||
%0 = tensor.concat dim(0) %arg0, %arg1 : (tensor<10x13xf32>, tensor<3x13xf32>) -> tensor<13x13xf32>
|
||||
func.return %0 : tensor<13x13xf32>
|
||||
}
|
||||
|
||||
func.func private @concat_replacement(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
|
||||
%funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||
%f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||
%concat = transform.structured.match ops{["tensor.concat"]} in %f#0 : (!transform.any_op) -> !transform.any_op
|
||||
%ins = transform.get_operand %concat[all] : (!transform.any_op) -> !transform.any_value
|
||||
%out = transform.get_result %concat[all] : (!transform.any_op) -> !transform.any_value
|
||||
transform.func.cast_and_call %f#1(%ins) -> %out before %concat {
|
||||
transform.type_conversion.tensor.cast_shape_dynamic_dims
|
||||
} : (!transform.any_op, !transform.any_value,
|
||||
!transform.any_value, !transform.any_op) -> !transform.any_op
|
||||
transform.apply_dce to %f#0 : !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @cast_to_dynamic
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<10x13xf32>
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<3x13xf32>
|
||||
// CHECK-DAG: %[[CAST0:.+]] = tensor.cast %[[ARG0]] : tensor<10x13xf32> to tensor<?x?xf32>
|
||||
// CHECK-DAG: %[[CAST1:.+]] = tensor.cast %[[ARG1]] : tensor<3x13xf32> to tensor<?x?xf32>
|
||||
// CHECK: %[[CALL:.+]] = call @concat_replacement(%[[CAST0]], %[[CAST1]])
|
||||
// CHECK: %[[CAST_RES:.+]] = tensor.cast %[[CALL]] : tensor<?x?xf32> to tensor<13x13xf32>
|
||||
// CHECK: return %[[CAST_RES]] : tensor<13x13xf32>
|
||||
|
||||
// -----
|
||||
|
||||
func.func @cast_to_static(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
|
||||
%0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<?x?xf32> into tensor<?xf32>
|
||||
func.return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
func.func private @collapse_replacement(%arg0: tensor<4x5xf32>) -> tensor<20xf32>
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
|
||||
%funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||
%f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||
%concat = transform.structured.match ops{["tensor.collapse_shape"]} in %f#0 : (!transform.any_op) -> !transform.any_op
|
||||
%ins = transform.get_operand %concat[all] : (!transform.any_op) -> !transform.any_value
|
||||
%out = transform.get_result %concat[all] : (!transform.any_op) -> !transform.any_value
|
||||
transform.func.cast_and_call %f#1(%ins) -> %out before %concat {
|
||||
transform.type_conversion.tensor.cast_shape_dynamic_dims ignore_dynamic_info
|
||||
} : (!transform.any_op, !transform.any_value,
|
||||
!transform.any_value, !transform.any_op) -> !transform.any_op
|
||||
transform.apply_dce to %f#0 : !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @cast_to_static
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
|
||||
// CHECK-DAG: %[[CAST_IN:.+]] = tensor.cast %[[ARG0]] : tensor<?x?xf32> to tensor<4x5xf32>
|
||||
// CHECK: %[[CALL:.+]] = call @collapse_replacement(%[[CAST_IN]])
|
||||
// CHECK: %[[CAST_RES:.+]] = tensor.cast %[[CALL]] : tensor<20xf32> to tensor<?xf32>
|
||||
// CHECK: return %[[CAST_RES]] : tensor<?xf32>
|
@ -502,7 +502,8 @@ def ApplyTestConversionPatternsOp
|
||||
|
||||
def TestTypeConverterOp
|
||||
: Op<Transform_Dialect, "apply_conversion_patterns.transform.test_type_converter",
|
||||
[DeclareOpInterfaceMethods<TypeConverterBuilderOpInterface>]> {
|
||||
[DeclareOpInterfaceMethods<TypeConverterBuilderOpInterface,
|
||||
["getTypeConverter"]>]> {
|
||||
let arguments = (ins);
|
||||
let results = (outs);
|
||||
let assemblyFormat = "attr-dict";
|
||||
|
Loading…
Reference in New Issue
Block a user