diff --git a/src/video_core/host_shaders/astc_decoder.comp b/src/video_core/host_shaders/astc_decoder.comp index da21b4bde8..f0358f5330 100644 --- a/src/video_core/host_shaders/astc_decoder.comp +++ b/src/video_core/host_shaders/astc_decoder.comp @@ -83,6 +83,12 @@ int result_index = 0; uint result_vector_max_index; bool result_limit_reached = false; +// avoid intermediate result_vector storage during color decode phase +bool write_color_values = false; +uint color_values_direct[32]; +uint color_out_index = 0; +uint color_num_values = 0; + // EncodingData helpers uint Encoding(EncodingData val) { return bitfieldExtract(val.data, 0, 8); @@ -114,9 +120,110 @@ EncodingData CreateEncodingData(uint encoding, uint num_bits, uint bit_val, uint return EncodingData(((encoding) << 0u) | ((num_bits) << 8u) | ((bit_val) << 16u) | ((quint_trit_val) << 24u)); } +uint ReplicateBitTo9(uint bit); +uint FastReplicateTo8(uint value, uint num_bits); + +void EmitColorValue(EncodingData val) { + // write directly to color_values_direct[] + const uint encoding = Encoding(val); + const uint bitlen = NumBits(val); + const uint bitval = BitValue(val); + + if (encoding == JUST_BITS) { + color_values_direct[++color_out_index] = FastReplicateTo8(bitval, bitlen); + return; + } + + uint A = ReplicateBitTo9((bitval & 1)); + uint B = 0, C = 0, D = QuintTritValue(val); + + if (encoding == TRIT) { + switch (bitlen) { + case 1: + C = 204; + break; + case 2: { + C = 93; + const uint b = (bitval >> 1) & 1; + B = (b << 8) | (b << 4) | (b << 2) | (b << 1); + break; + } + case 3: { + C = 44; + const uint cb = (bitval >> 1) & 3; + B = (cb << 7) | (cb << 2) | cb; + break; + } + case 4: { + C = 22; + const uint dcb = (bitval >> 1) & 7; + B = (dcb << 6) | dcb; + break; + } + case 5: { + C = 11; + const uint edcb = (bitval >> 1) & 0xF; + B = (edcb << 5) | (edcb >> 2); + break; + } + case 6: { + C = 5; + const uint fedcb = (bitval >> 1) & 0x1F; + B = (fedcb << 4) | (fedcb >> 4); + break; + } + } + } else { // QUINT + switch (bitlen) { + case 1: + C = 113; + break; + case 2: { + C = 54; + const uint b = (bitval >> 1) & 1; + B = (b << 8) | (b << 3) | (b << 2); + break; + } + case 3: { + C = 26; + const uint cb = (bitval >> 1) & 3; + B = (cb << 7) | (cb << 1) | (cb >> 1); + break; + } + case 4: { + C = 13; + const uint dcb = (bitval >> 1) & 7; + B = (dcb << 6) | (dcb >> 1); + break; + } + case 5: { + C = 6; + const uint edcb = (bitval >> 1) & 0xF; + B = (edcb << 5) | (edcb >> 3); + break; + } + } + } + + uint T = (D * C) + B; + T ^= A; + T = (A & 0x80) | (T >> 2); + color_values_direct[++color_out_index] = T; +} + void ResultEmplaceBack(EncodingData val) { + if (write_color_values) { + if (color_out_index >= color_num_values) { + // avoid decoding more than needed by this phase + result_limit_reached = true; + return; + } + EmitColorValue(val); + return; + } + if (result_index >= result_vector_max_index) { // Alert callers to avoid decoding more than needed by this phase result_limit_reached = true; @@ -196,33 +303,36 @@ uint Hash52(uint p) { p ^= p >> 17; return p; } +struct PartitionTable { + uint s1, s2, s3, s4, s5, s6, s7, s8; + uint rnum; + bool small_block; +}; -uint Select2DPartition(uint seed, uint x, uint y, uint partition_count) { - if ((block_dims.y * block_dims.x) < 32) { - x <<= 1; - y <<= 1; - } +PartitionTable GetPartitionTable(uint seed, uint partition_count) { + PartitionTable pt; + pt.small_block = (block_dims.y * block_dims.x) < 32; seed += (partition_count - 1) * 1024; + uint rnum = Hash52(uint(seed)); + pt.rnum = rnum; - const uint rnum = Hash52(uint(seed)); - uint seed1 = uint(rnum & 0xF); - uint seed2 = uint((rnum >> 4) & 0xF); - uint seed3 = uint((rnum >> 8) & 0xF); - uint seed4 = uint((rnum >> 12) & 0xF); - uint seed5 = uint((rnum >> 16) & 0xF); - uint seed6 = uint((rnum >> 20) & 0xF); - uint seed7 = uint((rnum >> 24) & 0xF); - uint seed8 = uint((rnum >> 28) & 0xF); - - seed1 = (seed1 * seed1); - seed2 = (seed2 * seed2); - seed3 = (seed3 * seed3); - seed4 = (seed4 * seed4); - seed5 = (seed5 * seed5); - seed6 = (seed6 * seed6); - seed7 = (seed7 * seed7); - seed8 = (seed8 * seed8); + uint seed1 = (rnum & 0xF); + seed1 *= seed1; + uint seed2 = (rnum >> 4) & 0xF; + seed2 *= seed2; + uint seed3 = (rnum >> 8) & 0xF; + seed3 *= seed3; + uint seed4 = (rnum >> 12) & 0xF; + seed4 *= seed4; + uint seed5 = (rnum >> 16) & 0xF; + seed5 *= seed5; + uint seed6 = (rnum >> 20) & 0xF; + seed6 *= seed6; + uint seed7 = (rnum >> 24) & 0xF; + seed7 *= seed7; + uint seed8 = (rnum >> 28) & 0xF; + seed8 *= seed8; uint sh1, sh2; if ((seed & 1) > 0) { @@ -232,31 +342,37 @@ uint Select2DPartition(uint seed, uint x, uint y, uint partition_count) { sh1 = (partition_count == 3) ? 6 : 5; sh2 = (seed & 2) > 0 ? 4 : 5; } - seed1 >>= sh1; - seed2 >>= sh2; - seed3 >>= sh1; - seed4 >>= sh2; - seed5 >>= sh1; - seed6 >>= sh2; - seed7 >>= sh1; - seed8 >>= sh2; - uint a = seed1 * x + seed2 * y + (rnum >> 14); - uint b = seed3 * x + seed4 * y + (rnum >> 10); - uint c = seed5 * x + seed6 * y + (rnum >> 6); - uint d = seed7 * x + seed8 * y + (rnum >> 2); + pt.s1 = seed1 >> sh1; + pt.s2 = seed2 >> sh2; + pt.s3 = seed3 >> sh1; + pt.s4 = seed4 >> sh2; + pt.s5 = seed5 >> sh1; + pt.s6 = seed6 >> sh2; + pt.s7 = seed7 >> sh1; + pt.s8 = seed8 >> sh2; + + return pt; + } + +uint SelectPartition(PartitionTable pt, uint x, uint y, uint partition_count) { + if (pt.small_block) { + x <<= 1; + y <<= 1; + } + + uint a = pt.s1 * x + pt.s2 * y + (pt.rnum >> 14); + uint b = pt.s3 * x + pt.s4 * y + (pt.rnum >> 10); + uint c = pt.s5 * x + pt.s6 * y + (pt.rnum >> 6); + uint d = pt.s7 * x + pt.s8 * y + (pt.rnum >> 2); a &= 0x3F; b &= 0x3F; c &= 0x3F; d &= 0x3F; - if (partition_count < 4) { - d = 0; - } - if (partition_count < 3) { - c = 0; - } + if (partition_count < 4) d = 0; + if (partition_count < 3) c = 0; if (a >= b && a >= c && a >= d) { return 0; @@ -457,7 +573,7 @@ void DecodeIntegerSequence(uint max_range, uint num_values) { } } -void DecodeColorValues(uvec4 modes, uint num_partitions, uint color_data_bits, out uint color_values[32]) { +void DecodeColorValues(uvec4 modes, uint num_partitions, uint color_data_bits) { uint num_values = 0; for (uint i = 0; i < num_partitions; i++) { num_values += ((modes[i] >> 2) + 1) << 1; @@ -471,104 +587,21 @@ void DecodeColorValues(uvec4 modes, uint num_partitions, uint color_data_bits, o break; } } - DecodeIntegerSequence(range - 1, num_values); - uint out_index = 0; - for (int itr = 0; itr < result_index; ++itr) { - if (out_index >= num_values) { - break; - } - const EncodingData val = GetEncodingFromVector(itr); - const uint encoding = Encoding(val); - const uint bitlen = NumBits(val); - const uint bitval = BitValue(val); - uint A = 0, B = 0, C = 0, D = 0; - A = ReplicateBitTo9((bitval & 1)); - switch (encoding) { - case JUST_BITS: - color_values[++out_index] = FastReplicateTo8(bitval, bitlen); - break; - case TRIT: { - D = QuintTritValue(val); - switch (bitlen) { - case 1: - C = 204; - break; - case 2: { - C = 93; - const uint b = (bitval >> 1) & 1; - B = (b << 8) | (b << 4) | (b << 2) | (b << 1); - break; - } - case 3: { - C = 44; - const uint cb = (bitval >> 1) & 3; - B = (cb << 7) | (cb << 2) | cb; - break; - } - case 4: { - C = 22; - const uint dcb = (bitval >> 1) & 7; - B = (dcb << 6) | dcb; - break; - } - case 5: { - C = 11; - const uint edcb = (bitval >> 1) & 0xF; - B = (edcb << 5) | (edcb >> 2); - break; - } - case 6: { - C = 5; - const uint fedcb = (bitval >> 1) & 0x1F; - B = (fedcb << 4) | (fedcb >> 4); - break; - } - } - break; - } - case QUINT: { - D = QuintTritValue(val); - switch (bitlen) { - case 1: - C = 113; - break; - case 2: { - C = 54; - const uint b = (bitval >> 1) & 1; - B = (b << 8) | (b << 3) | (b << 2); - break; - } - case 3: { - C = 26; - const uint cb = (bitval >> 1) & 3; - B = (cb << 7) | (cb << 1) | (cb >> 1); - break; - } - case 4: { - C = 13; - const uint dcb = (bitval >> 1) & 7; - B = (dcb << 6) | (dcb >> 1); - break; - } - case 5: { - C = 6; - const uint edcb = (bitval >> 1) & 0xF; - B = (edcb << 5) | (edcb >> 3); - break; - } - } - break; - } - } - if (encoding != JUST_BITS) { - uint T = (D * C) + B; - T ^= A; - T = (A & 0x80) | (T >> 2); - color_values[++out_index] = T; - } + // Decode directly into color_values_direct[] + write_color_values = true; + color_out_index = 0; + color_num_values = num_values; + for (uint i = 0; i < 32; ++i) { + color_values_direct[i] = 0; } + + DecodeIntegerSequence(range - 1, num_values); + + write_color_values = false; } + + ivec2 BitTransferSigned(int a, int b) { ivec2 transferred; transferred.y = b >> 1; @@ -730,7 +763,7 @@ uint UnquantizeTexelWeight(EncodingData val) { uint encoding = Encoding(val), bitlen = NumBits(val), bitval = BitValue(val); if (encoding == JUST_BITS) { return (bitlen >= 1 && bitlen <= 5) - ? uint(floor(0.5f + float(bitval) * 64.0f / float((1 << bitlen) - 1))) + ? ((bitval * 64) + ((1 << bitlen) - 1) / 2) / ((1 << bitlen) - 1) : FastReplicateTo6(bitval, bitlen); } else if (encoding == TRIT || encoding == QUINT) { uint B = 0, C = 0, D = 0; @@ -1069,13 +1102,12 @@ void DecompressBlock(ivec3 coord) { uvec4 endpoints0[4]; uvec4 endpoints1[4]; { - // This decode phase should at most push 32 elements into the vector - result_vector_max_index = 32; - uint color_values[32]; + // Decode directly into color_values_direct[] (no intermediate result_vector storage) + result_limit_reached = false; uint colvals_index = 0; - DecodeColorValues(color_endpoint_mode, num_partitions, color_data_bits, color_values); + DecodeColorValues(color_endpoint_mode, num_partitions, color_data_bits); for (uint i = 0; i < num_partitions; i++) { - ComputeEndpoints(endpoints0[i], endpoints1[i], color_endpoint_mode[i], color_values, + ComputeEndpoints(endpoints0[i], endpoints1[i], color_endpoint_mode[i], color_values_direct, colvals_index); } } @@ -1106,11 +1138,15 @@ void DecompressBlock(ivec3 coord) { DecodeIntegerSequence(max_weight, GetNumWeightValues(size_params, dual_plane)); UnquantizeTexelWeights(size_params, dual_plane); + PartitionTable pt; + if (num_partitions > 1) { + pt = GetPartitionTable(partition_index, num_partitions); + } for (uint j = 0; j < block_dims.y; j++) { for (uint i = 0; i < block_dims.x; i++) { uint local_partition = 0; if (num_partitions > 1) { - local_partition = Select2DPartition(partition_index, i, j, num_partitions); + local_partition = SelectPartition(pt, i, j, num_partitions); } const uvec4 C0 = ReplicateByteTo16(endpoints0[local_partition]); const uvec4 C1 = ReplicateByteTo16(endpoints1[local_partition]);