mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-01-13 11:22:03 +00:00
[mlir] make the bitwidth of device side index computations configurable
The patch makes the index type lowering of the GPU to NVVM/ROCDL conversion configurable. It introduces a pass option that controls the bitwidth used when lowering index computations. Differential Revision: https://reviews.llvm.org/D80285
This commit is contained in:
parent
e935a540ea
commit
d10b1a38a7
@ -8,6 +8,7 @@
|
||||
#ifndef MLIR_CONVERSION_GPUTONVVM_GPUTONVVMPASS_H_
|
||||
#define MLIR_CONVERSION_GPUTONVVM_GPUTONVVMPASS_H_
|
||||
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
@ -24,9 +25,11 @@ class GPUModuleOp;
|
||||
void populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
|
||||
OwningRewritePatternList &patterns);
|
||||
|
||||
/// Creates a pass that lowers GPU dialect operations to NVVM counterparts.
|
||||
std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
|
||||
createLowerGpuOpsToNVVMOpsPass();
|
||||
/// Creates a pass that lowers GPU dialect operations to NVVM counterparts. The
|
||||
/// index bitwidth used for the lowering of the device side index computations
|
||||
/// is configurable.
|
||||
std::unique_ptr<OperationPass<gpu::GPUModuleOp>> createLowerGpuOpsToNVVMOpsPass(
|
||||
unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
|
@ -8,6 +8,7 @@
|
||||
#ifndef MLIR_CONVERSION_GPUTOROCDL_GPUTOROCDLPASS_H_
|
||||
#define MLIR_CONVERSION_GPUTOROCDL_GPUTOROCDLPASS_H_
|
||||
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
@ -25,9 +26,12 @@ class GPUModuleOp;
|
||||
void populateGpuToROCDLConversionPatterns(LLVMTypeConverter &converter,
|
||||
OwningRewritePatternList &patterns);
|
||||
|
||||
/// Creates a pass that lowers GPU dialect operations to ROCDL counterparts.
|
||||
/// Creates a pass that lowers GPU dialect operations to ROCDL counterparts. The
|
||||
/// index bitwidth used for the lowering of the device side index computations
|
||||
/// is configurable.
|
||||
std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
|
||||
createLowerGpuOpsToROCDLOpsPass();
|
||||
createLowerGpuOpsToROCDLOpsPass(
|
||||
unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
|
@ -100,6 +100,11 @@ def ConvertGpuLaunchFuncToGpuRuntimeCalls : Pass<"launch-func-to-gpu-runtime",
|
||||
def ConvertGpuOpsToNVVMOps : Pass<"convert-gpu-to-nvvm", "gpu::GPUModuleOp"> {
|
||||
let summary = "Generate NVVM operations for gpu operations";
|
||||
let constructor = "mlir::createLowerGpuOpsToNVVMOpsPass()";
|
||||
let options = [
|
||||
Option<"indexBitwidth", "index-bitwidth", "unsigned",
|
||||
/*default=kDeriveIndexBitwidthFromDataLayout*/"0",
|
||||
"Bitwidth of the index type, 0 to use size of machine word">
|
||||
];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -109,6 +114,11 @@ def ConvertGpuOpsToNVVMOps : Pass<"convert-gpu-to-nvvm", "gpu::GPUModuleOp"> {
|
||||
def ConvertGpuOpsToROCDLOps : Pass<"convert-gpu-to-rocdl", "gpu::GPUModuleOp"> {
|
||||
let summary = "Generate ROCDL operations for gpu operations";
|
||||
let constructor = "mlir::createLowerGpuOpsToROCDLOpsPass()";
|
||||
let options = [
|
||||
Option<"indexBitwidth", "index-bitwidth", "unsigned",
|
||||
/*default=kDeriveIndexBitwidthFromDataLayout*/"0",
|
||||
"Bitwidth of the index type, 0 to use size of machine word">
|
||||
];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -15,6 +15,7 @@
|
||||
#ifndef MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H
|
||||
#define MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H
|
||||
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace llvm {
|
||||
@ -35,22 +36,6 @@ class LLVMDialect;
|
||||
class LLVMType;
|
||||
} // namespace LLVM
|
||||
|
||||
/// Set of callbacks that allows the customization of LLVMTypeConverter.
|
||||
struct LLVMTypeConverterCustomization {
|
||||
using CustomCallback = std::function<LogicalResult(LLVMTypeConverter &, Type,
|
||||
SmallVectorImpl<Type> &)>;
|
||||
|
||||
/// Customize the type conversion of function arguments.
|
||||
CustomCallback funcArgConverter;
|
||||
|
||||
/// Used to determine the bitwidth of the LLVM integer type that the index
|
||||
/// type gets lowered to. Defaults to deriving the size from the data layout.
|
||||
unsigned indexBitwidth;
|
||||
|
||||
/// Initialize customization to default callbacks.
|
||||
LLVMTypeConverterCustomization();
|
||||
};
|
||||
|
||||
/// Callback to convert function argument types. It converts a MemRef function
|
||||
/// argument to a list of non-aggregate types containing descriptor
|
||||
/// information, and an UnrankedmemRef function argument to a list containing
|
||||
@ -75,13 +60,11 @@ class LLVMTypeConverter : public TypeConverter {
|
||||
public:
|
||||
using TypeConverter::convertType;
|
||||
|
||||
/// Create an LLVMTypeConverter using the default
|
||||
/// LLVMTypeConverterCustomization.
|
||||
/// Create an LLVMTypeConverter using the default LowerToLLVMOptions.
|
||||
LLVMTypeConverter(MLIRContext *ctx);
|
||||
|
||||
/// Create an LLVMTypeConverter using 'custom' customizations.
|
||||
LLVMTypeConverter(MLIRContext *ctx,
|
||||
const LLVMTypeConverterCustomization &custom);
|
||||
/// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
|
||||
LLVMTypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options);
|
||||
|
||||
/// Convert a function type. The arguments and results are converted one by
|
||||
/// one and results are packed into a wrapped LLVM IR structure type. `result`
|
||||
@ -127,7 +110,7 @@ public:
|
||||
LLVM::LLVMType getIndexType();
|
||||
|
||||
/// Gets the bitwidth of the index type when converted to LLVM.
|
||||
unsigned getIndexTypeBitwidth() { return customizations.indexBitwidth; }
|
||||
unsigned getIndexTypeBitwidth() { return options.indexBitwidth; }
|
||||
|
||||
protected:
|
||||
/// LLVM IR module used to parse/create types.
|
||||
@ -193,8 +176,8 @@ private:
|
||||
// Convert a 1D vector type into an LLVM vector type.
|
||||
Type convertVectorType(VectorType type);
|
||||
|
||||
/// Callbacks for customizing the type conversion.
|
||||
LLVMTypeConverterCustomization customizations;
|
||||
/// Options for customizing the llvm lowering.
|
||||
LowerToLLVMOptions options;
|
||||
};
|
||||
|
||||
/// Helper class to produce LLVM dialect operations extracting or inserting
|
||||
@ -389,11 +372,17 @@ public:
|
||||
};
|
||||
|
||||
/// Base class for operation conversions targeting the LLVM IR dialect. Provides
|
||||
/// conversion patterns with access to an LLVMTypeConverter.
|
||||
/// conversion patterns with access to an LLVMTypeConverter and the
|
||||
/// LowerToLLVMOptions.
|
||||
class ConvertToLLVMPattern : public ConversionPattern {
|
||||
public:
|
||||
ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter,
|
||||
const LowerToLLVMOptions &options = {
|
||||
/*useBarePtrCallConv=*/false,
|
||||
/*emitCWrappers=*/false,
|
||||
/*indexBitwidth=*/kDeriveIndexBitwidthFromDataLayout,
|
||||
/*useAlignedAlloc=*/false},
|
||||
PatternBenefit benefit = 1);
|
||||
|
||||
/// Returns the LLVM dialect.
|
||||
@ -445,6 +434,9 @@ public:
|
||||
protected:
|
||||
/// Reference to the type converter, with potential extensions.
|
||||
LLVMTypeConverter &typeConverter;
|
||||
|
||||
/// Reference to the llvm lowering options.
|
||||
const LowerToLLVMOptions &options;
|
||||
};
|
||||
|
||||
/// Utility class for operation conversions targeting the LLVM dialect that
|
||||
@ -453,10 +445,11 @@ template <typename OpTy>
|
||||
class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
|
||||
public:
|
||||
ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter,
|
||||
const LowerToLLVMOptions &options,
|
||||
PatternBenefit benefit = 1)
|
||||
: ConvertToLLVMPattern(OpTy::getOperationName(),
|
||||
&typeConverter.getContext(), typeConverter,
|
||||
benefit) {}
|
||||
options, benefit) {}
|
||||
};
|
||||
|
||||
namespace LLVM {
|
||||
|
@ -14,43 +14,10 @@
|
||||
namespace mlir {
|
||||
class LLVMTypeConverter;
|
||||
class ModuleOp;
|
||||
template <typename T> class OperationPass;
|
||||
template <typename T>
|
||||
class OperationPass;
|
||||
class OwningRewritePatternList;
|
||||
|
||||
/// Collect a set of patterns to convert memory-related operations from the
|
||||
/// Standard dialect to the LLVM dialect, excluding non-memory-related
|
||||
/// operations and FuncOp.
|
||||
void populateStdToLLVMMemoryConversionPatterns(
|
||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
|
||||
bool useAlignedAlloc);
|
||||
|
||||
/// Collect a set of patterns to convert from the Standard dialect to the LLVM
|
||||
/// dialect, excluding the memory-related operations.
|
||||
void populateStdToLLVMNonMemoryConversionPatterns(
|
||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns);
|
||||
|
||||
/// Collect the default pattern to convert a FuncOp to the LLVM dialect. If
|
||||
/// `emitCWrappers` is set, the pattern will also produce functions
|
||||
/// that pass memref descriptors by pointer-to-structure in addition to the
|
||||
/// default unpacked form.
|
||||
void populateStdToLLVMDefaultFuncOpConversionPattern(
|
||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
|
||||
bool emitCWrappers = false);
|
||||
|
||||
/// Collect a set of default patterns to convert from the Standard dialect to
|
||||
/// LLVM.
|
||||
void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
||||
OwningRewritePatternList &patterns,
|
||||
bool emitCWrappers = false,
|
||||
bool useAlignedAlloc = false);
|
||||
|
||||
/// Collect a set of patterns to convert from the Standard dialect to
|
||||
/// LLVM using the bare pointer calling convention for MemRef function
|
||||
/// arguments.
|
||||
void populateStdToLLVMBarePtrConversionPatterns(
|
||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
|
||||
bool useAlignedAlloc);
|
||||
|
||||
/// Value to pass as bitwidth for the index type when the converter is expected
|
||||
/// to derive the bitwidth from the LLVM data layout.
|
||||
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout = 0;
|
||||
@ -63,6 +30,35 @@ struct LowerToLLVMOptions {
|
||||
bool useAlignedAlloc = false;
|
||||
};
|
||||
|
||||
/// Collect a set of patterns to convert memory-related operations from the
|
||||
/// Standard dialect to the LLVM dialect, excluding non-memory-related
|
||||
/// operations and FuncOp.
|
||||
void populateStdToLLVMMemoryConversionPatterns(
|
||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
|
||||
const LowerToLLVMOptions &options);
|
||||
|
||||
/// Collect a set of patterns to convert from the Standard dialect to the LLVM
|
||||
/// dialect, excluding the memory-related operations.
|
||||
void populateStdToLLVMNonMemoryConversionPatterns(
|
||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
|
||||
const LowerToLLVMOptions &options);
|
||||
|
||||
/// Collect the default pattern to convert a FuncOp to the LLVM dialect. If
|
||||
/// `emitCWrappers` is set, the pattern will also produce functions
|
||||
/// that pass memref descriptors by pointer-to-structure in addition to the
|
||||
/// default unpacked form.
|
||||
void populateStdToLLVMFuncOpConversionPattern(
|
||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
|
||||
const LowerToLLVMOptions &options);
|
||||
|
||||
/// Collect the patterns to convert from the Standard dialect to LLVM.
|
||||
void populateStdToLLVMConversionPatterns(
|
||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
|
||||
const LowerToLLVMOptions &options = {
|
||||
/*useBarePtrCallConv=*/false, /*emitCWrappers=*/false,
|
||||
/*indexBitwidth=*/kDeriveIndexBitwidthFromDataLayout,
|
||||
/*useAlignedAlloc=*/false});
|
||||
|
||||
/// Creates a pass to convert the Standard dialect into the LLVMIR dialect.
|
||||
/// stdlib malloc/free is used by default for allocating memrefs allocated with
|
||||
/// std.alloc, while LLVM's alloca is used for those allocated with std.alloca.
|
||||
|
@ -30,7 +30,6 @@ using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
|
||||
struct GPUShuffleOpLowering : public ConvertToLLVMPattern {
|
||||
explicit GPUShuffleOpLowering(LLVMTypeConverter &lowering_)
|
||||
: ConvertToLLVMPattern(gpu::ShuffleOp::getOperationName(),
|
||||
@ -97,17 +96,27 @@ struct GPUShuffleOpLowering : public ConvertToLLVMPattern {
|
||||
///
|
||||
/// This pass only handles device code and is not meant to be run on GPU host
|
||||
/// code.
|
||||
class LowerGpuOpsToNVVMOpsPass
|
||||
struct LowerGpuOpsToNVVMOpsPass
|
||||
: public ConvertGpuOpsToNVVMOpsBase<LowerGpuOpsToNVVMOpsPass> {
|
||||
public:
|
||||
LowerGpuOpsToNVVMOpsPass() = default;
|
||||
LowerGpuOpsToNVVMOpsPass(unsigned indexBitwidth) {
|
||||
this->indexBitwidth = indexBitwidth;
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
gpu::GPUModuleOp m = getOperation();
|
||||
|
||||
/// Customize the bitwidth used for the device side index computations.
|
||||
LowerToLLVMOptions options = {/*useBarePtrCallConv =*/false,
|
||||
/*emitCWrappers = */ true,
|
||||
/*indexBitwidth =*/indexBitwidth,
|
||||
/*useAlignedAlloc =*/false};
|
||||
|
||||
/// MemRef conversion for GPU to NVVM lowering. The GPU dialect uses memory
|
||||
/// space 5 for private memory attributions, but NVVM represents private
|
||||
/// memory allocations as local `alloca`s in the default address space. This
|
||||
/// converter drops the private memory space to support the use case above.
|
||||
LLVMTypeConverter converter(m.getContext());
|
||||
LLVMTypeConverter converter(m.getContext(), options);
|
||||
converter.addConversion([&](MemRefType type) -> Optional<Type> {
|
||||
if (type.getMemorySpace() != gpu::GPUDialect::getPrivateAddressSpace())
|
||||
return llvm::None;
|
||||
@ -176,6 +185,6 @@ void mlir::populateGpuToNVVMConversionPatterns(
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
|
||||
mlir::createLowerGpuOpsToNVVMOpsPass() {
|
||||
return std::make_unique<LowerGpuOpsToNVVMOpsPass>();
|
||||
mlir::createLowerGpuOpsToNVVMOpsPass(unsigned indexBitwidth) {
|
||||
return std::make_unique<LowerGpuOpsToNVVMOpsPass>(indexBitwidth);
|
||||
}
|
||||
|
@ -41,13 +41,22 @@ namespace {
|
||||
//
|
||||
// This pass only handles device code and is not meant to be run on GPU host
|
||||
// code.
|
||||
class LowerGpuOpsToROCDLOpsPass
|
||||
struct LowerGpuOpsToROCDLOpsPass
|
||||
: public ConvertGpuOpsToROCDLOpsBase<LowerGpuOpsToROCDLOpsPass> {
|
||||
public:
|
||||
LowerGpuOpsToROCDLOpsPass() = default;
|
||||
LowerGpuOpsToROCDLOpsPass(unsigned indexBitwidth) {
|
||||
this->indexBitwidth = indexBitwidth;
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
gpu::GPUModuleOp m = getOperation();
|
||||
|
||||
LLVMTypeConverter converter(m.getContext());
|
||||
/// Customize the bitwidth used for the device side index computations.
|
||||
LowerToLLVMOptions options = {/*useBarePtrCallConv =*/false,
|
||||
/*emitCWrappers = */ true,
|
||||
/*indexBitwidth =*/indexBitwidth,
|
||||
/*useAlignedAlloc =*/false};
|
||||
LLVMTypeConverter converter(m.getContext(), options);
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
|
||||
@ -106,6 +115,6 @@ void mlir::populateGpuToROCDLConversionPatterns(
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
|
||||
mlir::createLowerGpuOpsToROCDLOpsPass() {
|
||||
return std::make_unique<LowerGpuOpsToROCDLOpsPass>();
|
||||
mlir::createLowerGpuOpsToROCDLOpsPass(unsigned indexBitwidth) {
|
||||
return std::make_unique<LowerGpuOpsToROCDLOpsPass>(indexBitwidth);
|
||||
}
|
||||
|
@ -51,11 +51,6 @@ static LLVM::LLVMType unwrap(Type type) {
|
||||
return wrappedLLVMType;
|
||||
}
|
||||
|
||||
/// Initialize customization to default callbacks.
|
||||
LLVMTypeConverterCustomization::LLVMTypeConverterCustomization()
|
||||
: funcArgConverter(structFuncArgTypeConverter),
|
||||
indexBitwidth(kDeriveIndexBitwidthFromDataLayout) {}
|
||||
|
||||
/// Callback to convert function argument types. It converts a MemRef function
|
||||
/// argument to a list of non-aggregate types containing descriptor
|
||||
/// information, and an UnrankedmemRef function argument to a list containing
|
||||
@ -122,20 +117,19 @@ LogicalResult mlir::barePtrFuncArgTypeConverter(LLVMTypeConverter &converter,
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Create an LLVMTypeConverter using default LLVMTypeConverterCustomization.
|
||||
/// Create an LLVMTypeConverter using default LowerToLLVMOptions.
|
||||
LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx)
|
||||
: LLVMTypeConverter(ctx, LLVMTypeConverterCustomization()) {}
|
||||
: LLVMTypeConverter(ctx, LowerToLLVMOptions()) {}
|
||||
|
||||
/// Create an LLVMTypeConverter using 'custom' customizations.
|
||||
LLVMTypeConverter::LLVMTypeConverter(
|
||||
MLIRContext *ctx, const LLVMTypeConverterCustomization &customs)
|
||||
/// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
|
||||
LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
|
||||
const LowerToLLVMOptions &options_)
|
||||
: llvmDialect(ctx->getRegisteredDialect<LLVM::LLVMDialect>()),
|
||||
customizations(customs) {
|
||||
options(options_) {
|
||||
assert(llvmDialect && "LLVM IR dialect is not registered");
|
||||
module = &llvmDialect->getLLVMModule();
|
||||
if (customizations.indexBitwidth == kDeriveIndexBitwidthFromDataLayout)
|
||||
customizations.indexBitwidth =
|
||||
module->getDataLayout().getPointerSizeInBits();
|
||||
if (options.indexBitwidth == kDeriveIndexBitwidthFromDataLayout)
|
||||
options.indexBitwidth = module->getDataLayout().getPointerSizeInBits();
|
||||
|
||||
// Register conversions for the standard types.
|
||||
addConversion([&](ComplexType type) { return convertComplexType(type); });
|
||||
@ -262,11 +256,15 @@ SmallVector<Type, 2> LLVMTypeConverter::convertUnrankedMemRefSignature() {
|
||||
LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature(
|
||||
FunctionType type, bool isVariadic,
|
||||
LLVMTypeConverter::SignatureConversion &result) {
|
||||
// Select the argument converter depending on the calling convetion.
|
||||
auto funcArgConverter = options.useBarePtrCallConv
|
||||
? barePtrFuncArgTypeConverter
|
||||
: structFuncArgTypeConverter;
|
||||
// Convert argument types one by one and check for errors.
|
||||
for (auto &en : llvm::enumerate(type.getInputs())) {
|
||||
Type type = en.value();
|
||||
SmallVector<Type, 8> converted;
|
||||
if (failed(customizations.funcArgConverter(*this, type, converted)))
|
||||
if (failed(funcArgConverter(*this, type, converted)))
|
||||
return {};
|
||||
result.addInputs(en.index(), converted);
|
||||
}
|
||||
@ -397,9 +395,10 @@ Type LLVMTypeConverter::convertVectorType(VectorType type) {
|
||||
ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName,
|
||||
MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter_,
|
||||
const LowerToLLVMOptions &options_,
|
||||
PatternBenefit benefit)
|
||||
: ConversionPattern(rootOpName, benefit, typeConverter_, context),
|
||||
typeConverter(typeConverter_) {}
|
||||
typeConverter(typeConverter_), options(options_) {}
|
||||
|
||||
/*============================================================================*/
|
||||
/* StructBuilder implementation */
|
||||
@ -1051,8 +1050,10 @@ protected:
|
||||
/// information.
|
||||
static constexpr StringRef kEmitIfaceAttrName = "llvm.emit_c_interface";
|
||||
struct FuncOpConversion : public FuncOpConversionBase {
|
||||
FuncOpConversion(LLVMTypeConverter &converter, bool emitCWrappers)
|
||||
: FuncOpConversionBase(converter), emitWrappers(emitCWrappers) {}
|
||||
FuncOpConversion(LLVMTypeConverter &converter,
|
||||
const LowerToLLVMOptions &options)
|
||||
: FuncOpConversionBase(converter, options) {}
|
||||
using ConvertOpToLLVMPattern<FuncOp>::options;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
@ -1063,7 +1064,7 @@ struct FuncOpConversion : public FuncOpConversionBase {
|
||||
if (!newFuncOp)
|
||||
return failure();
|
||||
|
||||
if (emitWrappers || funcOp.getAttrOfType<UnitAttr>(kEmitIfaceAttrName)) {
|
||||
if (options.emitCWrappers || funcOp.getAttrOfType<UnitAttr>(kEmitIfaceAttrName)) {
|
||||
if (newFuncOp.isExternal())
|
||||
wrapExternalFunction(rewriter, op->getLoc(), typeConverter, funcOp,
|
||||
newFuncOp);
|
||||
@ -1075,11 +1076,6 @@ struct FuncOpConversion : public FuncOpConversionBase {
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
/// If true, also create the adaptor functions having signatures compatible
|
||||
/// with those produced by clang.
|
||||
const bool emitWrappers;
|
||||
};
|
||||
|
||||
/// FuncOp legalization pattern that converts MemRef arguments to bare pointers
|
||||
@ -1506,11 +1502,11 @@ struct AllocLikeOpLowering : public ConvertOpToLLVMPattern<AllocLikeOp> {
|
||||
using ConvertOpToLLVMPattern<AllocLikeOp>::getIndexType;
|
||||
using ConvertOpToLLVMPattern<AllocLikeOp>::typeConverter;
|
||||
using ConvertOpToLLVMPattern<AllocLikeOp>::getVoidPtrType;
|
||||
using ConvertOpToLLVMPattern<AllocLikeOp>::options;
|
||||
|
||||
explicit AllocLikeOpLowering(LLVMTypeConverter &converter,
|
||||
bool useAlignedAlloc = false)
|
||||
: ConvertOpToLLVMPattern<AllocLikeOp>(converter),
|
||||
useAlignedAlloc(useAlignedAlloc) {}
|
||||
const LowerToLLVMOptions &options)
|
||||
: ConvertOpToLLVMPattern<AllocLikeOp>(converter, options) {}
|
||||
|
||||
LogicalResult match(Operation *op) const override {
|
||||
MemRefType memRefType = cast<AllocLikeOp>(op).getType();
|
||||
@ -1677,7 +1673,7 @@ struct AllocLikeOpLowering : public ConvertOpToLLVMPattern<AllocLikeOp> {
|
||||
/// allocation size to be a multiple of alignment,
|
||||
Optional<int64_t> getAllocationAlignment(AllocOp allocOp) const {
|
||||
// No alignment can be used for the 'malloc' call itself.
|
||||
if (!useAlignedAlloc)
|
||||
if (!options.useAlignedAlloc)
|
||||
return None;
|
||||
|
||||
if (allocOp.alignment())
|
||||
@ -1849,16 +1845,14 @@ struct AllocLikeOpLowering : public ConvertOpToLLVMPattern<AllocLikeOp> {
|
||||
}
|
||||
|
||||
protected:
|
||||
/// Use aligned_alloc instead of malloc for all heap allocations.
|
||||
bool useAlignedAlloc;
|
||||
/// The minimum alignment to use with aligned_alloc (has to be a power of 2).
|
||||
uint64_t kMinAlignedAllocAlignment = 16UL;
|
||||
};
|
||||
|
||||
struct AllocOpLowering : public AllocLikeOpLowering<AllocOp> {
|
||||
explicit AllocOpLowering(LLVMTypeConverter &converter,
|
||||
bool useAlignedAlloc = false)
|
||||
: AllocLikeOpLowering<AllocOp>(converter, useAlignedAlloc) {}
|
||||
const LowerToLLVMOptions &options)
|
||||
: AllocLikeOpLowering<AllocOp>(converter, options) {}
|
||||
};
|
||||
|
||||
using AllocaOpLowering = AllocLikeOpLowering<AllocaOp>;
|
||||
@ -1939,8 +1933,9 @@ struct CallIndirectOpLowering : public CallOpInterfaceLowering<CallIndirectOp> {
|
||||
struct DeallocOpLowering : public ConvertOpToLLVMPattern<DeallocOp> {
|
||||
using ConvertOpToLLVMPattern<DeallocOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
explicit DeallocOpLowering(LLVMTypeConverter &converter)
|
||||
: ConvertOpToLLVMPattern<DeallocOp>(converter) {}
|
||||
explicit DeallocOpLowering(LLVMTypeConverter &converter,
|
||||
const LowerToLLVMOptions &options)
|
||||
: ConvertOpToLLVMPattern<DeallocOp>(converter, options) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
@ -2960,7 +2955,8 @@ private:
|
||||
|
||||
/// Collect a set of patterns to convert from the Standard dialect to LLVM.
|
||||
void mlir::populateStdToLLVMNonMemoryConversionPatterns(
|
||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
|
||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
|
||||
const LowerToLLVMOptions &options) {
|
||||
// FIXME: this should be tablegen'ed
|
||||
// clang-format off
|
||||
patterns.insert<
|
||||
@ -3023,13 +3019,13 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
|
||||
UnsignedRemIOpLowering,
|
||||
UnsignedShiftRightOpLowering,
|
||||
XOrOpLowering,
|
||||
ZeroExtendIOpLowering>(converter);
|
||||
ZeroExtendIOpLowering>(converter, options);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
void mlir::populateStdToLLVMMemoryConversionPatterns(
|
||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
|
||||
bool useAlignedAlloc) {
|
||||
const LowerToLLVMOptions &options) {
|
||||
// clang-format off
|
||||
patterns.insert<
|
||||
AssumeAlignmentOpLowering,
|
||||
@ -3039,41 +3035,26 @@ void mlir::populateStdToLLVMMemoryConversionPatterns(
|
||||
MemRefCastOpLowering,
|
||||
StoreOpLowering,
|
||||
SubViewOpLowering,
|
||||
ViewOpLowering>(converter);
|
||||
patterns.insert<
|
||||
AllocOpLowering
|
||||
>(converter, useAlignedAlloc);
|
||||
ViewOpLowering,
|
||||
AllocOpLowering>(converter, options);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
void mlir::populateStdToLLVMDefaultFuncOpConversionPattern(
|
||||
void mlir::populateStdToLLVMFuncOpConversionPattern(
|
||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
|
||||
bool emitCWrappers) {
|
||||
patterns.insert<FuncOpConversion>(converter, emitCWrappers);
|
||||
const LowerToLLVMOptions &options) {
|
||||
if (options.useBarePtrCallConv)
|
||||
patterns.insert<BarePtrFuncOpConversion>(converter, options);
|
||||
else
|
||||
patterns.insert<FuncOpConversion>(converter, options);
|
||||
}
|
||||
|
||||
void mlir::populateStdToLLVMConversionPatterns(
|
||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
|
||||
bool emitCWrappers, bool useAlignedAlloc) {
|
||||
populateStdToLLVMDefaultFuncOpConversionPattern(converter, patterns,
|
||||
emitCWrappers);
|
||||
populateStdToLLVMNonMemoryConversionPatterns(converter, patterns);
|
||||
populateStdToLLVMMemoryConversionPatterns(converter, patterns,
|
||||
useAlignedAlloc);
|
||||
}
|
||||
|
||||
static void populateStdToLLVMBarePtrFuncOpConversionPattern(
|
||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
|
||||
patterns.insert<BarePtrFuncOpConversion>(converter);
|
||||
}
|
||||
|
||||
void mlir::populateStdToLLVMBarePtrConversionPatterns(
|
||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
|
||||
bool useAlignedAlloc) {
|
||||
populateStdToLLVMBarePtrFuncOpConversionPattern(converter, patterns);
|
||||
populateStdToLLVMNonMemoryConversionPatterns(converter, patterns);
|
||||
populateStdToLLVMMemoryConversionPatterns(converter, patterns,
|
||||
useAlignedAlloc);
|
||||
const LowerToLLVMOptions &options) {
|
||||
populateStdToLLVMFuncOpConversionPattern(converter, patterns, options);
|
||||
populateStdToLLVMNonMemoryConversionPatterns(converter, patterns, options);
|
||||
populateStdToLLVMMemoryConversionPatterns(converter, patterns, options);
|
||||
}
|
||||
|
||||
// Create an LLVM IR structure type if there is more than one result.
|
||||
@ -3163,19 +3144,12 @@ struct LLVMLoweringPass : public ConvertStandardToLLVMBase<LLVMLoweringPass> {
|
||||
|
||||
ModuleOp m = getOperation();
|
||||
|
||||
LLVMTypeConverterCustomization customs;
|
||||
customs.funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter
|
||||
: structFuncArgTypeConverter;
|
||||
customs.indexBitwidth = indexBitwidth;
|
||||
LLVMTypeConverter typeConverter(&getContext(), customs);
|
||||
LowerToLLVMOptions options = {useBarePtrCallConv, emitCWrappers,
|
||||
indexBitwidth, useAlignedAlloc};
|
||||
LLVMTypeConverter typeConverter(&getContext(), options);
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
if (useBarePtrCallConv)
|
||||
populateStdToLLVMBarePtrConversionPatterns(typeConverter, patterns,
|
||||
useAlignedAlloc);
|
||||
else
|
||||
populateStdToLLVMConversionPatterns(typeConverter, patterns,
|
||||
emitCWrappers, useAlignedAlloc);
|
||||
populateStdToLLVMConversionPatterns(typeConverter, patterns, options);
|
||||
|
||||
LLVMConversionTarget target(getContext());
|
||||
if (failed(applyPartialConversion(m, target, patterns)))
|
||||
|
@ -1,36 +1,52 @@
|
||||
// RUN: mlir-opt %s -convert-gpu-to-nvvm -split-input-file | FileCheck %s
|
||||
// RUN: mlir-opt %s -convert-gpu-to-nvvm='index-bitwidth=32' -split-input-file | FileCheck --check-prefix=CHECK32 %s
|
||||
|
||||
gpu.module @test_module {
|
||||
// CHECK-LABEL: func @gpu_index_ops()
|
||||
// CHECK32-LABEL: func @gpu_index_ops()
|
||||
func @gpu_index_ops()
|
||||
-> (index, index, index, index, index, index,
|
||||
index, index, index, index, index, index) {
|
||||
// CHECK32-NOT: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
|
||||
|
||||
// CHECK: = nvvm.read.ptx.sreg.tid.x : !llvm.i32
|
||||
// CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
|
||||
%tIdX = "gpu.thread_id"() {dimension = "x"} : () -> (index)
|
||||
// CHECK: = nvvm.read.ptx.sreg.tid.y : !llvm.i32
|
||||
// CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
|
||||
%tIdY = "gpu.thread_id"() {dimension = "y"} : () -> (index)
|
||||
// CHECK: = nvvm.read.ptx.sreg.tid.z : !llvm.i32
|
||||
// CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
|
||||
%tIdZ = "gpu.thread_id"() {dimension = "z"} : () -> (index)
|
||||
|
||||
// CHECK: = nvvm.read.ptx.sreg.ntid.x : !llvm.i32
|
||||
// CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
|
||||
%bDimX = "gpu.block_dim"() {dimension = "x"} : () -> (index)
|
||||
// CHECK: = nvvm.read.ptx.sreg.ntid.y : !llvm.i32
|
||||
// CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
|
||||
%bDimY = "gpu.block_dim"() {dimension = "y"} : () -> (index)
|
||||
// CHECK: = nvvm.read.ptx.sreg.ntid.z : !llvm.i32
|
||||
// CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
|
||||
%bDimZ = "gpu.block_dim"() {dimension = "z"} : () -> (index)
|
||||
|
||||
// CHECK: = nvvm.read.ptx.sreg.ctaid.x : !llvm.i32
|
||||
// CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
|
||||
%bIdX = "gpu.block_id"() {dimension = "x"} : () -> (index)
|
||||
// CHECK: = nvvm.read.ptx.sreg.ctaid.y : !llvm.i32
|
||||
// CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
|
||||
%bIdY = "gpu.block_id"() {dimension = "y"} : () -> (index)
|
||||
// CHECK: = nvvm.read.ptx.sreg.ctaid.z : !llvm.i32
|
||||
// CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
|
||||
%bIdZ = "gpu.block_id"() {dimension = "z"} : () -> (index)
|
||||
|
||||
// CHECK: = nvvm.read.ptx.sreg.nctaid.x : !llvm.i32
|
||||
// CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
|
||||
%gDimX = "gpu.grid_dim"() {dimension = "x"} : () -> (index)
|
||||
// CHECK: = nvvm.read.ptx.sreg.nctaid.y : !llvm.i32
|
||||
// CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
|
||||
%gDimY = "gpu.grid_dim"() {dimension = "y"} : () -> (index)
|
||||
// CHECK: = nvvm.read.ptx.sreg.nctaid.z : !llvm.i32
|
||||
// CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
|
||||
%gDimZ = "gpu.grid_dim"() {dimension = "z"} : () -> (index)
|
||||
|
||||
std.return %tIdX, %tIdY, %tIdZ, %bDimX, %bDimY, %bDimZ,
|
||||
@ -42,6 +58,21 @@ gpu.module @test_module {
|
||||
|
||||
// -----
|
||||
|
||||
gpu.module @test_module {
|
||||
// CHECK-LABEL: func @gpu_index_comp
|
||||
// CHECK32-LABEL: func @gpu_index_comp
|
||||
func @gpu_index_comp(%idx : index) -> index {
|
||||
// CHECK: = llvm.add %{{.*}}, %{{.*}} : !llvm.i64
|
||||
// CHECK32: = llvm.add %{{.*}}, %{{.*}} : !llvm.i32
|
||||
%0 = addi %idx, %idx : index
|
||||
// CHECK: llvm.return %{{.*}} : !llvm.i64
|
||||
// CHECK32: llvm.return %{{.*}} : !llvm.i32
|
||||
std.return %0 : index
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
gpu.module @test_module {
|
||||
// CHECK-LABEL: func @gpu_all_reduce_op()
|
||||
gpu.func @gpu_all_reduce_op() {
|
||||
|
@ -1,36 +1,52 @@
|
||||
// RUN: mlir-opt %s -convert-gpu-to-rocdl -split-input-file | FileCheck %s
|
||||
// RUN: mlir-opt %s -convert-gpu-to-rocdl='index-bitwidth=32' -split-input-file | FileCheck --check-prefix=CHECK32 %s
|
||||
|
||||
gpu.module @test_module {
|
||||
// CHECK-LABEL: func @gpu_index_ops()
|
||||
// CHECK32-LABEL: func @gpu_index_ops()
|
||||
func @gpu_index_ops()
|
||||
-> (index, index, index, index, index, index,
|
||||
index, index, index, index, index, index) {
|
||||
// CHECK32-NOT: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
|
||||
|
||||
// CHECK: rocdl.workitem.id.x : !llvm.i32
|
||||
// CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
|
||||
%tIdX = "gpu.thread_id"() {dimension = "x"} : () -> (index)
|
||||
// CHECK: rocdl.workitem.id.y : !llvm.i32
|
||||
// CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
|
||||
%tIdY = "gpu.thread_id"() {dimension = "y"} : () -> (index)
|
||||
// CHECK: rocdl.workitem.id.z : !llvm.i32
|
||||
// CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
|
||||
%tIdZ = "gpu.thread_id"() {dimension = "z"} : () -> (index)
|
||||
|
||||
// CHECK: rocdl.workgroup.dim.x : !llvm.i32
|
||||
// CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
|
||||
%bDimX = "gpu.block_dim"() {dimension = "x"} : () -> (index)
|
||||
// CHECK: rocdl.workgroup.dim.y : !llvm.i32
|
||||
// CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
|
||||
%bDimY = "gpu.block_dim"() {dimension = "y"} : () -> (index)
|
||||
// CHECK: rocdl.workgroup.dim.z : !llvm.i32
|
||||
// CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
|
||||
%bDimZ = "gpu.block_dim"() {dimension = "z"} : () -> (index)
|
||||
|
||||
// CHECK: rocdl.workgroup.id.x : !llvm.i32
|
||||
// CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
|
||||
%bIdX = "gpu.block_id"() {dimension = "x"} : () -> (index)
|
||||
// CHECK: rocdl.workgroup.id.y : !llvm.i32
|
||||
// CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
|
||||
%bIdY = "gpu.block_id"() {dimension = "y"} : () -> (index)
|
||||
// CHECK: rocdl.workgroup.id.z : !llvm.i32
|
||||
// CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
|
||||
%bIdZ = "gpu.block_id"() {dimension = "z"} : () -> (index)
|
||||
|
||||
// CHECK: rocdl.grid.dim.x : !llvm.i32
|
||||
// CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
|
||||
%gDimX = "gpu.grid_dim"() {dimension = "x"} : () -> (index)
|
||||
// CHECK: rocdl.grid.dim.y : !llvm.i32
|
||||
// CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
|
||||
%gDimY = "gpu.grid_dim"() {dimension = "y"} : () -> (index)
|
||||
// CHECK: rocdl.grid.dim.z : !llvm.i32
|
||||
// CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
|
||||
%gDimZ = "gpu.grid_dim"() {dimension = "z"} : () -> (index)
|
||||
|
||||
std.return %tIdX, %tIdY, %tIdZ, %bDimX, %bDimY, %bDimZ,
|
||||
@ -42,6 +58,21 @@ gpu.module @test_module {
|
||||
|
||||
// -----
|
||||
|
||||
gpu.module @test_module {
|
||||
// CHECK-LABEL: func @gpu_index_comp
|
||||
// CHECK32-LABEL: func @gpu_index_comp
|
||||
func @gpu_index_comp(%idx : index) -> index {
|
||||
// CHECK: = llvm.add %{{.*}}, %{{.*}} : !llvm.i64
|
||||
// CHECK32: = llvm.add %{{.*}}, %{{.*}} : !llvm.i32
|
||||
%0 = addi %idx, %idx : index
|
||||
// CHECK: llvm.return %{{.*}} : !llvm.i64
|
||||
// CHECK32: llvm.return %{{.*}} : !llvm.i32
|
||||
std.return %0 : index
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
gpu.module @test_module {
|
||||
// CHECK-LABEL: func @gpu_sync()
|
||||
func @gpu_sync() {
|
||||
|
Loading…
x
Reference in New Issue
Block a user