mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-02-21 10:42:35 +00:00
[mlir][ArmSME] Lower transfer_write + transpose to vertical store (#71181)
This patch extends the lowering of vector.transfer_write in VectorToArmSME to support in-flight transpose via SME vertical store.
This commit is contained in:
parent
6206817380
commit
4240b1790f
@ -136,13 +136,31 @@ struct TransferReadToArmSMELowering
|
||||
|
||||
/// Conversion pattern for vector.transfer_write.
|
||||
///
|
||||
/// vector.transfer_write %vector, %source[%c0, %c0] : vector<[16]x[16]xi8>,
|
||||
/// memref<?x?xi8>
|
||||
/// ---
|
||||
///
|
||||
/// Example 1: op with identity permutation map to horizontal
|
||||
/// arm_sme.tile_store:
|
||||
///
|
||||
/// vector.transfer_write %vector, %source[%c0, %c0]
|
||||
/// {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
|
||||
///
|
||||
/// is converted to:
|
||||
///
|
||||
/// arm_sme.tile_store %vector, %source[%c0, %c0] : memref<?x?xi8>,
|
||||
/// vector<[16]x[16]xi8>
|
||||
/// ---
|
||||
///
|
||||
/// Example 2: op with transpose permutation map to vertical arm_sme.tile_store
|
||||
/// (in-flight transpose):
|
||||
///
|
||||
/// vector.transfer_write %vector, %source[%c0, %c0]
|
||||
/// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
|
||||
/// in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
|
||||
///
|
||||
/// is converted to:
|
||||
///
|
||||
/// arm_sme.tile_store %vector, %source[%c0, %c0] layout<vertical>
|
||||
/// : memref<?x?xi8>, vector<[16]x[16]xi8>
|
||||
struct TransferWriteToArmSMELowering
|
||||
: public OpRewritePattern<vector::TransferWriteOp> {
|
||||
using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
|
||||
@ -156,9 +174,28 @@ struct TransferWriteToArmSMELowering
|
||||
if (!llvm::isa<MemRefType>(writeOp.getSource().getType()))
|
||||
return failure();
|
||||
|
||||
// Out-of-bounds dims are not supported.
|
||||
if (writeOp.hasOutOfBoundsDim())
|
||||
return rewriter.notifyMatchFailure(writeOp,
|
||||
"not inbounds transfer write");
|
||||
|
||||
AffineExpr d0, d1;
|
||||
bindDims(writeOp.getContext(), d0, d1);
|
||||
AffineMap map = writeOp.getPermutationMap();
|
||||
bool isTranspose = (map == AffineMap::get(map.getNumDims(), 0, {d1, d0},
|
||||
writeOp.getContext()));
|
||||
|
||||
if (!map.isIdentity() && !isTranspose)
|
||||
return rewriter.notifyMatchFailure(writeOp,
|
||||
"unsupported permutation map");
|
||||
|
||||
arm_sme::TileSliceLayout layout =
|
||||
isTranspose ? arm_sme::TileSliceLayout::Vertical
|
||||
: arm_sme::TileSliceLayout::Horizontal;
|
||||
|
||||
rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
|
||||
writeOp, writeOp.getVector(), writeOp.getSource(), writeOp.getIndices(),
|
||||
writeOp.getMask());
|
||||
writeOp.getMask(), layout);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -337,6 +337,37 @@ func.func @transfer_write_2d_with_mask_f64(%vector : vector<[2]x[2]xf64>, %dest
|
||||
|
||||
// -----
|
||||
|
||||
/// in-flight transpose via vertical store.
|
||||
|
||||
// CHECK-LABEL: func.func @transfer_write_2d_transpose_i64(
|
||||
// CHECK-SAME: %[[VECTOR:.*]]: vector<[2]x[2]xi64>,
|
||||
// CHECK-SAME: %[[DEST:.*]]: memref<?x?xi64>) {
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
|
||||
func.func @transfer_write_2d_transpose_i64(%vector : vector<[2]x[2]xi64>, %dest : memref<?x?xi64>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
vector.transfer_write %vector, %dest[%c0, %c0] {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : vector<[2]x[2]xi64>, memref<?x?xi64>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
/// in-flight transpose via vertical store with mask.
|
||||
|
||||
// CHECK-LABEL: func.func @transfer_write_2d_transpose_with_mask_bf16(
|
||||
// CHECK-SAME: %[[VECTOR:.*]]: vector<[8]x[8]xbf16>,
|
||||
// CHECK-SAME: %[[DEST:.*]]: memref<?x?xbf16>,
|
||||
// CHECK-SAME: %[[MASK:.*]]: vector<[8]x[8]xi1>) {
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]], %[[MASK]] layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
|
||||
func.func @transfer_write_2d_transpose_with_mask_bf16(%vector : vector<[8]x[8]xbf16>, %dest : memref<?x?xbf16>, %mask : vector<[8]x[8]xi1>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
vector.transfer_write %vector, %dest[%c0, %c0], %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : vector<[8]x[8]xbf16>, memref<?x?xbf16>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// The following tests check the 'vector.transfer_write' -> 'arm_sme.intr.zero'
|
||||
// lowering only occurs for vector types of correct rank, shape, element size
|
||||
// and number of scalable dims.
|
||||
@ -398,6 +429,17 @@ func.func @transfer_write_2d__fixed(%vector : vector<16x16xi8>, %dest : memref<?
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @transfer_write_2d__out_of_bounds
|
||||
// CHECK: vector.transfer_write
|
||||
// CHECK-NOT: arm_sme.tile_store
|
||||
func.func @transfer_write_2d__out_of_bounds(%vector : vector<[4]x[4]xf32>, %dest : memref<?x?xf32>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [false, false]} : vector<[4]x[4]xf32>, memref<?x?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// vector.broadcast
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -32,6 +32,25 @@ func.func @transfer_write_2d_mask(%A : memref<?x?xf32>, %base1: index, %base2: i
|
||||
return
|
||||
}
|
||||
|
||||
// Vector transpose + store.
|
||||
func.func @transfer_write_2d_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
|
||||
%0 = vector.load %A[%base1, %base2] : memref<?x?xf32>, vector<[4]x[4]xf32>
|
||||
vector.transfer_write %0, %A[%base1, %base2] {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]} :
|
||||
vector<[4]x[4]xf32>, memref<?x?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// Vector transpose + masked store.
|
||||
func.func @transfer_write_2d_mask_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
|
||||
%c2 = arith.constant 2 : index
|
||||
%c4 = arith.constant 4 : index
|
||||
%mask = vector.create_mask %c4, %c2 : vector<[4]x[4]xi1>
|
||||
%0 = vector.load %A[%base1, %base2] : memref<?x?xf32>, vector<[4]x[4]xf32>
|
||||
vector.transfer_write %0, %A[%base1, %base2], %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]} :
|
||||
vector<[4]x[4]xf32>, memref<?x?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// Vector load + print.
|
||||
func.func @load_and_print(%A : memref<?x?xf32>, %base1: index, %base2: index) {
|
||||
%0 = vector.load %A[%base1, %base2] : memref<?x?xf32>, vector<[4]x[4]xf32>
|
||||
@ -116,6 +135,26 @@ func.func @entry() {
|
||||
call @transfer_write_2d_mask(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
|
||||
call @load_and_print(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
|
||||
|
||||
// 4. Reload 3. + transpose + store.
|
||||
// CHECK-LABEL: TILE BEGIN:
|
||||
// CHECK-NEXT: ( 0, 0, 20, 30
|
||||
// CHECK-NEXT: ( 0, 0, 21, 31
|
||||
// CHECK-NEXT: ( 0, 0, 0, 0
|
||||
// CHECK-NEXT: ( 3, 13, 0, 0
|
||||
call @transfer_write_2d_transposed(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
|
||||
call @load_and_print(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
|
||||
|
||||
// 5. Reload 4. + transpose + masked store (nrows=4, ncols=2).
|
||||
// The mask applies after permutation. Columns 2 and 3 (from 4.) are
|
||||
// preserved.
|
||||
// CHECK-LABEL: TILE BEGIN:
|
||||
// CHECK-NEXT: ( 0, 0, 20, 30
|
||||
// CHECK-NEXT: ( 0, 0, 21, 31
|
||||
// CHECK-NEXT: ( 20, 21, 0, 0
|
||||
// CHECK-NEXT: ( 30, 31, 0, 0
|
||||
call @transfer_write_2d_mask_transposed(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
|
||||
call @load_and_print(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
|
||||
|
||||
memref.dealloc %A : memref<?x?xf32>
|
||||
|
||||
return
|
||||
|
Loading…
x
Reference in New Issue
Block a user