mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-12-11 08:48:12 +00:00
[mlir][ArmSME] Add vector to tile intrinsics
Add support for following vector to tile (MOVA) intrinsics to ArmSME dialect: llvm.aarch64.sme.write.vert llvm.aarch64.sme.write.horiz Includes the definition of new type predicate 'ScalableVectorOfRankAndLengthAndType' in OpBase.td. Reviewed By: awarzynski, dcaballe Differential Revision: https://reviews.llvm.org/D157004
This commit is contained in:
parent
ba818c4019
commit
8ce23b8e5c
@ -14,6 +14,7 @@
|
||||
#ifndef ARMSME_OPS
|
||||
#define ARMSME_OPS
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
@ -61,6 +62,12 @@ def nxnxv2f64 : SMETileType<F64, [2, 2 ], "vector<[2]x[2]xf64>">;
|
||||
def SMETile : AnyTypeOf<[nxnxv16i8, nxnxv8i16, nxnxv4i32, nxnxv2i64, nxnxv1i128,
|
||||
nxnxv8f16, nxnxv8bf16, nxnxv4f32, nxnxv2f64]>;
|
||||
|
||||
def SVEVector : ScalableVectorOfRankAndLengthAndType<
|
||||
[1], [16, 8, 4, 2, 1], [I8, I16, I32, I64, I128, F16, BF16, F32, F64]>;
|
||||
|
||||
def SVEPredicate : ScalableVectorOfRankAndLengthAndType<
|
||||
[1], [16, 8, 4, 2, 1], [I1]>;
|
||||
|
||||
// A type constraint that verifies the bitwidth of the scalar integer returned
|
||||
// from 'arm_sme.get_tile_id' matches the element bitwidth of a "virtual tile".
|
||||
def TileElementWidthMatchesTileID : TypesMatchWith<
|
||||
@ -496,6 +503,18 @@ def LLVM_aarch64_sme_str
|
||||
Arguments<(ins Arg<I32, "Index">,
|
||||
Arg<LLVM_AnyPointer, "Store address", [MemWrite]>)>;
|
||||
|
||||
// Vector to tile
|
||||
class LLVM_aarch64_sme_write<string direction>
|
||||
: ArmSME_IntrOp<"write." # direction, /*overloadedOperands=*/[3],
|
||||
[AllShapesMatch<["pg", "vector"]>]>,
|
||||
Arguments<(ins Arg<I32, "Virtual tile ID">,
|
||||
Arg<I32, "Tile slice">,
|
||||
Arg<SVEPredicate, "Vector predicate">:$pg,
|
||||
Arg<SVEVector, "Vector operand">:$vector)>;
|
||||
|
||||
def LLVM_aarch64_sme_write_horiz : LLVM_aarch64_sme_write<"horiz">;
|
||||
def LLVM_aarch64_sme_write_vert : LLVM_aarch64_sme_write<"vert">;
|
||||
|
||||
def LLVM_aarch64_sme_za_enable : ArmSME_IntrOp<"za.enable">;
|
||||
def LLVM_aarch64_sme_za_disable : ArmSME_IntrOp<"za.disable">;
|
||||
|
||||
|
@ -533,6 +533,19 @@ class ScalableVectorOfLengthAndType<list<int> allowedLengths,
|
||||
ScalableVectorOfLength<allowedLengths>.summary,
|
||||
"::mlir::VectorType">;
|
||||
|
||||
// Any scalable vector where the rank is from the given `allowedRanks` list and
|
||||
// the number of elements is from the given `allowedLengths` list and the type
|
||||
// is from the given `allowedTypes` list
|
||||
class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks,
|
||||
list<int> allowedLengths,
|
||||
list<Type> allowedTypes> : AllOfType<
|
||||
[ScalableVectorOfRank<allowedRanks>, ScalableVectorOf<allowedTypes>,
|
||||
ScalableVectorOfLength<allowedLengths>],
|
||||
ScalableVectorOfRank<allowedRanks>.summary #
|
||||
ScalableVectorOf<allowedTypes>.summary #
|
||||
ScalableVectorOfLength<allowedLengths>.summary,
|
||||
"::mlir::VectorType">;
|
||||
|
||||
def AnyVector : VectorOf<[AnyType]>;
|
||||
// Temporary vector type clone that allows gradual transition to 0-D vectors.
|
||||
def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;
|
||||
|
12
mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
Normal file
12
mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
Normal file
@ -0,0 +1,12 @@
|
||||
// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s
|
||||
|
||||
// Verify shape of predicate and vector must match
|
||||
llvm.func @arm_sme_vector_to_tile_invalid_types(%tileslice : i32,
|
||||
%nxv4i1 : vector<[4]xi1>,
|
||||
%nxv16i8 : vector<[16]xi8>) {
|
||||
%tile = llvm.mlir.constant(0 : index) : i32
|
||||
// expected-error @+1 {{failed to verify that all of {pg, vector} have same shape}}
|
||||
"arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv4i1, %nxv16i8) :
|
||||
(i32, i32, vector<[4]xi1>, vector<[16]xi8>) -> ()
|
||||
llvm.return
|
||||
}
|
@ -236,3 +236,101 @@ llvm.func @arm_sme_toggle_za() {
|
||||
"arm_sme.intr.za.disable"() : () -> ()
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @arm_sme_vector_to_tile_horiz
|
||||
llvm.func @arm_sme_vector_to_tile_horiz(%tileslice : i32,
|
||||
%nxv16i1 : vector<[16]xi1>,
|
||||
%nxv8i1 : vector<[8]xi1>,
|
||||
%nxv4i1 : vector<[4]xi1>,
|
||||
%nxv2i1 : vector<[2]xi1>,
|
||||
%nxv1i1 : vector<[1]xi1>,
|
||||
%nxv16i8 : vector<[16]xi8>,
|
||||
%nxv8i16 : vector<[8]xi16>,
|
||||
%nxv4i32 : vector<[4]xi32>,
|
||||
%nxv2i64 : vector<[2]xi64>,
|
||||
%nxv1i128 : vector<[1]xi128>,
|
||||
%nxv8f16 : vector<[8]xf16>,
|
||||
%nxv8bf16 : vector<[8]xbf16>,
|
||||
%nxv4f32 : vector<[4]xf32>,
|
||||
%nxv2f64 : vector<[2]xf64>) {
|
||||
%tile = llvm.mlir.constant(0 : index) : i32
|
||||
// CHECK: call void @llvm.aarch64.sme.write.horiz.nxv16i8
|
||||
"arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv16i1, %nxv16i8) :
|
||||
(i32, i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
|
||||
// CHECK: call void @llvm.aarch64.sme.write.horiz.nxv8i16
|
||||
"arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv8i1, %nxv8i16) :
|
||||
(i32, i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
|
||||
// CHECK: call void @llvm.aarch64.sme.write.horiz.nxv4i32
|
||||
"arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv4i1, %nxv4i32) :
|
||||
(i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
|
||||
// CHECK: call void @llvm.aarch64.sme.write.horiz.nxv2i64
|
||||
"arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv2i1, %nxv2i64) :
|
||||
(i32, i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
|
||||
// CHECK: call void @llvm.aarch64.sme.write.horiz.nxv1i128
|
||||
"arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv1i1, %nxv1i128) :
|
||||
(i32, i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
|
||||
// CHECK: call void @llvm.aarch64.sme.write.horiz.nxv8f16
|
||||
"arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv8i1, %nxv8f16) :
|
||||
(i32, i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
|
||||
// CHECK: call void @llvm.aarch64.sme.write.horiz.nxv8bf16
|
||||
"arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv8i1, %nxv8bf16) :
|
||||
(i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
|
||||
// CHECK: call void @llvm.aarch64.sme.write.horiz.nxv4f32
|
||||
"arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv4i1, %nxv4f32) :
|
||||
(i32, i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
|
||||
// CHECK: call void @llvm.aarch64.sme.write.horiz.nxv2f64
|
||||
"arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv2i1, %nxv2f64) :
|
||||
(i32, i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @arm_sme_vector_to_tile_vert
|
||||
llvm.func @arm_sme_vector_to_tile_vert(%tileslice : i32,
|
||||
%nxv16i1 : vector<[16]xi1>,
|
||||
%nxv8i1 : vector<[8]xi1>,
|
||||
%nxv4i1 : vector<[4]xi1>,
|
||||
%nxv2i1 : vector<[2]xi1>,
|
||||
%nxv1i1 : vector<[1]xi1>,
|
||||
%nxv16i8 : vector<[16]xi8>,
|
||||
%nxv8i16 : vector<[8]xi16>,
|
||||
%nxv4i32 : vector<[4]xi32>,
|
||||
%nxv2i64 : vector<[2]xi64>,
|
||||
%nxv1i128 : vector<[1]xi128>,
|
||||
%nxv8f16 : vector<[8]xf16>,
|
||||
%nxv8bf16 : vector<[8]xbf16>,
|
||||
%nxv4f32 : vector<[4]xf32>,
|
||||
%nxv2f64 : vector<[2]xf64>) {
|
||||
%tile = llvm.mlir.constant(0 : index) : i32
|
||||
// CHECK: call void @llvm.aarch64.sme.write.vert.nxv16i8
|
||||
"arm_sme.intr.write.vert"(%tile, %tileslice, %nxv16i1, %nxv16i8) :
|
||||
(i32, i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
|
||||
// CHECK: call void @llvm.aarch64.sme.write.vert.nxv8i16
|
||||
"arm_sme.intr.write.vert"(%tile, %tileslice, %nxv8i1, %nxv8i16) :
|
||||
(i32, i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
|
||||
// CHECK: call void @llvm.aarch64.sme.write.vert.nxv4i32
|
||||
"arm_sme.intr.write.vert"(%tile, %tileslice, %nxv4i1, %nxv4i32) :
|
||||
(i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
|
||||
// CHECK: call void @llvm.aarch64.sme.write.vert.nxv2i64
|
||||
"arm_sme.intr.write.vert"(%tile, %tileslice, %nxv2i1, %nxv2i64) :
|
||||
(i32, i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
|
||||
// CHECK: call void @llvm.aarch64.sme.write.vert.nxv1i128
|
||||
"arm_sme.intr.write.vert"(%tile, %tileslice, %nxv1i1, %nxv1i128) :
|
||||
(i32, i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
|
||||
// CHECK: call void @llvm.aarch64.sme.write.vert.nxv8f16
|
||||
"arm_sme.intr.write.vert"(%tile, %tileslice, %nxv8i1, %nxv8f16) :
|
||||
(i32, i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
|
||||
// CHECK: call void @llvm.aarch64.sme.write.vert.nxv8bf16
|
||||
"arm_sme.intr.write.vert"(%tile, %tileslice, %nxv8i1, %nxv8bf16) :
|
||||
(i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
|
||||
// CHECK: call void @llvm.aarch64.sme.write.vert.nxv4f32
|
||||
"arm_sme.intr.write.vert"(%tile, %tileslice, %nxv4i1, %nxv4f32) :
|
||||
(i32, i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
|
||||
// CHECK: call void @llvm.aarch64.sme.write.vert.nxv2f64
|
||||
"arm_sme.intr.write.vert"(%tile, %tileslice, %nxv2i1, %nxv2f64) :
|
||||
(i32, i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
|
||||
llvm.return
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user