JIT: rewrite pdep implementation

- use better algorithm that is O(# set bits) instead of O(# total bits)
- eliminate spilling by careful management of our temporaries
- fix nzcv clobber bug (whoops)

Signed-off-by: Alyssa Rosenzweig <alyssa@rosenzweig.io>
This commit is contained in:
Alyssa Rosenzweig 2024-01-23 18:30:08 -04:00
parent c0be974272
commit 04805f351b

View File

@ -713,64 +713,58 @@ DEF_OP(PDep) {
LOGMAN_THROW_AA_FMT(OpSize == 4 || OpSize == 8, "Unsupported {} size: {}", __func__, OpSize);
const auto EmitSize = OpSize == 8 ? ARMEmitter::Size::i64Bit : ARMEmitter::Size::i32Bit;
const auto Input = GetReg(Op->Input.ID());
const auto Mask = GetReg(Op->Mask.ID());
const auto Dest = GetReg(Node);
const auto ShiftedBitReg = TMP1.R();
const auto BitReg = TMP2.R();
const auto SubMaskReg = TMP3.R();
const auto IndexReg = TMP4.R();
const auto ZeroReg = ARMEmitter::Reg::zr;
// PDep implementation follows the ideas from
// http://0x80.pl/articles/pdep-soft-emu.html ... Basically, iterate the *set*
// bits only, which will be faster than the naive implementation as long as
// there are enough holes in the mask.
//
// The specific arm64 assembly used is based on the sequence that clang
// generates for the C code, giving context to the scheduling yielding better
// ILP than I would do by hand. The registers are allocated by hand however,
// to fit within the tight constraints we have here withot spilling. Also, we
// use cbz/cbnz for conditional branching to avoid clobbering NZCV.
const auto InputReg = StaticRegisters[0];
const auto MaskReg = StaticRegisters[1];
const auto DestReg = StaticRegisters[2];
// We can't clobber these
const auto OrigInput = GetReg(Op->Input.ID());
const auto OrigMask = GetReg(Op->Mask.ID());
const auto SpillCode = 1U << InputReg.Idx() |
1U << MaskReg.Idx() |
1U << DestReg.Idx();
// So we have shadow as temporaries
const auto Input = TMP1.R();
const auto Mask = TMP2.R();
// these get used variously as scratch
const auto T0 = TMP3.R();
const auto T1 = TMP4.R();
ARMEmitter::SingleUseForwardLabel EarlyExit;
ARMEmitter::BackwardLabel NextBit;
ARMEmitter::SingleUseForwardLabel Done;
cbz(EmitSize, Mask, &EarlyExit);
mov(EmitSize, IndexReg, ZeroReg);
// We sadly need to spill regs for this for the time being
// TODO: Remove when scratch registers can be allocated
// explicitly.
SpillStaticRegs(TMP1, false, SpillCode);
// First, copy the input/mask, since we'll be clobbering. Copy as 64-bit to
// make this 0-uop on Firestorm.
mov(ARMEmitter::Size::i64Bit, Input, OrigInput);
mov(ARMEmitter::Size::i64Bit, Mask, OrigMask);
// Now, they're copied, so we can start setting Dest (even if it overlaps with
// one of them). Handle early exit case
mov(EmitSize, Dest, 0);
cbz(EmitSize, OrigMask, &Done);
mov(EmitSize, InputReg, Input);
mov(EmitSize, MaskReg, Mask);
mov(EmitSize, DestReg, ZeroReg);
// Setup for first iteration
neg(EmitSize, T0, Mask);
and_(EmitSize, T0, T0, Mask);
// Main loop
Bind(&NextBit);
rbit(EmitSize, ShiftedBitReg, MaskReg);
clz(EmitSize, ShiftedBitReg, ShiftedBitReg);
lsrv(EmitSize, BitReg, InputReg, IndexReg);
and_(EmitSize, BitReg, BitReg, 1);
sub(EmitSize, SubMaskReg, MaskReg, 1);
add(EmitSize, IndexReg, IndexReg, 1);
ands(EmitSize, MaskReg, MaskReg, SubMaskReg);
lslv(EmitSize, ShiftedBitReg, BitReg, ShiftedBitReg);
orr(EmitSize, DestReg, DestReg, ShiftedBitReg);
b(ARMEmitter::Condition::CC_NE, &NextBit);
// Store result in a temp so it doesn't get clobbered.
// and restore it after the re-fill below.
mov(EmitSize, IndexReg, DestReg);
// Restore our registers before leaving
// TODO: Also remove along with above TODO.
FillStaticRegs(false, SpillCode);
mov(EmitSize, Dest, IndexReg);
b(&Done);
// Early exit
Bind(&EarlyExit);
mov(EmitSize, Dest, ZeroReg);
sbfx(EmitSize, T1, Input, 0, 1);
eor(EmitSize, Mask, Mask, T0);
and_(EmitSize, T0, T1, T0);
neg(EmitSize, T1, Mask);
orr(EmitSize, Dest, Dest, T0);
lsr(EmitSize, Input, Input, 1);
and_(EmitSize, T0, Mask, T1);
cbnz(EmitSize, T0, &NextBit);
// All done with nothing to do.
Bind(&Done);