[AArch64] Adding Neon Sm3 & Sm4 Intrinsics

This adds SM3 and SM4 Intrinsics support for AArch64, specifically:
        vsm3ss1q_u32
        vsm3tt1aq_u32
        vsm3tt1bq_u32
        vsm3tt2aq_u32
        vsm3tt2bq_u32
        vsm3partw1q_u32
        vsm3partw2q_u32
        vsm4eq_u32
        vsm4ekeyq_u32

Reviewed By: labrinea

Differential Revision: https://reviews.llvm.org/D95655
This commit is contained in:
Pengxuan Zheng 2021-02-11 14:18:40 -08:00
parent 74916008a8
commit 61cca0f2e5
8 changed files with 286 additions and 2 deletions

View File

@ -1134,6 +1134,17 @@ def SHA1SU0 : SInst<"vsha1su0", "....", "QUi">;
def SHA256H : SInst<"vsha256h", "....", "QUi">;
def SHA256H2 : SInst<"vsha256h2", "....", "QUi">;
def SHA256SU1 : SInst<"vsha256su1", "....", "QUi">;
def SM3SS1 : SInst<"vsm3ss1", "....", "QUi">;
def SM3TT1A : SInst<"vsm3tt1a", "....I", "QUi">;
def SM3TT1B : SInst<"vsm3tt1b", "....I", "QUi">;
def SM3TT2A : SInst<"vsm3tt2a", "....I", "QUi">;
def SM3TT2B : SInst<"vsm3tt2b", "....I", "QUi">;
def SM3PARTW1 : SInst<"vsm3partw1", "....", "QUi">;
def SM3PARTW2 : SInst<"vsm3partw2", "....", "QUi">;
def SM4E : SInst<"vsm4e", "...", "QUi">;
def SM4EKEY : SInst<"vsm4ekey", "...", "QUi">;
}
////////////////////////////////////////////////////////////////////////////////

View File

@ -5832,6 +5832,15 @@ static const ARMVectorIntrinsicInfo AArch64SIMDIntrinsicMap[] = {
NEONMAP0(vshr_n_v),
NEONMAP0(vshrn_n_v),
NEONMAP0(vshrq_n_v),
NEONMAP1(vsm3partw1q_v, aarch64_crypto_sm3partw1, 0),
NEONMAP1(vsm3partw2q_v, aarch64_crypto_sm3partw2, 0),
NEONMAP1(vsm3ss1q_v, aarch64_crypto_sm3ss1, 0),
NEONMAP1(vsm3tt1aq_v, aarch64_crypto_sm3tt1a, 0),
NEONMAP1(vsm3tt1bq_v, aarch64_crypto_sm3tt1b, 0),
NEONMAP1(vsm3tt2aq_v, aarch64_crypto_sm3tt2a, 0),
NEONMAP1(vsm3tt2bq_v, aarch64_crypto_sm3tt2b, 0),
NEONMAP1(vsm4ekeyq_v, aarch64_crypto_sm4ekey, 0),
NEONMAP1(vsm4eq_v, aarch64_crypto_sm4e, 0),
NEONMAP1(vst1_x2_v, aarch64_neon_st1x2, 0),
NEONMAP1(vst1_x3_v, aarch64_neon_st1x3, 0),
NEONMAP1(vst1_x4_v, aarch64_neon_st1x4, 0),
@ -6710,6 +6719,22 @@ Value *CodeGenFunction::EmitCommonNeonBuiltinExpr(
Ops.push_back(getAlignmentValue32(PtrOp0));
return EmitNeonCall(CGM.getIntrinsic(Int, Tys), Ops, "");
}
case NEON::BI__builtin_neon_vsm3partw1q_v:
case NEON::BI__builtin_neon_vsm3partw2q_v:
case NEON::BI__builtin_neon_vsm3ss1q_v:
case NEON::BI__builtin_neon_vsm4ekeyq_v:
case NEON::BI__builtin_neon_vsm4eq_v: {
Function *F = CGM.getIntrinsic(Int);
return EmitNeonCall(F, Ops, "");
}
case NEON::BI__builtin_neon_vsm3tt1aq_v:
case NEON::BI__builtin_neon_vsm3tt1bq_v:
case NEON::BI__builtin_neon_vsm3tt2aq_v:
case NEON::BI__builtin_neon_vsm3tt2bq_v: {
Function *F = CGM.getIntrinsic(Int);
Ops[3] = Builder.CreateZExt(Ops[3], Int64Ty);
return EmitNeonCall(F, Ops, "");
}
case NEON::BI__builtin_neon_vst1_x2_v:
case NEON::BI__builtin_neon_vst1q_x2_v:
case NEON::BI__builtin_neon_vst1_x3_v:

View File

@ -0,0 +1,32 @@
// RUN: %clang_cc1 -triple aarch64-linux-gnu -target-feature +neon -target-feature +crypto -verify %s
#include <arm_neon.h>
void test_range_check_vsm3tt1a(uint32x4_t a, uint32x4_t b, uint32x4_t c) {
vsm3tt1aq_u32(a, b, c, 4); // expected-error {{argument value 4 is outside the valid range [0, 3]}}
vsm3tt1aq_u32(a, b, c, -1); // expected-error {{argument value -1 is outside the valid range [0, 3]}}
vsm3tt1aq_u32(a, b, c, 3);
vsm3tt1aq_u32(a, b, c, 0);
}
void test_range_check_vsm3tt1b(uint32x4_t a, uint32x4_t b, uint32x4_t c) {
vsm3tt1bq_u32(a, b, c, 4);// expected-error {{argument value 4 is outside the valid range [0, 3]}}
vsm3tt1bq_u32(a, b, c, -1); // expected-error {{argument value -1 is outside the valid range [0, 3]}}
vsm3tt1bq_u32(a, b, c, 3);
vsm3tt1bq_u32(a, b, c, 0);
}
void test_range_check_vsm3tt2a(uint32x4_t a, uint32x4_t b, uint32x4_t c) {
vsm3tt2aq_u32(a, b, c, 4);// expected-error {{argument value 4 is outside the valid range [0, 3]}}
vsm3tt2aq_u32(a, b, c, -1); // expected-error {{argument value -1 is outside the valid range [0, 3]}}
vsm3tt2aq_u32(a, b, c, 3);
vsm3tt2aq_u32(a, b, c, 0);
}
void test_range_check_vsm3tt2b(uint32x4_t a, uint32x4_t b, uint32x4_t c) {
vsm3tt2bq_u32(a, b, c, 4);// expected-error {{argument value 4 is outside the valid range [0, 3]}}
vsm3tt2bq_u32(a, b, c, -1); // expected-error {{argument value -1 is outside the valid range [0, 3]}}
vsm3tt2bq_u32(a, b, c, 3);
vsm3tt2bq_u32(a, b, c, 0);
}

View File

@ -0,0 +1,66 @@
// RUN: %clang_cc1 -triple aarch64-linux-gnu -target-feature +neon \
// RUN: -target-feature +crypto -S -emit-llvm -o - %s \
// RUN: | FileCheck %s
// RUN: not %clang_cc1 -triple aarch64-linux-gnu -target-feature +neon \
// RUN: -S -emit-llvm -o - %s 2>&1 | FileCheck --check-prefix=CHECK-NO-CRYPTO %s
//The front-end requires the addition of both +crypto and +sm4 in the
// command line, however the back-end requires only +sm4 (includes sm4&sm3)
#include <arm_neon.h>
void test_vsm3partw1(uint32x4_t a, uint32x4_t b, uint32x4_t c) {
// CHECK-LABEL: @test_vsm3partw1(
// CHECK-NO-CRYPTO: warning: implicit declaration of function 'vsm3partw1q_u32' is invalid in C99
// CHECK: call <4 x i32> @llvm.aarch64.crypto.sm3partw1
uint32x4_t result = vsm3partw1q_u32(a, b, c);
}
void test_vsm3partw2(uint32x4_t a, uint32x4_t b, uint32x4_t c) {
// CHECK-LABEL: @test_vsm3partw2(
// CHECK: call <4 x i32> @llvm.aarch64.crypto.sm3partw2
uint32x4_t result = vsm3partw2q_u32(a, b, c);
}
void test_vsm3ss1(uint32x4_t a, uint32x4_t b, uint32x4_t c) {
// CHECK-LABEL: @test_vsm3ss1(
// CHECK: call <4 x i32> @llvm.aarch64.crypto.sm3ss1
uint32x4_t result = vsm3ss1q_u32(a, b, c);
}
void test_vsm3tt1a(uint32x4_t a, uint32x4_t b, uint32x4_t c) {
// CHECK-LABEL: @test_vsm3tt1a(
// CHECK: call <4 x i32> @llvm.aarch64.crypto.sm3tt1a
uint32x4_t result = vsm3tt1aq_u32(a, b, c, 2);
}
void test_vsm3tt1b(uint32x4_t a, uint32x4_t b, uint32x4_t c) {
// CHECK-LABEL: @test_vsm3tt1b(
// CHECK: call <4 x i32> @llvm.aarch64.crypto.sm3tt1b
uint32x4_t result = vsm3tt1bq_u32(a, b, c, 2);
}
void test_vsm3tt2a(uint32x4_t a, uint32x4_t b, uint32x4_t c) {
// CHECK-LABEL: @test_vsm3tt2a(
// CHECK: call <4 x i32> @llvm.aarch64.crypto.sm3tt2a
uint32x4_t result = vsm3tt2aq_u32(a, b, c, 2);
}
void test_vsm3tt2b(uint32x4_t a, uint32x4_t b, uint32x4_t c) {
// CHECK-LABEL: @test_vsm3tt2b(
// CHECK: call <4 x i32> @llvm.aarch64.crypto.sm3tt2b
uint32x4_t result = vsm3tt2bq_u32(a, b, c, 2);
}
void test_vsm4e(uint32x4_t a, uint32x4_t b) {
// CHECK-LABEL: @test_vsm4e(
// CHECK: call <4 x i32> @llvm.aarch64.crypto.sm4e
uint32x4_t result = vsm4eq_u32(a, b);
}
void test_vsm4ekey(uint32x4_t a, uint32x4_t b) {
// CHECK-LABEL: @test_vsm4ekey(
// CHECK: call <4 x i32> @llvm.aarch64.crypto.sm4ekey
uint32x4_t result = vsm4ekeyq_u32(a, b);
}

View File

@ -711,6 +711,17 @@ let TargetPrefix = "aarch64" in {
class Crypto_SHA_8Hash4Schedule_Intrinsic
: DefaultAttrsIntrinsic<[llvm_v4i32_ty], [llvm_v4i32_ty, llvm_v4i32_ty, llvm_v4i32_ty],
[IntrNoMem]>;
class Crypto_SM3_3Vector_Intrinsic
: Intrinsic<[llvm_v4i32_ty], [llvm_v4i32_ty, llvm_v4i32_ty, llvm_v4i32_ty],
[IntrNoMem]>;
class Crypto_SM3_3VectorIndexed_Intrinsic
: Intrinsic<[llvm_v4i32_ty], [llvm_v4i32_ty, llvm_v4i32_ty, llvm_v4i32_ty, llvm_i64_ty],
[IntrNoMem, ImmArg<ArgIndex<3>>]>;
class Crypto_SM4_2Vector_Intrinsic
: Intrinsic<[llvm_v4i32_ty], [llvm_v4i32_ty, llvm_v4i32_ty], [IntrNoMem]>;
}
// AES
@ -734,6 +745,17 @@ def int_aarch64_crypto_sha256h2 : Crypto_SHA_8Hash4Schedule_Intrinsic;
def int_aarch64_crypto_sha256su0 : Crypto_SHA_8Schedule_Intrinsic;
def int_aarch64_crypto_sha256su1 : Crypto_SHA_12Schedule_Intrinsic;
//SM3 & SM4
def int_aarch64_crypto_sm3partw1 : Crypto_SM3_3Vector_Intrinsic;
def int_aarch64_crypto_sm3partw2 : Crypto_SM3_3Vector_Intrinsic;
def int_aarch64_crypto_sm3ss1 : Crypto_SM3_3Vector_Intrinsic;
def int_aarch64_crypto_sm3tt1a : Crypto_SM3_3VectorIndexed_Intrinsic;
def int_aarch64_crypto_sm3tt1b : Crypto_SM3_3VectorIndexed_Intrinsic;
def int_aarch64_crypto_sm3tt2a : Crypto_SM3_3VectorIndexed_Intrinsic;
def int_aarch64_crypto_sm3tt2b : Crypto_SM3_3VectorIndexed_Intrinsic;
def int_aarch64_crypto_sm4e : Crypto_SM4_2Vector_Intrinsic;
def int_aarch64_crypto_sm4ekey : Crypto_SM4_2Vector_Intrinsic;
//===----------------------------------------------------------------------===//
// CRC32

View File

@ -10912,8 +10912,8 @@ class BaseCryptoV82<dag oops, dag iops, string asm, string asmops, string cst,
}
class CryptoRRTied<bits<1>op0, bits<2>op1, string asm, string asmops>
: BaseCryptoV82<(outs V128:$Vd), (ins V128:$Vn, V128:$Vm), asm, asmops,
"$Vm = $Vd", []> {
: BaseCryptoV82<(outs V128:$Vdst), (ins V128:$Vd, V128:$Vn), asm, asmops,
"$Vd = $Vdst", []> {
let Inst{31-25} = 0b1100111;
let Inst{24-21} = 0b0110;
let Inst{20-15} = 0b000001;

View File

@ -936,6 +936,32 @@ def SM3PARTW1 : CryptoRRRTied_4S<0b1, 0b00, "sm3partw1">;
def SM3PARTW2 : CryptoRRRTied_4S<0b1, 0b01, "sm3partw2">;
def SM4ENCKEY : CryptoRRR_4S<0b1, 0b10, "sm4ekey">;
def SM4E : CryptoRRTied_4S<0b0, 0b01, "sm4e">;
def : Pat<(v4i32 (int_aarch64_crypto_sm3ss1 (v4i32 V128:$Vn), (v4i32 V128:$Vm), (v4i32 V128:$Va))),
(SM3SS1 (v4i32 V128:$Vn), (v4i32 V128:$Vm), (v4i32 V128:$Va))>;
class SM3PARTW_pattern<Instruction INST, Intrinsic OpNode>
: Pat<(v4i32 (OpNode (v4i32 V128:$Vd), (v4i32 V128:$Vn), (v4i32 V128:$Vm))),
(INST (v4i32 V128:$Vd), (v4i32 V128:$Vn), (v4i32 V128:$Vm))>;
class SM3TT_pattern<Instruction INST, Intrinsic OpNode>
: Pat<(v4i32 (OpNode (v4i32 V128:$Vd), (v4i32 V128:$Vn), (v4i32 V128:$Vm), (i64 VectorIndexS_timm:$imm) )),
(INST (v4i32 V128:$Vd), (v4i32 V128:$Vn), (v4i32 V128:$Vm), (VectorIndexS_timm:$imm))>;
class SM4_pattern<Instruction INST, Intrinsic OpNode>
: Pat<(v4i32 (OpNode (v4i32 V128:$Vn), (v4i32 V128:$Vm))),
(INST (v4i32 V128:$Vn), (v4i32 V128:$Vm))>;
def : SM3PARTW_pattern<SM3PARTW1, int_aarch64_crypto_sm3partw1>;
def : SM3PARTW_pattern<SM3PARTW2, int_aarch64_crypto_sm3partw2>;
def : SM3TT_pattern<SM3TT1A, int_aarch64_crypto_sm3tt1a>;
def : SM3TT_pattern<SM3TT1B, int_aarch64_crypto_sm3tt1b>;
def : SM3TT_pattern<SM3TT2A, int_aarch64_crypto_sm3tt2a>;
def : SM3TT_pattern<SM3TT2B, int_aarch64_crypto_sm3tt2b>;
def : SM4_pattern<SM4ENCKEY, int_aarch64_crypto_sm4ekey>;
def : SM4_pattern<SM4E, int_aarch64_crypto_sm4e>;
} // HasSM4
let Predicates = [HasRCPC] in {

View File

@ -0,0 +1,102 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc %s -mtriple=aarch64 -mattr=+v8.3a,+sm4 -o - | FileCheck %s
define <4 x i32> @test_vsm3partw1(<4 x i32> %a, <4 x i32> %b, <4 x i32> %c) {
; CHECK-LABEL: test_vsm3partw1:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sm3partw1 v0.4s, v1.4s, v2.4s
; CHECK-NEXT: ret
entry:
%vsm3partw1.i = tail call <4 x i32> @llvm.aarch64.crypto.sm3partw1(<4 x i32> %a, <4 x i32> %b, <4 x i32> %c)
ret <4 x i32> %vsm3partw1.i
}
define <4 x i32> @test_vsm3partw2(<4 x i32> %a, <4 x i32> %b, <4 x i32> %c) {
; CHECK-LABEL: test_vsm3partw2:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sm3partw2 v0.4s, v1.4s, v2.4s
; CHECK-NEXT: ret
entry:
%vsm3partw2.i = tail call <4 x i32> @llvm.aarch64.crypto.sm3partw2(<4 x i32> %a, <4 x i32> %b, <4 x i32> %c)
ret <4 x i32> %vsm3partw2.i
}
define <4 x i32> @test_vsm3ss1(<4 x i32> %a, <4 x i32> %b, <4 x i32> %c) {
; CHECK-LABEL: test_vsm3ss1:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sm3ss1 v0.4s, v0.4s, v1.4s, v2.4s
; CHECK-NEXT: ret
entry:
%vsm3ss1.i = tail call <4 x i32> @llvm.aarch64.crypto.sm3ss1(<4 x i32> %a, <4 x i32> %b, <4 x i32> %c)
ret <4 x i32> %vsm3ss1.i
}
define <4 x i32> @test_vsm3tt1a(<4 x i32> %a, <4 x i32> %b, <4 x i32> %c) {
; CHECK-LABEL: test_vsm3tt1a:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sm3tt1a v0.4s, v1.4s, v2.s[2]
; CHECK-NEXT: ret
entry:
%vsm3tt1a.i = tail call <4 x i32> @llvm.aarch64.crypto.sm3tt1a(<4 x i32> %a, <4 x i32> %b, <4 x i32> %c, i64 2)
ret <4 x i32> %vsm3tt1a.i
}
define <4 x i32> @test_vsm3tt1b(<4 x i32> %a, <4 x i32> %b, <4 x i32> %c) {
; CHECK-LABEL: test_vsm3tt1b:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sm3tt1b v0.4s, v1.4s, v2.s[2]
; CHECK-NEXT: ret
entry:
%vsm3tt1b.i = tail call <4 x i32> @llvm.aarch64.crypto.sm3tt1b(<4 x i32> %a, <4 x i32> %b, <4 x i32> %c, i64 2)
ret <4 x i32> %vsm3tt1b.i
}
define <4 x i32> @test_vsm3tt2a(<4 x i32> %a, <4 x i32> %b, <4 x i32> %c) {
; CHECK-LABEL: test_vsm3tt2a:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sm3tt2a v0.4s, v1.4s, v2.s[2]
; CHECK-NEXT: ret
entry:
%vsm3tt2a.i = tail call <4 x i32> @llvm.aarch64.crypto.sm3tt2a(<4 x i32> %a, <4 x i32> %b, <4 x i32> %c, i64 2)
ret <4 x i32> %vsm3tt2a.i
}
define <4 x i32> @test_vsm3tt2b(<4 x i32> %a, <4 x i32> %b, <4 x i32> %c) {
; CHECK-LABEL: test_vsm3tt2b:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sm3tt2b v0.4s, v1.4s, v2.s[2]
; CHECK-NEXT: ret
entry:
%vsm3tt2b.i = tail call <4 x i32> @llvm.aarch64.crypto.sm3tt2b(<4 x i32> %a, <4 x i32> %b, <4 x i32> %c, i64 2)
ret <4 x i32> %vsm3tt2b.i
}
define <4 x i32> @test_vsm4e(<4 x i32> %a, <4 x i32> %b) {
; CHECK-LABEL: test_vsm4e:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sm4e v0.4s, v1.4s
; CHECK-NEXT: ret
entry:
%vsm4e.i = tail call <4 x i32> @llvm.aarch64.crypto.sm4e(<4 x i32> %a, <4 x i32> %b)
ret <4 x i32> %vsm4e.i
}
define <4 x i32> @test_vsm4ekey(<4 x i32> %a, <4 x i32> %b) {
; CHECK-LABEL: test_vsm4ekey:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sm4ekey v0.4s, v0.4s, v1.4s
; CHECK-NEXT: ret
entry:
%vsm4ekey.i = tail call <4 x i32> @llvm.aarch64.crypto.sm4ekey(<4 x i32> %a, <4 x i32> %b)
ret <4 x i32> %vsm4ekey.i
}
declare <4 x i32> @llvm.aarch64.crypto.sm3partw1(<4 x i32>, <4 x i32>, <4 x i32>)
declare <4 x i32> @llvm.aarch64.crypto.sm3partw2(<4 x i32>, <4 x i32>, <4 x i32>)
declare <4 x i32> @llvm.aarch64.crypto.sm3ss1(<4 x i32>, <4 x i32>, <4 x i32>)
declare <4 x i32> @llvm.aarch64.crypto.sm3tt1a(<4 x i32>, <4 x i32>, <4 x i32>, i64 immarg)
declare <4 x i32> @llvm.aarch64.crypto.sm3tt2b(<4 x i32>, <4 x i32>, <4 x i32>, i64 immarg)
declare <4 x i32> @llvm.aarch64.crypto.sm3tt2a(<4 x i32>, <4 x i32>, <4 x i32>, i64 immarg)
declare <4 x i32> @llvm.aarch64.crypto.sm3tt1b(<4 x i32>, <4 x i32>, <4 x i32>, i64 immarg)
declare <4 x i32> @llvm.aarch64.crypto.sm4e(<4 x i32>, <4 x i32>)
declare <4 x i32> @llvm.aarch64.crypto.sm4ekey(<4 x i32>, <4 x i32>)