mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-04-04 06:12:19 +00:00

This helps to understand what the problem is when vectorization of structured ops failes due to mismatching vector sizes.
3174 lines
127 KiB
C++
3174 lines
127 KiB
C++
//===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements the linalg dialect Vectorization transformations.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
#include "mlir/Dialect/Affine/Utils.h"
|
|
|
|
#include "mlir/Analysis/SliceAnalysis.h"
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
|
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
|
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
|
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
|
|
#include "mlir/IR/AffineExpr.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Support/LLVM.h"
|
|
#include "mlir/Transforms/RegionUtils.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/ADT/Sequence.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
#include "llvm/Support/Debug.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
#include <optional>
|
|
#include <type_traits>
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::linalg;
|
|
|
|
#define DEBUG_TYPE "linalg-vectorization"
|
|
|
|
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
|
|
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
|
|
|
|
/// Try to vectorize `convOp` as a convolution.
|
|
static FailureOr<Operation *> vectorizeConvolution(RewriterBase &rewriter,
|
|
LinalgOp convOp);
|
|
|
|
/// Return the unique instance of OpType in `block` if it is indeed unique.
|
|
/// Return null if none or more than 1 instances exist.
|
|
template <typename OpType>
|
|
static OpType getSingleOpOfType(Block &block) {
|
|
OpType res;
|
|
block.walk([&](OpType op) {
|
|
if (res) {
|
|
res = nullptr;
|
|
return WalkResult::interrupt();
|
|
}
|
|
res = op;
|
|
return WalkResult::advance();
|
|
});
|
|
return res;
|
|
}
|
|
|
|
/// Helper function to extract the input slices after filter is unrolled along
|
|
/// kw.
|
|
static SmallVector<Value>
|
|
extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input,
|
|
int64_t nSize, int64_t wSize, int64_t cSize,
|
|
int64_t kwSize, int strideW, int dilationW,
|
|
int64_t wSizeStep, bool isSingleChanneled) {
|
|
SmallVector<Value> result;
|
|
if (isSingleChanneled) {
|
|
// Extract input slice of size {wSizeStep} @ [w + kw] for non-channeled
|
|
// convolution.
|
|
SmallVector<int64_t> sizes{wSizeStep};
|
|
SmallVector<int64_t> strides{1};
|
|
for (int64_t kw = 0; kw < kwSize; ++kw) {
|
|
for (int64_t w = 0; w < wSize; w += wSizeStep) {
|
|
result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
|
|
loc, input, /*offsets=*/ArrayRef<int64_t>{w + kw}, sizes, strides));
|
|
}
|
|
}
|
|
} else {
|
|
// Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0]
|
|
// for channeled convolution.
|
|
SmallVector<int64_t> sizes{nSize, wSizeStep, cSize};
|
|
SmallVector<int64_t> strides{1, 1, 1};
|
|
for (int64_t kw = 0; kw < kwSize; ++kw) {
|
|
for (int64_t w = 0; w < wSize; w += wSizeStep) {
|
|
result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
|
|
loc, input,
|
|
/*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
|
|
sizes, strides));
|
|
}
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
/// Helper function to extract the filter slices after filter is unrolled along
|
|
/// kw.
|
|
static SmallVector<Value> extractConvFilterSlices(RewriterBase &rewriter,
|
|
Location loc, Value filter,
|
|
int64_t kwSize) {
|
|
SmallVector<Value> result;
|
|
// Extract rhs slice of size [{c, f} for channeled convolutions and {1} for
|
|
// non-chanelled convolution] @ [kw].
|
|
for (int64_t kw = 0; kw < kwSize; ++kw) {
|
|
result.push_back(rewriter.create<vector::ExtractOp>(
|
|
loc, filter, /*offsets=*/ArrayRef<int64_t>{kw}));
|
|
}
|
|
return result;
|
|
}
|
|
|
|
/// Helper function to extract the result slices after filter is unrolled along
|
|
/// kw.
|
|
static SmallVector<Value>
|
|
extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res,
|
|
int64_t nSize, int64_t wSize, int64_t fSize,
|
|
int64_t wSizeStep, bool isSingleChanneled) {
|
|
SmallVector<Value> result;
|
|
if (isSingleChanneled) {
|
|
// Extract res slice: {wSizeStep} @ [w] for non-channeled convolution.
|
|
SmallVector<int64_t> sizes{wSizeStep};
|
|
SmallVector<int64_t> strides{1};
|
|
for (int64_t w = 0; w < wSize; w += wSizeStep) {
|
|
result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
|
|
loc, res, /*offsets=*/ArrayRef<int64_t>{w}, sizes, strides));
|
|
}
|
|
} else {
|
|
// Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
|
|
// convolution.
|
|
SmallVector<int64_t> sizes{nSize, wSizeStep, fSize};
|
|
SmallVector<int64_t> strides{1, 1, 1};
|
|
for (int64_t w = 0; w < wSize; w += wSizeStep) {
|
|
result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
|
|
loc, res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, sizes, strides));
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
/// Helper function to insert the computed result slices.
|
|
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc,
|
|
Value res, int64_t wSize, int64_t wSizeStep,
|
|
SmallVectorImpl<Value> &resVals,
|
|
bool isSingleChanneled) {
|
|
|
|
if (isSingleChanneled) {
|
|
// Write back res slice: {wSizeStep} @ [w] for non-channeled convolution.
|
|
// This does not depend on kw.
|
|
SmallVector<int64_t> strides{1};
|
|
for (int64_t w = 0; w < wSize; w += wSizeStep) {
|
|
res = rewriter.create<vector::InsertStridedSliceOp>(
|
|
loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{w}, strides);
|
|
}
|
|
} else {
|
|
// Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
|
|
// convolution. This does not depend on kw.
|
|
SmallVector<int64_t> strides{1, 1, 1};
|
|
for (int64_t w = 0; w < wSize; w += wSizeStep) {
|
|
res = rewriter.create<vector::InsertStridedSliceOp>(
|
|
loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{0, w, 0},
|
|
strides);
|
|
}
|
|
}
|
|
return res;
|
|
}
|
|
|
|
/// Contains the vectorization state and related methods used across the
|
|
/// vectorization process of a given operation.
|
|
struct VectorizationState {
|
|
VectorizationState(RewriterBase &rewriter) : rewriterGuard(rewriter) {}
|
|
|
|
/// Initializes the vectorization state, including the computation of the
|
|
/// canonical vector shape for vectorization.
|
|
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
|
|
ArrayRef<int64_t> inputVectorSizes,
|
|
ArrayRef<bool> inputScalableVecDims);
|
|
|
|
/// Returns the canonical vector shape used to vectorize the iteration space.
|
|
ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
|
|
|
|
/// Returns a vector type of the provided `elementType` with the canonical
|
|
/// vector shape and the corresponding fixed/scalable dimensions bit. If
|
|
/// `dimPermutation` is provided, the canonical vector dimensions are permuted
|
|
/// accordingly.
|
|
VectorType getCanonicalVecType(
|
|
Type elementType,
|
|
std::optional<AffineMap> dimPermutation = std::nullopt) const {
|
|
SmallVector<int64_t> vectorShape;
|
|
SmallVector<bool> scalableDims;
|
|
if (dimPermutation.has_value()) {
|
|
vectorShape =
|
|
applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
|
|
scalableDims =
|
|
applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
|
|
} else {
|
|
vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
|
|
scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
|
|
}
|
|
|
|
return VectorType::get(vectorShape, elementType, scalableDims);
|
|
}
|
|
|
|
/// Masks an operation with the canonical vector mask if the operation needs
|
|
/// masking. Returns the masked operation or the original operation if masking
|
|
/// is not needed. If provided, the canonical mask for this operation is
|
|
/// permuted using `maybeMaskingMap`.
|
|
Operation *
|
|
maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
|
|
std::optional<AffineMap> maybeMaskingMap = std::nullopt);
|
|
|
|
private:
|
|
/// Initializes the iteration space static sizes using the Linalg op
|
|
/// information. This may become more complicated in the future.
|
|
void initIterSpaceStaticSizes(LinalgOp linalgOp) {
|
|
iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
|
|
}
|
|
|
|
/// Generates 'arith.constant' and 'tensor/memref.dim' operations for
|
|
/// all the static and dynamic dimensions of the iteration space to be
|
|
/// vectorized and store them in `iterSpaceValueSizes`.
|
|
LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter,
|
|
LinalgOp linalgOp);
|
|
|
|
/// Create or retrieve an existing mask value to mask `opToMask` in the
|
|
/// canonical vector iteration space. If `maybeMaskingMap` the mask is
|
|
/// permuted using that permutation map. If a new mask is created, it will be
|
|
/// cached for future users.
|
|
Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
|
|
LinalgOp linalgOp,
|
|
std::optional<AffineMap> maybeMaskingMap);
|
|
|
|
// Holds the compile-time static sizes of the iteration space to vectorize.
|
|
// Dynamic dimensions are represented using ShapedType::kDynamic.
|
|
SmallVector<int64_t> iterSpaceStaticSizes;
|
|
|
|
/// Holds the value sizes of the iteration space to vectorize. Static
|
|
/// dimensions are represented by 'arith.constant' and dynamic
|
|
/// dimensions by 'tensor/memref.dim'.
|
|
SmallVector<Value> iterSpaceValueSizes;
|
|
|
|
/// Holds the canonical vector shape used to vectorize the iteration space.
|
|
SmallVector<int64_t> canonicalVecShape;
|
|
|
|
/// Holds the vector dimensions that are scalable in the canonical vector
|
|
/// shape.
|
|
SmallVector<bool> scalableVecDims;
|
|
|
|
/// Holds the active masks for permutations of the canonical vector iteration
|
|
/// space.
|
|
DenseMap<AffineMap, Value> activeMaskCache;
|
|
|
|
/// Global vectorization guard for the incoming rewriter. It's initialized
|
|
/// when the vectorization state is initialized.
|
|
OpBuilder::InsertionGuard rewriterGuard;
|
|
};
|
|
|
|
LogicalResult
|
|
VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
|
|
LinalgOp linalgOp) {
|
|
// TODO: Support 0-d vectors.
|
|
for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
|
|
if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) {
|
|
// Create constant index op for static dimensions.
|
|
iterSpaceValueSizes.push_back(rewriter.create<arith::ConstantIndexOp>(
|
|
linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
|
|
continue;
|
|
}
|
|
|
|
// Find an operand defined on this dimension of the iteration space to
|
|
// extract the runtime dimension size.
|
|
Value operand;
|
|
unsigned operandDimPos;
|
|
if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
|
|
operandDimPos)))
|
|
return failure();
|
|
|
|
Value dynamicDim = linalgOp.hasTensorSemantics()
|
|
? (Value)rewriter.create<tensor::DimOp>(
|
|
linalgOp.getLoc(), operand, operandDimPos)
|
|
: (Value)rewriter.create<memref::DimOp>(
|
|
linalgOp.getLoc(), operand, operandDimPos);
|
|
iterSpaceValueSizes.push_back(dynamicDim);
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
/// Initializes the vectorization state, including the computation of the
|
|
/// canonical vector shape for vectorization.
|
|
// TODO: Move this to the constructor when we can remove the failure cases.
|
|
LogicalResult
|
|
VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
|
|
ArrayRef<int64_t> inputVectorSizes,
|
|
ArrayRef<bool> inputScalableVecDims) {
|
|
// Initialize the insertion point.
|
|
rewriter.setInsertionPoint(linalgOp);
|
|
|
|
if (!inputVectorSizes.empty()) {
|
|
// Get the canonical vector shape from the input vector sizes provided. This
|
|
// path should be taken to vectorize code with dynamic shapes and when using
|
|
// vector sizes greater than the iteration space sizes.
|
|
canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
|
|
scalableVecDims.append(inputScalableVecDims.begin(),
|
|
inputScalableVecDims.end());
|
|
} else {
|
|
// Compute the canonical vector shape from the operation shape. If there are
|
|
// dynamic shapes, the operation won't be vectorized. We assume all the
|
|
// vector dimensions are fixed.
|
|
canonicalVecShape = linalgOp.getStaticLoopRanges();
|
|
scalableVecDims.append(linalgOp.getNumLoops(), false);
|
|
}
|
|
|
|
LDBG("Canonical vector shape: ");
|
|
LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
|
|
LLVM_DEBUG(llvm::dbgs() << "\n");
|
|
LDBG("Scalable vector dims: ");
|
|
LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs()));
|
|
LLVM_DEBUG(llvm::dbgs() << "\n");
|
|
|
|
if (ShapedType::isDynamicShape(canonicalVecShape))
|
|
return failure();
|
|
|
|
// Initialize iteration space static sizes.
|
|
initIterSpaceStaticSizes(linalgOp);
|
|
|
|
// Generate 'arith.constant' and 'tensor/memref.dim' operations for
|
|
// all the static and dynamic dimensions of the iteration space, needed to
|
|
// compute a mask during vectorization.
|
|
if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
|
|
return failure();
|
|
|
|
return success();
|
|
}
|
|
|
|
/// Create or retrieve an existing mask value to mask `opToMask` in the
|
|
/// canonical vector iteration space. If `maybeMaskingMap` the mask is permuted
|
|
/// using that permutation map. If a new mask is created, it will be cached for
|
|
/// future users.
|
|
Value VectorizationState::getOrCreateMaskFor(
|
|
RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
|
|
std::optional<AffineMap> maybeMaskingMap) {
|
|
// No mask is needed if the operation is not maskable.
|
|
auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
|
|
if (!maskableOp)
|
|
return Value();
|
|
|
|
assert(!maskableOp.isMasked() &&
|
|
"Masking an operation that is already masked");
|
|
|
|
// If no masking map was provided, use an identity map with the loop dims.
|
|
assert((!maybeMaskingMap || *maybeMaskingMap) &&
|
|
"Unexpected null mask permutation map");
|
|
AffineMap maskingMap =
|
|
maybeMaskingMap ? *maybeMaskingMap
|
|
: AffineMap::getMultiDimIdentityMap(
|
|
linalgOp.getNumLoops(), rewriter.getContext());
|
|
|
|
LDBG("Masking map: " << maskingMap << "\n");
|
|
|
|
// Return the active mask for the masking map of this operation if it was
|
|
// already created.
|
|
auto activeMaskIt = activeMaskCache.find(maskingMap);
|
|
if (activeMaskIt != activeMaskCache.end()) {
|
|
Value mask = activeMaskIt->second;
|
|
LDBG("Reusing mask: " << mask << "\n");
|
|
return mask;
|
|
}
|
|
|
|
// Compute permuted projection of the iteration space to be masked and the
|
|
// corresponding mask shape. If the resulting iteration space dimensions are
|
|
// static and identical to the mask shape, masking is not needed for this
|
|
// operation.
|
|
// TODO: Improve this check. Only projected permutation indexing maps are
|
|
// supported.
|
|
SmallVector<int64_t> permutedStaticSizes =
|
|
applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
|
|
auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap);
|
|
auto maskShape = maskType.getShape();
|
|
|
|
LDBG("Mask shape: ");
|
|
LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));
|
|
LLVM_DEBUG(llvm::dbgs() << "\n");
|
|
|
|
if (permutedStaticSizes == maskShape) {
|
|
LDBG("Masking is not needed for masking map: " << maskingMap << "\n");
|
|
activeMaskCache[maskingMap] = Value();
|
|
return Value();
|
|
}
|
|
|
|
// Permute the iteration space value sizes to compute the mask upper bounds.
|
|
SmallVector<Value> upperBounds =
|
|
applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
|
|
assert(!maskShape.empty() && !upperBounds.empty() &&
|
|
"Masked 0-d vectors are not supported yet");
|
|
|
|
// Create the mask based on the dimension values.
|
|
Value mask = rewriter.create<vector::CreateMaskOp>(linalgOp.getLoc(),
|
|
maskType, upperBounds);
|
|
LDBG("Creating new mask: " << mask << "\n");
|
|
activeMaskCache[maskingMap] = mask;
|
|
return mask;
|
|
}
|
|
|
|
/// Masks an operation with the canonical vector mask if the operation needs
|
|
/// masking. Returns the masked operation or the original operation if masking
|
|
/// is not needed. If provided, the canonical mask for this operation is
|
|
/// permuted using `maybeMaskingMap`.
|
|
Operation *
|
|
VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask,
|
|
LinalgOp linalgOp,
|
|
std::optional<AffineMap> maybeMaskingMap) {
|
|
LDBG("Trying to mask: " << *opToMask << "\n");
|
|
|
|
// Create or retrieve mask for this operation.
|
|
Value mask =
|
|
getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
|
|
|
|
if (!mask) {
|
|
LDBG("No mask required\n");
|
|
return opToMask;
|
|
}
|
|
|
|
// Wrap the operation with a new `vector.mask` and update D-U chain.
|
|
assert(opToMask && "Expected a valid operation to mask");
|
|
auto maskOp = cast<vector::MaskOp>(
|
|
mlir::vector::maskOperation(rewriter, opToMask, mask));
|
|
Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
|
|
|
|
for (auto [resIdx, resVal] : llvm::enumerate(opToMask->getResults()))
|
|
rewriter.replaceAllUsesExcept(resVal, maskOp.getResult(resIdx),
|
|
maskOpTerminator);
|
|
|
|
LDBG("Masked operation: " << *maskOp << "\n");
|
|
return maskOp;
|
|
}
|
|
|
|
/// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
|
|
/// projectedPermutation, compress the unused dimensions to serve as a
|
|
/// permutation_map for a vector transfer operation.
|
|
/// For example, given a linalg op such as:
|
|
///
|
|
/// ```
|
|
/// %0 = linalg.generic {
|
|
/// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>,
|
|
/// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>
|
|
/// }
|
|
/// ins(%0 : tensor<2x3x4xf32>)
|
|
/// outs(%1 : tensor<5x6xf32>)
|
|
/// ```
|
|
///
|
|
/// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine
|
|
/// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second
|
|
/// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`.
|
|
static AffineMap reindexIndexingMap(AffineMap map) {
|
|
assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) &&
|
|
"expected projected permutation");
|
|
auto res = compressUnusedDims(map);
|
|
assert(res.getNumDims() == res.getNumResults() &&
|
|
"expected reindexed map with same number of dims and results");
|
|
return res;
|
|
}
|
|
|
|
/// Helper enum to represent conv1d input traversal order.
|
|
enum class Conv1DOpOrder {
|
|
W, // Corresponds to non-channeled 1D convolution operation.
|
|
Ncw, // Corresponds to operation that traverses the input in (n, c, w) order.
|
|
Nwc // Corresponds to operation that traverses the input in (n, w, c) order.
|
|
};
|
|
|
|
/// Helper data structure to represent the result of vectorization.
|
|
/// In certain specific cases, like terminators, we do not want to propagate/
|
|
enum VectorizationStatus {
|
|
/// Op failed to vectorize.
|
|
Failure = 0,
|
|
/// Op vectorized and custom function took care of replacement logic
|
|
NoReplace,
|
|
/// Op vectorized into a new Op whose results will replace original Op's
|
|
/// results.
|
|
NewOp
|
|
// TODO: support values if Op vectorized to Many-Ops whose results we need to
|
|
// aggregate for replacement.
|
|
};
|
|
struct VectorizationResult {
|
|
/// Return status from vectorizing the current op.
|
|
enum VectorizationStatus status = VectorizationStatus::Failure;
|
|
/// New vectorized operation to replace the current op.
|
|
/// Replacement behavior is specified by `status`.
|
|
Operation *newOp;
|
|
};
|
|
|
|
std::optional<vector::CombiningKind>
|
|
mlir::linalg::getCombinerOpKind(Operation *combinerOp) {
|
|
using ::mlir::vector::CombiningKind;
|
|
|
|
if (!combinerOp)
|
|
return std::nullopt;
|
|
return llvm::TypeSwitch<Operation *, std::optional<CombiningKind>>(combinerOp)
|
|
.Case<arith::AddIOp, arith::AddFOp>(
|
|
[&](auto op) { return CombiningKind::ADD; })
|
|
.Case<arith::AndIOp>([&](auto op) { return CombiningKind::AND; })
|
|
.Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; })
|
|
.Case<arith::MaxUIOp>([&](auto op) { return CombiningKind::MAXUI; })
|
|
.Case<arith::MaximumFOp>([&](auto op) { return CombiningKind::MAXIMUMF; })
|
|
.Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; })
|
|
.Case<arith::MinUIOp>([&](auto op) { return CombiningKind::MINUI; })
|
|
.Case<arith::MinimumFOp>([&](auto op) { return CombiningKind::MINIMUMF; })
|
|
.Case<arith::MulIOp, arith::MulFOp>(
|
|
[&](auto op) { return CombiningKind::MUL; })
|
|
.Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })
|
|
.Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; })
|
|
.Default([&](auto op) { return std::nullopt; });
|
|
}
|
|
|
|
/// Check whether `outputOperand` is a reduction with a single combiner
|
|
/// operation. Return the combiner operation of the reduction. Return
|
|
/// nullptr otherwise. Multiple reduction operations would impose an
|
|
/// ordering between reduction dimensions and is currently unsupported in
|
|
/// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
|
|
/// max(min(X))
|
|
// TODO: use in LinalgOp verification, there is a circular dependency atm.
|
|
static Operation *matchLinalgReduction(OpOperand *outputOperand) {
|
|
auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
|
|
unsigned outputPos =
|
|
outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs();
|
|
// Only single combiner operations are supported for now.
|
|
SmallVector<Operation *, 4> combinerOps;
|
|
if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
|
|
combinerOps.size() != 1)
|
|
return nullptr;
|
|
|
|
// Return the combiner operation.
|
|
return combinerOps[0];
|
|
}
|
|
|
|
/// Broadcast `value` to a vector of `shape` if possible. Return value
|
|
/// otherwise.
|
|
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
|
|
auto dstVecType = dyn_cast<VectorType>(dstType);
|
|
// If no shape to broadcast to, just return `value`.
|
|
if (dstVecType.getRank() == 0)
|
|
return value;
|
|
if (vector::isBroadcastableTo(value.getType(), dstVecType) !=
|
|
vector::BroadcastableToResult::Success)
|
|
return value;
|
|
Location loc = b.getInsertionPoint()->getLoc();
|
|
return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
|
|
}
|
|
|
|
/// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
|
|
/// assumes that `reductionOp` has two operands and one of them is the reduction
|
|
/// initial value.buildMultiDimReduce
|
|
// Note: this is a true builder that notifies the OpBuilder listener.
|
|
// TODO: Consider moving as a static helper on the ReduceOp.
|
|
static Operation *buildMultiDimReduce(OpBuilder &b, Operation *reduceOp,
|
|
Value valueToReduce, Value acc,
|
|
ArrayRef<bool> dimsToMask) {
|
|
auto maybeKind = getCombinerOpKind(reduceOp);
|
|
assert(maybeKind && "Failed precondition: could not get reduction kind");
|
|
return b.create<vector::MultiDimReductionOp>(
|
|
reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
|
|
}
|
|
|
|
static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
|
|
return llvm::to_vector(
|
|
llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator));
|
|
}
|
|
|
|
/// Build a vector.transfer_write of `value` into `outputOperand` at indices set
|
|
/// to all `0`; where `outputOperand` is an output operand of the LinalgOp
|
|
/// currently being vectorized. If `dest` has null rank, build an memref.store.
|
|
/// Return the produced value or null if no value is produced.
|
|
// Note: this is a true builder that notifies the OpBuilder listener.
|
|
// TODO: Consider moving as a static helper on the ReduceOp.
|
|
static Value buildVectorWrite(RewriterBase &rewriter, Value value,
|
|
OpOperand *outputOperand,
|
|
VectorizationState &state) {
|
|
Location loc = value.getLoc();
|
|
auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
|
|
AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
|
|
|
|
// Compute the vector type of the value to store. This type should be an
|
|
// identity or projection of the canonical vector type without any permutation
|
|
// applied, given that any permutation in a transfer write happens as part of
|
|
// the write itself.
|
|
AffineMap vectorTypeMap = AffineMap::getFilteredIdentityMap(
|
|
opOperandMap.getContext(), opOperandMap.getNumInputs(),
|
|
[&](AffineDimExpr dimExpr) -> bool {
|
|
return llvm::is_contained(opOperandMap.getResults(), dimExpr);
|
|
});
|
|
auto vectorType = state.getCanonicalVecType(
|
|
getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap);
|
|
|
|
Operation *write;
|
|
if (vectorType.getRank() > 0) {
|
|
AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap));
|
|
SmallVector<Value> indices(linalgOp.getRank(outputOperand),
|
|
rewriter.create<arith::ConstantIndexOp>(loc, 0));
|
|
value = broadcastIfNeeded(rewriter, value, vectorType);
|
|
assert(value.getType() == vectorType && "Incorrect type");
|
|
write = rewriter.create<vector::TransferWriteOp>(
|
|
loc, value, outputOperand->get(), indices, writeMap);
|
|
} else {
|
|
// 0-d case is still special: do not invert the reindexing writeMap.
|
|
if (!isa<VectorType>(value.getType()))
|
|
value = rewriter.create<vector::BroadcastOp>(loc, vectorType, value);
|
|
assert(value.getType() == vectorType && "Incorrect type");
|
|
write = rewriter.create<vector::TransferWriteOp>(
|
|
loc, value, outputOperand->get(), ValueRange{});
|
|
}
|
|
|
|
write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
|
|
|
|
// If masked, set in-bounds to true. Masking guarantees that the access will
|
|
// be in-bounds.
|
|
if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
|
|
auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
|
|
SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true);
|
|
maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
|
|
}
|
|
|
|
LDBG("vectorized op: " << *write << "\n");
|
|
if (!write->getResults().empty())
|
|
return write->getResult(0);
|
|
return Value();
|
|
}
|
|
|
|
// Custom vectorization precondition function type. This is intented to be used
|
|
// with CustomVectorizationHook. Returns success if the corresponding custom
|
|
// hook can vectorize the op.
|
|
using CustomVectorizationPrecondition =
|
|
std::function<LogicalResult(Operation *, bool)>;
|
|
|
|
// Custom vectorization function type. Produce a vector form of Operation*
|
|
// assuming all its vectorized operands are already in the IRMapping.
|
|
// Return nullptr if the Operation cannot be vectorized.
|
|
using CustomVectorizationHook =
|
|
std::function<VectorizationResult(Operation *, const IRMapping &)>;
|
|
|
|
/// Helper function to vectorize the terminator of a `linalgOp`. New result
|
|
/// vector values are appended to `newResults`. Return
|
|
/// VectorizationStatus::NoReplace to signal the vectorization algorithm that it
|
|
/// should not try to map produced operations and instead return the results
|
|
/// using the `newResults` vector making them available to the vectorization
|
|
/// algorithm for RAUW. This function is meant to be used as a
|
|
/// CustomVectorizationHook.
|
|
static VectorizationResult
|
|
vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
|
|
const IRMapping &bvm, VectorizationState &state,
|
|
LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
|
|
auto yieldOp = dyn_cast<linalg::YieldOp>(op);
|
|
if (!yieldOp)
|
|
return VectorizationResult{VectorizationStatus::Failure, nullptr};
|
|
for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
|
|
// TODO: Scan for an opportunity for reuse.
|
|
// TODO: use a map.
|
|
Value vectorValue = bvm.lookup(output.value());
|
|
Value newResult =
|
|
buildVectorWrite(rewriter, vectorValue,
|
|
linalgOp.getDpsInitOperand(output.index()), state);
|
|
if (newResult)
|
|
newResults.push_back(newResult);
|
|
}
|
|
|
|
return VectorizationResult{VectorizationStatus::NoReplace, nullptr};
|
|
}
|
|
|
|
/// Helper function to vectorize the index operations of a `linalgOp`. Return
|
|
/// VectorizationStatus::NewOp to signal the vectorization algorithm that it
|
|
/// should map the produced operations. This function is meant to be used as a
|
|
/// CustomVectorizationHook.
|
|
static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
|
|
VectorizationState &state,
|
|
Operation *op,
|
|
LinalgOp linalgOp) {
|
|
IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
|
|
if (!indexOp)
|
|
return VectorizationResult{VectorizationStatus::Failure, nullptr};
|
|
auto loc = indexOp.getLoc();
|
|
// Compute the static loop sizes of the index op.
|
|
auto targetShape = state.getCanonicalVecShape();
|
|
// Compute a one-dimensional index vector for the index op dimension.
|
|
auto constantSeq =
|
|
llvm::to_vector(llvm::seq<int64_t>(0, targetShape[indexOp.getDim()]));
|
|
auto indexSteps = rewriter.create<arith::ConstantOp>(
|
|
loc, rewriter.getIndexVectorAttr(constantSeq));
|
|
// Return the one-dimensional index vector if it lives in the trailing
|
|
// dimension of the iteration space since the vectorization algorithm in this
|
|
// case can handle the broadcast.
|
|
if (indexOp.getDim() == targetShape.size() - 1)
|
|
return VectorizationResult{VectorizationStatus::NewOp, indexSteps};
|
|
// Otherwise permute the targetShape to move the index dimension last,
|
|
// broadcast the one-dimensional index vector to the permuted shape, and
|
|
// finally transpose the broadcasted index vector to undo the permutation.
|
|
auto permPattern =
|
|
llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
|
|
std::swap(permPattern[indexOp.getDim()], permPattern.back());
|
|
auto permMap =
|
|
AffineMap::getPermutationMap(permPattern, linalgOp.getContext());
|
|
|
|
auto broadCastOp = rewriter.create<vector::BroadcastOp>(
|
|
loc, state.getCanonicalVecType(rewriter.getIndexType(), permMap),
|
|
indexSteps);
|
|
SmallVector<int64_t> transposition =
|
|
llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
|
|
std::swap(transposition.back(), transposition[indexOp.getDim()]);
|
|
auto transposeOp =
|
|
rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition);
|
|
return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
|
|
}
|
|
|
|
/// Helper function to check if the tensor.extract can be vectorized by the
|
|
/// custom hook vectorizeTensorExtract.
|
|
static LogicalResult
|
|
tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract) {
|
|
tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
|
|
if (!extractOp)
|
|
return failure();
|
|
|
|
if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
|
|
return failure();
|
|
|
|
// Check the index type, but only for non 0-d tensors (for which we do need
|
|
// access indices).
|
|
if (not extractOp.getIndices().empty()) {
|
|
if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
|
|
return failure();
|
|
}
|
|
|
|
if (llvm::any_of(extractOp->getResultTypes(), [](Type type) {
|
|
return !VectorType::isValidElementType(type);
|
|
})) {
|
|
return failure();
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
/// Calculates the offsets (`$index_vec`) for `vector.gather` operations
|
|
/// generated from `tensor.extract`. The offset is calculated as follows
|
|
/// (example using scalar values):
|
|
///
|
|
/// offset = extractOp.indices[0]
|
|
/// for (i = 1; i < numIndices; i++)
|
|
/// offset = extractOp.dimSize[i] * offset + extractOp.indices[i];
|
|
///
|
|
/// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to:
|
|
/// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3
|
|
static Value calculateGatherOffset(RewriterBase &rewriter,
|
|
VectorizationState &state,
|
|
tensor::ExtractOp extractOp,
|
|
const IRMapping &bvm) {
|
|
// The vector of indices for GatherOp should be shaped as the output vector.
|
|
auto indexVecType = state.getCanonicalVecType(rewriter.getIndexType());
|
|
auto loc = extractOp.getLoc();
|
|
|
|
Value offset = broadcastIfNeeded(
|
|
rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType);
|
|
|
|
const size_t numIndices = extractOp.getIndices().size();
|
|
for (size_t i = 1; i < numIndices; i++) {
|
|
Value dimIdx = rewriter.create<arith::ConstantIndexOp>(loc, i);
|
|
|
|
auto dimSize = broadcastIfNeeded(
|
|
rewriter,
|
|
rewriter.create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
|
|
indexVecType);
|
|
|
|
offset = rewriter.create<arith::MulIOp>(loc, offset, dimSize);
|
|
|
|
auto extractOpIndex = broadcastIfNeeded(
|
|
rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType);
|
|
|
|
offset = rewriter.create<arith::AddIOp>(loc, extractOpIndex, offset);
|
|
}
|
|
|
|
return offset;
|
|
}
|
|
|
|
enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };
|
|
|
|
/// Checks whether /p val can be used for calculating a loop invariant index.
|
|
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
|
|
|
|
auto targetShape = linalgOp.getStaticLoopRanges();
|
|
assert(((llvm::count_if(targetShape,
|
|
[](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
|
|
"n-D vectors are not yet supported");
|
|
assert(targetShape.back() != 1 &&
|
|
"1-D vectors with the trailing dim eqaual 1 are not yet supported");
|
|
|
|
// Blocks outside _this_ linalg.generic are effectively loop invariant.
|
|
// However, analysing block arguments for _this_ linalg.generic Op is a bit
|
|
// tricky. Just bail out in the latter case.
|
|
// TODO: We could try analysing the corresponding affine map here.
|
|
auto *block = linalgOp.getBlock();
|
|
if (isa<BlockArgument>(val))
|
|
return llvm::all_of(block->getArguments(),
|
|
[&val](Value v) { return (v != val); });
|
|
|
|
Operation *defOp = val.getDefiningOp();
|
|
assert(defOp && "This is neither a block argument nor an operation result");
|
|
|
|
// IndexOp is loop invariant as long as its result remains constant across
|
|
// iterations. Given the assumptions on the loop ranges above, only the
|
|
// trailing loop dim ever changes.
|
|
auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
|
|
if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp))
|
|
return (indexOp.getDim() != trailingLoopDim);
|
|
|
|
auto *ancestor = block->findAncestorOpInBlock(*defOp);
|
|
|
|
// Values define outside `linalgOp` are loop invariant.
|
|
if (!ancestor)
|
|
return true;
|
|
|
|
// Values defined inside `linalgOp`, which are constant, are loop invariant.
|
|
if (isa<arith::ConstantOp>(ancestor))
|
|
return true;
|
|
|
|
bool result = true;
|
|
for (auto op : ancestor->getOperands())
|
|
result &= isLoopInvariantIdx(linalgOp, op);
|
|
|
|
return result;
|
|
}
|
|
|
|
/// Check whether \p val could be used for calculating the trailing index for a
|
|
/// contiguous load operation.
|
|
///
|
|
/// There are currently 3 types of values that are allowed here:
|
|
/// 1. loop-invariant values,
|
|
/// 2. values that increment by 1 with every loop iteration,
|
|
/// 3. results of basic arithmetic operations (linear and continuous)
|
|
/// involving 1., 2. and 3.
|
|
/// This method returns True if indeed only such values are used in calculating
|
|
/// \p val.
|
|
///
|
|
/// Additionally, the trailing index for a contiguous load operation should
|
|
/// increment by 1 with every loop iteration, i.e. be based on:
|
|
/// * `linalg.index <dim>` ,
|
|
/// where <dim> is the trailing dim of the iteration space. \p foundIndexOp is
|
|
/// updated to `true` when such an op is found.
|
|
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
|
|
bool &foundIndexOp) {
|
|
|
|
auto targetShape = linalgOp.getStaticLoopRanges();
|
|
assert(((llvm::count_if(targetShape,
|
|
[](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
|
|
"n-D vectors are not yet supported");
|
|
assert(targetShape.back() != 1 &&
|
|
"1-D vectors with the trailing dim 1 are not yet supported");
|
|
|
|
// Blocks outside _this_ linalg.generic are effectively loop invariant.
|
|
// However, analysing block arguments for _this_ linalg.generic Op is a bit
|
|
// tricky. Just bail out in the latter case.
|
|
// TODO: We could try analysing the corresponding affine map here.
|
|
auto *block = linalgOp.getBlock();
|
|
if (isa<BlockArgument>(val))
|
|
return llvm::all_of(block->getArguments(),
|
|
[&val](Value v) { return (v != val); });
|
|
|
|
Operation *defOp = val.getDefiningOp();
|
|
assert(defOp && "This is neither a block argument nor an operation result");
|
|
|
|
// Given the assumption on the loop ranges above, only the trailing loop
|
|
// index is not constant.
|
|
auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
|
|
if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
|
|
foundIndexOp = (indexOp.getDim() == trailingLoopDim);
|
|
return true;
|
|
}
|
|
|
|
auto *ancestor = block->findAncestorOpInBlock(*defOp);
|
|
|
|
if (!ancestor)
|
|
return false;
|
|
|
|
// Conservatively reject Ops that could lead to indices with stride other
|
|
// than 1.
|
|
if (!isa<arith::AddIOp, arith::SubIOp, arith::ConstantOp, linalg::IndexOp>(
|
|
ancestor))
|
|
return false;
|
|
|
|
bool result = false;
|
|
for (auto op : ancestor->getOperands())
|
|
result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp);
|
|
|
|
return result;
|
|
}
|
|
|
|
/// Check whether \p extractOp would be a gather or a contiguous load Op after
|
|
/// vectorising \p linalgOp. Note that it is always safe to use gather load
|
|
/// operations for contiguous loads (albeit slow), but not vice-versa. When in
|
|
/// doubt, bail out and assume that \p extractOp is a gather load.
|
|
static VectorMemoryAccessKind
|
|
getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
|
|
LinalgOp &linalgOp) {
|
|
|
|
auto targetShape = linalgOp.getStaticLoopRanges();
|
|
auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
|
|
|
|
// 0.1 Is this a 0-D vector? If yes then this is a scalar broadcast.
|
|
if (inputShape.getShape().empty())
|
|
return VectorMemoryAccessKind::ScalarBroadcast;
|
|
|
|
// 0.2 In the case of dynamic shapes just bail-out and assume that it's a
|
|
// gather load.
|
|
// TODO: Relax this condition.
|
|
if (linalgOp.hasDynamicShape())
|
|
return VectorMemoryAccessKind::Gather;
|
|
|
|
// 1. Assume that it's a gather load when reading _into_:
|
|
// * an n-D vector, like`tensor<1x2x4xi32` or`tensor<2x1x4xi32>`, or
|
|
// * a 1-D vector with the trailing dim equal 1, e.g. `tensor<1x4x1xi32`.
|
|
// TODO: Relax these conditions.
|
|
// FIXME: This condition assumes non-dynamic sizes.
|
|
if ((llvm::count_if(targetShape,
|
|
[](int64_t dimSize) { return dimSize > 1; }) != 1) ||
|
|
targetShape.back() == 1)
|
|
return VectorMemoryAccessKind::Gather;
|
|
|
|
// 2. Assume that it's a gather load when reading _from_ a tensor for which
|
|
// the trailing dimension is 1, e.g. `tensor<1x4x1xi32>`.
|
|
// TODO: Relax this condition.
|
|
if (inputShape.getShape().back() == 1)
|
|
return VectorMemoryAccessKind::Gather;
|
|
|
|
bool leadingIdxsLoopInvariant = true;
|
|
|
|
// 3. Analyze the leading indices of `extractOp`.
|
|
// Look at the way each index is calculated and decide whether it is suitable
|
|
// for a contiguous load, i.e. whether it's loop invariant.
|
|
auto indices = extractOp.getIndices();
|
|
auto leadIndices = indices.drop_back(1);
|
|
|
|
for (auto [i, indexVal] : llvm::enumerate(leadIndices)) {
|
|
if (inputShape.getShape()[i] == 1)
|
|
continue;
|
|
|
|
leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal);
|
|
}
|
|
|
|
if (!leadingIdxsLoopInvariant) {
|
|
LDBG("Found gather load: " << extractOp);
|
|
return VectorMemoryAccessKind::Gather;
|
|
}
|
|
|
|
// 4. Analyze the trailing index for `extractOp`.
|
|
// At this point we know that the leading indices are loop invariant. This
|
|
// means that is potentially a scalar or a contiguous load. We can decide
|
|
// based on the trailing idx.
|
|
auto extractOpTrailingIdx = indices.back();
|
|
|
|
// 4a. Scalar broadcast load
|
|
// If the trailing index is loop invariant then this is a scalar load.
|
|
if (leadingIdxsLoopInvariant &&
|
|
isLoopInvariantIdx(linalgOp, extractOpTrailingIdx)) {
|
|
LDBG("Found scalar broadcast load: " << extractOp);
|
|
|
|
return VectorMemoryAccessKind::ScalarBroadcast;
|
|
}
|
|
|
|
// 4b. Contiguous loads
|
|
// The trailing `extractOp` index should increment with every loop iteration.
|
|
// This effectively means that it must be based on the trailing loop index.
|
|
// This is what the following bool captures.
|
|
bool foundIndexOp = false;
|
|
bool isContiguousLoad =
|
|
isContiguousLoadIdx(linalgOp, extractOpTrailingIdx, foundIndexOp);
|
|
isContiguousLoad &= foundIndexOp;
|
|
|
|
if (isContiguousLoad) {
|
|
LDBG("Found contigous load: " << extractOp);
|
|
return VectorMemoryAccessKind::Contiguous;
|
|
}
|
|
|
|
// 5. Fallback case - gather load.
|
|
LDBG("Found gather load: " << extractOp);
|
|
return VectorMemoryAccessKind::Gather;
|
|
}
|
|
|
|
/// Helper function to vectorize the tensor.extract operations. Returns
|
|
/// VectorizationStatus::NewOp to signal the vectorization algorithm that it
|
|
/// should map the produced operations. This function is meant to be used as a
|
|
/// CustomVectorizationHook.
|
|
static VectorizationResult
|
|
vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
|
|
Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
|
|
tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
|
|
if (!extractOp)
|
|
return VectorizationResult{VectorizationStatus::Failure, nullptr};
|
|
auto loc = extractOp.getLoc();
|
|
|
|
// Compute the static loop sizes of the extract op.
|
|
auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
|
|
auto maskConstantOp = rewriter.create<arith::ConstantOp>(
|
|
loc,
|
|
DenseIntElementsAttr::get(state.getCanonicalVecType(rewriter.getI1Type()),
|
|
/*value=*/true));
|
|
auto passThruConstantOp =
|
|
rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(resultType));
|
|
|
|
// Base indices are currently set to 0. We will need to re-visit if more
|
|
// generic scenarios are to be supported.
|
|
SmallVector<Value> baseIndices(
|
|
extractOp.getIndices().size(),
|
|
rewriter.create<arith::ConstantIndexOp>(loc, 0));
|
|
|
|
VectorMemoryAccessKind memAccessKind =
|
|
getTensorExtractMemoryAccessPattern(extractOp, linalgOp);
|
|
|
|
// 1. Handle gather access
|
|
if (memAccessKind == VectorMemoryAccessKind::Gather) {
|
|
Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm);
|
|
|
|
// Generate the gather load
|
|
Operation *gatherOp = rewriter.create<vector::GatherOp>(
|
|
loc, resultType, extractOp.getTensor(), baseIndices, offset,
|
|
maskConstantOp, passThruConstantOp);
|
|
gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
|
|
|
|
LDBG("Vectorised as gather load: " << extractOp << "\n");
|
|
return VectorizationResult{VectorizationStatus::NewOp, gatherOp};
|
|
}
|
|
|
|
// 2. Handle:
|
|
// a. scalar loads + broadcast,
|
|
// b. contiguous loads.
|
|
// Both cases use vector.transfer_read.
|
|
|
|
// Collect indices for `vector.transfer_read`. At this point, the indices will
|
|
// either be scalars or would have been broadcast to vectors matching the
|
|
// result type. For indices that are vectors, there are two options:
|
|
// * for non-trailing indices, all elements are identical (contiguous
|
|
// loads are identified by looking for non-trailing indices that are
|
|
// invariant with respect to the corresponding linalg.generic), or
|
|
// * for trailing indices, the index vector will contain values with stride
|
|
// one, but for `vector.transfer_read` only the first (i.e. 0th) index is
|
|
// needed.
|
|
// This means that
|
|
// * for scalar indices - just re-use it,
|
|
// * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
|
|
// (0th) element and use that.
|
|
SmallVector<Value> transferReadIdxs;
|
|
auto resTrailingDim = resultType.getShape().back();
|
|
auto zero = rewriter.create<arith::ConstantOp>(
|
|
loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type()));
|
|
for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
|
|
auto idx = bvm.lookup(extractOp.getIndices()[i]);
|
|
if (idx.getType().isIndex()) {
|
|
transferReadIdxs.push_back(idx);
|
|
continue;
|
|
}
|
|
|
|
auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
|
|
loc, VectorType::get({resTrailingDim}, rewriter.getIndexType()),
|
|
bvm.lookup(extractOp.getIndices()[i]));
|
|
transferReadIdxs.push_back(
|
|
rewriter.create<vector::ExtractElementOp>(loc, indexAs1dVector, zero));
|
|
}
|
|
|
|
// `tensor.extract_element` is always in-bounds, hence the following holds.
|
|
auto dstRank = resultType.getRank();
|
|
auto srcRank = extractOp.getTensor().getType().getRank();
|
|
SmallVector<bool> inBounds(dstRank, true);
|
|
|
|
// 2a. Handle scalar broadcast access.
|
|
if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
|
|
MLIRContext *ctx = rewriter.getContext();
|
|
SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(0, ctx));
|
|
auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);
|
|
|
|
auto transferReadOp = rewriter.create<vector::TransferReadOp>(
|
|
loc, resultType, extractOp.getTensor(), transferReadIdxs,
|
|
permutationMap, inBounds);
|
|
|
|
LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n");
|
|
return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
|
|
}
|
|
|
|
// 2b. Handle contiguous access.
|
|
auto permutationMap = AffineMap::getMinorIdentityMap(
|
|
srcRank, std::min(dstRank, srcRank), rewriter.getContext());
|
|
|
|
int32_t rankDiff = dstRank - srcRank;
|
|
// When dstRank > srcRank, broadcast the source tensor to the unitary leading
|
|
// dims so that the ranks match. This is done by extending the map with 0s.
|
|
// For example, for dstRank = 3, srcRank = 2, the following map created
|
|
// above:
|
|
// (d0, d1) --> (d0, d1)
|
|
// is extended as:
|
|
// (d0, d1) --> (0, d0, d1)
|
|
while (rankDiff > 0) {
|
|
permutationMap = permutationMap.insertResult(
|
|
mlir::getAffineConstantExpr(0, rewriter.getContext()), 0);
|
|
rankDiff--;
|
|
}
|
|
|
|
auto transferReadOp = rewriter.create<vector::TransferReadOp>(
|
|
loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap,
|
|
inBounds);
|
|
|
|
LDBG("Vectorised as contiguous load: " << extractOp);
|
|
return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
|
|
}
|
|
|
|
/// Emit reduction operations if the shapes of the value to reduce is different
|
|
/// that the result shape.
|
|
// Note: this is a true builder that notifies the OpBuilder listener.
|
|
// TODO: Consider moving as a static helper on the ReduceOp.
|
|
static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
|
|
Value reduceValue, Value initialValue,
|
|
const IRMapping &bvm) {
|
|
Value reduceVec = bvm.lookup(reduceValue);
|
|
Value outputVec = bvm.lookup(initialValue);
|
|
auto reduceType = dyn_cast<VectorType>(reduceVec.getType());
|
|
auto outputType = dyn_cast<VectorType>(outputVec.getType());
|
|
// Reduce only if needed as the value may already have been reduce for
|
|
// contraction vectorization.
|
|
if (!reduceType ||
|
|
(outputType && reduceType.getShape() == outputType.getShape()))
|
|
return nullptr;
|
|
SmallVector<bool> dimsToMask = getDimsToReduce(linalgOp);
|
|
return buildMultiDimReduce(b, op, reduceVec, outputVec, dimsToMask);
|
|
}
|
|
|
|
/// Generic vectorization for a single operation `op`, given already vectorized
|
|
/// operands carried by `bvm`. Vectorization occurs as follows:
|
|
/// 1. Try to apply any of the `customVectorizationHooks` and return its
|
|
/// result on success.
|
|
/// 2. Clone any constant in the current scope without vectorization: each
|
|
/// consumer of the constant will later determine the shape to which the
|
|
/// constant needs to be broadcast to.
|
|
/// 3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
|
|
/// of the `customVectorizationHooks` to cover such cases.
|
|
/// 4. Clone `op` in vector form to a vector of shape prescribed by the first
|
|
/// operand of maximal rank. Other operands have smaller rank and are
|
|
/// broadcast accordingly. It is assumed this broadcast is always legal,
|
|
/// otherwise, it means one of the `customVectorizationHooks` is incorrect.
|
|
///
|
|
/// This function assumes all operands of `op` have been vectorized and are in
|
|
/// the `bvm` mapping. As a consequence, this function is meant to be called on
|
|
/// a topologically-sorted list of ops.
|
|
/// This function does not update `bvm` but returns a VectorizationStatus that
|
|
/// instructs the caller what `bvm` update needs to occur.
|
|
static VectorizationResult
|
|
vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
|
|
LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
|
|
ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
|
|
LDBG("vectorize op " << *op << "\n");
|
|
|
|
// 1. Try to apply any CustomVectorizationHook.
|
|
if (!customVectorizationHooks.empty()) {
|
|
for (auto &customFunc : customVectorizationHooks) {
|
|
VectorizationResult result = customFunc(op, bvm);
|
|
if (result.status == VectorizationStatus::Failure)
|
|
continue;
|
|
return result;
|
|
}
|
|
}
|
|
|
|
// 2. Constant ops don't get vectorized but rather broadcasted at their users.
|
|
// Clone so that the constant is not confined to the linalgOp block .
|
|
if (isa<arith::ConstantOp, func::ConstantOp>(op))
|
|
return VectorizationResult{VectorizationStatus::NewOp, rewriter.clone(*op)};
|
|
|
|
// 3. Only ElementwiseMappable are allowed in the generic vectorization.
|
|
if (!OpTrait::hasElementwiseMappableTraits(op))
|
|
return VectorizationResult{VectorizationStatus::Failure, nullptr};
|
|
|
|
// 4 . Check if the operation is a reduction.
|
|
SmallVector<std::pair<Value, Value>> reductionOperands;
|
|
for (Value operand : op->getOperands()) {
|
|
auto blockArg = dyn_cast<BlockArgument>(operand);
|
|
if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
|
|
blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
|
|
continue;
|
|
SmallVector<Operation *> reductionOps;
|
|
Value reduceValue = matchReduction(
|
|
linalgOp.getRegionOutputArgs(),
|
|
blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
|
|
if (!reduceValue)
|
|
continue;
|
|
reductionOperands.push_back(std::make_pair(reduceValue, operand));
|
|
}
|
|
if (!reductionOperands.empty()) {
|
|
assert(reductionOperands.size() == 1);
|
|
Operation *reduceOp =
|
|
reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
|
|
reductionOperands[0].second, bvm);
|
|
if (reduceOp)
|
|
return VectorizationResult{VectorizationStatus::NewOp, reduceOp};
|
|
}
|
|
|
|
// 5. Generic vectorization path for ElementwiseMappable ops.
|
|
// a. Get the first max ranked shape.
|
|
VectorType firstMaxRankedType;
|
|
for (Value operand : op->getOperands()) {
|
|
auto vecOperand = bvm.lookup(operand);
|
|
assert(vecOperand && "Vector operand couldn't be found");
|
|
|
|
auto vecType = dyn_cast<VectorType>(vecOperand.getType());
|
|
if (vecType && (!firstMaxRankedType ||
|
|
firstMaxRankedType.getRank() < vecType.getRank()))
|
|
firstMaxRankedType = vecType;
|
|
}
|
|
// b. Broadcast each op if needed.
|
|
SmallVector<Value> vecOperands;
|
|
for (Value scalarOperand : op->getOperands()) {
|
|
Value vecOperand = bvm.lookup(scalarOperand);
|
|
assert(vecOperand && "Vector operand couldn't be found");
|
|
|
|
if (firstMaxRankedType) {
|
|
auto vecType = VectorType::get(firstMaxRankedType.getShape(),
|
|
getElementTypeOrSelf(vecOperand.getType()),
|
|
firstMaxRankedType.getScalableDims());
|
|
vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
|
|
} else {
|
|
vecOperands.push_back(vecOperand);
|
|
}
|
|
}
|
|
// c. for elementwise, the result is the vector with the firstMaxRankedShape
|
|
SmallVector<Type> resultTypes;
|
|
for (Type resultType : op->getResultTypes()) {
|
|
resultTypes.push_back(
|
|
firstMaxRankedType
|
|
? VectorType::get(firstMaxRankedType.getShape(), resultType,
|
|
firstMaxRankedType.getScalableDims())
|
|
: resultType);
|
|
}
|
|
// d. Build and return the new op.
|
|
return VectorizationResult{
|
|
VectorizationStatus::NewOp,
|
|
rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
|
|
resultTypes, op->getAttrs())};
|
|
}
|
|
|
|
/// Generic vectorization function that rewrites the body of a `linalgOp` into
|
|
/// vector form. Generic vectorization proceeds as follows:
|
|
/// 1. Verify the `linalgOp` has one non-empty region.
|
|
/// 2. Values defined above the region are mapped to themselves and will be
|
|
/// broadcasted on a per-need basis by their consumers.
|
|
/// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d
|
|
/// load).
|
|
/// TODO: Reuse opportunities for RAR dependencies.
|
|
/// 4a. Register CustomVectorizationHook for YieldOp to capture the results.
|
|
/// 4rewriter. Register CustomVectorizationHook for IndexOp to access the
|
|
/// iteration indices.
|
|
/// 5. Iteratively call vectorizeOneOp on the region operations.
|
|
///
|
|
/// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is
|
|
/// performed to the maximal common vector size implied by the `linalgOp`
|
|
/// iteration space. This eager broadcasting is introduced in the
|
|
/// permutation_map of the vector.transfer_read operations. The eager
|
|
/// broadcasting makes it trivial to detrmine where broadcast, transposes and
|
|
/// reductions should occur, without any bookkeeping. The tradeoff is that, in
|
|
/// the absence of good canonicalizations, the amount of work increases.
|
|
/// This is not deemed a problem as we expect canonicalizations and foldings to
|
|
/// aggressively clean up the useless work.
|
|
static LogicalResult
|
|
vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
|
|
LinalgOp linalgOp,
|
|
SmallVectorImpl<Value> &newResults) {
|
|
LDBG("Vectorizing operation as linalg generic\n");
|
|
Block *block = linalgOp.getBlock();
|
|
|
|
// 2. Values defined above the region can only be broadcast for now. Make them
|
|
// map to themselves.
|
|
IRMapping bvm;
|
|
SetVector<Value> valuesSet;
|
|
mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
|
|
bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
|
|
|
|
if (linalgOp.getNumDpsInits() == 0)
|
|
return failure();
|
|
|
|
// 3. Turn all BBArgs into vector.transfer_read / load.
|
|
Location loc = linalgOp.getLoc();
|
|
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
|
for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
|
|
BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
|
|
if (linalgOp.isScalar(opOperand)) {
|
|
bvm.map(bbarg, opOperand->get());
|
|
continue;
|
|
}
|
|
|
|
// 3.a. Convert the indexing map for this input/output to a transfer read
|
|
// permutation map and masking map.
|
|
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
|
|
|
|
// Remove zeros from indexing map to use it as masking map.
|
|
SmallVector<int64_t> zeroPos;
|
|
auto results = indexingMap.getResults();
|
|
for (const auto &result : llvm::enumerate(results)) {
|
|
if (result.value().isa<AffineConstantExpr>()) {
|
|
zeroPos.push_back(result.index());
|
|
}
|
|
}
|
|
AffineMap maskingMap = indexingMap.dropResults(zeroPos);
|
|
|
|
AffineMap readMap;
|
|
VectorType readType;
|
|
Type elemType = getElementTypeOrSelf(opOperand->get());
|
|
if (linalgOp.isDpsInput(opOperand)) {
|
|
// 3.a.i. For input reads we use the canonical vector shape.
|
|
readMap = inverseAndBroadcastProjectedPermutation(indexingMap);
|
|
readType = state.getCanonicalVecType(elemType);
|
|
} else {
|
|
// 3.a.ii. For output reads (iteration-carried dependence, e.g.,
|
|
// reductions), the vector shape is computed by mapping the canonical
|
|
// vector shape to the output domain and back to the canonical domain.
|
|
readMap = inversePermutation(reindexIndexingMap(indexingMap));
|
|
readType =
|
|
state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
|
|
}
|
|
|
|
SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
|
|
|
|
Operation *read = rewriter.create<vector::TransferReadOp>(
|
|
loc, readType, opOperand->get(), indices, readMap);
|
|
read = state.maskOperation(rewriter, read, linalgOp, maskingMap);
|
|
Value readValue = read->getResult(0);
|
|
|
|
// 3.b. If masked, set in-bounds to true. Masking guarantees that the access
|
|
// will be in-bounds.
|
|
if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
|
|
SmallVector<bool> inBounds(readType.getRank(), true);
|
|
cast<vector::TransferReadOp>(maskOp.getMaskableOp())
|
|
.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
|
|
}
|
|
|
|
// 3.c. Not all ops support 0-d vectors, extract the scalar for now.
|
|
// TODO: remove this.
|
|
if (readType.getRank() == 0)
|
|
readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
|
|
|
|
LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
|
|
<< "\n");
|
|
bvm.map(bbarg, readValue);
|
|
bvm.map(opOperand->get(), readValue);
|
|
}
|
|
|
|
SmallVector<CustomVectorizationHook> hooks;
|
|
// 4a. Register CustomVectorizationHook for yieldOp.
|
|
CustomVectorizationHook vectorizeYield =
|
|
[&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
|
|
return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
|
|
};
|
|
hooks.push_back(vectorizeYield);
|
|
|
|
// 4b. Register CustomVectorizationHook for indexOp.
|
|
CustomVectorizationHook vectorizeIndex =
|
|
[&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
|
|
return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
|
|
};
|
|
hooks.push_back(vectorizeIndex);
|
|
|
|
// 4c. Register CustomVectorizationHook for extractOp.
|
|
CustomVectorizationHook vectorizeExtract =
|
|
[&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
|
|
return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
|
|
};
|
|
hooks.push_back(vectorizeExtract);
|
|
|
|
// 5. Iteratively call `vectorizeOneOp` to each op in the slice.
|
|
for (Operation &op : block->getOperations()) {
|
|
VectorizationResult result =
|
|
vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
|
|
if (result.status == VectorizationStatus::Failure) {
|
|
LDBG("failed to vectorize: " << op << "\n");
|
|
return failure();
|
|
}
|
|
if (result.status == VectorizationStatus::NewOp) {
|
|
Operation *maybeMaskedOp =
|
|
state.maskOperation(rewriter, result.newOp, linalgOp);
|
|
LDBG("New vector op: " << *maybeMaskedOp << "\n");
|
|
bvm.map(op.getResults(), maybeMaskedOp->getResults());
|
|
}
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
/// Vectorize a `padOp` with (1) static result type, (2) constant padding value
|
|
/// and (3) all-zero lowPad to
|
|
/// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
|
|
static LogicalResult
|
|
vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
|
|
ArrayRef<int64_t> inputVectorSizes,
|
|
SmallVectorImpl<Value> &newResults) {
|
|
auto padValue = padOp.getConstantPaddingValue();
|
|
Location loc = padOp.getLoc();
|
|
int64_t rank = inputVectorSizes.size();
|
|
auto maskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
|
|
auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
|
|
|
|
// transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))
|
|
OpBuilder::InsertionGuard g(rewriter);
|
|
rewriter.setInsertionPoint(padOp);
|
|
|
|
ReifiedRankedShapedTypeDims reifiedReturnShapes;
|
|
LogicalResult status =
|
|
cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
|
|
.reifyResultShapes(rewriter, reifiedReturnShapes);
|
|
(void)status; // prevent unused variable warning on non-assert builds
|
|
assert(succeeded(status) && "failed to reify result shapes");
|
|
auto emptyOp = rewriter.create<tensor::EmptyOp>(loc, reifiedReturnShapes[0],
|
|
padValue.getType());
|
|
SmallVector<OpFoldResult> mixedSourceDims =
|
|
tensor::getMixedSizes(rewriter, loc, padOp.getSource());
|
|
Value mask =
|
|
rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
|
|
auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
|
auto transferReadOp = rewriter.create<vector::TransferReadOp>(
|
|
loc,
|
|
/*vectorType=*/vectorType,
|
|
/*source=*/padOp.getSource(),
|
|
/*indices=*/SmallVector<Value>(rank, zero),
|
|
/*padding=*/padValue,
|
|
/*inBounds=*/SmallVector<bool>(rank, true));
|
|
auto maskedOp = cast<vector::MaskOp>(
|
|
mlir::vector::maskOperation(rewriter, transferReadOp, mask));
|
|
Operation *write = rewriter.create<vector::TransferWriteOp>(
|
|
loc,
|
|
/*vector=*/maskedOp->getResult(0),
|
|
/*source=*/emptyOp,
|
|
/*indices=*/SmallVector<Value>(rank, zero),
|
|
/*inBounds=*/SmallVector<bool>(rank, true));
|
|
bool needMaskForWrite = llvm::any_of(
|
|
llvm::zip_equal(inputVectorSizes, padOp.getResultType().getShape()),
|
|
[](auto it) { return std::get<0>(it) != std::get<1>(it); });
|
|
if (needMaskForWrite) {
|
|
Value maskForWrite = rewriter.create<vector::CreateMaskOp>(
|
|
loc, maskType, reifiedReturnShapes[0]);
|
|
write = mlir::vector::maskOperation(rewriter, write, maskForWrite);
|
|
}
|
|
newResults.push_back(write->getResult(0));
|
|
return success();
|
|
}
|
|
|
|
// TODO: probably need some extra checks for reduction followed by consumer
|
|
// ops that may not commute (e.g. linear reduction + non-linear instructions).
|
|
static LogicalResult reductionPreconditions(LinalgOp op) {
|
|
if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) {
|
|
LDBG("reduction precondition failed: no reduction iterator\n");
|
|
return failure();
|
|
}
|
|
for (OpOperand *opOperand : op.getDpsInitOperands()) {
|
|
AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
|
|
if (indexingMap.isPermutation())
|
|
continue;
|
|
|
|
Operation *reduceOp = matchLinalgReduction(opOperand);
|
|
if (!reduceOp || !getCombinerOpKind(reduceOp)) {
|
|
LDBG("reduction precondition failed: reduction detection failed\n");
|
|
return failure();
|
|
}
|
|
}
|
|
return success();
|
|
}
|
|
|
|
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op) {
|
|
// TODO: Masking only supports dynamic generic ops for now.
|
|
if (!isa<linalg::GenericOp, linalg::FillOp, linalg::CopyOp,
|
|
linalg::ContractionOpInterface>(op.getOperation()))
|
|
return failure();
|
|
|
|
LDBG("Dynamically-shaped op meets vectorization pre-conditions\n");
|
|
return success();
|
|
}
|
|
|
|
/// Returns success if `inputVectorSizes` is a valid masking configuraion for
|
|
/// given `shape`, i.e., it meets:
|
|
/// 1. The numbers of elements in both array are equal.
|
|
/// 2. `inputVectorSizes` does nos have dynamic dimensions.
|
|
/// 3. All the values in `inputVectorSizes` are greater than or equal to
|
|
/// static sizes in `shape`.
|
|
static LogicalResult
|
|
isValidMaskedInputVector(ArrayRef<int64_t> shape,
|
|
ArrayRef<int64_t> inputVectorSizes) {
|
|
LDBG("Iteration space static sizes:");
|
|
LLVM_DEBUG(llvm::interleaveComma(shape, llvm::dbgs()));
|
|
LLVM_DEBUG(llvm::dbgs() << "\n");
|
|
|
|
if (inputVectorSizes.size() != shape.size()) {
|
|
LDBG("Input vector sizes don't match the number of loops");
|
|
return failure();
|
|
}
|
|
if (ShapedType::isDynamicShape(inputVectorSizes)) {
|
|
LDBG("Input vector sizes can't have dynamic dimensions");
|
|
return failure();
|
|
}
|
|
if (!llvm::all_of(llvm::zip(shape, inputVectorSizes),
|
|
[](std::tuple<int64_t, int64_t> sizePair) {
|
|
int64_t staticSize = std::get<0>(sizePair);
|
|
int64_t inputSize = std::get<1>(sizePair);
|
|
return ShapedType::isDynamic(staticSize) ||
|
|
staticSize <= inputSize;
|
|
})) {
|
|
LDBG("Input vector sizes must be greater than or equal to iteration space "
|
|
"static sizes");
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
static LogicalResult
|
|
vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
|
|
ArrayRef<int64_t> inputVectorSizes,
|
|
bool vectorizeNDExtract) {
|
|
// tensor with dimension of 0 cannot be vectorized.
|
|
if (llvm::any_of(linalgOp.getStaticShape(),
|
|
[](int64_t dim) { return dim == 0; }))
|
|
return failure();
|
|
// Check API contract for input vector sizes.
|
|
if (!inputVectorSizes.empty() &&
|
|
failed(isValidMaskedInputVector(linalgOp.getStaticLoopRanges(),
|
|
inputVectorSizes)))
|
|
return failure();
|
|
|
|
if (linalgOp.hasDynamicShape() &&
|
|
failed(vectorizeDynamicLinalgOpPrecondition(linalgOp))) {
|
|
LDBG("Dynamically-shaped op failed vectorization pre-conditions\n");
|
|
return failure();
|
|
}
|
|
|
|
SmallVector<CustomVectorizationPrecondition> customPreconditions;
|
|
|
|
// Register CustomVectorizationPrecondition for extractOp.
|
|
customPreconditions.push_back(tensorExtractVectorizationPrecondition);
|
|
|
|
// All types in the body should be a supported element type for VectorType.
|
|
for (Operation &innerOp : linalgOp->getRegion(0).front()) {
|
|
// Check if any custom hook can vectorize the inner op.
|
|
if (llvm::any_of(
|
|
customPreconditions,
|
|
[&](const CustomVectorizationPrecondition &customPrecondition) {
|
|
return succeeded(
|
|
customPrecondition(&innerOp, vectorizeNDExtract));
|
|
})) {
|
|
continue;
|
|
}
|
|
if (llvm::any_of(innerOp.getOperandTypes(), [](Type type) {
|
|
return !VectorType::isValidElementType(type);
|
|
})) {
|
|
return failure();
|
|
}
|
|
if (llvm::any_of(innerOp.getResultTypes(), [](Type type) {
|
|
return !VectorType::isValidElementType(type);
|
|
})) {
|
|
return failure();
|
|
}
|
|
}
|
|
if (isElementwise(linalgOp))
|
|
return success();
|
|
// TODO: isaConvolutionOpInterface that can also infer from generic features.
|
|
// But we will still need stride/dilation attributes that will be annoying to
|
|
// reverse-engineer...
|
|
if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
|
|
return success();
|
|
// TODO: the common vector shape is equal to the static loop sizes only when
|
|
// all indexing maps are projected permutations. For convs and stencils the
|
|
// logic will need to evolve.
|
|
if (!allIndexingsAreProjectedPermutation(linalgOp)) {
|
|
LDBG("precondition failed: not projected permutations\n");
|
|
return failure();
|
|
}
|
|
if (failed(reductionPreconditions(linalgOp))) {
|
|
LDBG("precondition failed: reduction preconditions\n");
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
static LogicalResult
|
|
vectorizePadOpPrecondition(tensor::PadOp padOp,
|
|
ArrayRef<int64_t> inputVectorSizes) {
|
|
auto padValue = padOp.getConstantPaddingValue();
|
|
if (!padValue) {
|
|
LDBG("pad value is not constant: " << padOp << "\n");
|
|
return failure();
|
|
}
|
|
|
|
ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
|
|
if (failed(isValidMaskedInputVector(resultTensorShape, inputVectorSizes)))
|
|
return failure();
|
|
|
|
if (llvm::any_of(padOp.getLow(), [](Value v) {
|
|
std::optional<int64_t> res = getConstantIntValue(v);
|
|
return !res.has_value() || res.value() != 0;
|
|
})) {
|
|
LDBG("low pad must all be zero: " << padOp << "\n");
|
|
return failure();
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
/// Preconditions for scalable vectors.
|
|
static LogicalResult
|
|
vectorizeScalableVectorPrecondition(Operation *op,
|
|
ArrayRef<int64_t> inputVectorSizes,
|
|
ArrayRef<bool> inputScalableVecDims) {
|
|
assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
|
|
"Number of input vector sizes and scalable dims doesn't match");
|
|
|
|
if (inputVectorSizes.empty())
|
|
return success();
|
|
|
|
bool isScalable = inputScalableVecDims.back();
|
|
if (!isScalable)
|
|
return success();
|
|
|
|
// Only element-wise ops supported in the presence of scalable dims.
|
|
auto linalgOp = dyn_cast<LinalgOp>(op);
|
|
return success(linalgOp && isElementwise(linalgOp));
|
|
}
|
|
|
|
LogicalResult mlir::linalg::vectorizeOpPrecondition(
|
|
Operation *op, ArrayRef<int64_t> inputVectorSizes,
|
|
ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract) {
|
|
if (failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes,
|
|
inputScalableVecDims)))
|
|
return failure();
|
|
|
|
return TypeSwitch<Operation *, LogicalResult>(op)
|
|
.Case<linalg::LinalgOp>([&](auto linalgOp) {
|
|
return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
|
|
vectorizeNDExtract);
|
|
})
|
|
.Case<tensor::PadOp>([&](auto padOp) {
|
|
return vectorizePadOpPrecondition(padOp, inputVectorSizes);
|
|
})
|
|
.Default([](auto) { return failure(); });
|
|
}
|
|
|
|
/// Converts affine.apply Ops to arithmetic operations.
|
|
static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
|
|
OpBuilder::InsertionGuard g(rewriter);
|
|
auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
|
|
|
|
for (auto op : make_early_inc_range(toReplace)) {
|
|
rewriter.setInsertionPoint(op);
|
|
auto expanded = affine::expandAffineExpr(
|
|
rewriter, op->getLoc(), op.getAffineMap().getResult(0),
|
|
op.getOperands().take_front(op.getAffineMap().getNumDims()),
|
|
op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
|
|
rewriter.replaceOp(op, expanded);
|
|
}
|
|
}
|
|
|
|
/// Emit a suitable vector form for an operation. If provided,
|
|
/// `inputVectorSizes` are used to vectorize this operation. `inputVectorSizes`
|
|
/// must match the rank of the iteration space of the operation and the input
|
|
/// vector sizes must be greater than or equal to their counterpart iteration
|
|
/// space sizes, if static. `inputVectorShapes` also allows the vectorization of
|
|
/// operations with dynamic shapes.
|
|
LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
|
|
ArrayRef<int64_t> inputVectorSizes,
|
|
ArrayRef<bool> inputScalableVecDims,
|
|
bool vectorizeNDExtract) {
|
|
LDBG("Attempting to vectorize:\n" << *op << "\n");
|
|
LDBG("Input vector sizes: ");
|
|
LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
|
|
LLVM_DEBUG(llvm::dbgs() << "\n");
|
|
LDBG("Input scalable vector dims: ");
|
|
LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
|
|
LLVM_DEBUG(llvm::dbgs() << "\n");
|
|
|
|
if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
|
|
vectorizeNDExtract))) {
|
|
LDBG("Vectorization pre-conditions failed\n");
|
|
return failure();
|
|
}
|
|
|
|
// Initialize vectorization state.
|
|
VectorizationState state(rewriter);
|
|
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
|
|
if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
|
|
inputScalableVecDims))) {
|
|
LDBG("Vectorization state couldn't be initialized\n");
|
|
return failure();
|
|
}
|
|
}
|
|
|
|
SmallVector<Value> results;
|
|
auto vectorizeResult =
|
|
TypeSwitch<Operation *, LogicalResult>(op)
|
|
.Case<linalg::LinalgOp>([&](auto linalgOp) {
|
|
// TODO: isaConvolutionOpInterface that can also infer from generic
|
|
// features. Will require stride/dilation attributes inference.
|
|
FailureOr<Operation *> convOr =
|
|
vectorizeConvolution(rewriter, linalgOp);
|
|
if (succeeded(convOr)) {
|
|
llvm::append_range(results, (*convOr)->getResults());
|
|
return success();
|
|
}
|
|
|
|
LDBG("Vectorize generic by broadcasting to the canonical vector "
|
|
"shape\n");
|
|
|
|
// Pre-process before proceeding.
|
|
convertAffineApply(rewriter, linalgOp);
|
|
|
|
// TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
|
|
// to 'OpBuilder' when it is passed over to some methods like
|
|
// 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
|
|
// erase an op within these methods, the actual rewriter won't be
|
|
// notified and we will end up with read-after-free issues!
|
|
return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
|
|
})
|
|
.Case<tensor::PadOp>([&](auto padOp) {
|
|
return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
|
|
results);
|
|
})
|
|
.Default([](auto) { return failure(); });
|
|
|
|
if (failed(vectorizeResult)) {
|
|
LDBG("Vectorization failed\n");
|
|
return failure();
|
|
}
|
|
|
|
if (!results.empty())
|
|
rewriter.replaceOp(op, results);
|
|
else
|
|
rewriter.eraseOp(op);
|
|
|
|
return success();
|
|
}
|
|
|
|
LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
|
|
memref::CopyOp copyOp) {
|
|
|
|
auto srcType = cast<MemRefType>(copyOp.getSource().getType());
|
|
auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
|
|
if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
|
|
return failure();
|
|
|
|
auto srcElementType = getElementTypeOrSelf(srcType);
|
|
auto dstElementType = getElementTypeOrSelf(dstType);
|
|
if (!VectorType::isValidElementType(srcElementType) ||
|
|
!VectorType::isValidElementType(dstElementType))
|
|
return failure();
|
|
|
|
auto readType = VectorType::get(srcType.getShape(), srcElementType);
|
|
auto writeType = VectorType::get(dstType.getShape(), dstElementType);
|
|
|
|
Location loc = copyOp->getLoc();
|
|
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
|
SmallVector<Value> indices(srcType.getRank(), zero);
|
|
|
|
Value readValue = rewriter.create<vector::TransferReadOp>(
|
|
loc, readType, copyOp.getSource(), indices,
|
|
rewriter.getMultiDimIdentityMap(srcType.getRank()));
|
|
if (cast<VectorType>(readValue.getType()).getRank() == 0) {
|
|
readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
|
|
readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
|
|
}
|
|
Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
|
|
loc, readValue, copyOp.getTarget(), indices,
|
|
rewriter.getMultiDimIdentityMap(srcType.getRank()));
|
|
rewriter.replaceOp(copyOp, writeValue->getResults());
|
|
return success();
|
|
}
|
|
|
|
//----------------------------------------------------------------------------//
|
|
// Misc. vectorization patterns.
|
|
//----------------------------------------------------------------------------//
|
|
|
|
/// Helper function that retrieves the value of an IntegerAttr.
|
|
static int64_t getIntFromAttr(Attribute attr) {
|
|
return cast<IntegerAttr>(attr).getInt();
|
|
}
|
|
|
|
/// Given an ArrayRef of OpFoldResults, return a vector of Values.
|
|
/// IntegerAttrs are converted to ConstantIndexOps. Other attribute types are
|
|
/// not supported.
|
|
static SmallVector<Value> ofrToIndexValues(RewriterBase &rewriter, Location loc,
|
|
ArrayRef<OpFoldResult> ofrs) {
|
|
SmallVector<Value> result;
|
|
for (auto o : ofrs) {
|
|
if (auto val = llvm::dyn_cast_if_present<Value>(o)) {
|
|
result.push_back(val);
|
|
} else {
|
|
result.push_back(rewriter.create<arith::ConstantIndexOp>(
|
|
loc, getIntFromAttr(o.template get<Attribute>())));
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
/// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and
|
|
/// InsertSliceOp. For now, only constant padding values are supported.
|
|
/// If there is enough static type information, TransferReadOps and
|
|
/// TransferWriteOps may be generated instead of InsertSliceOps.
|
|
struct GenericPadOpVectorizationPattern : public GeneralizePadOpPattern {
|
|
GenericPadOpVectorizationPattern(MLIRContext *context,
|
|
PatternBenefit benefit = 1)
|
|
: GeneralizePadOpPattern(context, tryVectorizeCopy, benefit) {}
|
|
/// Vectorize the copying of a tensor::PadOp's source. This is possible if
|
|
/// each dimension size is statically know in the source type or the result
|
|
/// type (or both).
|
|
static LogicalResult tryVectorizeCopy(RewriterBase &rewriter,
|
|
tensor::PadOp padOp, Value dest) {
|
|
auto sourceType = padOp.getSourceType();
|
|
auto resultType = padOp.getResultType();
|
|
if (!VectorType::isValidElementType(sourceType.getElementType()))
|
|
return failure();
|
|
|
|
// Copy cannot be vectorized if pad value is non-constant and source shape
|
|
// is dynamic. In case of a dynamic source shape, padding must be appended
|
|
// by TransferReadOp, but TransferReadOp supports only constant padding.
|
|
auto padValue = padOp.getConstantPaddingValue();
|
|
if (!padValue) {
|
|
if (!sourceType.hasStaticShape())
|
|
return failure();
|
|
// Create dummy padding value.
|
|
auto elemType = sourceType.getElementType();
|
|
padValue = rewriter.create<arith::ConstantOp>(
|
|
padOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
|
|
}
|
|
|
|
SmallVector<int64_t> vecShape;
|
|
SmallVector<bool> readInBounds;
|
|
SmallVector<bool> writeInBounds;
|
|
for (unsigned i = 0; i < sourceType.getRank(); ++i) {
|
|
if (!sourceType.isDynamicDim(i)) {
|
|
vecShape.push_back(sourceType.getDimSize(i));
|
|
// Source shape is statically known: Neither read nor write are
|
|
// out-of- bounds.
|
|
readInBounds.push_back(true);
|
|
writeInBounds.push_back(true);
|
|
} else if (!resultType.isDynamicDim(i)) {
|
|
// Source shape is not statically known, but result shape is.
|
|
// Vectorize with size of result shape. This may be larger than the
|
|
// source size.
|
|
vecShape.push_back(resultType.getDimSize(i));
|
|
// Read may be out-of-bounds because the result size could be larger
|
|
// than the source size.
|
|
readInBounds.push_back(false);
|
|
// Write is out-of-bounds if low padding > 0.
|
|
writeInBounds.push_back(
|
|
getConstantIntValue(padOp.getMixedLowPad()[i]) ==
|
|
static_cast<int64_t>(0));
|
|
} else {
|
|
// Neither source nor result dim of padOp is static. Cannot vectorize
|
|
// the copy.
|
|
return failure();
|
|
}
|
|
}
|
|
auto vecType = VectorType::get(vecShape, sourceType.getElementType());
|
|
|
|
// Generate TransferReadOp.
|
|
SmallVector<Value> readIndices(
|
|
vecType.getRank(),
|
|
rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
|
|
auto read = rewriter.create<vector::TransferReadOp>(
|
|
padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue,
|
|
ArrayRef<bool>{readInBounds});
|
|
|
|
// If `dest` is a FillOp and the TransferWriteOp would overwrite the
|
|
// entire tensor, write directly to the FillOp's operand.
|
|
if (llvm::equal(vecShape, resultType.getShape()) &&
|
|
llvm::all_of(writeInBounds, [](bool b) { return b; }))
|
|
if (auto fill = dest.getDefiningOp<FillOp>())
|
|
dest = fill.output();
|
|
|
|
// Generate TransferWriteOp.
|
|
auto writeIndices =
|
|
ofrToIndexValues(rewriter, padOp.getLoc(), padOp.getMixedLowPad());
|
|
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
|
|
padOp, read, dest, writeIndices, ArrayRef<bool>{writeInBounds});
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Base pattern for rewriting tensor::PadOps whose result is consumed by a
|
|
/// given operation type OpTy.
|
|
template <typename OpTy>
|
|
struct VectorizePadOpUserPattern : public OpRewritePattern<tensor::PadOp> {
|
|
using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tensor::PadOp padOp,
|
|
PatternRewriter &rewriter) const final {
|
|
bool changed = false;
|
|
// Insert users in vector, because some users may be replaced/removed.
|
|
for (auto *user : llvm::to_vector<4>(padOp->getUsers()))
|
|
if (auto op = dyn_cast<OpTy>(user))
|
|
changed |= rewriteUser(rewriter, padOp, op).succeeded();
|
|
return success(changed);
|
|
}
|
|
|
|
protected:
|
|
virtual LogicalResult rewriteUser(PatternRewriter &rewriter,
|
|
tensor::PadOp padOp, OpTy op) const = 0;
|
|
};
|
|
|
|
/// Rewrite use of tensor::PadOp result in TransferReadOp. E.g.:
|
|
/// ```
|
|
/// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
|
|
/// %r = vector.transfer_read %0[%c0, %c0], %cst
|
|
/// {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32>
|
|
/// ```
|
|
/// is rewritten to:
|
|
/// ```
|
|
/// %r = vector.transfer_read %src[%c0, %c0], %padding
|
|
/// {in_bounds = [true, true]}
|
|
/// : tensor<?x?xf32>, vector<17x5xf32>
|
|
/// ```
|
|
/// Note: By restricting this pattern to in-bounds TransferReadOps, we can be
|
|
/// sure that the original padding value %cst was never used.
|
|
///
|
|
/// This rewrite is possible if:
|
|
/// - `xferOp` has no out-of-bounds dims or mask.
|
|
/// - Low padding is static 0.
|
|
/// - Single, scalar padding value.
|
|
struct PadOpVectorizationWithTransferReadPattern
|
|
: public VectorizePadOpUserPattern<vector::TransferReadOp> {
|
|
using VectorizePadOpUserPattern<
|
|
vector::TransferReadOp>::VectorizePadOpUserPattern;
|
|
|
|
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
|
|
vector::TransferReadOp xferOp) const override {
|
|
// Low padding must be static 0.
|
|
if (!padOp.hasZeroLowPad())
|
|
return failure();
|
|
// Pad value must be a constant.
|
|
auto padValue = padOp.getConstantPaddingValue();
|
|
if (!padValue)
|
|
return failure();
|
|
// Padding value of existing `xferOp` is unused.
|
|
if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
|
|
return failure();
|
|
|
|
rewriter.updateRootInPlace(xferOp, [&]() {
|
|
SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
|
|
xferOp->setAttr(xferOp.getInBoundsAttrName(),
|
|
rewriter.getBoolArrayAttr(inBounds));
|
|
xferOp.getSourceMutable().assign(padOp.getSource());
|
|
xferOp.getPaddingMutable().assign(padValue);
|
|
});
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Rewrite use of tensor::PadOp result in TransferWriteOp.
|
|
/// This pattern rewrites TransferWriteOps that write to a padded tensor
|
|
/// value, where the same amount of padding is immediately removed again after
|
|
/// the write. In such cases, the TransferWriteOp can write to the non-padded
|
|
/// tensor value and apply out-of-bounds masking. E.g.:
|
|
/// ```
|
|
/// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
|
|
/// : tensor<...> to tensor<?x?xf32>
|
|
/// %1 = tensor.pad %0 ... : tensor<?x?xf32> to tensor<17x5xf32>
|
|
/// %2 = vector.transfer_write %vec, %1[...]
|
|
/// : vector<17x5xf32>, tensor<17x5xf32>
|
|
/// %r = tensor.extract_slice %2[0, 0] [%s0, %s1] [1, 1]
|
|
/// : tensor<17x5xf32> to tensor<?x?xf32>
|
|
/// ```
|
|
/// is rewritten to:
|
|
/// ```
|
|
/// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
|
|
/// : tensor<...> to tensor<?x?xf32>
|
|
/// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>,
|
|
/// tensor<?x?xf32>
|
|
/// ```
|
|
/// Note: It is important that the ExtractSliceOp %r resizes the result of the
|
|
/// TransferWriteOp to the same size as the input of the TensorPadOp (or an
|
|
/// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ
|
|
/// from %r's old dimensions.
|
|
///
|
|
/// This rewrite is possible if:
|
|
/// - Low padding is static 0.
|
|
/// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
|
|
/// ExtractSliceOp trims the same amount of padding that was added
|
|
/// beforehand.
|
|
/// - Single, scalar padding value.
|
|
struct PadOpVectorizationWithTransferWritePattern
|
|
: public VectorizePadOpUserPattern<vector::TransferWriteOp> {
|
|
using VectorizePadOpUserPattern<
|
|
vector::TransferWriteOp>::VectorizePadOpUserPattern;
|
|
|
|
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
|
|
vector::TransferWriteOp xferOp) const override {
|
|
// TODO: support 0-d corner case.
|
|
if (xferOp.getTransferRank() == 0)
|
|
return failure();
|
|
|
|
// Low padding must be static 0.
|
|
if (!padOp.hasZeroLowPad())
|
|
return failure();
|
|
// Pad value must be a constant.
|
|
auto padValue = padOp.getConstantPaddingValue();
|
|
if (!padValue)
|
|
return failure();
|
|
// TransferWriteOp result must be directly consumed by an ExtractSliceOp.
|
|
if (!xferOp->hasOneUse())
|
|
return failure();
|
|
auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
|
|
if (!trimPadding)
|
|
return failure();
|
|
// Only static zero offsets supported when trimming padding.
|
|
if (!trimPadding.hasZeroOffset())
|
|
return failure();
|
|
// trimPadding must remove the amount of padding that was added earlier.
|
|
if (!hasSameTensorSize(padOp.getSource(), trimPadding))
|
|
return failure();
|
|
|
|
// Insert the new TransferWriteOp at position of the old TransferWriteOp.
|
|
rewriter.setInsertionPoint(xferOp);
|
|
|
|
SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
|
|
auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
|
|
xferOp, padOp.getSource().getType(), xferOp.getVector(),
|
|
padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
|
|
xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds));
|
|
rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
|
|
|
|
return success();
|
|
}
|
|
|
|
/// Check if `beforePadding` and `afterTrimming` have the same tensor size,
|
|
/// i.e., same dimensions.
|
|
///
|
|
/// Dimensions may be static, dynamic or mix of both. In case of dynamic
|
|
/// dimensions, this function tries to infer the (static) tensor size by
|
|
/// looking at the defining op and utilizing op-specific knowledge.
|
|
///
|
|
/// This is a conservative analysis. In case equal tensor sizes cannot be
|
|
/// proven statically, this analysis returns `false` even though the tensor
|
|
/// sizes may turn out to be equal at runtime.
|
|
bool hasSameTensorSize(Value beforePadding,
|
|
tensor::ExtractSliceOp afterTrimming) const {
|
|
// If the input to tensor::PadOp is a CastOp, try with both CastOp
|
|
// result and CastOp operand.
|
|
if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
|
|
if (hasSameTensorSize(castOp.getSource(), afterTrimming))
|
|
return true;
|
|
|
|
auto t1 = dyn_cast<RankedTensorType>(beforePadding.getType());
|
|
auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
|
|
// Only RankedTensorType supported.
|
|
if (!t1 || !t2)
|
|
return false;
|
|
// Rank of both values must be the same.
|
|
if (t1.getRank() != t2.getRank())
|
|
return false;
|
|
|
|
// All static dimensions must be the same. Mixed cases (e.g., dimension
|
|
// static in `t1` but dynamic in `t2`) are not supported.
|
|
for (unsigned i = 0; i < t1.getRank(); ++i) {
|
|
if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
|
|
return false;
|
|
if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
|
|
return false;
|
|
}
|
|
|
|
// Nothing more to check if all dimensions are static.
|
|
if (t1.getNumDynamicDims() == 0)
|
|
return true;
|
|
|
|
// All dynamic sizes must be the same. The only supported case at the
|
|
// moment is when `beforePadding` is an ExtractSliceOp (or a cast
|
|
// thereof).
|
|
|
|
// Apart from CastOp, only ExtractSliceOp is supported.
|
|
auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
|
|
if (!beforeSlice)
|
|
return false;
|
|
|
|
assert(static_cast<size_t>(t1.getRank()) ==
|
|
beforeSlice.getMixedSizes().size());
|
|
assert(static_cast<size_t>(t2.getRank()) ==
|
|
afterTrimming.getMixedSizes().size());
|
|
|
|
for (unsigned i = 0; i < t1.getRank(); ++i) {
|
|
// Skip static dimensions.
|
|
if (!t1.isDynamicDim(i))
|
|
continue;
|
|
auto size1 = beforeSlice.getMixedSizes()[i];
|
|
auto size2 = afterTrimming.getMixedSizes()[i];
|
|
|
|
// Case 1: Same value or same constant int.
|
|
if (isEqualConstantIntOrValue(size1, size2))
|
|
continue;
|
|
|
|
// Other cases: Take a deeper look at defining ops of values.
|
|
auto v1 = llvm::dyn_cast_if_present<Value>(size1);
|
|
auto v2 = llvm::dyn_cast_if_present<Value>(size2);
|
|
if (!v1 || !v2)
|
|
return false;
|
|
|
|
// Case 2: Both values are identical AffineMinOps. (Should not happen if
|
|
// CSE is run.)
|
|
auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
|
|
auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
|
|
if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
|
|
minOp1.getOperands() == minOp2.getOperands())
|
|
continue;
|
|
|
|
// Add additional cases as needed.
|
|
}
|
|
|
|
// All tests passed.
|
|
return true;
|
|
}
|
|
};
|
|
|
|
/// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
|
|
/// ```
|
|
/// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
|
|
/// %r = tensor.insert_slice %0
|
|
/// into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1]
|
|
/// : tensor<17x5xf32> into tensor<?x?x17x5xf32>
|
|
/// ```
|
|
/// is rewritten to:
|
|
/// ```
|
|
/// %0 = vector.transfer_read %src[%c0, %c0], %padding
|
|
/// : tensor<?x?xf32>, vector<17x5xf32>
|
|
/// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0]
|
|
/// {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32>
|
|
/// ```
|
|
///
|
|
/// This rewrite is possible if:
|
|
/// - Low padding is static 0.
|
|
/// - `padOp` result shape is static.
|
|
/// - The entire padded tensor is inserted.
|
|
/// (Implies that sizes of `insertOp` are all static.)
|
|
/// - Only unit strides in `insertOp`.
|
|
/// - Single, scalar padding value.
|
|
/// - `padOp` result not used as destination.
|
|
struct PadOpVectorizationWithInsertSlicePattern
|
|
: public VectorizePadOpUserPattern<tensor::InsertSliceOp> {
|
|
using VectorizePadOpUserPattern<
|
|
tensor::InsertSliceOp>::VectorizePadOpUserPattern;
|
|
|
|
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
|
|
tensor::InsertSliceOp insertOp) const override {
|
|
// Low padding must be static 0.
|
|
if (!padOp.hasZeroLowPad())
|
|
return failure();
|
|
// Only unit stride supported.
|
|
if (!insertOp.hasUnitStride())
|
|
return failure();
|
|
// Pad value must be a constant.
|
|
auto padValue = padOp.getConstantPaddingValue();
|
|
if (!padValue)
|
|
return failure();
|
|
// Dynamic shapes not supported.
|
|
if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
|
|
return failure();
|
|
// Pad result not used as destination.
|
|
if (insertOp.getDest() == padOp.getResult())
|
|
return failure();
|
|
|
|
auto vecType = VectorType::get(padOp.getType().getShape(),
|
|
padOp.getType().getElementType());
|
|
unsigned vecRank = vecType.getRank();
|
|
unsigned tensorRank = insertOp.getType().getRank();
|
|
|
|
// Check if sizes match: Insert the entire tensor into most minor dims.
|
|
// (No permutations allowed.)
|
|
SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
|
|
expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
|
|
if (!llvm::all_of(
|
|
llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
|
|
return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
|
|
}))
|
|
return failure();
|
|
|
|
// Insert the TransferReadOp and TransferWriteOp at the position of the
|
|
// InsertSliceOp.
|
|
rewriter.setInsertionPoint(insertOp);
|
|
|
|
// Generate TransferReadOp: Read entire source tensor and add high
|
|
// padding.
|
|
SmallVector<Value> readIndices(
|
|
vecRank, rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
|
|
auto read = rewriter.create<vector::TransferReadOp>(
|
|
padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);
|
|
|
|
// Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
|
|
// specified offsets. Write is fully in-bounds because a InsertSliceOp's
|
|
// source must fit into the destination at the specified offsets.
|
|
auto writeIndices =
|
|
ofrToIndexValues(rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
|
|
SmallVector<bool> inBounds(vecRank, true);
|
|
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
|
|
insertOp, read, insertOp.getDest(), writeIndices,
|
|
ArrayRef<bool>{inBounds});
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
void mlir::linalg::populatePadOpVectorizationPatterns(
|
|
RewritePatternSet &patterns, PatternBenefit baseBenefit) {
|
|
patterns.add<GenericPadOpVectorizationPattern>(patterns.getContext(),
|
|
baseBenefit);
|
|
// Try these specialized patterns first before resorting to the generic one.
|
|
patterns.add<PadOpVectorizationWithTransferReadPattern,
|
|
PadOpVectorizationWithTransferWritePattern,
|
|
PadOpVectorizationWithInsertSlicePattern>(
|
|
patterns.getContext(), baseBenefit.getBenefit() + 1);
|
|
}
|
|
|
|
//----------------------------------------------------------------------------//
|
|
// Forwarding patterns
|
|
//----------------------------------------------------------------------------//
|
|
|
|
/// Check whether there is any interleaved use of any `values` between
|
|
/// `firstOp` and `secondOp`. Conservatively return `true` if any op or value
|
|
/// is in a different block.
|
|
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
|
|
ValueRange values) {
|
|
if (firstOp->getBlock() != secondOp->getBlock() ||
|
|
!firstOp->isBeforeInBlock(secondOp)) {
|
|
LDBG("interleavedUses precondition failed, firstOp: "
|
|
<< *firstOp << ", second op: " << *secondOp << "\n");
|
|
return true;
|
|
}
|
|
for (auto v : values) {
|
|
for (auto &u : v.getUses()) {
|
|
Operation *owner = u.getOwner();
|
|
if (owner == firstOp || owner == secondOp)
|
|
continue;
|
|
// TODO: this is too conservative, use dominance info in the future.
|
|
if (owner->getBlock() == firstOp->getBlock() &&
|
|
(owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
|
|
continue;
|
|
LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp
|
|
<< ", second op: " << *secondOp << "\n");
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
/// Return the unique subview use of `v` if it is indeed unique, null
|
|
/// otherwise.
|
|
static memref::SubViewOp getSubViewUseIfUnique(Value v) {
|
|
memref::SubViewOp subViewOp;
|
|
for (auto &u : v.getUses()) {
|
|
if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
|
|
if (subViewOp)
|
|
return memref::SubViewOp();
|
|
subViewOp = newSubViewOp;
|
|
}
|
|
}
|
|
return subViewOp;
|
|
}
|
|
|
|
/// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
|
|
/// when available.
|
|
LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
|
|
vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
|
|
|
|
// TODO: support mask.
|
|
if (xferOp.getMask())
|
|
return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
|
|
|
|
// Transfer into `view`.
|
|
Value viewOrAlloc = xferOp.getSource();
|
|
if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
|
|
!viewOrAlloc.getDefiningOp<memref::AllocOp>())
|
|
return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
|
|
|
|
// Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
|
|
memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
|
|
if (!subViewOp)
|
|
return rewriter.notifyMatchFailure(xferOp, "no subview found");
|
|
Value subView = subViewOp.getResult();
|
|
|
|
// Find the copy into `subView` without interleaved uses.
|
|
memref::CopyOp copyOp;
|
|
for (auto &u : subView.getUses()) {
|
|
if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
|
|
assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
|
|
if (newCopyOp.getTarget() != subView)
|
|
continue;
|
|
if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
|
|
continue;
|
|
copyOp = newCopyOp;
|
|
break;
|
|
}
|
|
}
|
|
if (!copyOp)
|
|
return rewriter.notifyMatchFailure(xferOp, "no copy found");
|
|
|
|
// Find the fill into `viewOrAlloc` without interleaved uses before the
|
|
// copy.
|
|
FillOp maybeFillOp;
|
|
for (auto &u : viewOrAlloc.getUses()) {
|
|
if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
|
|
assert(isa<MemRefType>(newFillOp.output().getType()));
|
|
if (newFillOp.output() != viewOrAlloc)
|
|
continue;
|
|
if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
|
|
continue;
|
|
maybeFillOp = newFillOp;
|
|
break;
|
|
}
|
|
}
|
|
// Ensure padding matches.
|
|
if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
|
|
return rewriter.notifyMatchFailure(xferOp,
|
|
"padding value does not match fill");
|
|
|
|
// `in` is the subview that memref.copy reads. Replace it.
|
|
Value in = copyOp.getSource();
|
|
|
|
// memref.copy + linalg.fill can be used to create a padded local buffer.
|
|
// The `masked` attribute is only valid on this padded buffer.
|
|
// When forwarding to vector.transfer_read, the attribute must be reset
|
|
// conservatively.
|
|
Value res = rewriter.create<vector::TransferReadOp>(
|
|
xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.getIndices(),
|
|
xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
|
|
// in_bounds is explicitly reset
|
|
/*inBoundsAttr=*/ArrayAttr());
|
|
|
|
if (maybeFillOp)
|
|
rewriter.eraseOp(maybeFillOp);
|
|
rewriter.eraseOp(copyOp);
|
|
rewriter.replaceOp(xferOp, res);
|
|
|
|
return success();
|
|
}
|
|
|
|
/// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
|
|
/// when available.
|
|
LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
|
|
vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
|
|
// TODO: support mask.
|
|
if (xferOp.getMask())
|
|
return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
|
|
|
|
// Transfer into `viewOrAlloc`.
|
|
Value viewOrAlloc = xferOp.getSource();
|
|
if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
|
|
!viewOrAlloc.getDefiningOp<memref::AllocOp>())
|
|
return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
|
|
|
|
// Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
|
|
memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
|
|
if (!subViewOp)
|
|
return rewriter.notifyMatchFailure(xferOp, "no subview found");
|
|
Value subView = subViewOp.getResult();
|
|
|
|
// Find the copy from `subView` without interleaved uses.
|
|
memref::CopyOp copyOp;
|
|
for (auto &u : subViewOp.getResult().getUses()) {
|
|
if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
|
|
if (newCopyOp.getSource() != subView)
|
|
continue;
|
|
if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
|
|
continue;
|
|
copyOp = newCopyOp;
|
|
break;
|
|
}
|
|
}
|
|
if (!copyOp)
|
|
return rewriter.notifyMatchFailure(xferOp, "no copy found");
|
|
|
|
// `out` is the subview copied into that we replace.
|
|
assert(isa<MemRefType>(copyOp.getTarget().getType()));
|
|
Value out = copyOp.getTarget();
|
|
|
|
// Forward vector.transfer into copy.
|
|
// memref.copy + linalg.fill can be used to create a padded local buffer.
|
|
// The `masked` attribute is only valid on this padded buffer.
|
|
// When forwarding to vector.transfer_write, the attribute must be reset
|
|
// conservatively.
|
|
rewriter.create<vector::TransferWriteOp>(
|
|
xferOp.getLoc(), xferOp.getVector(), out, xferOp.getIndices(),
|
|
xferOp.getPermutationMapAttr(), xferOp.getMask(),
|
|
// in_bounds is explicitly reset
|
|
/*inBoundsAttr=*/ArrayAttr());
|
|
|
|
rewriter.eraseOp(copyOp);
|
|
rewriter.eraseOp(xferOp);
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Convolution vectorization patterns
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
template <int N>
|
|
static void bindShapeDims(ShapedType shapedType) {}
|
|
|
|
template <int N, typename IntTy, typename... IntTy2>
|
|
static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
|
|
val = shapedType.getShape()[N];
|
|
bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
|
|
}
|
|
|
|
/// Bind a pack of int& to the leading dimensions of shapedType.getShape().
|
|
template <typename... IntTy>
|
|
static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
|
|
bindShapeDims<0>(shapedType, vals...);
|
|
}
|
|
|
|
namespace {
|
|
bool isCastOfBlockArgument(Operation *op) {
|
|
return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
|
|
isa<BlockArgument>(op->getOperand(0));
|
|
}
|
|
|
|
bool isSupportedPoolKind(vector::CombiningKind kind) {
|
|
switch (kind) {
|
|
case vector::CombiningKind::ADD:
|
|
case vector::CombiningKind::MAXF:
|
|
case vector::CombiningKind::MAXIMUMF:
|
|
case vector::CombiningKind::MAXSI:
|
|
case vector::CombiningKind::MAXUI:
|
|
case vector::CombiningKind::MINF:
|
|
case vector::CombiningKind::MINIMUMF:
|
|
case vector::CombiningKind::MINSI:
|
|
case vector::CombiningKind::MINUI:
|
|
return true;
|
|
default:
|
|
return false;
|
|
}
|
|
}
|
|
|
|
/// Generate a vector implementation for either:
|
|
/// ```
|
|
/// Op def: ( w, kw )
|
|
/// Iters: ({Par(), Red()})
|
|
/// Layout: {{w + kw}, {kw}, {w}}
|
|
/// ```
|
|
/// kw is unrolled.
|
|
///
|
|
/// or
|
|
///
|
|
/// ```
|
|
/// Op def: ( n, w, c, kw, f )
|
|
/// Iters: ({Par(), Par(), Par(), Red(), Red()})
|
|
/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
|
|
/// ```
|
|
/// kw is unrolled, w is unrolled iff dilationW > 1.
|
|
///
|
|
/// or
|
|
///
|
|
/// ```
|
|
/// Op def: ( n, c, w, f, kw )
|
|
/// Iters: ({Par(), Par(), Par(), Red(), Red()})
|
|
/// Layout: {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
|
|
/// ```
|
|
/// kw is unrolled, w is unrolled iff dilationW > 1.
|
|
///
|
|
/// or
|
|
///
|
|
/// ```
|
|
/// Op def: ( n, w, c, kw )
|
|
/// Iters: ({Par(), Par(), Par(), Red()})
|
|
/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
|
|
/// ```
|
|
/// kw is unrolled, w is unrolled iff dilationW > 1.
|
|
struct Conv1DGenerator
|
|
: public StructuredGenerator<LinalgOp, utils::IteratorType> {
|
|
Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp, int strideW,
|
|
int dilationW)
|
|
: StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp),
|
|
strideW(strideW), dilationW(dilationW) {
|
|
// Determine whether `linalgOp` can be generated with this generator
|
|
if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
|
|
return;
|
|
lhsShaped = linalgOp.getDpsInputOperand(0)->get();
|
|
rhsShaped = linalgOp.getDpsInputOperand(1)->get();
|
|
resShaped = linalgOp.getDpsInitOperand(0)->get();
|
|
lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
|
|
rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
|
|
resShapedType = dyn_cast<ShapedType>(resShaped.getType());
|
|
if (!lhsShapedType || !rhsShapedType || !resShapedType)
|
|
return;
|
|
// (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
|
|
// (non-channeled convolution -> LHS and RHS both have single dimensions).
|
|
if (!((lhsShapedType.getRank() == 3 && resShapedType.getRank() == 3) ||
|
|
(lhsShapedType.getRank() == 1 && resShapedType.getRank() == 1)))
|
|
return;
|
|
|
|
Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
|
|
if (!reduceOp)
|
|
return;
|
|
redOp = reduceOp->getName().getIdentifier();
|
|
|
|
if (!setOperKind(reduceOp))
|
|
return;
|
|
auto maybeKind = getCombinerOpKind(reduceOp);
|
|
if (!maybeKind || (*maybeKind != vector::CombiningKind::ADD &&
|
|
(oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
|
|
return;
|
|
}
|
|
|
|
auto rhsRank = rhsShapedType.getRank();
|
|
switch (oper) {
|
|
case Conv:
|
|
if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
|
|
return;
|
|
break;
|
|
case Pool:
|
|
if (rhsRank != 1)
|
|
return;
|
|
break;
|
|
}
|
|
// The op is now known to be valid.
|
|
valid = true;
|
|
}
|
|
|
|
/// Generate a vector implementation for:
|
|
/// ```
|
|
/// Op def: ( w, kw )
|
|
/// Iters: ({Par(), Red()})
|
|
/// Layout: {{w + kw}, {kw}, {w}}
|
|
/// ```
|
|
/// kw is always unrolled.
|
|
///
|
|
/// or
|
|
///
|
|
/// ```
|
|
/// Op def: ( n, w, c, kw, f )
|
|
/// Iters: ({Par(), Par(), Par(), Red(), Red()})
|
|
/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
|
|
/// ```
|
|
/// kw is always unrolled.
|
|
/// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
|
|
/// > 1.
|
|
FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
|
|
if (!valid)
|
|
return rewriter.notifyMatchFailure(op, "unvectorizable 1-D conv/pool");
|
|
|
|
int64_t nSize, wSize, cSize, kwSize, fSize;
|
|
SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
|
|
bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
|
|
switch (conv1DOpOrder) {
|
|
case Conv1DOpOrder::W:
|
|
// Initialize unused dimensions
|
|
nSize = fSize = cSize = 0;
|
|
// out{W}
|
|
bindShapeDims(resShapedType, wSize);
|
|
// kernel{kw}
|
|
bindShapeDims(rhsShapedType, kwSize);
|
|
lhsShape = {// iw = ow + kw - 1
|
|
// (i.e. 16 convolved with 3 -> 14)
|
|
(wSize + kwSize - 1)};
|
|
rhsShape = {kwSize};
|
|
resShape = {wSize};
|
|
break;
|
|
case Conv1DOpOrder::Nwc:
|
|
// out{n, w, f}
|
|
bindShapeDims(resShapedType, nSize, wSize, fSize);
|
|
switch (oper) {
|
|
case Conv:
|
|
// kernel{kw, c, f}
|
|
bindShapeDims(rhsShapedType, kwSize, cSize);
|
|
break;
|
|
case Pool:
|
|
// kernel{kw}
|
|
bindShapeDims(rhsShapedType, kwSize);
|
|
cSize = fSize;
|
|
break;
|
|
}
|
|
lhsShape = {nSize,
|
|
// iw = ow * sw + kw * dw - 1
|
|
// (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
|
|
// Perform the proper inclusive -> exclusive -> inclusive.
|
|
((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
|
|
1,
|
|
cSize};
|
|
switch (oper) {
|
|
case Conv:
|
|
rhsShape = {kwSize, cSize, fSize};
|
|
break;
|
|
case Pool:
|
|
rhsShape = {kwSize};
|
|
break;
|
|
}
|
|
resShape = {nSize, wSize, fSize};
|
|
break;
|
|
case Conv1DOpOrder::Ncw:
|
|
// out{n, f, w}
|
|
bindShapeDims(resShapedType, nSize, fSize, wSize);
|
|
switch (oper) {
|
|
case Conv:
|
|
// kernel{f, c, kw}
|
|
bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
|
|
break;
|
|
case Pool:
|
|
// kernel{kw}
|
|
bindShapeDims(rhsShapedType, kwSize);
|
|
cSize = fSize;
|
|
break;
|
|
}
|
|
lhsShape = {nSize, cSize,
|
|
// iw = ow * sw + kw * dw - 1
|
|
// (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
|
|
// Perform the proper inclusive -> exclusive -> inclusive.
|
|
((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
|
|
1};
|
|
switch (oper) {
|
|
case Conv:
|
|
rhsShape = {fSize, cSize, kwSize};
|
|
break;
|
|
case Pool:
|
|
rhsShape = {kwSize};
|
|
break;
|
|
}
|
|
resShape = {nSize, fSize, wSize};
|
|
break;
|
|
}
|
|
|
|
vector::TransferWriteOp write;
|
|
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
|
|
|
// w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
|
|
// When strideW == 1, we can batch the contiguous loads and avoid
|
|
// unrolling
|
|
int64_t wSizeStep = strideW == 1 ? wSize : 1;
|
|
|
|
Type lhsEltType = lhsShapedType.getElementType();
|
|
Type rhsEltType = rhsShapedType.getElementType();
|
|
Type resEltType = resShapedType.getElementType();
|
|
auto lhsType = VectorType::get(lhsShape, lhsEltType);
|
|
auto rhsType = VectorType::get(rhsShape, rhsEltType);
|
|
auto resType = VectorType::get(resShape, resEltType);
|
|
// Zero padding with the corresponding dimensions for lhs, rhs and res.
|
|
SmallVector<Value> lhsPadding(lhsShape.size(), zero);
|
|
SmallVector<Value> rhsPadding(rhsShape.size(), zero);
|
|
SmallVector<Value> resPadding(resShape.size(), zero);
|
|
|
|
// Read the whole lhs, rhs and res in one shot (with zero padding).
|
|
Value lhs = rewriter.create<vector::TransferReadOp>(loc, lhsType, lhsShaped,
|
|
lhsPadding);
|
|
// This is needed only for Conv.
|
|
Value rhs = nullptr;
|
|
if (oper == Conv)
|
|
rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
|
|
rhsPadding);
|
|
Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped,
|
|
resPadding);
|
|
|
|
// The base vectorization case for channeled convolution is input: {n,w,c},
|
|
// weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
|
|
// vectorization case, we do pre transpose on input, weight, and output.
|
|
switch (conv1DOpOrder) {
|
|
case Conv1DOpOrder::W:
|
|
case Conv1DOpOrder::Nwc:
|
|
// Base case, so no transposes necessary.
|
|
break;
|
|
case Conv1DOpOrder::Ncw: {
|
|
// To match base vectorization case, we pre-transpose current case.
|
|
// ncw -> nwc
|
|
static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
|
|
lhs = rewriter.create<vector::TransposeOp>(loc, lhs, permLhs);
|
|
// fcw -> wcf
|
|
static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
|
|
|
|
// This is needed only for Conv.
|
|
if (oper == Conv)
|
|
rhs = rewriter.create<vector::TransposeOp>(loc, rhs, permRhs);
|
|
// nfw -> nwf
|
|
static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
|
|
res = rewriter.create<vector::TransposeOp>(loc, res, permRes);
|
|
break;
|
|
}
|
|
}
|
|
|
|
//===------------------------------------------------------------------===//
|
|
// Begin vector-only rewrite part
|
|
//===------------------------------------------------------------------===//
|
|
// Unroll along kw and read slices of lhs and rhs.
|
|
SmallVector<Value> lhsVals, rhsVals, resVals;
|
|
lhsVals = extractConvInputSlices(rewriter, loc, lhs, nSize, wSize, cSize,
|
|
kwSize, strideW, dilationW, wSizeStep,
|
|
isSingleChanneled);
|
|
// Do not do for pooling.
|
|
if (oper == Conv)
|
|
rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize);
|
|
resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize,
|
|
wSizeStep, isSingleChanneled);
|
|
|
|
auto linearIndex = [&](int64_t kw, int64_t w) {
|
|
return kw * (wSize / wSizeStep) + w;
|
|
};
|
|
|
|
// Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f} or
|
|
// perform outerproduct for non-channeled convolution or
|
|
// perform simple arith operation for pooling
|
|
for (int64_t kw = 0; kw < kwSize; ++kw) {
|
|
for (int64_t w = 0; w < wSize; w += wSizeStep) {
|
|
switch (oper) {
|
|
case Conv:
|
|
if (isSingleChanneled) {
|
|
resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
|
|
lhsVals[linearIndex(kw, w)],
|
|
rhsVals[kw], resVals[w]);
|
|
} else {
|
|
resVals[w] = conv1dSliceAsContraction(rewriter, loc,
|
|
lhsVals[linearIndex(kw, w)],
|
|
rhsVals[kw], resVals[w]);
|
|
}
|
|
break;
|
|
case Pool:
|
|
resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
|
|
resVals[w]);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
res = insertConvResultSlices(rewriter, loc, res, wSize, wSizeStep, resVals,
|
|
isSingleChanneled);
|
|
//===------------------------------------------------------------------===//
|
|
// End vector-only rewrite part
|
|
//===------------------------------------------------------------------===//
|
|
|
|
// The base vectorization case for channeled convolution is output: {n,w,f}
|
|
// To reuse the result from base pattern vectorization case, we post
|
|
// transpose the base case result.
|
|
switch (conv1DOpOrder) {
|
|
case Conv1DOpOrder::W:
|
|
case Conv1DOpOrder::Nwc:
|
|
// Base case, so no transposes necessary.
|
|
break;
|
|
case Conv1DOpOrder::Ncw: {
|
|
// nwf -> nfw
|
|
static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
|
|
res = rewriter.create<vector::TransposeOp>(loc, res, perm);
|
|
break;
|
|
}
|
|
}
|
|
|
|
return rewriter
|
|
.create<vector::TransferWriteOp>(loc, res, resShaped, resPadding)
|
|
.getOperation();
|
|
}
|
|
|
|
// Take a value and widen to have the same element type as `ty`.
|
|
Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
|
|
const Type srcElementType = getElementTypeOrSelf(val.getType());
|
|
const Type dstElementType = getElementTypeOrSelf(ty);
|
|
assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
|
|
if (srcElementType == dstElementType)
|
|
return val;
|
|
|
|
const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth();
|
|
const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth();
|
|
const Type dstType =
|
|
cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType);
|
|
|
|
if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
|
|
return rewriter.create<arith::SIToFPOp>(loc, dstType, val);
|
|
}
|
|
|
|
if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
|
|
srcWidth < dstWidth)
|
|
return rewriter.create<arith::ExtFOp>(loc, dstType, val);
|
|
|
|
if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
|
|
srcWidth < dstWidth)
|
|
return rewriter.create<arith::ExtSIOp>(loc, dstType, val);
|
|
|
|
assert(false && "unhandled promotion case");
|
|
return nullptr;
|
|
}
|
|
|
|
// Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
|
|
Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
|
|
Value lhs, Value rhs, Value res) {
|
|
vector::IteratorType par = vector::IteratorType::parallel;
|
|
vector::IteratorType red = vector::IteratorType::reduction;
|
|
AffineExpr n, w, f, c;
|
|
bindDims(ctx, n, w, f, c);
|
|
lhs = promote(rewriter, loc, lhs, res.getType());
|
|
rhs = promote(rewriter, loc, rhs, res.getType());
|
|
return rewriter.create<vector::ContractionOp>(
|
|
loc, lhs, rhs, res,
|
|
/*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
|
|
/*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
|
|
}
|
|
|
|
// Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
|
|
// convolution.
|
|
Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
|
|
Value lhs, Value rhs, Value res) {
|
|
return rewriter.create<vector::OuterProductOp>(
|
|
loc, res.getType(), lhs, rhs, res, vector::CombiningKind::ADD);
|
|
}
|
|
|
|
// Create a reduction: lhs{n, w, c} -> res{n, w, c}
|
|
Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs,
|
|
Value res) {
|
|
if (isPoolExt)
|
|
lhs = rewriter.create(loc, poolExtOp, lhs, res.getType())->getResult(0);
|
|
return rewriter
|
|
.create(loc, redOp, ArrayRef<Value>{lhs, res}, res.getType())
|
|
->getResult(0);
|
|
}
|
|
|
|
/// Generate a vector implementation for:
|
|
/// ```
|
|
/// Op def: ( n, w, c, kw)
|
|
/// Iters: ({Par(), Par(), Par(), Red()})
|
|
/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
|
|
/// ```
|
|
/// kw is always unrolled.
|
|
/// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
|
|
/// > 1.
|
|
FailureOr<Operation *> depthwiseConv() {
|
|
if (!valid)
|
|
return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
|
|
|
|
int64_t nSize, wSize, cSize, kwSize;
|
|
// kernel{kw, c}
|
|
bindShapeDims(rhsShapedType, kwSize, cSize);
|
|
// out{n, w, c}
|
|
bindShapeDims(resShapedType, nSize, wSize);
|
|
|
|
vector::TransferWriteOp write;
|
|
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
|
|
|
// w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
|
|
// When strideW == 1, we can batch the contiguous loads and avoid
|
|
// unrolling
|
|
int64_t wSizeStep = strideW == 1 ? wSize : 1;
|
|
|
|
Type lhsEltType = lhsShapedType.getElementType();
|
|
Type rhsEltType = rhsShapedType.getElementType();
|
|
Type resEltType = resShapedType.getElementType();
|
|
VectorType lhsType = VectorType::get(
|
|
{nSize,
|
|
// iw = ow * sw + kw * dw - 1
|
|
// (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
|
|
((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
|
|
cSize},
|
|
lhsEltType);
|
|
VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType);
|
|
VectorType resType = VectorType::get({nSize, wSize, cSize}, resEltType);
|
|
|
|
// Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
|
|
// 0].
|
|
Value lhs = rewriter.create<vector::TransferReadOp>(
|
|
loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
|
|
// Read rhs slice of size {kw, c} @ [0, 0].
|
|
Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
|
|
ValueRange{zero, zero});
|
|
// Read res slice of size {n, w, c} @ [0, 0, 0].
|
|
Value res = rewriter.create<vector::TransferReadOp>(
|
|
loc, resType, resShaped, ValueRange{zero, zero, zero});
|
|
|
|
//===------------------------------------------------------------------===//
|
|
// Begin vector-only rewrite part
|
|
//===------------------------------------------------------------------===//
|
|
// Unroll along kw and read slices of lhs and rhs.
|
|
SmallVector<Value> lhsVals, rhsVals, resVals;
|
|
// Extract lhs slice of size {n, wSizeStep, c}
|
|
// @ [0, sw * w + dw * kw, 0].
|
|
for (int64_t kw = 0; kw < kwSize; ++kw) {
|
|
for (int64_t w = 0; w < wSize; w += wSizeStep) {
|
|
lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
|
|
loc, lhs,
|
|
/*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
|
|
/*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
|
|
/*strides=*/ArrayRef<int64_t>{1, 1, 1}));
|
|
}
|
|
}
|
|
// Extract rhs slice of size {c} @ [kw].
|
|
for (int64_t kw = 0; kw < kwSize; ++kw) {
|
|
rhsVals.push_back(rewriter.create<vector::ExtractOp>(
|
|
loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
|
|
}
|
|
// Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
|
|
for (int64_t w = 0; w < wSize; w += wSizeStep) {
|
|
resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
|
|
loc, res,
|
|
/*offsets=*/ArrayRef<int64_t>{0, w, 0},
|
|
/*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
|
|
/*strides=*/ArrayRef<int64_t>{1, 1, 1}));
|
|
}
|
|
|
|
auto linearIndex = [&](int64_t kw, int64_t w) {
|
|
return kw * (wSize / wSizeStep) + w;
|
|
};
|
|
|
|
// Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
|
|
for (int64_t kw = 0; kw < kwSize; ++kw) {
|
|
for (int64_t w = 0; w < wSize; w += wSizeStep) {
|
|
resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc,
|
|
lhsVals[linearIndex(kw, w)],
|
|
rhsVals[kw], resVals[w]);
|
|
}
|
|
}
|
|
|
|
// Its possible we failed to create the Fma.
|
|
if (!llvm::all_of(resVals, [](Value v) { return v; })) {
|
|
// Manually revert (in reverse order) to avoid leaving a bad IR state.
|
|
for (auto &collection :
|
|
{resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
|
|
for (Value v : collection)
|
|
rewriter.eraseOp(v.getDefiningOp());
|
|
return rewriter.notifyMatchFailure(op, "failed to create FMA");
|
|
}
|
|
|
|
// Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
|
|
// This does not depend on kw.
|
|
for (int64_t w = 0; w < wSize; w += wSizeStep) {
|
|
res = rewriter.create<vector::InsertStridedSliceOp>(
|
|
loc, resVals[w], res,
|
|
/*offsets=*/ArrayRef<int64_t>{0, w, 0},
|
|
/*strides=*/ArrayRef<int64_t>{1, 1, 1});
|
|
}
|
|
//===------------------------------------------------------------------===//
|
|
// End vector-only rewrite part
|
|
//===------------------------------------------------------------------===//
|
|
|
|
// Write back res slice of size {n, w, c} @ [0, 0, 0].
|
|
return rewriter
|
|
.create<vector::TransferWriteOp>(loc, res, resShaped,
|
|
ValueRange{zero, zero, zero})
|
|
.getOperation();
|
|
}
|
|
|
|
/// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to MulAcc
|
|
Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
|
|
Value lhs, Value rhs, Value res) {
|
|
auto rhsTy = cast<ShapedType>(rhs.getType());
|
|
auto resTy = cast<ShapedType>(res.getType());
|
|
|
|
// TODO(suderman): Change this to use a vector.ima intrinsic.
|
|
lhs = promote(rewriter, loc, lhs, resTy);
|
|
|
|
rhs = rewriter.create<vector::BroadcastOp>(
|
|
loc, resTy.clone(rhsTy.getElementType()), rhs);
|
|
rhs = promote(rewriter, loc, rhs, resTy);
|
|
|
|
if (!lhs || !rhs)
|
|
return nullptr;
|
|
|
|
if (isa<FloatType>(resTy.getElementType()))
|
|
return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
|
|
|
|
auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
|
|
return rewriter.create<arith::AddIOp>(loc, mul, res);
|
|
}
|
|
|
|
/// Entry point for non-channeled convolution:
|
|
/// {{w + kw}, {kw}, {w}}
|
|
FailureOr<Operation *> generateNonChanneledConv() {
|
|
AffineExpr w, kw;
|
|
bindDims(ctx, w, kw);
|
|
if (!iters({Par(), Red()}))
|
|
return rewriter.notifyMatchFailure(op,
|
|
"failed to match conv::W 1-par 1-red");
|
|
|
|
// No transposition needed.
|
|
if (layout({/*lhsIndex*/ {w + kw},
|
|
/*rhsIndex*/ {kw},
|
|
/*resIndex*/ {w}}))
|
|
return conv(Conv1DOpOrder::W);
|
|
|
|
return rewriter.notifyMatchFailure(op, "not a conv::W layout");
|
|
}
|
|
|
|
/// Entry point that transposes into the common form:
|
|
/// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
|
|
FailureOr<Operation *> generateNwcConv() {
|
|
AffineExpr n, w, f, kw, c;
|
|
bindDims(ctx, n, w, f, kw, c);
|
|
if (!iters({Par(), Par(), Par(), Red(), Red()}))
|
|
return rewriter.notifyMatchFailure(
|
|
op, "failed to match conv::Nwc 3-par 2-red");
|
|
|
|
// No transposition needed.
|
|
if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
|
|
/*rhsIndex*/ {kw, c, f},
|
|
/*resIndex*/ {n, w, f}}))
|
|
return conv(Conv1DOpOrder::Nwc);
|
|
|
|
return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout");
|
|
}
|
|
|
|
/// Entry point that transposes into the common form:
|
|
/// {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
|
|
FailureOr<Operation *> generateNcwConv() {
|
|
AffineExpr n, w, f, kw, c;
|
|
bindDims(ctx, n, f, w, c, kw);
|
|
if (!iters({Par(), Par(), Par(), Red(), Red()}))
|
|
return rewriter.notifyMatchFailure(
|
|
op, "failed to match conv::Ncw 3-par 2-red");
|
|
|
|
if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
|
|
/*rhsIndex*/ {f, c, kw},
|
|
/*resIndex*/ {n, f, w}}))
|
|
return conv(Conv1DOpOrder::Ncw);
|
|
|
|
return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout");
|
|
}
|
|
|
|
/// Entry point that transposes into the common form:
|
|
/// {{n, strideW * w + dilationW * kw, c}, {kw}, {n, w, c}} for pooling
|
|
FailureOr<Operation *> generateNwcPooling() {
|
|
AffineExpr n, w, c, kw;
|
|
bindDims(ctx, n, w, c, kw);
|
|
if (!iters({Par(), Par(), Par(), Red()}))
|
|
return rewriter.notifyMatchFailure(op,
|
|
"failed to match pooling 3-par 1-red");
|
|
|
|
// No transposition needed.
|
|
if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
|
|
/*rhsIndex*/ {kw},
|
|
/*resIndex*/ {n, w, c}}))
|
|
return conv(Conv1DOpOrder::Nwc);
|
|
|
|
return rewriter.notifyMatchFailure(op, "not a pooling::Nwc layout");
|
|
}
|
|
|
|
/// Entry point that transposes into the common form:
|
|
/// {{n, c, strideW * w + dilationW * kw}, {kw}, {n, c, w}} for pooling
|
|
FailureOr<Operation *> generateNcwPooling() {
|
|
AffineExpr n, w, c, kw;
|
|
bindDims(ctx, n, c, w, kw);
|
|
if (!iters({Par(), Par(), Par(), Red()}))
|
|
return rewriter.notifyMatchFailure(op,
|
|
"failed to match pooling 3-par 1-red");
|
|
|
|
if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
|
|
/*rhsIndex*/ {kw},
|
|
/*resIndex*/ {n, c, w}}))
|
|
return conv(Conv1DOpOrder::Ncw);
|
|
|
|
return rewriter.notifyMatchFailure(op, "not a pooling::Ncw layout");
|
|
}
|
|
|
|
/// Entry point that transposes into the common form:
|
|
/// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
|
|
FailureOr<Operation *> generateDilatedConv() {
|
|
AffineExpr n, w, c, kw;
|
|
bindDims(ctx, n, w, c, kw);
|
|
if (!iters({Par(), Par(), Par(), Red()}))
|
|
return rewriter.notifyMatchFailure(
|
|
op, "failed to match depthwise::Nwc conv 3-par 1-red");
|
|
|
|
// No transposition needed.
|
|
if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
|
|
/*rhsIndex*/ {kw, c},
|
|
/*resIndex*/ {n, w, c}}))
|
|
return depthwiseConv();
|
|
|
|
return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
|
|
}
|
|
|
|
private:
|
|
enum OperKind { Conv, Pool };
|
|
bool valid = false;
|
|
OperKind oper = Conv;
|
|
StringAttr redOp;
|
|
StringAttr poolExtOp;
|
|
bool isPoolExt = false;
|
|
int strideW, dilationW;
|
|
Value lhsShaped, rhsShaped, resShaped;
|
|
ShapedType lhsShapedType, rhsShapedType, resShapedType;
|
|
|
|
// Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
|
|
// Returns true iff it is a valid conv/pooling op.
|
|
// If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
|
|
// + yield) and rhs is not used) then it is the body of a pooling
|
|
// If conv, check for single `mul` predecessor. The `mul` operands must be
|
|
// block arguments or extension of block arguments.
|
|
// Otherwise, check for one or zero `ext` predecessor. The `ext` operands
|
|
// must be block arguments or extension of block arguments.
|
|
bool setOperKind(Operation *reduceOp) {
|
|
int numBlockArguments = llvm::count_if(
|
|
reduceOp->getOperands(), [](Value v) { return isa<BlockArgument>(v); });
|
|
switch (numBlockArguments) {
|
|
case 1: {
|
|
// Will be convolution if feeder is a MulOp.
|
|
// Otherwise, if it can be pooling.
|
|
auto feedValIt = llvm::find_if(reduceOp->getOperands(), [](Value v) {
|
|
return !isa<BlockArgument>(v);
|
|
});
|
|
Operation *feedOp = (*feedValIt).getDefiningOp();
|
|
if (isCastOfBlockArgument(feedOp)) {
|
|
oper = Pool;
|
|
isPoolExt = true;
|
|
poolExtOp = feedOp->getName().getIdentifier();
|
|
} else if (!(isa<arith::MulIOp, arith::MulFOp>(feedOp) &&
|
|
llvm::all_of(feedOp->getOperands(), [](Value v) {
|
|
if (isa<BlockArgument>(v))
|
|
return true;
|
|
if (Operation *op = v.getDefiningOp())
|
|
return isCastOfBlockArgument(op);
|
|
return false;
|
|
}))) {
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
case 2:
|
|
// Must be pooling
|
|
oper = Pool;
|
|
isPoolExt = false;
|
|
return true;
|
|
default:
|
|
return false;
|
|
}
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
/// Helper function to vectorize a LinalgOp with convolution semantics.
|
|
// TODO: extend the generic vectorization to support windows and drop this.
|
|
static FailureOr<Operation *> vectorizeConvolution(RewriterBase &rewriter,
|
|
LinalgOp op) {
|
|
// The ConvolutionOpInterface gives us guarantees of existence for
|
|
// strides/dilations. However, we do not need to rely on those, we can simply
|
|
// use them if present, otherwise use the default and let the generic conv.
|
|
// matcher in the ConvGenerator succeed or fail.
|
|
auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
|
|
auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
|
|
auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
|
|
auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
|
|
Conv1DGenerator e(rewriter, op, stride, dilation);
|
|
auto res = e.generateNonChanneledConv();
|
|
if (succeeded(res))
|
|
return res;
|
|
res = e.generateNwcConv();
|
|
if (succeeded(res))
|
|
return res;
|
|
res = e.generateNcwConv();
|
|
if (succeeded(res))
|
|
return res;
|
|
res = e.generateNwcPooling();
|
|
if (succeeded(res))
|
|
return res;
|
|
res = e.generateNcwPooling();
|
|
if (succeeded(res))
|
|
return res;
|
|
return e.generateDilatedConv();
|
|
}
|
|
|
|
struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
|
|
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(LinalgOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
|
|
if (failed(resultOrFail))
|
|
return failure();
|
|
Operation *newOp = *resultOrFail;
|
|
if (newOp->getNumResults() == 0) {
|
|
rewriter.eraseOp(op.getOperation());
|
|
return success();
|
|
}
|
|
assert(newOp->getNumResults() == 1 && "expected single result");
|
|
rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
|
|
return success();
|
|
}
|
|
};
|
|
|
|
void mlir::linalg::populateConvolutionVectorizationPatterns(
|
|
RewritePatternSet &patterns, PatternBenefit benefit) {
|
|
patterns.add<VectorizeConvolution>(patterns.getContext(), benefit);
|
|
}
|