mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-11-27 15:41:46 +00:00
[NFC] Use ConvertOpToLLVMPattern instead of ConvertToLLVMPattern.
- use ConvertOpToLLVMPattern to avoid explicit casting and in most cases the constructor can be reused to save a few lines of code. Differential Revision: https://reviews.llvm.org/D92989
This commit is contained in:
parent
a1ae3c6ac9
commit
563879b6f9
@ -18,8 +18,7 @@ template <typename T> class OperationPass;
|
||||
|
||||
/// Populate the given list with patterns that convert from Linalg to LLVM.
|
||||
void populateLinalgToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
||||
OwningRewritePatternList &patterns,
|
||||
MLIRContext *ctx);
|
||||
OwningRewritePatternList &patterns);
|
||||
|
||||
/// Create a pass to convert Linalg operations to the LLVMIR dialect.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertLinalgToLLVMPass();
|
||||
|
@ -19,8 +19,7 @@ class OperationPass;
|
||||
class OwningRewritePatternList;
|
||||
|
||||
/// Populate the given list with patterns that convert from OpenMP to LLVM.
|
||||
void populateOpenMPToLLVMConversionPatterns(MLIRContext *context,
|
||||
LLVMTypeConverter &converter,
|
||||
void populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
||||
OwningRewritePatternList &patterns);
|
||||
|
||||
/// Create a pass to convert OpenMP operations to the LLVMIR dialect.
|
||||
|
@ -565,8 +565,8 @@ protected:
|
||||
template <typename SourceOp>
|
||||
class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
|
||||
public:
|
||||
ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit = 1)
|
||||
explicit ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit = 1)
|
||||
: ConvertToLLVMPattern(SourceOp::getOperationName(),
|
||||
&typeConverter.getContext(), typeConverter,
|
||||
benefit) {}
|
||||
|
@ -34,8 +34,7 @@ static Type getSrcVectorElementType(OpTy op) {
|
||||
/// operands as is, preserve attributes.
|
||||
template <typename SourceOp, typename TargetOp>
|
||||
static LogicalResult
|
||||
matchAndRewriteOneToOne(const ConvertToLLVMPattern &lowering,
|
||||
LLVMTypeConverter &typeConverter, Operation *op,
|
||||
matchAndRewriteOneToOne(LLVMTypeConverter &typeConverter, Operation *op,
|
||||
ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
unsigned numResults = op->getNumResults();
|
||||
@ -73,71 +72,61 @@ namespace {
|
||||
// TODO: Patterns are too verbose due to the fact that we have 1 op (e.g.
|
||||
// MaskRndScaleOp) and different possible target ops. It would be better to take
|
||||
// a Functor so that all these conversions become 1-liners.
|
||||
struct MaskRndScaleOpPS512Conversion : public ConvertToLLVMPattern {
|
||||
explicit MaskRndScaleOpPS512Conversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: ConvertToLLVMPattern(MaskRndScaleOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
struct MaskRndScaleOpPS512Conversion
|
||||
: public ConvertOpToLLVMPattern<MaskRndScaleOp> {
|
||||
using ConvertOpToLLVMPattern<MaskRndScaleOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(MaskRndScaleOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (!getSrcVectorElementType(cast<MaskRndScaleOp>(op)).isF32())
|
||||
if (!getSrcVectorElementType(op).isF32())
|
||||
return failure();
|
||||
return matchAndRewriteOneToOne<MaskRndScaleOp,
|
||||
LLVM::x86_avx512_mask_rndscale_ps_512>(
|
||||
*this, *getTypeConverter(), op, operands, rewriter);
|
||||
*getTypeConverter(), op, operands, rewriter);
|
||||
}
|
||||
};
|
||||
|
||||
struct MaskRndScaleOpPD512Conversion : public ConvertToLLVMPattern {
|
||||
explicit MaskRndScaleOpPD512Conversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: ConvertToLLVMPattern(MaskRndScaleOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
struct MaskRndScaleOpPD512Conversion
|
||||
: public ConvertOpToLLVMPattern<MaskRndScaleOp> {
|
||||
using ConvertOpToLLVMPattern<MaskRndScaleOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(MaskRndScaleOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (!getSrcVectorElementType(cast<MaskRndScaleOp>(op)).isF64())
|
||||
if (!getSrcVectorElementType(op).isF64())
|
||||
return failure();
|
||||
return matchAndRewriteOneToOne<MaskRndScaleOp,
|
||||
LLVM::x86_avx512_mask_rndscale_pd_512>(
|
||||
*this, *getTypeConverter(), op, operands, rewriter);
|
||||
*getTypeConverter(), op, operands, rewriter);
|
||||
}
|
||||
};
|
||||
|
||||
struct ScaleFOpPS512Conversion : public ConvertToLLVMPattern {
|
||||
explicit ScaleFOpPS512Conversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: ConvertToLLVMPattern(MaskScaleFOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
struct ScaleFOpPS512Conversion : public ConvertOpToLLVMPattern<MaskScaleFOp> {
|
||||
using ConvertOpToLLVMPattern<MaskScaleFOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(MaskScaleFOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (!getSrcVectorElementType(cast<MaskScaleFOp>(op)).isF32())
|
||||
if (!getSrcVectorElementType(op).isF32())
|
||||
return failure();
|
||||
return matchAndRewriteOneToOne<MaskScaleFOp,
|
||||
LLVM::x86_avx512_mask_scalef_ps_512>(
|
||||
*this, *getTypeConverter(), op, operands, rewriter);
|
||||
*getTypeConverter(), op, operands, rewriter);
|
||||
}
|
||||
};
|
||||
|
||||
struct ScaleFOpPD512Conversion : public ConvertToLLVMPattern {
|
||||
explicit ScaleFOpPD512Conversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: ConvertToLLVMPattern(MaskScaleFOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
struct ScaleFOpPD512Conversion : public ConvertOpToLLVMPattern<MaskScaleFOp> {
|
||||
using ConvertOpToLLVMPattern<MaskScaleFOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(MaskScaleFOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (!getSrcVectorElementType(cast<MaskScaleFOp>(op)).isF64())
|
||||
if (!getSrcVectorElementType(op).isF64())
|
||||
return failure();
|
||||
return matchAndRewriteOneToOne<MaskScaleFOp,
|
||||
LLVM::x86_avx512_mask_scalef_pd_512>(
|
||||
*this, *getTypeConverter(), op, operands, rewriter);
|
||||
*getTypeConverter(), op, operands, rewriter);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
@ -145,11 +134,10 @@ struct ScaleFOpPD512Conversion : public ConvertToLLVMPattern {
|
||||
/// Populate the given list with patterns that convert from AVX512 to LLVM.
|
||||
void mlir::populateAVX512ToLLVMConversionPatterns(
|
||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
|
||||
MLIRContext *ctx = converter.getDialect()->getContext();
|
||||
// clang-format off
|
||||
patterns.insert<MaskRndScaleOpPS512Conversion,
|
||||
MaskRndScaleOpPD512Conversion,
|
||||
ScaleFOpPS512Conversion,
|
||||
ScaleFOpPD512Conversion>(ctx, converter);
|
||||
ScaleFOpPD512Conversion>(converter);
|
||||
// clang-format on
|
||||
}
|
||||
|
@ -18,17 +18,13 @@
|
||||
namespace mlir {
|
||||
|
||||
template <unsigned AllocaAddrSpace>
|
||||
struct GPUFuncOpLowering : ConvertToLLVMPattern {
|
||||
explicit GPUFuncOpLowering(LLVMTypeConverter &typeConverter)
|
||||
: ConvertToLLVMPattern(gpu::GPUFuncOp::getOperationName(),
|
||||
typeConverter.getDialect()->getContext(),
|
||||
typeConverter) {}
|
||||
struct GPUFuncOpLowering : ConvertOpToLLVMPattern<gpu::GPUFuncOp> {
|
||||
using ConvertOpToLLVMPattern<gpu::GPUFuncOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
assert(operands.empty() && "func op is not expected to have operands");
|
||||
auto gpuFuncOp = cast<gpu::GPUFuncOp>(op);
|
||||
Location loc = gpuFuncOp.getLoc();
|
||||
|
||||
SmallVector<LLVM::GlobalOp, 3> workgroupBuffers;
|
||||
@ -154,14 +150,11 @@ struct GPUFuncOpLowering : ConvertToLLVMPattern {
|
||||
}
|
||||
};
|
||||
|
||||
struct GPUReturnOpLowering : public ConvertToLLVMPattern {
|
||||
GPUReturnOpLowering(LLVMTypeConverter &typeConverter)
|
||||
: ConvertToLLVMPattern(gpu::ReturnOp::getOperationName(),
|
||||
typeConverter.getDialect()->getContext(),
|
||||
typeConverter) {}
|
||||
struct GPUReturnOpLowering : public ConvertOpToLLVMPattern<gpu::ReturnOp> {
|
||||
using ConvertOpToLLVMPattern<gpu::ReturnOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(gpu::ReturnOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands);
|
||||
return success();
|
||||
|
@ -21,7 +21,7 @@ namespace mlir {
|
||||
// `indexBitwidth`, sign-extend or truncate the resulting value to match the
|
||||
// bitwidth expected by the consumers of the value.
|
||||
template <typename Op, typename XOp, typename YOp, typename ZOp>
|
||||
struct GPUIndexIntrinsicOpLowering : public ConvertToLLVMPattern {
|
||||
struct GPUIndexIntrinsicOpLowering : public ConvertOpToLLVMPattern<Op> {
|
||||
private:
|
||||
enum dimension { X = 0, Y = 1, Z = 2, invalid };
|
||||
unsigned indexBitwidth;
|
||||
@ -36,19 +36,17 @@ private:
|
||||
|
||||
public:
|
||||
explicit GPUIndexIntrinsicOpLowering(LLVMTypeConverter &typeConverter)
|
||||
: ConvertToLLVMPattern(Op::getOperationName(),
|
||||
typeConverter.getDialect()->getContext(),
|
||||
typeConverter),
|
||||
: ConvertOpToLLVMPattern<Op>(typeConverter),
|
||||
indexBitwidth(typeConverter.getIndexTypeBitwidth()) {}
|
||||
|
||||
// Convert the kernel arguments to an LLVM type, preserve the rest.
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(Op op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
MLIRContext *context = rewriter.getContext();
|
||||
Value newOp;
|
||||
switch (dimensionToIndex(cast<Op>(op))) {
|
||||
switch (dimensionToIndex(op)) {
|
||||
case X:
|
||||
newOp = rewriter.create<XOp>(loc, LLVM::LLVMType::getInt32Ty(context));
|
||||
break;
|
||||
|
@ -29,16 +29,15 @@ namespace mlir {
|
||||
/// will be transformed into
|
||||
/// llvm.call @__nv_expf(%arg_f32) : (!llvm.float) -> !llvm.float
|
||||
template <typename SourceOp>
|
||||
struct OpToFuncCallLowering : public ConvertToLLVMPattern {
|
||||
struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
|
||||
public:
|
||||
explicit OpToFuncCallLowering(LLVMTypeConverter &lowering_, StringRef f32Func,
|
||||
StringRef f64Func)
|
||||
: ConvertToLLVMPattern(SourceOp::getOperationName(),
|
||||
lowering_.getDialect()->getContext(), lowering_),
|
||||
f32Func(f32Func), f64Func(f64Func) {}
|
||||
: ConvertOpToLLVMPattern<SourceOp>(lowering_), f32Func(f32Func),
|
||||
f64Func(f64Func) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
using LLVM::LLVMFuncOp;
|
||||
using LLVM::LLVMType;
|
||||
|
@ -31,10 +31,8 @@ using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
struct GPUShuffleOpLowering : public ConvertToLLVMPattern {
|
||||
explicit GPUShuffleOpLowering(LLVMTypeConverter &lowering_)
|
||||
: ConvertToLLVMPattern(gpu::ShuffleOp::getOperationName(),
|
||||
lowering_.getDialect()->getContext(), lowering_) {}
|
||||
struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
|
||||
using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
/// Lowers a shuffle to the corresponding NVVM op.
|
||||
///
|
||||
@ -53,7 +51,7 @@ struct GPUShuffleOpLowering : public ConvertToLLVMPattern {
|
||||
/// %shfl_pred = llvm.extractvalue %shfl[1 : index] :
|
||||
/// !llvm<"{ float, i1 }">
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(gpu::ShuffleOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
gpu::ShuffleOpAdaptor adaptor(operands);
|
||||
|
@ -126,19 +126,17 @@ private:
|
||||
};
|
||||
|
||||
// RangeOp creates a new range descriptor.
|
||||
class RangeOpConversion : public ConvertToLLVMPattern {
|
||||
class RangeOpConversion : public ConvertOpToLLVMPattern<RangeOp> {
|
||||
public:
|
||||
explicit RangeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
|
||||
: ConvertToLLVMPattern(RangeOp::getOperationName(), context, lowering_) {}
|
||||
using ConvertOpToLLVMPattern<RangeOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(RangeOp rangeOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto rangeOp = cast<RangeOp>(op);
|
||||
auto rangeDescriptorTy = convertRangeType(
|
||||
rangeOp.getType().cast<RangeType>(), *getTypeConverter());
|
||||
|
||||
edsc::ScopedContext context(rewriter, op->getLoc());
|
||||
edsc::ScopedContext context(rewriter, rangeOp->getLoc());
|
||||
|
||||
// Fill in an aggregate value of the descriptor.
|
||||
RangeOpAdaptor adaptor(operands);
|
||||
@ -146,7 +144,7 @@ public:
|
||||
desc = llvm_insertvalue(desc, adaptor.min(), rewriter.getI64ArrayAttr(0));
|
||||
desc = llvm_insertvalue(desc, adaptor.max(), rewriter.getI64ArrayAttr(1));
|
||||
desc = llvm_insertvalue(desc, adaptor.step(), rewriter.getI64ArrayAttr(2));
|
||||
rewriter.replaceOp(op, desc);
|
||||
rewriter.replaceOp(rangeOp, desc);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@ -154,17 +152,13 @@ public:
|
||||
// ReshapeOp creates a new view descriptor of the proper rank.
|
||||
// For now, the only conversion supported is for target MemRef with static sizes
|
||||
// and strides.
|
||||
class ReshapeOpConversion : public ConvertToLLVMPattern {
|
||||
class ReshapeOpConversion : public ConvertOpToLLVMPattern<ReshapeOp> {
|
||||
public:
|
||||
explicit ReshapeOpConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &lowering_)
|
||||
: ConvertToLLVMPattern(ReshapeOp::getOperationName(), context,
|
||||
lowering_) {}
|
||||
using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(ReshapeOp reshapeOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto reshapeOp = cast<ReshapeOp>(op);
|
||||
MemRefType dstType = reshapeOp.getResultType();
|
||||
|
||||
if (!dstType.hasStaticShape())
|
||||
@ -178,7 +172,7 @@ public:
|
||||
}))
|
||||
return failure();
|
||||
|
||||
edsc::ScopedContext context(rewriter, op->getLoc());
|
||||
edsc::ScopedContext context(rewriter, reshapeOp->getLoc());
|
||||
ReshapeOpAdaptor adaptor(operands);
|
||||
BaseViewConversionHelper baseDesc(adaptor.src());
|
||||
BaseViewConversionHelper desc(typeConverter->convertType(dstType));
|
||||
@ -189,7 +183,7 @@ public:
|
||||
desc.setConstantSize(en.index(), en.value());
|
||||
for (auto en : llvm::enumerate(strides))
|
||||
desc.setConstantStride(en.index(), en.value());
|
||||
rewriter.replaceOp(op, {desc});
|
||||
rewriter.replaceOp(reshapeOp, {desc});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@ -200,19 +194,17 @@ public:
|
||||
/// and stride corresponding to the region of memory within the bounds of
|
||||
/// the parent view.
|
||||
/// The linalg.slice op is replaced by the alloca'ed pointer.
|
||||
class SliceOpConversion : public ConvertToLLVMPattern {
|
||||
class SliceOpConversion : public ConvertOpToLLVMPattern<SliceOp> {
|
||||
public:
|
||||
explicit SliceOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
|
||||
: ConvertToLLVMPattern(SliceOp::getOperationName(), context, lowering_) {}
|
||||
using ConvertOpToLLVMPattern<SliceOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(SliceOp sliceOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
edsc::ScopedContext context(rewriter, op->getLoc());
|
||||
edsc::ScopedContext context(rewriter, sliceOp->getLoc());
|
||||
SliceOpAdaptor adaptor(operands);
|
||||
BaseViewConversionHelper baseDesc(adaptor.view());
|
||||
|
||||
auto sliceOp = cast<SliceOp>(op);
|
||||
auto memRefType = sliceOp.getBaseViewType();
|
||||
auto int64Ty = typeConverter->convertType(rewriter.getIntegerType(64))
|
||||
.cast<LLVM::LLVMType>();
|
||||
@ -248,7 +240,7 @@ public:
|
||||
|
||||
// Corner case, no sizes or strides: early return the descriptor.
|
||||
if (sliceOp.getShapedType().getRank() == 0)
|
||||
return rewriter.replaceOp(op, {desc}), success();
|
||||
return rewriter.replaceOp(sliceOp, {desc}), success();
|
||||
|
||||
Value zero = llvm_constant(
|
||||
int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
|
||||
@ -279,20 +271,18 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, {desc});
|
||||
rewriter.replaceOp(sliceOp, {desc});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// YieldOp produces and LLVM::ReturnOp.
|
||||
class YieldOpConversion : public ConvertToLLVMPattern {
|
||||
class YieldOpConversion : public ConvertOpToLLVMPattern<linalg::YieldOp> {
|
||||
public:
|
||||
explicit YieldOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
|
||||
: ConvertToLLVMPattern(linalg::YieldOp::getOperationName(), context,
|
||||
lowering_) {}
|
||||
using ConvertOpToLLVMPattern<linalg::YieldOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(linalg::YieldOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands);
|
||||
return success();
|
||||
@ -302,10 +292,9 @@ public:
|
||||
|
||||
/// Populate the given list with patterns that convert from Linalg to LLVM.
|
||||
void mlir::populateLinalgToLLVMConversionPatterns(
|
||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
|
||||
MLIRContext *ctx) {
|
||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
|
||||
patterns.insert<RangeOpConversion, ReshapeOpConversion, SliceOpConversion,
|
||||
YieldOpConversion>(ctx, converter);
|
||||
YieldOpConversion>(converter);
|
||||
|
||||
// Populate the type conversions for the linalg types.
|
||||
converter.addConversion(
|
||||
@ -331,7 +320,7 @@ void ConvertLinalgToLLVMPass::runOnOperation() {
|
||||
populateVectorToSCFConversionPatterns(patterns, &getContext());
|
||||
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
|
||||
populateVectorToLLVMConversionPatterns(converter, patterns);
|
||||
populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext());
|
||||
populateLinalgToLLVMConversionPatterns(converter, patterns);
|
||||
|
||||
LLVMConversionTarget target(getContext());
|
||||
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
||||
|
@ -21,34 +21,30 @@ namespace {
|
||||
/// expected to either be processed by the conversion infrastructure or already
|
||||
/// contain ops compatible with LLVM dialect types.
|
||||
template <typename OpType>
|
||||
struct RegionOpConversion : public ConvertToLLVMPattern {
|
||||
explicit RegionOpConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: ConvertToLLVMPattern(OpType::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
struct RegionOpConversion : public ConvertOpToLLVMPattern<OpType> {
|
||||
using ConvertOpToLLVMPattern<OpType>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(OpType curOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto curOp = cast<OpType>(op);
|
||||
auto newOp = rewriter.create<OpType>(curOp.getLoc(), TypeRange(), operands,
|
||||
curOp.getAttrs());
|
||||
rewriter.inlineRegionBefore(curOp.region(), newOp.region(),
|
||||
newOp.region().end());
|
||||
if (failed(rewriter.convertRegionTypes(&newOp.region(), *typeConverter)))
|
||||
if (failed(rewriter.convertRegionTypes(&newOp.region(),
|
||||
*this->getTypeConverter())))
|
||||
return failure();
|
||||
|
||||
rewriter.eraseOp(op);
|
||||
rewriter.eraseOp(curOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::populateOpenMPToLLVMConversionPatterns(
|
||||
MLIRContext *context, LLVMTypeConverter &converter,
|
||||
OwningRewritePatternList &patterns) {
|
||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
|
||||
patterns.insert<RegionOpConversion<omp::ParallelOp>,
|
||||
RegionOpConversion<omp::WsLoopOp>>(context, converter);
|
||||
RegionOpConversion<omp::WsLoopOp>>(converter);
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -60,13 +56,12 @@ struct ConvertOpenMPToLLVMPass
|
||||
|
||||
void ConvertOpenMPToLLVMPass::runOnOperation() {
|
||||
auto module = getOperation();
|
||||
MLIRContext *context = &getContext();
|
||||
|
||||
// Convert to OpenMP operations with LLVM IR dialect
|
||||
OwningRewritePatternList patterns;
|
||||
LLVMTypeConverter converter(&getContext());
|
||||
populateStdToLLVMConversionPatterns(converter, patterns);
|
||||
populateOpenMPToLLVMConversionPatterns(context, converter, patterns);
|
||||
populateOpenMPToLLVMConversionPatterns(converter, patterns);
|
||||
|
||||
LLVMConversionTarget target(getContext());
|
||||
target.addDynamicallyLegalOp<omp::ParallelOp, omp::WsLoopOp>(
|
||||
|
@ -296,39 +296,33 @@ namespace {
|
||||
|
||||
/// Conversion pattern for a vector.matrix_multiply.
|
||||
/// This is lowered directly to the proper llvm.intr.matrix.multiply.
|
||||
class VectorMatmulOpConversion : public ConvertToLLVMPattern {
|
||||
class VectorMatmulOpConversion
|
||||
: public ConvertOpToLLVMPattern<vector::MatmulOp> {
|
||||
public:
|
||||
explicit VectorMatmulOpConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(vector::MatmulOp matmulOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto matmulOp = cast<vector::MatmulOp>(op);
|
||||
auto adaptor = vector::MatmulOpAdaptor(operands);
|
||||
rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
|
||||
op, typeConverter->convertType(matmulOp.res().getType()), adaptor.lhs(),
|
||||
adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(),
|
||||
matmulOp.rhs_columns());
|
||||
matmulOp, typeConverter->convertType(matmulOp.res().getType()),
|
||||
adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(),
|
||||
matmulOp.lhs_columns(), matmulOp.rhs_columns());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Conversion pattern for a vector.flat_transpose.
|
||||
/// This is lowered directly to the proper llvm.intr.matrix.transpose.
|
||||
class VectorFlatTransposeOpConversion : public ConvertToLLVMPattern {
|
||||
class VectorFlatTransposeOpConversion
|
||||
: public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
|
||||
public:
|
||||
explicit VectorFlatTransposeOpConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: ConvertToLLVMPattern(vector::FlatTransposeOp::getOperationName(),
|
||||
context, typeConverter) {}
|
||||
using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(vector::FlatTransposeOp transOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto transOp = cast<vector::FlatTransposeOp>(op);
|
||||
auto adaptor = vector::FlatTransposeOpAdaptor(operands);
|
||||
rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
|
||||
transOp, typeConverter->convertType(transOp.res().getType()),
|
||||
@ -338,18 +332,15 @@ public:
|
||||
};
|
||||
|
||||
/// Conversion pattern for a vector.maskedload.
|
||||
class VectorMaskedLoadOpConversion : public ConvertToLLVMPattern {
|
||||
class VectorMaskedLoadOpConversion
|
||||
: public ConvertOpToLLVMPattern<vector::MaskedLoadOp> {
|
||||
public:
|
||||
explicit VectorMaskedLoadOpConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: ConvertToLLVMPattern(vector::MaskedLoadOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
using ConvertOpToLLVMPattern<vector::MaskedLoadOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(vector::MaskedLoadOp load, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
auto load = cast<vector::MaskedLoadOp>(op);
|
||||
auto loc = load->getLoc();
|
||||
auto adaptor = vector::MaskedLoadOpAdaptor(operands);
|
||||
|
||||
// Resolve alignment.
|
||||
@ -371,18 +362,15 @@ public:
|
||||
};
|
||||
|
||||
/// Conversion pattern for a vector.maskedstore.
|
||||
class VectorMaskedStoreOpConversion : public ConvertToLLVMPattern {
|
||||
class VectorMaskedStoreOpConversion
|
||||
: public ConvertOpToLLVMPattern<vector::MaskedStoreOp> {
|
||||
public:
|
||||
explicit VectorMaskedStoreOpConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: ConvertToLLVMPattern(vector::MaskedStoreOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
using ConvertOpToLLVMPattern<vector::MaskedStoreOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(vector::MaskedStoreOp store, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
auto store = cast<vector::MaskedStoreOp>(op);
|
||||
auto loc = store->getLoc();
|
||||
auto adaptor = vector::MaskedStoreOpAdaptor(operands);
|
||||
|
||||
// Resolve alignment.
|
||||
@ -404,18 +392,15 @@ public:
|
||||
};
|
||||
|
||||
/// Conversion pattern for a vector.gather.
|
||||
class VectorGatherOpConversion : public ConvertToLLVMPattern {
|
||||
class VectorGatherOpConversion
|
||||
: public ConvertOpToLLVMPattern<vector::GatherOp> {
|
||||
public:
|
||||
explicit VectorGatherOpConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: ConvertToLLVMPattern(vector::GatherOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(vector::GatherOp gather, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
auto gather = cast<vector::GatherOp>(op);
|
||||
auto loc = gather->getLoc();
|
||||
auto adaptor = vector::GatherOpAdaptor(operands);
|
||||
|
||||
// Resolve alignment.
|
||||
@ -440,18 +425,15 @@ public:
|
||||
};
|
||||
|
||||
/// Conversion pattern for a vector.scatter.
|
||||
class VectorScatterOpConversion : public ConvertToLLVMPattern {
|
||||
class VectorScatterOpConversion
|
||||
: public ConvertOpToLLVMPattern<vector::ScatterOp> {
|
||||
public:
|
||||
explicit VectorScatterOpConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: ConvertToLLVMPattern(vector::ScatterOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(vector::ScatterOp scatter, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
auto scatter = cast<vector::ScatterOp>(op);
|
||||
auto loc = scatter->getLoc();
|
||||
auto adaptor = vector::ScatterOpAdaptor(operands);
|
||||
|
||||
// Resolve alignment.
|
||||
@ -476,18 +458,15 @@ public:
|
||||
};
|
||||
|
||||
/// Conversion pattern for a vector.expandload.
|
||||
class VectorExpandLoadOpConversion : public ConvertToLLVMPattern {
|
||||
class VectorExpandLoadOpConversion
|
||||
: public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
|
||||
public:
|
||||
explicit VectorExpandLoadOpConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: ConvertToLLVMPattern(vector::ExpandLoadOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(vector::ExpandLoadOp expand, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
auto expand = cast<vector::ExpandLoadOp>(op);
|
||||
auto loc = expand->getLoc();
|
||||
auto adaptor = vector::ExpandLoadOpAdaptor(operands);
|
||||
|
||||
Value ptr;
|
||||
@ -497,25 +476,22 @@ public:
|
||||
|
||||
auto vType = expand.getResultVectorType();
|
||||
rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
|
||||
op, typeConverter->convertType(vType), ptr, adaptor.mask(),
|
||||
expand, typeConverter->convertType(vType), ptr, adaptor.mask(),
|
||||
adaptor.pass_thru());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Conversion pattern for a vector.compressstore.
|
||||
class VectorCompressStoreOpConversion : public ConvertToLLVMPattern {
|
||||
class VectorCompressStoreOpConversion
|
||||
: public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
|
||||
public:
|
||||
explicit VectorCompressStoreOpConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: ConvertToLLVMPattern(vector::CompressStoreOp::getOperationName(),
|
||||
context, typeConverter) {}
|
||||
using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(vector::CompressStoreOp compress, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
auto compress = cast<vector::CompressStoreOp>(op);
|
||||
auto loc = compress->getLoc();
|
||||
auto adaptor = vector::CompressStoreOpAdaptor(operands);
|
||||
|
||||
Value ptr;
|
||||
@ -524,25 +500,23 @@ public:
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
|
||||
op, adaptor.value(), ptr, adaptor.mask());
|
||||
compress, adaptor.value(), ptr, adaptor.mask());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Conversion pattern for all vector reductions.
|
||||
class VectorReductionOpConversion : public ConvertToLLVMPattern {
|
||||
class VectorReductionOpConversion
|
||||
: public ConvertOpToLLVMPattern<vector::ReductionOp> {
|
||||
public:
|
||||
explicit VectorReductionOpConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter,
|
||||
explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv,
|
||||
bool reassociateFPRed)
|
||||
: ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context,
|
||||
typeConverter),
|
||||
: ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
|
||||
reassociateFPReductions(reassociateFPRed) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(vector::ReductionOp reductionOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto reductionOp = cast<vector::ReductionOp>(op);
|
||||
auto kind = reductionOp.kind();
|
||||
Type eltType = reductionOp.dest().getType();
|
||||
Type llvmType = typeConverter->convertType(eltType);
|
||||
@ -550,33 +524,33 @@ public:
|
||||
// Integer reductions: add/mul/min/max/and/or/xor.
|
||||
if (kind == "add")
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(
|
||||
op, llvmType, operands[0]);
|
||||
reductionOp, llvmType, operands[0]);
|
||||
else if (kind == "mul")
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(
|
||||
op, llvmType, operands[0]);
|
||||
reductionOp, llvmType, operands[0]);
|
||||
else if (kind == "min" &&
|
||||
(eltType.isIndex() || eltType.isUnsignedInteger()))
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>(
|
||||
op, llvmType, operands[0]);
|
||||
reductionOp, llvmType, operands[0]);
|
||||
else if (kind == "min")
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>(
|
||||
op, llvmType, operands[0]);
|
||||
reductionOp, llvmType, operands[0]);
|
||||
else if (kind == "max" &&
|
||||
(eltType.isIndex() || eltType.isUnsignedInteger()))
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>(
|
||||
op, llvmType, operands[0]);
|
||||
reductionOp, llvmType, operands[0]);
|
||||
else if (kind == "max")
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>(
|
||||
op, llvmType, operands[0]);
|
||||
reductionOp, llvmType, operands[0]);
|
||||
else if (kind == "and")
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(
|
||||
op, llvmType, operands[0]);
|
||||
reductionOp, llvmType, operands[0]);
|
||||
else if (kind == "or")
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(
|
||||
op, llvmType, operands[0]);
|
||||
reductionOp, llvmType, operands[0]);
|
||||
else if (kind == "xor")
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(
|
||||
op, llvmType, operands[0]);
|
||||
reductionOp, llvmType, operands[0]);
|
||||
else
|
||||
return failure();
|
||||
return success();
|
||||
@ -590,27 +564,27 @@ public:
|
||||
// Optional accumulator (or zero).
|
||||
Value acc = operands.size() > 1 ? operands[1]
|
||||
: rewriter.create<LLVM::ConstantOp>(
|
||||
op->getLoc(), llvmType,
|
||||
reductionOp->getLoc(), llvmType,
|
||||
rewriter.getZeroAttr(eltType));
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
|
||||
op, llvmType, acc, operands[0],
|
||||
reductionOp, llvmType, acc, operands[0],
|
||||
rewriter.getBoolAttr(reassociateFPReductions));
|
||||
} else if (kind == "mul") {
|
||||
// Optional accumulator (or one).
|
||||
Value acc = operands.size() > 1
|
||||
? operands[1]
|
||||
: rewriter.create<LLVM::ConstantOp>(
|
||||
op->getLoc(), llvmType,
|
||||
reductionOp->getLoc(), llvmType,
|
||||
rewriter.getFloatAttr(eltType, 1.0));
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
|
||||
op, llvmType, acc, operands[0],
|
||||
reductionOp, llvmType, acc, operands[0],
|
||||
rewriter.getBoolAttr(reassociateFPReductions));
|
||||
} else if (kind == "min")
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(op, llvmType,
|
||||
operands[0]);
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(
|
||||
reductionOp, llvmType, operands[0]);
|
||||
else if (kind == "max")
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(op, llvmType,
|
||||
operands[0]);
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(
|
||||
reductionOp, llvmType, operands[0]);
|
||||
else
|
||||
return failure();
|
||||
return success();
|
||||
@ -621,17 +595,16 @@ private:
|
||||
};
|
||||
|
||||
/// Conversion pattern for a vector.create_mask (1-D only).
|
||||
class VectorCreateMaskOpConversion : public ConvertToLLVMPattern {
|
||||
class VectorCreateMaskOpConversion
|
||||
: public ConvertOpToLLVMPattern<vector::CreateMaskOp> {
|
||||
public:
|
||||
explicit VectorCreateMaskOpConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter,
|
||||
explicit VectorCreateMaskOpConversion(LLVMTypeConverter &typeConv,
|
||||
bool enableIndexOpt)
|
||||
: ConvertToLLVMPattern(vector::CreateMaskOp::getOperationName(), context,
|
||||
typeConverter),
|
||||
: ConvertOpToLLVMPattern<vector::CreateMaskOp>(typeConv),
|
||||
enableIndexOptimizations(enableIndexOpt) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(vector::CreateMaskOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto dstType = op->getResult(0).getType().cast<VectorType>();
|
||||
int64_t rank = dstType.getRank();
|
||||
@ -648,19 +621,16 @@ private:
|
||||
const bool enableIndexOptimizations;
|
||||
};
|
||||
|
||||
class VectorShuffleOpConversion : public ConvertToLLVMPattern {
|
||||
class VectorShuffleOpConversion
|
||||
: public ConvertOpToLLVMPattern<vector::ShuffleOp> {
|
||||
public:
|
||||
explicit VectorShuffleOpConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(vector::ShuffleOp shuffleOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
auto loc = shuffleOp->getLoc();
|
||||
auto adaptor = vector::ShuffleOpAdaptor(operands);
|
||||
auto shuffleOp = cast<vector::ShuffleOp>(op);
|
||||
auto v1Type = shuffleOp.getV1VectorType();
|
||||
auto v2Type = shuffleOp.getV2VectorType();
|
||||
auto vectorType = shuffleOp.getVectorType();
|
||||
@ -680,9 +650,9 @@ public:
|
||||
// For rank 1, where both operands have *exactly* the same vector type,
|
||||
// there is direct shuffle support in LLVM. Use it!
|
||||
if (rank == 1 && v1Type == v2Type) {
|
||||
Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>(
|
||||
Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
|
||||
loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
|
||||
rewriter.replaceOp(op, shuffle);
|
||||
rewriter.replaceOp(shuffleOp, llvmShuffleOp);
|
||||
return success();
|
||||
}
|
||||
|
||||
@ -701,23 +671,22 @@ public:
|
||||
insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
|
||||
llvmType, rank, insPos++);
|
||||
}
|
||||
rewriter.replaceOp(op, insert);
|
||||
rewriter.replaceOp(shuffleOp, insert);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class VectorExtractElementOpConversion : public ConvertToLLVMPattern {
|
||||
class VectorExtractElementOpConversion
|
||||
: public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
|
||||
public:
|
||||
explicit VectorExtractElementOpConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(),
|
||||
context, typeConverter) {}
|
||||
using ConvertOpToLLVMPattern<
|
||||
vector::ExtractElementOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(vector::ExtractElementOp extractEltOp,
|
||||
ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto adaptor = vector::ExtractElementOpAdaptor(operands);
|
||||
auto extractEltOp = cast<vector::ExtractElementOp>(op);
|
||||
auto vectorType = extractEltOp.getVectorType();
|
||||
auto llvmType = typeConverter->convertType(vectorType.getElementType());
|
||||
|
||||
@ -726,24 +695,21 @@ public:
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
|
||||
op, llvmType, adaptor.vector(), adaptor.position());
|
||||
extractEltOp, llvmType, adaptor.vector(), adaptor.position());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class VectorExtractOpConversion : public ConvertToLLVMPattern {
|
||||
class VectorExtractOpConversion
|
||||
: public ConvertOpToLLVMPattern<vector::ExtractOp> {
|
||||
public:
|
||||
explicit VectorExtractOpConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
auto loc = extractOp->getLoc();
|
||||
auto adaptor = vector::ExtractOpAdaptor(operands);
|
||||
auto extractOp = cast<vector::ExtractOp>(op);
|
||||
auto vectorType = extractOp.getVectorType();
|
||||
auto resultType = extractOp.getResult().getType();
|
||||
auto llvmResultType = typeConverter->convertType(resultType);
|
||||
@ -757,12 +723,12 @@ public:
|
||||
if (resultType.isa<VectorType>()) {
|
||||
Value extracted = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, llvmResultType, adaptor.vector(), positionArrayAttr);
|
||||
rewriter.replaceOp(op, extracted);
|
||||
rewriter.replaceOp(extractOp, extracted);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Potential extraction of 1-D vector from array.
|
||||
auto *context = op->getContext();
|
||||
auto *context = extractOp->getContext();
|
||||
Value extracted = adaptor.vector();
|
||||
auto positionAttrs = positionArrayAttr.getValue();
|
||||
if (positionAttrs.size() > 1) {
|
||||
@ -780,7 +746,7 @@ public:
|
||||
auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
|
||||
extracted =
|
||||
rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
|
||||
rewriter.replaceOp(op, extracted);
|
||||
rewriter.replaceOp(extractOp, extracted);
|
||||
|
||||
return success();
|
||||
}
|
||||
@ -800,39 +766,32 @@ public:
|
||||
/// (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">)
|
||||
/// -> !llvm<"<8 x float>">
|
||||
/// ```
|
||||
class VectorFMAOp1DConversion : public ConvertToLLVMPattern {
|
||||
class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
|
||||
public:
|
||||
explicit VectorFMAOp1DConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(vector::FMAOp fmaOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto adaptor = vector::FMAOpAdaptor(operands);
|
||||
vector::FMAOp fmaOp = cast<vector::FMAOp>(op);
|
||||
VectorType vType = fmaOp.getVectorType();
|
||||
if (vType.getRank() != 1)
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(op, adaptor.lhs(),
|
||||
rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(fmaOp, adaptor.lhs(),
|
||||
adaptor.rhs(), adaptor.acc());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class VectorInsertElementOpConversion : public ConvertToLLVMPattern {
|
||||
class VectorInsertElementOpConversion
|
||||
: public ConvertOpToLLVMPattern<vector::InsertElementOp> {
|
||||
public:
|
||||
explicit VectorInsertElementOpConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(),
|
||||
context, typeConverter) {}
|
||||
using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(vector::InsertElementOp insertEltOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto adaptor = vector::InsertElementOpAdaptor(operands);
|
||||
auto insertEltOp = cast<vector::InsertElementOp>(op);
|
||||
auto vectorType = insertEltOp.getDestVectorType();
|
||||
auto llvmType = typeConverter->convertType(vectorType);
|
||||
|
||||
@ -841,24 +800,22 @@ public:
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
|
||||
op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position());
|
||||
insertEltOp, llvmType, adaptor.dest(), adaptor.source(),
|
||||
adaptor.position());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class VectorInsertOpConversion : public ConvertToLLVMPattern {
|
||||
class VectorInsertOpConversion
|
||||
: public ConvertOpToLLVMPattern<vector::InsertOp> {
|
||||
public:
|
||||
explicit VectorInsertOpConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
auto loc = insertOp->getLoc();
|
||||
auto adaptor = vector::InsertOpAdaptor(operands);
|
||||
auto insertOp = cast<vector::InsertOp>(op);
|
||||
auto sourceType = insertOp.getSourceType();
|
||||
auto destVectorType = insertOp.getDestVectorType();
|
||||
auto llvmResultType = typeConverter->convertType(destVectorType);
|
||||
@ -873,12 +830,12 @@ public:
|
||||
Value inserted = rewriter.create<LLVM::InsertValueOp>(
|
||||
loc, llvmResultType, adaptor.dest(), adaptor.source(),
|
||||
positionArrayAttr);
|
||||
rewriter.replaceOp(op, inserted);
|
||||
rewriter.replaceOp(insertOp, inserted);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Potential extraction of 1-D vector from array.
|
||||
auto *context = op->getContext();
|
||||
auto *context = insertOp->getContext();
|
||||
Value extracted = adaptor.dest();
|
||||
auto positionAttrs = positionArrayAttr.getValue();
|
||||
auto position = positionAttrs.back().cast<IntegerAttr>();
|
||||
@ -908,7 +865,7 @@ public:
|
||||
nMinusOnePositionAttrs);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, inserted);
|
||||
rewriter.replaceOp(insertOp, inserted);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@ -1117,18 +1074,15 @@ computeContiguousStrides(MemRefType memRefType) {
|
||||
return strides;
|
||||
}
|
||||
|
||||
class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
|
||||
class VectorTypeCastOpConversion
|
||||
: public ConvertOpToLLVMPattern<vector::TypeCastOp> {
|
||||
public:
|
||||
explicit VectorTypeCastOpConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(vector::TypeCastOp castOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op);
|
||||
auto loc = castOp->getLoc();
|
||||
MemRefType sourceMemRefType =
|
||||
castOp.getOperand().getType().cast<MemRefType>();
|
||||
MemRefType targetMemRefType =
|
||||
@ -1195,7 +1149,7 @@ public:
|
||||
desc.setStride(rewriter, loc, index, stride);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, {desc});
|
||||
rewriter.replaceOp(castOp, {desc});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@ -1208,18 +1162,16 @@ public:
|
||||
/// 4. Create a mask where offsetVector is compared against memref upper bound.
|
||||
/// 5. Rewrite op as a masked read or write.
|
||||
template <typename ConcreteOp>
|
||||
class VectorTransferConversion : public ConvertToLLVMPattern {
|
||||
class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
|
||||
public:
|
||||
explicit VectorTransferConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConv,
|
||||
explicit VectorTransferConversion(LLVMTypeConverter &typeConv,
|
||||
bool enableIndexOpt)
|
||||
: ConvertToLLVMPattern(ConcreteOp::getOperationName(), context, typeConv),
|
||||
: ConvertOpToLLVMPattern<ConcreteOp>(typeConv),
|
||||
enableIndexOptimizations(enableIndexOpt) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto xferOp = cast<ConcreteOp>(op);
|
||||
auto adaptor = getTransferOpAdapter(xferOp, operands);
|
||||
|
||||
if (xferOp.getVectorType().getRank() > 1 ||
|
||||
@ -1228,16 +1180,18 @@ public:
|
||||
if (xferOp.permutation_map() !=
|
||||
AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(),
|
||||
xferOp.getVectorType().getRank(),
|
||||
op->getContext()))
|
||||
xferOp->getContext()))
|
||||
return failure();
|
||||
// Only contiguous source tensors supported atm.
|
||||
auto strides = computeContiguousStrides(xferOp.getMemRefType());
|
||||
if (!strides)
|
||||
return failure();
|
||||
|
||||
auto toLLVMTy = [&](Type t) { return typeConverter->convertType(t); };
|
||||
auto toLLVMTy = [&](Type t) {
|
||||
return this->getTypeConverter()->convertType(t);
|
||||
};
|
||||
|
||||
Location loc = op->getLoc();
|
||||
Location loc = xferOp->getLoc();
|
||||
MemRefType memRefType = xferOp.getMemRefType();
|
||||
|
||||
if (auto memrefVectorElementType =
|
||||
@ -1267,8 +1221,8 @@ public:
|
||||
// addrspacecast shall be used when source/dst memrefs are not on
|
||||
// address space 0.
|
||||
// TODO: support alignment when possible.
|
||||
Value dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(),
|
||||
adaptor.indices(), rewriter);
|
||||
Value dataPtr = this->getStridedElementPtr(
|
||||
loc, memRefType, adaptor.memref(), adaptor.indices(), rewriter);
|
||||
auto vecTy =
|
||||
toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
|
||||
Value vectorDataPtr;
|
||||
@ -1280,8 +1234,9 @@ public:
|
||||
loc, vecTy.getPointerTo(), dataPtr);
|
||||
|
||||
if (!xferOp.isMaskedDim(0))
|
||||
return replaceTransferOpWithLoadOrStore(
|
||||
rewriter, *getTypeConverter(), loc, xferOp, operands, vectorDataPtr);
|
||||
return replaceTransferOpWithLoadOrStore(rewriter,
|
||||
*this->getTypeConverter(), loc,
|
||||
xferOp, operands, vectorDataPtr);
|
||||
|
||||
// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
|
||||
// 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
|
||||
@ -1294,11 +1249,11 @@ public:
|
||||
unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
|
||||
Value off = xferOp.indices()[lastIndex];
|
||||
Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex);
|
||||
Value mask = buildVectorComparison(rewriter, op, enableIndexOptimizations,
|
||||
vecWidth, dim, &off);
|
||||
Value mask = buildVectorComparison(
|
||||
rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off);
|
||||
|
||||
// 5. Rewrite as a masked read / write.
|
||||
return replaceTransferOpWithMasked(rewriter, *getTypeConverter(), loc,
|
||||
return replaceTransferOpWithMasked(rewriter, *this->getTypeConverter(), loc,
|
||||
xferOp, operands, vectorDataPtr, mask);
|
||||
}
|
||||
|
||||
@ -1306,12 +1261,9 @@ private:
|
||||
const bool enableIndexOptimizations;
|
||||
};
|
||||
|
||||
class VectorPrintOpConversion : public ConvertToLLVMPattern {
|
||||
class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
|
||||
public:
|
||||
explicit VectorPrintOpConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: ConvertToLLVMPattern(vector::PrintOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
// Proof-of-concept lowering implementation that relies on a small
|
||||
// runtime support library, which only needs to provide a few
|
||||
@ -1326,9 +1278,8 @@ public:
|
||||
// TODO: rely solely on libc in future? something else?
|
||||
//
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(vector::PrintOp printOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto printOp = cast<vector::PrintOp>(op);
|
||||
auto adaptor = vector::PrintOpAdaptor(operands);
|
||||
Type printType = printOp.getPrintType();
|
||||
|
||||
@ -1341,11 +1292,11 @@ public:
|
||||
Type eltType = vectorType ? vectorType.getElementType() : printType;
|
||||
Operation *printer;
|
||||
if (eltType.isF32()) {
|
||||
printer = getPrintFloat(op);
|
||||
printer = getPrintFloat(printOp);
|
||||
} else if (eltType.isF64()) {
|
||||
printer = getPrintDouble(op);
|
||||
printer = getPrintDouble(printOp);
|
||||
} else if (eltType.isIndex()) {
|
||||
printer = getPrintU64(op);
|
||||
printer = getPrintU64(printOp);
|
||||
} else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
|
||||
// Integers need a zero or sign extension on the operand
|
||||
// (depending on the source type) as well as a signed or
|
||||
@ -1355,7 +1306,7 @@ public:
|
||||
if (width <= 64) {
|
||||
if (width < 64)
|
||||
conversion = PrintConversion::ZeroExt64;
|
||||
printer = getPrintU64(op);
|
||||
printer = getPrintU64(printOp);
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
@ -1368,7 +1319,7 @@ public:
|
||||
conversion = PrintConversion::ZeroExt64;
|
||||
else if (width < 64)
|
||||
conversion = PrintConversion::SignExt64;
|
||||
printer = getPrintI64(op);
|
||||
printer = getPrintI64(printOp);
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
@ -1379,10 +1330,10 @@ public:
|
||||
|
||||
// Unroll vector into elementary print calls.
|
||||
int64_t rank = vectorType ? vectorType.getRank() : 0;
|
||||
emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank,
|
||||
emitRanks(rewriter, printOp, adaptor.source(), vectorType, printer, rank,
|
||||
conversion);
|
||||
emitCall(rewriter, op->getLoc(), getPrintNewline(op));
|
||||
rewriter.eraseOp(op);
|
||||
emitCall(rewriter, printOp->getLoc(), getPrintNewline(printOp));
|
||||
rewriter.eraseOp(printOp);
|
||||
return success();
|
||||
}
|
||||
|
||||
@ -1560,11 +1511,11 @@ void mlir::populateVectorToLLVMConversionPatterns(
|
||||
VectorInsertStridedSliceOpSameRankRewritePattern,
|
||||
VectorExtractStridedSliceOpConversion>(ctx);
|
||||
patterns.insert<VectorReductionOpConversion>(
|
||||
ctx, converter, reassociateFPReductions);
|
||||
converter, reassociateFPReductions);
|
||||
patterns.insert<VectorCreateMaskOpConversion,
|
||||
VectorTransferConversion<TransferReadOp>,
|
||||
VectorTransferConversion<TransferWriteOp>>(
|
||||
ctx, converter, enableIndexOptimizations);
|
||||
converter, enableIndexOptimizations);
|
||||
patterns
|
||||
.insert<VectorShuffleOpConversion,
|
||||
VectorExtractElementOpConversion,
|
||||
@ -1579,13 +1530,12 @@ void mlir::populateVectorToLLVMConversionPatterns(
|
||||
VectorGatherOpConversion,
|
||||
VectorScatterOpConversion,
|
||||
VectorExpandLoadOpConversion,
|
||||
VectorCompressStoreOpConversion>(ctx, converter);
|
||||
VectorCompressStoreOpConversion>(converter);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
void mlir::populateVectorToLLVMMatrixConversionPatterns(
|
||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
|
||||
MLIRContext *ctx = converter.getDialect()->getContext();
|
||||
patterns.insert<VectorMatmulOpConversion>(ctx, converter);
|
||||
patterns.insert<VectorFlatTransposeOpConversion>(ctx, converter);
|
||||
patterns.insert<VectorMatmulOpConversion>(converter);
|
||||
patterns.insert<VectorFlatTransposeOpConversion>(converter);
|
||||
}
|
||||
|
@ -55,17 +55,13 @@ namespace {
|
||||
/// types. For unsupported cases, they will fall back to the vector to
|
||||
/// llvm conversion pattern.
|
||||
template <typename ConcreteOp>
|
||||
class VectorTransferConversion : public ConvertToLLVMPattern {
|
||||
class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
|
||||
public:
|
||||
explicit VectorTransferConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConv)
|
||||
: ConvertToLLVMPattern(ConcreteOp::getOperationName(), context,
|
||||
typeConv) {}
|
||||
using ConvertOpToLLVMPattern<ConcreteOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto xferOp = cast<ConcreteOp>(op);
|
||||
typename ConcreteOp::Adaptor adaptor(operands);
|
||||
|
||||
if (xferOp.getVectorType().getRank() > 1 ||
|
||||
@ -79,11 +75,13 @@ public:
|
||||
if (!xferOp.isMaskedDim(0))
|
||||
return failure();
|
||||
|
||||
auto toLLVMTy = [&](Type t) { return typeConverter->convertType(t); };
|
||||
auto toLLVMTy = [&](Type t) {
|
||||
return this->getTypeConverter()->convertType(t);
|
||||
};
|
||||
LLVM::LLVMType vecTy =
|
||||
toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
|
||||
unsigned vecWidth = vecTy.getVectorNumElements();
|
||||
Location loc = op->getLoc();
|
||||
Location loc = xferOp->getLoc();
|
||||
|
||||
// The backend result vector scalarization have trouble scalarize
|
||||
// <1 x ty> result, exclude the x1 width from the lowering.
|
||||
@ -102,8 +100,8 @@ public:
|
||||
// Note that the dataPtr starts at the offset address specified by
|
||||
// indices, so no need to calculate offset size in bytes again in
|
||||
// the MUBUF instruction.
|
||||
Value dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(),
|
||||
adaptor.indices(), rewriter);
|
||||
Value dataPtr = this->getStridedElementPtr(
|
||||
loc, memRefType, adaptor.memref(), adaptor.indices(), rewriter);
|
||||
|
||||
// 1. Create and fill a <4 x i32> dwordConfig with:
|
||||
// 1st two elements holding the address of dataPtr.
|
||||
@ -126,7 +124,7 @@ public:
|
||||
constConfig);
|
||||
Value dataPtrAsI64 = rewriter.create<LLVM::PtrToIntOp>(
|
||||
loc, toLLVMTy(i64Ty).template cast<LLVM::LLVMType>(), dataPtr);
|
||||
Value zero = createIndexConstant(rewriter, loc, 0);
|
||||
Value zero = this->createIndexConstant(rewriter, loc, 0);
|
||||
Value dwordConfig = rewriter.create<LLVM::InsertElementOp>(
|
||||
loc,
|
||||
LLVM::LLVMType::getVectorTy(
|
||||
@ -143,7 +141,7 @@ public:
|
||||
loc, toLLVMTy(i32Ty),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(32), 0));
|
||||
return replaceTransferOpWithMubuf(
|
||||
rewriter, operands, *getTypeConverter(), loc, xferOp, vecTy,
|
||||
rewriter, operands, *this->getTypeConverter(), loc, xferOp, vecTy,
|
||||
dwordConfig, int32Zero, int32Zero, int1False, int1False);
|
||||
}
|
||||
};
|
||||
@ -151,9 +149,8 @@ public:
|
||||
|
||||
void mlir::populateVectorToROCDLConversionPatterns(
|
||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
|
||||
MLIRContext *ctx = converter.getDialect()->getContext();
|
||||
patterns.insert<VectorTransferConversion<TransferReadOp>,
|
||||
VectorTransferConversion<TransferWriteOp>>(ctx, converter);
|
||||
VectorTransferConversion<TransferWriteOp>>(converter);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
Loading…
Reference in New Issue
Block a user