[NFC] Use ConvertOpToLLVMPattern instead of ConvertToLLVMPattern.

- use ConvertOpToLLVMPattern to avoid explicit casting and in most cases the
  constructor can be reused to save a few lines of code.

Differential Revision: https://reviews.llvm.org/D92989
This commit is contained in:
Rahul Joshi 2020-12-09 18:18:35 -08:00
parent a1ae3c6ac9
commit 563879b6f9
12 changed files with 238 additions and 333 deletions

View File

@ -18,8 +18,7 @@ template <typename T> class OperationPass;
/// Populate the given list with patterns that convert from Linalg to LLVM.
void populateLinalgToLLVMConversionPatterns(LLVMTypeConverter &converter,
OwningRewritePatternList &patterns,
MLIRContext *ctx);
OwningRewritePatternList &patterns);
/// Create a pass to convert Linalg operations to the LLVMIR dialect.
std::unique_ptr<OperationPass<ModuleOp>> createConvertLinalgToLLVMPass();

View File

@ -19,8 +19,7 @@ class OperationPass;
class OwningRewritePatternList;
/// Populate the given list with patterns that convert from OpenMP to LLVM.
void populateOpenMPToLLVMConversionPatterns(MLIRContext *context,
LLVMTypeConverter &converter,
void populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
OwningRewritePatternList &patterns);
/// Create a pass to convert OpenMP operations to the LLVMIR dialect.

View File

@ -565,8 +565,8 @@ protected:
template <typename SourceOp>
class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
public:
ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
explicit ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertToLLVMPattern(SourceOp::getOperationName(),
&typeConverter.getContext(), typeConverter,
benefit) {}

View File

@ -34,8 +34,7 @@ static Type getSrcVectorElementType(OpTy op) {
/// operands as is, preserve attributes.
template <typename SourceOp, typename TargetOp>
static LogicalResult
matchAndRewriteOneToOne(const ConvertToLLVMPattern &lowering,
LLVMTypeConverter &typeConverter, Operation *op,
matchAndRewriteOneToOne(LLVMTypeConverter &typeConverter, Operation *op,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) {
unsigned numResults = op->getNumResults();
@ -73,71 +72,61 @@ namespace {
// TODO: Patterns are too verbose due to the fact that we have 1 op (e.g.
// MaskRndScaleOp) and different possible target ops. It would be better to take
// a Functor so that all these conversions become 1-liners.
struct MaskRndScaleOpPS512Conversion : public ConvertToLLVMPattern {
explicit MaskRndScaleOpPS512Conversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: ConvertToLLVMPattern(MaskRndScaleOp::getOperationName(), context,
typeConverter) {}
struct MaskRndScaleOpPS512Conversion
: public ConvertOpToLLVMPattern<MaskRndScaleOp> {
using ConvertOpToLLVMPattern<MaskRndScaleOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(MaskRndScaleOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (!getSrcVectorElementType(cast<MaskRndScaleOp>(op)).isF32())
if (!getSrcVectorElementType(op).isF32())
return failure();
return matchAndRewriteOneToOne<MaskRndScaleOp,
LLVM::x86_avx512_mask_rndscale_ps_512>(
*this, *getTypeConverter(), op, operands, rewriter);
*getTypeConverter(), op, operands, rewriter);
}
};
struct MaskRndScaleOpPD512Conversion : public ConvertToLLVMPattern {
explicit MaskRndScaleOpPD512Conversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: ConvertToLLVMPattern(MaskRndScaleOp::getOperationName(), context,
typeConverter) {}
struct MaskRndScaleOpPD512Conversion
: public ConvertOpToLLVMPattern<MaskRndScaleOp> {
using ConvertOpToLLVMPattern<MaskRndScaleOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(MaskRndScaleOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (!getSrcVectorElementType(cast<MaskRndScaleOp>(op)).isF64())
if (!getSrcVectorElementType(op).isF64())
return failure();
return matchAndRewriteOneToOne<MaskRndScaleOp,
LLVM::x86_avx512_mask_rndscale_pd_512>(
*this, *getTypeConverter(), op, operands, rewriter);
*getTypeConverter(), op, operands, rewriter);
}
};
struct ScaleFOpPS512Conversion : public ConvertToLLVMPattern {
explicit ScaleFOpPS512Conversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: ConvertToLLVMPattern(MaskScaleFOp::getOperationName(), context,
typeConverter) {}
struct ScaleFOpPS512Conversion : public ConvertOpToLLVMPattern<MaskScaleFOp> {
using ConvertOpToLLVMPattern<MaskScaleFOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(MaskScaleFOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (!getSrcVectorElementType(cast<MaskScaleFOp>(op)).isF32())
if (!getSrcVectorElementType(op).isF32())
return failure();
return matchAndRewriteOneToOne<MaskScaleFOp,
LLVM::x86_avx512_mask_scalef_ps_512>(
*this, *getTypeConverter(), op, operands, rewriter);
*getTypeConverter(), op, operands, rewriter);
}
};
struct ScaleFOpPD512Conversion : public ConvertToLLVMPattern {
explicit ScaleFOpPD512Conversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: ConvertToLLVMPattern(MaskScaleFOp::getOperationName(), context,
typeConverter) {}
struct ScaleFOpPD512Conversion : public ConvertOpToLLVMPattern<MaskScaleFOp> {
using ConvertOpToLLVMPattern<MaskScaleFOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(MaskScaleFOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (!getSrcVectorElementType(cast<MaskScaleFOp>(op)).isF64())
if (!getSrcVectorElementType(op).isF64())
return failure();
return matchAndRewriteOneToOne<MaskScaleFOp,
LLVM::x86_avx512_mask_scalef_pd_512>(
*this, *getTypeConverter(), op, operands, rewriter);
*getTypeConverter(), op, operands, rewriter);
}
};
} // namespace
@ -145,11 +134,10 @@ struct ScaleFOpPD512Conversion : public ConvertToLLVMPattern {
/// Populate the given list with patterns that convert from AVX512 to LLVM.
void mlir::populateAVX512ToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
MLIRContext *ctx = converter.getDialect()->getContext();
// clang-format off
patterns.insert<MaskRndScaleOpPS512Conversion,
MaskRndScaleOpPD512Conversion,
ScaleFOpPS512Conversion,
ScaleFOpPD512Conversion>(ctx, converter);
ScaleFOpPD512Conversion>(converter);
// clang-format on
}

View File

@ -18,17 +18,13 @@
namespace mlir {
template <unsigned AllocaAddrSpace>
struct GPUFuncOpLowering : ConvertToLLVMPattern {
explicit GPUFuncOpLowering(LLVMTypeConverter &typeConverter)
: ConvertToLLVMPattern(gpu::GPUFuncOp::getOperationName(),
typeConverter.getDialect()->getContext(),
typeConverter) {}
struct GPUFuncOpLowering : ConvertOpToLLVMPattern<gpu::GPUFuncOp> {
using ConvertOpToLLVMPattern<gpu::GPUFuncOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
assert(operands.empty() && "func op is not expected to have operands");
auto gpuFuncOp = cast<gpu::GPUFuncOp>(op);
Location loc = gpuFuncOp.getLoc();
SmallVector<LLVM::GlobalOp, 3> workgroupBuffers;
@ -154,14 +150,11 @@ struct GPUFuncOpLowering : ConvertToLLVMPattern {
}
};
struct GPUReturnOpLowering : public ConvertToLLVMPattern {
GPUReturnOpLowering(LLVMTypeConverter &typeConverter)
: ConvertToLLVMPattern(gpu::ReturnOp::getOperationName(),
typeConverter.getDialect()->getContext(),
typeConverter) {}
struct GPUReturnOpLowering : public ConvertOpToLLVMPattern<gpu::ReturnOp> {
using ConvertOpToLLVMPattern<gpu::ReturnOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(gpu::ReturnOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands);
return success();

View File

@ -21,7 +21,7 @@ namespace mlir {
// `indexBitwidth`, sign-extend or truncate the resulting value to match the
// bitwidth expected by the consumers of the value.
template <typename Op, typename XOp, typename YOp, typename ZOp>
struct GPUIndexIntrinsicOpLowering : public ConvertToLLVMPattern {
struct GPUIndexIntrinsicOpLowering : public ConvertOpToLLVMPattern<Op> {
private:
enum dimension { X = 0, Y = 1, Z = 2, invalid };
unsigned indexBitwidth;
@ -36,19 +36,17 @@ private:
public:
explicit GPUIndexIntrinsicOpLowering(LLVMTypeConverter &typeConverter)
: ConvertToLLVMPattern(Op::getOperationName(),
typeConverter.getDialect()->getContext(),
typeConverter),
: ConvertOpToLLVMPattern<Op>(typeConverter),
indexBitwidth(typeConverter.getIndexTypeBitwidth()) {}
// Convert the kernel arguments to an LLVM type, preserve the rest.
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(Op op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
MLIRContext *context = rewriter.getContext();
Value newOp;
switch (dimensionToIndex(cast<Op>(op))) {
switch (dimensionToIndex(op)) {
case X:
newOp = rewriter.create<XOp>(loc, LLVM::LLVMType::getInt32Ty(context));
break;

View File

@ -29,16 +29,15 @@ namespace mlir {
/// will be transformed into
/// llvm.call @__nv_expf(%arg_f32) : (!llvm.float) -> !llvm.float
template <typename SourceOp>
struct OpToFuncCallLowering : public ConvertToLLVMPattern {
struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
public:
explicit OpToFuncCallLowering(LLVMTypeConverter &lowering_, StringRef f32Func,
StringRef f64Func)
: ConvertToLLVMPattern(SourceOp::getOperationName(),
lowering_.getDialect()->getContext(), lowering_),
f32Func(f32Func), f64Func(f64Func) {}
: ConvertOpToLLVMPattern<SourceOp>(lowering_), f32Func(f32Func),
f64Func(f64Func) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
using LLVM::LLVMFuncOp;
using LLVM::LLVMType;

View File

@ -31,10 +31,8 @@ using namespace mlir;
namespace {
struct GPUShuffleOpLowering : public ConvertToLLVMPattern {
explicit GPUShuffleOpLowering(LLVMTypeConverter &lowering_)
: ConvertToLLVMPattern(gpu::ShuffleOp::getOperationName(),
lowering_.getDialect()->getContext(), lowering_) {}
struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern;
/// Lowers a shuffle to the corresponding NVVM op.
///
@ -53,7 +51,7 @@ struct GPUShuffleOpLowering : public ConvertToLLVMPattern {
/// %shfl_pred = llvm.extractvalue %shfl[1 : index] :
/// !llvm<"{ float, i1 }">
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(gpu::ShuffleOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
gpu::ShuffleOpAdaptor adaptor(operands);

View File

@ -126,19 +126,17 @@ private:
};
// RangeOp creates a new range descriptor.
class RangeOpConversion : public ConvertToLLVMPattern {
class RangeOpConversion : public ConvertOpToLLVMPattern<RangeOp> {
public:
explicit RangeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
: ConvertToLLVMPattern(RangeOp::getOperationName(), context, lowering_) {}
using ConvertOpToLLVMPattern<RangeOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(RangeOp rangeOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto rangeOp = cast<RangeOp>(op);
auto rangeDescriptorTy = convertRangeType(
rangeOp.getType().cast<RangeType>(), *getTypeConverter());
edsc::ScopedContext context(rewriter, op->getLoc());
edsc::ScopedContext context(rewriter, rangeOp->getLoc());
// Fill in an aggregate value of the descriptor.
RangeOpAdaptor adaptor(operands);
@ -146,7 +144,7 @@ public:
desc = llvm_insertvalue(desc, adaptor.min(), rewriter.getI64ArrayAttr(0));
desc = llvm_insertvalue(desc, adaptor.max(), rewriter.getI64ArrayAttr(1));
desc = llvm_insertvalue(desc, adaptor.step(), rewriter.getI64ArrayAttr(2));
rewriter.replaceOp(op, desc);
rewriter.replaceOp(rangeOp, desc);
return success();
}
};
@ -154,17 +152,13 @@ public:
// ReshapeOp creates a new view descriptor of the proper rank.
// For now, the only conversion supported is for target MemRef with static sizes
// and strides.
class ReshapeOpConversion : public ConvertToLLVMPattern {
class ReshapeOpConversion : public ConvertOpToLLVMPattern<ReshapeOp> {
public:
explicit ReshapeOpConversion(MLIRContext *context,
LLVMTypeConverter &lowering_)
: ConvertToLLVMPattern(ReshapeOp::getOperationName(), context,
lowering_) {}
using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(ReshapeOp reshapeOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto reshapeOp = cast<ReshapeOp>(op);
MemRefType dstType = reshapeOp.getResultType();
if (!dstType.hasStaticShape())
@ -178,7 +172,7 @@ public:
}))
return failure();
edsc::ScopedContext context(rewriter, op->getLoc());
edsc::ScopedContext context(rewriter, reshapeOp->getLoc());
ReshapeOpAdaptor adaptor(operands);
BaseViewConversionHelper baseDesc(adaptor.src());
BaseViewConversionHelper desc(typeConverter->convertType(dstType));
@ -189,7 +183,7 @@ public:
desc.setConstantSize(en.index(), en.value());
for (auto en : llvm::enumerate(strides))
desc.setConstantStride(en.index(), en.value());
rewriter.replaceOp(op, {desc});
rewriter.replaceOp(reshapeOp, {desc});
return success();
}
};
@ -200,19 +194,17 @@ public:
/// and stride corresponding to the region of memory within the bounds of
/// the parent view.
/// The linalg.slice op is replaced by the alloca'ed pointer.
class SliceOpConversion : public ConvertToLLVMPattern {
class SliceOpConversion : public ConvertOpToLLVMPattern<SliceOp> {
public:
explicit SliceOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
: ConvertToLLVMPattern(SliceOp::getOperationName(), context, lowering_) {}
using ConvertOpToLLVMPattern<SliceOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(SliceOp sliceOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
edsc::ScopedContext context(rewriter, op->getLoc());
edsc::ScopedContext context(rewriter, sliceOp->getLoc());
SliceOpAdaptor adaptor(operands);
BaseViewConversionHelper baseDesc(adaptor.view());
auto sliceOp = cast<SliceOp>(op);
auto memRefType = sliceOp.getBaseViewType();
auto int64Ty = typeConverter->convertType(rewriter.getIntegerType(64))
.cast<LLVM::LLVMType>();
@ -248,7 +240,7 @@ public:
// Corner case, no sizes or strides: early return the descriptor.
if (sliceOp.getShapedType().getRank() == 0)
return rewriter.replaceOp(op, {desc}), success();
return rewriter.replaceOp(sliceOp, {desc}), success();
Value zero = llvm_constant(
int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
@ -279,20 +271,18 @@ public:
}
}
rewriter.replaceOp(op, {desc});
rewriter.replaceOp(sliceOp, {desc});
return success();
}
};
// YieldOp produces and LLVM::ReturnOp.
class YieldOpConversion : public ConvertToLLVMPattern {
class YieldOpConversion : public ConvertOpToLLVMPattern<linalg::YieldOp> {
public:
explicit YieldOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
: ConvertToLLVMPattern(linalg::YieldOp::getOperationName(), context,
lowering_) {}
using ConvertOpToLLVMPattern<linalg::YieldOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(linalg::YieldOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands);
return success();
@ -302,10 +292,9 @@ public:
/// Populate the given list with patterns that convert from Linalg to LLVM.
void mlir::populateLinalgToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
MLIRContext *ctx) {
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
patterns.insert<RangeOpConversion, ReshapeOpConversion, SliceOpConversion,
YieldOpConversion>(ctx, converter);
YieldOpConversion>(converter);
// Populate the type conversions for the linalg types.
converter.addConversion(
@ -331,7 +320,7 @@ void ConvertLinalgToLLVMPass::runOnOperation() {
populateVectorToSCFConversionPatterns(patterns, &getContext());
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
populateVectorToLLVMConversionPatterns(converter, patterns);
populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext());
populateLinalgToLLVMConversionPatterns(converter, patterns);
LLVMConversionTarget target(getContext());
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();

View File

@ -21,34 +21,30 @@ namespace {
/// expected to either be processed by the conversion infrastructure or already
/// contain ops compatible with LLVM dialect types.
template <typename OpType>
struct RegionOpConversion : public ConvertToLLVMPattern {
explicit RegionOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: ConvertToLLVMPattern(OpType::getOperationName(), context,
typeConverter) {}
struct RegionOpConversion : public ConvertOpToLLVMPattern<OpType> {
using ConvertOpToLLVMPattern<OpType>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(OpType curOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto curOp = cast<OpType>(op);
auto newOp = rewriter.create<OpType>(curOp.getLoc(), TypeRange(), operands,
curOp.getAttrs());
rewriter.inlineRegionBefore(curOp.region(), newOp.region(),
newOp.region().end());
if (failed(rewriter.convertRegionTypes(&newOp.region(), *typeConverter)))
if (failed(rewriter.convertRegionTypes(&newOp.region(),
*this->getTypeConverter())))
return failure();
rewriter.eraseOp(op);
rewriter.eraseOp(curOp);
return success();
}
};
} // namespace
void mlir::populateOpenMPToLLVMConversionPatterns(
MLIRContext *context, LLVMTypeConverter &converter,
OwningRewritePatternList &patterns) {
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
patterns.insert<RegionOpConversion<omp::ParallelOp>,
RegionOpConversion<omp::WsLoopOp>>(context, converter);
RegionOpConversion<omp::WsLoopOp>>(converter);
}
namespace {
@ -60,13 +56,12 @@ struct ConvertOpenMPToLLVMPass
void ConvertOpenMPToLLVMPass::runOnOperation() {
auto module = getOperation();
MLIRContext *context = &getContext();
// Convert to OpenMP operations with LLVM IR dialect
OwningRewritePatternList patterns;
LLVMTypeConverter converter(&getContext());
populateStdToLLVMConversionPatterns(converter, patterns);
populateOpenMPToLLVMConversionPatterns(context, converter, patterns);
populateOpenMPToLLVMConversionPatterns(converter, patterns);
LLVMConversionTarget target(getContext());
target.addDynamicallyLegalOp<omp::ParallelOp, omp::WsLoopOp>(

View File

@ -296,39 +296,33 @@ namespace {
/// Conversion pattern for a vector.matrix_multiply.
/// This is lowered directly to the proper llvm.intr.matrix.multiply.
class VectorMatmulOpConversion : public ConvertToLLVMPattern {
class VectorMatmulOpConversion
: public ConvertOpToLLVMPattern<vector::MatmulOp> {
public:
explicit VectorMatmulOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context,
typeConverter) {}
using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(vector::MatmulOp matmulOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto matmulOp = cast<vector::MatmulOp>(op);
auto adaptor = vector::MatmulOpAdaptor(operands);
rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
op, typeConverter->convertType(matmulOp.res().getType()), adaptor.lhs(),
adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(),
matmulOp.rhs_columns());
matmulOp, typeConverter->convertType(matmulOp.res().getType()),
adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(),
matmulOp.lhs_columns(), matmulOp.rhs_columns());
return success();
}
};
/// Conversion pattern for a vector.flat_transpose.
/// This is lowered directly to the proper llvm.intr.matrix.transpose.
class VectorFlatTransposeOpConversion : public ConvertToLLVMPattern {
class VectorFlatTransposeOpConversion
: public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
public:
explicit VectorFlatTransposeOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: ConvertToLLVMPattern(vector::FlatTransposeOp::getOperationName(),
context, typeConverter) {}
using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(vector::FlatTransposeOp transOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto transOp = cast<vector::FlatTransposeOp>(op);
auto adaptor = vector::FlatTransposeOpAdaptor(operands);
rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
transOp, typeConverter->convertType(transOp.res().getType()),
@ -338,18 +332,15 @@ public:
};
/// Conversion pattern for a vector.maskedload.
class VectorMaskedLoadOpConversion : public ConvertToLLVMPattern {
class VectorMaskedLoadOpConversion
: public ConvertOpToLLVMPattern<vector::MaskedLoadOp> {
public:
explicit VectorMaskedLoadOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: ConvertToLLVMPattern(vector::MaskedLoadOp::getOperationName(), context,
typeConverter) {}
using ConvertOpToLLVMPattern<vector::MaskedLoadOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(vector::MaskedLoadOp load, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto load = cast<vector::MaskedLoadOp>(op);
auto loc = load->getLoc();
auto adaptor = vector::MaskedLoadOpAdaptor(operands);
// Resolve alignment.
@ -371,18 +362,15 @@ public:
};
/// Conversion pattern for a vector.maskedstore.
class VectorMaskedStoreOpConversion : public ConvertToLLVMPattern {
class VectorMaskedStoreOpConversion
: public ConvertOpToLLVMPattern<vector::MaskedStoreOp> {
public:
explicit VectorMaskedStoreOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: ConvertToLLVMPattern(vector::MaskedStoreOp::getOperationName(), context,
typeConverter) {}
using ConvertOpToLLVMPattern<vector::MaskedStoreOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(vector::MaskedStoreOp store, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto store = cast<vector::MaskedStoreOp>(op);
auto loc = store->getLoc();
auto adaptor = vector::MaskedStoreOpAdaptor(operands);
// Resolve alignment.
@ -404,18 +392,15 @@ public:
};
/// Conversion pattern for a vector.gather.
class VectorGatherOpConversion : public ConvertToLLVMPattern {
class VectorGatherOpConversion
: public ConvertOpToLLVMPattern<vector::GatherOp> {
public:
explicit VectorGatherOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: ConvertToLLVMPattern(vector::GatherOp::getOperationName(), context,
typeConverter) {}
using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(vector::GatherOp gather, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto gather = cast<vector::GatherOp>(op);
auto loc = gather->getLoc();
auto adaptor = vector::GatherOpAdaptor(operands);
// Resolve alignment.
@ -440,18 +425,15 @@ public:
};
/// Conversion pattern for a vector.scatter.
class VectorScatterOpConversion : public ConvertToLLVMPattern {
class VectorScatterOpConversion
: public ConvertOpToLLVMPattern<vector::ScatterOp> {
public:
explicit VectorScatterOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: ConvertToLLVMPattern(vector::ScatterOp::getOperationName(), context,
typeConverter) {}
using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(vector::ScatterOp scatter, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto scatter = cast<vector::ScatterOp>(op);
auto loc = scatter->getLoc();
auto adaptor = vector::ScatterOpAdaptor(operands);
// Resolve alignment.
@ -476,18 +458,15 @@ public:
};
/// Conversion pattern for a vector.expandload.
class VectorExpandLoadOpConversion : public ConvertToLLVMPattern {
class VectorExpandLoadOpConversion
: public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
public:
explicit VectorExpandLoadOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: ConvertToLLVMPattern(vector::ExpandLoadOp::getOperationName(), context,
typeConverter) {}
using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(vector::ExpandLoadOp expand, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto expand = cast<vector::ExpandLoadOp>(op);
auto loc = expand->getLoc();
auto adaptor = vector::ExpandLoadOpAdaptor(operands);
Value ptr;
@ -497,25 +476,22 @@ public:
auto vType = expand.getResultVectorType();
rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
op, typeConverter->convertType(vType), ptr, adaptor.mask(),
expand, typeConverter->convertType(vType), ptr, adaptor.mask(),
adaptor.pass_thru());
return success();
}
};
/// Conversion pattern for a vector.compressstore.
class VectorCompressStoreOpConversion : public ConvertToLLVMPattern {
class VectorCompressStoreOpConversion
: public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
public:
explicit VectorCompressStoreOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: ConvertToLLVMPattern(vector::CompressStoreOp::getOperationName(),
context, typeConverter) {}
using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(vector::CompressStoreOp compress, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto compress = cast<vector::CompressStoreOp>(op);
auto loc = compress->getLoc();
auto adaptor = vector::CompressStoreOpAdaptor(operands);
Value ptr;
@ -524,25 +500,23 @@ public:
return failure();
rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
op, adaptor.value(), ptr, adaptor.mask());
compress, adaptor.value(), ptr, adaptor.mask());
return success();
}
};
/// Conversion pattern for all vector reductions.
class VectorReductionOpConversion : public ConvertToLLVMPattern {
class VectorReductionOpConversion
: public ConvertOpToLLVMPattern<vector::ReductionOp> {
public:
explicit VectorReductionOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter,
explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv,
bool reassociateFPRed)
: ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context,
typeConverter),
: ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
reassociateFPReductions(reassociateFPRed) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(vector::ReductionOp reductionOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto reductionOp = cast<vector::ReductionOp>(op);
auto kind = reductionOp.kind();
Type eltType = reductionOp.dest().getType();
Type llvmType = typeConverter->convertType(eltType);
@ -550,33 +524,33 @@ public:
// Integer reductions: add/mul/min/max/and/or/xor.
if (kind == "add")
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(
op, llvmType, operands[0]);
reductionOp, llvmType, operands[0]);
else if (kind == "mul")
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(
op, llvmType, operands[0]);
reductionOp, llvmType, operands[0]);
else if (kind == "min" &&
(eltType.isIndex() || eltType.isUnsignedInteger()))
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>(
op, llvmType, operands[0]);
reductionOp, llvmType, operands[0]);
else if (kind == "min")
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>(
op, llvmType, operands[0]);
reductionOp, llvmType, operands[0]);
else if (kind == "max" &&
(eltType.isIndex() || eltType.isUnsignedInteger()))
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>(
op, llvmType, operands[0]);
reductionOp, llvmType, operands[0]);
else if (kind == "max")
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>(
op, llvmType, operands[0]);
reductionOp, llvmType, operands[0]);
else if (kind == "and")
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(
op, llvmType, operands[0]);
reductionOp, llvmType, operands[0]);
else if (kind == "or")
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(
op, llvmType, operands[0]);
reductionOp, llvmType, operands[0]);
else if (kind == "xor")
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(
op, llvmType, operands[0]);
reductionOp, llvmType, operands[0]);
else
return failure();
return success();
@ -590,27 +564,27 @@ public:
// Optional accumulator (or zero).
Value acc = operands.size() > 1 ? operands[1]
: rewriter.create<LLVM::ConstantOp>(
op->getLoc(), llvmType,
reductionOp->getLoc(), llvmType,
rewriter.getZeroAttr(eltType));
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
op, llvmType, acc, operands[0],
reductionOp, llvmType, acc, operands[0],
rewriter.getBoolAttr(reassociateFPReductions));
} else if (kind == "mul") {
// Optional accumulator (or one).
Value acc = operands.size() > 1
? operands[1]
: rewriter.create<LLVM::ConstantOp>(
op->getLoc(), llvmType,
reductionOp->getLoc(), llvmType,
rewriter.getFloatAttr(eltType, 1.0));
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
op, llvmType, acc, operands[0],
reductionOp, llvmType, acc, operands[0],
rewriter.getBoolAttr(reassociateFPReductions));
} else if (kind == "min")
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(op, llvmType,
operands[0]);
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(
reductionOp, llvmType, operands[0]);
else if (kind == "max")
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(op, llvmType,
operands[0]);
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(
reductionOp, llvmType, operands[0]);
else
return failure();
return success();
@ -621,17 +595,16 @@ private:
};
/// Conversion pattern for a vector.create_mask (1-D only).
class VectorCreateMaskOpConversion : public ConvertToLLVMPattern {
class VectorCreateMaskOpConversion
: public ConvertOpToLLVMPattern<vector::CreateMaskOp> {
public:
explicit VectorCreateMaskOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter,
explicit VectorCreateMaskOpConversion(LLVMTypeConverter &typeConv,
bool enableIndexOpt)
: ConvertToLLVMPattern(vector::CreateMaskOp::getOperationName(), context,
typeConverter),
: ConvertOpToLLVMPattern<vector::CreateMaskOp>(typeConv),
enableIndexOptimizations(enableIndexOpt) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(vector::CreateMaskOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto dstType = op->getResult(0).getType().cast<VectorType>();
int64_t rank = dstType.getRank();
@ -648,19 +621,16 @@ private:
const bool enableIndexOptimizations;
};
class VectorShuffleOpConversion : public ConvertToLLVMPattern {
class VectorShuffleOpConversion
: public ConvertOpToLLVMPattern<vector::ShuffleOp> {
public:
explicit VectorShuffleOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context,
typeConverter) {}
using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(vector::ShuffleOp shuffleOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto loc = shuffleOp->getLoc();
auto adaptor = vector::ShuffleOpAdaptor(operands);
auto shuffleOp = cast<vector::ShuffleOp>(op);
auto v1Type = shuffleOp.getV1VectorType();
auto v2Type = shuffleOp.getV2VectorType();
auto vectorType = shuffleOp.getVectorType();
@ -680,9 +650,9 @@ public:
// For rank 1, where both operands have *exactly* the same vector type,
// there is direct shuffle support in LLVM. Use it!
if (rank == 1 && v1Type == v2Type) {
Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>(
Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
rewriter.replaceOp(op, shuffle);
rewriter.replaceOp(shuffleOp, llvmShuffleOp);
return success();
}
@ -701,23 +671,22 @@ public:
insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
llvmType, rank, insPos++);
}
rewriter.replaceOp(op, insert);
rewriter.replaceOp(shuffleOp, insert);
return success();
}
};
class VectorExtractElementOpConversion : public ConvertToLLVMPattern {
class VectorExtractElementOpConversion
: public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
public:
explicit VectorExtractElementOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(),
context, typeConverter) {}
using ConvertOpToLLVMPattern<
vector::ExtractElementOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(vector::ExtractElementOp extractEltOp,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto adaptor = vector::ExtractElementOpAdaptor(operands);
auto extractEltOp = cast<vector::ExtractElementOp>(op);
auto vectorType = extractEltOp.getVectorType();
auto llvmType = typeConverter->convertType(vectorType.getElementType());
@ -726,24 +695,21 @@ public:
return failure();
rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
op, llvmType, adaptor.vector(), adaptor.position());
extractEltOp, llvmType, adaptor.vector(), adaptor.position());
return success();
}
};
class VectorExtractOpConversion : public ConvertToLLVMPattern {
class VectorExtractOpConversion
: public ConvertOpToLLVMPattern<vector::ExtractOp> {
public:
explicit VectorExtractOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context,
typeConverter) {}
using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto loc = extractOp->getLoc();
auto adaptor = vector::ExtractOpAdaptor(operands);
auto extractOp = cast<vector::ExtractOp>(op);
auto vectorType = extractOp.getVectorType();
auto resultType = extractOp.getResult().getType();
auto llvmResultType = typeConverter->convertType(resultType);
@ -757,12 +723,12 @@ public:
if (resultType.isa<VectorType>()) {
Value extracted = rewriter.create<LLVM::ExtractValueOp>(
loc, llvmResultType, adaptor.vector(), positionArrayAttr);
rewriter.replaceOp(op, extracted);
rewriter.replaceOp(extractOp, extracted);
return success();
}
// Potential extraction of 1-D vector from array.
auto *context = op->getContext();
auto *context = extractOp->getContext();
Value extracted = adaptor.vector();
auto positionAttrs = positionArrayAttr.getValue();
if (positionAttrs.size() > 1) {
@ -780,7 +746,7 @@ public:
auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
extracted =
rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
rewriter.replaceOp(op, extracted);
rewriter.replaceOp(extractOp, extracted);
return success();
}
@ -800,39 +766,32 @@ public:
/// (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">)
/// -> !llvm<"<8 x float>">
/// ```
class VectorFMAOp1DConversion : public ConvertToLLVMPattern {
class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
public:
explicit VectorFMAOp1DConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context,
typeConverter) {}
using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(vector::FMAOp fmaOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto adaptor = vector::FMAOpAdaptor(operands);
vector::FMAOp fmaOp = cast<vector::FMAOp>(op);
VectorType vType = fmaOp.getVectorType();
if (vType.getRank() != 1)
return failure();
rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(op, adaptor.lhs(),
rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(fmaOp, adaptor.lhs(),
adaptor.rhs(), adaptor.acc());
return success();
}
};
class VectorInsertElementOpConversion : public ConvertToLLVMPattern {
class VectorInsertElementOpConversion
: public ConvertOpToLLVMPattern<vector::InsertElementOp> {
public:
explicit VectorInsertElementOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(),
context, typeConverter) {}
using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(vector::InsertElementOp insertEltOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto adaptor = vector::InsertElementOpAdaptor(operands);
auto insertEltOp = cast<vector::InsertElementOp>(op);
auto vectorType = insertEltOp.getDestVectorType();
auto llvmType = typeConverter->convertType(vectorType);
@ -841,24 +800,22 @@ public:
return failure();
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position());
insertEltOp, llvmType, adaptor.dest(), adaptor.source(),
adaptor.position());
return success();
}
};
class VectorInsertOpConversion : public ConvertToLLVMPattern {
class VectorInsertOpConversion
: public ConvertOpToLLVMPattern<vector::InsertOp> {
public:
explicit VectorInsertOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context,
typeConverter) {}
using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto loc = insertOp->getLoc();
auto adaptor = vector::InsertOpAdaptor(operands);
auto insertOp = cast<vector::InsertOp>(op);
auto sourceType = insertOp.getSourceType();
auto destVectorType = insertOp.getDestVectorType();
auto llvmResultType = typeConverter->convertType(destVectorType);
@ -873,12 +830,12 @@ public:
Value inserted = rewriter.create<LLVM::InsertValueOp>(
loc, llvmResultType, adaptor.dest(), adaptor.source(),
positionArrayAttr);
rewriter.replaceOp(op, inserted);
rewriter.replaceOp(insertOp, inserted);
return success();
}
// Potential extraction of 1-D vector from array.
auto *context = op->getContext();
auto *context = insertOp->getContext();
Value extracted = adaptor.dest();
auto positionAttrs = positionArrayAttr.getValue();
auto position = positionAttrs.back().cast<IntegerAttr>();
@ -908,7 +865,7 @@ public:
nMinusOnePositionAttrs);
}
rewriter.replaceOp(op, inserted);
rewriter.replaceOp(insertOp, inserted);
return success();
}
};
@ -1117,18 +1074,15 @@ computeContiguousStrides(MemRefType memRefType) {
return strides;
}
class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
class VectorTypeCastOpConversion
: public ConvertOpToLLVMPattern<vector::TypeCastOp> {
public:
explicit VectorTypeCastOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context,
typeConverter) {}
using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(vector::TypeCastOp castOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op);
auto loc = castOp->getLoc();
MemRefType sourceMemRefType =
castOp.getOperand().getType().cast<MemRefType>();
MemRefType targetMemRefType =
@ -1195,7 +1149,7 @@ public:
desc.setStride(rewriter, loc, index, stride);
}
rewriter.replaceOp(op, {desc});
rewriter.replaceOp(castOp, {desc});
return success();
}
};
@ -1208,18 +1162,16 @@ public:
/// 4. Create a mask where offsetVector is compared against memref upper bound.
/// 5. Rewrite op as a masked read or write.
template <typename ConcreteOp>
class VectorTransferConversion : public ConvertToLLVMPattern {
class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
public:
explicit VectorTransferConversion(MLIRContext *context,
LLVMTypeConverter &typeConv,
explicit VectorTransferConversion(LLVMTypeConverter &typeConv,
bool enableIndexOpt)
: ConvertToLLVMPattern(ConcreteOp::getOperationName(), context, typeConv),
: ConvertOpToLLVMPattern<ConcreteOp>(typeConv),
enableIndexOptimizations(enableIndexOpt) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto xferOp = cast<ConcreteOp>(op);
auto adaptor = getTransferOpAdapter(xferOp, operands);
if (xferOp.getVectorType().getRank() > 1 ||
@ -1228,16 +1180,18 @@ public:
if (xferOp.permutation_map() !=
AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(),
xferOp.getVectorType().getRank(),
op->getContext()))
xferOp->getContext()))
return failure();
// Only contiguous source tensors supported atm.
auto strides = computeContiguousStrides(xferOp.getMemRefType());
if (!strides)
return failure();
auto toLLVMTy = [&](Type t) { return typeConverter->convertType(t); };
auto toLLVMTy = [&](Type t) {
return this->getTypeConverter()->convertType(t);
};
Location loc = op->getLoc();
Location loc = xferOp->getLoc();
MemRefType memRefType = xferOp.getMemRefType();
if (auto memrefVectorElementType =
@ -1267,8 +1221,8 @@ public:
// addrspacecast shall be used when source/dst memrefs are not on
// address space 0.
// TODO: support alignment when possible.
Value dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(),
adaptor.indices(), rewriter);
Value dataPtr = this->getStridedElementPtr(
loc, memRefType, adaptor.memref(), adaptor.indices(), rewriter);
auto vecTy =
toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
Value vectorDataPtr;
@ -1280,8 +1234,9 @@ public:
loc, vecTy.getPointerTo(), dataPtr);
if (!xferOp.isMaskedDim(0))
return replaceTransferOpWithLoadOrStore(
rewriter, *getTypeConverter(), loc, xferOp, operands, vectorDataPtr);
return replaceTransferOpWithLoadOrStore(rewriter,
*this->getTypeConverter(), loc,
xferOp, operands, vectorDataPtr);
// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
// 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
@ -1294,11 +1249,11 @@ public:
unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
Value off = xferOp.indices()[lastIndex];
Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex);
Value mask = buildVectorComparison(rewriter, op, enableIndexOptimizations,
vecWidth, dim, &off);
Value mask = buildVectorComparison(
rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off);
// 5. Rewrite as a masked read / write.
return replaceTransferOpWithMasked(rewriter, *getTypeConverter(), loc,
return replaceTransferOpWithMasked(rewriter, *this->getTypeConverter(), loc,
xferOp, operands, vectorDataPtr, mask);
}
@ -1306,12 +1261,9 @@ private:
const bool enableIndexOptimizations;
};
class VectorPrintOpConversion : public ConvertToLLVMPattern {
class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
public:
explicit VectorPrintOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: ConvertToLLVMPattern(vector::PrintOp::getOperationName(), context,
typeConverter) {}
using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
// Proof-of-concept lowering implementation that relies on a small
// runtime support library, which only needs to provide a few
@ -1326,9 +1278,8 @@ public:
// TODO: rely solely on libc in future? something else?
//
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(vector::PrintOp printOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto printOp = cast<vector::PrintOp>(op);
auto adaptor = vector::PrintOpAdaptor(operands);
Type printType = printOp.getPrintType();
@ -1341,11 +1292,11 @@ public:
Type eltType = vectorType ? vectorType.getElementType() : printType;
Operation *printer;
if (eltType.isF32()) {
printer = getPrintFloat(op);
printer = getPrintFloat(printOp);
} else if (eltType.isF64()) {
printer = getPrintDouble(op);
printer = getPrintDouble(printOp);
} else if (eltType.isIndex()) {
printer = getPrintU64(op);
printer = getPrintU64(printOp);
} else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
// Integers need a zero or sign extension on the operand
// (depending on the source type) as well as a signed or
@ -1355,7 +1306,7 @@ public:
if (width <= 64) {
if (width < 64)
conversion = PrintConversion::ZeroExt64;
printer = getPrintU64(op);
printer = getPrintU64(printOp);
} else {
return failure();
}
@ -1368,7 +1319,7 @@ public:
conversion = PrintConversion::ZeroExt64;
else if (width < 64)
conversion = PrintConversion::SignExt64;
printer = getPrintI64(op);
printer = getPrintI64(printOp);
} else {
return failure();
}
@ -1379,10 +1330,10 @@ public:
// Unroll vector into elementary print calls.
int64_t rank = vectorType ? vectorType.getRank() : 0;
emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank,
emitRanks(rewriter, printOp, adaptor.source(), vectorType, printer, rank,
conversion);
emitCall(rewriter, op->getLoc(), getPrintNewline(op));
rewriter.eraseOp(op);
emitCall(rewriter, printOp->getLoc(), getPrintNewline(printOp));
rewriter.eraseOp(printOp);
return success();
}
@ -1560,11 +1511,11 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorInsertStridedSliceOpSameRankRewritePattern,
VectorExtractStridedSliceOpConversion>(ctx);
patterns.insert<VectorReductionOpConversion>(
ctx, converter, reassociateFPReductions);
converter, reassociateFPReductions);
patterns.insert<VectorCreateMaskOpConversion,
VectorTransferConversion<TransferReadOp>,
VectorTransferConversion<TransferWriteOp>>(
ctx, converter, enableIndexOptimizations);
converter, enableIndexOptimizations);
patterns
.insert<VectorShuffleOpConversion,
VectorExtractElementOpConversion,
@ -1579,13 +1530,12 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorGatherOpConversion,
VectorScatterOpConversion,
VectorExpandLoadOpConversion,
VectorCompressStoreOpConversion>(ctx, converter);
VectorCompressStoreOpConversion>(converter);
// clang-format on
}
void mlir::populateVectorToLLVMMatrixConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
MLIRContext *ctx = converter.getDialect()->getContext();
patterns.insert<VectorMatmulOpConversion>(ctx, converter);
patterns.insert<VectorFlatTransposeOpConversion>(ctx, converter);
patterns.insert<VectorMatmulOpConversion>(converter);
patterns.insert<VectorFlatTransposeOpConversion>(converter);
}

View File

@ -55,17 +55,13 @@ namespace {
/// types. For unsupported cases, they will fall back to the vector to
/// llvm conversion pattern.
template <typename ConcreteOp>
class VectorTransferConversion : public ConvertToLLVMPattern {
class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
public:
explicit VectorTransferConversion(MLIRContext *context,
LLVMTypeConverter &typeConv)
: ConvertToLLVMPattern(ConcreteOp::getOperationName(), context,
typeConv) {}
using ConvertOpToLLVMPattern<ConcreteOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto xferOp = cast<ConcreteOp>(op);
typename ConcreteOp::Adaptor adaptor(operands);
if (xferOp.getVectorType().getRank() > 1 ||
@ -79,11 +75,13 @@ public:
if (!xferOp.isMaskedDim(0))
return failure();
auto toLLVMTy = [&](Type t) { return typeConverter->convertType(t); };
auto toLLVMTy = [&](Type t) {
return this->getTypeConverter()->convertType(t);
};
LLVM::LLVMType vecTy =
toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
unsigned vecWidth = vecTy.getVectorNumElements();
Location loc = op->getLoc();
Location loc = xferOp->getLoc();
// The backend result vector scalarization have trouble scalarize
// <1 x ty> result, exclude the x1 width from the lowering.
@ -102,8 +100,8 @@ public:
// Note that the dataPtr starts at the offset address specified by
// indices, so no need to calculate offset size in bytes again in
// the MUBUF instruction.
Value dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(),
adaptor.indices(), rewriter);
Value dataPtr = this->getStridedElementPtr(
loc, memRefType, adaptor.memref(), adaptor.indices(), rewriter);
// 1. Create and fill a <4 x i32> dwordConfig with:
// 1st two elements holding the address of dataPtr.
@ -126,7 +124,7 @@ public:
constConfig);
Value dataPtrAsI64 = rewriter.create<LLVM::PtrToIntOp>(
loc, toLLVMTy(i64Ty).template cast<LLVM::LLVMType>(), dataPtr);
Value zero = createIndexConstant(rewriter, loc, 0);
Value zero = this->createIndexConstant(rewriter, loc, 0);
Value dwordConfig = rewriter.create<LLVM::InsertElementOp>(
loc,
LLVM::LLVMType::getVectorTy(
@ -143,7 +141,7 @@ public:
loc, toLLVMTy(i32Ty),
rewriter.getIntegerAttr(rewriter.getIntegerType(32), 0));
return replaceTransferOpWithMubuf(
rewriter, operands, *getTypeConverter(), loc, xferOp, vecTy,
rewriter, operands, *this->getTypeConverter(), loc, xferOp, vecTy,
dwordConfig, int32Zero, int32Zero, int1False, int1False);
}
};
@ -151,9 +149,8 @@ public:
void mlir::populateVectorToROCDLConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
MLIRContext *ctx = converter.getDialect()->getContext();
patterns.insert<VectorTransferConversion<TransferReadOp>,
VectorTransferConversion<TransferWriteOp>>(ctx, converter);
VectorTransferConversion<TransferWriteOp>>(converter);
}
namespace {