[mlir][ArmSME] Add vector to tile intrinsics

Add support for following vector to tile (MOVA) intrinsics to ArmSME
dialect:

  llvm.aarch64.sme.write.vert
  llvm.aarch64.sme.write.horiz

Includes the definition of new type predicate
'ScalableVectorOfRankAndLengthAndType' in OpBase.td.

Reviewed By: awarzynski, dcaballe

Differential Revision: https://reviews.llvm.org/D157004
This commit is contained in:
Cullen Rhodes 2023-08-03 09:52:56 +00:00
parent ba818c4019
commit 8ce23b8e5c
4 changed files with 142 additions and 0 deletions

View File

@ -14,6 +14,7 @@
#ifndef ARMSME_OPS
#define ARMSME_OPS
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
@ -61,6 +62,12 @@ def nxnxv2f64 : SMETileType<F64, [2, 2 ], "vector<[2]x[2]xf64>">;
def SMETile : AnyTypeOf<[nxnxv16i8, nxnxv8i16, nxnxv4i32, nxnxv2i64, nxnxv1i128,
nxnxv8f16, nxnxv8bf16, nxnxv4f32, nxnxv2f64]>;
def SVEVector : ScalableVectorOfRankAndLengthAndType<
[1], [16, 8, 4, 2, 1], [I8, I16, I32, I64, I128, F16, BF16, F32, F64]>;
def SVEPredicate : ScalableVectorOfRankAndLengthAndType<
[1], [16, 8, 4, 2, 1], [I1]>;
// A type constraint that verifies the bitwidth of the scalar integer returned
// from 'arm_sme.get_tile_id' matches the element bitwidth of a "virtual tile".
def TileElementWidthMatchesTileID : TypesMatchWith<
@ -496,6 +503,18 @@ def LLVM_aarch64_sme_str
Arguments<(ins Arg<I32, "Index">,
Arg<LLVM_AnyPointer, "Store address", [MemWrite]>)>;
// Vector to tile
class LLVM_aarch64_sme_write<string direction>
: ArmSME_IntrOp<"write." # direction, /*overloadedOperands=*/[3],
[AllShapesMatch<["pg", "vector"]>]>,
Arguments<(ins Arg<I32, "Virtual tile ID">,
Arg<I32, "Tile slice">,
Arg<SVEPredicate, "Vector predicate">:$pg,
Arg<SVEVector, "Vector operand">:$vector)>;
def LLVM_aarch64_sme_write_horiz : LLVM_aarch64_sme_write<"horiz">;
def LLVM_aarch64_sme_write_vert : LLVM_aarch64_sme_write<"vert">;
def LLVM_aarch64_sme_za_enable : ArmSME_IntrOp<"za.enable">;
def LLVM_aarch64_sme_za_disable : ArmSME_IntrOp<"za.disable">;

View File

@ -533,6 +533,19 @@ class ScalableVectorOfLengthAndType<list<int> allowedLengths,
ScalableVectorOfLength<allowedLengths>.summary,
"::mlir::VectorType">;
// Any scalable vector where the rank is from the given `allowedRanks` list and
// the number of elements is from the given `allowedLengths` list and the type
// is from the given `allowedTypes` list
class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks,
list<int> allowedLengths,
list<Type> allowedTypes> : AllOfType<
[ScalableVectorOfRank<allowedRanks>, ScalableVectorOf<allowedTypes>,
ScalableVectorOfLength<allowedLengths>],
ScalableVectorOfRank<allowedRanks>.summary #
ScalableVectorOf<allowedTypes>.summary #
ScalableVectorOfLength<allowedLengths>.summary,
"::mlir::VectorType">;
def AnyVector : VectorOf<[AnyType]>;
// Temporary vector type clone that allows gradual transition to 0-D vectors.
def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;

View File

@ -0,0 +1,12 @@
// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s
// Verify shape of predicate and vector must match
llvm.func @arm_sme_vector_to_tile_invalid_types(%tileslice : i32,
%nxv4i1 : vector<[4]xi1>,
%nxv16i8 : vector<[16]xi8>) {
%tile = llvm.mlir.constant(0 : index) : i32
// expected-error @+1 {{failed to verify that all of {pg, vector} have same shape}}
"arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv4i1, %nxv16i8) :
(i32, i32, vector<[4]xi1>, vector<[16]xi8>) -> ()
llvm.return
}

View File

@ -236,3 +236,101 @@ llvm.func @arm_sme_toggle_za() {
"arm_sme.intr.za.disable"() : () -> ()
llvm.return
}
// -----
// CHECK-LABEL: @arm_sme_vector_to_tile_horiz
llvm.func @arm_sme_vector_to_tile_horiz(%tileslice : i32,
%nxv16i1 : vector<[16]xi1>,
%nxv8i1 : vector<[8]xi1>,
%nxv4i1 : vector<[4]xi1>,
%nxv2i1 : vector<[2]xi1>,
%nxv1i1 : vector<[1]xi1>,
%nxv16i8 : vector<[16]xi8>,
%nxv8i16 : vector<[8]xi16>,
%nxv4i32 : vector<[4]xi32>,
%nxv2i64 : vector<[2]xi64>,
%nxv1i128 : vector<[1]xi128>,
%nxv8f16 : vector<[8]xf16>,
%nxv8bf16 : vector<[8]xbf16>,
%nxv4f32 : vector<[4]xf32>,
%nxv2f64 : vector<[2]xf64>) {
%tile = llvm.mlir.constant(0 : index) : i32
// CHECK: call void @llvm.aarch64.sme.write.horiz.nxv16i8
"arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv16i1, %nxv16i8) :
(i32, i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
// CHECK: call void @llvm.aarch64.sme.write.horiz.nxv8i16
"arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv8i1, %nxv8i16) :
(i32, i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
// CHECK: call void @llvm.aarch64.sme.write.horiz.nxv4i32
"arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv4i1, %nxv4i32) :
(i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
// CHECK: call void @llvm.aarch64.sme.write.horiz.nxv2i64
"arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv2i1, %nxv2i64) :
(i32, i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
// CHECK: call void @llvm.aarch64.sme.write.horiz.nxv1i128
"arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv1i1, %nxv1i128) :
(i32, i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
// CHECK: call void @llvm.aarch64.sme.write.horiz.nxv8f16
"arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv8i1, %nxv8f16) :
(i32, i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
// CHECK: call void @llvm.aarch64.sme.write.horiz.nxv8bf16
"arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv8i1, %nxv8bf16) :
(i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
// CHECK: call void @llvm.aarch64.sme.write.horiz.nxv4f32
"arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv4i1, %nxv4f32) :
(i32, i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
// CHECK: call void @llvm.aarch64.sme.write.horiz.nxv2f64
"arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv2i1, %nxv2f64) :
(i32, i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
llvm.return
}
// -----
// CHECK-LABEL: @arm_sme_vector_to_tile_vert
llvm.func @arm_sme_vector_to_tile_vert(%tileslice : i32,
%nxv16i1 : vector<[16]xi1>,
%nxv8i1 : vector<[8]xi1>,
%nxv4i1 : vector<[4]xi1>,
%nxv2i1 : vector<[2]xi1>,
%nxv1i1 : vector<[1]xi1>,
%nxv16i8 : vector<[16]xi8>,
%nxv8i16 : vector<[8]xi16>,
%nxv4i32 : vector<[4]xi32>,
%nxv2i64 : vector<[2]xi64>,
%nxv1i128 : vector<[1]xi128>,
%nxv8f16 : vector<[8]xf16>,
%nxv8bf16 : vector<[8]xbf16>,
%nxv4f32 : vector<[4]xf32>,
%nxv2f64 : vector<[2]xf64>) {
%tile = llvm.mlir.constant(0 : index) : i32
// CHECK: call void @llvm.aarch64.sme.write.vert.nxv16i8
"arm_sme.intr.write.vert"(%tile, %tileslice, %nxv16i1, %nxv16i8) :
(i32, i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
// CHECK: call void @llvm.aarch64.sme.write.vert.nxv8i16
"arm_sme.intr.write.vert"(%tile, %tileslice, %nxv8i1, %nxv8i16) :
(i32, i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
// CHECK: call void @llvm.aarch64.sme.write.vert.nxv4i32
"arm_sme.intr.write.vert"(%tile, %tileslice, %nxv4i1, %nxv4i32) :
(i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
// CHECK: call void @llvm.aarch64.sme.write.vert.nxv2i64
"arm_sme.intr.write.vert"(%tile, %tileslice, %nxv2i1, %nxv2i64) :
(i32, i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
// CHECK: call void @llvm.aarch64.sme.write.vert.nxv1i128
"arm_sme.intr.write.vert"(%tile, %tileslice, %nxv1i1, %nxv1i128) :
(i32, i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
// CHECK: call void @llvm.aarch64.sme.write.vert.nxv8f16
"arm_sme.intr.write.vert"(%tile, %tileslice, %nxv8i1, %nxv8f16) :
(i32, i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
// CHECK: call void @llvm.aarch64.sme.write.vert.nxv8bf16
"arm_sme.intr.write.vert"(%tile, %tileslice, %nxv8i1, %nxv8bf16) :
(i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
// CHECK: call void @llvm.aarch64.sme.write.vert.nxv4f32
"arm_sme.intr.write.vert"(%tile, %tileslice, %nxv4i1, %nxv4f32) :
(i32, i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
// CHECK: call void @llvm.aarch64.sme.write.vert.nxv2f64
"arm_sme.intr.write.vert"(%tile, %tileslice, %nxv2i1, %nxv2f64) :
(i32, i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
llvm.return
}