IR: fix scalar FMA tied sources

needs to be modelled explicitly or else we lose information when translating

Signed-off-by: Alyssa Rosenzweig <alyssa@rosenzweig.io>
This commit is contained in:
Alyssa Rosenzweig 2024-09-04 07:37:56 -04:00
parent 8fe1e9562d
commit 6d4693cbc1
4 changed files with 31 additions and 25 deletions

View File

@ -332,7 +332,8 @@ private:
using ScalarFMAOpCaller =
std::function<void(ARMEmitter::VRegister Dst, ARMEmitter::VRegister Src1, ARMEmitter::VRegister Src2, ARMEmitter::VRegister Src3)>;
void VFScalarFMAOperation(uint8_t OpSize, uint8_t ElementSize, ScalarFMAOpCaller ScalarEmit, ARMEmitter::VRegister Dst,
ARMEmitter::VRegister Vector1, ARMEmitter::VRegister Vector2, ARMEmitter::VRegister Addend);
ARMEmitter::VRegister Upper, ARMEmitter::VRegister Vector1, ARMEmitter::VRegister Vector2,
ARMEmitter::VRegister Addend);
using ScalarBinaryOpCaller = std::function<void(ARMEmitter::VRegister Dst, ARMEmitter::VRegister Src1, ARMEmitter::VRegister Src2)>;
void VFScalarOperation(uint8_t OpSize, uint8_t ElementSize, bool ZeroUpperBits, ScalarBinaryOpCaller ScalarEmit,
ARMEmitter::VRegister Dst, ARMEmitter::VRegister Vector1, ARMEmitter::VRegister Vector2);

View File

@ -211,11 +211,12 @@ namespace FEXCore::CPU {
}; \
\
const auto Dst = GetVReg(Node); \
const auto Upper = GetVReg(Op->Upper.ID()); \
const auto Vector1 = GetVReg(Op->Vector1.ID()); \
const auto Vector2 = GetVReg(Op->Vector2.ID()); \
const auto Addend = GetVReg(Op->Addend.ID()); \
\
VFScalarFMAOperation(IROp->Size, ElementSize, ScalarEmit, Dst, Vector1, Vector2, Addend); \
VFScalarFMAOperation(IROp->Size, ElementSize, ScalarEmit, Dst, Upper, Vector1, Vector2, Addend); \
}
DEF_UNOP(VAbs, abs, true)
@ -260,17 +261,17 @@ DEF_FMAOP_SCALAR_INSERT(VFNMLAScalarInsert, fmsub)
DEF_FMAOP_SCALAR_INSERT(VFNMLSScalarInsert, fnmadd)
void Arm64JITCore::VFScalarFMAOperation(uint8_t OpSize, uint8_t ElementSize, ScalarFMAOpCaller ScalarEmit, ARMEmitter::VRegister Dst,
ARMEmitter::VRegister Vector1, ARMEmitter::VRegister Vector2, ARMEmitter::VRegister Addend) {
ARMEmitter::VRegister Upper, ARMEmitter::VRegister Vector1, ARMEmitter::VRegister Vector2,
ARMEmitter::VRegister Addend) {
LOGMAN_THROW_A_FMT(OpSize == Core::CPUState::XMM_SSE_REG_SIZE, "256-bit unsupported", __func__);
LOGMAN_THROW_AA_FMT(ElementSize == 2 || ElementSize == 4 || ElementSize == 8, "Invalid size");
const auto SubRegSize = ARMEmitter::ToVectorSizePair(ElementSize == 2 ? ARMEmitter::SubRegSize::i16Bit :
ElementSize == 4 ? ARMEmitter::SubRegSize::i32Bit :
ARMEmitter::SubRegSize::i64Bit);
if (Dst != Vector1 && Dst != Vector2 && Dst != Addend && HostSupportsAFP) {
// If destination doesnt overlap any incoming register then move the adder to the destination first.
mov(Dst.Q(), Addend.Q());
Dst = Addend;
if (Dst != Upper) {
// If destination is not tied, move the upper bits to the destination first.
mov(Dst.Q(), Upper.Q());
}
if (HostSupportsAFP && Dst == Addend) {
@ -278,7 +279,7 @@ void Arm64JITCore::VFScalarFMAOperation(uint8_t OpSize, uint8_t ElementSize, Sca
// If the host CPU supports AFP then scalar does an insert without modifying upper bits.
ScalarEmit(Dst, Vector1, Vector2, Addend);
} else {
// No overlap between addr and destination or host doesn't support AFP, need to emit in to a temporary then insert.
// Host doesn't support AFP, need to emit in to a temporary then insert.
ScalarEmit(VTMP1, Vector1, Vector2, Addend);
ins(SubRegSize.Vector, Dst.Q(), 0, VTMP1.Q(), 0);
}

View File

@ -2486,14 +2486,14 @@ void OpDispatchBuilder::AVX128_VFMAScalarImpl(OpcodeArgs, IROps IROp, uint8_t Sr
const OpSize ElementSize = Op->Flags & X86Tables::DecodeFlags::FLAG_OPTION_AVX_W ? OpSize::i64Bit : OpSize::i32Bit;
auto Dest = AVX128_LoadSource_WithOpSize(Op, Op->Dest, Op->Flags, !Is128Bit);
auto Src1 = AVX128_LoadSource_WithOpSize(Op, Op->Src[0], Op->Flags, !Is128Bit);
auto Src2 = AVX128_LoadSource_WithOpSize(Op, Op->Src[1], Op->Flags, !Is128Bit);
auto Dest = AVX128_LoadSource_WithOpSize(Op, Op->Dest, Op->Flags, !Is128Bit).Low;
auto Src1 = AVX128_LoadSource_WithOpSize(Op, Op->Src[0], Op->Flags, !Is128Bit).Low;
auto Src2 = AVX128_LoadSource_WithOpSize(Op, Op->Src[1], Op->Flags, !Is128Bit).Low;
RefPair Sources[3] = {Dest, Src1, Src2};
Ref Sources[3] = {Dest, Src1, Src2};
DeriveOp(Result_Low, IROp,
_VFMLAScalarInsert(OpSize::i128Bit, ElementSize, Sources[Src1Idx - 1].Low, Sources[Src2Idx - 1].Low, Sources[AddendIdx - 1].Low));
_VFMLAScalarInsert(OpSize::i128Bit, ElementSize, Dest, Sources[Src1Idx - 1], Sources[Src2Idx - 1], Sources[AddendIdx - 1]));
AVX128_StoreResult_WithOpSize(Op, Op->Dest, AVX128_Zext(Result_Low));
}

View File

@ -1790,41 +1790,45 @@
"DestSize": "RegisterSize",
"NumElements": "RegisterSize / ElementSize"
},
"FPR = VFMLAScalarInsert u8:#RegisterSize, u8:#ElementSize, FPR:$Vector1, FPR:$Vector2, FPR:$Addend": {
"FPR = VFMLAScalarInsert u8:#RegisterSize, u8:#ElementSize, FPR:$Upper, FPR:$Vector1, FPR:$Vector2, FPR:$Addend": {
"Desc": [
"Dest = (Vector1 * Vector2) + Addend",
"This explicitly matches x86 FMA semantics because ARM semantics are mind-bending."
"This explicitly matches x86 FMA semantics because ARM semantics are mind-bending.",
"Upper elements copied from Upper"
],
"DestSize": "RegisterSize",
"NumElements": "RegisterSize / ElementSize",
"TiedSource": 2
"TiedSource": 0
},
"FPR = VFMLSScalarInsert u8:#RegisterSize, u8:#ElementSize, FPR:$Vector1, FPR:$Vector2, FPR:$Addend": {
"FPR = VFMLSScalarInsert u8:#RegisterSize, u8:#ElementSize, FPR:$Upper, FPR:$Vector1, FPR:$Vector2, FPR:$Addend": {
"Desc": [
"Dest = (Vector1 * Vector2) - Addend",
"This explicitly matches x86 FMA semantics because ARM semantics are mind-bending."
"This explicitly matches x86 FMA semantics because ARM semantics are mind-bending.",
"Upper elements copied from Upper"
],
"DestSize": "RegisterSize",
"NumElements": "RegisterSize / ElementSize",
"TiedSource": 2
"TiedSource": 0
},
"FPR = VFNMLAScalarInsert u8:#RegisterSize, u8:#ElementSize, FPR:$Vector1, FPR:$Vector2, FPR:$Addend": {
"FPR = VFNMLAScalarInsert u8:#RegisterSize, u8:#ElementSize, FPR:$Upper, FPR:$Vector1, FPR:$Vector2, FPR:$Addend": {
"Desc": [
"Dest = (-Vector1 * Vector2) + Addend",
"This explicitly matches x86 FMA semantics because ARM semantics are mind-bending."
"This explicitly matches x86 FMA semantics because ARM semantics are mind-bending.",
"Upper elements copied from Upper"
],
"DestSize": "RegisterSize",
"NumElements": "RegisterSize / ElementSize",
"TiedSource": 2
"TiedSource": 0
},
"FPR = VFNMLSScalarInsert u8:#RegisterSize, u8:#ElementSize, FPR:$Vector1, FPR:$Vector2, FPR:$Addend": {
"FPR = VFNMLSScalarInsert u8:#RegisterSize, u8:#ElementSize, FPR:$Upper, FPR:$Vector1, FPR:$Vector2, FPR:$Addend": {
"Desc": [
"Dest = (-Vector1 * Vector2) - Addend",
"This explicitly matches x86 FMA semantics because ARM semantics are mind-bending."
"This explicitly matches x86 FMA semantics because ARM semantics are mind-bending.",
"Upper elements copied from Upper"
],
"DestSize": "RegisterSize",
"NumElements": "RegisterSize / ElementSize",
"TiedSource": 2
"TiedSource": 0
}
},
"Vector": {