mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-11-29 16:41:27 +00:00
[mlir][Linalg] NFC - Refactor vectorization to be more composable
Differential Revision: https://reviews.llvm.org/D96116
This commit is contained in:
parent
7fe41ac3df
commit
0fcbbde2c7
@ -31,13 +31,6 @@ struct LinalgTilingOptions;
|
||||
//===----------------------------------------------------------------------===//
|
||||
using LinalgLoops = SmallVector<Operation *, 4>;
|
||||
|
||||
struct TiledLinalgOp {
|
||||
LinalgOp op;
|
||||
SmallVector<Operation *, 8> loops;
|
||||
SmallVector<Value, 4> tensorResults;
|
||||
TiledLinalgOp &operator=(const TiledLinalgOp &) = default;
|
||||
};
|
||||
|
||||
/// Populates patterns for vectorization of all ConvN-D ops.
|
||||
void populateConvVectorizationPatterns(
|
||||
MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns,
|
||||
@ -63,6 +56,12 @@ void populateLinalgBufferizePatterns(MLIRContext *context,
|
||||
/// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be
|
||||
/// integers, in the range 0..`tileSizes.size()` without duplications
|
||||
/// (i.e. `[1,1,2]` is an invalid permutation).
|
||||
struct TiledLinalgOp {
|
||||
LinalgOp op;
|
||||
SmallVector<Operation *, 8> loops;
|
||||
SmallVector<Value, 4> tensorResults;
|
||||
TiledLinalgOp &operator=(const TiledLinalgOp &) = default;
|
||||
};
|
||||
Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
|
||||
const LinalgTilingOptions &options);
|
||||
|
||||
@ -264,7 +263,12 @@ Optional<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op,
|
||||
OperationFolder *folder = nullptr);
|
||||
|
||||
/// Emit a suitable vector form for a Linalg op with fully static shape.
|
||||
void vectorizeLinalgOp(OpBuilder &builder, Operation *op);
|
||||
struct VectorizedLinalgOp {
|
||||
SmallVector<Value> tensorResults;
|
||||
VectorizedLinalgOp &operator=(const VectorizedLinalgOp &) = default;
|
||||
};
|
||||
Optional<VectorizedLinalgOp> vectorizeLinalgOp(OpBuilder &builder,
|
||||
Operation *op);
|
||||
|
||||
/// Emits a loop nest of `LoopTy` with the proper body for `op`.
|
||||
template <typename LoopTy>
|
||||
|
@ -468,10 +468,13 @@ LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
|
||||
return failure();
|
||||
if (failed(filter.checkAndNotify(rewriter, linalgOp)))
|
||||
return failure();
|
||||
if (failed(vectorizeLinalgOpPrecondition(op)))
|
||||
Optional<VectorizedLinalgOp> res = vectorizeLinalgOp(rewriter, op);
|
||||
if (!res)
|
||||
return failure();
|
||||
vectorizeLinalgOp(rewriter, op);
|
||||
rewriter.eraseOp(op);
|
||||
if (!res->tensorResults.empty())
|
||||
rewriter.replaceOp(op, res->tensorResults);
|
||||
else
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -248,8 +248,7 @@ vectorizeOneOp(OpBuilder &builder, Operation *op,
|
||||
/// TODO: Reuse opportunities for RAR dependencies.
|
||||
/// 4. Register CustomVectorizationHook for YieldOp to capture the results.
|
||||
/// 5. Iteratively call vectorizeOneOp on the region operations.
|
||||
/// 6. RAUW the linalg op by the results captured vectorizing the YieldOp.
|
||||
static LogicalResult vectorizeAsLinalgGeneric(
|
||||
static Optional<VectorizedLinalgOp> vectorizeAsLinalgGeneric(
|
||||
OpBuilder &builder, LinalgOp linalgOp,
|
||||
ArrayRef<CustomVectorizationHook> customVectorizationHooks = {}) {
|
||||
// 1. Certain Linalg ops do not have a region but only a region builder.
|
||||
@ -306,7 +305,7 @@ static LogicalResult vectorizeAsLinalgGeneric(
|
||||
VectorizationResult result = vectorizeOneOp(builder, &op, bvm, hooks);
|
||||
if (result.status == VectorizationStatus::Failure) {
|
||||
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: failed to vectorize: " << op);
|
||||
return failure();
|
||||
return llvm::None;
|
||||
}
|
||||
if (result.status == VectorizationStatus::NewOp) {
|
||||
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vector op: "
|
||||
@ -315,10 +314,7 @@ static LogicalResult vectorizeAsLinalgGeneric(
|
||||
}
|
||||
}
|
||||
|
||||
// 6. RAUW the linalg op by the results captured vectorizing the YieldOp.
|
||||
if (!results.empty())
|
||||
linalgOp->replaceAllUsesWith(results);
|
||||
return success();
|
||||
return VectorizedLinalgOp{{results}};
|
||||
}
|
||||
|
||||
/// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
|
||||
@ -357,7 +353,8 @@ static bool isElementwise(Operation *op) {
|
||||
return hasOnlyScalarElementwiseOp(genericOp.getRegion());
|
||||
}
|
||||
|
||||
static void vectorizeContraction(OpBuilder &builder, LinalgOp linalgOp) {
|
||||
static Optional<VectorizedLinalgOp> vectorizeContraction(OpBuilder &builder,
|
||||
LinalgOp linalgOp) {
|
||||
assert(isaContractionOpInterface(linalgOp) &&
|
||||
"expected vectorizeContraction preconditions to be met");
|
||||
Location loc = linalgOp.getLoc();
|
||||
@ -384,11 +381,7 @@ static void vectorizeContraction(OpBuilder &builder, LinalgOp linalgOp) {
|
||||
linalgOp.indexing_maps(), linalgOp.iterator_types());
|
||||
return VectorizationResult{VectorizationStatus::NewOp, contract};
|
||||
};
|
||||
auto status =
|
||||
vectorizeAsLinalgGeneric(builder, linalgOp, {vectorizeContraction});
|
||||
(void)status;
|
||||
assert(succeeded(status) &&
|
||||
"Unexpected vectorization failed despite preconditions");
|
||||
return vectorizeAsLinalgGeneric(builder, linalgOp, {vectorizeContraction});
|
||||
}
|
||||
|
||||
LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
|
||||
@ -408,8 +401,10 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
|
||||
return success(isaContractionOpInterface(linalgOp));
|
||||
}
|
||||
|
||||
void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
|
||||
assert(succeeded(vectorizeLinalgOpPrecondition(op)));
|
||||
Optional<VectorizedLinalgOp> mlir::linalg::vectorizeLinalgOp(OpBuilder &builder,
|
||||
Operation *op) {
|
||||
if (failed(vectorizeLinalgOpPrecondition(op)))
|
||||
return llvm::None;
|
||||
|
||||
edsc::ScopedContext scope(builder, op->getLoc());
|
||||
// In the case of 0-D memrefs, return null and special case to scalar load or
|
||||
@ -418,8 +413,10 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
|
||||
// Vectorize fill as a vector.broadcast.
|
||||
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
|
||||
<< "Rewrite linalg.fill as vector.broadcast: " << *op);
|
||||
buildVectorWrite(builder, fillOp.value(), fillOp.output());
|
||||
return;
|
||||
VectorizedLinalgOp res;
|
||||
if (Value v = buildVectorWrite(builder, fillOp.value(), fillOp.output()))
|
||||
res.tensorResults.push_back(v);
|
||||
return res;
|
||||
}
|
||||
if (auto copyOp = dyn_cast<linalg::CopyOp>(op)) {
|
||||
// Vectorize copy as a vector.transfer_read+vector.transfer_write.
|
||||
@ -428,21 +425,26 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
|
||||
"vector.transfer_write: "
|
||||
<< *op);
|
||||
Value vector = buildVectorRead(builder, copyOp.input());
|
||||
buildVectorWrite(builder, vector, copyOp.output());
|
||||
return;
|
||||
VectorizedLinalgOp res;
|
||||
if (Value v = buildVectorWrite(builder, vector, copyOp.output()))
|
||||
res.tensorResults.push_back(v);
|
||||
return res;
|
||||
}
|
||||
|
||||
if (isElementwise(op)) {
|
||||
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
|
||||
<< "Rewrite linalg op as vector.transfer_read + " << *op);
|
||||
auto status = vectorizeAsLinalgGeneric(builder, cast<LinalgOp>(op));
|
||||
(void)status;
|
||||
assert(succeeded(status) &&
|
||||
"Unexpected vectorization failed despite preconditions");
|
||||
return;
|
||||
<< "Vectorize linalg op as a generic: " << *op);
|
||||
return vectorizeAsLinalgGeneric(builder, cast<LinalgOp>(op));
|
||||
}
|
||||
|
||||
vectorizeContraction(builder, cast<LinalgOp>(op));
|
||||
// TODO: as soon as Copy and FillOp. get a region builder, replace all the
|
||||
// above by:
|
||||
// if (isa<FillOp, CopyOp>(op) || isElementwise(op)) {
|
||||
// LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
|
||||
// << "Vectorize linalg op as a generic: " << *op);
|
||||
// return vectorizeAsLinalgGeneric(builder, cast<LinalgOp>(op));
|
||||
// }
|
||||
|
||||
return vectorizeContraction(builder, cast<LinalgOp>(op));
|
||||
}
|
||||
|
||||
//----------------------------------------------------------------------------//
|
||||
|
@ -1,4 +1,6 @@
|
||||
// RUN: mlir-opt %s -test-linalg-transform-patterns=test-linalg-to-vector-patterns -split-input-file | FileCheck %s
|
||||
// RUN: mlir-opt %s -test-linalg-transform-patterns=test-linalg-to-vector-patterns -split-input-file -debug-only=linalg-vectorization
|
||||
|
||||
//| FileCheck %s
|
||||
|
||||
// -----
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user