mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-01-19 14:44:39 +00:00
[mlir][transform] Add support for expressing scalable tile sizes
This patch enables specifying scalable tile sizes when using the Transform dialect to drive tiling, e.g.: ``` %1, %loop = transform.structured.tile %0 [[4]] ``` This is implemented by extending the TileOp with a dedicated attribute for "scalability" and by updating various parsing hooks. At the moment, only the trailing tile size can be scalable. The following is not yet supported: ``` %1, %loop = transform.structured.tile %0 [[4], [4]] ``` This change is a part of larger effort to enable scalable vectorisation in Linalg. See this RFC for more context: * https://discourse.llvm.org/t/rfc-scalable-vectorisation-in-linalg/ Differential Revision: https://reviews.llvm.org/D150944
This commit is contained in:
parent
cd888e6ffe
commit
a5b3677ddc
@ -1528,7 +1528,8 @@ def TileOp : Op<Transform_Dialect, "structured.tile",
|
||||
let arguments = (ins TransformHandleTypeInterface:$target,
|
||||
Variadic<TransformParamTypeOrAnyHandle>:$dynamic_sizes,
|
||||
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sizes,
|
||||
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$interchange);
|
||||
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$interchange,
|
||||
DefaultValuedOptionalAttr<BoolAttr, "false">:$last_tile_size_scalable);
|
||||
let results = (outs TransformHandleTypeInterface:$tiled_linalg_op,
|
||||
Variadic<TransformHandleTypeInterface>:$loops);
|
||||
let builders = [
|
||||
|
@ -72,17 +72,27 @@ void printDynamicIndexList(
|
||||
/// 1. `result` is filled with the i64 ArrayAttr "[`kDynamic`, 7, 42,
|
||||
/// `kDynamic`]"
|
||||
/// 2. `ssa` is filled with "[%arg0, %arg1]".
|
||||
///
|
||||
/// Trailing indices can be scalable. For example, "42" in "[7, [42]]" is
|
||||
/// scalable. This notation is similar to how scalable dims are marked when
|
||||
/// defining Vectors. If /p isTrailingIdxScalable is null, scalable indices are
|
||||
/// not allowed/expected. When it's not null, this hook will set the
|
||||
/// corresponding value to:
|
||||
/// * true if the trailing idx is scalable,
|
||||
/// * false otherwise.
|
||||
ParseResult parseDynamicIndexList(
|
||||
OpAsmParser &parser,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
|
||||
DenseI64ArrayAttr &integers, SmallVectorImpl<Type> *valueTypes = nullptr,
|
||||
DenseI64ArrayAttr &integers, bool *isTrailingIdxScalable = nullptr,
|
||||
SmallVectorImpl<Type> *valueTypes = nullptr,
|
||||
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
|
||||
inline ParseResult parseDynamicIndexList(
|
||||
OpAsmParser &parser,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
|
||||
DenseI64ArrayAttr &integers, SmallVectorImpl<Type> &valueTypes,
|
||||
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
|
||||
return parseDynamicIndexList(parser, values, integers, &valueTypes,
|
||||
return parseDynamicIndexList(parser, values, integers,
|
||||
/*isTrailingIdxScalable=*/nullptr, &valueTypes,
|
||||
delimiter);
|
||||
}
|
||||
|
||||
|
@ -2391,6 +2391,7 @@ transform::TileOp::apply(TransformResults &transformResults,
|
||||
SmallVector<Operation *> tiled;
|
||||
SmallVector<SmallVector<Operation *, 4>, 4> loops;
|
||||
loops.resize(getLoops().size());
|
||||
bool scalable = getLastTileSizeScalable();
|
||||
for (auto [i, op] : llvm::enumerate(targets)) {
|
||||
auto tilingInterface = dyn_cast<TilingInterface>(op);
|
||||
auto dpsInterface = dyn_cast<DestinationStyleOpInterface>(op);
|
||||
@ -2409,10 +2410,21 @@ transform::TileOp::apply(TransformResults &transformResults,
|
||||
SmallVector<Value, 4> sizes;
|
||||
sizes.reserve(tileSizes.size());
|
||||
unsigned dynamicIdx = 0;
|
||||
for (OpFoldResult ofr : getMixedSizes()) {
|
||||
unsigned trailingIdx = getMixedSizes().size() - 1;
|
||||
|
||||
for (auto [ofrIdx, ofr] : llvm::enumerate(getMixedSizes())) {
|
||||
if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
|
||||
sizes.push_back(b.create<arith::ConstantIndexOp>(
|
||||
getLoc(), cast<IntegerAttr>(attr).getInt()));
|
||||
// Only the trailing tile size is allowed to be scalable atm.
|
||||
if (scalable && (ofrIdx == trailingIdx)) {
|
||||
auto val = b.create<arith::ConstantIndexOp>(
|
||||
getLoc(), attr.cast<IntegerAttr>().getInt());
|
||||
Value vscale =
|
||||
b.create<vector::VectorScaleOp>(getLoc(), b.getIndexType());
|
||||
sizes.push_back(b.create<arith::MulIOp>(getLoc(), val, vscale));
|
||||
} else {
|
||||
sizes.push_back(b.create<arith::ConstantIndexOp>(
|
||||
getLoc(), cast<IntegerAttr>(attr).getInt()));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
ArrayRef<Operation *> dynamicSizes = dynamicSizeProducers[dynamicIdx];
|
||||
@ -2507,8 +2519,9 @@ ParseResult transform::TileOp::parse(OpAsmParser &parser,
|
||||
DenseI64ArrayAttr staticSizes;
|
||||
FunctionType functionalType;
|
||||
llvm::SMLoc operandLoc;
|
||||
bool scalable = false;
|
||||
if (parser.parseOperand(target) || parser.getCurrentLocation(&operandLoc) ||
|
||||
parseDynamicIndexList(parser, dynamicSizes, staticSizes) ||
|
||||
parseDynamicIndexList(parser, dynamicSizes, staticSizes, &scalable) ||
|
||||
parseOptionalInterchange(parser, result) ||
|
||||
parser.parseColonType(functionalType))
|
||||
return ParseResult::failure();
|
||||
@ -2531,6 +2544,10 @@ ParseResult transform::TileOp::parse(OpAsmParser &parser,
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto scalableAttr = parser.getBuilder().getBoolAttr(scalable);
|
||||
result.addAttribute(getLastTileSizeScalableAttrName(result.name),
|
||||
scalableAttr);
|
||||
|
||||
result.addAttribute(getStaticSizesAttrName(result.name), staticSizes);
|
||||
result.addTypes(functionalType.getResults());
|
||||
return success();
|
||||
|
@ -1261,9 +1261,9 @@ ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
dynamicSteps;
|
||||
if (succeeded(parser.parseOptionalKeyword("in"))) {
|
||||
// Parse upper bounds.
|
||||
if (parseDynamicIndexList(parser, dynamicUbs, staticUbs,
|
||||
/*valueTypes=*/nullptr,
|
||||
OpAsmParser::Delimiter::Paren) ||
|
||||
if (parseDynamicIndexList(
|
||||
parser, dynamicUbs, staticUbs, /*scalable=*/nullptr,
|
||||
/*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) ||
|
||||
parser.resolveOperands(dynamicUbs, indexType, result.operands))
|
||||
return failure();
|
||||
|
||||
@ -1273,26 +1273,26 @@ ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
} else {
|
||||
// Parse lower bounds.
|
||||
if (parser.parseEqual() ||
|
||||
parseDynamicIndexList(parser, dynamicLbs, staticLbs,
|
||||
/*valueTypes=*/nullptr,
|
||||
OpAsmParser::Delimiter::Paren) ||
|
||||
parseDynamicIndexList(
|
||||
parser, dynamicLbs, staticLbs, /*scalable=*/nullptr,
|
||||
/*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) ||
|
||||
|
||||
parser.resolveOperands(dynamicLbs, indexType, result.operands))
|
||||
return failure();
|
||||
|
||||
// Parse upper bounds.
|
||||
if (parser.parseKeyword("to") ||
|
||||
parseDynamicIndexList(parser, dynamicUbs, staticUbs,
|
||||
/*valueTypes=*/nullptr,
|
||||
OpAsmParser::Delimiter::Paren) ||
|
||||
parseDynamicIndexList(
|
||||
parser, dynamicUbs, staticUbs, /*scalable=*/nullptr,
|
||||
/*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) ||
|
||||
parser.resolveOperands(dynamicUbs, indexType, result.operands))
|
||||
return failure();
|
||||
|
||||
// Parse step values.
|
||||
if (parser.parseKeyword("step") ||
|
||||
parseDynamicIndexList(parser, dynamicSteps, staticSteps,
|
||||
/*valueTypes=*/nullptr,
|
||||
OpAsmParser::Delimiter::Paren) ||
|
||||
parseDynamicIndexList(
|
||||
parser, dynamicSteps, staticSteps, /*scalable=*/nullptr,
|
||||
/*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) ||
|
||||
parser.resolveOperands(dynamicSteps, indexType, result.operands))
|
||||
return failure();
|
||||
}
|
||||
|
@ -42,5 +42,6 @@ ParseResult mlir::transform::parsePackedOrDynamicIndexList(
|
||||
return success();
|
||||
}
|
||||
|
||||
return parseDynamicIndexList(parser, values, integers, &valueTypes);
|
||||
return parseDynamicIndexList(parser, values, integers, /*scalable=*/nullptr,
|
||||
&valueTypes);
|
||||
}
|
||||
|
@ -128,13 +128,26 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
|
||||
ParseResult mlir::parseDynamicIndexList(
|
||||
OpAsmParser &parser,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
|
||||
DenseI64ArrayAttr &integers, SmallVectorImpl<Type> *valueTypes,
|
||||
AsmParser::Delimiter delimiter) {
|
||||
DenseI64ArrayAttr &integers, bool *isTrailingIdxScalable,
|
||||
SmallVectorImpl<Type> *valueTypes, AsmParser::Delimiter delimiter) {
|
||||
|
||||
SmallVector<int64_t, 4> integerVals;
|
||||
bool foundScalable = false;
|
||||
auto parseIntegerOrValue = [&]() {
|
||||
OpAsmParser::UnresolvedOperand operand;
|
||||
auto res = parser.parseOptionalOperand(operand);
|
||||
|
||||
// If `foundScalable` has already been set to `true` then a non-trailing
|
||||
// tile size was identified as scalable.
|
||||
if (foundScalable) {
|
||||
parser.emitError(parser.getNameLoc())
|
||||
<< "non-trailing tile size cannot be scalable";
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (isTrailingIdxScalable && parser.parseOptionalLSquare().succeeded())
|
||||
foundScalable = true;
|
||||
|
||||
if (res.has_value() && succeeded(res.value())) {
|
||||
values.push_back(operand);
|
||||
integerVals.push_back(ShapedType::kDynamic);
|
||||
@ -146,6 +159,8 @@ ParseResult mlir::parseDynamicIndexList(
|
||||
return failure();
|
||||
integerVals.push_back(integer);
|
||||
}
|
||||
if (foundScalable && parser.parseOptionalRSquare().failed())
|
||||
return failure();
|
||||
return success();
|
||||
};
|
||||
if (parser.parseCommaSeparatedList(delimiter, parseIntegerOrValue,
|
||||
@ -153,6 +168,8 @@ ParseResult mlir::parseDynamicIndexList(
|
||||
return parser.emitError(parser.getNameLoc())
|
||||
<< "expected SSA value or integer";
|
||||
integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
|
||||
if (isTrailingIdxScalable)
|
||||
*isTrailingIdxScalable = foundScalable;
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s
|
||||
// RUN: mlir-opt --test-transform-dialect-interpreter --mlir-print-local-scope --split-input-file --verify-diagnostics %s | FileCheck %s
|
||||
|
||||
transform.sequence failures(propagate) {
|
||||
^bb0(%arg1: !transform.any_op):
|
||||
@ -149,3 +149,96 @@ transform.sequence failures(propagate) {
|
||||
transform.structured.tile_to_forall_op %0 tile_sizes[1, 1]
|
||||
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#map = affine_map<(d0) -> (d0)>
|
||||
|
||||
module {
|
||||
func.func @scalable_tile(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>, %arg3: f32) -> tensor<?xf32> {
|
||||
%0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>) outs(%arg2 : tensor<?xf32>) {
|
||||
^bb0(%in_1: f32, %in_2: f32, %out: f32):
|
||||
%1 = arith.addf %in_1, %in_2 : f32
|
||||
%2 = arith.mulf %arg3, %1 : f32
|
||||
linalg.yield %2 : f32
|
||||
} -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @scalable_tile(
|
||||
// CHECK-SAME: %[[ARG_0:.*]]: tensor<?xf32>, %[[ARG_1:.*]]: tensor<?xf32>, %[[ARG_2:.*]]: tensor<?xf32>,
|
||||
// CHECK: %[[C4:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[DIM:.*]] = tensor.dim %[[ARG_0]], %[[C4]] : tensor<?xf32>
|
||||
// CHECK: %[[VEC_SIZE:.*]] = arith.constant 4 : index
|
||||
// CHECK: %[[VS:.*]] = vector.vscale
|
||||
// CHECK: %[[STEP:.*]] = arith.muli %[[VEC_SIZE]], %[[VS]] : index
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: scf.for %[[IV:.*]] = %[[C0]] to %[[DIM]] step %[[STEP]] iter_args(%[[VAL:.*]] = %[[ARG_2]]) -> (tensor<?xf32>) {
|
||||
// CHECK: %[[SIZE:.*]] = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%[[IV]])[%[[STEP]], %[[DIM]]]
|
||||
// CHECK: %[[SLICE_ARG0:.*]] = tensor.extract_slice %[[ARG_0]][%[[IV]]] [%[[SIZE]]] [1] : tensor<?xf32> to tensor<?xf32>
|
||||
// CHECK: %[[SLICE_ARG1:.*]] = tensor.extract_slice %[[ARG_1]][%[[IV]]] [%[[SIZE]]] [1] : tensor<?xf32> to tensor<?xf32>
|
||||
// CHECK: %[[SLICE_ARG2:.*]] = tensor.extract_slice %[[VAL]][%[[IV]]] [%[[SIZE]]] [1] : tensor<?xf32> to tensor<?xf32>
|
||||
// CHECK: linalg.generic {indexing_maps = {{.*}}, iterator_types = ["parallel"]} ins(%[[SLICE_ARG0]], %[[SLICE_ARG1]] : tensor<?xf32>, tensor<?xf32>) outs(%[[SLICE_ARG2]] : tensor<?xf32>) {
|
||||
|
||||
transform.sequence failures(propagate) {
|
||||
^bb0(%arg1: !transform.any_op):
|
||||
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||
%1, %loop = transform.structured.tile %0 [[4]] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @scalable_and_fixed_length_tile
|
||||
// CHECK: %[[STEP_0:.*]] = arith.constant 4 : index
|
||||
// CHECK: %[[STEP_1:.*]] = arith.constant 4 : index
|
||||
// CHECK: %[[C4:.*]] = arith.constant 4 : index
|
||||
// CHECK: %[[VS:.*]] = vector.vscale
|
||||
// CHECK: %[[STEP_2:.*]] = arith.muli %[[C4]], %[[VS]] : index
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[C128:.*]] = arith.constant 128 : index
|
||||
// CHECK: scf.for %[[VAL_11:.*]] = %[[C0]] to %[[C128]] step %[[STEP_0]]
|
||||
// CHECK: %[[C0_1:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[C128_1:.*]] = arith.constant 128 : index
|
||||
// CHECK: scf.for %[[VAL_16:.*]] = %[[C0_1]] to %[[C128_1]] step %[[STEP_1]]
|
||||
// CHECK: %[[C0_2:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[C128_2:.*]] = arith.constant 128 : index
|
||||
// CHECK: scf.for %{{.*}} = %[[C0_2]] to %[[C128_2]] step %[[STEP_2]]
|
||||
|
||||
func.func @scalable_and_fixed_length_tile(
|
||||
%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>)
|
||||
-> tensor<128x128xf32> {
|
||||
%0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
|
||||
outs(%arg2: tensor<128x128xf32>)
|
||||
-> tensor<128x128xf32>
|
||||
|
||||
return %0 : tensor<128x128xf32>
|
||||
}
|
||||
|
||||
transform.sequence failures(propagate) {
|
||||
^bb0(%arg1: !transform.any_op):
|
||||
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||
%1, %loops:3 = transform.structured.tile %0 [4, 4, [4]] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// TODO: Add support for for specyfying more than one scalable tile size
|
||||
|
||||
func.func @scalable_and_fixed_length_tile(
|
||||
%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>)
|
||||
-> tensor<128x128xf32> {
|
||||
%0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
|
||||
outs(%arg2: tensor<128x128xf32>)
|
||||
-> tensor<128x128xf32>
|
||||
|
||||
return %0 : tensor<128x128xf32>
|
||||
}
|
||||
|
||||
transform.sequence failures(propagate) {
|
||||
^bb0(%arg1: !transform.any_op):
|
||||
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||
// expected-error @below {{non-trailing tile size cannot be scalable}}
|
||||
// expected-error @below {{expected SSA value or integer}}
|
||||
%1, %loops:3 = transform.structured.tile %0 [4, [4], [4]] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user