[mlir][Transform] Add updateConversionTarget to ConversionPatternDescriptorOpInterface

This change adds a method to modify the ConversionTarget used during
`transform.apply_conversion_patterns` to the
`ConversionPatternDescriptorOpInterface`. This is needed when the TypeConverter
is used to dictate the dynamic legality of operations, as in "structural"
conversion patterns present in, for example, the SCF and func dialects.

As a first use case/test, this change also adds a
`transform.apply_patterns.scf.structural_conversions` operation to the SCF
dialect.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D158672
This commit is contained in:
Christopher Bate 2023-08-23 14:56:56 -06:00
parent 5a58e98c20
commit e2d39f799b
8 changed files with 105 additions and 9 deletions

View File

@ -27,6 +27,17 @@ def ApplyForLoopCanonicalizationPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
def ApplySCFStructuralConversionPatternsOp : Op<Transform_Dialect,
"apply_conversion_patterns.scf.structural_conversions",
[DeclareOpInterfaceMethods<ConversionPatternDescriptorOpInterface,
["populateConversionTargetRules"]>]> {
let description = [{
Collects patterns for performing structural conversions of SCF operations.
}];
let assemblyFormat = "attr-dict";
}
def Transform_ScfForOp : Transform_ConcreteOpType<"scf.for">;
def GetParentForOp : Op<Transform_Dialect, "loop.get_parent_for",
@ -273,8 +284,8 @@ def TakeAssumedBranchOp : Op<Transform_Dialect, "scf.take_assumed_branch", [
TransformOpInterface, TransformEachOpTrait]> {
let description = [{
Given an scf.if conditional, inject user-defined information that it is
always safe to execute only the if or else branch.
always safe to execute only the if or else branch.
This is achieved by just replacing the scf.if by the content of one of its
branches.

View File

@ -53,6 +53,16 @@ void populateSCFStructuralTypeConversionsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target);
/// Similar to `populateSCFStructuralTypeConversionsAndLegality` but does not
/// populate the conversion target.
void populateSCFStructuralTypeConversions(TypeConverter &typeConverter,
RewritePatternSet &patterns);
/// Updates the ConversionTarget with dynamic legality of SCF operations based
/// on the provided type converter.
void populateSCFStructuralTypeConversionTarget(
const TypeConverter &typeConverter, ConversionTarget &target);
/// Populates the provided pattern set with patterns that do 1:N type
/// conversions on (some) SCF ops. This is intended to be used with
/// applyPartialOneToNConversion.

View File

@ -333,6 +333,23 @@ def ConversionPatternDescriptorOpInterface
/*arguments=*/(ins "::mlir::TypeConverter &":$typeConverter,
"::mlir::RewritePatternSet &":$patterns)
>,
InterfaceMethod<
/*desc=*/[{
Populate the ConversionTarget using the final TypeConverter. The default
implementation is to do nothing. Overriding this method can be useful
in order to setup the ConversionTarget for structural type conversions.
In such a situation, an op's legality depends on using the TypeConverter
to determine whether the op's operand and result types are legal
(defined as converting to themselves).
}],
/*returnType=*/"void",
/*name=*/"populateConversionTargetRules",
/*arguments=*/(ins "const ::mlir::TypeConverter &":$typeConverter,
"::mlir::ConversionTarget &":$conversionTarget),
/*methodBody=*/"",
/*defaultImplementation=*/"return;"
>,
InterfaceMethod<
/*desc=*/[{
Return the type converter to be used with this pattern set. If no

View File

@ -32,6 +32,18 @@ void transform::ApplyForLoopCanonicalizationPatternsOp::populatePatterns(
scf::populateSCFForLoopCanonicalizationPatterns(patterns);
}
void transform::ApplySCFStructuralConversionPatternsOp::populatePatterns(
TypeConverter &typeConverter, RewritePatternSet &patterns) {
scf::populateSCFStructuralTypeConversions(typeConverter, patterns);
}
void transform::ApplySCFStructuralConversionPatternsOp::
populateConversionTargetRules(const TypeConverter &typeConverter,
ConversionTarget &conversionTarget) {
scf::populateSCFStructuralTypeConversionTarget(typeConverter,
conversionTarget);
}
//===----------------------------------------------------------------------===//
// GetParentForOp
//===----------------------------------------------------------------------===//

View File

@ -247,12 +247,15 @@ public:
};
} // namespace
void mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
void mlir::scf::populateSCFStructuralTypeConversions(
TypeConverter &typeConverter, RewritePatternSet &patterns) {
patterns.add<ConvertForOpTypes, ConvertIfOpTypes, ConvertYieldOpTypes,
ConvertWhileOpTypes, ConvertConditionOpTypes>(
typeConverter, patterns.getContext());
}
void mlir::scf::populateSCFStructuralTypeConversionTarget(
const TypeConverter &typeConverter, ConversionTarget &target) {
target.addDynamicallyLegalOp<ForOp, IfOp>([&](Operation *op) {
return typeConverter.isLegal(op->getResultTypes());
});
@ -266,3 +269,10 @@ void mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
target.addDynamicallyLegalOp<WhileOp, ConditionOp>(
[&](Operation *op) { return typeConverter.isLegal(op); });
}
void mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
populateSCFStructuralTypeConversions(typeConverter, patterns);
populateSCFStructuralTypeConversionTarget(typeConverter, target);
}

View File

@ -547,6 +547,12 @@ DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
}
converter = defaultTypeConverter.get();
}
// Add descriptor-specific updates to the conversion target, which may
// depend on the final type converter. In structural converters, the
// legality of types dictates the dynamic legality of an operation.
descriptor.populateConversionTargetRules(*converter, conversionTarget);
descriptor.populatePatterns(*converter, patterns);
}
}

View File

@ -280,3 +280,30 @@ transform.sequence failures(propagate) {
%0 = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.loop.promote_if_one_iteration %0 : !transform.any_op
}
// -----
// CHECK-LABEL: func @test_structural_conversion_patterns(
// CHECK: scf.for {{.*}} -> (memref<f32>) {
func.func @test_structural_conversion_patterns(%a: tensor<f32>) -> tensor<f32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
%0 = scf.for %j = %c0 to %c10 step %c1 iter_args(%arg0 = %a) -> tensor<f32> {
%1 = "test.foo"(%arg0) : (tensor<f32>) -> (tensor<f32>)
scf.yield %1 : tensor<f32>
}
return %0 : tensor<f32>
}
transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_conversion_patterns to %0 {
transform.apply_conversion_patterns.scf.structural_conversions
} with type_converter {
transform.apply_conversion_patterns.transform.test_type_converter
} { partial_conversion } : !transform.any_op
}

View File

@ -956,17 +956,20 @@ namespace {
class TestTypeConverter : public TypeConverter {
public:
TestTypeConverter() {
addConversion([](Type t) { return t; });
addConversion([](RankedTensorType type) -> Type {
return MemRefType::get(type.getShape(), type.getElementType());
});
addSourceMaterialization([&](OpBuilder &builder, Type resultType,
ValueRange inputs,
Location loc) -> std::optional<Value> {
auto unrealizedCastConverter = [&](OpBuilder &builder, Type resultType,
ValueRange inputs,
Location loc) -> std::optional<Value> {
if (inputs.size() != 1)
return std::nullopt;
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
});
};
addSourceMaterialization(unrealizedCastConverter);
addTargetMaterialization(unrealizedCastConverter);
}
};
} // namespace