mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-04-01 12:43:47 +00:00
[mlir] support materialization for 1-1 type conversions
Dialect conversion infrastructure supports 1->N type conversions by requiring individual conversions to provide facilities to generate operations retrofitting N values into 1 of the original type when N > 1. This functionality can also be used to materialize explicit "cast"-like operations, but it did not support 1->1 type conversions until now. Modify TypeConverter to support materialization of cast operations for 1-1 conversions. This also makes materialization specification more extensible following the same pattern as type conversions. Instead of overloading a virtual function, users or subclasses of TypeConversion can now register type-specific materialization callbacks that will be called in order for the given type. Differential Revision: https://reviews.llvm.org/D79729
This commit is contained in:
parent
bff0c56ff9
commit
5c5dafc534
@ -217,16 +217,20 @@ class TypeConverter {
|
||||
template <typename ConversionFnT>
|
||||
void addConversion(ConversionFnT &&callback);
|
||||
|
||||
/// This hook allows for materializing a conversion from a set of types into
|
||||
/// one result type by generating a cast operation of some kind. The generated
|
||||
/// operation should produce one result, of 'resultType', with the provided
|
||||
/// 'inputs' as operands. This hook must be overridden when a type conversion
|
||||
/// Register a materialization function, which must be convertibe to the
|
||||
/// following form
|
||||
/// `Optional<Value>(PatternRewriter &, T, ValueRange, Location)`,
|
||||
/// where `T` is any subclass of `Type`. This function is responsible for
|
||||
/// creating an operation, using the PatternRewriter and Location provided,
|
||||
/// that "casts" a range of values into a single value of the given type `T`.
|
||||
/// It must return a Value of the converted type on success, an `llvm::None`
|
||||
/// if it failed but other materialization can be attempted, and `nullptr` on
|
||||
/// unrecoverable failure. It will only be called for (sub)types of `T`.
|
||||
/// Materialization functions must be provided when a type conversion
|
||||
/// results in more than one type, or if a type conversion may persist after
|
||||
/// the conversion has finished.
|
||||
virtual Operation *materializeConversion(PatternRewriter &rewriter,
|
||||
Type resultType,
|
||||
ArrayRef<Value> inputs,
|
||||
Location loc);
|
||||
template <typename FnT>
|
||||
void addMaterialization(FnT &&callback);
|
||||
};
|
||||
```
|
||||
|
||||
|
@ -122,11 +122,6 @@ public:
|
||||
/// pointers to memref descriptors for arguments.
|
||||
LLVM::LLVMType convertFunctionTypeCWrapper(FunctionType type);
|
||||
|
||||
/// Creates descriptor structs from individual values constituting them.
|
||||
Operation *materializeConversion(PatternRewriter &rewriter, Type type,
|
||||
ArrayRef<Value> values,
|
||||
Location loc) override;
|
||||
|
||||
/// Gets the LLVM representation of the index type. The returned type is an
|
||||
/// integer type with the size configured for this type converter.
|
||||
LLVM::LLVMType getIndexType();
|
||||
|
@ -113,6 +113,25 @@ public:
|
||||
registerConversion(wrapCallback<T>(std::forward<FnT>(callback)));
|
||||
}
|
||||
|
||||
/// Register a materialization function, which must be convertible to the
|
||||
/// following form:
|
||||
/// `Optional<Value>(PatternRewriter &, T, ValueRange, Location)`,
|
||||
/// where `T` is any subclass of `Type`. This function is responsible for
|
||||
/// creating an operation, using the PatternRewriter and Location provided,
|
||||
/// that "casts" a range of values into a single value of the given type `T`.
|
||||
/// It must return a Value of the converted type on success, an `llvm::None`
|
||||
/// if it failed but other materialization can be attempted, and `nullptr` on
|
||||
/// unrecoverable failure. It will only be called for (sub)types of `T`.
|
||||
/// Materialization functions must be provided when a type conversion
|
||||
/// results in more than one type, or if a type conversion may persist after
|
||||
/// the conversion has finished.
|
||||
template <typename FnT,
|
||||
typename T = typename llvm::function_traits<FnT>::template arg_t<1>>
|
||||
void addMaterialization(FnT &&callback) {
|
||||
registerMaterialization(
|
||||
wrapMaterialization<T>(std::forward<FnT>(callback)));
|
||||
}
|
||||
|
||||
/// Convert the given type. This function should return failure if no valid
|
||||
/// conversion exists, success otherwise. If the new set of types is empty,
|
||||
/// the type is removed and any usages of the existing value are expected to
|
||||
@ -148,18 +167,10 @@ public:
|
||||
/// valid conversion for the signature on success, None otherwise.
|
||||
Optional<SignatureConversion> convertBlockSignature(Block *block);
|
||||
|
||||
/// This hook allows for materializing a conversion from a set of types into
|
||||
/// one result type by generating a cast operation of some kind. The generated
|
||||
/// operation should produce one result, of 'resultType', with the provided
|
||||
/// 'inputs' as operands. This hook must be overridden when a type conversion
|
||||
/// results in more than one type, or if a type conversion may persist after
|
||||
/// the conversion has finished.
|
||||
virtual Operation *materializeConversion(PatternRewriter &rewriter,
|
||||
Type resultType,
|
||||
ArrayRef<Value> inputs,
|
||||
Location loc) {
|
||||
llvm_unreachable("expected 'materializeConversion' to be overridden");
|
||||
}
|
||||
/// Materialize a conversion from a set of types into one result type by
|
||||
/// generating a cast operation of some kind.
|
||||
Value materializeConversion(PatternRewriter &rewriter, Location loc,
|
||||
Type resultType, ValueRange inputs);
|
||||
|
||||
private:
|
||||
/// The signature of the callback used to convert a type. If the new set of
|
||||
@ -168,6 +179,9 @@ private:
|
||||
using ConversionCallbackFn =
|
||||
std::function<Optional<LogicalResult>(Type, SmallVectorImpl<Type> &)>;
|
||||
|
||||
using MaterializationCallbackFn = std::function<Optional<Value>(
|
||||
PatternRewriter &, Type, ValueRange, Location)>;
|
||||
|
||||
/// Generate a wrapper for the given callback. This allows for accepting
|
||||
/// different callback forms, that all compose into a single version.
|
||||
/// With callback of form: `Optional<Type>(T)`
|
||||
@ -204,8 +218,30 @@ private:
|
||||
conversions.emplace_back(std::move(callback));
|
||||
}
|
||||
|
||||
/// Generate a wrapper for the given materialization callback. The callback
|
||||
/// may take any subclass of `Type` and the wrapper will check for the target
|
||||
/// type to be of the expected class before calling the callback.
|
||||
template <typename T, typename FnT>
|
||||
MaterializationCallbackFn wrapMaterialization(FnT &&callback) {
|
||||
return [callback = std::forward<FnT>(callback)](
|
||||
PatternRewriter &rewriter, Type resultType, ValueRange inputs,
|
||||
Location loc) -> Optional<Value> {
|
||||
if (T derivedType = resultType.dyn_cast<T>())
|
||||
return callback(rewriter, derivedType, inputs, loc);
|
||||
return llvm::None;
|
||||
};
|
||||
}
|
||||
|
||||
/// Register a materialization.
|
||||
void registerMaterialization(MaterializationCallbackFn &&callback) {
|
||||
materializations.emplace_back(std::move(callback));
|
||||
}
|
||||
|
||||
/// The set of registered conversion functions.
|
||||
SmallVector<ConversionCallbackFn, 4> conversions;
|
||||
|
||||
/// The list of registered materialization functions.
|
||||
SmallVector<MaterializationCallbackFn, 2> materializations;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -150,6 +150,24 @@ LLVMTypeConverter::LLVMTypeConverter(
|
||||
|
||||
// LLVMType is legal, so add a pass-through conversion.
|
||||
addConversion([](LLVM::LLVMType type) { return type; });
|
||||
|
||||
// Materialization for memrefs creates descriptor structs from individual
|
||||
// values constituting them, when descriptors are used, i.e. more than one
|
||||
// value represents a memref.
|
||||
addMaterialization([&](PatternRewriter &rewriter,
|
||||
UnrankedMemRefType resultType, ValueRange inputs,
|
||||
Location loc) -> Optional<Value> {
|
||||
if (inputs.size() == 1)
|
||||
return llvm::None;
|
||||
return UnrankedMemRefDescriptor::pack(rewriter, loc, *this, resultType,
|
||||
inputs);
|
||||
});
|
||||
addMaterialization([&](PatternRewriter &rewriter, MemRefType resultType,
|
||||
ValueRange inputs, Location loc) -> Optional<Value> {
|
||||
if (inputs.size() == 1)
|
||||
return llvm::None;
|
||||
return MemRefDescriptor::pack(rewriter, loc, *this, resultType, inputs);
|
||||
});
|
||||
}
|
||||
|
||||
/// Returns the MLIR context.
|
||||
@ -297,22 +315,6 @@ LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
|
||||
return LLVM::LLVMType::getFunctionTy(resultType, inputs, false);
|
||||
}
|
||||
|
||||
/// Creates descriptor structs from individual values constituting them.
|
||||
Operation *LLVMTypeConverter::materializeConversion(PatternRewriter &rewriter,
|
||||
Type type,
|
||||
ArrayRef<Value> values,
|
||||
Location loc) {
|
||||
if (auto unrankedMemRefType = type.dyn_cast<UnrankedMemRefType>())
|
||||
return UnrankedMemRefDescriptor::pack(rewriter, loc, *this,
|
||||
unrankedMemRefType, values)
|
||||
.getDefiningOp();
|
||||
|
||||
auto memRefType = type.dyn_cast<MemRefType>();
|
||||
assert(memRefType && "1->N conversion is only supported for memrefs");
|
||||
return MemRefDescriptor::pack(rewriter, loc, *this, memRefType, values)
|
||||
.getDefiningOp();
|
||||
}
|
||||
|
||||
// Convert a MemRef to an LLVM type. The result is a MemRef descriptor which
|
||||
// contains:
|
||||
// 1. the pointer to the data buffer, followed by
|
||||
|
@ -305,27 +305,18 @@ void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
|
||||
// persist in the IR after conversion.
|
||||
if (!origArg.use_empty()) {
|
||||
rewriter.setInsertionPointToStart(newBlock);
|
||||
auto *newOp = typeConverter->materializeConversion(
|
||||
rewriter, origArg.getType(), llvm::None, loc);
|
||||
origArg.replaceAllUsesWith(newOp->getResult(0));
|
||||
Value newArg = typeConverter->materializeConversion(
|
||||
rewriter, loc, origArg.getType(), llvm::None);
|
||||
assert(newArg &&
|
||||
"Couldn't materialize a block argument after 1->0 conversion");
|
||||
origArg.replaceAllUsesWith(newArg);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// If mapping is 1-1, replace the remaining uses and drop the cast
|
||||
// operation.
|
||||
// FIXME(riverriddle) This should check that the result type and operand
|
||||
// type are the same, otherwise it should force a conversion to be
|
||||
// materialized.
|
||||
if (argInfo->newArgSize == 1) {
|
||||
origArg.replaceAllUsesWith(
|
||||
mapping.lookupOrDefault(newBlock->getArgument(argInfo->newArgIdx)));
|
||||
continue;
|
||||
}
|
||||
|
||||
// Otherwise this is a 1->N value mapping.
|
||||
// Otherwise this is a 1->1+ value mapping.
|
||||
Value castValue = argInfo->castValue;
|
||||
assert(argInfo->newArgSize > 1 && castValue && "expected 1->N mapping");
|
||||
assert(argInfo->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
|
||||
|
||||
// If the argument is still used, replace it with the generated cast.
|
||||
if (!origArg.use_empty())
|
||||
@ -333,7 +324,7 @@ void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
|
||||
|
||||
// If all users of the cast were removed, we can drop it. Otherwise, keep
|
||||
// the operation alive and let the user handle any remaining usages.
|
||||
if (castValue.use_empty())
|
||||
if (castValue.use_empty() && castValue.getDefiningOp())
|
||||
castValue.getDefiningOp()->erase();
|
||||
}
|
||||
}
|
||||
@ -389,22 +380,22 @@ Block *ArgConverter::applySignatureConversion(
|
||||
continue;
|
||||
}
|
||||
|
||||
// If this is a 1->1 mapping, then map the argument directly.
|
||||
if (inputMap->size == 1) {
|
||||
mapping.map(origArg, newArgs[inputMap->inputNo]);
|
||||
info.argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Otherwise, this is a 1->N mapping. Call into the provided type converter
|
||||
// to pack the new values.
|
||||
// Otherwise, this is a 1->1+ mapping. Call into the provided type converter
|
||||
// to pack the new values. For 1->1 mappings, if there is no materialization
|
||||
// provided, use the argument directly instead.
|
||||
auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
|
||||
Operation *cast = typeConverter->materializeConversion(
|
||||
rewriter, origArg.getType(), replArgs, loc);
|
||||
assert(cast->getNumResults() == 1);
|
||||
mapping.map(origArg, cast->getResult(0));
|
||||
Value newArg;
|
||||
if (typeConverter)
|
||||
newArg = typeConverter->materializeConversion(
|
||||
rewriter, loc, origArg.getType(), replArgs);
|
||||
if (!newArg) {
|
||||
assert(replArgs.size() == 1 &&
|
||||
"couldn't materialize the result of 1->N conversion");
|
||||
newArg = replArgs.front();
|
||||
}
|
||||
mapping.map(origArg, newArg);
|
||||
info.argInfo[i] =
|
||||
ConvertedArgInfo(inputMap->inputNo, inputMap->size, cast->getResult(0));
|
||||
ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
|
||||
}
|
||||
|
||||
// Remove the original block from the region and return the new one.
|
||||
@ -1815,6 +1806,15 @@ LogicalResult TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
|
||||
return success();
|
||||
}
|
||||
|
||||
Value TypeConverter::materializeConversion(PatternRewriter &rewriter,
|
||||
Location loc, Type resultType,
|
||||
ValueRange inputs) {
|
||||
for (MaterializationCallbackFn &fn : llvm::reverse(materializations))
|
||||
if (Optional<Value> result = fn(rewriter, resultType, inputs, loc))
|
||||
return result.getValue();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/// Create a default conversion pattern that rewrites the type signature of a
|
||||
/// FuncOp.
|
||||
namespace {
|
||||
|
@ -48,6 +48,13 @@ func @remap_input_1_to_N_remaining_use(%arg0: f32) {
|
||||
"work"(%arg0) : (f32) -> ()
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @remap_materialize_1_to_1(%{{.*}}: i43)
|
||||
func @remap_materialize_1_to_1(%arg0: i42) {
|
||||
// CHECK: %[[V:.*]] = "test.cast"(%arg0) : (i43) -> i42
|
||||
// CHECK: "test.return"(%[[V]])
|
||||
"test.return"(%arg0) : (i42) -> ()
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @remap_input_to_self
|
||||
func @remap_input_to_self(%arg0: index) {
|
||||
// CHECK-NOT: test.cast
|
||||
|
@ -477,7 +477,11 @@ struct TestNestedOpCreationUndoRewrite
|
||||
namespace {
|
||||
struct TestTypeConverter : public TypeConverter {
|
||||
using TypeConverter::TypeConverter;
|
||||
TestTypeConverter() { addConversion(convertType); }
|
||||
TestTypeConverter() {
|
||||
addConversion(convertType);
|
||||
addMaterialization(materializeCast);
|
||||
addMaterialization(materializeOneToOneCast);
|
||||
}
|
||||
|
||||
static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
|
||||
// Drop I16 types.
|
||||
@ -490,6 +494,12 @@ struct TestTypeConverter : public TypeConverter {
|
||||
return success();
|
||||
}
|
||||
|
||||
// Convert I42 to I43.
|
||||
if (t.isInteger(42)) {
|
||||
results.push_back(IntegerType::get(43, t.getContext()));
|
||||
return success();
|
||||
}
|
||||
|
||||
// Split F32 into F16,F16.
|
||||
if (t.isF32()) {
|
||||
results.assign(2, FloatType::getF16(t.getContext()));
|
||||
@ -501,12 +511,24 @@ struct TestTypeConverter : public TypeConverter {
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Override the hook to materialize a conversion. This is necessary because
|
||||
/// we generate 1->N type mappings.
|
||||
Operation *materializeConversion(PatternRewriter &rewriter, Type resultType,
|
||||
ArrayRef<Value> inputs,
|
||||
Location loc) override {
|
||||
return rewriter.create<TestCastOp>(loc, resultType, inputs);
|
||||
/// Hook for materializing a conversion. This is necessary because we generate
|
||||
/// 1->N type mappings.
|
||||
static Optional<Value> materializeCast(PatternRewriter &rewriter,
|
||||
Type resultType, ValueRange inputs,
|
||||
Location loc) {
|
||||
if (inputs.size() == 1)
|
||||
return inputs[0];
|
||||
return rewriter.create<TestCastOp>(loc, resultType, inputs).getResult();
|
||||
}
|
||||
|
||||
/// Materialize the cast for one-to-one conversion from i64 to f64.
|
||||
static Optional<Value> materializeOneToOneCast(PatternRewriter &rewriter,
|
||||
IntegerType resultType,
|
||||
ValueRange inputs,
|
||||
Location loc) {
|
||||
if (resultType.getWidth() == 42 && inputs.size() == 1)
|
||||
return rewriter.create<TestCastOp>(loc, resultType, inputs).getResult();
|
||||
return llvm::None;
|
||||
}
|
||||
};
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user