mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-03-04 08:27:50 +00:00
[mlir][Transform] Add updateConversionTarget
to ConversionPatternDescriptorOpInterface
This change adds a method to modify the ConversionTarget used during `transform.apply_conversion_patterns` to the `ConversionPatternDescriptorOpInterface`. This is needed when the TypeConverter is used to dictate the dynamic legality of operations, as in "structural" conversion patterns present in, for example, the SCF and func dialects. As a first use case/test, this change also adds a `transform.apply_patterns.scf.structural_conversions` operation to the SCF dialect. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D158672
This commit is contained in:
parent
5a58e98c20
commit
e2d39f799b
@ -27,6 +27,17 @@ def ApplyForLoopCanonicalizationPatternsOp : Op<Transform_Dialect,
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
def ApplySCFStructuralConversionPatternsOp : Op<Transform_Dialect,
|
||||
"apply_conversion_patterns.scf.structural_conversions",
|
||||
[DeclareOpInterfaceMethods<ConversionPatternDescriptorOpInterface,
|
||||
["populateConversionTargetRules"]>]> {
|
||||
let description = [{
|
||||
Collects patterns for performing structural conversions of SCF operations.
|
||||
}];
|
||||
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
def Transform_ScfForOp : Transform_ConcreteOpType<"scf.for">;
|
||||
|
||||
def GetParentForOp : Op<Transform_Dialect, "loop.get_parent_for",
|
||||
@ -273,8 +284,8 @@ def TakeAssumedBranchOp : Op<Transform_Dialect, "scf.take_assumed_branch", [
|
||||
TransformOpInterface, TransformEachOpTrait]> {
|
||||
let description = [{
|
||||
Given an scf.if conditional, inject user-defined information that it is
|
||||
always safe to execute only the if or else branch.
|
||||
|
||||
always safe to execute only the if or else branch.
|
||||
|
||||
This is achieved by just replacing the scf.if by the content of one of its
|
||||
branches.
|
||||
|
||||
|
@ -53,6 +53,16 @@ void populateSCFStructuralTypeConversionsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target);
|
||||
|
||||
/// Similar to `populateSCFStructuralTypeConversionsAndLegality` but does not
|
||||
/// populate the conversion target.
|
||||
void populateSCFStructuralTypeConversions(TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
/// Updates the ConversionTarget with dynamic legality of SCF operations based
|
||||
/// on the provided type converter.
|
||||
void populateSCFStructuralTypeConversionTarget(
|
||||
const TypeConverter &typeConverter, ConversionTarget &target);
|
||||
|
||||
/// Populates the provided pattern set with patterns that do 1:N type
|
||||
/// conversions on (some) SCF ops. This is intended to be used with
|
||||
/// applyPartialOneToNConversion.
|
||||
|
@ -333,6 +333,23 @@ def ConversionPatternDescriptorOpInterface
|
||||
/*arguments=*/(ins "::mlir::TypeConverter &":$typeConverter,
|
||||
"::mlir::RewritePatternSet &":$patterns)
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Populate the ConversionTarget using the final TypeConverter. The default
|
||||
implementation is to do nothing. Overriding this method can be useful
|
||||
in order to setup the ConversionTarget for structural type conversions.
|
||||
In such a situation, an op's legality depends on using the TypeConverter
|
||||
to determine whether the op's operand and result types are legal
|
||||
(defined as converting to themselves).
|
||||
|
||||
}],
|
||||
/*returnType=*/"void",
|
||||
/*name=*/"populateConversionTargetRules",
|
||||
/*arguments=*/(ins "const ::mlir::TypeConverter &":$typeConverter,
|
||||
"::mlir::ConversionTarget &":$conversionTarget),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/"return;"
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the type converter to be used with this pattern set. If no
|
||||
|
@ -32,6 +32,18 @@ void transform::ApplyForLoopCanonicalizationPatternsOp::populatePatterns(
|
||||
scf::populateSCFForLoopCanonicalizationPatterns(patterns);
|
||||
}
|
||||
|
||||
void transform::ApplySCFStructuralConversionPatternsOp::populatePatterns(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns) {
|
||||
scf::populateSCFStructuralTypeConversions(typeConverter, patterns);
|
||||
}
|
||||
|
||||
void transform::ApplySCFStructuralConversionPatternsOp::
|
||||
populateConversionTargetRules(const TypeConverter &typeConverter,
|
||||
ConversionTarget &conversionTarget) {
|
||||
scf::populateSCFStructuralTypeConversionTarget(typeConverter,
|
||||
conversionTarget);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GetParentForOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -247,12 +247,15 @@ public:
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target) {
|
||||
void mlir::scf::populateSCFStructuralTypeConversions(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns) {
|
||||
patterns.add<ConvertForOpTypes, ConvertIfOpTypes, ConvertYieldOpTypes,
|
||||
ConvertWhileOpTypes, ConvertConditionOpTypes>(
|
||||
typeConverter, patterns.getContext());
|
||||
}
|
||||
|
||||
void mlir::scf::populateSCFStructuralTypeConversionTarget(
|
||||
const TypeConverter &typeConverter, ConversionTarget &target) {
|
||||
target.addDynamicallyLegalOp<ForOp, IfOp>([&](Operation *op) {
|
||||
return typeConverter.isLegal(op->getResultTypes());
|
||||
});
|
||||
@ -266,3 +269,10 @@ void mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
|
||||
target.addDynamicallyLegalOp<WhileOp, ConditionOp>(
|
||||
[&](Operation *op) { return typeConverter.isLegal(op); });
|
||||
}
|
||||
|
||||
void mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target) {
|
||||
populateSCFStructuralTypeConversions(typeConverter, patterns);
|
||||
populateSCFStructuralTypeConversionTarget(typeConverter, target);
|
||||
}
|
||||
|
@ -547,6 +547,12 @@ DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
|
||||
}
|
||||
converter = defaultTypeConverter.get();
|
||||
}
|
||||
|
||||
// Add descriptor-specific updates to the conversion target, which may
|
||||
// depend on the final type converter. In structural converters, the
|
||||
// legality of types dictates the dynamic legality of an operation.
|
||||
descriptor.populateConversionTargetRules(*converter, conversionTarget);
|
||||
|
||||
descriptor.populatePatterns(*converter, patterns);
|
||||
}
|
||||
}
|
||||
|
@ -280,3 +280,30 @@ transform.sequence failures(propagate) {
|
||||
%0 = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||
transform.loop.promote_if_one_iteration %0 : !transform.any_op
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @test_structural_conversion_patterns(
|
||||
// CHECK: scf.for {{.*}} -> (memref<f32>) {
|
||||
|
||||
func.func @test_structural_conversion_patterns(%a: tensor<f32>) -> tensor<f32> {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%c10 = arith.constant 10 : index
|
||||
%0 = scf.for %j = %c0 to %c10 step %c1 iter_args(%arg0 = %a) -> tensor<f32> {
|
||||
%1 = "test.foo"(%arg0) : (tensor<f32>) -> (tensor<f32>)
|
||||
scf.yield %1 : tensor<f32>
|
||||
}
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
transform.sequence failures(propagate) {
|
||||
^bb1(%arg1: !transform.any_op):
|
||||
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||
transform.apply_conversion_patterns to %0 {
|
||||
transform.apply_conversion_patterns.scf.structural_conversions
|
||||
} with type_converter {
|
||||
transform.apply_conversion_patterns.transform.test_type_converter
|
||||
} { partial_conversion } : !transform.any_op
|
||||
}
|
||||
|
@ -956,17 +956,20 @@ namespace {
|
||||
class TestTypeConverter : public TypeConverter {
|
||||
public:
|
||||
TestTypeConverter() {
|
||||
addConversion([](Type t) { return t; });
|
||||
addConversion([](RankedTensorType type) -> Type {
|
||||
return MemRefType::get(type.getShape(), type.getElementType());
|
||||
});
|
||||
addSourceMaterialization([&](OpBuilder &builder, Type resultType,
|
||||
ValueRange inputs,
|
||||
Location loc) -> std::optional<Value> {
|
||||
auto unrealizedCastConverter = [&](OpBuilder &builder, Type resultType,
|
||||
ValueRange inputs,
|
||||
Location loc) -> std::optional<Value> {
|
||||
if (inputs.size() != 1)
|
||||
return std::nullopt;
|
||||
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
|
||||
.getResult(0);
|
||||
});
|
||||
};
|
||||
addSourceMaterialization(unrealizedCastConverter);
|
||||
addTargetMaterialization(unrealizedCastConverter);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
Loading…
x
Reference in New Issue
Block a user