mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-01-31 22:25:56 +00:00
[mlir][transform] Add support for expressing scalable tile sizes
This patch enables specifying scalable tile sizes when using the Transform dialect to drive tiling, e.g.: ``` %1, %loop = transform.structured.tile %0 [[4]] ``` This is implemented by extending the TileOp with a dedicated attribute for "scalability" and by updating various parsing hooks. At the moment, only the trailing tile size can be scalable. The following is not yet supported: ``` %1, %loop = transform.structured.tile %0 [[4], [4]] ``` This change is a part of larger effort to enable scalable vectorisation in Linalg. See this RFC for more context: * https://discourse.llvm.org/t/rfc-scalable-vectorisation-in-linalg/ Differential Revision: https://reviews.llvm.org/D150944
This commit is contained in:
parent
cd888e6ffe
commit
a5b3677ddc
@ -1528,7 +1528,8 @@ def TileOp : Op<Transform_Dialect, "structured.tile",
|
|||||||
let arguments = (ins TransformHandleTypeInterface:$target,
|
let arguments = (ins TransformHandleTypeInterface:$target,
|
||||||
Variadic<TransformParamTypeOrAnyHandle>:$dynamic_sizes,
|
Variadic<TransformParamTypeOrAnyHandle>:$dynamic_sizes,
|
||||||
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sizes,
|
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sizes,
|
||||||
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$interchange);
|
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$interchange,
|
||||||
|
DefaultValuedOptionalAttr<BoolAttr, "false">:$last_tile_size_scalable);
|
||||||
let results = (outs TransformHandleTypeInterface:$tiled_linalg_op,
|
let results = (outs TransformHandleTypeInterface:$tiled_linalg_op,
|
||||||
Variadic<TransformHandleTypeInterface>:$loops);
|
Variadic<TransformHandleTypeInterface>:$loops);
|
||||||
let builders = [
|
let builders = [
|
||||||
|
@ -72,17 +72,27 @@ void printDynamicIndexList(
|
|||||||
/// 1. `result` is filled with the i64 ArrayAttr "[`kDynamic`, 7, 42,
|
/// 1. `result` is filled with the i64 ArrayAttr "[`kDynamic`, 7, 42,
|
||||||
/// `kDynamic`]"
|
/// `kDynamic`]"
|
||||||
/// 2. `ssa` is filled with "[%arg0, %arg1]".
|
/// 2. `ssa` is filled with "[%arg0, %arg1]".
|
||||||
|
///
|
||||||
|
/// Trailing indices can be scalable. For example, "42" in "[7, [42]]" is
|
||||||
|
/// scalable. This notation is similar to how scalable dims are marked when
|
||||||
|
/// defining Vectors. If /p isTrailingIdxScalable is null, scalable indices are
|
||||||
|
/// not allowed/expected. When it's not null, this hook will set the
|
||||||
|
/// corresponding value to:
|
||||||
|
/// * true if the trailing idx is scalable,
|
||||||
|
/// * false otherwise.
|
||||||
ParseResult parseDynamicIndexList(
|
ParseResult parseDynamicIndexList(
|
||||||
OpAsmParser &parser,
|
OpAsmParser &parser,
|
||||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
|
||||||
DenseI64ArrayAttr &integers, SmallVectorImpl<Type> *valueTypes = nullptr,
|
DenseI64ArrayAttr &integers, bool *isTrailingIdxScalable = nullptr,
|
||||||
|
SmallVectorImpl<Type> *valueTypes = nullptr,
|
||||||
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
|
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
|
||||||
inline ParseResult parseDynamicIndexList(
|
inline ParseResult parseDynamicIndexList(
|
||||||
OpAsmParser &parser,
|
OpAsmParser &parser,
|
||||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
|
||||||
DenseI64ArrayAttr &integers, SmallVectorImpl<Type> &valueTypes,
|
DenseI64ArrayAttr &integers, SmallVectorImpl<Type> &valueTypes,
|
||||||
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
|
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
|
||||||
return parseDynamicIndexList(parser, values, integers, &valueTypes,
|
return parseDynamicIndexList(parser, values, integers,
|
||||||
|
/*isTrailingIdxScalable=*/nullptr, &valueTypes,
|
||||||
delimiter);
|
delimiter);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2391,6 +2391,7 @@ transform::TileOp::apply(TransformResults &transformResults,
|
|||||||
SmallVector<Operation *> tiled;
|
SmallVector<Operation *> tiled;
|
||||||
SmallVector<SmallVector<Operation *, 4>, 4> loops;
|
SmallVector<SmallVector<Operation *, 4>, 4> loops;
|
||||||
loops.resize(getLoops().size());
|
loops.resize(getLoops().size());
|
||||||
|
bool scalable = getLastTileSizeScalable();
|
||||||
for (auto [i, op] : llvm::enumerate(targets)) {
|
for (auto [i, op] : llvm::enumerate(targets)) {
|
||||||
auto tilingInterface = dyn_cast<TilingInterface>(op);
|
auto tilingInterface = dyn_cast<TilingInterface>(op);
|
||||||
auto dpsInterface = dyn_cast<DestinationStyleOpInterface>(op);
|
auto dpsInterface = dyn_cast<DestinationStyleOpInterface>(op);
|
||||||
@ -2409,10 +2410,21 @@ transform::TileOp::apply(TransformResults &transformResults,
|
|||||||
SmallVector<Value, 4> sizes;
|
SmallVector<Value, 4> sizes;
|
||||||
sizes.reserve(tileSizes.size());
|
sizes.reserve(tileSizes.size());
|
||||||
unsigned dynamicIdx = 0;
|
unsigned dynamicIdx = 0;
|
||||||
for (OpFoldResult ofr : getMixedSizes()) {
|
unsigned trailingIdx = getMixedSizes().size() - 1;
|
||||||
|
|
||||||
|
for (auto [ofrIdx, ofr] : llvm::enumerate(getMixedSizes())) {
|
||||||
if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
|
if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
|
||||||
sizes.push_back(b.create<arith::ConstantIndexOp>(
|
// Only the trailing tile size is allowed to be scalable atm.
|
||||||
getLoc(), cast<IntegerAttr>(attr).getInt()));
|
if (scalable && (ofrIdx == trailingIdx)) {
|
||||||
|
auto val = b.create<arith::ConstantIndexOp>(
|
||||||
|
getLoc(), attr.cast<IntegerAttr>().getInt());
|
||||||
|
Value vscale =
|
||||||
|
b.create<vector::VectorScaleOp>(getLoc(), b.getIndexType());
|
||||||
|
sizes.push_back(b.create<arith::MulIOp>(getLoc(), val, vscale));
|
||||||
|
} else {
|
||||||
|
sizes.push_back(b.create<arith::ConstantIndexOp>(
|
||||||
|
getLoc(), cast<IntegerAttr>(attr).getInt()));
|
||||||
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
ArrayRef<Operation *> dynamicSizes = dynamicSizeProducers[dynamicIdx];
|
ArrayRef<Operation *> dynamicSizes = dynamicSizeProducers[dynamicIdx];
|
||||||
@ -2507,8 +2519,9 @@ ParseResult transform::TileOp::parse(OpAsmParser &parser,
|
|||||||
DenseI64ArrayAttr staticSizes;
|
DenseI64ArrayAttr staticSizes;
|
||||||
FunctionType functionalType;
|
FunctionType functionalType;
|
||||||
llvm::SMLoc operandLoc;
|
llvm::SMLoc operandLoc;
|
||||||
|
bool scalable = false;
|
||||||
if (parser.parseOperand(target) || parser.getCurrentLocation(&operandLoc) ||
|
if (parser.parseOperand(target) || parser.getCurrentLocation(&operandLoc) ||
|
||||||
parseDynamicIndexList(parser, dynamicSizes, staticSizes) ||
|
parseDynamicIndexList(parser, dynamicSizes, staticSizes, &scalable) ||
|
||||||
parseOptionalInterchange(parser, result) ||
|
parseOptionalInterchange(parser, result) ||
|
||||||
parser.parseColonType(functionalType))
|
parser.parseColonType(functionalType))
|
||||||
return ParseResult::failure();
|
return ParseResult::failure();
|
||||||
@ -2531,6 +2544,10 @@ ParseResult transform::TileOp::parse(OpAsmParser &parser,
|
|||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto scalableAttr = parser.getBuilder().getBoolAttr(scalable);
|
||||||
|
result.addAttribute(getLastTileSizeScalableAttrName(result.name),
|
||||||
|
scalableAttr);
|
||||||
|
|
||||||
result.addAttribute(getStaticSizesAttrName(result.name), staticSizes);
|
result.addAttribute(getStaticSizesAttrName(result.name), staticSizes);
|
||||||
result.addTypes(functionalType.getResults());
|
result.addTypes(functionalType.getResults());
|
||||||
return success();
|
return success();
|
||||||
|
@ -1261,9 +1261,9 @@ ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) {
|
|||||||
dynamicSteps;
|
dynamicSteps;
|
||||||
if (succeeded(parser.parseOptionalKeyword("in"))) {
|
if (succeeded(parser.parseOptionalKeyword("in"))) {
|
||||||
// Parse upper bounds.
|
// Parse upper bounds.
|
||||||
if (parseDynamicIndexList(parser, dynamicUbs, staticUbs,
|
if (parseDynamicIndexList(
|
||||||
/*valueTypes=*/nullptr,
|
parser, dynamicUbs, staticUbs, /*scalable=*/nullptr,
|
||||||
OpAsmParser::Delimiter::Paren) ||
|
/*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) ||
|
||||||
parser.resolveOperands(dynamicUbs, indexType, result.operands))
|
parser.resolveOperands(dynamicUbs, indexType, result.operands))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
@ -1273,26 +1273,26 @@ ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) {
|
|||||||
} else {
|
} else {
|
||||||
// Parse lower bounds.
|
// Parse lower bounds.
|
||||||
if (parser.parseEqual() ||
|
if (parser.parseEqual() ||
|
||||||
parseDynamicIndexList(parser, dynamicLbs, staticLbs,
|
parseDynamicIndexList(
|
||||||
/*valueTypes=*/nullptr,
|
parser, dynamicLbs, staticLbs, /*scalable=*/nullptr,
|
||||||
OpAsmParser::Delimiter::Paren) ||
|
/*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) ||
|
||||||
|
|
||||||
parser.resolveOperands(dynamicLbs, indexType, result.operands))
|
parser.resolveOperands(dynamicLbs, indexType, result.operands))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// Parse upper bounds.
|
// Parse upper bounds.
|
||||||
if (parser.parseKeyword("to") ||
|
if (parser.parseKeyword("to") ||
|
||||||
parseDynamicIndexList(parser, dynamicUbs, staticUbs,
|
parseDynamicIndexList(
|
||||||
/*valueTypes=*/nullptr,
|
parser, dynamicUbs, staticUbs, /*scalable=*/nullptr,
|
||||||
OpAsmParser::Delimiter::Paren) ||
|
/*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) ||
|
||||||
parser.resolveOperands(dynamicUbs, indexType, result.operands))
|
parser.resolveOperands(dynamicUbs, indexType, result.operands))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// Parse step values.
|
// Parse step values.
|
||||||
if (parser.parseKeyword("step") ||
|
if (parser.parseKeyword("step") ||
|
||||||
parseDynamicIndexList(parser, dynamicSteps, staticSteps,
|
parseDynamicIndexList(
|
||||||
/*valueTypes=*/nullptr,
|
parser, dynamicSteps, staticSteps, /*scalable=*/nullptr,
|
||||||
OpAsmParser::Delimiter::Paren) ||
|
/*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) ||
|
||||||
parser.resolveOperands(dynamicSteps, indexType, result.operands))
|
parser.resolveOperands(dynamicSteps, indexType, result.operands))
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
@ -42,5 +42,6 @@ ParseResult mlir::transform::parsePackedOrDynamicIndexList(
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
return parseDynamicIndexList(parser, values, integers, &valueTypes);
|
return parseDynamicIndexList(parser, values, integers, /*scalable=*/nullptr,
|
||||||
|
&valueTypes);
|
||||||
}
|
}
|
||||||
|
@ -128,13 +128,26 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
|
|||||||
ParseResult mlir::parseDynamicIndexList(
|
ParseResult mlir::parseDynamicIndexList(
|
||||||
OpAsmParser &parser,
|
OpAsmParser &parser,
|
||||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
|
||||||
DenseI64ArrayAttr &integers, SmallVectorImpl<Type> *valueTypes,
|
DenseI64ArrayAttr &integers, bool *isTrailingIdxScalable,
|
||||||
AsmParser::Delimiter delimiter) {
|
SmallVectorImpl<Type> *valueTypes, AsmParser::Delimiter delimiter) {
|
||||||
|
|
||||||
SmallVector<int64_t, 4> integerVals;
|
SmallVector<int64_t, 4> integerVals;
|
||||||
|
bool foundScalable = false;
|
||||||
auto parseIntegerOrValue = [&]() {
|
auto parseIntegerOrValue = [&]() {
|
||||||
OpAsmParser::UnresolvedOperand operand;
|
OpAsmParser::UnresolvedOperand operand;
|
||||||
auto res = parser.parseOptionalOperand(operand);
|
auto res = parser.parseOptionalOperand(operand);
|
||||||
|
|
||||||
|
// If `foundScalable` has already been set to `true` then a non-trailing
|
||||||
|
// tile size was identified as scalable.
|
||||||
|
if (foundScalable) {
|
||||||
|
parser.emitError(parser.getNameLoc())
|
||||||
|
<< "non-trailing tile size cannot be scalable";
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isTrailingIdxScalable && parser.parseOptionalLSquare().succeeded())
|
||||||
|
foundScalable = true;
|
||||||
|
|
||||||
if (res.has_value() && succeeded(res.value())) {
|
if (res.has_value() && succeeded(res.value())) {
|
||||||
values.push_back(operand);
|
values.push_back(operand);
|
||||||
integerVals.push_back(ShapedType::kDynamic);
|
integerVals.push_back(ShapedType::kDynamic);
|
||||||
@ -146,6 +159,8 @@ ParseResult mlir::parseDynamicIndexList(
|
|||||||
return failure();
|
return failure();
|
||||||
integerVals.push_back(integer);
|
integerVals.push_back(integer);
|
||||||
}
|
}
|
||||||
|
if (foundScalable && parser.parseOptionalRSquare().failed())
|
||||||
|
return failure();
|
||||||
return success();
|
return success();
|
||||||
};
|
};
|
||||||
if (parser.parseCommaSeparatedList(delimiter, parseIntegerOrValue,
|
if (parser.parseCommaSeparatedList(delimiter, parseIntegerOrValue,
|
||||||
@ -153,6 +168,8 @@ ParseResult mlir::parseDynamicIndexList(
|
|||||||
return parser.emitError(parser.getNameLoc())
|
return parser.emitError(parser.getNameLoc())
|
||||||
<< "expected SSA value or integer";
|
<< "expected SSA value or integer";
|
||||||
integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
|
integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
|
||||||
|
if (isTrailingIdxScalable)
|
||||||
|
*isTrailingIdxScalable = foundScalable;
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s
|
// RUN: mlir-opt --test-transform-dialect-interpreter --mlir-print-local-scope --split-input-file --verify-diagnostics %s | FileCheck %s
|
||||||
|
|
||||||
transform.sequence failures(propagate) {
|
transform.sequence failures(propagate) {
|
||||||
^bb0(%arg1: !transform.any_op):
|
^bb0(%arg1: !transform.any_op):
|
||||||
@ -149,3 +149,96 @@ transform.sequence failures(propagate) {
|
|||||||
transform.structured.tile_to_forall_op %0 tile_sizes[1, 1]
|
transform.structured.tile_to_forall_op %0 tile_sizes[1, 1]
|
||||||
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
|
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
#map = affine_map<(d0) -> (d0)>
|
||||||
|
|
||||||
|
module {
|
||||||
|
func.func @scalable_tile(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>, %arg3: f32) -> tensor<?xf32> {
|
||||||
|
%0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>) outs(%arg2 : tensor<?xf32>) {
|
||||||
|
^bb0(%in_1: f32, %in_2: f32, %out: f32):
|
||||||
|
%1 = arith.addf %in_1, %in_2 : f32
|
||||||
|
%2 = arith.mulf %arg3, %1 : f32
|
||||||
|
linalg.yield %2 : f32
|
||||||
|
} -> tensor<?xf32>
|
||||||
|
return %0 : tensor<?xf32>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @scalable_tile(
|
||||||
|
// CHECK-SAME: %[[ARG_0:.*]]: tensor<?xf32>, %[[ARG_1:.*]]: tensor<?xf32>, %[[ARG_2:.*]]: tensor<?xf32>,
|
||||||
|
// CHECK: %[[C4:.*]] = arith.constant 0 : index
|
||||||
|
// CHECK: %[[DIM:.*]] = tensor.dim %[[ARG_0]], %[[C4]] : tensor<?xf32>
|
||||||
|
// CHECK: %[[VEC_SIZE:.*]] = arith.constant 4 : index
|
||||||
|
// CHECK: %[[VS:.*]] = vector.vscale
|
||||||
|
// CHECK: %[[STEP:.*]] = arith.muli %[[VEC_SIZE]], %[[VS]] : index
|
||||||
|
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||||
|
// CHECK: scf.for %[[IV:.*]] = %[[C0]] to %[[DIM]] step %[[STEP]] iter_args(%[[VAL:.*]] = %[[ARG_2]]) -> (tensor<?xf32>) {
|
||||||
|
// CHECK: %[[SIZE:.*]] = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%[[IV]])[%[[STEP]], %[[DIM]]]
|
||||||
|
// CHECK: %[[SLICE_ARG0:.*]] = tensor.extract_slice %[[ARG_0]][%[[IV]]] [%[[SIZE]]] [1] : tensor<?xf32> to tensor<?xf32>
|
||||||
|
// CHECK: %[[SLICE_ARG1:.*]] = tensor.extract_slice %[[ARG_1]][%[[IV]]] [%[[SIZE]]] [1] : tensor<?xf32> to tensor<?xf32>
|
||||||
|
// CHECK: %[[SLICE_ARG2:.*]] = tensor.extract_slice %[[VAL]][%[[IV]]] [%[[SIZE]]] [1] : tensor<?xf32> to tensor<?xf32>
|
||||||
|
// CHECK: linalg.generic {indexing_maps = {{.*}}, iterator_types = ["parallel"]} ins(%[[SLICE_ARG0]], %[[SLICE_ARG1]] : tensor<?xf32>, tensor<?xf32>) outs(%[[SLICE_ARG2]] : tensor<?xf32>) {
|
||||||
|
|
||||||
|
transform.sequence failures(propagate) {
|
||||||
|
^bb0(%arg1: !transform.any_op):
|
||||||
|
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
|
%1, %loop = transform.structured.tile %0 [[4]] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @scalable_and_fixed_length_tile
|
||||||
|
// CHECK: %[[STEP_0:.*]] = arith.constant 4 : index
|
||||||
|
// CHECK: %[[STEP_1:.*]] = arith.constant 4 : index
|
||||||
|
// CHECK: %[[C4:.*]] = arith.constant 4 : index
|
||||||
|
// CHECK: %[[VS:.*]] = vector.vscale
|
||||||
|
// CHECK: %[[STEP_2:.*]] = arith.muli %[[C4]], %[[VS]] : index
|
||||||
|
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||||
|
// CHECK: %[[C128:.*]] = arith.constant 128 : index
|
||||||
|
// CHECK: scf.for %[[VAL_11:.*]] = %[[C0]] to %[[C128]] step %[[STEP_0]]
|
||||||
|
// CHECK: %[[C0_1:.*]] = arith.constant 0 : index
|
||||||
|
// CHECK: %[[C128_1:.*]] = arith.constant 128 : index
|
||||||
|
// CHECK: scf.for %[[VAL_16:.*]] = %[[C0_1]] to %[[C128_1]] step %[[STEP_1]]
|
||||||
|
// CHECK: %[[C0_2:.*]] = arith.constant 0 : index
|
||||||
|
// CHECK: %[[C128_2:.*]] = arith.constant 128 : index
|
||||||
|
// CHECK: scf.for %{{.*}} = %[[C0_2]] to %[[C128_2]] step %[[STEP_2]]
|
||||||
|
|
||||||
|
func.func @scalable_and_fixed_length_tile(
|
||||||
|
%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>)
|
||||||
|
-> tensor<128x128xf32> {
|
||||||
|
%0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
|
||||||
|
outs(%arg2: tensor<128x128xf32>)
|
||||||
|
-> tensor<128x128xf32>
|
||||||
|
|
||||||
|
return %0 : tensor<128x128xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
transform.sequence failures(propagate) {
|
||||||
|
^bb0(%arg1: !transform.any_op):
|
||||||
|
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
|
%1, %loops:3 = transform.structured.tile %0 [4, 4, [4]] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// TODO: Add support for for specyfying more than one scalable tile size
|
||||||
|
|
||||||
|
func.func @scalable_and_fixed_length_tile(
|
||||||
|
%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>)
|
||||||
|
-> tensor<128x128xf32> {
|
||||||
|
%0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
|
||||||
|
outs(%arg2: tensor<128x128xf32>)
|
||||||
|
-> tensor<128x128xf32>
|
||||||
|
|
||||||
|
return %0 : tensor<128x128xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
transform.sequence failures(propagate) {
|
||||||
|
^bb0(%arg1: !transform.any_op):
|
||||||
|
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
|
// expected-error @below {{non-trailing tile size cannot be scalable}}
|
||||||
|
// expected-error @below {{expected SSA value or integer}}
|
||||||
|
%1, %loops:3 = transform.structured.tile %0 [4, [4], [4]] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user