mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-01-15 20:51:35 +00:00
[MLIR][Vector] Add support for TupleGetOp folding through InsertSlicesOp and ExtractSlicesOp.
Summary: Add support for TupleGetOp folding through InsertSlicesOp and ExtractSlicesOp. Vector-to-vector transformations for unrolling and lowering to hardware vectors can generate chains of structured vector operations (InsertSlicesOp, ExtractSlicesOp and ShapeCastOp) between the producer of a hardware vector value and its consumer. Because InsertSlicesOp, ExtractSlicesOp and ShapeCastOp are structured, we can track the location (tuple index and vector offsets) of the consumer vector value through the chain of structured operations to the producer, enabling a much more powerful producer-consumer fowarding of values through structured ops and tuple, which in turn enables a more powerful TupleGetOp folding transformation. Reviewers: nicolasvasilache, aartbik Reviewed By: aartbik Subscribers: grosul1, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D76889
This commit is contained in:
parent
34756a1c70
commit
31a346cc35
@ -31,6 +31,9 @@ class VectorType;
|
||||
SmallVector<int64_t, 4> computeStrides(ArrayRef<int64_t> shape,
|
||||
ArrayRef<int64_t> sizes);
|
||||
|
||||
/// Computes and returns the linearized index of 'offsets' w.r.t. 'basis'.
|
||||
int64_t linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis);
|
||||
|
||||
/// Given the slice strides together with a linear index in the dimension
|
||||
/// space, returns the vector-space offsets in each dimension for a
|
||||
/// de-linearized index.
|
||||
|
@ -69,15 +69,6 @@ static int64_t computeMaxLinearIndex(ArrayRef<int64_t> basis) {
|
||||
return res;
|
||||
}
|
||||
|
||||
/// Computes and returns the linearized index of 'offsets' w.r.t. 'basis'.
|
||||
static int64_t linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis) {
|
||||
assert(offsets.size() == basis.size());
|
||||
int64_t linearIndex = 0;
|
||||
for (unsigned idx = 0, e = basis.size(); idx < e; ++idx)
|
||||
linearIndex += offsets[idx] * basis[idx];
|
||||
return linearIndex;
|
||||
}
|
||||
|
||||
// Clones `op` into a new operations that takes `operands` and returns
|
||||
// `resultTypes`.
|
||||
static Operation *cloneOpWithOperandsAndTypes(PatternRewriter &builder,
|
||||
@ -683,6 +674,99 @@ struct ShapeCastOpDecomposer : public OpRewritePattern<vector::ShapeCastOp> {
|
||||
}
|
||||
};
|
||||
|
||||
/// Returns the producer Value of the same type as 'consumerValue', by tracking
|
||||
/// the tuple index and offsets of the consumer vector value through the
|
||||
/// chain of operations (TupleGetOp, InsertSlicesOp, ExtractSlicesOp, TupleOp)
|
||||
/// from consumer to producer. Each operation in the chain is structured, and
|
||||
/// so the tuple index and offsets can be mapped from result to input, while
|
||||
/// visiting each operation in the chain.
|
||||
/// Returns nullptr on failure.
|
||||
static Value getProducerValue(Value consumerValue) {
|
||||
auto consumerVectorType = consumerValue.getType().cast<VectorType>();
|
||||
// A tupleIndex == -1 indicates that 'offsets' are w.r.t a vector type.
|
||||
int64_t tupleIndex = -1;
|
||||
SmallVector<int64_t, 4> offsets(consumerVectorType.getRank(), 0);
|
||||
auto *op = consumerValue.getDefiningOp();
|
||||
while (op != nullptr) {
|
||||
if (auto tupleGetOp = dyn_cast<vector::TupleGetOp>(op)) {
|
||||
assert(tupleIndex == -1 && "TupleGetOp must have vector result type");
|
||||
|
||||
// Update 'tupleIndex' and next defining 'op' to visit.
|
||||
tupleIndex = tupleGetOp.getIndex();
|
||||
op = tupleGetOp.vectors().getDefiningOp();
|
||||
} else if (auto extractSlicesOp = dyn_cast<vector::ExtractSlicesOp>(op)) {
|
||||
assert(tupleIndex >= 0);
|
||||
|
||||
// Compute slice strides for 'extractSlicesOp'.
|
||||
SmallVector<int64_t, 4> sizes;
|
||||
extractSlicesOp.getSizes(sizes);
|
||||
auto sliceStrides = computeStrides(
|
||||
extractSlicesOp.getSourceVectorType().getShape(), sizes);
|
||||
|
||||
// Compute 'elementOffsets' into 'extractSlicesOp' input vector type,
|
||||
// of 'extractSlicesOp' result vector tuple element at 'tupleIndex'.
|
||||
auto vectorOffsets = delinearize(sliceStrides, tupleIndex);
|
||||
auto elementOffsets =
|
||||
computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets);
|
||||
|
||||
// Add 'elementOffsets' to 'offsets' so that 'offsets' are now relative
|
||||
// to the 'extractSlicesOp' input vector type.
|
||||
assert(offsets.size() == elementOffsets.size());
|
||||
for (unsigned i = 0, e = offsets.size(); i < e; ++i)
|
||||
offsets[i] += elementOffsets[i];
|
||||
|
||||
// Clear 'tupleIndex' and update next defining 'op' to visit.
|
||||
tupleIndex = -1;
|
||||
op = extractSlicesOp.vector().getDefiningOp();
|
||||
} else if (auto insertSlicesOp = dyn_cast<vector::InsertSlicesOp>(op)) {
|
||||
assert(tupleIndex == -1);
|
||||
|
||||
// Compute slice strides for 'insertSlicesOp'.
|
||||
SmallVector<int64_t, 4> sizes;
|
||||
insertSlicesOp.getSizes(sizes);
|
||||
auto sliceStrides = computeStrides(
|
||||
insertSlicesOp.getResultVectorType().getShape(), sizes);
|
||||
|
||||
// Compute 'vectorOffsets' of 'insertSlicesOp' input vector slice,
|
||||
// of 'insertSlicesOp' result vector type at 'offsets'.
|
||||
SmallVector<int64_t, 4> vectorOffsets(offsets.size());
|
||||
assert(offsets.size() == sizes.size());
|
||||
for (unsigned i = 0, e = offsets.size(); i < e; ++i)
|
||||
vectorOffsets[i] = offsets[i] / sizes[i];
|
||||
|
||||
// Compute the source tuple element index.
|
||||
tupleIndex = linearize(vectorOffsets, sliceStrides);
|
||||
|
||||
// Subtract 'elementOffsets' from 'offsets' so that 'offsets' are now
|
||||
// relative to input tuple element vector type at 'tupleIndex'.
|
||||
auto elementOffsets =
|
||||
computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets);
|
||||
assert(offsets.size() == elementOffsets.size());
|
||||
for (unsigned i = 0, e = offsets.size(); i < e; ++i) {
|
||||
offsets[i] -= elementOffsets[i];
|
||||
assert(offsets[i] >= 0);
|
||||
}
|
||||
|
||||
// Update next defining 'op' to visit.
|
||||
op = insertSlicesOp.vectors().getDefiningOp();
|
||||
} else if (auto tupleOp = dyn_cast<vector::TupleOp>(op)) {
|
||||
assert(tupleIndex >= 0);
|
||||
|
||||
// Return tuple element 'value' at 'tupleIndex' if it matches type.
|
||||
auto value = tupleOp.getOperand(tupleIndex);
|
||||
if (value.getType() == consumerVectorType)
|
||||
return value;
|
||||
|
||||
// Update 'tupleIndex' and next defining 'op' to visit.
|
||||
tupleIndex = -1;
|
||||
op = value.getDefiningOp();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/// ShapeCastOpFolder folds cancelling ShapeCastOps away.
|
||||
//
|
||||
// Example:
|
||||
@ -740,28 +824,11 @@ struct TupleGetFolderOp : public OpRewritePattern<vector::TupleGetOp> {
|
||||
|
||||
LogicalResult matchAndRewrite(vector::TupleGetOp tupleGetOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Return if 'tupleGetOp.vectors' arg was not defined by ExtractSlicesOp.
|
||||
auto extractSlicesOp = dyn_cast_or_null<vector::ExtractSlicesOp>(
|
||||
tupleGetOp.vectors().getDefiningOp());
|
||||
if (!extractSlicesOp)
|
||||
return failure();
|
||||
|
||||
// Return if 'extractSlicesOp.vector' arg was not defined by InsertSlicesOp.
|
||||
auto insertSlicesOp = dyn_cast_or_null<vector::InsertSlicesOp>(
|
||||
extractSlicesOp.vector().getDefiningOp());
|
||||
if (!insertSlicesOp)
|
||||
return failure();
|
||||
|
||||
// Return if 'insertSlicesOp.vectors' arg was not defined by TupleOp.
|
||||
auto tupleOp = dyn_cast_or_null<vector::TupleOp>(
|
||||
insertSlicesOp.vectors().getDefiningOp());
|
||||
if (!tupleOp)
|
||||
return failure();
|
||||
|
||||
// Forward Value from 'tupleOp' at 'tupleGetOp.index'.
|
||||
Value tupleValue = tupleOp.getOperand(tupleGetOp.getIndex());
|
||||
rewriter.replaceOp(tupleGetOp, tupleValue);
|
||||
return success();
|
||||
if (auto producer = getProducerValue(tupleGetOp.getResult())) {
|
||||
rewriter.replaceOp(tupleGetOp, producer);
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -28,10 +28,10 @@
|
||||
|
||||
using llvm::SetVector;
|
||||
|
||||
namespace mlir {
|
||||
using namespace mlir;
|
||||
|
||||
SmallVector<int64_t, 4> computeStrides(ArrayRef<int64_t> shape,
|
||||
ArrayRef<int64_t> sizes) {
|
||||
SmallVector<int64_t, 4> mlir::computeStrides(ArrayRef<int64_t> shape,
|
||||
ArrayRef<int64_t> sizes) {
|
||||
int64_t rank = shape.size();
|
||||
// Compute the count for each dimension.
|
||||
SmallVector<int64_t, 4> sliceDimCounts(rank);
|
||||
@ -45,8 +45,16 @@ SmallVector<int64_t, 4> computeStrides(ArrayRef<int64_t> shape,
|
||||
return sliceStrides;
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 4> delinearize(ArrayRef<int64_t> sliceStrides,
|
||||
int64_t index) {
|
||||
int64_t mlir::linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis) {
|
||||
assert(offsets.size() == basis.size());
|
||||
int64_t linearIndex = 0;
|
||||
for (unsigned idx = 0, e = basis.size(); idx < e; ++idx)
|
||||
linearIndex += offsets[idx] * basis[idx];
|
||||
return linearIndex;
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 4> mlir::delinearize(ArrayRef<int64_t> sliceStrides,
|
||||
int64_t index) {
|
||||
int64_t rank = sliceStrides.size();
|
||||
SmallVector<int64_t, 4> vectorOffsets(rank);
|
||||
for (int64_t r = 0; r < rank; ++r) {
|
||||
@ -57,16 +65,15 @@ SmallVector<int64_t, 4> delinearize(ArrayRef<int64_t> sliceStrides,
|
||||
return vectorOffsets;
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 4>
|
||||
computeElementOffsetsFromVectorSliceOffsets(ArrayRef<int64_t> sizes,
|
||||
ArrayRef<int64_t> vectorOffsets) {
|
||||
SmallVector<int64_t, 4> mlir::computeElementOffsetsFromVectorSliceOffsets(
|
||||
ArrayRef<int64_t> sizes, ArrayRef<int64_t> vectorOffsets) {
|
||||
return functional::zipMap([](int64_t v1, int64_t v2) { return v1 * v2; },
|
||||
vectorOffsets, sizes);
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 4> computeSliceSizes(ArrayRef<int64_t> shape,
|
||||
ArrayRef<int64_t> sizes,
|
||||
ArrayRef<int64_t> elementOffsets) {
|
||||
SmallVector<int64_t, 4>
|
||||
mlir::computeSliceSizes(ArrayRef<int64_t> shape, ArrayRef<int64_t> sizes,
|
||||
ArrayRef<int64_t> elementOffsets) {
|
||||
int64_t rank = shape.size();
|
||||
SmallVector<int64_t, 4> sliceSizes(rank);
|
||||
for (unsigned r = 0; r < rank; ++r)
|
||||
@ -74,8 +81,8 @@ SmallVector<int64_t, 4> computeSliceSizes(ArrayRef<int64_t> shape,
|
||||
return sliceSizes;
|
||||
}
|
||||
|
||||
Optional<SmallVector<int64_t, 4>> shapeRatio(ArrayRef<int64_t> superShape,
|
||||
ArrayRef<int64_t> subShape) {
|
||||
Optional<SmallVector<int64_t, 4>> mlir::shapeRatio(ArrayRef<int64_t> superShape,
|
||||
ArrayRef<int64_t> subShape) {
|
||||
if (superShape.size() < subShape.size()) {
|
||||
return Optional<SmallVector<int64_t, 4>>();
|
||||
}
|
||||
@ -114,8 +121,8 @@ Optional<SmallVector<int64_t, 4>> shapeRatio(ArrayRef<int64_t> superShape,
|
||||
return SmallVector<int64_t, 4>{result.rbegin(), result.rend()};
|
||||
}
|
||||
|
||||
Optional<SmallVector<int64_t, 4>> shapeRatio(VectorType superVectorType,
|
||||
VectorType subVectorType) {
|
||||
Optional<SmallVector<int64_t, 4>> mlir::shapeRatio(VectorType superVectorType,
|
||||
VectorType subVectorType) {
|
||||
assert(superVectorType.getElementType() == subVectorType.getElementType() &&
|
||||
"vector types must be of the same elemental type");
|
||||
return shapeRatio(superVectorType.getShape(), subVectorType.getShape());
|
||||
@ -201,9 +208,9 @@ static SetVector<Operation *> getEnclosingforOps(Operation *op) {
|
||||
return getParentsOfType<AffineForOp>(op);
|
||||
}
|
||||
|
||||
AffineMap
|
||||
makePermutationMap(Operation *op, ArrayRef<Value> indices,
|
||||
const DenseMap<Operation *, unsigned> &loopToVectorDim) {
|
||||
AffineMap mlir::makePermutationMap(
|
||||
Operation *op, ArrayRef<Value> indices,
|
||||
const DenseMap<Operation *, unsigned> &loopToVectorDim) {
|
||||
DenseMap<Operation *, unsigned> enclosingLoopToVectorDim;
|
||||
auto enclosingLoops = getEnclosingforOps(op);
|
||||
for (auto *forInst : enclosingLoops) {
|
||||
@ -212,7 +219,7 @@ makePermutationMap(Operation *op, ArrayRef<Value> indices,
|
||||
enclosingLoopToVectorDim.insert(*it);
|
||||
}
|
||||
}
|
||||
return makePermutationMap(indices, enclosingLoopToVectorDim);
|
||||
return ::makePermutationMap(indices, enclosingLoopToVectorDim);
|
||||
}
|
||||
|
||||
bool matcher::operatesOnSuperVectorsOf(Operation &op,
|
||||
@ -275,4 +282,3 @@ bool matcher::operatesOnSuperVectorsOf(Operation &op,
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
@ -313,6 +313,95 @@ func @tuple_get(%arg0: vector<4xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
|
||||
return %1 : vector<8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @tuple_get_producer_consumer
|
||||
// CHECK-SAME: %[[A0:.*0]]: vector<2x4xf32>,
|
||||
// CHECK-SAME: %[[A1:.*1]]: vector<2x4xf32>,
|
||||
// CHECK-SAME: %[[A2:.*2]]: vector<2x4xf32>,
|
||||
// CHECK-SAME: %[[A3:.*3]]: vector<2x4xf32>,
|
||||
// CHECK-SAME: %[[A4:.*4]]: vector<2x4xf32>,
|
||||
// CHECK-SAME: %[[A5:.*5]]: vector<2x4xf32>,
|
||||
// CHECK-SAME: %[[A6:.*6]]: vector<2x4xf32>,
|
||||
// CHECK-SAME: %[[A7:.*7]]: vector<2x4xf32>
|
||||
// CHECK: return %[[A7]] : vector<2x4xf32>
|
||||
|
||||
func @tuple_get_producer_consumer(
|
||||
%arg0 : vector<2x4xf32>, %arg1 : vector<2x4xf32>,
|
||||
%arg2 : vector<2x4xf32>, %arg3 : vector<2x4xf32>,
|
||||
%arg4 : vector<2x4xf32>, %arg5 : vector<2x4xf32>,
|
||||
%arg6 : vector<2x4xf32>, %arg7 : vector<2x4xf32>) -> vector<2x4xf32> {
|
||||
%0 = vector.tuple %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7
|
||||
: vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>,
|
||||
vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>
|
||||
// %arg7 == %0 at tupleIndex = 7, offsets = [0, 0]
|
||||
%1 = vector.insert_slices %0, [2, 4], [1, 1]
|
||||
: tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>,
|
||||
vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
|
||||
into vector<4x16xf32>
|
||||
// %arg7 == %1 at tupleIndex = -1, offsets = [2, 12]
|
||||
%2 = vector.extract_slices %1, [4, 8], [1, 1]
|
||||
: vector<4x16xf32> into tuple<vector<4x8xf32>, vector<4x8xf32>>
|
||||
// %arg7 == %2 at tupleIndex = 1, offsets = [2, 4]
|
||||
%3 = vector.tuple_get %2, 1 : tuple<vector<4x8xf32>, vector<4x8xf32>>
|
||||
// %arg7 == %3 at tupleIndex = -1, offsets = [2, 4]
|
||||
%4 = vector.extract_slices %3, [2, 4], [1, 1]
|
||||
: vector<4x8xf32> into
|
||||
tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
|
||||
// %arg7 == %4 at tupleIndex = 3, offsets = [0, 0]
|
||||
%5 = vector.tuple_get %4, 3
|
||||
: tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
|
||||
// %arg7 == %5
|
||||
return %5 : vector<2x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @tuple_get_producer_consumer_swizzle
|
||||
// CHECK-SAME: %[[A0:.*0]]: vector<2x4xf32>,
|
||||
// CHECK-SAME: %[[A1:.*1]]: vector<2x4xf32>,
|
||||
// CHECK-SAME: %[[A2:.*2]]: vector<2x4xf32>,
|
||||
// CHECK-SAME: %[[A3:.*3]]: vector<2x4xf32>,
|
||||
// CHECK-SAME: %[[A4:.*4]]: vector<2x4xf32>,
|
||||
// CHECK-SAME: %[[A5:.*5]]: vector<2x4xf32>,
|
||||
// CHECK-SAME: %[[A6:.*6]]: vector<2x4xf32>,
|
||||
// CHECK-SAME: %[[A7:.*7]]: vector<2x4xf32>
|
||||
// CHECK: return %[[A7]] : vector<2x4xf32>
|
||||
|
||||
func @tuple_get_producer_consumer_swizzle(
|
||||
%arg0 : vector<2x4xf32>, %arg1 : vector<2x4xf32>,
|
||||
%arg2 : vector<2x4xf32>, %arg3 : vector<2x4xf32>,
|
||||
%arg4 : vector<2x4xf32>, %arg5 : vector<2x4xf32>,
|
||||
%arg6 : vector<2x4xf32>, %arg7 : vector<2x4xf32>) -> vector<2x4xf32> {
|
||||
%0 = vector.tuple %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7
|
||||
: vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>,
|
||||
vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>
|
||||
// %arg7 == %0 at tupleIndex = 7, offsets = [0, 0]
|
||||
%1 = vector.insert_slices %0, [2, 4], [1, 1]
|
||||
: tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>,
|
||||
vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
|
||||
into vector<4x16xf32>
|
||||
// %arg7 == %1 at tupleIndex = -1, offsets = [2, 12]
|
||||
%2 = vector.extract_slices %1, [4, 8], [1, 1]
|
||||
: vector<4x16xf32> into tuple<vector<4x8xf32>, vector<4x8xf32>>
|
||||
// %arg7 == %2 at tupleIndex = 1, offsets = [2, 4]
|
||||
|
||||
// Extract tuple elements.
|
||||
%3 = vector.tuple_get %2, 0 : tuple<vector<4x8xf32>, vector<4x8xf32>>
|
||||
%4 = vector.tuple_get %2, 1 : tuple<vector<4x8xf32>, vector<4x8xf32>>
|
||||
// %arg7 == %4 at tupleIndex = -1, offsets = [2, 4]
|
||||
|
||||
// Swizzle tuple elements.
|
||||
%5 = vector.tuple %4, %3 : vector<4x8xf32>, vector<4x8xf32>
|
||||
// %arg7 == %5 at tupleIndex = 0, offsets = [2, 4]
|
||||
%6 = vector.tuple_get %5, 0 : tuple<vector<4x8xf32>, vector<4x8xf32>>
|
||||
// %arg7 == %6 at tupleIndex = -1, offsets = [2, 4]
|
||||
%7 = vector.extract_slices %6, [2, 4], [1, 1]
|
||||
: vector<4x8xf32> into
|
||||
tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
|
||||
// %arg7 == %7 at tupleIndex = 3, offsets = [0, 0]
|
||||
%8 = vector.tuple_get %7, 3
|
||||
: tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
|
||||
// %arg7 == %8
|
||||
return %8 : vector<2x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @vector_transfers_vector_element_type
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
|
Loading…
x
Reference in New Issue
Block a user