[mlir][Linalg] NFC - Refactor vectorization to be more composable

Differential Revision: https://reviews.llvm.org/D96116
This commit is contained in:
Nicolas Vasilache 2021-02-05 11:48:16 +00:00
parent 7fe41ac3df
commit 0fcbbde2c7
4 changed files with 50 additions and 39 deletions

View File

@ -31,13 +31,6 @@ struct LinalgTilingOptions;
//===----------------------------------------------------------------------===//
using LinalgLoops = SmallVector<Operation *, 4>;
struct TiledLinalgOp {
LinalgOp op;
SmallVector<Operation *, 8> loops;
SmallVector<Value, 4> tensorResults;
TiledLinalgOp &operator=(const TiledLinalgOp &) = default;
};
/// Populates patterns for vectorization of all ConvN-D ops.
void populateConvVectorizationPatterns(
MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns,
@ -63,6 +56,12 @@ void populateLinalgBufferizePatterns(MLIRContext *context,
/// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be
/// integers, in the range 0..`tileSizes.size()` without duplications
/// (i.e. `[1,1,2]` is an invalid permutation).
struct TiledLinalgOp {
LinalgOp op;
SmallVector<Operation *, 8> loops;
SmallVector<Value, 4> tensorResults;
TiledLinalgOp &operator=(const TiledLinalgOp &) = default;
};
Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
const LinalgTilingOptions &options);
@ -264,7 +263,12 @@ Optional<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op,
OperationFolder *folder = nullptr);
/// Emit a suitable vector form for a Linalg op with fully static shape.
void vectorizeLinalgOp(OpBuilder &builder, Operation *op);
struct VectorizedLinalgOp {
SmallVector<Value> tensorResults;
VectorizedLinalgOp &operator=(const VectorizedLinalgOp &) = default;
};
Optional<VectorizedLinalgOp> vectorizeLinalgOp(OpBuilder &builder,
Operation *op);
/// Emits a loop nest of `LoopTy` with the proper body for `op`.
template <typename LoopTy>

View File

@ -468,10 +468,13 @@ LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
return failure();
if (failed(filter.checkAndNotify(rewriter, linalgOp)))
return failure();
if (failed(vectorizeLinalgOpPrecondition(op)))
Optional<VectorizedLinalgOp> res = vectorizeLinalgOp(rewriter, op);
if (!res)
return failure();
vectorizeLinalgOp(rewriter, op);
rewriter.eraseOp(op);
if (!res->tensorResults.empty())
rewriter.replaceOp(op, res->tensorResults);
else
rewriter.eraseOp(op);
return success();
}

View File

@ -248,8 +248,7 @@ vectorizeOneOp(OpBuilder &builder, Operation *op,
/// TODO: Reuse opportunities for RAR dependencies.
/// 4. Register CustomVectorizationHook for YieldOp to capture the results.
/// 5. Iteratively call vectorizeOneOp on the region operations.
/// 6. RAUW the linalg op by the results captured vectorizing the YieldOp.
static LogicalResult vectorizeAsLinalgGeneric(
static Optional<VectorizedLinalgOp> vectorizeAsLinalgGeneric(
OpBuilder &builder, LinalgOp linalgOp,
ArrayRef<CustomVectorizationHook> customVectorizationHooks = {}) {
// 1. Certain Linalg ops do not have a region but only a region builder.
@ -306,7 +305,7 @@ static LogicalResult vectorizeAsLinalgGeneric(
VectorizationResult result = vectorizeOneOp(builder, &op, bvm, hooks);
if (result.status == VectorizationStatus::Failure) {
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: failed to vectorize: " << op);
return failure();
return llvm::None;
}
if (result.status == VectorizationStatus::NewOp) {
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vector op: "
@ -315,10 +314,7 @@ static LogicalResult vectorizeAsLinalgGeneric(
}
}
// 6. RAUW the linalg op by the results captured vectorizing the YieldOp.
if (!results.empty())
linalgOp->replaceAllUsesWith(results);
return success();
return VectorizedLinalgOp{{results}};
}
/// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
@ -357,7 +353,8 @@ static bool isElementwise(Operation *op) {
return hasOnlyScalarElementwiseOp(genericOp.getRegion());
}
static void vectorizeContraction(OpBuilder &builder, LinalgOp linalgOp) {
static Optional<VectorizedLinalgOp> vectorizeContraction(OpBuilder &builder,
LinalgOp linalgOp) {
assert(isaContractionOpInterface(linalgOp) &&
"expected vectorizeContraction preconditions to be met");
Location loc = linalgOp.getLoc();
@ -384,11 +381,7 @@ static void vectorizeContraction(OpBuilder &builder, LinalgOp linalgOp) {
linalgOp.indexing_maps(), linalgOp.iterator_types());
return VectorizationResult{VectorizationStatus::NewOp, contract};
};
auto status =
vectorizeAsLinalgGeneric(builder, linalgOp, {vectorizeContraction});
(void)status;
assert(succeeded(status) &&
"Unexpected vectorization failed despite preconditions");
return vectorizeAsLinalgGeneric(builder, linalgOp, {vectorizeContraction});
}
LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
@ -408,8 +401,10 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
return success(isaContractionOpInterface(linalgOp));
}
void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
assert(succeeded(vectorizeLinalgOpPrecondition(op)));
Optional<VectorizedLinalgOp> mlir::linalg::vectorizeLinalgOp(OpBuilder &builder,
Operation *op) {
if (failed(vectorizeLinalgOpPrecondition(op)))
return llvm::None;
edsc::ScopedContext scope(builder, op->getLoc());
// In the case of 0-D memrefs, return null and special case to scalar load or
@ -418,8 +413,10 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
// Vectorize fill as a vector.broadcast.
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
<< "Rewrite linalg.fill as vector.broadcast: " << *op);
buildVectorWrite(builder, fillOp.value(), fillOp.output());
return;
VectorizedLinalgOp res;
if (Value v = buildVectorWrite(builder, fillOp.value(), fillOp.output()))
res.tensorResults.push_back(v);
return res;
}
if (auto copyOp = dyn_cast<linalg::CopyOp>(op)) {
// Vectorize copy as a vector.transfer_read+vector.transfer_write.
@ -428,21 +425,26 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
"vector.transfer_write: "
<< *op);
Value vector = buildVectorRead(builder, copyOp.input());
buildVectorWrite(builder, vector, copyOp.output());
return;
VectorizedLinalgOp res;
if (Value v = buildVectorWrite(builder, vector, copyOp.output()))
res.tensorResults.push_back(v);
return res;
}
if (isElementwise(op)) {
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
<< "Rewrite linalg op as vector.transfer_read + " << *op);
auto status = vectorizeAsLinalgGeneric(builder, cast<LinalgOp>(op));
(void)status;
assert(succeeded(status) &&
"Unexpected vectorization failed despite preconditions");
return;
<< "Vectorize linalg op as a generic: " << *op);
return vectorizeAsLinalgGeneric(builder, cast<LinalgOp>(op));
}
vectorizeContraction(builder, cast<LinalgOp>(op));
// TODO: as soon as Copy and FillOp. get a region builder, replace all the
// above by:
// if (isa<FillOp, CopyOp>(op) || isElementwise(op)) {
// LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
// << "Vectorize linalg op as a generic: " << *op);
// return vectorizeAsLinalgGeneric(builder, cast<LinalgOp>(op));
// }
return vectorizeContraction(builder, cast<LinalgOp>(op));
}
//----------------------------------------------------------------------------//

View File

@ -1,4 +1,6 @@
// RUN: mlir-opt %s -test-linalg-transform-patterns=test-linalg-to-vector-patterns -split-input-file | FileCheck %s
// RUN: mlir-opt %s -test-linalg-transform-patterns=test-linalg-to-vector-patterns -split-input-file -debug-only=linalg-vectorization
//| FileCheck %s
// -----