mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-04-01 12:43:47 +00:00
[mlir][TilingInterface] Make the tiling set tile sizes function use OpFoldResult
. (#66566)
This commit is contained in:
parent
75fdf2e7b6
commit
170a25a793
@ -26,7 +26,7 @@ namespace mlir {
|
||||
namespace scf {
|
||||
|
||||
using SCFTileSizeComputationFunction =
|
||||
std::function<SmallVector<Value>(OpBuilder &, Operation *)>;
|
||||
std::function<SmallVector<OpFoldResult>(OpBuilder &, Operation *)>;
|
||||
|
||||
/// Options to use to control tiling.
|
||||
struct SCFTilingOptions {
|
||||
@ -40,17 +40,10 @@ struct SCFTilingOptions {
|
||||
tileSizeComputationFunction = std::move(fun);
|
||||
return *this;
|
||||
}
|
||||
/// Set the `tileSizeComputationFunction` to return the values `ts`. The
|
||||
/// values must not fold away when tiling. Otherwise, use a more robust
|
||||
/// `tileSizeComputationFunction`.
|
||||
SCFTilingOptions &setTileSizes(const SmallVector<Value, 4> &ts) {
|
||||
tileSizeComputationFunction = [=](OpBuilder &, Operation *) { return ts; };
|
||||
return *this;
|
||||
}
|
||||
/// Convenience function to set the `tileSizeComputationFunction` to a
|
||||
/// function that computes tile sizes at the point they are needed. Allows
|
||||
/// proper interaction with folding.
|
||||
SCFTilingOptions &setTileSizes(ArrayRef<int64_t> ts);
|
||||
SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> ts);
|
||||
|
||||
/// The interchange vector to reorder the tiled loops.
|
||||
SmallVector<int64_t> interchangeVector = {};
|
||||
|
@ -473,7 +473,9 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
|
||||
|
||||
scf::SCFTilingOptions tilingOptions;
|
||||
tilingOptions.interchangeVector = tileInterchange;
|
||||
tilingOptions = tilingOptions.setTileSizes(tileSizes);
|
||||
SmallVector<OpFoldResult> tileSizesOfr =
|
||||
getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
|
||||
tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
|
||||
scf::SCFTileAndFuseOptions tileAndFuseOptions;
|
||||
tileAndFuseOptions.tilingOptions = tilingOptions;
|
||||
LogicalResult result = applyTilingToAll(
|
||||
@ -923,7 +925,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
|
||||
auto nextProducer = getNextProducer();
|
||||
if (failed(nextProducer)) {
|
||||
auto diag = mlir::emitSilenceableFailure(getLoc())
|
||||
<< "could not find next producer to fuse into container";
|
||||
<< "could not find next producer to fuse into container";
|
||||
diag.attachNote(containingOp->getLoc()) << "containing op";
|
||||
return diag;
|
||||
}
|
||||
@ -1999,7 +2001,7 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
|
||||
transform::TransformState &state) {
|
||||
scf::SCFTilingOptions tilingOptions;
|
||||
tilingOptions.setTileSizeComputationFunction([&](OpBuilder &b, Operation *) {
|
||||
SmallVector<Value, 4> tileSizes;
|
||||
SmallVector<OpFoldResult> tileSizes;
|
||||
Location loc = target.getLoc();
|
||||
SmallVector<OpFoldResult> allShapeSizes =
|
||||
target.createFlatListOfOperandDims(b, loc);
|
||||
@ -2012,9 +2014,8 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
|
||||
// If the shape size is dynamic, tile by 1.
|
||||
// Otherwise, do not tile (i.e. tile size 0).
|
||||
for (OpFoldResult shapeSize : shapeSizes) {
|
||||
tileSizes.push_back(getConstantIntValue(shapeSize)
|
||||
? b.create<arith::ConstantIndexOp>(loc, 0)
|
||||
: b.create<arith::ConstantIndexOp>(loc, 1));
|
||||
tileSizes.push_back(getConstantIntValue(shapeSize) ? b.getIndexAttr(0)
|
||||
: b.getIndexAttr(1));
|
||||
}
|
||||
return tileSizes;
|
||||
});
|
||||
@ -2549,7 +2550,7 @@ transform::TileOp::apply(transform::TransformRewriter &rewriter,
|
||||
if (!tileSizes.empty()) {
|
||||
tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &b,
|
||||
Operation *) {
|
||||
SmallVector<Value, 4> sizes;
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
sizes.reserve(tileSizes.size());
|
||||
unsigned dynamicIdx = 0;
|
||||
|
||||
@ -2560,10 +2561,10 @@ transform::TileOp::apply(transform::TransformRewriter &rewriter,
|
||||
getLoc(), attr.cast<IntegerAttr>().getInt());
|
||||
Value vscale =
|
||||
b.create<vector::VectorScaleOp>(getLoc(), b.getIndexType());
|
||||
sizes.push_back(b.create<arith::MulIOp>(getLoc(), val, vscale));
|
||||
sizes.push_back(
|
||||
b.create<arith::MulIOp>(getLoc(), val, vscale).getResult());
|
||||
} else {
|
||||
sizes.push_back(b.create<arith::ConstantIndexOp>(
|
||||
getLoc(), cast<IntegerAttr>(attr).getInt()));
|
||||
sizes.push_back(attr);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
@ -2573,8 +2574,7 @@ transform::TileOp::apply(transform::TransformRewriter &rewriter,
|
||||
assert((dynamicSizes.empty() ^ params.empty()) &&
|
||||
"expected either dynamic sizes or parameters");
|
||||
if (!params.empty()) {
|
||||
sizes.push_back(
|
||||
b.create<arith::ConstantIndexOp>(getLoc(), params[index]));
|
||||
sizes.push_back(b.getIndexAttr(params[index]));
|
||||
} else {
|
||||
sizes.push_back(dynamicSizes[index]->getResult(0));
|
||||
}
|
||||
|
@ -31,19 +31,11 @@
|
||||
using namespace mlir;
|
||||
|
||||
scf::SCFTilingOptions &
|
||||
scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
|
||||
scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) {
|
||||
assert(!tileSizeComputationFunction && "tile sizes already set");
|
||||
SmallVector<int64_t> tileSizes(ts.begin(), ts.end());
|
||||
auto tileSizes = llvm::to_vector(ts);
|
||||
tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
|
||||
OpBuilder::InsertionGuard guard(b);
|
||||
b.setInsertionPointToStart(
|
||||
&op->getParentWithTrait<OpTrait::IsIsolatedFromAbove>()
|
||||
->getRegion(0)
|
||||
.front());
|
||||
return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
|
||||
Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
|
||||
return v;
|
||||
}));
|
||||
return tileSizes;
|
||||
};
|
||||
return *this;
|
||||
}
|
||||
@ -108,17 +100,16 @@ static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
|
||||
|
||||
/// Generate an empty loop nest that represents the tiled loop nest shell.
|
||||
/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
|
||||
/// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops.
|
||||
/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
|
||||
/// - In `offsets` and `sizes` return the multi-dimensional offset and size of
|
||||
/// the
|
||||
/// tile processed within the inner most loop.
|
||||
static SmallVector<scf::ForOp>
|
||||
generateTileLoopNest(OpBuilder &builder, Location loc,
|
||||
ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals,
|
||||
SmallVector<OpFoldResult> &offsets,
|
||||
SmallVector<OpFoldResult> &sizes) {
|
||||
static SmallVector<scf::ForOp> generateTileLoopNest(
|
||||
OpBuilder &builder, Location loc, ArrayRef<Range> loopRanges,
|
||||
ArrayRef<OpFoldResult> tileSizes, SmallVector<OpFoldResult> &offsets,
|
||||
SmallVector<OpFoldResult> &sizes) {
|
||||
assert(!loopRanges.empty() && "expected at least one loop range");
|
||||
assert(loopRanges.size() == tileSizeVals.size() &&
|
||||
assert(loopRanges.size() == tileSizes.size() &&
|
||||
"expected as many tile sizes as loop ranges");
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
SmallVector<scf::ForOp> loops;
|
||||
@ -130,7 +121,8 @@ generateTileLoopNest(OpBuilder &builder, Location loc,
|
||||
getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().offset);
|
||||
Value size =
|
||||
getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().size);
|
||||
Value tileSize = tileSizeVals[loopRange.index()];
|
||||
Value tileSize = getValueOrCreateConstantIndexOp(
|
||||
builder, loc, tileSizes[loopRange.index()]);
|
||||
// No loops if tile size is zero. Set offset and size to the loop
|
||||
// offset and size.
|
||||
if (matchPattern(tileSize, m_Zero())) {
|
||||
@ -296,10 +288,10 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
|
||||
// skips tiling a particular dimension. This convention is significantly
|
||||
// simpler to handle instead of adjusting affine maps to account for missing
|
||||
// dimensions.
|
||||
SmallVector<Value> tileSizeVector =
|
||||
SmallVector<OpFoldResult> tileSizeVector =
|
||||
options.tileSizeComputationFunction(rewriter, op);
|
||||
if (tileSizeVector.size() < iterationDomain.size()) {
|
||||
auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
|
||||
auto zero = rewriter.getIndexAttr(0);
|
||||
tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
|
||||
}
|
||||
|
||||
@ -402,17 +394,17 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
|
||||
FailureOr<scf::SCFReductionTilingResult>
|
||||
mlir::scf::tileReductionUsingScf(RewriterBase &b,
|
||||
PartialReductionOpInterface op,
|
||||
ArrayRef<OpFoldResult> tileSize) {
|
||||
ArrayRef<OpFoldResult> tileSizes) {
|
||||
Location loc = op.getLoc();
|
||||
// Ops implementing PartialReductionOpInterface are expected to implement
|
||||
// TilingInterface.
|
||||
auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
|
||||
SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
|
||||
SmallVector<Value> tileSizeVector =
|
||||
getValueOrCreateConstantIndexOp(b, loc, tileSize);
|
||||
if (tileSizeVector.size() < iterationDomain.size()) {
|
||||
auto zero = b.create<arith::ConstantIndexOp>(loc, 0);
|
||||
tileSizeVector.append(iterationDomain.size() - tileSizeVector.size(), zero);
|
||||
auto tileSizesVector = llvm::to_vector(tileSizes);
|
||||
if (tileSizesVector.size() < iterationDomain.size()) {
|
||||
auto zero = b.getIndexAttr(0);
|
||||
tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(),
|
||||
zero);
|
||||
}
|
||||
if (op->getNumResults() != 1)
|
||||
return b.notifyMatchFailure(
|
||||
@ -429,7 +421,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
|
||||
|
||||
// 1. create the inital tensor value.
|
||||
FailureOr<Operation *> identityTensor =
|
||||
op.generateInitialTensorForPartialReduction(b, loc, tileSize,
|
||||
op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector,
|
||||
reductionDims);
|
||||
if (failed(identityTensor))
|
||||
return b.notifyMatchFailure(op,
|
||||
@ -437,7 +429,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
|
||||
// 2. Create the nested loops.
|
||||
SmallVector<OpFoldResult> offsets, sizes;
|
||||
SmallVector<scf::ForOp> loops = generateTileLoopNest(
|
||||
b, loc, iterationDomain, tileSizeVector, offsets, sizes);
|
||||
b, loc, iterationDomain, tileSizesVector, offsets, sizes);
|
||||
|
||||
// 3. Generate the tiled implementation within the inner most loop.
|
||||
b.setInsertionPoint(loops.back().getBody()->getTerminator());
|
||||
|
@ -190,16 +190,16 @@ transform.sequence failures(propagate) {
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @scalable_and_fixed_length_tile
|
||||
// CHECK: %[[STEP_0:.*]] = arith.constant 4 : index
|
||||
// CHECK: %[[STEP_1:.*]] = arith.constant 4 : index
|
||||
// CHECK: %[[C4:.*]] = arith.constant 4 : index
|
||||
// CHECK: %[[VS:.*]] = vector.vscale
|
||||
// CHECK: %[[STEP_2:.*]] = arith.muli %[[C4]], %[[VS]] : index
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[C128:.*]] = arith.constant 128 : index
|
||||
// CHECK: %[[STEP_0:.*]] = arith.constant 4 : index
|
||||
// CHECK: scf.for %[[VAL_11:.*]] = %[[C0]] to %[[C128]] step %[[STEP_0]]
|
||||
// CHECK: %[[C0_1:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[C128_1:.*]] = arith.constant 128 : index
|
||||
// CHECK: %[[STEP_1:.*]] = arith.constant 4 : index
|
||||
// CHECK: scf.for %[[VAL_16:.*]] = %[[C0_1]] to %[[C128_1]] step %[[STEP_1]]
|
||||
// CHECK: %[[C0_2:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[C128_2:.*]] = arith.constant 128 : index
|
||||
|
@ -450,7 +450,9 @@ static void addPatternForTiling(MLIRContext *context,
|
||||
ArrayRef<int64_t> tileSizes,
|
||||
ArrayRef<int64_t> interchange = {}) {
|
||||
scf::SCFTilingOptions tilingOptions;
|
||||
tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
|
||||
SmallVector<OpFoldResult> tileSizesOfr =
|
||||
getAsIndexOpFoldResult(context, tileSizes);
|
||||
tilingOptions.setTileSizes(tileSizesOfr).setInterchange(interchange);
|
||||
LinalgTransformationFilter filter(StringAttr::get(context, filterName),
|
||||
StringAttr::get(context, "tiled"));
|
||||
patterns.add<TestTileUsingSCFForOp>(context, tilingOptions, filter);
|
||||
@ -462,7 +464,9 @@ static void addPatternForTileFuseAndYield(MLIRContext *context,
|
||||
ArrayRef<int64_t> tileSizes,
|
||||
ArrayRef<int64_t> interchange = {}) {
|
||||
scf::SCFTilingOptions tilingOptions;
|
||||
tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
|
||||
SmallVector<OpFoldResult> tileSizesOfr =
|
||||
getAsIndexOpFoldResult(context, tileSizes);
|
||||
tilingOptions.setTileSizes(tileSizesOfr).setInterchange(interchange);
|
||||
LinalgTransformationFilter filter(StringAttr::get(context, filterName),
|
||||
StringAttr::get(context, "tiled"));
|
||||
patterns.add<TestTileConsumerFuseAndYieldProducerUsingSCFForOp>(
|
||||
@ -475,8 +479,10 @@ static void addPatternForTileAndFuse(MLIRContext *context,
|
||||
ArrayRef<int64_t> tileSizes,
|
||||
ArrayRef<int64_t> interchange = {}) {
|
||||
scf::SCFTileAndFuseOptions tileAndFuseOptions;
|
||||
tileAndFuseOptions.tilingOptions.setTileSizes(tileSizes).setInterchange(
|
||||
interchange);
|
||||
SmallVector<OpFoldResult> tileSizesOfr =
|
||||
getAsIndexOpFoldResult(context, tileSizes);
|
||||
tileAndFuseOptions.tilingOptions.setTileSizes(tileSizesOfr)
|
||||
.setInterchange(interchange);
|
||||
LinalgTransformationFilter filter(StringAttr::get(context, filterName),
|
||||
StringAttr::get(context, "tiled"));
|
||||
patterns.add<TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp>(
|
||||
|
Loading…
x
Reference in New Issue
Block a user