llvm-capstone/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Aart Bik 1c81adf362 [VectorOps] Add lowering of vector.shuffle to LLVM IR
For example, a shuffle

%1 = vector.shuffle %arg0, %arg1 [0 : i32, 1 : i32] : vector<2xf32>, vector<2xf32>

becomes a direct LLVM shuffle

0 = llvm.shufflevector %arg0, %arg1 [0 : i32, 1 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">

but

%1 = vector.shuffle %a, %b[1 : i32, 0 : i32, 2: i32] : vector<1x4xf32>, vector<2x4xf32>

becomes the more elaborate (note the index permutation that drives
argument selection for the extract operations)

%0 = llvm.mlir.undef : !llvm<"[3 x <4 x float>]">
%1 = llvm.extractvalue %arg1[0] : !llvm<"[2 x <4 x float>]">
%2 = llvm.insertvalue %1, %0[0] : !llvm<"[3 x <4 x float>]">
%3 = llvm.extractvalue %arg0[0] : !llvm<"[1 x <4 x float>]">
%4 = llvm.insertvalue %3, %2[1] : !llvm<"[3 x <4 x float>]">
%5 = llvm.extractvalue %arg1[1] : !llvm<"[2 x <4 x float>]">
%6 = llvm.insertvalue %5, %4[2] : !llvm<"[3 x <4 x float>]">

PiperOrigin-RevId: 285268164
2019-12-12 14:11:56 -08:00

604 lines
25 KiB
C++

//===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/VectorOps/VectorOps.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
#include "llvm/Support/Allocator.h"
#include "llvm/Support/ErrorHandling.h"
using namespace mlir;
template <typename T>
static LLVM::LLVMType getPtrToElementType(T containerType,
LLVMTypeConverter &lowering) {
return lowering.convertType(containerType.getElementType())
.template cast<LLVM::LLVMType>()
.getPointerTo();
}
// Helper to reduce vector type by one rank at front.
static VectorType reducedVectorTypeFront(VectorType tp) {
assert((tp.getRank() > 1) && "unlowerable vector type");
return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
}
// Helper to reduce vector type by *all* but one rank at back.
static VectorType reducedVectorTypeBack(VectorType tp) {
assert((tp.getRank() > 1) && "unlowerable vector type");
return VectorType::get(tp.getShape().take_back(), tp.getElementType());
}
// Helper that picks the proper sequence for inserting.
static Value *insertOne(ConversionPatternRewriter &rewriter,
LLVMTypeConverter &lowering, Location loc, Value *val1,
Value *val2, Type llvmType, int64_t rank, int64_t pos) {
if (rank == 1) {
auto idxType = rewriter.getIndexType();
auto constant = rewriter.create<LLVM::ConstantOp>(
loc, lowering.convertType(idxType),
rewriter.getIntegerAttr(idxType, pos));
return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
constant);
}
return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
rewriter.getI64ArrayAttr(pos));
}
// Helper that picks the proper sequence for extracting.
static Value *extractOne(ConversionPatternRewriter &rewriter,
LLVMTypeConverter &lowering, Location loc, Value *val,
Type llvmType, int64_t rank, int64_t pos) {
if (rank == 1) {
auto idxType = rewriter.getIndexType();
auto constant = rewriter.create<LLVM::ConstantOp>(
loc, lowering.convertType(idxType),
rewriter.getIntegerAttr(idxType, pos));
return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
constant);
}
return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
rewriter.getI64ArrayAttr(pos));
}
class VectorBroadcastOpConversion : public LLVMOpLowering {
public:
explicit VectorBroadcastOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: LLVMOpLowering(vector::BroadcastOp::getOperationName(), context,
typeConverter) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto broadcastOp = cast<vector::BroadcastOp>(op);
VectorType dstVectorType = broadcastOp.getVectorType();
if (lowering.convertType(dstVectorType) == nullptr)
return matchFailure();
// Rewrite when the full vector type can be lowered (which
// implies all 'reduced' types can be lowered too).
auto adaptor = vector::BroadcastOpOperandAdaptor(operands);
VectorType srcVectorType =
broadcastOp.getSourceType().dyn_cast<VectorType>();
rewriter.replaceOp(
op, expandRanks(adaptor.source(), // source value to be expanded
op->getLoc(), // location of original broadcast
srcVectorType, dstVectorType, rewriter));
return matchSuccess();
}
private:
// Expands the given source value over all the ranks, as defined
// by the source and destination type (a null source type denotes
// expansion from a scalar value into a vector).
//
// TODO(ajcbik): consider replacing this one-pattern lowering
// with a two-pattern lowering using other vector
// ops once all insert/extract/shuffle operations
// are available with lowering implemention.
//
Value *expandRanks(Value *value, Location loc, VectorType srcVectorType,
VectorType dstVectorType,
ConversionPatternRewriter &rewriter) const {
assert((dstVectorType != nullptr) && "invalid result type in broadcast");
// Determine rank of source and destination.
int64_t srcRank = srcVectorType ? srcVectorType.getRank() : 0;
int64_t dstRank = dstVectorType.getRank();
int64_t curDim = dstVectorType.getDimSize(0);
if (srcRank < dstRank)
// Duplicate this rank.
return duplicateOneRank(value, loc, srcVectorType, dstVectorType, dstRank,
curDim, rewriter);
// If all trailing dimensions are the same, the broadcast consists of
// simply passing through the source value and we are done. Otherwise,
// any non-matching dimension forces a stretch along this rank.
assert((srcVectorType != nullptr) && (srcRank > 0) &&
(srcRank == dstRank) && "invalid rank in broadcast");
for (int64_t r = 0; r < dstRank; r++) {
if (srcVectorType.getDimSize(r) != dstVectorType.getDimSize(r)) {
return stretchOneRank(value, loc, srcVectorType, dstVectorType, dstRank,
curDim, rewriter);
}
}
return value;
}
// Picks the best way to duplicate a single rank. For the 1-D case, a
// single insert-elt/shuffle is the most efficient expansion. For higher
// dimensions, however, we need dim x insert-values on a new broadcast
// with one less leading dimension, which will be lowered "recursively"
// to matching LLVM IR.
// For example:
// v = broadcast s : f32 to vector<4x2xf32>
// becomes:
// x = broadcast s : f32 to vector<2xf32>
// v = [x,x,x,x]
// becomes:
// x = [s,s]
// v = [x,x,x,x]
Value *duplicateOneRank(Value *value, Location loc, VectorType srcVectorType,
VectorType dstVectorType, int64_t rank, int64_t dim,
ConversionPatternRewriter &rewriter) const {
Type llvmType = lowering.convertType(dstVectorType);
assert((llvmType != nullptr) && "unlowerable vector type");
if (rank == 1) {
Value *undef = rewriter.create<LLVM::UndefOp>(loc, llvmType);
Value *expand =
insertOne(rewriter, lowering, loc, undef, value, llvmType, rank, 0);
SmallVector<int32_t, 4> zeroValues(dim, 0);
return rewriter.create<LLVM::ShuffleVectorOp>(
loc, expand, undef, rewriter.getI32ArrayAttr(zeroValues));
}
Value *expand =
expandRanks(value, loc, srcVectorType,
reducedVectorTypeFront(dstVectorType), rewriter);
Value *result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
for (int64_t d = 0; d < dim; ++d) {
result =
insertOne(rewriter, lowering, loc, result, expand, llvmType, rank, d);
}
return result;
}
// Picks the best way to stretch a single rank. For the 1-D case, a
// single insert-elt/shuffle is the most efficient expansion when at
// a stretch. Otherwise, every dimension needs to be expanded
// individually and individually inserted in the resulting vector.
// For example:
// v = broadcast w : vector<4x1x2xf32> to vector<4x2x2xf32>
// becomes:
// a = broadcast w[0] : vector<1x2xf32> to vector<2x2xf32>
// b = broadcast w[1] : vector<1x2xf32> to vector<2x2xf32>
// c = broadcast w[2] : vector<1x2xf32> to vector<2x2xf32>
// d = broadcast w[3] : vector<1x2xf32> to vector<2x2xf32>
// v = [a,b,c,d]
// becomes:
// x = broadcast w[0][0] : vector<2xf32> to vector <2x2xf32>
// y = broadcast w[1][0] : vector<2xf32> to vector <2x2xf32>
// a = [x, y]
// etc.
Value *stretchOneRank(Value *value, Location loc, VectorType srcVectorType,
VectorType dstVectorType, int64_t rank, int64_t dim,
ConversionPatternRewriter &rewriter) const {
Type llvmType = lowering.convertType(dstVectorType);
assert((llvmType != nullptr) && "unlowerable vector type");
Value *result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
bool atStretch = dim != srcVectorType.getDimSize(0);
if (rank == 1) {
assert(atStretch);
Type redLlvmType = lowering.convertType(dstVectorType.getElementType());
Value *one =
extractOne(rewriter, lowering, loc, value, redLlvmType, rank, 0);
Value *expand =
insertOne(rewriter, lowering, loc, result, one, llvmType, rank, 0);
SmallVector<int32_t, 4> zeroValues(dim, 0);
return rewriter.create<LLVM::ShuffleVectorOp>(
loc, expand, result, rewriter.getI32ArrayAttr(zeroValues));
}
VectorType redSrcType = reducedVectorTypeFront(srcVectorType);
VectorType redDstType = reducedVectorTypeFront(dstVectorType);
Type redLlvmType = lowering.convertType(redSrcType);
for (int64_t d = 0; d < dim; ++d) {
int64_t pos = atStretch ? 0 : d;
Value *one =
extractOne(rewriter, lowering, loc, value, redLlvmType, rank, pos);
Value *expand = expandRanks(one, loc, redSrcType, redDstType, rewriter);
result =
insertOne(rewriter, lowering, loc, result, expand, llvmType, rank, d);
}
return result;
}
};
class VectorShuffleOpConversion : public LLVMOpLowering {
public:
explicit VectorShuffleOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: LLVMOpLowering(vector::ShuffleOp::getOperationName(), context,
typeConverter) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto adaptor = vector::ShuffleOpOperandAdaptor(operands);
auto shuffleOp = cast<vector::ShuffleOp>(op);
auto v1Type = shuffleOp.getV1VectorType();
auto v2Type = shuffleOp.getV2VectorType();
auto vectorType = shuffleOp.getVectorType();
Type llvmType = lowering.convertType(vectorType);
auto maskArrayAttr = shuffleOp.mask();
// Bail if result type cannot be lowered.
if (!llvmType)
return matchFailure();
// Get rank and dimension sizes.
int64_t rank = vectorType.getRank();
assert(v1Type.getRank() == rank);
assert(v2Type.getRank() == rank);
int64_t v1Dim = v1Type.getDimSize(0);
// For rank 1, where both operands have *exactly* the same vector type,
// there is direct shuffle support in LLVM. Use it!
if (rank == 1 && v1Type == v2Type) {
Value *shuffle = rewriter.create<LLVM::ShuffleVectorOp>(
loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
rewriter.replaceOp(op, shuffle);
return matchSuccess();
}
// For all other cases, insert the individual values individually.
Value *insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
int64_t insPos = 0;
for (auto en : llvm::enumerate(maskArrayAttr)) {
int64_t extPos = en.value().cast<IntegerAttr>().getInt();
Value *value = adaptor.v1();
if (extPos >= v1Dim) {
extPos -= v1Dim;
value = adaptor.v2();
}
Value *extract =
extractOne(rewriter, lowering, loc, value, llvmType, rank, extPos);
insert = insertOne(rewriter, lowering, loc, insert, extract, llvmType,
rank, insPos++);
}
rewriter.replaceOp(op, insert);
return matchSuccess();
}
};
class VectorExtractOpConversion : public LLVMOpLowering {
public:
explicit VectorExtractOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: LLVMOpLowering(vector::ExtractOp::getOperationName(), context,
typeConverter) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto adaptor = vector::ExtractOpOperandAdaptor(operands);
auto extractOp = cast<vector::ExtractOp>(op);
auto vectorType = extractOp.getVectorType();
auto resultType = extractOp.getResult()->getType();
auto llvmResultType = lowering.convertType(resultType);
auto positionArrayAttr = extractOp.position();
// Bail if result type cannot be lowered.
if (!llvmResultType)
return matchFailure();
// One-shot extraction of vector from array (only requires extractvalue).
if (resultType.isa<VectorType>()) {
Value *extracted = rewriter.create<LLVM::ExtractValueOp>(
loc, llvmResultType, adaptor.vector(), positionArrayAttr);
rewriter.replaceOp(op, extracted);
return matchSuccess();
}
// Potential extraction of 1-D vector from array.
auto *context = op->getContext();
Value *extracted = adaptor.vector();
auto positionAttrs = positionArrayAttr.getValue();
if (positionAttrs.size() > 1) {
auto oneDVectorType = reducedVectorTypeBack(vectorType);
auto nMinusOnePositionAttrs =
ArrayAttr::get(positionAttrs.drop_back(), context);
extracted = rewriter.create<LLVM::ExtractValueOp>(
loc, lowering.convertType(oneDVectorType), extracted,
nMinusOnePositionAttrs);
}
// Remaining extraction of element from 1-D LLVM vector
auto position = positionAttrs.back().cast<IntegerAttr>();
auto i32Type = LLVM::LLVMType::getInt32Ty(lowering.getDialect());
auto constant = rewriter.create<LLVM::ConstantOp>(loc, i32Type, position);
extracted =
rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
rewriter.replaceOp(op, extracted);
return matchSuccess();
}
};
class VectorInsertOpConversion : public LLVMOpLowering {
public:
explicit VectorInsertOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: LLVMOpLowering(vector::InsertOp::getOperationName(), context,
typeConverter) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto adaptor = vector::InsertOpOperandAdaptor(operands);
auto insertOp = cast<vector::InsertOp>(op);
auto sourceType = insertOp.getSourceType();
auto destVectorType = insertOp.getDestVectorType();
auto llvmResultType = lowering.convertType(destVectorType);
auto positionArrayAttr = insertOp.position();
// Bail if result type cannot be lowered.
if (!llvmResultType)
return matchFailure();
// One-shot insertion of a vector into an array (only requires insertvalue).
if (sourceType.isa<VectorType>()) {
Value *inserted = rewriter.create<LLVM::InsertValueOp>(
loc, llvmResultType, adaptor.dest(), adaptor.source(),
positionArrayAttr);
rewriter.replaceOp(op, inserted);
return matchSuccess();
}
// Potential extraction of 1-D vector from array.
auto *context = op->getContext();
Value *extracted = adaptor.dest();
auto positionAttrs = positionArrayAttr.getValue();
auto position = positionAttrs.back().cast<IntegerAttr>();
auto oneDVectorType = destVectorType;
if (positionAttrs.size() > 1) {
oneDVectorType = reducedVectorTypeBack(destVectorType);
auto nMinusOnePositionAttrs =
ArrayAttr::get(positionAttrs.drop_back(), context);
extracted = rewriter.create<LLVM::ExtractValueOp>(
loc, lowering.convertType(oneDVectorType), extracted,
nMinusOnePositionAttrs);
}
// Insertion of an element into a 1-D LLVM vector.
auto i32Type = LLVM::LLVMType::getInt32Ty(lowering.getDialect());
auto constant = rewriter.create<LLVM::ConstantOp>(loc, i32Type, position);
Value *inserted = rewriter.create<LLVM::InsertElementOp>(
loc, lowering.convertType(oneDVectorType), extracted, adaptor.source(),
constant);
// Potential insertion of resulting 1-D vector into array.
if (positionAttrs.size() > 1) {
auto nMinusOnePositionAttrs =
ArrayAttr::get(positionAttrs.drop_back(), context);
inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
adaptor.dest(), inserted,
nMinusOnePositionAttrs);
}
rewriter.replaceOp(op, inserted);
return matchSuccess();
}
};
class VectorOuterProductOpConversion : public LLVMOpLowering {
public:
explicit VectorOuterProductOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: LLVMOpLowering(vector::OuterProductOp::getOperationName(), context,
typeConverter) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto adaptor = vector::OuterProductOpOperandAdaptor(operands);
auto *ctx = op->getContext();
auto vLHS = adaptor.lhs()->getType().cast<LLVM::LLVMType>();
auto vRHS = adaptor.rhs()->getType().cast<LLVM::LLVMType>();
auto rankLHS = vLHS.getUnderlyingType()->getVectorNumElements();
auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements();
auto llvmArrayOfVectType = lowering.convertType(
cast<vector::OuterProductOp>(op).getResult()->getType());
Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType);
Value *a = adaptor.lhs(), *b = adaptor.rhs();
Value *acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front();
SmallVector<Value *, 8> lhs, accs;
lhs.reserve(rankLHS);
accs.reserve(rankLHS);
for (unsigned d = 0, e = rankLHS; d < e; ++d) {
// shufflevector explicitly requires i32.
auto attr = rewriter.getI32IntegerAttr(d);
SmallVector<Attribute, 4> bcastAttr(rankRHS, attr);
auto bcastArrayAttr = ArrayAttr::get(bcastAttr, ctx);
Value *aD = nullptr, *accD = nullptr;
// 1. Broadcast the element a[d] into vector aD.
aD = rewriter.create<LLVM::ShuffleVectorOp>(loc, a, a, bcastArrayAttr);
// 2. If acc is present, extract 1-d vector acc[d] into accD.
if (acc)
accD = rewriter.create<LLVM::ExtractValueOp>(
loc, vRHS, acc, rewriter.getI64ArrayAttr(d));
// 3. Compute aD outer b (plus accD, if relevant).
Value *aOuterbD =
accD ? rewriter.create<LLVM::FMulAddOp>(loc, vRHS, aD, b, accD)
.getResult()
: rewriter.create<LLVM::FMulOp>(loc, aD, b).getResult();
// 4. Insert as value `d` in the descriptor.
desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayOfVectType,
desc, aOuterbD,
rewriter.getI64ArrayAttr(d));
}
rewriter.replaceOp(op, desc);
return matchSuccess();
}
};
class VectorTypeCastOpConversion : public LLVMOpLowering {
public:
explicit VectorTypeCastOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: LLVMOpLowering(vector::TypeCastOp::getOperationName(), context,
typeConverter) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op);
MemRefType sourceMemRefType =
castOp.getOperand()->getType().cast<MemRefType>();
MemRefType targetMemRefType =
castOp.getResult()->getType().cast<MemRefType>();
// Only static shape casts supported atm.
if (!sourceMemRefType.hasStaticShape() ||
!targetMemRefType.hasStaticShape())
return matchFailure();
auto llvmSourceDescriptorTy =
operands[0]->getType().dyn_cast<LLVM::LLVMType>();
if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy())
return matchFailure();
MemRefDescriptor sourceMemRef(operands[0]);
auto llvmTargetDescriptorTy = lowering.convertType(targetMemRefType)
.dyn_cast_or_null<LLVM::LLVMType>();
if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
return matchFailure();
int64_t offset;
SmallVector<int64_t, 4> strides;
auto successStrides =
getStridesAndOffset(sourceMemRefType, strides, offset);
bool isContiguous = (strides.back() == 1);
if (isContiguous) {
auto sizes = sourceMemRefType.getShape();
for (int index = 0, e = strides.size() - 2; index < e; ++index) {
if (strides[index] != strides[index + 1] * sizes[index + 1]) {
isContiguous = false;
break;
}
}
}
// Only contiguous source tensors supported atm.
if (failed(successStrides) || !isContiguous)
return matchFailure();
auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect());
// Create descriptor.
auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
Type llvmTargetElementTy = desc.getElementType();
// Set allocated ptr.
Value *allocated = sourceMemRef.allocatedPtr(rewriter, loc);
allocated =
rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
desc.setAllocatedPtr(rewriter, loc, allocated);
// Set aligned ptr.
Value *ptr = sourceMemRef.alignedPtr(rewriter, loc);
ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
desc.setAlignedPtr(rewriter, loc, ptr);
// Fill offset 0.
auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
desc.setOffset(rewriter, loc, zero);
// Fill size and stride descriptors in memref.
for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
int64_t index = indexedSize.index();
auto sizeAttr =
rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
desc.setSize(rewriter, loc, index, size);
auto strideAttr =
rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]);
auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
desc.setStride(rewriter, loc, index, stride);
}
rewriter.replaceOp(op, {desc});
return matchSuccess();
}
};
/// Populate the given list with patterns that convert from Vector to LLVM.
void mlir::populateVectorToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
patterns.insert<VectorBroadcastOpConversion, VectorShuffleOpConversion,
VectorExtractOpConversion, VectorInsertOpConversion,
VectorOuterProductOpConversion, VectorTypeCastOpConversion>(
converter.getDialect()->getContext(), converter);
}
namespace {
struct LowerVectorToLLVMPass : public ModulePass<LowerVectorToLLVMPass> {
void runOnModule() override;
};
} // namespace
void LowerVectorToLLVMPass::runOnModule() {
// Convert to the LLVM IR dialect using the converter defined above.
OwningRewritePatternList patterns;
LLVMTypeConverter converter(&getContext());
populateVectorToLLVMConversionPatterns(converter, patterns);
populateStdToLLVMConversionPatterns(converter, patterns);
ConversionTarget target(getContext());
target.addLegalDialect<LLVM::LLVMDialect>();
target.addDynamicallyLegalOp<FuncOp>(
[&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
if (failed(
applyPartialConversion(getModule(), target, patterns, &converter))) {
signalPassFailure();
}
}
OpPassBase<ModuleOp> *mlir::createLowerVectorToLLVMPass() {
return new LowerVectorToLLVMPass();
}
static PassRegistration<LowerVectorToLLVMPass>
pass("convert-vector-to-llvm",
"Lower the operations from the vector dialect into the LLVM dialect");