shader: Improve signed integer normalization (#3254)

This commit is contained in:
Macdu 2024-03-23 23:05:18 +01:00 committed by GitHub
parent 2424f92c2e
commit e671af683a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 56 additions and 95 deletions

View File

@ -59,7 +59,7 @@ spv::Id make_vector_or_scalar_type(spv::Builder &b, spv::Id component, int size)
spv::Id unwrap_type(spv::Builder &b, spv::Id type);
spv::Id convert_to_float(spv::Builder &b, spv::Id opr, DataType type, bool normal);
spv::Id convert_to_float(spv::Builder &b, const SpirvUtilFunctions &utils, spv::Id opr, DataType type, bool normal);
spv::Id convert_to_int(spv::Builder &b, const SpirvUtilFunctions &utils, spv::Id opr, DataType type, bool normal);
spv::Id add_uvec2_uint(spv::Builder &b, spv::Id vec, spv::Id to_add);

View File

@ -1388,7 +1388,7 @@ static spv::Function *make_frag_finalize_function(spv::Builder &b, const SpirvSh
color = utils::load(b, parameters, utils, features, color_val_operand, 0xF, reg_off);
if (!is_float_data_type(color_val_operand.type))
color = utils::convert_to_float(b, color, color_val_operand.type, true);
color = utils::convert_to_float(b, utils, color, color_val_operand.type, true);
if (program.is_frag_color_used() && features.should_use_shader_interlock()) {
spv::Id signed_i32 = b.makeIntType(32);

View File

@ -999,10 +999,10 @@ bool USSETranslatorVisitor::sop2(
spv::Id src1_alpha = load(inst.opr.src1, 0b1000, src1_repeat_offset);
spv::Id src2_alpha = load(inst.opr.src2, 0b1000, src2_repeat_offset);
src1_color = utils::convert_to_float(m_b, src1_color, DataType::UINT8, true);
src2_color = utils::convert_to_float(m_b, src2_color, DataType::UINT8, true);
src1_alpha = utils::convert_to_float(m_b, src1_alpha, DataType::UINT8, true);
src2_alpha = utils::convert_to_float(m_b, src2_alpha, DataType::UINT8, true);
src1_color = utils::convert_to_float(m_b, m_util_funcs, src1_color, DataType::UINT8, true);
src2_color = utils::convert_to_float(m_b, m_util_funcs, src2_color, DataType::UINT8, true);
src1_alpha = utils::convert_to_float(m_b, m_util_funcs, src1_alpha, DataType::UINT8, true);
src2_alpha = utils::convert_to_float(m_b, m_util_funcs, src2_alpha, DataType::UINT8, true);
spv::Id src_color_type = m_b.getTypeId(src1_color);
spv::Id src_alpha_type = m_b.getTypeId(src1_alpha);
@ -1191,8 +1191,8 @@ bool shader::usse::USSETranslatorVisitor::sop2m(Imm2 pred,
spv::Id src1 = load(inst.opr.src1, 0b1111, 0);
spv::Id src2 = load(inst.opr.src2, 0b1111, 0);
src1 = utils::convert_to_float(m_b, src1, DataType::UINT8, true);
src2 = utils::convert_to_float(m_b, src2, DataType::UINT8, true);
src1 = utils::convert_to_float(m_b, m_util_funcs, src1, DataType::UINT8, true);
src2 = utils::convert_to_float(m_b, m_util_funcs, src2, DataType::UINT8, true);
spv::Id src_type = m_b.getTypeId(src1);
@ -1442,12 +1442,12 @@ bool shader::usse::USSETranslatorVisitor::sop3(Imm2 pred,
spv::Id src1_alpha = load(inst.opr.src1, 0b1000);
spv::Id src2_alpha = load(inst.opr.src2, 0b1000);
src0_color = utils::convert_to_float(m_b, src0_color, DataType::UINT8, true);
src1_color = utils::convert_to_float(m_b, src1_color, DataType::UINT8, true);
src2_color = utils::convert_to_float(m_b, src2_color, DataType::UINT8, true);
src0_alpha = utils::convert_to_float(m_b, src0_alpha, DataType::UINT8, true);
src1_alpha = utils::convert_to_float(m_b, src1_alpha, DataType::UINT8, true);
src2_alpha = utils::convert_to_float(m_b, src2_alpha, DataType::UINT8, true);
src0_color = utils::convert_to_float(m_b, m_util_funcs, src0_color, DataType::UINT8, true);
src1_color = utils::convert_to_float(m_b, m_util_funcs, src1_color, DataType::UINT8, true);
src2_color = utils::convert_to_float(m_b, m_util_funcs, src2_color, DataType::UINT8, true);
src0_alpha = utils::convert_to_float(m_b, m_util_funcs, src0_alpha, DataType::UINT8, true);
src1_alpha = utils::convert_to_float(m_b, m_util_funcs, src1_alpha, DataType::UINT8, true);
src2_alpha = utils::convert_to_float(m_b, m_util_funcs, src2_alpha, DataType::UINT8, true);
spv::Id src_color_type = m_b.getTypeId(src0_color);
spv::Id src_alpha_type = m_b.getTypeId(src0_alpha);

View File

@ -512,7 +512,7 @@ bool USSETranslatorVisitor::vpck(
// source is int destination is float
if (is_float_data_type(inst.opr.dest.type) && !is_float_data_type(inst.opr.src1.type)) {
source = utils::convert_to_float(m_b, source, inst.opr.src1.type, scale);
source = utils::convert_to_float(m_b, m_util_funcs, source, inst.opr.src1.type, scale);
}
// source is float destination is int

View File

@ -172,6 +172,7 @@ static spv::Function *make_fx10_unpack_func(spv::Builder &b, const SpirvUtilFunc
spv::Id type_i32 = b.makeIntType(32);
spv::Id ivec3 = b.makeVectorType(type_i32, 3);
spv::Id uvec3 = b.makeVectorType(b.makeUintType(32), 3);
spv::Id type_f32 = b.makeFloatType(32);
spv::Id type_f32_v3 = b.makeVectorType(type_f32, 3);
@ -181,25 +182,25 @@ static spv::Function *make_fx10_unpack_func(spv::Builder &b, const SpirvUtilFunc
spv::Id extracted = fx10_unpack_func->getParamId(0);
// Cast to uint first
// Cast to int first
extracted = b.createUnaryOp(spv::OpBitcast, type_i32, extracted);
spv::Id vec = b.createCompositeConstruct(ivec3, { extracted, extracted, extracted });
// vec = vec >> ivec3(0,10,20);
// vec = vec >> uvec3(0,10,20);
// note: note entirely sure, I really hope the layout is the same as in a 32-bit little-endian integer
const spv::Id shift_amount = b.createCompositeConstruct(ivec3, { b.makeIntConstant(0), b.makeIntConstant(10), b.makeIntConstant(20) });
const spv::Id shift_amount = b.makeCompositeConstant(uvec3, { b.makeUintConstant(0), b.makeUintConstant(10), b.makeUintConstant(20) });
vec = b.createBinOp(spv::OpShiftRightLogical, ivec3, vec, shift_amount);
// sign-extend the 10-bit integer:
// vec <<= 22 (logical)
// vec >>= 22 (arithmetic)
spv::Id extend_amount = b.makeIntConstant(22);
extend_amount = b.createCompositeConstruct(ivec3, { extend_amount, extend_amount, extend_amount });
spv::Id extend_amount = b.makeUintConstant(22);
extend_amount = b.makeCompositeConstant(uvec3, { extend_amount, extend_amount, extend_amount });
vec = b.createBinOp(spv::OpShiftLeftLogical, ivec3, vec, extend_amount);
vec = b.createBinOp(spv::OpShiftRightArithmetic, ivec3, vec, extend_amount);
// normalize it
vec = convert_to_float(b, vec, DataType::C10, true);
vec = convert_to_float(b, utils, vec, DataType::C10, true);
b.makeReturn(false, vec);
b.setBuildPoint(last_build_point);
@ -261,39 +262,22 @@ static spv::Function *make_unpack_func(spv::Builder &b, const FeatureState &feat
decorations, &unpack_func_block);
spv::Id extracted = unpack_func->getParamId(0);
extracted = b.createUnaryOp(spv::OpBitcast, is_signed ? type_i32 : type_ui32, extracted);
std::vector<spv::Id> comps;
const spv::Id result_type = is_signed ? type_i32 : type_ui32;
extracted = b.createUnaryOp(spv::OpBitcast, result_type, extracted);
const auto comp_bits = 32 / comp_count;
spv::Id comp_bits_val = b.makeUintConstant(comp_bits);
std::vector<spv::Id> comps;
for (int i = 0; i < comp_count; ++i) {
spv::Id comp;
if (is_signed) {
comp = b.createTriOp(spv::OpBitFieldSExtract, type_i32, extracted, b.makeIntConstant(comp_bits * i), b.makeIntConstant(comp_bits));
} else {
comp = b.createTriOp(spv::OpBitFieldUExtract, type_ui32, extracted, b.makeIntConstant(comp_bits * i), b.makeIntConstant(comp_bits));
}
const spv::Op op = is_signed ? spv::OpBitFieldSExtract : spv::OpBitFieldUExtract;
spv::Id comp = b.createTriOp(op, result_type, extracted, b.makeUintConstant(comp_bits * i), comp_bits_val);
comps.push_back(comp);
}
auto output = b.createCompositeConstruct(output_type, comps);
if (is_signed) {
// Sign extended them. Thanks kd-11 for method.
spv::Id sign_check_vec_type = b.makeVectorType(b.makeBoolType(), comp_count);
std::vector<std::uint32_t> constants(comp_count, b.makeIntConstant(1 << (comp_bits - 1)));
std::vector<std::uint32_t> constant_bias(comp_count, b.makeIntConstant(1 << comp_bits));
spv::Id sign_check_vec = b.createBinOp(spv::OpSLessThan, sign_check_vec_type, output,
b.makeCompositeConstant(output_type, constants));
spv::Id bias_vec = b.makeCompositeConstant(output_type, constant_bias);
output = b.createTriOp(spv::OpSelect, output_type, sign_check_vec, output, b.createBinOp(spv::OpISub, output_type, output, bias_vec));
}
b.makeReturn(false, output);
b.setBuildPoint(last_build_point);
@ -358,15 +342,10 @@ static spv::Function *make_pack_func(spv::Builder &b, const FeatureState &featur
const spv::Id comp_type = b.getContainedTypeId(input_type);
auto output = b.makeUintConstant(0);
auto output = is_signed ? b.makeIntConstant(0) : b.makeUintConstant(0);
for (int i = 0; i < comp_count; ++i) {
auto comp = b.createBinOp(spv::OpVectorExtractDynamic, comp_type, extracted, b.makeIntConstant(i));
if (is_signed) {
comp = b.createUnaryOp(spv::OpBitcast, type_ui32, comp);
}
output = b.createOp(spv::OpBitFieldInsert, type_ui32, { output, comp, b.makeIntConstant(comp_bits * i), b.makeIntConstant(comp_bits) });
spv::Id comp = b.createBinOp(spv::OpVectorExtractDynamic, comp_type, extracted, b.makeIntConstant(i));
output = b.createOp(spv::OpBitFieldInsert, comp_type, { output, comp, b.makeIntConstant(comp_bits * i), b.makeIntConstant(comp_bits) });
}
output = b.createUnaryOp(spv::OpBitcast, type_f32, output);
@ -1442,27 +1421,26 @@ spv::Id unwrap_type(spv::Builder &b, spv::Id type) {
return type;
}
// will break in 32-bit host
static std::pair<float, float> get_int_normalize_range_constants(DataType type) {
static float get_int_normalize_range_constants(DataType type) {
switch (type) {
case DataType::UINT8:
return { 0.0f, 255.0f };
return 255.0f;
case DataType::INT8:
return { 128.0f, 127.0f };
return 127.0f;
case DataType::C10:
// signed 10-bit
return { 512.0f, 511.0f };
// signed 10-bit, with a range of [-2, 2]
return 255.0f;
case DataType::UINT16:
return { 0.0f, 65535.0f };
return 65535.0f;
case DataType::INT16:
return { 32768.0f, 32767.0f };
return 32767.0f;
case DataType::UINT32:
return { 0.0f, 4294967295.0f };
return 4294967295.0f;
case DataType::INT32:
return { 2147483648.0f, 2147483647.0f };
return 2147483647.0f;
default:
assert(false);
return { 0.0f, 0.0f };
return 0.0f;
}
}
@ -1477,7 +1455,7 @@ static spv::Id create_constant_vector_or_scalar(spv::Builder &b, spv::Id constan
return b.createCompositeConstruct(b.makeVectorType(b.getTypeId(constant), comp_count), oprs);
}
spv::Id convert_to_float(spv::Builder &b, spv::Id opr, DataType type, bool normal) {
spv::Id convert_to_float(spv::Builder &b, const SpirvUtilFunctions &utils, spv::Id opr, DataType type, bool normal) {
const auto spv_type = unwrap_type(b, b.getTypeId(opr));
const auto comp_count = b.isVector(opr) ? b.getNumComponents(opr) : 1;
const auto target_type = b.isVector(opr) ? b.makeVectorType(b.makeFloatType(32), comp_count) : b.makeFloatType(32);
@ -1492,22 +1470,15 @@ spv::Id convert_to_float(spv::Builder &b, spv::Id opr, DataType type, bool norma
}
if (normal) {
const auto constant_range = get_int_normalize_range_constants(type);
const auto normalizer = b.makeFloatConstant(constant_range.second);
const auto normalizer_vec = create_constant_vector_or_scalar(b, normalizer, comp_count);
const float normalizer = b.makeFloatConstant(get_int_normalize_range_constants(type));
const spv::Id normalizer_vec = create_constant_vector_or_scalar(b, normalizer, comp_count);
opr = b.createBinOp(spv::OpFDiv, target_type, opr, normalizer_vec);
if (is_sint) {
const auto zero_vec = create_constant_vector_or_scalar(b, b.makeFloatConstant(0.0f), comp_count);
const auto b_vec_type = make_vector_or_scalar_type(b, b.makeBoolType(), comp_count);
const auto normalizer_neg = b.makeFloatConstant(constant_range.first);
const auto normalize_vec_neg = create_constant_vector_or_scalar(b, normalizer_neg, comp_count);
opr = b.createTriOp(spv::OpSelect, target_type, b.createBinOp(spv::OpFOrdLessThan, b_vec_type, opr, zero_vec),
b.createBinOp(spv::OpFDiv, target_type, opr, normalize_vec_neg),
b.createBinOp(spv::OpFDiv, target_type, opr, normalizer_vec));
} else {
opr = b.createBinOp(spv::OpFDiv, target_type, opr, normalizer_vec);
// opr = max(-1.0f, opr) (or -2.0f for fx10)
float lower_bound = type == DataType::C10 ? -2.f : -1.f;
const spv::Id minus1 = create_constant_vector_or_scalar(b, b.makeFloatConstant(lower_bound), comp_count);
opr = b.createBuiltinCall(target_type, utils.std_builtins, GLSLstd450FMax, { opr, minus1 });
}
}
return opr;
@ -1523,26 +1494,16 @@ spv::Id convert_to_int(spv::Builder &b, const SpirvUtilFunctions &utils, spv::Id
const auto target_type = b.isVector(opr) ? b.makeVectorType(target_comp_type, comp_count) : target_comp_type;
if (normal) {
const auto constant_range = get_int_normalize_range_constants(type);
const auto normalizer = b.makeFloatConstant(constant_range.second);
const float constant_range = get_int_normalize_range_constants(type);
const spv::Id normalizer = b.makeFloatConstant(constant_range);
const auto normalizer_vec = create_constant_vector_or_scalar(b, normalizer, comp_count);
const auto range_begin_vec = create_constant_vector_or_scalar(b, b.makeFloatConstant(is_uint ? 0.f : -1.f), comp_count);
const auto range_end_vec = create_constant_vector_or_scalar(b, b.makeFloatConstant(1.f), comp_count);
const bool is_fx10 = type == DataType::C10; // fx10 range is [-2,2]
const auto range_begin_vec = create_constant_vector_or_scalar(b, b.makeFloatConstant(is_uint ? 0.f : (is_fx10 ? -2.f : -1.f)), comp_count);
const auto range_end_vec = create_constant_vector_or_scalar(b, b.makeFloatConstant(is_fx10 ? 2.f : 1.f), comp_count);
// opr = round(clamp(opr * norm), -1, 1)
opr = b.createBuiltinCall(opr_type, utils.std_builtins, GLSLstd450FClamp, { opr, range_begin_vec, range_end_vec });
if (is_uint) {
opr = b.createBinOp(spv::OpFMul, opr_type, opr, normalizer_vec);
} else {
const auto zero_vec = create_constant_vector_or_scalar(b, b.makeFloatConstant(0.f), comp_count);
const auto b_vec_type = make_vector_or_scalar_type(b, b.makeBoolType(), comp_count);
const auto normalizer_neg = b.makeFloatConstant(constant_range.first);
const auto normalize_vec_neg = create_constant_vector_or_scalar(b, normalizer_neg, comp_count);
opr = b.createTriOp(spv::OpSelect, opr_type, b.createBinOp(spv::OpFOrdLessThan, b_vec_type, opr, zero_vec),
b.createBinOp(spv::OpFMul, opr_type, opr, normalize_vec_neg),
b.createBinOp(spv::OpFMul, opr_type, opr, normalizer_vec));
}
opr = b.createBinOp(spv::OpFMul, opr_type, opr, normalizer_vec);
opr = b.createBuiltinCall(opr_type, utils.std_builtins, GLSLstd450Round, { opr });
}