mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-04-02 21:22:44 +00:00
[mlir] support materialization for 1-1 type conversions
Dialect conversion infrastructure supports 1->N type conversions by requiring individual conversions to provide facilities to generate operations retrofitting N values into 1 of the original type when N > 1. This functionality can also be used to materialize explicit "cast"-like operations, but it did not support 1->1 type conversions until now. Modify TypeConverter to support materialization of cast operations for 1-1 conversions. This also makes materialization specification more extensible following the same pattern as type conversions. Instead of overloading a virtual function, users or subclasses of TypeConversion can now register type-specific materialization callbacks that will be called in order for the given type. Differential Revision: https://reviews.llvm.org/D79729
This commit is contained in:
parent
bff0c56ff9
commit
5c5dafc534
@ -217,16 +217,20 @@ class TypeConverter {
|
|||||||
template <typename ConversionFnT>
|
template <typename ConversionFnT>
|
||||||
void addConversion(ConversionFnT &&callback);
|
void addConversion(ConversionFnT &&callback);
|
||||||
|
|
||||||
/// This hook allows for materializing a conversion from a set of types into
|
/// Register a materialization function, which must be convertibe to the
|
||||||
/// one result type by generating a cast operation of some kind. The generated
|
/// following form
|
||||||
/// operation should produce one result, of 'resultType', with the provided
|
/// `Optional<Value>(PatternRewriter &, T, ValueRange, Location)`,
|
||||||
/// 'inputs' as operands. This hook must be overridden when a type conversion
|
/// where `T` is any subclass of `Type`. This function is responsible for
|
||||||
|
/// creating an operation, using the PatternRewriter and Location provided,
|
||||||
|
/// that "casts" a range of values into a single value of the given type `T`.
|
||||||
|
/// It must return a Value of the converted type on success, an `llvm::None`
|
||||||
|
/// if it failed but other materialization can be attempted, and `nullptr` on
|
||||||
|
/// unrecoverable failure. It will only be called for (sub)types of `T`.
|
||||||
|
/// Materialization functions must be provided when a type conversion
|
||||||
/// results in more than one type, or if a type conversion may persist after
|
/// results in more than one type, or if a type conversion may persist after
|
||||||
/// the conversion has finished.
|
/// the conversion has finished.
|
||||||
virtual Operation *materializeConversion(PatternRewriter &rewriter,
|
template <typename FnT>
|
||||||
Type resultType,
|
void addMaterialization(FnT &&callback);
|
||||||
ArrayRef<Value> inputs,
|
|
||||||
Location loc);
|
|
||||||
};
|
};
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -122,11 +122,6 @@ public:
|
|||||||
/// pointers to memref descriptors for arguments.
|
/// pointers to memref descriptors for arguments.
|
||||||
LLVM::LLVMType convertFunctionTypeCWrapper(FunctionType type);
|
LLVM::LLVMType convertFunctionTypeCWrapper(FunctionType type);
|
||||||
|
|
||||||
/// Creates descriptor structs from individual values constituting them.
|
|
||||||
Operation *materializeConversion(PatternRewriter &rewriter, Type type,
|
|
||||||
ArrayRef<Value> values,
|
|
||||||
Location loc) override;
|
|
||||||
|
|
||||||
/// Gets the LLVM representation of the index type. The returned type is an
|
/// Gets the LLVM representation of the index type. The returned type is an
|
||||||
/// integer type with the size configured for this type converter.
|
/// integer type with the size configured for this type converter.
|
||||||
LLVM::LLVMType getIndexType();
|
LLVM::LLVMType getIndexType();
|
||||||
|
@ -113,6 +113,25 @@ public:
|
|||||||
registerConversion(wrapCallback<T>(std::forward<FnT>(callback)));
|
registerConversion(wrapCallback<T>(std::forward<FnT>(callback)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Register a materialization function, which must be convertible to the
|
||||||
|
/// following form:
|
||||||
|
/// `Optional<Value>(PatternRewriter &, T, ValueRange, Location)`,
|
||||||
|
/// where `T` is any subclass of `Type`. This function is responsible for
|
||||||
|
/// creating an operation, using the PatternRewriter and Location provided,
|
||||||
|
/// that "casts" a range of values into a single value of the given type `T`.
|
||||||
|
/// It must return a Value of the converted type on success, an `llvm::None`
|
||||||
|
/// if it failed but other materialization can be attempted, and `nullptr` on
|
||||||
|
/// unrecoverable failure. It will only be called for (sub)types of `T`.
|
||||||
|
/// Materialization functions must be provided when a type conversion
|
||||||
|
/// results in more than one type, or if a type conversion may persist after
|
||||||
|
/// the conversion has finished.
|
||||||
|
template <typename FnT,
|
||||||
|
typename T = typename llvm::function_traits<FnT>::template arg_t<1>>
|
||||||
|
void addMaterialization(FnT &&callback) {
|
||||||
|
registerMaterialization(
|
||||||
|
wrapMaterialization<T>(std::forward<FnT>(callback)));
|
||||||
|
}
|
||||||
|
|
||||||
/// Convert the given type. This function should return failure if no valid
|
/// Convert the given type. This function should return failure if no valid
|
||||||
/// conversion exists, success otherwise. If the new set of types is empty,
|
/// conversion exists, success otherwise. If the new set of types is empty,
|
||||||
/// the type is removed and any usages of the existing value are expected to
|
/// the type is removed and any usages of the existing value are expected to
|
||||||
@ -148,18 +167,10 @@ public:
|
|||||||
/// valid conversion for the signature on success, None otherwise.
|
/// valid conversion for the signature on success, None otherwise.
|
||||||
Optional<SignatureConversion> convertBlockSignature(Block *block);
|
Optional<SignatureConversion> convertBlockSignature(Block *block);
|
||||||
|
|
||||||
/// This hook allows for materializing a conversion from a set of types into
|
/// Materialize a conversion from a set of types into one result type by
|
||||||
/// one result type by generating a cast operation of some kind. The generated
|
/// generating a cast operation of some kind.
|
||||||
/// operation should produce one result, of 'resultType', with the provided
|
Value materializeConversion(PatternRewriter &rewriter, Location loc,
|
||||||
/// 'inputs' as operands. This hook must be overridden when a type conversion
|
Type resultType, ValueRange inputs);
|
||||||
/// results in more than one type, or if a type conversion may persist after
|
|
||||||
/// the conversion has finished.
|
|
||||||
virtual Operation *materializeConversion(PatternRewriter &rewriter,
|
|
||||||
Type resultType,
|
|
||||||
ArrayRef<Value> inputs,
|
|
||||||
Location loc) {
|
|
||||||
llvm_unreachable("expected 'materializeConversion' to be overridden");
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/// The signature of the callback used to convert a type. If the new set of
|
/// The signature of the callback used to convert a type. If the new set of
|
||||||
@ -168,6 +179,9 @@ private:
|
|||||||
using ConversionCallbackFn =
|
using ConversionCallbackFn =
|
||||||
std::function<Optional<LogicalResult>(Type, SmallVectorImpl<Type> &)>;
|
std::function<Optional<LogicalResult>(Type, SmallVectorImpl<Type> &)>;
|
||||||
|
|
||||||
|
using MaterializationCallbackFn = std::function<Optional<Value>(
|
||||||
|
PatternRewriter &, Type, ValueRange, Location)>;
|
||||||
|
|
||||||
/// Generate a wrapper for the given callback. This allows for accepting
|
/// Generate a wrapper for the given callback. This allows for accepting
|
||||||
/// different callback forms, that all compose into a single version.
|
/// different callback forms, that all compose into a single version.
|
||||||
/// With callback of form: `Optional<Type>(T)`
|
/// With callback of form: `Optional<Type>(T)`
|
||||||
@ -204,8 +218,30 @@ private:
|
|||||||
conversions.emplace_back(std::move(callback));
|
conversions.emplace_back(std::move(callback));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Generate a wrapper for the given materialization callback. The callback
|
||||||
|
/// may take any subclass of `Type` and the wrapper will check for the target
|
||||||
|
/// type to be of the expected class before calling the callback.
|
||||||
|
template <typename T, typename FnT>
|
||||||
|
MaterializationCallbackFn wrapMaterialization(FnT &&callback) {
|
||||||
|
return [callback = std::forward<FnT>(callback)](
|
||||||
|
PatternRewriter &rewriter, Type resultType, ValueRange inputs,
|
||||||
|
Location loc) -> Optional<Value> {
|
||||||
|
if (T derivedType = resultType.dyn_cast<T>())
|
||||||
|
return callback(rewriter, derivedType, inputs, loc);
|
||||||
|
return llvm::None;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Register a materialization.
|
||||||
|
void registerMaterialization(MaterializationCallbackFn &&callback) {
|
||||||
|
materializations.emplace_back(std::move(callback));
|
||||||
|
}
|
||||||
|
|
||||||
/// The set of registered conversion functions.
|
/// The set of registered conversion functions.
|
||||||
SmallVector<ConversionCallbackFn, 4> conversions;
|
SmallVector<ConversionCallbackFn, 4> conversions;
|
||||||
|
|
||||||
|
/// The list of registered materialization functions.
|
||||||
|
SmallVector<MaterializationCallbackFn, 2> materializations;
|
||||||
};
|
};
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -150,6 +150,24 @@ LLVMTypeConverter::LLVMTypeConverter(
|
|||||||
|
|
||||||
// LLVMType is legal, so add a pass-through conversion.
|
// LLVMType is legal, so add a pass-through conversion.
|
||||||
addConversion([](LLVM::LLVMType type) { return type; });
|
addConversion([](LLVM::LLVMType type) { return type; });
|
||||||
|
|
||||||
|
// Materialization for memrefs creates descriptor structs from individual
|
||||||
|
// values constituting them, when descriptors are used, i.e. more than one
|
||||||
|
// value represents a memref.
|
||||||
|
addMaterialization([&](PatternRewriter &rewriter,
|
||||||
|
UnrankedMemRefType resultType, ValueRange inputs,
|
||||||
|
Location loc) -> Optional<Value> {
|
||||||
|
if (inputs.size() == 1)
|
||||||
|
return llvm::None;
|
||||||
|
return UnrankedMemRefDescriptor::pack(rewriter, loc, *this, resultType,
|
||||||
|
inputs);
|
||||||
|
});
|
||||||
|
addMaterialization([&](PatternRewriter &rewriter, MemRefType resultType,
|
||||||
|
ValueRange inputs, Location loc) -> Optional<Value> {
|
||||||
|
if (inputs.size() == 1)
|
||||||
|
return llvm::None;
|
||||||
|
return MemRefDescriptor::pack(rewriter, loc, *this, resultType, inputs);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the MLIR context.
|
/// Returns the MLIR context.
|
||||||
@ -297,22 +315,6 @@ LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
|
|||||||
return LLVM::LLVMType::getFunctionTy(resultType, inputs, false);
|
return LLVM::LLVMType::getFunctionTy(resultType, inputs, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates descriptor structs from individual values constituting them.
|
|
||||||
Operation *LLVMTypeConverter::materializeConversion(PatternRewriter &rewriter,
|
|
||||||
Type type,
|
|
||||||
ArrayRef<Value> values,
|
|
||||||
Location loc) {
|
|
||||||
if (auto unrankedMemRefType = type.dyn_cast<UnrankedMemRefType>())
|
|
||||||
return UnrankedMemRefDescriptor::pack(rewriter, loc, *this,
|
|
||||||
unrankedMemRefType, values)
|
|
||||||
.getDefiningOp();
|
|
||||||
|
|
||||||
auto memRefType = type.dyn_cast<MemRefType>();
|
|
||||||
assert(memRefType && "1->N conversion is only supported for memrefs");
|
|
||||||
return MemRefDescriptor::pack(rewriter, loc, *this, memRefType, values)
|
|
||||||
.getDefiningOp();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert a MemRef to an LLVM type. The result is a MemRef descriptor which
|
// Convert a MemRef to an LLVM type. The result is a MemRef descriptor which
|
||||||
// contains:
|
// contains:
|
||||||
// 1. the pointer to the data buffer, followed by
|
// 1. the pointer to the data buffer, followed by
|
||||||
|
@ -305,27 +305,18 @@ void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
|
|||||||
// persist in the IR after conversion.
|
// persist in the IR after conversion.
|
||||||
if (!origArg.use_empty()) {
|
if (!origArg.use_empty()) {
|
||||||
rewriter.setInsertionPointToStart(newBlock);
|
rewriter.setInsertionPointToStart(newBlock);
|
||||||
auto *newOp = typeConverter->materializeConversion(
|
Value newArg = typeConverter->materializeConversion(
|
||||||
rewriter, origArg.getType(), llvm::None, loc);
|
rewriter, loc, origArg.getType(), llvm::None);
|
||||||
origArg.replaceAllUsesWith(newOp->getResult(0));
|
assert(newArg &&
|
||||||
|
"Couldn't materialize a block argument after 1->0 conversion");
|
||||||
|
origArg.replaceAllUsesWith(newArg);
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// If mapping is 1-1, replace the remaining uses and drop the cast
|
// Otherwise this is a 1->1+ value mapping.
|
||||||
// operation.
|
|
||||||
// FIXME(riverriddle) This should check that the result type and operand
|
|
||||||
// type are the same, otherwise it should force a conversion to be
|
|
||||||
// materialized.
|
|
||||||
if (argInfo->newArgSize == 1) {
|
|
||||||
origArg.replaceAllUsesWith(
|
|
||||||
mapping.lookupOrDefault(newBlock->getArgument(argInfo->newArgIdx)));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Otherwise this is a 1->N value mapping.
|
|
||||||
Value castValue = argInfo->castValue;
|
Value castValue = argInfo->castValue;
|
||||||
assert(argInfo->newArgSize > 1 && castValue && "expected 1->N mapping");
|
assert(argInfo->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
|
||||||
|
|
||||||
// If the argument is still used, replace it with the generated cast.
|
// If the argument is still used, replace it with the generated cast.
|
||||||
if (!origArg.use_empty())
|
if (!origArg.use_empty())
|
||||||
@ -333,7 +324,7 @@ void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
|
|||||||
|
|
||||||
// If all users of the cast were removed, we can drop it. Otherwise, keep
|
// If all users of the cast were removed, we can drop it. Otherwise, keep
|
||||||
// the operation alive and let the user handle any remaining usages.
|
// the operation alive and let the user handle any remaining usages.
|
||||||
if (castValue.use_empty())
|
if (castValue.use_empty() && castValue.getDefiningOp())
|
||||||
castValue.getDefiningOp()->erase();
|
castValue.getDefiningOp()->erase();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -389,22 +380,22 @@ Block *ArgConverter::applySignatureConversion(
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// If this is a 1->1 mapping, then map the argument directly.
|
// Otherwise, this is a 1->1+ mapping. Call into the provided type converter
|
||||||
if (inputMap->size == 1) {
|
// to pack the new values. For 1->1 mappings, if there is no materialization
|
||||||
mapping.map(origArg, newArgs[inputMap->inputNo]);
|
// provided, use the argument directly instead.
|
||||||
info.argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Otherwise, this is a 1->N mapping. Call into the provided type converter
|
|
||||||
// to pack the new values.
|
|
||||||
auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
|
auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
|
||||||
Operation *cast = typeConverter->materializeConversion(
|
Value newArg;
|
||||||
rewriter, origArg.getType(), replArgs, loc);
|
if (typeConverter)
|
||||||
assert(cast->getNumResults() == 1);
|
newArg = typeConverter->materializeConversion(
|
||||||
mapping.map(origArg, cast->getResult(0));
|
rewriter, loc, origArg.getType(), replArgs);
|
||||||
|
if (!newArg) {
|
||||||
|
assert(replArgs.size() == 1 &&
|
||||||
|
"couldn't materialize the result of 1->N conversion");
|
||||||
|
newArg = replArgs.front();
|
||||||
|
}
|
||||||
|
mapping.map(origArg, newArg);
|
||||||
info.argInfo[i] =
|
info.argInfo[i] =
|
||||||
ConvertedArgInfo(inputMap->inputNo, inputMap->size, cast->getResult(0));
|
ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove the original block from the region and return the new one.
|
// Remove the original block from the region and return the new one.
|
||||||
@ -1815,6 +1806,15 @@ LogicalResult TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Value TypeConverter::materializeConversion(PatternRewriter &rewriter,
|
||||||
|
Location loc, Type resultType,
|
||||||
|
ValueRange inputs) {
|
||||||
|
for (MaterializationCallbackFn &fn : llvm::reverse(materializations))
|
||||||
|
if (Optional<Value> result = fn(rewriter, resultType, inputs, loc))
|
||||||
|
return result.getValue();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
/// Create a default conversion pattern that rewrites the type signature of a
|
/// Create a default conversion pattern that rewrites the type signature of a
|
||||||
/// FuncOp.
|
/// FuncOp.
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -48,6 +48,13 @@ func @remap_input_1_to_N_remaining_use(%arg0: f32) {
|
|||||||
"work"(%arg0) : (f32) -> ()
|
"work"(%arg0) : (f32) -> ()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @remap_materialize_1_to_1(%{{.*}}: i43)
|
||||||
|
func @remap_materialize_1_to_1(%arg0: i42) {
|
||||||
|
// CHECK: %[[V:.*]] = "test.cast"(%arg0) : (i43) -> i42
|
||||||
|
// CHECK: "test.return"(%[[V]])
|
||||||
|
"test.return"(%arg0) : (i42) -> ()
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @remap_input_to_self
|
// CHECK-LABEL: func @remap_input_to_self
|
||||||
func @remap_input_to_self(%arg0: index) {
|
func @remap_input_to_self(%arg0: index) {
|
||||||
// CHECK-NOT: test.cast
|
// CHECK-NOT: test.cast
|
||||||
|
@ -477,7 +477,11 @@ struct TestNestedOpCreationUndoRewrite
|
|||||||
namespace {
|
namespace {
|
||||||
struct TestTypeConverter : public TypeConverter {
|
struct TestTypeConverter : public TypeConverter {
|
||||||
using TypeConverter::TypeConverter;
|
using TypeConverter::TypeConverter;
|
||||||
TestTypeConverter() { addConversion(convertType); }
|
TestTypeConverter() {
|
||||||
|
addConversion(convertType);
|
||||||
|
addMaterialization(materializeCast);
|
||||||
|
addMaterialization(materializeOneToOneCast);
|
||||||
|
}
|
||||||
|
|
||||||
static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
|
static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
|
||||||
// Drop I16 types.
|
// Drop I16 types.
|
||||||
@ -490,6 +494,12 @@ struct TestTypeConverter : public TypeConverter {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Convert I42 to I43.
|
||||||
|
if (t.isInteger(42)) {
|
||||||
|
results.push_back(IntegerType::get(43, t.getContext()));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
// Split F32 into F16,F16.
|
// Split F32 into F16,F16.
|
||||||
if (t.isF32()) {
|
if (t.isF32()) {
|
||||||
results.assign(2, FloatType::getF16(t.getContext()));
|
results.assign(2, FloatType::getF16(t.getContext()));
|
||||||
@ -501,12 +511,24 @@ struct TestTypeConverter : public TypeConverter {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Override the hook to materialize a conversion. This is necessary because
|
/// Hook for materializing a conversion. This is necessary because we generate
|
||||||
/// we generate 1->N type mappings.
|
/// 1->N type mappings.
|
||||||
Operation *materializeConversion(PatternRewriter &rewriter, Type resultType,
|
static Optional<Value> materializeCast(PatternRewriter &rewriter,
|
||||||
ArrayRef<Value> inputs,
|
Type resultType, ValueRange inputs,
|
||||||
Location loc) override {
|
Location loc) {
|
||||||
return rewriter.create<TestCastOp>(loc, resultType, inputs);
|
if (inputs.size() == 1)
|
||||||
|
return inputs[0];
|
||||||
|
return rewriter.create<TestCastOp>(loc, resultType, inputs).getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Materialize the cast for one-to-one conversion from i64 to f64.
|
||||||
|
static Optional<Value> materializeOneToOneCast(PatternRewriter &rewriter,
|
||||||
|
IntegerType resultType,
|
||||||
|
ValueRange inputs,
|
||||||
|
Location loc) {
|
||||||
|
if (resultType.getWidth() == 42 && inputs.size() == 1)
|
||||||
|
return rewriter.create<TestCastOp>(loc, resultType, inputs).getResult();
|
||||||
|
return llvm::None;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user