[ARM][BFloat] Implement bf16 get/set_lane without casts to i16 vectors

Currently, in order to extract an element from a bf16 vector, we cast
the vector to an i16 vector, perform the extraction, and cast the result to
bfloat. This behavior was copied from the old fp16 implementation.

The goal of this patch is to achieve optimal code generation for lane
copying intrinsics in a subsequent patch (LLVM fails to fold certain
combinations of bitcast, insertelement, extractelement and
shufflevector instructions leading to the generation of suboptimal code).

Differential Revision: https://reviews.llvm.org/D82206
This commit is contained in:
Mikhail Maltsev 2020-06-22 17:15:11 +00:00
parent 6ae0f5f3e1
commit 3a4feb1d53
3 changed files with 48 additions and 48 deletions

View File

@ -190,28 +190,20 @@ def OP_SCALAR_QRDMLAH_LN : Op<(call "vqadd", $p0, (call "vqrdmulh", $p1,
def OP_SCALAR_QRDMLSH_LN : Op<(call "vqsub", $p0, (call "vqrdmulh", $p1,
(call "vget_lane", $p2, $p3)))>;
multiclass ScalarGetSetLaneOpsF16<string scalarTy,
string vectorTy4, string vectorTy8> {
def _GET_LN : Op<(bitcast scalarTy,
(call "vget_lane",
(bitcast "int16x4_t", $p0), $p1))>;
def _GET_LNQ : Op<(bitcast scalarTy,
(call "vget_lane",
(bitcast "int16x8_t", $p0), $p1))>;
def _SET_LN : Op<(bitcast vectorTy4,
(call "vset_lane",
(bitcast "int16_t", $p0),
(bitcast "int16x4_t", $p1), $p2))>;
def _SET_LNQ : Op<(bitcast vectorTy8,
(call "vset_lane",
(bitcast "int16_t", $p0),
(bitcast "int16x8_t", $p1), $p2))>;
}
defm OP_SCALAR_HALF: ScalarGetSetLaneOpsF16<"float16_t",
"float16x4_t", "float16x8_t">;
defm OP_SCALAR_BF16: ScalarGetSetLaneOpsF16<"bfloat16_t",
"bfloat16x4_t", "bfloat16x8_t">;
def OP_SCALAR_HALF_GET_LN : Op<(bitcast "float16_t",
(call "vget_lane",
(bitcast "int16x4_t", $p0), $p1))>;
def OP_SCALAR_HALF_GET_LNQ : Op<(bitcast "float16_t",
(call "vget_lane",
(bitcast "int16x8_t", $p0), $p1))>;
def OP_SCALAR_HALF_SET_LN : Op<(bitcast "float16x4_t",
(call "vset_lane",
(bitcast "int16_t", $p0),
(bitcast "int16x4_t", $p1), $p2))>;
def OP_SCALAR_HALF_SET_LNQ : Op<(bitcast "float16x8_t",
(call "vset_lane",
(bitcast "int16_t", $p0),
(bitcast "int16x8_t", $p1), $p2))>;
def OP_DOT_LN
: Op<(call "vdot", $p0, $p1,
@ -1918,10 +1910,12 @@ let ArchGuard = "defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC)" in {
def VGET_HIGH_BF : NoTestOpInst<"vget_high", ".Q", "b", OP_HI>;
def VGET_LOW_BF : NoTestOpInst<"vget_low", ".Q", "b", OP_LO>;
def VGET_LANE_BF : IOpInst<"vget_lane", "1.I", "b", OP_SCALAR_BF16_GET_LN>;
def VSET_LANE_BF : IOpInst<"vset_lane", ".1.I", "b", OP_SCALAR_BF16_SET_LN>;
def VGET_LANEQ_BF : IOpInst<"vget_lane", "1.I", "Qb", OP_SCALAR_BF16_GET_LNQ>;
def VSET_LANEQ_BF : IOpInst<"vset_lane", ".1.I", "Qb", OP_SCALAR_BF16_SET_LNQ>;
def VGET_LANE_BF : IInst<"vget_lane", "1.I", "bQb">;
def VSET_LANE_BF : IInst<"vset_lane", ".1.I", "bQb">;
def SCALAR_VDUP_LANE_BF : IInst<"vdup_lane", "1.I", "Sb">;
def SCALAR_VDUP_LANEQ_BF : IInst<"vdup_laneq", "1QI", "Sb"> {
let isLaneQ = 1;
}
def VLD1_BF : WInst<"vld1", ".(c*!)", "bQb">;
def VLD2_BF : WInst<"vld2", "2(c*!)", "bQb">;
@ -1957,18 +1951,6 @@ let ArchGuard = "defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC)" in {
}
let ArchGuard = "defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) && !defined(__aarch64__)" in {
def SCALAR_VDUP_LANE_BF_A32 : IOpInst<"vduph_lane", "1.I", "b", OP_SCALAR_BF16_GET_LN>;
def SCALAR_VDUP_LANEQ_BF_A32 : IOpInst<"vduph_laneq", "1.I", "Hb", OP_SCALAR_BF16_GET_LNQ>;
}
let ArchGuard = "defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) && defined(__aarch64__)" in {
def SCALAR_VDUP_LANE_BF_A64 : IInst<"vdup_lane", "1.I", "Sb">;
def SCALAR_VDUP_LANEQ_BF_A64 : IInst<"vdup_laneq", "1QI", "Sb"> {
let isLaneQ = 1;
}
}
let ArchGuard = "defined(__ARM_FEATURE_BF16) && !defined(__aarch64__)" in {
let BigEndianSafe = 1 in {
defm VREINTERPRET_BF : REINTERPRET_CROSS_TYPES<

View File

@ -6416,21 +6416,27 @@ static bool HasExtraNeonArgument(unsigned BuiltinID) {
default: break;
case NEON::BI__builtin_neon_vget_lane_i8:
case NEON::BI__builtin_neon_vget_lane_i16:
case NEON::BI__builtin_neon_vget_lane_bf16:
case NEON::BI__builtin_neon_vget_lane_i32:
case NEON::BI__builtin_neon_vget_lane_i64:
case NEON::BI__builtin_neon_vget_lane_f32:
case NEON::BI__builtin_neon_vgetq_lane_i8:
case NEON::BI__builtin_neon_vgetq_lane_i16:
case NEON::BI__builtin_neon_vgetq_lane_bf16:
case NEON::BI__builtin_neon_vgetq_lane_i32:
case NEON::BI__builtin_neon_vgetq_lane_i64:
case NEON::BI__builtin_neon_vgetq_lane_f32:
case NEON::BI__builtin_neon_vduph_lane_bf16:
case NEON::BI__builtin_neon_vduph_laneq_bf16:
case NEON::BI__builtin_neon_vset_lane_i8:
case NEON::BI__builtin_neon_vset_lane_i16:
case NEON::BI__builtin_neon_vset_lane_bf16:
case NEON::BI__builtin_neon_vset_lane_i32:
case NEON::BI__builtin_neon_vset_lane_i64:
case NEON::BI__builtin_neon_vset_lane_f32:
case NEON::BI__builtin_neon_vsetq_lane_i8:
case NEON::BI__builtin_neon_vsetq_lane_i16:
case NEON::BI__builtin_neon_vsetq_lane_bf16:
case NEON::BI__builtin_neon_vsetq_lane_i32:
case NEON::BI__builtin_neon_vsetq_lane_i64:
case NEON::BI__builtin_neon_vsetq_lane_f32:
@ -6876,12 +6882,16 @@ Value *CodeGenFunction::EmitARMBuiltinExpr(unsigned BuiltinID,
case NEON::BI__builtin_neon_vget_lane_i16:
case NEON::BI__builtin_neon_vget_lane_i32:
case NEON::BI__builtin_neon_vget_lane_i64:
case NEON::BI__builtin_neon_vget_lane_bf16:
case NEON::BI__builtin_neon_vget_lane_f32:
case NEON::BI__builtin_neon_vgetq_lane_i8:
case NEON::BI__builtin_neon_vgetq_lane_i16:
case NEON::BI__builtin_neon_vgetq_lane_i32:
case NEON::BI__builtin_neon_vgetq_lane_i64:
case NEON::BI__builtin_neon_vgetq_lane_bf16:
case NEON::BI__builtin_neon_vgetq_lane_f32:
case NEON::BI__builtin_neon_vduph_lane_bf16:
case NEON::BI__builtin_neon_vduph_laneq_bf16:
return Builder.CreateExtractElement(Ops[0], Ops[1], "vget_lane");
case NEON::BI__builtin_neon_vrndns_f32: {
@ -6894,11 +6904,13 @@ Value *CodeGenFunction::EmitARMBuiltinExpr(unsigned BuiltinID,
case NEON::BI__builtin_neon_vset_lane_i16:
case NEON::BI__builtin_neon_vset_lane_i32:
case NEON::BI__builtin_neon_vset_lane_i64:
case NEON::BI__builtin_neon_vset_lane_bf16:
case NEON::BI__builtin_neon_vset_lane_f32:
case NEON::BI__builtin_neon_vsetq_lane_i8:
case NEON::BI__builtin_neon_vsetq_lane_i16:
case NEON::BI__builtin_neon_vsetq_lane_i32:
case NEON::BI__builtin_neon_vsetq_lane_i64:
case NEON::BI__builtin_neon_vsetq_lane_bf16:
case NEON::BI__builtin_neon_vsetq_lane_f32:
return Builder.CreateInsertElement(Ops[1], Ops[0], Ops[2], "vset_lane");
@ -9309,11 +9321,13 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
case NEON::BI__builtin_neon_vset_lane_i16:
case NEON::BI__builtin_neon_vset_lane_i32:
case NEON::BI__builtin_neon_vset_lane_i64:
case NEON::BI__builtin_neon_vset_lane_bf16:
case NEON::BI__builtin_neon_vset_lane_f32:
case NEON::BI__builtin_neon_vsetq_lane_i8:
case NEON::BI__builtin_neon_vsetq_lane_i16:
case NEON::BI__builtin_neon_vsetq_lane_i32:
case NEON::BI__builtin_neon_vsetq_lane_i64:
case NEON::BI__builtin_neon_vsetq_lane_bf16:
case NEON::BI__builtin_neon_vsetq_lane_f32:
Ops.push_back(EmitScalarExpr(E->getArg(2)));
return Builder.CreateInsertElement(Ops[1], Ops[0], Ops[2], "vset_lane");
@ -9592,11 +9606,13 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
: Intrinsic::aarch64_neon_sqsub;
return EmitNeonCall(CGM.getIntrinsic(AccInt, Int64Ty), Ops, "vqdmlXl");
}
case NEON::BI__builtin_neon_vget_lane_bf16:
case NEON::BI__builtin_neon_vduph_lane_bf16:
case NEON::BI__builtin_neon_vduph_lane_f16: {
return Builder.CreateExtractElement(Ops[0], EmitScalarExpr(E->getArg(1)),
"vget_lane");
}
case NEON::BI__builtin_neon_vgetq_lane_bf16:
case NEON::BI__builtin_neon_vduph_laneq_bf16:
case NEON::BI__builtin_neon_vduph_laneq_f16: {
return Builder.CreateExtractElement(Ops[0], EmitScalarExpr(E->getArg(1)),

View File

@ -1,6 +1,8 @@
// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py
// RUN: %clang_cc1 -triple armv8.6a-arm-none-eabi -target-feature +neon -target-feature +bf16 -mfloat-abi hard \
// RUN: -disable-O0-optnone -emit-llvm %s -o - | opt -S -mem2reg -instcombine | FileCheck %s
// RUN: %clang_cc1 -triple armv8.6a-arm-none-eabi -target-feature +neon -target-feature +bf16 -mfloat-abi soft \
// RUN: -disable-O0-optnone -emit-llvm %s -o - | opt -S -mem2reg -instcombine | FileCheck %s
#include <arm_neon.h>
@ -98,8 +100,8 @@ bfloat16x4_t test_vget_low_bf16(bfloat16x8_t a) {
// CHECK-LABEL: @test_vget_lane_bf16(
// CHECK-NEXT: entry:
// CHECK-NEXT: [[DOTCAST1:%.*]] = extractelement <4 x bfloat> [[V:%.*]], i32 1
// CHECK-NEXT: ret bfloat [[DOTCAST1]]
// CHECK-NEXT: [[VGET_LANE:%.*]] = extractelement <4 x bfloat> [[V:%.*]], i32 1
// CHECK-NEXT: ret bfloat [[VGET_LANE]]
//
bfloat16_t test_vget_lane_bf16(bfloat16x4_t v) {
return vget_lane_bf16(v, 1);
@ -107,8 +109,8 @@ bfloat16_t test_vget_lane_bf16(bfloat16x4_t v) {
// CHECK-LABEL: @test_vgetq_lane_bf16(
// CHECK-NEXT: entry:
// CHECK-NEXT: [[DOTCAST1:%.*]] = extractelement <8 x bfloat> [[V:%.*]], i32 7
// CHECK-NEXT: ret bfloat [[DOTCAST1]]
// CHECK-NEXT: [[VGET_LANE:%.*]] = extractelement <8 x bfloat> [[V:%.*]], i32 7
// CHECK-NEXT: ret bfloat [[VGET_LANE]]
//
bfloat16_t test_vgetq_lane_bf16(bfloat16x8_t v) {
return vgetq_lane_bf16(v, 7);
@ -116,8 +118,8 @@ bfloat16_t test_vgetq_lane_bf16(bfloat16x8_t v) {
// CHECK-LABEL: @test_vset_lane_bf16(
// CHECK-NEXT: entry:
// CHECK-NEXT: [[TMP0:%.*]] = insertelement <4 x bfloat> [[V:%.*]], bfloat [[A:%.*]], i32 1
// CHECK-NEXT: ret <4 x bfloat> [[TMP0]]
// CHECK-NEXT: [[VSET_LANE:%.*]] = insertelement <4 x bfloat> [[V:%.*]], bfloat [[A:%.*]], i32 1
// CHECK-NEXT: ret <4 x bfloat> [[VSET_LANE]]
//
bfloat16x4_t test_vset_lane_bf16(bfloat16_t a, bfloat16x4_t v) {
return vset_lane_bf16(a, v, 1);
@ -125,8 +127,8 @@ bfloat16x4_t test_vset_lane_bf16(bfloat16_t a, bfloat16x4_t v) {
// CHECK-LABEL: @test_vsetq_lane_bf16(
// CHECK-NEXT: entry:
// CHECK-NEXT: [[TMP0:%.*]] = insertelement <8 x bfloat> [[V:%.*]], bfloat [[A:%.*]], i32 7
// CHECK-NEXT: ret <8 x bfloat> [[TMP0]]
// CHECK-NEXT: [[VSET_LANE:%.*]] = insertelement <8 x bfloat> [[V:%.*]], bfloat [[A:%.*]], i32 7
// CHECK-NEXT: ret <8 x bfloat> [[VSET_LANE]]
//
bfloat16x8_t test_vsetq_lane_bf16(bfloat16_t a, bfloat16x8_t v) {
return vsetq_lane_bf16(a, v, 7);
@ -143,8 +145,8 @@ bfloat16_t test_vduph_lane_bf16(bfloat16x4_t v) {
// CHECK-LABEL: @test_vduph_laneq_bf16(
// CHECK-NEXT: entry:
// CHECK-NEXT: [[VGETQ_LANE:%.*]] = extractelement <8 x bfloat> [[V:%.*]], i32 7
// CHECK-NEXT: ret bfloat [[VGETQ_LANE]]
// CHECK-NEXT: [[VGET_LANE:%.*]] = extractelement <8 x bfloat> [[V:%.*]], i32 7
// CHECK-NEXT: ret bfloat [[VGET_LANE]]
//
bfloat16_t test_vduph_laneq_bf16(bfloat16x8_t v) {
return vduph_laneq_bf16(v, 7);