[mlir][transform] Add an op for replacing values with function calls (#78398)

Adds `transform.func.cast_and_call` that takes a set of inputs and
outputs and replaces the uses of those outputs with a call to a function
at a specified insertion point.

The idea with this operation is to allow users to author independent IR
outside of a to-be-compiled module, and then match and replace a slice
of the program with a call to the external function.

Additionally adds a mechanism for populating a type converter with a set
of conversion materialization functions that allow insertion of
casts on the inputs/outputs to and from the types of the function
signature.
This commit is contained in:
Quinn Dawkins 2024-01-19 10:21:52 -08:00 committed by GitHub
parent 0784b1eefa
commit 42b160356f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 538 additions and 5 deletions

View File

@ -12,6 +12,8 @@
include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
include "mlir/Dialect/Transform/IR/TransformTypes.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/RegionKindInterface.td"
include "mlir/IR/OpBase.td"
def ApplyFuncToLLVMConversionPatternsOp : Op<Transform_Dialect,
@ -26,4 +28,74 @@ def ApplyFuncToLLVMConversionPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
def CastAndCallOp : Op<Transform_Dialect,
"func.cast_and_call",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
AttrSizedOperandSegments,
ReportTrackingListenerFailuresOpTrait]
# GraphRegionNoTerminator.traits> {
let summary = "Casts values to the signature of a function and replaces them "
"with a call";
let description = [{
This transform takes value handles to a set of `inputs` and `outputs` and
attempts to cast them to the function signature of the attached function
op, then builds a call to the function and replaces the users of the
outputs. It is the responsibility of the user to ensure that the slice of
the program replaced by this operation makes sense, i.e. there is no
verification that the inputs to this operation have any relation to the
outputs outside of basic dominance requirements needed for the call.
The casting materialization functions are specified in the graph region of
this op. They must implement the `TypeConverterBuilderOpInterface`. The
order of ops within the region is irrelevant.
The target function can be specified by a symbol name or by a handle to the
operation.
This transform only reads the operand handles and only replaces the users of
the outputs with the results of the call. No handles are consumed and no
operations are removed. Users are expected to run cleanup separately if
desired.
Warning: The replacement of the uses of the outputs could invalidate certain
restricted value handle types (e.g. `transform.block_arg` if it existed, by
replacing the use with something not coming from a block argument). The
value will still exist in such cases but wouldn't verify against the type.
See the discussion here for more information:
https://github.com/llvm/llvm-project/pull/78398#discussion_r1455070087
This transform will emit a silenceable failure if:
- The set of outputs isn't unique
- The handle for the insertion point does not include exactly one operation
- The insertion point op does not dominate any of the output users
- The insertion point op is not dominated by any of the inputs
- The function signature does not match the number of inputs/outputs
This transform will emit a definite failure if it fails to resolve the
target function, or if it fails to materialize the conversion casts of
either the inputs to the function argument types, or the call results to
the output types.
}];
let arguments = (ins
TransformHandleTypeInterface:$insertion_point,
UnitAttr:$insert_after,
Optional<TransformValueHandleTypeInterface>:$inputs,
Optional<TransformValueHandleTypeInterface>:$outputs,
OptionalAttr<SymbolRefAttr>:$function_name,
Optional<TransformHandleTypeInterface>:$function);
let results = (outs TransformHandleTypeInterface:$result);
let regions = (region MaxSizedRegion<1>:$conversions);
let assemblyFormat = [{
($function_name^)? ($function^)?
( `(` $inputs^ `)` )?
( `->` $outputs^ )?
(`after` $insert_after^):(`before`)? $insertion_point
($conversions^)? attr-dict `:` functional-type(operands, results)
}];
let hasVerifier = 1;
}
#endif // FUNC_TRANSFORM_OPS

View File

@ -18,7 +18,8 @@ include "mlir/IR/OpBase.td"
def MemrefToLLVMTypeConverterOp : Op<Transform_Dialect,
"apply_conversion_patterns.memref.memref_to_llvm_type_converter",
[DeclareOpInterfaceMethods<TypeConverterBuilderOpInterface,
["getTypeConverterType"]>]> {
["getTypeConverter",
"getTypeConverterType"]>]> {
let description = [{
This operation provides an "LLVMTypeConverter" that lowers memref types to
LLVM types.

View File

@ -169,4 +169,22 @@ def MakeLoopIndependentOp
}];
}
def TypeConversionCastShapeDynamicDimsOp : Op<Transform_Dialect,
"type_conversion.tensor.cast_shape_dynamic_dims",
[DeclareOpInterfaceMethods<TypeConverterBuilderOpInterface,
["populateTypeMaterializations"]>]> {
let description = [{
Populates a type converter with conversion materialization functions that
cast a tensor value between two cast-compatible tensors. See `tensor.cast`
for more information on cast compatibility between tensors.
If `ignore_dynamic_info` is not set, this will set an additional constraint
that source materializations do not cast dynamic dimensions to static ones.
}];
let arguments = (ins UnitAttr:$ignore_dynamic_info);
let assemblyFormat =
"(`ignore_dynamic_info` $ignore_dynamic_info^)? attr-dict";
}
#endif // TENSOR_TRANSFORM_OPS

View File

@ -284,8 +284,14 @@ def TypeConverterBuilderOpInterface
: OpInterface<"TypeConverterBuilderOpInterface"> {
let description = [{
This interface should be implemented by ops that specify a type converter
for a dialect conversion. Such ops can be used with
"apply_conversion_patterns".
for a dialect conversion, or to populate a type converter with
conversions.
When such ops are intended to be used with "apply_conversion_patterns" or
other operations that expect a type converter, a non-default implementation
of `getTypeConverter` should be implemented. For use with "cast_and_call"
like ops that construct a type converter iteratively, non-default
`populateTypeMaterializations` should be implemented.
}];
let cppNamespace = "::mlir::transform";
@ -297,7 +303,11 @@ def TypeConverterBuilderOpInterface
}],
/*returnType=*/"std::unique_ptr<::mlir::TypeConverter>",
/*name=*/"getTypeConverter",
/*arguments=*/(ins)
/*arguments=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return std::make_unique<::mlir::TypeConverter>();
}]
>,
StaticInterfaceMethod<
/*desc=*/[{
@ -310,6 +320,17 @@ def TypeConverterBuilderOpInterface
/*methodBody=*/"",
/*defaultImplementation=*/[{ return "TypeConverter"; }]
>,
InterfaceMethod<
/*desc=*/[{
Populate the given type converter with source/target materialization
functions.
}],
/*returnType=*/"void",
/*name=*/"populateTypeMaterializations",
/*arguments=*/(ins "::mlir::TypeConverter &":$converter),
/*methodBody=*/"",
/*defaultImplementation=*/[{ return; }]
>,
];
}

View File

@ -15,6 +15,7 @@
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
@ -36,6 +37,196 @@ transform::ApplyFuncToLLVMConversionPatternsOp::verifyTypeConverter(
return success();
}
//===----------------------------------------------------------------------===//
// CastAndCallOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
SmallVector<Value> inputs;
if (getInputs())
llvm::append_range(inputs, state.getPayloadValues(getInputs()));
SetVector<Value> outputs;
if (getOutputs()) {
for (auto output : state.getPayloadValues(getOutputs()))
outputs.insert(output);
// Verify that the set of output values to be replaced is unique.
if (outputs.size() !=
llvm::range_size(state.getPayloadValues(getOutputs()))) {
return emitSilenceableFailure(getLoc())
<< "cast and call output values must be unique";
}
}
// Get the insertion point for the call.
auto insertionOps = state.getPayloadOps(getInsertionPoint());
if (!llvm::hasSingleElement(insertionOps)) {
return emitSilenceableFailure(getLoc())
<< "Only one op can be specified as an insertion point";
}
bool insertAfter = getInsertAfter();
Operation *insertionPoint = *insertionOps.begin();
// Check that all inputs dominate the insertion point, and the insertion
// point dominates all users of the outputs.
DominanceInfo dom(insertionPoint);
for (Value output : outputs) {
for (Operation *user : output.getUsers()) {
// If we are inserting after the insertion point operation, the
// insertion point operation must properly dominate the user. Otherwise
// basic dominance is enough.
bool doesDominate = insertAfter
? dom.properlyDominates(insertionPoint, user)
: dom.dominates(insertionPoint, user);
if (!doesDominate) {
return emitDefiniteFailure()
<< "User " << user << " is not dominated by insertion point "
<< insertionPoint;
}
}
}
for (Value input : inputs) {
// If we are inserting before the insertion point operation, the
// input must properly dominate the insertion point operation. Otherwise
// basic dominance is enough.
bool doesDominate = insertAfter
? dom.dominates(input, insertionPoint)
: dom.properlyDominates(input, insertionPoint);
if (!doesDominate) {
return emitDefiniteFailure()
<< "input " << input << " does not dominate insertion point "
<< insertionPoint;
}
}
// Get the function to call. This can either be specified by symbol or as a
// transform handle.
func::FuncOp targetFunction = nullptr;
if (getFunctionName()) {
targetFunction = SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
insertionPoint, *getFunctionName());
if (!targetFunction) {
return emitDefiniteFailure()
<< "unresolved symbol " << *getFunctionName();
}
} else if (getFunction()) {
auto payloadOps = state.getPayloadOps(getFunction());
if (!llvm::hasSingleElement(payloadOps)) {
return emitDefiniteFailure() << "requires a single function to call";
}
targetFunction = dyn_cast<func::FuncOp>(*payloadOps.begin());
if (!targetFunction) {
return emitDefiniteFailure() << "invalid non-function callee";
}
} else {
llvm_unreachable("Invalid CastAndCall op without a function to call");
return emitDefiniteFailure();
}
// Verify that the function argument and result lengths match the inputs and
// outputs given to this op.
if (targetFunction.getNumArguments() != inputs.size()) {
return emitSilenceableFailure(targetFunction.getLoc())
<< "mismatch between number of function arguments "
<< targetFunction.getNumArguments() << " and number of inputs "
<< inputs.size();
}
if (targetFunction.getNumResults() != outputs.size()) {
return emitSilenceableFailure(targetFunction.getLoc())
<< "mismatch between number of function results "
<< targetFunction->getNumResults() << " and number of outputs "
<< outputs.size();
}
// Gather all specified converters.
mlir::TypeConverter converter;
if (!getRegion().empty()) {
for (Operation &op : getRegion().front()) {
cast<transform::TypeConverterBuilderOpInterface>(&op)
.populateTypeMaterializations(converter);
}
}
if (insertAfter)
rewriter.setInsertionPointAfter(insertionPoint);
else
rewriter.setInsertionPoint(insertionPoint);
for (auto [input, type] :
llvm::zip_equal(inputs, targetFunction.getArgumentTypes())) {
if (input.getType() != type) {
Value newInput = converter.materializeSourceConversion(
rewriter, input.getLoc(), type, input);
if (!newInput) {
return emitDefiniteFailure() << "Failed to materialize conversion of "
<< input << " to type " << type;
}
input = newInput;
}
}
auto callOp = rewriter.create<func::CallOp>(insertionPoint->getLoc(),
targetFunction, inputs);
// Cast the call results back to the expected types. If any conversions fail
// this is a definite failure as the call has been constructed at this point.
for (auto [output, newOutput] :
llvm::zip_equal(outputs, callOp.getResults())) {
Value convertedOutput = newOutput;
if (output.getType() != newOutput.getType()) {
convertedOutput = converter.materializeTargetConversion(
rewriter, output.getLoc(), output.getType(), newOutput);
if (!convertedOutput) {
return emitDefiniteFailure()
<< "Failed to materialize conversion of " << newOutput
<< " to type " << output.getType();
}
}
rewriter.replaceAllUsesExcept(output, convertedOutput, callOp);
}
results.set(cast<OpResult>(getResult()), {callOp});
return DiagnosedSilenceableFailure::success();
}
LogicalResult transform::CastAndCallOp::verify() {
if (!getRegion().empty()) {
for (Operation &op : getRegion().front()) {
if (!isa<transform::TypeConverterBuilderOpInterface>(&op)) {
InFlightDiagnostic diag = emitOpError()
<< "expected children ops to implement "
"TypeConverterBuilderOpInterface";
diag.attachNote(op.getLoc()) << "op without interface";
return diag;
}
}
}
if (!getFunction() && !getFunctionName()) {
return emitOpError() << "expected a function handle or name to call";
}
if (getFunction() && getFunctionName()) {
return emitOpError() << "function handle and name are mutually exclusive";
}
return success();
}
void transform::CastAndCallOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getInsertionPoint(), effects);
if (getInputs())
transform::onlyReadsHandle(getInputs(), effects);
if (getOutputs())
transform::onlyReadsHandle(getOutputs(), effects);
if (getFunction())
transform::onlyReadsHandle(getFunction(), effects);
transform::producesHandle(getResult(), effects);
transform::modifiesPayload(effects);
}
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//

View File

@ -15,6 +15,8 @@
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/IR/Builders.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
using namespace tensor;
@ -128,6 +130,44 @@ void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns(
tensor::populateRewriteAsConstantPatterns(patterns);
}
//===----------------------------------------------------------------------===//
// TypeConversionCastTensorShapeOp
//===----------------------------------------------------------------------===//
void transform::TypeConversionCastShapeDynamicDimsOp::
populateTypeMaterializations(TypeConverter &converter) {
bool ignoreDynamicInfo = getIgnoreDynamicInfo();
converter.addSourceMaterialization([ignoreDynamicInfo](
OpBuilder &builder, Type resultType,
ValueRange inputs,
Location loc) -> std::optional<Value> {
if (inputs.size() != 1) {
return std::nullopt;
}
Value input = inputs[0];
if (!ignoreDynamicInfo &&
!tensor::preservesStaticInformation(resultType, input.getType())) {
return std::nullopt;
}
if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) {
return std::nullopt;
}
return builder.create<tensor::CastOp>(loc, resultType, input).getResult();
});
converter.addTargetMaterialization([](OpBuilder &builder, Type resultType,
ValueRange inputs,
Location loc) -> std::optional<Value> {
if (inputs.size() != 1) {
return std::nullopt;
}
Value input = inputs[0];
if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) {
return std::nullopt;
}
return builder.create<tensor::CastOp>(loc, resultType, input).getResult();
});
}
//===----------------------------------------------------------------------===//
// MakeLoopIndependentOp
//===----------------------------------------------------------------------===//

View File

@ -16,10 +16,12 @@
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
@ -30,11 +32,13 @@
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include <optional>
#define DEBUG_TYPE "transform-dialect"

View File

@ -0,0 +1,120 @@
// RUN: mlir-opt %s --transform-interpreter -allow-unregistered-dialect --split-input-file | FileCheck %s
// CHECK-LABEL: func.func @basic_cast_and_call
func.func @basic_cast_and_call() {
// CHECK-NEXT: call @second()
"test.foo"() : () -> ()
// CHECK-NEXT: test.foo
// CHECK-NEXT: call @third()
func.return
}
func.func @second() {
"test.bar"() : () -> ()
func.return
}
func.func private @third()
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
%funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%f:3 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
%foo = transform.structured.match ops{["test.foo"]} in %f#0 : (!transform.any_op) -> !transform.any_op
transform.func.cast_and_call @second before %foo : (!transform.any_op) -> !transform.any_op
transform.func.cast_and_call %f#2 after %foo : (!transform.any_op, !transform.any_op) -> !transform.any_op
transform.yield
}
}
// -----
// CHECK-LABEL: func.func @non_empty_arg_and_out
func.func @non_empty_arg_and_out(%arg0 : index) -> i32 {
// CHECK-NEXT: %[[FOO:.+]] = "test.foo"
%0 = "test.foo"(%arg0) : (index) -> (index)
// CHECK-NEXT: %[[CALL:.+]] = call @second(%[[FOO]]) : (index) -> i32
%1 = "test.bar"(%0) : (index) -> (i32)
// CHECK: return %[[CALL]] : i32
func.return %1 : i32
}
func.func private @second(%arg1 : index) -> i32
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
%funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%foo = transform.structured.match ops{["test.foo"]} in %f#0 : (!transform.any_op) -> !transform.any_op
%bar = transform.structured.match ops{["test.bar"]} in %f#0 : (!transform.any_op) -> !transform.any_op
%in = transform.get_result %foo[0] : (!transform.any_op) -> !transform.any_value
%out = transform.get_result %bar[0] : (!transform.any_op) -> !transform.any_value
transform.func.cast_and_call %f#1(%in) -> %out before %bar
: (!transform.any_op, !transform.any_value,
!transform.any_value, !transform.any_op) -> !transform.any_op
transform.yield
}
}
// -----
// CHECK-LABEL: func.func @multi_arg_and_result
func.func @multi_arg_and_result(%arg0 : index) -> (index, index) {
// CHECK-NEXT: %[[FOO:.+]] = "test.foo"
%0 = "test.foo"(%arg0) : (index) -> (index)
%1 = "test.bar"(%0) : (index) -> (index)
%2 = "test.bar"(%0) : (index) -> (index)
// CHECK: %[[CALL:.+]]:2 = call @second(%[[FOO]], %[[FOO]]) : (index, index) -> (index, index)
// CHECK: return %[[CALL]]#0, %[[CALL]]#1 : index, index
func.return %1, %2 : index, index
}
func.func private @second(%arg1: index, %arg2: index) -> (index, index)
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
%funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%foo = transform.structured.match ops{["test.foo"]} in %f#0 : (!transform.any_op) -> !transform.any_op
%bars = transform.structured.match ops{["test.bar"]} in %f#0 : (!transform.any_op) -> !transform.any_op
%in0 = transform.get_result %foo[0] : (!transform.any_op) -> !transform.any_value
%in1 = transform.get_result %foo[0] : (!transform.any_op) -> !transform.any_value
%ins = transform.merge_handles %in0, %in1 : !transform.any_value
%outs = transform.get_result %bars[0] : (!transform.any_op) -> !transform.any_value
transform.func.cast_and_call %f#1(%ins) -> %outs after %foo
: (!transform.any_op, !transform.any_value,
!transform.any_value, !transform.any_op) -> !transform.any_op
transform.yield
}
}
// -----
// CHECK-LABEL: func.func @nested_call
func.func @nested_call() {
// CHECK-NEXT: %[[ARG:.+]] = "test.arg"
// CHECK-NEXT: test.foo
%0 = "test.arg"() : () -> (index)
"test.foo"() ({
// CHECK-NEXT: call @second(%[[ARG]]) : (index) -> ()
"test.bar"(%0) : (index) -> ()
}) : () -> ()
}
func.func private @second(%arg1: index) -> ()
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
%funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%arg = transform.structured.match ops{["test.arg"]} in %f#0 : (!transform.any_op) -> !transform.any_op
%bar = transform.structured.match ops{["test.bar"]} in %f#0 : (!transform.any_op) -> !transform.any_op
%in = transform.get_result %arg[0] : (!transform.any_op) -> !transform.any_value
transform.func.cast_and_call %f#1(%in) before %bar
: (!transform.any_op, !transform.any_value, !transform.any_op) -> !transform.any_op
transform.yield
}
}

View File

@ -0,0 +1,65 @@
// RUN: mlir-opt %s --transform-interpreter -allow-unregistered-dialect --split-input-file | FileCheck %s
func.func @cast_to_dynamic(%arg0: tensor<10x13xf32>, %arg1: tensor<3x13xf32>) -> tensor<13x13xf32> {
%0 = tensor.concat dim(0) %arg0, %arg1 : (tensor<10x13xf32>, tensor<3x13xf32>) -> tensor<13x13xf32>
func.return %0 : tensor<13x13xf32>
}
func.func private @concat_replacement(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
%funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%concat = transform.structured.match ops{["tensor.concat"]} in %f#0 : (!transform.any_op) -> !transform.any_op
%ins = transform.get_operand %concat[all] : (!transform.any_op) -> !transform.any_value
%out = transform.get_result %concat[all] : (!transform.any_op) -> !transform.any_value
transform.func.cast_and_call %f#1(%ins) -> %out before %concat {
transform.type_conversion.tensor.cast_shape_dynamic_dims
} : (!transform.any_op, !transform.any_value,
!transform.any_value, !transform.any_op) -> !transform.any_op
transform.apply_dce to %f#0 : !transform.any_op
transform.yield
}
}
// CHECK-LABEL: func.func @cast_to_dynamic
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<10x13xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<3x13xf32>
// CHECK-DAG: %[[CAST0:.+]] = tensor.cast %[[ARG0]] : tensor<10x13xf32> to tensor<?x?xf32>
// CHECK-DAG: %[[CAST1:.+]] = tensor.cast %[[ARG1]] : tensor<3x13xf32> to tensor<?x?xf32>
// CHECK: %[[CALL:.+]] = call @concat_replacement(%[[CAST0]], %[[CAST1]])
// CHECK: %[[CAST_RES:.+]] = tensor.cast %[[CALL]] : tensor<?x?xf32> to tensor<13x13xf32>
// CHECK: return %[[CAST_RES]] : tensor<13x13xf32>
// -----
func.func @cast_to_static(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
%0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<?x?xf32> into tensor<?xf32>
func.return %0 : tensor<?xf32>
}
func.func private @collapse_replacement(%arg0: tensor<4x5xf32>) -> tensor<20xf32>
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
%funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%concat = transform.structured.match ops{["tensor.collapse_shape"]} in %f#0 : (!transform.any_op) -> !transform.any_op
%ins = transform.get_operand %concat[all] : (!transform.any_op) -> !transform.any_value
%out = transform.get_result %concat[all] : (!transform.any_op) -> !transform.any_value
transform.func.cast_and_call %f#1(%ins) -> %out before %concat {
transform.type_conversion.tensor.cast_shape_dynamic_dims ignore_dynamic_info
} : (!transform.any_op, !transform.any_value,
!transform.any_value, !transform.any_op) -> !transform.any_op
transform.apply_dce to %f#0 : !transform.any_op
transform.yield
}
}
// CHECK-LABEL: func.func @cast_to_static
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-DAG: %[[CAST_IN:.+]] = tensor.cast %[[ARG0]] : tensor<?x?xf32> to tensor<4x5xf32>
// CHECK: %[[CALL:.+]] = call @collapse_replacement(%[[CAST_IN]])
// CHECK: %[[CAST_RES:.+]] = tensor.cast %[[CALL]] : tensor<20xf32> to tensor<?xf32>
// CHECK: return %[[CAST_RES]] : tensor<?xf32>

View File

@ -502,7 +502,8 @@ def ApplyTestConversionPatternsOp
def TestTypeConverterOp
: Op<Transform_Dialect, "apply_conversion_patterns.transform.test_type_converter",
[DeclareOpInterfaceMethods<TypeConverterBuilderOpInterface>]> {
[DeclareOpInterfaceMethods<TypeConverterBuilderOpInterface,
["getTypeConverter"]>]> {
let arguments = (ins);
let results = (outs);
let assemblyFormat = "attr-dict";