[mlir][Vector] Add basic scalable vectorization support to Linalg vectorizer

For now, only elementwise operations are supported. Operations that perform any
kind of data permutation require changes in the representation of scalable
dimensions in VectorType.

Differential Revision: https://reviews.llvm.org/D152599
This commit is contained in:
Diego Caballero 2023-06-10 00:36:33 +00:00
parent 9d5466849a
commit 77a5ea2e67
6 changed files with 339 additions and 113 deletions

View File

@ -299,6 +299,7 @@ LogicalResult promoteSubviewsPrecondition(Operation *op,
/// Return success if the operation can be vectorized.
LogicalResult vectorizeOpPrecondition(Operation *op,
ArrayRef<int64_t> inputVectorSizes = {},
ArrayRef<bool> inputScalableVecDims = {},
bool vectorizeNDExtract = false);
//===----------------------------------------------------------------------===//
@ -592,8 +593,8 @@ LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value /*buffer*/);
/// dynamic shapes.
LogicalResult vectorize(RewriterBase &rewriter, Operation *op,
ArrayRef<int64_t> inputVectorSizes = {},
bool vectorizeNDExtract = false,
bool lastVectorSizeScalable = false);
ArrayRef<bool> inputScalableVecDims = {},
bool vectorizeNDExtract = false);
/// Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);

View File

@ -3036,7 +3036,7 @@ struct VectorizationPattern : public RewritePattern {
if (!linalgOp)
return rewriter.notifyMatchFailure(op, "expected Linalg Op");
return vectorize(rewriter, linalgOp, /*inputVectorSizes=*/{},
vectorizeNDExtract);
/*scalableVecDims=*/{}, vectorizeNDExtract);
}
private:
@ -3137,16 +3137,16 @@ DiagnosedSilenceableFailure transform::MaskedVectorizeOp::apply(
}
// TODO: Check that the correct number of vectorSizes was provided.
SmallVector<bool> scalableVecDims(vectorSizes.size(), false);
scalableVecDims.back() = getLastVectorSizeScalable();
for (Operation *target : targets) {
if (!isa<linalg::LinalgOp, tensor::PadOp>(target)) {
return mlir::emitSilenceableFailure(target->getLoc())
<< "Unsupported Op, cannot vectorize";
}
if (failed(linalg::vectorize(rewriter, target, vectorSizes,
getVectorizeNdExtract(),
getLastVectorSizeScalable()))) {
if (failed(linalg::vectorize(rewriter, target, vectorSizes, scalableVecDims,
getVectorizeNdExtract()))) {
return mlir::emitSilenceableFailure(target->getLoc())
<< "Attempted to vectorize, but failed";
}

View File

@ -169,6 +169,21 @@ static Value insertConvResultSlices(RewriterBase &rewriter, Location loc,
return res;
}
/// Return true if the scalable vector dimensions are supported. For now, we
/// only support scalable vectors in the trailing dimension.
static bool areValidScalableVecDims(ArrayRef<bool> scalableVecDims) {
if (scalableVecDims.empty())
return true;
auto isScalable = [](bool isScalableVecSize) { return isScalableVecSize; };
if (std::any_of(scalableVecDims.begin(), scalableVecDims.end() - 1,
isScalable)) {
return false;
}
return true;
}
/// Contains the vectorization state and related methods used across the
/// vectorization process of a given operation.
struct VectorizationState {
@ -177,11 +192,42 @@ struct VectorizationState {
/// Initializes the vectorization state, including the computation of the
/// canonical vector shape for vectorization.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
ArrayRef<int64_t> inputVectorSizes);
ArrayRef<int64_t> inputVectorSizes,
ArrayRef<bool> inputScalableVecDims);
/// Returns the canonical vector shape used to vectorize the iteration space.
ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
/// Returns a vector type of the provided `elementType` with the canonical
/// vector shape and the corresponding fixed/scalable dimensions bit. If
/// `dimPermutation` is provided, the canonical vector dimensions are permuted
/// accordingly.
VectorType getCanonicalVecType(
Type elementType,
std::optional<AffineMap> dimPermutation = std::nullopt) const {
SmallVector<int64_t> vectorShape;
SmallVector<bool> scalableDims;
if (dimPermutation.has_value()) {
vectorShape =
applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
scalableDims =
applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
} else {
vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
}
// Make sure we don't end up with unsupported scalable vector dimensions
// after the permutation. If so, we should bail out on that operation in the
// scalable preconditions.
assert(areValidScalableVecDims(scalableDims) &&
"Permuted scalable vector dimensions are not supported");
// TODO: Extend scalable vector type to support a bit map.
bool numScalableDims = !scalableVecDims.empty() && scalableVecDims.back();
return VectorType::get(vectorShape, elementType, numScalableDims);
}
/// Masks an operation with the canonical vector mask if the operation needs
/// masking. Returns the masked operation or the original operation if masking
/// is not needed. If provided, the canonical mask for this operation is
@ -223,6 +269,10 @@ private:
/// Holds the canonical vector shape used to vectorize the iteration space.
SmallVector<int64_t> canonicalVecShape;
/// Holds the vector dimensions that are scalable in the canonical vector
/// shape.
SmallVector<bool> scalableVecDims;
/// Holds the active masks for permutations of the canonical vector iteration
/// space.
DenseMap<AffineMap, Value> activeMaskCache;
@ -268,7 +318,8 @@ VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
// TODO: Move this to the constructor when we can remove the failure cases.
LogicalResult
VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
ArrayRef<int64_t> inputVectorSizes) {
ArrayRef<int64_t> inputVectorSizes,
ArrayRef<bool> inputScalableVecDims) {
// Initialize the insertion point.
rewriter.setInsertionPoint(linalgOp);
@ -277,15 +328,22 @@ VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
// path should be taken to vectorize code with dynamic shapes and when using
// vector sizes greater than the iteration space sizes.
canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
scalableVecDims.append(inputScalableVecDims.begin(),
inputScalableVecDims.end());
} else {
// Compute the canonical vector shape from the operation shape. If there are
// dynamic shapes, the operation won't be vectorized.
// dynamic shapes, the operation won't be vectorized. We assume all the
// vector dimensions are fixed.
canonicalVecShape = linalgOp.getStaticLoopRanges();
scalableVecDims.append(linalgOp.getNumLoops(), false);
}
LDBG("Canonical vector shape: ");
LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
LLVM_DEBUG(llvm::dbgs() << "\n");
LDBG("Scalable vector dims: ");
LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs()));
LLVM_DEBUG(llvm::dbgs() << "\n");
if (ShapedType::isDynamicShape(canonicalVecShape))
return failure();
@ -343,9 +401,10 @@ Value VectorizationState::getOrCreateMaskFor(
// TODO: Improve this check. Only projected permutation indexing maps are
// supported.
SmallVector<int64_t> permutedStaticSizes =
applyPermutationMap(maskingMap, ArrayRef<int64_t>(iterSpaceStaticSizes));
SmallVector<int64_t> maskShape =
applyPermutationMap(maskingMap, ArrayRef<int64_t>(canonicalVecShape));
applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap);
auto maskShape = maskType.getShape();
LDBG("Mask shape: ");
LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));
LLVM_DEBUG(llvm::dbgs() << "\n");
@ -362,8 +421,7 @@ Value VectorizationState::getOrCreateMaskFor(
assert(!maskShape.empty() && !upperBounds.empty() &&
"Masked 0-d vectors are not supported yet");
// Create the mask based on the dimension size values.
auto maskType = VectorType::get(maskShape, rewriter.getI1Type());
// Create the mask based on the dimension values.
Value mask = rewriter.create<vector::CreateMaskOp>(linalgOp.getLoc(),
maskType, upperBounds);
LDBG("Creating new mask: " << mask << "\n");
@ -504,18 +562,16 @@ static Operation *matchLinalgReduction(OpOperand *outputOperand) {
/// Broadcast `value` to a vector of `shape` if possible. Return value
/// otherwise.
static Value broadcastIfNeeded(OpBuilder &b, Value value,
ArrayRef<int64_t> shape) {
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
auto dstVecType = dyn_cast<VectorType>(dstType);
// If no shape to broadcast to, just return `value`.
if (shape.empty())
if (dstVecType.getRank() == 0)
return value;
VectorType targetVectorType =
VectorType::get(shape, getElementTypeOrSelf(value));
if (vector::isBroadcastableTo(value.getType(), targetVectorType) !=
if (vector::isBroadcastableTo(value.getType(), dstVecType) !=
vector::BroadcastableToResult::Success)
return value;
Location loc = b.getInsertionPoint()->getLoc();
return b.createOrFold<vector::BroadcastOp>(loc, targetVectorType, value);
return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
}
/// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
@ -549,16 +605,15 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value,
Location loc = value.getLoc();
auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
auto vectorType =
VectorType::get(opOperandMap.compose(state.getCanonicalVecShape()),
getElementTypeOrSelf(outputOperand->get().getType()));
auto vectorType = state.getCanonicalVecType(
getElementTypeOrSelf(outputOperand->get().getType()), opOperandMap);
Operation *write;
if (vectorType.getRank() > 0) {
AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap));
SmallVector<Value> indices(linalgOp.getRank(outputOperand),
rewriter.create<arith::ConstantIndexOp>(loc, 0));
value = broadcastIfNeeded(rewriter, value, vectorType.getShape());
value = broadcastIfNeeded(rewriter, value, vectorType);
write = rewriter.create<vector::TransferWriteOp>(
loc, value, outputOperand->get(), indices, writeMap);
} else {
@ -639,10 +694,10 @@ static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
return VectorizationResult{VectorizationStatus::Failure, nullptr};
auto loc = indexOp.getLoc();
// Compute the static loop sizes of the index op.
auto targetShape = llvm::to_vector(state.getCanonicalVecShape());
auto targetShape = state.getCanonicalVecShape();
// Compute a one-dimensional index vector for the index op dimension.
SmallVector<int64_t> constantSeq =
llvm::to_vector<16>(llvm::seq<int64_t>(0, targetShape[indexOp.getDim()]));
auto constantSeq =
llvm::to_vector(llvm::seq<int64_t>(0, targetShape[indexOp.getDim()]));
auto indexSteps = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIndexVectorAttr(constantSeq));
// Return the one-dimensional index vector if it lives in the trailing
@ -653,9 +708,15 @@ static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
// Otherwise permute the targetShape to move the index dimension last,
// broadcast the one-dimensional index vector to the permuted shape, and
// finally transpose the broadcasted index vector to undo the permutation.
std::swap(targetShape[indexOp.getDim()], targetShape.back());
auto permPattern =
llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
std::swap(permPattern[indexOp.getDim()], permPattern.back());
auto permMap =
AffineMap::getPermutationMap(permPattern, linalgOp.getContext());
auto broadCastOp = rewriter.create<vector::BroadcastOp>(
loc, VectorType::get(targetShape, rewriter.getIndexType()), indexSteps);
loc, state.getCanonicalVecType(rewriter.getIndexType(), permMap),
indexSteps);
SmallVector<int64_t> transposition =
llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
std::swap(transposition.back(), transposition[indexOp.getDim()]);
@ -698,15 +759,15 @@ tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract) {
/// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to:
/// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3
static Value calculateGatherOffset(RewriterBase &rewriter,
VectorizationState &state,
tensor::ExtractOp extractOp,
const IRMapping &bvm,
const ArrayRef<int64_t> targetShape) {
// The vector of indices for GatherOp should be shaped as the output vector
auto indexVecType = VectorType::get(targetShape, rewriter.getIndexType());
const IRMapping &bvm) {
// The vector of indices for GatherOp should be shaped as the output vector.
auto indexVecType = state.getCanonicalVecType(rewriter.getIndexType());
auto loc = extractOp.getLoc();
Value offset = broadcastIfNeeded(
rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType.getShape());
rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType);
const size_t numIndices = extractOp.getIndices().size();
for (size_t i = 1; i < numIndices; i++) {
@ -715,13 +776,12 @@ static Value calculateGatherOffset(RewriterBase &rewriter,
auto dimSize = broadcastIfNeeded(
rewriter,
rewriter.create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
indexVecType.getShape());
indexVecType);
offset = rewriter.create<arith::MulIOp>(loc, offset, dimSize);
auto extractOpIndex =
broadcastIfNeeded(rewriter, bvm.lookup(extractOp.getIndices()[i]),
indexVecType.getShape());
auto extractOpIndex = broadcastIfNeeded(
rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType);
offset = rewriter.create<arith::AddIOp>(loc, extractOpIndex, offset);
}
@ -935,14 +995,11 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
auto loc = extractOp.getLoc();
// Compute the static loop sizes of the extract op.
auto targetShape = state.getCanonicalVecShape();
auto resultType =
VectorType::get(targetShape, extractOp.getResult().getType());
auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
auto maskConstantOp = rewriter.create<arith::ConstantOp>(
loc, DenseIntElementsAttr::get(
VectorType::get(targetShape, rewriter.getI1Type()),
/*value=*/true));
loc,
DenseIntElementsAttr::get(state.getCanonicalVecType(rewriter.getI1Type()),
/*value=*/true));
auto passThruConstantOp =
rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(resultType));
@ -957,7 +1014,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
// 1. Handle gather access
if (memAccessKind == VectorMemoryAccessKind::Gather) {
Value offset = calculateGatherOffset(rewriter, extractOp, bvm, targetShape);
Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm);
// Generate the gather load
Operation *gatherOp = rewriter.create<vector::GatherOp>(
@ -1090,8 +1147,8 @@ static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
/// This function does not update `bvm` but returns a VectorizationStatus that
/// instructs the caller what `bvm` update needs to occur.
static VectorizationResult
vectorizeOneOp(RewriterBase &rewriter, LinalgOp linalgOp, Operation *op,
const IRMapping &bvm,
vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
LDBG("vectorize op " << *op << "\n");
@ -1139,33 +1196,41 @@ vectorizeOneOp(RewriterBase &rewriter, LinalgOp linalgOp, Operation *op,
}
// 5. Generic vectorization path for ElementwiseMappable ops.
// a. first get the first max ranked shape.
SmallVector<int64_t, 4> firstMaxRankedShape;
// a. Get the first max ranked shape.
VectorType firstMaxRankedType;
for (Value operand : op->getOperands()) {
auto vt = dyn_cast<VectorType>(bvm.lookup(operand).getType());
if (vt && firstMaxRankedShape.size() < vt.getShape().size())
firstMaxRankedShape.assign(vt.getShape().begin(), vt.getShape().end());
auto vecType = dyn_cast<VectorType>(bvm.lookup(operand).getType());
if (vecType && (!firstMaxRankedType ||
firstMaxRankedType.getRank() < vecType.getRank()))
firstMaxRankedType = vecType;
}
// b. Broadcast each op if needed.
SmallVector<Value> vectorizedOperands;
for (Value scalarOperand : op->getOperands()) {
Value vectorizedOperand = bvm.lookup(scalarOperand);
auto vecType =
VectorType::get(firstMaxRankedType.getShape(),
getElementTypeOrSelf(vectorizedOperand.getType()),
firstMaxRankedType.getNumScalableDims());
vectorizedOperands.push_back(
!firstMaxRankedType
? vectorizedOperand
: broadcastIfNeeded(rewriter, vectorizedOperand, vecType));
}
// rewriter. broadcast each op if needed.
auto vectorizedOperands = llvm::map_range(op->getOperands(), [&](Value v) {
return firstMaxRankedShape.empty()
? bvm.lookup(v)
: broadcastIfNeeded(rewriter, bvm.lookup(v),
firstMaxRankedShape);
});
// c. for elementwise, the result is the vector with the firstMaxRankedShape
auto returnTypes = llvm::map_range(op->getResultTypes(), [&](Type t) {
return firstMaxRankedShape.empty()
? t
: VectorType::get(firstMaxRankedShape, t);
});
// Build and return the new op.
SmallVector<Type> resultTypes;
for (Type resultType : op->getResultTypes()) {
resultTypes.push_back(
!firstMaxRankedType
? resultType
: VectorType::get(firstMaxRankedType.getShape(), resultType,
firstMaxRankedType.getNumScalableDims()));
}
// d. Build and return the new op.
return VectorizationResult{
VectorizationStatus::NewOp,
rewriter.create(op->getLoc(), op->getName().getIdentifier(),
llvm::to_vector<4>(vectorizedOperands),
llvm::to_vector<4>(returnTypes), op->getAttrs())};
vectorizedOperands, resultTypes, op->getAttrs())};
}
/// Generic vectorization function that rewrites the body of a `linalgOp` into
@ -1232,22 +1297,21 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
AffineMap maskingMap = indexingMap.dropResults(zeroPos);
AffineMap readMap;
SmallVector<int64_t> readVecShape;
VectorType readType;
Type elemType = getElementTypeOrSelf(opOperand->get());
if (linalgOp.isDpsInput(opOperand)) {
// 3.a.i. For input reads we use the canonical vector shape.
readMap = inverseAndBroadcastProjectedPermutation(indexingMap);
readVecShape = llvm::to_vector(state.getCanonicalVecShape());
readType = state.getCanonicalVecType(elemType);
} else {
// 3.a.ii. For output reads (iteration-carried dependence, e.g.,
// reductions), the vector shape is computed by mapping the canonical
// vector shape to the output domain and back to the canonical domain.
readMap = inversePermutation(reindexIndexingMap(indexingMap));
readVecShape =
readMap.compose(indexingMap.compose(state.getCanonicalVecShape()));
readType =
state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
}
auto readType =
VectorType::get(readVecShape, getElementTypeOrSelf(opOperand->get()));
SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
Operation *read = rewriter.create<vector::TransferReadOp>(
@ -1265,7 +1329,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
// 3.c. Not all ops support 0-d vectors, extract the scalar for now.
// TODO: remove this.
if (cast<VectorType>(readValue.getType()).getRank() == 0)
if (readType.getRank() == 0)
readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
@ -1299,7 +1363,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
// 5. Iteratively call `vectorizeOneOp` to each op in the slice.
for (Operation &op : block->getOperations()) {
VectorizationResult result =
vectorizeOneOp(rewriter, linalgOp, &op, bvm, hooks);
vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
if (result.status == VectorizationStatus::Failure) {
LDBG("failed to vectorize: " << op << "\n");
return failure();
@ -1526,10 +1590,38 @@ vectorizePadOpPrecondition(tensor::PadOp padOp,
return success();
}
LogicalResult
mlir::linalg::vectorizeOpPrecondition(Operation *op,
ArrayRef<int64_t> inputVectorSizes,
bool vectorizeNDExtract) {
/// Preconditions for scalable vectors.
static LogicalResult
vectorizeScalableVectorPrecondition(Operation *op,
ArrayRef<int64_t> inputVectorSizes,
ArrayRef<bool> inputScalableVecDims) {
assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
"Number of input vector sizes and scalable dims doesn't match");
if (inputVectorSizes.empty())
return success();
if (!areValidScalableVecDims(inputScalableVecDims)) {
LDBG("Non-trailing scalable vector dimensions are not supported\n");
return failure();
}
bool isScalable = inputScalableVecDims.back();
if (!isScalable)
return success();
// Only element-wise ops supported in the presence of scalable dims.
auto linalgOp = dyn_cast<LinalgOp>(op);
return success(linalgOp && isElementwise(linalgOp));
}
LogicalResult mlir::linalg::vectorizeOpPrecondition(
Operation *op, ArrayRef<int64_t> inputVectorSizes,
ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract) {
if (failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes,
inputScalableVecDims)))
return failure();
return TypeSwitch<Operation *, LogicalResult>(op)
.Case<linalg::LinalgOp>([&](auto linalgOp) {
return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
@ -1564,19 +1656,18 @@ static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
/// operations with dynamic shapes.
LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
ArrayRef<int64_t> inputVectorSizes,
bool vectorizeNDExtract,
bool lastVectorSizeScalable) {
ArrayRef<bool> inputScalableVecDims,
bool vectorizeNDExtract) {
LDBG("Attempting to vectorize:\n" << *op << "\n");
LDBG("Input vector sizes: ");
LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
LLVM_DEBUG(llvm::dbgs() << "\n");
LDBG("Scalable vectorisation: " << lastVectorSizeScalable << "\n");
LDBG("Input scalable vector dims: ");
LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
LLVM_DEBUG(llvm::dbgs() << "\n");
if (lastVectorSizeScalable)
op->emitWarning("Scalable vectorization is not supported yet");
if (failed(
vectorizeOpPrecondition(op, inputVectorSizes, vectorizeNDExtract))) {
if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
vectorizeNDExtract))) {
LDBG("Vectorization pre-conditions failed\n");
return failure();
}
@ -1584,7 +1675,8 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
// Initialize vectorization state.
VectorizationState state(rewriter);
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
if (failed(state.initState(rewriter, linalgOp, inputVectorSizes))) {
if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
inputScalableVecDims))) {
LDBG("Vectorization state couldn't be initialized\n");
return failure();
}

View File

@ -346,7 +346,8 @@ LogicalResult MultiDimReductionOp::verify() {
Type MultiDimReductionOp::getExpectedMaskType() {
auto vecType = getSourceVectorType();
return VectorType::get(vecType.getShape(),
IntegerType::get(vecType.getContext(), /*width=*/1));
IntegerType::get(vecType.getContext(), /*width=*/1),
vecType.getNumScalableDims());
}
namespace {
@ -483,8 +484,9 @@ void ReductionOp::print(OpAsmPrinter &p) {
/// Returns the mask type expected by this operation.
Type ReductionOp::getExpectedMaskType() {
auto vecType = getSourceVectorType();
return vecType.cloneWith(std::nullopt,
IntegerType::get(vecType.getContext(), /*width=*/1));
return VectorType::get(vecType.getShape(),
IntegerType::get(vecType.getContext(), /*width=*/1),
vecType.getNumScalableDims());
}
Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
@ -926,6 +928,10 @@ Type ContractionOp::getExpectedMaskType() {
assert(!ShapedType::isDynamicShape(maskShape) &&
"Mask shape couldn't be computed");
// TODO: Extend the scalable vector type representation with a bit map.
assert(lhsType.getNumScalableDims() == 0 &&
rhsType.getNumScalableDims() == 0 &&
"Scalable vectors are not supported yet");
return VectorType::get(maskShape,
IntegerType::get(lhsType.getContext(), /*width=*/1));
@ -2856,7 +2862,8 @@ LogicalResult OuterProductOp::verify() {
Type OuterProductOp::getExpectedMaskType() {
auto vecType = this->getResultVectorType();
return VectorType::get(vecType.getShape(),
IntegerType::get(vecType.getContext(), /*width=*/1));
IntegerType::get(vecType.getContext(), /*width=*/1),
vecType.getNumScalableDims());
}
//===----------------------------------------------------------------------===//
@ -3509,9 +3516,12 @@ static VectorType inferTransferOpMaskType(VectorType vecType,
AffineMap permMap) {
auto i1Type = IntegerType::get(permMap.getContext(), 1);
AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
// TODO: Extend the scalable vector type representation with a bit map.
assert((permMap.isMinorIdentity() || vecType.getNumScalableDims() == 0) &&
"Scalable vectors are not supported yet");
assert(invPermMap && "Inversed permutation map couldn't be computed");
SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
return VectorType::get(maskShape, i1Type);
return VectorType::get(maskShape, i1Type, vecType.getNumScalableDims());
}
ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
@ -4470,7 +4480,8 @@ LogicalResult GatherOp::verify() {
Type GatherOp::getExpectedMaskType() {
auto vecType = this->getIndexVectorType();
return VectorType::get(vecType.getShape(),
IntegerType::get(vecType.getContext(), /*width=*/1));
IntegerType::get(vecType.getContext(), /*width=*/1),
vecType.getNumScalableDims());
}
std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file --verify-diagnostics | FileCheck %s
// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file | FileCheck %s
func.func @vectorize_dynamic_identity(%arg0: tensor<?xf32>,
%arg1: tensor<?xf32>,
@ -485,17 +485,3 @@ transform.sequence failures(propagate) {
transform.structured.masked_vectorize %0 vector_sizes [8, 16, 4] : !transform.any_op
}
// -----
func.func @vectorize_dynamic_matmul_scalable(%A: memref<?x?xf32>, %B: memref<?x?xf32>, %C: memref<?x?xf32>) {
// expected-warning @+1 {{Scalable vectorization is not supported yet}}
linalg.matmul ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
outs(%C: memref<?x?xf32>)
return
}
transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.structured.masked_vectorize %0 vector_sizes [8, 16, [4]] : !transform.any_op
}

View File

@ -0,0 +1,136 @@
// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file | FileCheck %s
func.func @vectorize_dynamic_identity(%arg0: tensor<?xf32>,
%arg1: tensor<?xf32>,
%arg2: tensor<?xf32>) -> tensor<?xf32> {
%0 = linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>,
affine_map<(d0) -> (d0)>,
affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"] }
ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
outs(%arg2 : tensor<?xf32>) {
^bb(%in0: f32, %in1: f32, %out: f32) :
%0 = arith.addf %in0, %in1 : f32
linalg.yield %0 : f32
} -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// CHECK-LABEL: @vectorize_dynamic_identity
// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_4:.*]] = tensor.dim %{{.*}}, %[[VAL_3]] : tensor<?xf32>
// CHECK: %[[VAL_7:.*]] = vector.create_mask %[[VAL_4]] : vector<[4]xi1>
// CHECK: %[[VAL_8:.*]] = vector.mask %[[VAL_7]] { vector.transfer_read %{{.*}} {in_bounds = [true]} : tensor<?xf32>, vector<[4]xf32> } : vector<[4]xi1> -> vector<[4]xf32>
// CHECK: %[[VAL_10:.*]] = vector.mask %[[VAL_7]] { vector.transfer_read %{{.*}} {in_bounds = [true]} : tensor<?xf32>, vector<[4]xf32> } : vector<[4]xi1> -> vector<[4]xf32>
// CHECK: %[[VAL_12:.*]] = vector.mask %[[VAL_7]] { vector.transfer_read %{{.*}} {in_bounds = [true]} : tensor<?xf32>, vector<[4]xf32> } : vector<[4]xi1> -> vector<[4]xf32>
// CHECK: %[[VAL_13:.*]] = arith.addf %[[VAL_8]], %[[VAL_10]] : vector<[4]xf32>
// CHECK: %[[VAL_14:.*]] = vector.mask %[[VAL_7]] { vector.transfer_write %{{.*}} {in_bounds = [true]} : vector<[4]xf32>, tensor<?xf32> } : vector<[4]xi1> -> tensor<?xf32>
transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.structured.masked_vectorize %0 vector_sizes [[4]] : !transform.any_op
}
// -----
func.func @vectorize_partial_dynamic_identity(%arg0: tensor<8x?xf32>,
%arg1: tensor<8x?xf32>,
%arg2: tensor<8x?xf32>) -> tensor<8x?xf32> {
%0 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"] }
ins(%arg0, %arg1 : tensor<8x?xf32>, tensor<8x?xf32>)
outs(%arg2 : tensor<8x?xf32>) {
^bb(%in0: f32, %in1: f32, %out: f32) :
%0 = arith.addf %in0, %in1 : f32
linalg.yield %0 : f32
} -> tensor<8x?xf32>
return %0 : tensor<8x?xf32>
}
// CHECK-LABEL: func.func @vectorize_partial_dynamic_identity(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x?xf32>, %[[VAL_1:.*]]: tensor<8x?xf32>, %[[VAL_2:.*]]: tensor<8x?xf32>) -> tensor<8x?xf32> {
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[VAL_4:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor<8x?xf32>
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 8 : index
// CHECK: %[[VAL_8:.*]] = vector.create_mask %[[VAL_7]], %[[VAL_4]] : vector<8x[32]xi1>
// CHECK: %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.transfer_read %[[VAL_0]][%[[VAL_5]], %[[VAL_5]]], %[[VAL_6]] {in_bounds = [true, true]} : tensor<8x?xf32>, vector<8x[32]xf32> } : vector<8x[32]xi1> -> vector<8x[32]xf32>
// CHECK: %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[VAL_11:.*]] = vector.mask %[[VAL_8]] { vector.transfer_read %[[VAL_1]][%[[VAL_5]], %[[VAL_5]]], %[[VAL_10]] {in_bounds = [true, true]} : tensor<8x?xf32>, vector<8x[32]xf32> } : vector<8x[32]xi1> -> vector<8x[32]xf32>
// CHECK: %[[VAL_12:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[VAL_13:.*]] = vector.mask %[[VAL_8]] { vector.transfer_read %[[VAL_2]][%[[VAL_5]], %[[VAL_5]]], %[[VAL_12]] {in_bounds = [true, true]} : tensor<8x?xf32>, vector<8x[32]xf32> } : vector<8x[32]xi1> -> vector<8x[32]xf32>
// CHECK: %[[VAL_14:.*]] = arith.addf %[[VAL_9]], %[[VAL_11]] : vector<8x[32]xf32>
// CHECK: %[[VAL_15:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_16:.*]] = vector.mask %[[VAL_8]] { vector.transfer_write %[[VAL_14]], %[[VAL_2]][%[[VAL_15]], %[[VAL_15]]] {in_bounds = [true, true]} : vector<8x[32]xf32>, tensor<8x?xf32> } : vector<8x[32]xi1> -> tensor<8x?xf32>
transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.structured.masked_vectorize %0 vector_sizes [8, [32]] : !transform.any_op
}
// -----
func.func @vectorize_static_shape_with_mask(%arg0: tensor<8x30xf32>,
%arg1: tensor<8x30xf32>,
%arg2: tensor<8x30xf32>) -> tensor<8x30xf32> {
%0 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"] }
ins(%arg0, %arg1 : tensor<8x30xf32>, tensor<8x30xf32>)
outs(%arg2 : tensor<8x30xf32>) {
^bb(%in0: f32, %in1: f32, %out: f32) :
%0 = arith.addf %in0, %in1 : f32
linalg.yield %0 : f32
} -> tensor<8x30xf32>
return %0 : tensor<8x30xf32>
}
// CHECK-LABEL: func.func @vectorize_static_shape_with_mask(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x30xf32>, %[[VAL_1:.*]]: tensor<8x30xf32>, %[[VAL_2:.*]]: tensor<8x30xf32>) -> tensor<8x30xf32> {
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 8 : index
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 30 : index
// CHECK: %[[VAL_7:.*]] = vector.create_mask %[[VAL_5]], %[[VAL_6]] : vector<8x[32]xi1>
// CHECK: %[[VAL_8:.*]] = vector.mask %[[VAL_7]] { vector.transfer_read %[[VAL_0]][%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : tensor<8x30xf32>, vector<8x[32]xf32> } : vector<8x[32]xi1> -> vector<8x[32]xf32>
// CHECK: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[VAL_10:.*]] = vector.mask %[[VAL_7]] { vector.transfer_read %[[VAL_1]][%[[VAL_3]], %[[VAL_3]]], %[[VAL_9]] {in_bounds = [true, true]} : tensor<8x30xf32>, vector<8x[32]xf32> } : vector<8x[32]xi1> -> vector<8x[32]xf32>
// CHECK: %[[VAL_11:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[VAL_12:.*]] = vector.mask %[[VAL_7]] { vector.transfer_read %[[VAL_2]][%[[VAL_3]], %[[VAL_3]]], %[[VAL_11]] {in_bounds = [true, true]} : tensor<8x30xf32>, vector<8x[32]xf32> } : vector<8x[32]xi1> -> vector<8x[32]xf32>
// CHECK: %[[VAL_13:.*]] = arith.addf %[[VAL_8]], %[[VAL_10]] : vector<8x[32]xf32>
// CHECK: %[[VAL_14:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_15:.*]] = vector.mask %[[VAL_7]] { vector.transfer_write %[[VAL_13]], %[[VAL_2]][%[[VAL_14]], %[[VAL_14]]] {in_bounds = [true, true]} : vector<8x[32]xf32>, tensor<8x30xf32> } : vector<8x[32]xi1> -> tensor<8x30xf32>
transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.structured.masked_vectorize %0 vector_sizes [8, [32]] : !transform.any_op
}
// -----
func.func @vectorize_dynamic_fill(%A : tensor<?x?xf32>, %arg0 : f32) -> tensor<?x?xf32> {
%0 = linalg.fill ins(%arg0 : f32) outs(%A : tensor<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
// CHECK-LABEL: func.func @vectorize_dynamic_fill
// CHECK: %[[DIM0:.*]] = tensor.dim
// CHECK: %[[DIM1:.*]] = tensor.dim
// CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM0]], %[[DIM1]] : vector<8x[16]xi1>
// CHECK: %[[BCAST:.*]] = vector.broadcast %{{.*}} : f32 to vector<8x[16]xf32>
// CHECK: vector.mask %[[MASK]] { vector.transfer_write %[[BCAST]], {{.*}} {in_bounds = [true, true]} : vector<8x[16]xf32>, tensor<?x?xf32> } : vector<8x[16]xi1>
transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.structured.masked_vectorize %0 vector_sizes [8, [16]] : !transform.any_op
}