[mlir] support materialization for 1-1 type conversions

Dialect conversion infrastructure supports 1->N type conversions by requiring
individual conversions to provide facilities to generate operations
retrofitting N values into 1 of the original type when N > 1. This
functionality can also be used to materialize explicit "cast"-like operations,
but it did not support 1->1 type conversions until now. Modify TypeConverter to
support materialization of cast operations for 1-1 conversions.

This also makes materialization specification more extensible following the
same pattern as type conversions. Instead of overloading a virtual function,
users or subclasses of TypeConversion can now register type-specific
materialization callbacks that will be called in order for the given type.

Differential Revision: https://reviews.llvm.org/D79729
This commit is contained in:
Alex Zinenko 2020-06-02 13:24:04 +02:00
parent bff0c56ff9
commit 5c5dafc534
7 changed files with 145 additions and 79 deletions

View File

@ -217,16 +217,20 @@ class TypeConverter {
template <typename ConversionFnT>
void addConversion(ConversionFnT &&callback);
/// This hook allows for materializing a conversion from a set of types into
/// one result type by generating a cast operation of some kind. The generated
/// operation should produce one result, of 'resultType', with the provided
/// 'inputs' as operands. This hook must be overridden when a type conversion
/// Register a materialization function, which must be convertibe to the
/// following form
/// `Optional<Value>(PatternRewriter &, T, ValueRange, Location)`,
/// where `T` is any subclass of `Type`. This function is responsible for
/// creating an operation, using the PatternRewriter and Location provided,
/// that "casts" a range of values into a single value of the given type `T`.
/// It must return a Value of the converted type on success, an `llvm::None`
/// if it failed but other materialization can be attempted, and `nullptr` on
/// unrecoverable failure. It will only be called for (sub)types of `T`.
/// Materialization functions must be provided when a type conversion
/// results in more than one type, or if a type conversion may persist after
/// the conversion has finished.
virtual Operation *materializeConversion(PatternRewriter &rewriter,
Type resultType,
ArrayRef<Value> inputs,
Location loc);
template <typename FnT>
void addMaterialization(FnT &&callback);
};
```

View File

@ -122,11 +122,6 @@ public:
/// pointers to memref descriptors for arguments.
LLVM::LLVMType convertFunctionTypeCWrapper(FunctionType type);
/// Creates descriptor structs from individual values constituting them.
Operation *materializeConversion(PatternRewriter &rewriter, Type type,
ArrayRef<Value> values,
Location loc) override;
/// Gets the LLVM representation of the index type. The returned type is an
/// integer type with the size configured for this type converter.
LLVM::LLVMType getIndexType();

View File

@ -113,6 +113,25 @@ public:
registerConversion(wrapCallback<T>(std::forward<FnT>(callback)));
}
/// Register a materialization function, which must be convertible to the
/// following form:
/// `Optional<Value>(PatternRewriter &, T, ValueRange, Location)`,
/// where `T` is any subclass of `Type`. This function is responsible for
/// creating an operation, using the PatternRewriter and Location provided,
/// that "casts" a range of values into a single value of the given type `T`.
/// It must return a Value of the converted type on success, an `llvm::None`
/// if it failed but other materialization can be attempted, and `nullptr` on
/// unrecoverable failure. It will only be called for (sub)types of `T`.
/// Materialization functions must be provided when a type conversion
/// results in more than one type, or if a type conversion may persist after
/// the conversion has finished.
template <typename FnT,
typename T = typename llvm::function_traits<FnT>::template arg_t<1>>
void addMaterialization(FnT &&callback) {
registerMaterialization(
wrapMaterialization<T>(std::forward<FnT>(callback)));
}
/// Convert the given type. This function should return failure if no valid
/// conversion exists, success otherwise. If the new set of types is empty,
/// the type is removed and any usages of the existing value are expected to
@ -148,18 +167,10 @@ public:
/// valid conversion for the signature on success, None otherwise.
Optional<SignatureConversion> convertBlockSignature(Block *block);
/// This hook allows for materializing a conversion from a set of types into
/// one result type by generating a cast operation of some kind. The generated
/// operation should produce one result, of 'resultType', with the provided
/// 'inputs' as operands. This hook must be overridden when a type conversion
/// results in more than one type, or if a type conversion may persist after
/// the conversion has finished.
virtual Operation *materializeConversion(PatternRewriter &rewriter,
Type resultType,
ArrayRef<Value> inputs,
Location loc) {
llvm_unreachable("expected 'materializeConversion' to be overridden");
}
/// Materialize a conversion from a set of types into one result type by
/// generating a cast operation of some kind.
Value materializeConversion(PatternRewriter &rewriter, Location loc,
Type resultType, ValueRange inputs);
private:
/// The signature of the callback used to convert a type. If the new set of
@ -168,6 +179,9 @@ private:
using ConversionCallbackFn =
std::function<Optional<LogicalResult>(Type, SmallVectorImpl<Type> &)>;
using MaterializationCallbackFn = std::function<Optional<Value>(
PatternRewriter &, Type, ValueRange, Location)>;
/// Generate a wrapper for the given callback. This allows for accepting
/// different callback forms, that all compose into a single version.
/// With callback of form: `Optional<Type>(T)`
@ -204,8 +218,30 @@ private:
conversions.emplace_back(std::move(callback));
}
/// Generate a wrapper for the given materialization callback. The callback
/// may take any subclass of `Type` and the wrapper will check for the target
/// type to be of the expected class before calling the callback.
template <typename T, typename FnT>
MaterializationCallbackFn wrapMaterialization(FnT &&callback) {
return [callback = std::forward<FnT>(callback)](
PatternRewriter &rewriter, Type resultType, ValueRange inputs,
Location loc) -> Optional<Value> {
if (T derivedType = resultType.dyn_cast<T>())
return callback(rewriter, derivedType, inputs, loc);
return llvm::None;
};
}
/// Register a materialization.
void registerMaterialization(MaterializationCallbackFn &&callback) {
materializations.emplace_back(std::move(callback));
}
/// The set of registered conversion functions.
SmallVector<ConversionCallbackFn, 4> conversions;
/// The list of registered materialization functions.
SmallVector<MaterializationCallbackFn, 2> materializations;
};
//===----------------------------------------------------------------------===//

View File

@ -150,6 +150,24 @@ LLVMTypeConverter::LLVMTypeConverter(
// LLVMType is legal, so add a pass-through conversion.
addConversion([](LLVM::LLVMType type) { return type; });
// Materialization for memrefs creates descriptor structs from individual
// values constituting them, when descriptors are used, i.e. more than one
// value represents a memref.
addMaterialization([&](PatternRewriter &rewriter,
UnrankedMemRefType resultType, ValueRange inputs,
Location loc) -> Optional<Value> {
if (inputs.size() == 1)
return llvm::None;
return UnrankedMemRefDescriptor::pack(rewriter, loc, *this, resultType,
inputs);
});
addMaterialization([&](PatternRewriter &rewriter, MemRefType resultType,
ValueRange inputs, Location loc) -> Optional<Value> {
if (inputs.size() == 1)
return llvm::None;
return MemRefDescriptor::pack(rewriter, loc, *this, resultType, inputs);
});
}
/// Returns the MLIR context.
@ -297,22 +315,6 @@ LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
return LLVM::LLVMType::getFunctionTy(resultType, inputs, false);
}
/// Creates descriptor structs from individual values constituting them.
Operation *LLVMTypeConverter::materializeConversion(PatternRewriter &rewriter,
Type type,
ArrayRef<Value> values,
Location loc) {
if (auto unrankedMemRefType = type.dyn_cast<UnrankedMemRefType>())
return UnrankedMemRefDescriptor::pack(rewriter, loc, *this,
unrankedMemRefType, values)
.getDefiningOp();
auto memRefType = type.dyn_cast<MemRefType>();
assert(memRefType && "1->N conversion is only supported for memrefs");
return MemRefDescriptor::pack(rewriter, loc, *this, memRefType, values)
.getDefiningOp();
}
// Convert a MemRef to an LLVM type. The result is a MemRef descriptor which
// contains:
// 1. the pointer to the data buffer, followed by

View File

@ -305,27 +305,18 @@ void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
// persist in the IR after conversion.
if (!origArg.use_empty()) {
rewriter.setInsertionPointToStart(newBlock);
auto *newOp = typeConverter->materializeConversion(
rewriter, origArg.getType(), llvm::None, loc);
origArg.replaceAllUsesWith(newOp->getResult(0));
Value newArg = typeConverter->materializeConversion(
rewriter, loc, origArg.getType(), llvm::None);
assert(newArg &&
"Couldn't materialize a block argument after 1->0 conversion");
origArg.replaceAllUsesWith(newArg);
}
continue;
}
// If mapping is 1-1, replace the remaining uses and drop the cast
// operation.
// FIXME(riverriddle) This should check that the result type and operand
// type are the same, otherwise it should force a conversion to be
// materialized.
if (argInfo->newArgSize == 1) {
origArg.replaceAllUsesWith(
mapping.lookupOrDefault(newBlock->getArgument(argInfo->newArgIdx)));
continue;
}
// Otherwise this is a 1->N value mapping.
// Otherwise this is a 1->1+ value mapping.
Value castValue = argInfo->castValue;
assert(argInfo->newArgSize > 1 && castValue && "expected 1->N mapping");
assert(argInfo->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
// If the argument is still used, replace it with the generated cast.
if (!origArg.use_empty())
@ -333,7 +324,7 @@ void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
// If all users of the cast were removed, we can drop it. Otherwise, keep
// the operation alive and let the user handle any remaining usages.
if (castValue.use_empty())
if (castValue.use_empty() && castValue.getDefiningOp())
castValue.getDefiningOp()->erase();
}
}
@ -389,22 +380,22 @@ Block *ArgConverter::applySignatureConversion(
continue;
}
// If this is a 1->1 mapping, then map the argument directly.
if (inputMap->size == 1) {
mapping.map(origArg, newArgs[inputMap->inputNo]);
info.argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size);
continue;
}
// Otherwise, this is a 1->N mapping. Call into the provided type converter
// to pack the new values.
// Otherwise, this is a 1->1+ mapping. Call into the provided type converter
// to pack the new values. For 1->1 mappings, if there is no materialization
// provided, use the argument directly instead.
auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
Operation *cast = typeConverter->materializeConversion(
rewriter, origArg.getType(), replArgs, loc);
assert(cast->getNumResults() == 1);
mapping.map(origArg, cast->getResult(0));
Value newArg;
if (typeConverter)
newArg = typeConverter->materializeConversion(
rewriter, loc, origArg.getType(), replArgs);
if (!newArg) {
assert(replArgs.size() == 1 &&
"couldn't materialize the result of 1->N conversion");
newArg = replArgs.front();
}
mapping.map(origArg, newArg);
info.argInfo[i] =
ConvertedArgInfo(inputMap->inputNo, inputMap->size, cast->getResult(0));
ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
}
// Remove the original block from the region and return the new one.
@ -1815,6 +1806,15 @@ LogicalResult TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
return success();
}
Value TypeConverter::materializeConversion(PatternRewriter &rewriter,
Location loc, Type resultType,
ValueRange inputs) {
for (MaterializationCallbackFn &fn : llvm::reverse(materializations))
if (Optional<Value> result = fn(rewriter, resultType, inputs, loc))
return result.getValue();
return nullptr;
}
/// Create a default conversion pattern that rewrites the type signature of a
/// FuncOp.
namespace {

View File

@ -48,6 +48,13 @@ func @remap_input_1_to_N_remaining_use(%arg0: f32) {
"work"(%arg0) : (f32) -> ()
}
// CHECK-LABEL: func @remap_materialize_1_to_1(%{{.*}}: i43)
func @remap_materialize_1_to_1(%arg0: i42) {
// CHECK: %[[V:.*]] = "test.cast"(%arg0) : (i43) -> i42
// CHECK: "test.return"(%[[V]])
"test.return"(%arg0) : (i42) -> ()
}
// CHECK-LABEL: func @remap_input_to_self
func @remap_input_to_self(%arg0: index) {
// CHECK-NOT: test.cast

View File

@ -477,7 +477,11 @@ struct TestNestedOpCreationUndoRewrite
namespace {
struct TestTypeConverter : public TypeConverter {
using TypeConverter::TypeConverter;
TestTypeConverter() { addConversion(convertType); }
TestTypeConverter() {
addConversion(convertType);
addMaterialization(materializeCast);
addMaterialization(materializeOneToOneCast);
}
static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
// Drop I16 types.
@ -490,6 +494,12 @@ struct TestTypeConverter : public TypeConverter {
return success();
}
// Convert I42 to I43.
if (t.isInteger(42)) {
results.push_back(IntegerType::get(43, t.getContext()));
return success();
}
// Split F32 into F16,F16.
if (t.isF32()) {
results.assign(2, FloatType::getF16(t.getContext()));
@ -501,12 +511,24 @@ struct TestTypeConverter : public TypeConverter {
return success();
}
/// Override the hook to materialize a conversion. This is necessary because
/// we generate 1->N type mappings.
Operation *materializeConversion(PatternRewriter &rewriter, Type resultType,
ArrayRef<Value> inputs,
Location loc) override {
return rewriter.create<TestCastOp>(loc, resultType, inputs);
/// Hook for materializing a conversion. This is necessary because we generate
/// 1->N type mappings.
static Optional<Value> materializeCast(PatternRewriter &rewriter,
Type resultType, ValueRange inputs,
Location loc) {
if (inputs.size() == 1)
return inputs[0];
return rewriter.create<TestCastOp>(loc, resultType, inputs).getResult();
}
/// Materialize the cast for one-to-one conversion from i64 to f64.
static Optional<Value> materializeOneToOneCast(PatternRewriter &rewriter,
IntegerType resultType,
ValueRange inputs,
Location loc) {
if (resultType.getWidth() == 42 && inputs.size() == 1)
return rewriter.create<TestCastOp>(loc, resultType, inputs).getResult();
return llvm::None;
}
};