diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp index 7d17b0ecd49a..c43f656c2f42 100644 --- a/clang/lib/Sema/SemaChecking.cpp +++ b/clang/lib/Sema/SemaChecking.cpp @@ -3546,9 +3546,11 @@ bool Sema::CheckX86BuiltinRoundingOrSAE(unsigned BuiltinID, CallExpr *TheCall) { // Make sure rounding mode is either ROUND_CUR_DIRECTION or ROUND_NO_EXC bit // is set. If the intrinsic has rounding control(bits 1:0), make sure its only - // combined with ROUND_NO_EXC. + // combined with ROUND_NO_EXC. If the intrinsic does not have rounding + // control, allow ROUND_NO_EXC and ROUND_CUR_DIRECTION together. if (Result == 4/*ROUND_CUR_DIRECTION*/ || Result == 8/*ROUND_NO_EXC*/ || + (!HasRC && Result == 12/*ROUND_CUR_DIRECTION|ROUND_NO_EXC*/) || (HasRC && Result.getZExtValue() >= 8 && Result.getZExtValue() <= 11)) return false; diff --git a/clang/test/Sema/builtins-x86.c b/clang/test/Sema/builtins-x86.c index 6a2a47d7792c..dca0bdc720a0 100644 --- a/clang/test/Sema/builtins-x86.c +++ b/clang/test/Sema/builtins-x86.c @@ -81,6 +81,19 @@ __mmask16 test__builtin_ia32_cmpps512_mask_rounding(__m512 __a, __m512 __b, __mm return __builtin_ia32_cmpps512_mask(__a, __b, 0, __u, 0); // expected-error {{invalid rounding argument}} } +// Make sure we allow 4(CUR_DIRECTION), 8(NO_EXC), and 12(CUR_DIRECTION|NOEXC) for SAE arguments. +__mmask16 test__builtin_ia32_cmpps512_mask_rounding_cur_dir(__m512 __a, __m512 __b, __mmask16 __u) { + return __builtin_ia32_cmpps512_mask(__a, __b, 0, __u, 4); // no-error +} + +__mmask16 test__builtin_ia32_cmpps512_mask_rounding_sae1(__m512 __a, __m512 __b, __mmask16 __u) { + return __builtin_ia32_cmpps512_mask(__a, __b, 0, __u, 8); // no-error +} + +__mmask16 test__builtin_ia32_cmpps512_mask_rounding_sae2(__m512 __a, __m512 __b, __mmask16 __u) { + return __builtin_ia32_cmpps512_mask(__a, __b, 0, __u, 12); // no-error +} + __m512 test__builtin_ia32_getmantps512_mask(__m512 a, __m512 b) { return __builtin_ia32_getmantps512_mask(a, 0, b, (__mmask16)-1, 10); // expected-error {{invalid rounding argument}} } diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 0a8219214f46..2195f40c247a 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -22706,8 +22706,16 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, return false; }; auto isRoundModeSAE = [](SDValue Rnd) { - if (auto *C = dyn_cast(Rnd)) - return C->getAPIntValue() == X86::STATIC_ROUNDING::NO_EXC; + if (auto *C = dyn_cast(Rnd)) { + unsigned RC = C->getZExtValue(); + if (RC & X86::STATIC_ROUNDING::NO_EXC) { + // Clear the NO_EXC bit and check remaining bits. + RC ^= X86::STATIC_ROUNDING::NO_EXC; + // As a convenience we allow no other bits or explicitly + // current direction. + return RC == 0 || RC == X86::STATIC_ROUNDING::CUR_DIRECTION; + } + } return false; }; diff --git a/llvm/test/CodeGen/X86/avx512-intrinsics.ll b/llvm/test/CodeGen/X86/avx512-intrinsics.ll index 6f0aba31cf15..b2d6ce4dfc9d 100644 --- a/llvm/test/CodeGen/X86/avx512-intrinsics.ll +++ b/llvm/test/CodeGen/X86/avx512-intrinsics.ll @@ -755,7 +755,7 @@ define <8 x double> @test_getexp_round_pd_512(<8 x double> %a0) { ; CHECK: # %bb.0: ; CHECK-NEXT: vgetexppd {sae}, %zmm0, %zmm0 ; CHECK-NEXT: ret{{[l|q]}} - %res = call <8 x double> @llvm.x86.avx512.mask.getexp.pd.512(<8 x double> %a0, <8 x double> zeroinitializer, i8 -1, i32 8) + %res = call <8 x double> @llvm.x86.avx512.mask.getexp.pd.512(<8 x double> %a0, <8 x double> zeroinitializer, i8 -1, i32 12) ret <8 x double> %res } declare <8 x double> @llvm.x86.avx512.mask.getexp.pd.512(<8 x double>, <8 x double>, i8, i32) nounwind readnone