diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h index 14f8b29f6689..e318bd5cf3e1 100644 --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h @@ -19,11 +19,14 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Location.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/StringRef.h" namespace mlir { +class PatternRewriter; + /// Tests whether the given maps describe a row major matmul. The test is /// permutation-invariant. Note that this only checks the affine maps from an /// operation, so does not perform any checks on the math being performed within @@ -132,6 +135,60 @@ inline StringRef toString(IteratorType t) { llvm_unreachable("Unsupported IteratorType"); } +/// Helper StructuredGenerator class to manipulate and rewrite ops with +/// `StructuredOpInterface`. This is templated for now because VectorOps do not +/// yet implement the StructuredOpInterface itself. +template +class StructuredGenerator { +public: + using MapList = ArrayRef>; + + struct IteratorType { + IteratorType(StringRef strRef) : strRef(strRef) {} + bool isOfType(Attribute attr) const { + auto sAttr = attr.dyn_cast(); + return sAttr && sAttr.getValue() == strRef; + } + StringRef strRef; + }; + struct Par : public IteratorType { + Par() : IteratorType(getParallelIteratorTypeName()) {} + }; + struct Red : public IteratorType { + Red() : IteratorType(getReductionIteratorTypeName()) {} + }; + struct Win : public IteratorType { + Win() : IteratorType(getWindowIteratorTypeName()) {} + }; + + StructuredGenerator(PatternRewriter &rewriter, StructuredOpInterface op) + : rewriter(rewriter), ctx(op.getContext()), loc(op.getLoc()), + iterators(op.iterator_types()), maps(op.getIndexingMaps()), op(op) {} + + bool iters(ArrayRef its) { + if (its.size() != iterators.size()) + return false; + for (int i = 0, e = its.size(); i != e; ++i) { + if (!its[i].isOfType(iterators[i])) + return false; + } + return true; + } + + bool layout(MapList l) { + auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; + return maps == infer(l); + } + +protected: + PatternRewriter &rewriter; + MLIRContext *ctx; + Location loc; + ArrayAttr iterators; + SmallVector maps; + Operation *op; +}; + } // end namespace mlir #endif // MLIR_UTILS_STRUCTUREDOPSUTILS_H diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp index 2b83c23730aa..d1b9835452b7 100644 --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -1252,35 +1252,22 @@ struct Red : public IteratorType { Red() : IteratorType(getReductionIteratorTypeName()) {} }; -// Unroll outer-products along reduction. -struct UnrolledOuterProductEmitter { - using MapList = ArrayRef>; +/// Generate a vector implementation for matmat, matvec and tmatvec. +/// This unrolls outer-products along the reduction dimension. +struct UnrolledOuterProductGenerator + : public StructuredGenerator { - UnrolledOuterProductEmitter(PatternRewriter &rewriter, - vector::ContractionOp op) - : rewriter(rewriter), loc(op.getLoc()), kind(op.kind()), - iterators(op.iterator_types()), maps(op.getIndexingMaps()), op(op) {} + UnrolledOuterProductGenerator(PatternRewriter &rewriter, + vector::ContractionOp op) + : StructuredGenerator(rewriter, op), + kind(op.kind()), lhs(op.lhs()), rhs(op.rhs()), res(op.acc()), + lhsType(op.getLhsType()) {} Value t(Value v) { static constexpr std::array perm = {1, 0}; return rewriter.create(loc, v, perm); } - bool iters(ArrayRef its) { - if (its.size() != iterators.size()) - return false; - for (int i = 0, e = its.size(); i != e; ++i) { - if (!its[i].isOfType(iterators[i])) - return false; - } - return true; - } - - bool layout(MapList l) { - auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; - return maps == infer(l); - } - LogicalResult outer_prod(Value lhs, Value rhs, Value res, int reductionSize) { assert(reductionSize > 0); for (int64_t k = 0; k < reductionSize; ++k) { @@ -1293,12 +1280,93 @@ struct UnrolledOuterProductEmitter { return success(); } - PatternRewriter &rewriter; - Location loc; + /// Two outer parallel, one inner reduction (matmat flavor). + LogicalResult matmat() { + if (!iters({Par(), Par(), Red()})) + return failure(); + // Set up the parallel/reduction structure in the right form. + AffineExpr m, n, k; + bindDims(rewriter.getContext(), m, n, k); + // Classical row-major matmul: Just permute the lhs. + if (layout({{m, k}, {k, n}, {m, n}})) + return outer_prod(t(lhs), rhs, res, lhsType.getDimSize(1)); + // TODO: may be better to fail and use some vector -> scalar reduction. + if (layout({{m, k}, {n, k}, {m, n}})) { + Value tlhs = t(lhs); + return outer_prod(tlhs, t(rhs), res, lhsType.getDimSize(1)); + } + // No need to permute anything. + if (layout({{k, m}, {k, n}, {m, n}})) + return outer_prod(lhs, rhs, res, lhsType.getDimSize(0)); + // Just permute the rhs. + if (layout({{k, m}, {n, k}, {m, n}})) + return outer_prod(lhs, t(rhs), res, lhsType.getDimSize(0)); + // Transposed output: swap RHS and LHS. + // Classical row-major matmul: permute the lhs. + if (layout({{m, k}, {k, n}, {n, m}})) + return outer_prod(rhs, t(lhs), res, lhsType.getDimSize(1)); + // TODO: may be better to fail and use some vector -> scalar reduction. + if (layout({{m, k}, {n, k}, {n, m}})) { + Value trhs = t(rhs); + return outer_prod(trhs, t(lhs), res, lhsType.getDimSize(1)); + } + if (layout({{k, m}, {k, n}, {n, m}})) + return outer_prod(rhs, lhs, res, lhsType.getDimSize(0)); + if (layout({{k, m}, {n, k}, {n, m}})) + return outer_prod(t(rhs), lhs, res, lhsType.getDimSize(0)); + return failure(); + } + + /// One outer parallel, one inner reduction (matvec flavor) + LogicalResult matvec() { + if (!iters({Par(), Red()})) + return failure(); + AffineExpr m, k; + bindDims(rewriter.getContext(), m, k); + + // Case mat-vec: transpose. + if (layout({{m, k}, {k}, {m}})) + return outer_prod(t(lhs), rhs, res, lhsType.getDimSize(1)); + // Case mat-trans-vec: ready to go. + if (layout({{k, m}, {k}, {m}})) + return outer_prod(lhs, rhs, res, lhsType.getDimSize(0)); + // Case vec-mat: swap and transpose. + if (layout({{k}, {m, k}, {m}})) + return outer_prod(t(rhs), lhs, res, lhsType.getDimSize(0)); + // Case vec-mat-trans: swap and ready to go. + if (layout({{k}, {k, m}, {m}})) + return outer_prod(rhs, lhs, res, lhsType.getDimSize(0)); + return failure(); + } + + // + // One outer reduction, one inner parallel (tmatvec flavor) + // + LogicalResult tmatvec() { + if (!iters({Red(), Par()})) + return failure(); + AffineExpr k, m; + bindDims(rewriter.getContext(), k, m); + + // Case mat-vec: transpose. + if (layout({{m, k}, {k}, {m}})) + return outer_prod(t(lhs), rhs, res, lhsType.getDimSize(1)); + // Case mat-trans-vec: ready to go. + if (layout({{k, m}, {k}, {m}})) + return outer_prod(lhs, rhs, res, lhsType.getDimSize(0)); + // Case vec-mat: swap and transpose. + if (layout({{k}, {m, k}, {m}})) + return outer_prod(t(rhs), lhs, res, lhsType.getDimSize(0)); + // Case vec-mat-trans: swap and ready to go. + if (layout({{k}, {k, m}, {m}})) + return outer_prod(rhs, lhs, res, lhsType.getDimSize(0)); + return failure(); + } + +private: vector::CombiningKind kind; - ArrayAttr iterators; - SmallVector maps; - Operation *op; + Value lhs, rhs, res; + VectorType lhsType; }; } // namespace @@ -1330,90 +1398,13 @@ LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite( if (failed(filter(op))) return failure(); - VectorType lhsType = op.getLhsType(); - Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc(); - - // - // Two outer parallel, one inner reduction (matmat flavor). - // - UnrolledOuterProductEmitter e(rewriter, op); - if (e.iters({Par(), Par(), Red()})) { - // Set up the parallel/reduction structure in right form. - AffineExpr m, n, k; - bindDims(rewriter.getContext(), m, n, k); - // Classical row-major matmul: Just permute the lhs. - if (e.layout({{m, k}, {k, n}, {m, n}})) - return e.outer_prod(e.t(lhs), rhs, res, lhsType.getDimSize(1)); - // TODO: may be better to fail and use some vector -> scalar reduction. - if (e.layout({{m, k}, {n, k}, {m, n}})) { - Value tlhs = e.t(lhs); - return e.outer_prod(tlhs, e.t(rhs), res, lhsType.getDimSize(1)); - } - // No need to permute anything. - if (e.layout({{k, m}, {k, n}, {m, n}})) - return e.outer_prod(lhs, rhs, res, lhsType.getDimSize(0)); - // Just permute the rhs. - if (e.layout({{k, m}, {n, k}, {m, n}})) - return e.outer_prod(lhs, e.t(rhs), res, lhsType.getDimSize(0)); - // Transposed output: swap RHS and LHS. - // Classical row-major matmul: permute the lhs. - if (e.layout({{m, k}, {k, n}, {n, m}})) - return e.outer_prod(rhs, e.t(lhs), res, lhsType.getDimSize(1)); - // TODO: may be better to fail and use some vector -> scalar reduction. - if (e.layout({{m, k}, {n, k}, {n, m}})) { - Value trhs = e.t(rhs); - return e.outer_prod(trhs, e.t(lhs), res, lhsType.getDimSize(1)); - } - if (e.layout({{k, m}, {k, n}, {n, m}})) - return e.outer_prod(rhs, lhs, res, lhsType.getDimSize(0)); - if (e.layout({{k, m}, {n, k}, {n, m}})) - return e.outer_prod(e.t(rhs), lhs, res, lhsType.getDimSize(0)); - return failure(); - } - - // - // One outer parallel, one inner reduction (matvec flavor) - // - if (e.iters({Par(), Red()})) { - AffineExpr m, k; - bindDims(rewriter.getContext(), m, k); - - // Case mat-vec: transpose. - if (e.layout({{m, k}, {k}, {m}})) - return e.outer_prod(e.t(lhs), rhs, res, lhsType.getDimSize(1)); - // Case mat-trans-vec: ready to go. - if (e.layout({{k, m}, {k}, {m}})) - return e.outer_prod(lhs, rhs, res, lhsType.getDimSize(0)); - // Case vec-mat: swap and transpose. - if (e.layout({{k}, {m, k}, {m}})) - return e.outer_prod(e.t(rhs), lhs, res, lhsType.getDimSize(0)); - // Case vec-mat-trans: swap and ready to go. - if (e.layout({{k}, {k, m}, {m}})) - return e.outer_prod(rhs, lhs, res, lhsType.getDimSize(0)); - return failure(); - } - - // - // One outer reduction, one inner parallel (tmatvec flavor) - // - if (e.iters({Red(), Par()})) { - AffineExpr k, m; - bindDims(rewriter.getContext(), k, m); - - // Case mat-vec: transpose. - if (e.layout({{m, k}, {k}, {m}})) - return e.outer_prod(e.t(lhs), rhs, res, lhsType.getDimSize(1)); - // Case mat-trans-vec: ready to go. - if (e.layout({{k, m}, {k}, {m}})) - return e.outer_prod(lhs, rhs, res, lhsType.getDimSize(0)); - // Case vec-mat: swap and transpose. - if (e.layout({{k}, {m, k}, {m}})) - return e.outer_prod(e.t(rhs), lhs, res, lhsType.getDimSize(0)); - // Case vec-mat-trans: swap and ready to go. - if (e.layout({{k}, {k, m}, {m}})) - return e.outer_prod(rhs, lhs, res, lhsType.getDimSize(0)); - return failure(); - } + UnrolledOuterProductGenerator e(rewriter, op); + if (succeeded(e.matmat())) + return success(); + if (succeeded(e.matvec())) + return success(); + if (succeeded(e.tmatvec())) + return success(); return failure(); }