implement 32-bit SMIN + SMINP (+sminp u16)

This commit is contained in:
lizzie
2026-01-12 10:32:19 +00:00
committed by crueter
parent b3725ff014
commit 1834477a67
2 changed files with 231 additions and 85 deletions

View File

@@ -2875,42 +2875,6 @@ static void EmitVectorPairedMinMax16(BlockOfCode& code, EmitContext& ctx, IR::In
ctx.reg_alloc.DefineValue(code, inst, x);
}
template<typename Function>
static void EmitVectorPairedMinMaxLower16(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst, Function fn) {
auto args = ctx.reg_alloc.GetArgumentInfo(inst);
auto const x = ctx.reg_alloc.UseScratchXmm(code, args[0]);
auto const y = ctx.reg_alloc.UseScratchXmm(code, args[1]);
auto const tmp = ctx.reg_alloc.ScratchXmm(code);
// swap idxs 1 and 2 so that both registers contain even then odd-indexed pairs of elements
code.pshuflw(x, x, 0b11'01'10'00);
code.pshuflw(y, y, 0b11'01'10'00);
// move pairs of even/odd-indexed elements into one register each
// tmp = x[0, 2], y[0, 2], 0s...
code.movaps(tmp, y);
code.insertps(tmp, x, 0b01001100);
// x = x[1, 3], y[1, 3], 0s...
code.insertps(x, y, 0b00011100);
(code.*fn)(x, tmp);
ctx.reg_alloc.DefineValue(code, inst, x);
}
static void EmitVectorPairedMinMaxLower32(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst, void (Xbyak::CodeGenerator::*fn)(Xbyak::Xmm const&, const Xbyak::Operand&)) {
auto args = ctx.reg_alloc.GetArgumentInfo(inst);
auto const x = ctx.reg_alloc.UseScratchXmm(code, args[0]);
auto const y = ctx.reg_alloc.UseXmm(code, args[1]);
auto const tmp = ctx.reg_alloc.ScratchXmm(code);
// tmp = x[1], y[1], 0, 0
code.movaps(tmp, y);
code.insertps(tmp, x, 0b01001100);
// x = x[0], y[0], 0, 0
code.insertps(x, y, 0b00011100);
(code.*fn)(x, tmp);
ctx.reg_alloc.DefineValue(code, inst, x);
}
void EmitX64::EmitVectorPairedMaxS8(EmitContext& ctx, IR::Inst* inst) {
auto args = ctx.reg_alloc.GetArgumentInfo(inst);
@@ -3098,7 +3062,21 @@ void EmitX64::EmitVectorPairedMaxLowerS8(EmitContext& ctx, IR::Inst* inst) {
void EmitX64::EmitVectorPairedMaxLowerS16(EmitContext& ctx, IR::Inst* inst) {
if (code.HasHostFeature(HostFeature::SSE41)) {
EmitVectorPairedMinMaxLower16(code, ctx, inst, &Xbyak::CodeGenerator::pmaxsw);
auto args = ctx.reg_alloc.GetArgumentInfo(inst);
auto const x = ctx.reg_alloc.UseScratchXmm(code, args[0]);
auto const y = ctx.reg_alloc.UseScratchXmm(code, args[1]);
auto const tmp = ctx.reg_alloc.ScratchXmm(code);
// swap idxs 1 and 2 so that both registers contain even then odd-indexed pairs of elements
code.pshuflw(x, x, 0b11'01'10'00);
code.pshuflw(y, y, 0b11'01'10'00);
// move pairs of even/odd-indexed elements into one register each
// tmp = x[0, 2], y[0, 2], 0s...
code.movaps(tmp, y);
code.insertps(tmp, x, 0b01001100);
// x = x[1, 3], y[1, 3], 0s...
code.insertps(x, y, 0b00011100);
code.pmaxsw(x, tmp);
ctx.reg_alloc.DefineValue(code, inst, x);
} else {
EmitTwoArgumentFallback(code, ctx, inst, [](VectorArray<s16>& result, const VectorArray<s16>& a, const VectorArray<s16>& b) {
LowerPairedMax(result, a, b);
@@ -3108,13 +3086,22 @@ void EmitX64::EmitVectorPairedMaxLowerS16(EmitContext& ctx, IR::Inst* inst) {
void EmitX64::EmitVectorPairedMaxLowerS32(EmitContext& ctx, IR::Inst* inst) {
if (code.HasHostFeature(HostFeature::SSE41)) {
EmitVectorPairedMinMaxLower32(code, ctx, inst, &Xbyak::CodeGenerator::pmaxsd);
return;
auto args = ctx.reg_alloc.GetArgumentInfo(inst);
auto const x = ctx.reg_alloc.UseScratchXmm(code, args[0]);
auto const y = ctx.reg_alloc.UseXmm(code, args[1]);
auto const tmp = ctx.reg_alloc.ScratchXmm(code);
// tmp = x[1], y[1], 0, 0
code.movaps(tmp, y);
code.insertps(tmp, x, 0b01001100);
// x = x[0], y[0], 0, 0
code.insertps(x, y, 0b00011100);
code.pmaxsd(x, tmp);
ctx.reg_alloc.DefineValue(code, inst, x);
} else {
EmitTwoArgumentFallback(code, ctx, inst, [](VectorArray<s32>& result, const VectorArray<s32>& a, const VectorArray<s32>& b) {
LowerPairedMax(result, a, b);
});
}
EmitTwoArgumentFallback(code, ctx, inst, [](VectorArray<s32>& result, const VectorArray<s32>& a, const VectorArray<s32>& b) {
LowerPairedMax(result, a, b);
});
}
void EmitX64::EmitVectorPairedMaxLowerU8(EmitContext& ctx, IR::Inst* inst) {
@@ -3130,24 +3117,46 @@ void EmitX64::EmitVectorPairedMaxLowerU8(EmitContext& ctx, IR::Inst* inst) {
void EmitX64::EmitVectorPairedMaxLowerU16(EmitContext& ctx, IR::Inst* inst) {
if (code.HasHostFeature(HostFeature::SSE41)) {
EmitVectorPairedMinMaxLower16(code, ctx, inst, &Xbyak::CodeGenerator::pmaxuw);
return;
auto args = ctx.reg_alloc.GetArgumentInfo(inst);
auto const x = ctx.reg_alloc.UseScratchXmm(code, args[0]);
auto const y = ctx.reg_alloc.UseScratchXmm(code, args[1]);
auto const tmp = ctx.reg_alloc.ScratchXmm(code);
// swap idxs 1 and 2 so that both registers contain even then odd-indexed pairs of elements
code.pshuflw(x, x, 0b11'01'10'00);
code.pshuflw(y, y, 0b11'01'10'00);
// move pairs of even/odd-indexed elements into one register each
// tmp = x[0, 2], y[0, 2], 0s...
code.movaps(tmp, y);
code.insertps(tmp, x, 0b01001100);
// x = x[1, 3], y[1, 3], 0s...
code.insertps(x, y, 0b00011100);
code.pmaxuw(x, tmp);
ctx.reg_alloc.DefineValue(code, inst, x);
} else {
EmitTwoArgumentFallback(code, ctx, inst, [](VectorArray<u16>& result, const VectorArray<u16>& a, const VectorArray<u16>& b) {
LowerPairedMax(result, a, b);
});
}
EmitTwoArgumentFallback(code, ctx, inst, [](VectorArray<u16>& result, const VectorArray<u16>& a, const VectorArray<u16>& b) {
LowerPairedMax(result, a, b);
});
}
void EmitX64::EmitVectorPairedMaxLowerU32(EmitContext& ctx, IR::Inst* inst) {
if (code.HasHostFeature(HostFeature::SSE41)) {
EmitVectorPairedMinMaxLower32(code, ctx, inst, &Xbyak::CodeGenerator::pmaxud);
return;
auto args = ctx.reg_alloc.GetArgumentInfo(inst);
auto const x = ctx.reg_alloc.UseScratchXmm(code, args[0]);
auto const y = ctx.reg_alloc.UseXmm(code, args[1]);
auto const tmp = ctx.reg_alloc.ScratchXmm(code);
// tmp = x[1], y[1], 0, 0
code.movaps(tmp, y);
code.insertps(tmp, x, 0b01001100);
// x = x[0], y[0], 0, 0
code.insertps(x, y, 0b00011100);
code.pmaxud(x, tmp);
ctx.reg_alloc.DefineValue(code, inst, x);
} else {
EmitTwoArgumentFallback(code, ctx, inst, [](VectorArray<u32>& result, const VectorArray<u32>& a, const VectorArray<u32>& b) {
LowerPairedMax(result, a, b);
});
}
EmitTwoArgumentFallback(code, ctx, inst, [](VectorArray<u32>& result, const VectorArray<u32>& a, const VectorArray<u32>& b) {
LowerPairedMax(result, a, b);
});
}
void EmitX64::EmitVectorPairedMinLowerS8(EmitContext& ctx, IR::Inst* inst) {
@@ -3167,24 +3176,55 @@ void EmitX64::EmitVectorPairedMinLowerS8(EmitContext& ctx, IR::Inst* inst) {
void EmitX64::EmitVectorPairedMinLowerS16(EmitContext& ctx, IR::Inst* inst) {
if (code.HasHostFeature(HostFeature::SSE41)) {
EmitVectorPairedMinMaxLower16(code, ctx, inst, &Xbyak::CodeGenerator::pminsw);
return;
auto args = ctx.reg_alloc.GetArgumentInfo(inst);
auto const x = ctx.reg_alloc.UseScratchXmm(code, args[0]);
auto const y = ctx.reg_alloc.UseScratchXmm(code, args[1]);
auto const tmp = ctx.reg_alloc.ScratchXmm(code);
// swap idxs 1 and 2 so that both registers contain even then odd-indexed pairs of elements
code.pshuflw(x, x, 0b11'01'10'00);
code.pshuflw(y, y, 0b11'01'10'00);
// move pairs of even/odd-indexed elements into one register each
// tmp = x[0, 2], y[0, 2], 0s...
code.movaps(tmp, y);
code.insertps(tmp, x, 0b01001100);
// x = x[1, 3], y[1, 3], 0s...
code.insertps(x, y, 0b00011100);
code.pminsw(x, tmp);
ctx.reg_alloc.DefineValue(code, inst, x);
} else {
EmitTwoArgumentFallback(code, ctx, inst, [](VectorArray<s16>& result, const VectorArray<s16>& a, const VectorArray<s16>& b) {
LowerPairedMin(result, a, b);
});
}
EmitTwoArgumentFallback(code, ctx, inst, [](VectorArray<s16>& result, const VectorArray<s16>& a, const VectorArray<s16>& b) {
LowerPairedMin(result, a, b);
});
}
void EmitX64::EmitVectorPairedMinLowerS32(EmitContext& ctx, IR::Inst* inst) {
auto args = ctx.reg_alloc.GetArgumentInfo(inst);
if (code.HasHostFeature(HostFeature::SSE41)) {
EmitVectorPairedMinMaxLower32(code, ctx, inst, &Xbyak::CodeGenerator::pminsd);
return;
auto const x = ctx.reg_alloc.UseScratchXmm(code, args[0]);
auto const y = ctx.reg_alloc.UseXmm(code, args[1]);
auto const tmp = ctx.reg_alloc.ScratchXmm(code);
// tmp = x[1], y[1], 0, 0
code.movaps(tmp, y);
code.insertps(tmp, x, 0b01001100);
// x = x[0], y[0], 0, 0
code.insertps(x, y, 0b00011100);
code.pminsd(x, tmp);
ctx.reg_alloc.DefineValue(code, inst, x);
} else {
auto const tmp0 = ctx.reg_alloc.UseScratchXmm(code, args[0]);
auto const tmp1 = ctx.reg_alloc.UseScratchXmm(code, args[1]);
auto const tmp2 = ctx.reg_alloc.ScratchXmm(code);
code.punpckldq(tmp0, tmp1);
code.pshufd(tmp1, tmp0, 238);
code.movdqa(tmp2, tmp0);
code.pcmpgtd(tmp2, tmp1);
code.pand(tmp1, tmp2);
code.pandn(tmp2, tmp0);
code.por(tmp2, tmp1);
code.movq(tmp0, tmp2);
ctx.reg_alloc.DefineValue(code, inst, tmp0);
}
EmitTwoArgumentFallback(code, ctx, inst, [](VectorArray<s32>& result, const VectorArray<s32>& a, const VectorArray<s32>& b) {
LowerPairedMin(result, a, b);
});
}
void EmitX64::EmitVectorPairedMinLowerU8(EmitContext& ctx, IR::Inst* inst) {
@@ -3199,39 +3239,80 @@ void EmitX64::EmitVectorPairedMinLowerU8(EmitContext& ctx, IR::Inst* inst) {
}
void EmitX64::EmitVectorPairedMinLowerU16(EmitContext& ctx, IR::Inst* inst) {
auto args = ctx.reg_alloc.GetArgumentInfo(inst);
if (code.HasHostFeature(HostFeature::SSE41)) {
EmitVectorPairedMinMaxLower16(code, ctx, inst, &Xbyak::CodeGenerator::pminuw);
return;
auto const x = ctx.reg_alloc.UseScratchXmm(code, args[0]);
auto const y = ctx.reg_alloc.UseScratchXmm(code, args[1]);
auto const tmp = ctx.reg_alloc.ScratchXmm(code);
// swap idxs 1 and 2 so that both registers contain even then odd-indexed pairs of elements
code.pshuflw(x, x, 0b11'01'10'00);
code.pshuflw(y, y, 0b11'01'10'00);
// move pairs of even/odd-indexed elements into one register each
// tmp = x[0, 2], y[0, 2], 0s...
code.movaps(tmp, y);
code.insertps(tmp, x, 0b01001100);
// x = x[1, 3], y[1, 3], 0s...
code.insertps(x, y, 0b00011100);
code.pminuw(x, tmp);
ctx.reg_alloc.DefineValue(code, inst, x);
} else {
auto const tmp0 = ctx.reg_alloc.UseScratchXmm(code, args[0]);
auto const tmp1 = ctx.reg_alloc.UseScratchXmm(code, args[1]);
auto const tmp2 = ctx.reg_alloc.ScratchXmm(code);
code.punpcklwd(tmp0, tmp1);
code.pshufd(tmp1, tmp0, 231);
code.pshuflw(tmp1, tmp1, 114);
code.pshufd(tmp0, tmp0, 232);
code.pshuflw(tmp0, tmp0, 216);
code.movdqa(tmp2, tmp1);
code.psubusw(tmp2, tmp0);
code.psubw(tmp1, tmp2);
code.movq(tmp0, tmp1);
ctx.reg_alloc.DefineValue(code, inst, tmp0);
}
EmitTwoArgumentFallback(code, ctx, inst, [](VectorArray<u16>& result, const VectorArray<u16>& a, const VectorArray<u16>& b) {
LowerPairedMin(result, a, b);
});
}
void EmitX64::EmitVectorPairedMinLowerU32(EmitContext& ctx, IR::Inst* inst) {
auto args = ctx.reg_alloc.GetArgumentInfo(inst);
if (code.HasHostFeature(HostFeature::SSE41)) {
EmitVectorPairedMinMaxLower32(code, ctx, inst, &Xbyak::CodeGenerator::pminud);
return;
auto const x = ctx.reg_alloc.UseScratchXmm(code, args[0]);
auto const y = ctx.reg_alloc.UseXmm(code, args[1]);
auto const tmp = ctx.reg_alloc.ScratchXmm(code);
// tmp = x[1], y[1], 0, 0
code.movaps(tmp, y);
code.insertps(tmp, x, 0b01001100);
// x = x[0], y[0], 0, 0
code.insertps(x, y, 0b00011100);
code.pminud(x, tmp);
ctx.reg_alloc.DefineValue(code, inst, x);
} else {
auto const tmp0 = ctx.reg_alloc.UseScratchXmm(code, args[0]);
auto const tmp1 = ctx.reg_alloc.UseScratchXmm(code, args[1]);
auto const tmp2 = ctx.reg_alloc.ScratchXmm(code);
auto const tmp3 = ctx.reg_alloc.ScratchXmm(code);
code.punpckldq(tmp0, tmp1);
code.pshufd(tmp1, tmp0, 238);
code.movdqa(tmp2, code.Const(xword, 0x8000'00008000'0000, 0x8000'00008000'0000));
code.movdqa(tmp3, tmp0);
code.pxor(tmp3, tmp2);
code.pxor(tmp2, tmp1);
code.pcmpgtd(tmp3, tmp2);
code.pand(tmp1, tmp3);
code.pandn(tmp3, tmp0);
code.por(tmp3, tmp1);
code.movq(tmp0, tmp3);
ctx.reg_alloc.DefineValue(code, inst, tmp0);
}
EmitTwoArgumentFallback(code, ctx, inst, [](VectorArray<u32>& result, const VectorArray<u32>& a, const VectorArray<u32>& b) {
LowerPairedMin(result, a, b);
});
}
template<typename D, typename T>
static D PolynomialMultiply(T lhs, T rhs) {
constexpr size_t bit_size = mcl::bitsizeof<T>;
const std::bitset<bit_size> operand(lhs);
D res = 0;
for (size_t i = 0; i < bit_size; i++) {
if (operand[i]) {
for (size_t i = 0; i < bit_size; i++)
if (operand[i])
res ^= rhs << i;
}
}
return res;
}

View File

@@ -448,6 +448,71 @@ TEST_CASE("A64: SQSHLU", "[a64]") {
CHECK(jit.GetVector(15) == Vector{0, 0x705cd04});
}
TEST_CASE("A64: SMIN", "[a64]") {
A64TestEnv env;
A64::UserConfig jit_user_config{};
jit_user_config.callbacks = &env;
A64::Jit jit{jit_user_config};
oaknut::VectorCodeGenerator code{env.code_mem, nullptr};
code.SMIN(V8.B16(), V0.B16(), V3.B16());
code.SMIN(V9.H8(), V1.H8(), V2.H8());
code.SMIN(V10.S4(), V2.S4(), V3.S4());
code.SMIN(V11.S4(), V3.S4(), V3.S4());
code.SMIN(V12.S4(), V0.S4(), V3.S4());
code.SMIN(V13.S4(), V1.S4(), V2.S4());
code.SMIN(V14.S4(), V2.S4(), V1.S4());
code.SMIN(V15.S4(), V3.S4(), V0.S4());
jit.SetPC(0);
jit.SetVector(0, Vector{0xffffffff'18ba6a6a, 0x7fffffff'943b954f});
jit.SetVector(1, Vector{0x0000000b'0000000f, 0xffffffff'ffffffff});
jit.SetVector(2, Vector{0x00000001'000000ff, 0x00000010'0000007f});
jit.SetVector(3, Vector{0xffffffff'ffffffff, 0x96dc5c14'0705cd04});
env.ticks_left = 4;
CheckedRun([&]() { jit.Run(); });
REQUIRE(jit.GetVector(8) == Vector{0xffffffffffbaffff, 0x96dcffff94059504});
REQUIRE(jit.GetVector(9) == Vector{0x10000000f, 0xffffffffffffffff});
REQUIRE(jit.GetVector(10) == Vector{0xffffffffffffffff, 0x96dc5c140000007f});
}
TEST_CASE("A64: SMINP", "[a64]") {
A64TestEnv env;
A64::UserConfig jit_user_config{};
jit_user_config.callbacks = &env;
A64::Jit jit{jit_user_config};
oaknut::VectorCodeGenerator code{env.code_mem, nullptr};
code.SMINP(V8.B16(), V0.B16(), V3.B16());
code.SMINP(V9.H8(), V1.H8(), V2.H8());
code.SMINP(V10.S4(), V2.S4(), V1.S4());
code.SMINP(V11.S4(), V3.S4(), V3.S4());
code.SMINP(V12.S4(), V0.S4(), V3.S4());
code.SMINP(V13.S4(), V1.S4(), V2.S4());
code.SMINP(V14.S4(), V2.S4(), V1.S4());
code.SMINP(V15.S4(), V3.S4(), V0.S4());
jit.SetPC(0);
jit.SetVector(0, Vector{0xffffffff'18ba6a6a, 0x7fffffff'943b954f});
jit.SetVector(1, Vector{0x0000000b'0000000f, 0xffffffff'ffffffff});
jit.SetVector(2, Vector{0x00000001'000000ff, 0x00000010'0000007f});
jit.SetVector(3, Vector{0xffffffff'ffffffff, 0x96dc5c14'0705cd04});
env.ticks_left = 4;
CheckedRun([&]() { jit.Run(); });
REQUIRE(jit.GetVector(8) == Vector{0xffff9495ffffba6a, 0x961405cdffffffff});
REQUIRE(jit.GetVector(9) == Vector{0xffffffff00000000, 0});
REQUIRE(jit.GetVector(10) == Vector{0x1000000001, 0xffffffff0000000b});
REQUIRE(jit.GetVector(11) == Vector{0x96dc5c14ffffffff, 0x96dc5c14ffffffff});
REQUIRE(jit.GetVector(12) == Vector{0x943b954fffffffff, 0x96dc5c14ffffffff});
REQUIRE(jit.GetVector(13) == Vector{0xffffffff0000000b, 0x1000000001});
REQUIRE(jit.GetVector(14) == Vector{0x1000000001, 0xffffffff0000000b});
REQUIRE(jit.GetVector(15) == Vector{0x96dc5c14ffffffff, 0x943b954fffffffff});
}
TEST_CASE("A64: XTN", "[a64]") {
A64TestEnv env;
A64::UserConfig jit_user_config{};