mirror of
https://gitee.com/openharmony/third_party_spirv-tools
synced 2024-11-23 07:20:28 +00:00
Remove MarkV and Stats code. (#2576)
* Remove MarkV and Stats code. This Cl removes the MarkV and Stats code from SPIRV-Tools. This code was unused and currently un-maintained.
This commit is contained in:
parent
3b5ab540ca
commit
42abaa099a
@ -59,7 +59,7 @@ build:
|
||||
|
||||
build_script:
|
||||
- mkdir build && cd build
|
||||
- cmake -GNinja -DSPIRV_BUILD_COMPRESSION=ON -DCMAKE_BUILD_TYPE=%CONFIGURATION% -DCMAKE_INSTALL_PREFIX=install -DRE2_BUILD_TESTING=OFF ..
|
||||
- cmake -GNinja -DCMAKE_BUILD_TYPE=%CONFIGURATION% -DCMAKE_INSTALL_PREFIX=install -DRE2_BUILD_TESTING=OFF ..
|
||||
- ninja install
|
||||
|
||||
test_script:
|
||||
|
@ -13,7 +13,6 @@ SPVTOOLS_SRC_FILES := \
|
||||
source/ext_inst.cpp \
|
||||
source/enum_string_mapping.cpp \
|
||||
source/extensions.cpp \
|
||||
source/id_descriptor.cpp \
|
||||
source/libspirv.cpp \
|
||||
source/name_mapper.cpp \
|
||||
source/opcode.cpp \
|
||||
|
@ -69,6 +69,10 @@ if(NOT ${SKIP_SPIRV_TOOLS_INSTALL})
|
||||
endif()
|
||||
|
||||
option(SPIRV_BUILD_COMPRESSION "Build SPIR-V compressing codec" OFF)
|
||||
if(SPIRV_BUILD_COMPRESSION)
|
||||
message(FATAL_ERROR "SPIR-V compression codec has been removed from SPIR-V tools. "
|
||||
"Please remove SPIRV_BUILD_COMPRESSION from your build options.")
|
||||
endif(SPIRV_BUILD_COMPRESSION)
|
||||
|
||||
option(SPIRV_WERROR "Enable error on warning" ON)
|
||||
if(("${CMAKE_CXX_COMPILER_ID}" MATCHES "GNU") OR (("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") AND (NOT CMAKE_CXX_SIMULATE_ID STREQUAL "MSVC")))
|
||||
@ -257,9 +261,6 @@ endif()
|
||||
|
||||
set(SPIRV_LIBRARIES "-lSPIRV-Tools -lSPIRV-Tools-link -lSPIRV-Tools-opt")
|
||||
set(SPIRV_SHARED_LIBRARIES "-lSPIRV-Tools-shared")
|
||||
if(SPIRV_BUILD_COMPRESSION)
|
||||
set(SPIRV_LIBRARIES "${SPIRV_LIBRARIES} -lSPIRV-Tools-comp")
|
||||
endif(SPIRV_BUILD_COMPRESSION)
|
||||
|
||||
# Build pkg-config file
|
||||
# Use a first-class target so it's regenerated when relevant files are updated.
|
||||
|
@ -307,8 +307,6 @@ The following CMake options are supported:
|
||||
the command line tools. This will prevent the tests from being built.
|
||||
* `SPIRV_SKIP_EXECUTABLES={ON|OFF}`, default `OFF`- Build only the library, not
|
||||
the command line tools and tests.
|
||||
* `SPIRV_BUILD_COMPRESSION={ON|OFF}`, default `OFF`- Build SPIR-V compressing
|
||||
codec.
|
||||
* `SPIRV_USE_SANITIZER=<sanitizer>`, default is no sanitizing - On UNIX
|
||||
platforms with an appropriate version of `clang` this option enables the use
|
||||
of the sanitizers documented [here][clang-sanitizers].
|
||||
|
@ -44,7 +44,7 @@ mkdir build && cd $SRC/build
|
||||
# Invoke the build.
|
||||
BUILD_SHA=${KOKORO_GITHUB_COMMIT:-$KOKORO_GITHUB_PULL_REQUEST_COMMIT}
|
||||
echo $(date): Starting build...
|
||||
cmake -DCMAKE_BUILD_TYPE=Release -DANDROID_NATIVE_API_LEVEL=android-14 -DANDROID_ABI="armeabi-v7a with NEON" -DSPIRV_BUILD_COMPRESSION=ON -DSPIRV_SKIP_TESTS=ON -DCMAKE_TOOLCHAIN_FILE=$TOOLCHAIN_PATH -GNinja -DANDROID_NDK=$ANDROID_NDK ..
|
||||
cmake -DCMAKE_BUILD_TYPE=Release -DANDROID_NATIVE_API_LEVEL=android-14 -DANDROID_ABI="armeabi-v7a with NEON" -DSPIRV_SKIP_TESTS=ON -DCMAKE_TOOLCHAIN_FILE=$TOOLCHAIN_PATH -GNinja -DANDROID_NDK=$ANDROID_NDK ..
|
||||
|
||||
echo $(date): Build everything...
|
||||
ninja
|
||||
|
@ -63,7 +63,7 @@ if "%KOKORO_GITHUB_COMMIT%." == "." (
|
||||
set BUILD_SHA=%KOKORO_GITHUB_COMMIT%
|
||||
)
|
||||
|
||||
set CMAKE_FLAGS=-DCMAKE_INSTALL_PREFIX=%KOKORO_ARTIFACTS_DIR%\install -GNinja -DSPIRV_BUILD_COMPRESSION=ON -DCMAKE_BUILD_TYPE=%BUILD_TYPE% -DRE2_BUILD_TESTING=OFF -DCMAKE_C_COMPILER=cl.exe -DCMAKE_CXX_COMPILER=cl.exe
|
||||
set CMAKE_FLAGS=-DCMAKE_INSTALL_PREFIX=%KOKORO_ARTIFACTS_DIR%\install -GNinja -DCMAKE_BUILD_TYPE=%BUILD_TYPE% -DRE2_BUILD_TESTING=OFF -DCMAKE_C_COMPILER=cl.exe -DCMAKE_CXX_COMPILER=cl.exe
|
||||
|
||||
:: Skip building tests for VS2013
|
||||
if %VS_VERSION% == 2013 (
|
||||
|
@ -196,7 +196,6 @@ set_source_files_properties(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/pch_source.cpp
|
||||
PROPERTIES OBJECT_DEPENDS "${PCH_DEPENDS}")
|
||||
|
||||
add_subdirectory(comp)
|
||||
add_subdirectory(opt)
|
||||
add_subdirectory(reduce)
|
||||
add_subdirectory(link)
|
||||
@ -221,7 +220,6 @@ set(SPIRV_SOURCES
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/enum_string_mapping.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ext_inst.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extensions.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/id_descriptor.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/instruction.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/latest_version_glsl_std_450_header.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/latest_version_opencl_std_header.h
|
||||
@ -254,7 +252,6 @@ set(SPIRV_SOURCES
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/enum_string_mapping.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ext_inst.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extensions.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/id_descriptor.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/libspirv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/name_mapper.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/opcode.cpp
|
||||
|
@ -1,52 +0,0 @@
|
||||
# Copyright (c) 2017 Google Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
if(SPIRV_BUILD_COMPRESSION)
|
||||
add_library(SPIRV-Tools-comp
|
||||
bit_stream.cpp
|
||||
bit_stream.h
|
||||
huffman_codec.h
|
||||
markv_codec.cpp
|
||||
markv_codec.h
|
||||
markv.cpp
|
||||
markv.h
|
||||
markv_decoder.cpp
|
||||
markv_decoder.h
|
||||
markv_encoder.cpp
|
||||
markv_encoder.h
|
||||
markv_logger.h
|
||||
move_to_front.h
|
||||
move_to_front.cpp)
|
||||
|
||||
spvtools_default_compile_options(SPIRV-Tools-comp)
|
||||
target_include_directories(SPIRV-Tools-comp
|
||||
PUBLIC ${spirv-tools_SOURCE_DIR}/include
|
||||
PUBLIC ${SPIRV_HEADER_INCLUDE_DIR}
|
||||
PRIVATE ${spirv-tools_BINARY_DIR}
|
||||
)
|
||||
|
||||
target_link_libraries(SPIRV-Tools-comp
|
||||
PUBLIC ${SPIRV_TOOLS})
|
||||
|
||||
set_property(TARGET SPIRV-Tools-comp PROPERTY FOLDER "SPIRV-Tools libraries")
|
||||
spvtools_check_symbol_exports(SPIRV-Tools-comp)
|
||||
|
||||
if(ENABLE_SPIRV_TOOLS_INSTALL)
|
||||
install(TARGETS SPIRV-Tools-comp
|
||||
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
|
||||
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR})
|
||||
endif(ENABLE_SPIRV_TOOLS_INSTALL)
|
||||
|
||||
endif(SPIRV_BUILD_COMPRESSION)
|
@ -1,348 +0,0 @@
|
||||
// Copyright (c) 2017 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cstring>
|
||||
#include <sstream>
|
||||
#include <type_traits>
|
||||
|
||||
#include "source/comp/bit_stream.h"
|
||||
|
||||
namespace spvtools {
|
||||
namespace comp {
|
||||
namespace {
|
||||
|
||||
// Returns if the system is little-endian. Unfortunately only works during
|
||||
// runtime.
|
||||
bool IsLittleEndian() {
|
||||
// This constant value allows the detection of the host machine's endianness.
|
||||
// Accessing it as an array of bytes is valid due to C++11 section 3.10
|
||||
// paragraph 10.
|
||||
static const uint16_t kFF00 = 0xff00;
|
||||
return reinterpret_cast<const unsigned char*>(&kFF00)[0] == 0;
|
||||
}
|
||||
|
||||
// Copies bytes from the given buffer to a uint64_t buffer.
|
||||
// Motivation: casting uint64_t* to uint8_t* is ok. Casting in the other
|
||||
// direction is only advisable if uint8_t* is aligned to 64-bit word boundary.
|
||||
std::vector<uint64_t> ToBuffer64(const void* buffer, size_t num_bytes) {
|
||||
std::vector<uint64_t> out;
|
||||
out.resize((num_bytes + 7) / 8, 0);
|
||||
memcpy(out.data(), buffer, num_bytes);
|
||||
return out;
|
||||
}
|
||||
|
||||
// Copies uint8_t buffer to a uint64_t buffer.
|
||||
std::vector<uint64_t> ToBuffer64(const std::vector<uint8_t>& in) {
|
||||
return ToBuffer64(in.data(), in.size());
|
||||
}
|
||||
|
||||
// Returns uint64_t containing the same bits as |val|.
|
||||
// Type size must be less than 8 bytes.
|
||||
template <typename T>
|
||||
uint64_t ToU64(T val) {
|
||||
static_assert(sizeof(T) <= 8, "Type size too big");
|
||||
uint64_t val64 = 0;
|
||||
std::memcpy(&val64, &val, sizeof(T));
|
||||
return val64;
|
||||
}
|
||||
|
||||
// Returns value of type T containing the same bits as |val64|.
|
||||
// Type size must be less than 8 bytes. Upper (unused) bits of |val64| must be
|
||||
// zero (irrelevant, but is checked with assertion).
|
||||
template <typename T>
|
||||
T FromU64(uint64_t val64) {
|
||||
assert(sizeof(T) == 8 || (val64 >> (sizeof(T) * 8)) == 0);
|
||||
static_assert(sizeof(T) <= 8, "Type size too big");
|
||||
T val = 0;
|
||||
std::memcpy(&val, &val64, sizeof(T));
|
||||
return val;
|
||||
}
|
||||
|
||||
// Writes bits from |val| to |writer| in chunks of size |chunk_length|.
|
||||
// Signal bit is used to signal if the reader should expect another chunk:
|
||||
// 0 - no more chunks to follow
|
||||
// 1 - more chunks to follow
|
||||
// If number of written bits reaches |max_payload| last chunk is truncated.
|
||||
void WriteVariableWidthInternal(BitWriterInterface* writer, uint64_t val,
|
||||
size_t chunk_length, size_t max_payload) {
|
||||
assert(chunk_length > 0);
|
||||
assert(chunk_length < max_payload);
|
||||
assert(max_payload == 64 || (val >> max_payload) == 0);
|
||||
|
||||
if (val == 0) {
|
||||
// Split in two writes for more readable logging.
|
||||
writer->WriteBits(0, chunk_length);
|
||||
writer->WriteBits(0, 1);
|
||||
return;
|
||||
}
|
||||
|
||||
size_t payload_written = 0;
|
||||
|
||||
while (val) {
|
||||
if (payload_written + chunk_length >= max_payload) {
|
||||
// This has to be the last chunk.
|
||||
// There is no need for the signal bit and the chunk can be truncated.
|
||||
const size_t left_to_write = max_payload - payload_written;
|
||||
assert((val >> left_to_write) == 0);
|
||||
writer->WriteBits(val, left_to_write);
|
||||
break;
|
||||
}
|
||||
|
||||
writer->WriteBits(val, chunk_length);
|
||||
payload_written += chunk_length;
|
||||
val = val >> chunk_length;
|
||||
|
||||
// Write a single bit to signal if there is more to come.
|
||||
writer->WriteBits(val ? 1 : 0, 1);
|
||||
}
|
||||
}
|
||||
|
||||
// Reads data written with WriteVariableWidthInternal. |chunk_length| and
|
||||
// |max_payload| should be identical to those used to write the data.
|
||||
// Returns false if the stream ends prematurely.
|
||||
bool ReadVariableWidthInternal(BitReaderInterface* reader, uint64_t* val,
|
||||
size_t chunk_length, size_t max_payload) {
|
||||
assert(chunk_length > 0);
|
||||
assert(chunk_length <= max_payload);
|
||||
size_t payload_read = 0;
|
||||
|
||||
while (payload_read + chunk_length < max_payload) {
|
||||
uint64_t bits = 0;
|
||||
if (reader->ReadBits(&bits, chunk_length) != chunk_length) return false;
|
||||
|
||||
*val |= bits << payload_read;
|
||||
payload_read += chunk_length;
|
||||
|
||||
uint64_t more_to_come = 0;
|
||||
if (reader->ReadBits(&more_to_come, 1) != 1) return false;
|
||||
|
||||
if (!more_to_come) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Need to read the last chunk which may be truncated. No signal bit follows.
|
||||
uint64_t bits = 0;
|
||||
const size_t left_to_read = max_payload - payload_read;
|
||||
if (reader->ReadBits(&bits, left_to_read) != left_to_read) return false;
|
||||
|
||||
*val |= bits << payload_read;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Calls WriteVariableWidthInternal with the right max_payload argument.
|
||||
template <typename T>
|
||||
void WriteVariableWidthUnsigned(BitWriterInterface* writer, T val,
|
||||
size_t chunk_length) {
|
||||
static_assert(std::is_unsigned<T>::value, "Type must be unsigned");
|
||||
static_assert(std::is_integral<T>::value, "Type must be integral");
|
||||
WriteVariableWidthInternal(writer, val, chunk_length, sizeof(T) * 8);
|
||||
}
|
||||
|
||||
// Calls ReadVariableWidthInternal with the right max_payload argument.
|
||||
template <typename T>
|
||||
bool ReadVariableWidthUnsigned(BitReaderInterface* reader, T* val,
|
||||
size_t chunk_length) {
|
||||
static_assert(std::is_unsigned<T>::value, "Type must be unsigned");
|
||||
static_assert(std::is_integral<T>::value, "Type must be integral");
|
||||
uint64_t val64 = 0;
|
||||
if (!ReadVariableWidthInternal(reader, &val64, chunk_length, sizeof(T) * 8))
|
||||
return false;
|
||||
*val = static_cast<T>(val64);
|
||||
assert(*val == val64);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Encodes signed |val| to an unsigned value and calls
|
||||
// WriteVariableWidthInternal with the right max_payload argument.
|
||||
template <typename T>
|
||||
void WriteVariableWidthSigned(BitWriterInterface* writer, T val,
|
||||
size_t chunk_length, size_t zigzag_exponent) {
|
||||
static_assert(std::is_signed<T>::value, "Type must be signed");
|
||||
static_assert(std::is_integral<T>::value, "Type must be integral");
|
||||
WriteVariableWidthInternal(writer, EncodeZigZag(val, zigzag_exponent),
|
||||
chunk_length, sizeof(T) * 8);
|
||||
}
|
||||
|
||||
// Calls ReadVariableWidthInternal with the right max_payload argument
|
||||
// and decodes the value.
|
||||
template <typename T>
|
||||
bool ReadVariableWidthSigned(BitReaderInterface* reader, T* val,
|
||||
size_t chunk_length, size_t zigzag_exponent) {
|
||||
static_assert(std::is_signed<T>::value, "Type must be signed");
|
||||
static_assert(std::is_integral<T>::value, "Type must be integral");
|
||||
uint64_t encoded = 0;
|
||||
if (!ReadVariableWidthInternal(reader, &encoded, chunk_length, sizeof(T) * 8))
|
||||
return false;
|
||||
|
||||
const int64_t decoded = DecodeZigZag(encoded, zigzag_exponent);
|
||||
|
||||
*val = static_cast<T>(decoded);
|
||||
assert(*val == decoded);
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void BitWriterInterface::WriteVariableWidthU64(uint64_t val,
|
||||
size_t chunk_length) {
|
||||
WriteVariableWidthUnsigned(this, val, chunk_length);
|
||||
}
|
||||
|
||||
void BitWriterInterface::WriteVariableWidthU32(uint32_t val,
|
||||
size_t chunk_length) {
|
||||
WriteVariableWidthUnsigned(this, val, chunk_length);
|
||||
}
|
||||
|
||||
void BitWriterInterface::WriteVariableWidthU16(uint16_t val,
|
||||
size_t chunk_length) {
|
||||
WriteVariableWidthUnsigned(this, val, chunk_length);
|
||||
}
|
||||
|
||||
void BitWriterInterface::WriteVariableWidthS64(int64_t val, size_t chunk_length,
|
||||
size_t zigzag_exponent) {
|
||||
WriteVariableWidthSigned(this, val, chunk_length, zigzag_exponent);
|
||||
}
|
||||
|
||||
BitWriterWord64::BitWriterWord64(size_t reserve_bits) : end_(0) {
|
||||
buffer_.reserve(NumBitsToNumWords<64>(reserve_bits));
|
||||
}
|
||||
|
||||
void BitWriterWord64::WriteBits(uint64_t bits, size_t num_bits) {
|
||||
// Check that |bits| and |num_bits| are valid and consistent.
|
||||
assert(num_bits <= 64);
|
||||
const bool is_little_endian = IsLittleEndian();
|
||||
assert(is_little_endian && "Big-endian architecture support not implemented");
|
||||
if (!is_little_endian) return;
|
||||
|
||||
if (num_bits == 0) return;
|
||||
|
||||
bits = GetLowerBits(bits, num_bits);
|
||||
|
||||
EmitSequence(bits, num_bits);
|
||||
|
||||
// Offset from the start of the current word.
|
||||
const size_t offset = end_ % 64;
|
||||
|
||||
if (offset == 0) {
|
||||
// If no offset, simply add |bits| as a new word to the buffer_.
|
||||
buffer_.push_back(bits);
|
||||
} else {
|
||||
// Shift bits and add them to the current word after offset.
|
||||
const uint64_t first_word = bits << offset;
|
||||
buffer_.back() |= first_word;
|
||||
|
||||
// If we don't overflow to the next word, there is nothing more to do.
|
||||
|
||||
if (offset + num_bits > 64) {
|
||||
// We overflow to the next word.
|
||||
const uint64_t second_word = bits >> (64 - offset);
|
||||
// Add remaining bits as a new word to buffer_.
|
||||
buffer_.push_back(second_word);
|
||||
}
|
||||
}
|
||||
|
||||
// Move end_ into position for next write.
|
||||
end_ += num_bits;
|
||||
assert(buffer_.size() * 64 >= end_);
|
||||
}
|
||||
|
||||
bool BitReaderInterface::ReadVariableWidthU64(uint64_t* val,
|
||||
size_t chunk_length) {
|
||||
return ReadVariableWidthUnsigned(this, val, chunk_length);
|
||||
}
|
||||
|
||||
bool BitReaderInterface::ReadVariableWidthU32(uint32_t* val,
|
||||
size_t chunk_length) {
|
||||
return ReadVariableWidthUnsigned(this, val, chunk_length);
|
||||
}
|
||||
|
||||
bool BitReaderInterface::ReadVariableWidthU16(uint16_t* val,
|
||||
size_t chunk_length) {
|
||||
return ReadVariableWidthUnsigned(this, val, chunk_length);
|
||||
}
|
||||
|
||||
bool BitReaderInterface::ReadVariableWidthS64(int64_t* val, size_t chunk_length,
|
||||
size_t zigzag_exponent) {
|
||||
return ReadVariableWidthSigned(this, val, chunk_length, zigzag_exponent);
|
||||
}
|
||||
|
||||
BitReaderWord64::BitReaderWord64(std::vector<uint64_t>&& buffer)
|
||||
: buffer_(std::move(buffer)), pos_(0) {}
|
||||
|
||||
BitReaderWord64::BitReaderWord64(const std::vector<uint8_t>& buffer)
|
||||
: buffer_(ToBuffer64(buffer)), pos_(0) {}
|
||||
|
||||
BitReaderWord64::BitReaderWord64(const void* buffer, size_t num_bytes)
|
||||
: buffer_(ToBuffer64(buffer, num_bytes)), pos_(0) {}
|
||||
|
||||
size_t BitReaderWord64::ReadBits(uint64_t* bits, size_t num_bits) {
|
||||
assert(num_bits <= 64);
|
||||
const bool is_little_endian = IsLittleEndian();
|
||||
assert(is_little_endian && "Big-endian architecture support not implemented");
|
||||
if (!is_little_endian) return 0;
|
||||
|
||||
if (ReachedEnd()) return 0;
|
||||
|
||||
// Index of the current word.
|
||||
const size_t index = pos_ / 64;
|
||||
|
||||
// Bit position in the current word where we start reading.
|
||||
const size_t offset = pos_ % 64;
|
||||
|
||||
// Read all bits from the current word (it might be too much, but
|
||||
// excessive bits will be removed later).
|
||||
*bits = buffer_[index] >> offset;
|
||||
|
||||
const size_t num_read_from_first_word = std::min(64 - offset, num_bits);
|
||||
pos_ += num_read_from_first_word;
|
||||
|
||||
if (pos_ >= buffer_.size() * 64) {
|
||||
// Reached end of buffer_.
|
||||
EmitSequence(*bits, num_read_from_first_word);
|
||||
return num_read_from_first_word;
|
||||
}
|
||||
|
||||
if (offset + num_bits > 64) {
|
||||
// Requested |num_bits| overflows to next word.
|
||||
// Write all bits from the beginning of next word to *bits after offset.
|
||||
*bits |= buffer_[index + 1] << (64 - offset);
|
||||
pos_ += offset + num_bits - 64;
|
||||
}
|
||||
|
||||
// We likely have written more bits than requested. Clear excessive bits.
|
||||
*bits = GetLowerBits(*bits, num_bits);
|
||||
EmitSequence(*bits, num_bits);
|
||||
return num_bits;
|
||||
}
|
||||
|
||||
bool BitReaderWord64::ReachedEnd() const { return pos_ >= buffer_.size() * 64; }
|
||||
|
||||
bool BitReaderWord64::OnlyZeroesLeft() const {
|
||||
if (ReachedEnd()) return true;
|
||||
|
||||
const size_t index = pos_ / 64;
|
||||
if (index < buffer_.size() - 1) return false;
|
||||
|
||||
assert(index == buffer_.size() - 1);
|
||||
|
||||
const size_t offset = pos_ % 64;
|
||||
const uint64_t remaining_bits = buffer_[index] >> offset;
|
||||
return !remaining_bits;
|
||||
}
|
||||
|
||||
} // namespace comp
|
||||
} // namespace spvtools
|
@ -1,280 +0,0 @@
|
||||
// Copyright (c) 2017 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Contains utils for reading, writing and debug printing bit streams.
|
||||
|
||||
#ifndef SOURCE_COMP_BIT_STREAM_H_
|
||||
#define SOURCE_COMP_BIT_STREAM_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <bitset>
|
||||
#include <cassert>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace spvtools {
|
||||
namespace comp {
|
||||
|
||||
// Terminology:
|
||||
// Bits - usually used for a uint64 word, first bit is the lowest.
|
||||
// Stream - std::string of '0' and '1', read left-to-right,
|
||||
// i.e. first bit is at the front and not at the end as in
|
||||
// std::bitset::to_string().
|
||||
// Bitset - std::bitset corresponding to uint64 bits and to reverse(stream).
|
||||
|
||||
// Converts number of bits to a respective number of chunks of size N.
|
||||
// For example NumBitsToNumWords<8> returns how many bytes are needed to store
|
||||
// |num_bits|.
|
||||
template <size_t N>
|
||||
inline size_t NumBitsToNumWords(size_t num_bits) {
|
||||
return (num_bits + (N - 1)) / N;
|
||||
}
|
||||
|
||||
// Returns value of the same type as |in|, where all but the first |num_bits|
|
||||
// are set to zero.
|
||||
template <typename T>
|
||||
inline T GetLowerBits(T in, size_t num_bits) {
|
||||
return sizeof(T) * 8 == num_bits ? in : in & T((T(1) << num_bits) - T(1));
|
||||
}
|
||||
|
||||
// Encodes signed integer as unsigned. This is a generalized version of
|
||||
// EncodeZigZag, designed to favor small positive numbers.
|
||||
// Values are transformed in blocks of 2^|block_exponent|.
|
||||
// If |block_exponent| is zero, then this degenerates into normal EncodeZigZag.
|
||||
// Example when |block_exponent| is 1 (return value is the index):
|
||||
// 0, 1, -1, -2, 2, 3, -3, -4, 4, 5, -5, -6, 6, 7, -7, -8
|
||||
// Example when |block_exponent| is 2:
|
||||
// 0, 1, 2, 3, -1, -2, -3, -4, 4, 5, 6, 7, -5, -6, -7, -8
|
||||
inline uint64_t EncodeZigZag(int64_t val, size_t block_exponent) {
|
||||
assert(block_exponent < 64);
|
||||
const uint64_t uval = static_cast<uint64_t>(val >= 0 ? val : -val - 1);
|
||||
const uint64_t block_num =
|
||||
((uval >> block_exponent) << 1) + (val >= 0 ? 0 : 1);
|
||||
const uint64_t pos = GetLowerBits(uval, block_exponent);
|
||||
return (block_num << block_exponent) + pos;
|
||||
}
|
||||
|
||||
// Decodes signed integer encoded with EncodeZigZag. |block_exponent| must be
|
||||
// the same.
|
||||
inline int64_t DecodeZigZag(uint64_t val, size_t block_exponent) {
|
||||
assert(block_exponent < 64);
|
||||
const uint64_t block_num = val >> block_exponent;
|
||||
const uint64_t pos = GetLowerBits(val, block_exponent);
|
||||
if (block_num & 1) {
|
||||
// Negative.
|
||||
return -1LL - ((block_num >> 1) << block_exponent) - pos;
|
||||
} else {
|
||||
// Positive.
|
||||
return ((block_num >> 1) << block_exponent) + pos;
|
||||
}
|
||||
}
|
||||
|
||||
// Converts first |num_bits| stored in uint64 to a left-to-right stream of bits.
|
||||
inline std::string BitsToStream(uint64_t bits, size_t num_bits = 64) {
|
||||
std::bitset<64> bitset(bits);
|
||||
std::string str = bitset.to_string().substr(64 - num_bits);
|
||||
std::reverse(str.begin(), str.end());
|
||||
return str;
|
||||
}
|
||||
|
||||
// Base class for writing sequences of bits.
|
||||
class BitWriterInterface {
|
||||
public:
|
||||
BitWriterInterface() = default;
|
||||
virtual ~BitWriterInterface() = default;
|
||||
|
||||
// Writes lower |num_bits| in |bits| to the stream.
|
||||
// |num_bits| must be no greater than 64.
|
||||
virtual void WriteBits(uint64_t bits, size_t num_bits) = 0;
|
||||
|
||||
// Writes bits from value of type |T| to the stream. No encoding is done.
|
||||
// Always writes 8 * sizeof(T) bits.
|
||||
template <typename T>
|
||||
void WriteUnencoded(T val) {
|
||||
static_assert(sizeof(T) <= 64, "Type size too large");
|
||||
uint64_t bits = 0;
|
||||
memcpy(&bits, &val, sizeof(T));
|
||||
WriteBits(bits, sizeof(T) * 8);
|
||||
}
|
||||
|
||||
// Writes |val| in chunks of size |chunk_length| followed by a signal bit:
|
||||
// 0 - no more chunks to follow
|
||||
// 1 - more chunks to follow
|
||||
// for example 255 is encoded into 1111 1 1111 0 for chunk length 4.
|
||||
// The last chunk can be truncated and signal bit omitted, if the entire
|
||||
// payload (for example 16 bit for uint16_t has already been written).
|
||||
void WriteVariableWidthU64(uint64_t val, size_t chunk_length);
|
||||
void WriteVariableWidthU32(uint32_t val, size_t chunk_length);
|
||||
void WriteVariableWidthU16(uint16_t val, size_t chunk_length);
|
||||
void WriteVariableWidthS64(int64_t val, size_t chunk_length,
|
||||
size_t zigzag_exponent);
|
||||
|
||||
// Returns number of bits written.
|
||||
virtual size_t GetNumBits() const = 0;
|
||||
|
||||
// Provides direct access to the buffer data if implemented.
|
||||
virtual const uint8_t* GetData() const { return nullptr; }
|
||||
|
||||
// Returns buffer size in bytes.
|
||||
size_t GetDataSizeBytes() const { return NumBitsToNumWords<8>(GetNumBits()); }
|
||||
|
||||
// Generates and returns byte array containing written bits.
|
||||
virtual std::vector<uint8_t> GetDataCopy() const = 0;
|
||||
|
||||
BitWriterInterface(const BitWriterInterface&) = delete;
|
||||
BitWriterInterface& operator=(const BitWriterInterface&) = delete;
|
||||
};
|
||||
|
||||
// This class is an implementation of BitWriterInterface, using
|
||||
// std::vector<uint64_t> to store written bits.
|
||||
class BitWriterWord64 : public BitWriterInterface {
|
||||
public:
|
||||
explicit BitWriterWord64(size_t reserve_bits = 64);
|
||||
|
||||
void WriteBits(uint64_t bits, size_t num_bits) override;
|
||||
|
||||
size_t GetNumBits() const override { return end_; }
|
||||
|
||||
const uint8_t* GetData() const override {
|
||||
return reinterpret_cast<const uint8_t*>(buffer_.data());
|
||||
}
|
||||
|
||||
std::vector<uint8_t> GetDataCopy() const override {
|
||||
return std::vector<uint8_t>(GetData(), GetData() + GetDataSizeBytes());
|
||||
}
|
||||
|
||||
// Sets callback to emit bit sequences after every write.
|
||||
void SetCallback(std::function<void(const std::string&)> callback) {
|
||||
callback_ = callback;
|
||||
}
|
||||
|
||||
protected:
|
||||
// Sends string generated from arguments to callback_ if defined.
|
||||
void EmitSequence(uint64_t bits, size_t num_bits) const {
|
||||
if (callback_) callback_(BitsToStream(bits, num_bits));
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<uint64_t> buffer_;
|
||||
// Total number of bits written so far. Named 'end' as analogy to std::end().
|
||||
size_t end_;
|
||||
|
||||
// If not null, the writer will use the callback to emit the written bit
|
||||
// sequence as a string of '0' and '1'.
|
||||
std::function<void(const std::string&)> callback_;
|
||||
};
|
||||
|
||||
// Base class for reading sequences of bits.
|
||||
class BitReaderInterface {
|
||||
public:
|
||||
BitReaderInterface() {}
|
||||
virtual ~BitReaderInterface() {}
|
||||
|
||||
// Reads |num_bits| from the stream, stores them in |bits|.
|
||||
// Returns number of read bits. |num_bits| must be no greater than 64.
|
||||
virtual size_t ReadBits(uint64_t* bits, size_t num_bits) = 0;
|
||||
|
||||
// Reads 8 * sizeof(T) bits and stores them in |val|.
|
||||
template <typename T>
|
||||
bool ReadUnencoded(T* val) {
|
||||
static_assert(sizeof(T) <= 64, "Type size too large");
|
||||
uint64_t bits = 0;
|
||||
const size_t num_read = ReadBits(&bits, sizeof(T) * 8);
|
||||
if (num_read != sizeof(T) * 8) return false;
|
||||
memcpy(val, &bits, sizeof(T));
|
||||
return true;
|
||||
}
|
||||
|
||||
// Returns number of bits already read.
|
||||
virtual size_t GetNumReadBits() const = 0;
|
||||
|
||||
// These two functions define 'hard' and 'soft' EOF.
|
||||
//
|
||||
// Returns true if the end of the buffer was reached.
|
||||
virtual bool ReachedEnd() const = 0;
|
||||
// Returns true if we reached the end of the buffer or are nearing it and only
|
||||
// zero bits are left to read. Implementations of this function are allowed to
|
||||
// commit a "false negative" error if the end of the buffer was not reached,
|
||||
// i.e. it can return false even if indeed only zeroes are left.
|
||||
// It is assumed that the consumer expects that
|
||||
// the buffer stream ends with padding zeroes, and would accept this as a
|
||||
// 'soft' EOF. Implementations of this class do not necessarily need to
|
||||
// implement this, default behavior can simply delegate to ReachedEnd().
|
||||
virtual bool OnlyZeroesLeft() const { return ReachedEnd(); }
|
||||
|
||||
// Reads value encoded with WriteVariableWidthXXX (see BitWriterInterface).
|
||||
// Reader and writer must use the same |chunk_length| and variable type.
|
||||
// Returns true on success, false if the bit stream ends prematurely.
|
||||
bool ReadVariableWidthU64(uint64_t* val, size_t chunk_length);
|
||||
bool ReadVariableWidthU32(uint32_t* val, size_t chunk_length);
|
||||
bool ReadVariableWidthU16(uint16_t* val, size_t chunk_length);
|
||||
bool ReadVariableWidthS64(int64_t* val, size_t chunk_length,
|
||||
size_t zigzag_exponent);
|
||||
|
||||
BitReaderInterface(const BitReaderInterface&) = delete;
|
||||
BitReaderInterface& operator=(const BitReaderInterface&) = delete;
|
||||
};
|
||||
|
||||
// This class is an implementation of BitReaderInterface which accepts both
|
||||
// uint8_t and uint64_t buffers as input. uint64_t buffers are consumed and
|
||||
// owned. uint8_t buffers are copied.
|
||||
class BitReaderWord64 : public BitReaderInterface {
|
||||
public:
|
||||
// Consumes and owns the buffer.
|
||||
explicit BitReaderWord64(std::vector<uint64_t>&& buffer);
|
||||
|
||||
// Copies the buffer and casts it to uint64.
|
||||
// Consuming the original buffer and casting it to uint64 is difficult,
|
||||
// as it would potentially cause data misalignment and poor performance.
|
||||
explicit BitReaderWord64(const std::vector<uint8_t>& buffer);
|
||||
BitReaderWord64(const void* buffer, size_t num_bytes);
|
||||
|
||||
size_t ReadBits(uint64_t* bits, size_t num_bits) override;
|
||||
|
||||
size_t GetNumReadBits() const override { return pos_; }
|
||||
|
||||
bool ReachedEnd() const override;
|
||||
bool OnlyZeroesLeft() const override;
|
||||
|
||||
BitReaderWord64() = delete;
|
||||
|
||||
// Sets callback to emit bit sequences after every read.
|
||||
void SetCallback(std::function<void(const std::string&)> callback) {
|
||||
callback_ = callback;
|
||||
}
|
||||
|
||||
protected:
|
||||
// Sends string generated from arguments to callback_ if defined.
|
||||
void EmitSequence(uint64_t bits, size_t num_bits) const {
|
||||
if (callback_) callback_(BitsToStream(bits, num_bits));
|
||||
}
|
||||
|
||||
private:
|
||||
const std::vector<uint64_t> buffer_;
|
||||
size_t pos_;
|
||||
|
||||
// If not null, the reader will use the callback to emit the read bit
|
||||
// sequence as a string of '0' and '1'.
|
||||
std::function<void(const std::string&)> callback_;
|
||||
};
|
||||
|
||||
} // namespace comp
|
||||
} // namespace spvtools
|
||||
|
||||
#endif // SOURCE_COMP_BIT_STREAM_H_
|
@ -1,389 +0,0 @@
|
||||
// Copyright (c) 2017 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Contains utils for reading, writing and debug printing bit streams.
|
||||
|
||||
#ifndef SOURCE_COMP_HUFFMAN_CODEC_H_
|
||||
#define SOURCE_COMP_HUFFMAN_CODEC_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <functional>
|
||||
#include <iomanip>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <ostream>
|
||||
#include <queue>
|
||||
#include <sstream>
|
||||
#include <stack>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace spvtools {
|
||||
namespace comp {
|
||||
|
||||
// Used to generate and apply a Huffman coding scheme.
|
||||
// |Val| is the type of variable being encoded (for example a string or a
|
||||
// literal).
|
||||
template <class Val>
|
||||
class HuffmanCodec {
|
||||
public:
|
||||
// Huffman tree node.
|
||||
struct Node {
|
||||
Node() {}
|
||||
|
||||
// Creates Node from serialization leaving weight and id undefined.
|
||||
Node(const Val& in_value, uint32_t in_left, uint32_t in_right)
|
||||
: value(in_value), left(in_left), right(in_right) {}
|
||||
|
||||
Val value = Val();
|
||||
uint32_t weight = 0;
|
||||
// Ids are issued sequentially starting from 1. Ids are used as an ordering
|
||||
// tie-breaker, to make sure that the ordering (and resulting coding scheme)
|
||||
// are consistent accross multiple platforms.
|
||||
uint32_t id = 0;
|
||||
// Handles of children.
|
||||
uint32_t left = 0;
|
||||
uint32_t right = 0;
|
||||
};
|
||||
|
||||
// Creates Huffman codec from a histogramm.
|
||||
// Histogramm counts must not be zero.
|
||||
explicit HuffmanCodec(const std::map<Val, uint32_t>& hist) {
|
||||
if (hist.empty()) return;
|
||||
|
||||
// Heuristic estimate.
|
||||
nodes_.reserve(3 * hist.size());
|
||||
|
||||
// Create NIL.
|
||||
CreateNode();
|
||||
|
||||
// The queue is sorted in ascending order by weight (or by node id if
|
||||
// weights are equal).
|
||||
std::vector<uint32_t> queue_vector;
|
||||
queue_vector.reserve(hist.size());
|
||||
std::priority_queue<uint32_t, std::vector<uint32_t>,
|
||||
std::function<bool(uint32_t, uint32_t)>>
|
||||
queue(std::bind(&HuffmanCodec::LeftIsBigger, this,
|
||||
std::placeholders::_1, std::placeholders::_2),
|
||||
std::move(queue_vector));
|
||||
|
||||
// Put all leaves in the queue.
|
||||
for (const auto& pair : hist) {
|
||||
const uint32_t node = CreateNode();
|
||||
MutableValueOf(node) = pair.first;
|
||||
MutableWeightOf(node) = pair.second;
|
||||
assert(WeightOf(node));
|
||||
queue.push(node);
|
||||
}
|
||||
|
||||
// Form the tree by combining two subtrees with the least weight,
|
||||
// and pushing the root of the new tree in the queue.
|
||||
while (true) {
|
||||
// We push a node at the end of each iteration, so the queue is never
|
||||
// supposed to be empty at this point, unless there are no leaves, but
|
||||
// that case was already handled.
|
||||
assert(!queue.empty());
|
||||
const uint32_t right = queue.top();
|
||||
queue.pop();
|
||||
|
||||
// If the queue is empty at this point, then the last node is
|
||||
// the root of the complete Huffman tree.
|
||||
if (queue.empty()) {
|
||||
root_ = right;
|
||||
break;
|
||||
}
|
||||
|
||||
const uint32_t left = queue.top();
|
||||
queue.pop();
|
||||
|
||||
// Combine left and right into a new tree and push it into the queue.
|
||||
const uint32_t parent = CreateNode();
|
||||
MutableWeightOf(parent) = WeightOf(right) + WeightOf(left);
|
||||
MutableLeftOf(parent) = left;
|
||||
MutableRightOf(parent) = right;
|
||||
queue.push(parent);
|
||||
}
|
||||
|
||||
// Traverse the tree and form encoding table.
|
||||
CreateEncodingTable();
|
||||
}
|
||||
|
||||
// Creates Huffman codec from saved tree structure.
|
||||
// |nodes| is the list of nodes of the tree, nodes[0] being NIL.
|
||||
// |root_handle| is the index of the root node.
|
||||
HuffmanCodec(uint32_t root_handle, std::vector<Node>&& nodes) {
|
||||
nodes_ = std::move(nodes);
|
||||
assert(!nodes_.empty());
|
||||
assert(root_handle > 0 && root_handle < nodes_.size());
|
||||
assert(!LeftOf(0) && !RightOf(0));
|
||||
|
||||
root_ = root_handle;
|
||||
|
||||
// Traverse the tree and form encoding table.
|
||||
CreateEncodingTable();
|
||||
}
|
||||
|
||||
// Serializes the codec in the following text format:
|
||||
// (<root_handle>, {
|
||||
// {0, 0, 0},
|
||||
// {val1, left1, right1},
|
||||
// {val2, left2, right2},
|
||||
// ...
|
||||
// })
|
||||
std::string SerializeToText(int indent_num_whitespaces) const {
|
||||
const bool value_is_text = std::is_same<Val, std::string>::value;
|
||||
|
||||
const std::string indent1 = std::string(indent_num_whitespaces, ' ');
|
||||
const std::string indent2 = std::string(indent_num_whitespaces + 2, ' ');
|
||||
|
||||
std::stringstream code;
|
||||
code << "(" << root_ << ", {\n";
|
||||
|
||||
for (const Node& node : nodes_) {
|
||||
code << indent2 << "{";
|
||||
if (value_is_text) code << "\"";
|
||||
code << node.value;
|
||||
if (value_is_text) code << "\"";
|
||||
code << ", " << node.left << ", " << node.right << "},\n";
|
||||
}
|
||||
|
||||
code << indent1 << "})";
|
||||
|
||||
return code.str();
|
||||
}
|
||||
|
||||
// Prints the Huffman tree in the following format:
|
||||
// w------w------'x'
|
||||
// w------'y'
|
||||
// Where w stands for the weight of the node.
|
||||
// Right tree branches appear above left branches. Taking the right path
|
||||
// adds 1 to the code, taking the left adds 0.
|
||||
void PrintTree(std::ostream& out) const { PrintTreeInternal(out, root_, 0); }
|
||||
|
||||
// Traverses the tree and prints the Huffman table: value, code
|
||||
// and optionally node weight for every leaf.
|
||||
void PrintTable(std::ostream& out, bool print_weights = true) {
|
||||
std::queue<std::pair<uint32_t, std::string>> queue;
|
||||
queue.emplace(root_, "");
|
||||
|
||||
while (!queue.empty()) {
|
||||
const uint32_t node = queue.front().first;
|
||||
const std::string code = queue.front().second;
|
||||
queue.pop();
|
||||
if (!RightOf(node) && !LeftOf(node)) {
|
||||
out << ValueOf(node);
|
||||
if (print_weights) out << " " << WeightOf(node);
|
||||
out << " " << code << std::endl;
|
||||
} else {
|
||||
if (LeftOf(node)) queue.emplace(LeftOf(node), code + "0");
|
||||
|
||||
if (RightOf(node)) queue.emplace(RightOf(node), code + "1");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the Huffman table. The table was built at at construction time,
|
||||
// this function just returns a const reference.
|
||||
const std::unordered_map<Val, std::pair<uint64_t, size_t>>& GetEncodingTable()
|
||||
const {
|
||||
return encoding_table_;
|
||||
}
|
||||
|
||||
// Encodes |val| and stores its Huffman code in the lower |num_bits| of
|
||||
// |bits|. Returns false of |val| is not in the Huffman table.
|
||||
bool Encode(const Val& val, uint64_t* bits, size_t* num_bits) const {
|
||||
auto it = encoding_table_.find(val);
|
||||
if (it == encoding_table_.end()) return false;
|
||||
*bits = it->second.first;
|
||||
*num_bits = it->second.second;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Reads bits one-by-one using callback |read_bit| until a match is found.
|
||||
// Matching value is stored in |val|. Returns false if |read_bit| terminates
|
||||
// before a code was mathced.
|
||||
// |read_bit| has type bool func(bool* bit). When called, the next bit is
|
||||
// stored in |bit|. |read_bit| returns false if the stream terminates
|
||||
// prematurely.
|
||||
bool DecodeFromStream(const std::function<bool(bool*)>& read_bit,
|
||||
Val* val) const {
|
||||
uint32_t node = root_;
|
||||
while (true) {
|
||||
assert(node);
|
||||
|
||||
if (!RightOf(node) && !LeftOf(node)) {
|
||||
*val = ValueOf(node);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool go_right;
|
||||
if (!read_bit(&go_right)) return false;
|
||||
|
||||
if (go_right)
|
||||
node = RightOf(node);
|
||||
else
|
||||
node = LeftOf(node);
|
||||
}
|
||||
|
||||
assert(0);
|
||||
return false;
|
||||
}
|
||||
|
||||
private:
|
||||
// Returns value of the node referenced by |handle|.
|
||||
Val ValueOf(uint32_t node) const { return nodes_.at(node).value; }
|
||||
|
||||
// Returns left child of |node|.
|
||||
uint32_t LeftOf(uint32_t node) const { return nodes_.at(node).left; }
|
||||
|
||||
// Returns right child of |node|.
|
||||
uint32_t RightOf(uint32_t node) const { return nodes_.at(node).right; }
|
||||
|
||||
// Returns weight of |node|.
|
||||
uint32_t WeightOf(uint32_t node) const { return nodes_.at(node).weight; }
|
||||
|
||||
// Returns id of |node|.
|
||||
uint32_t IdOf(uint32_t node) const { return nodes_.at(node).id; }
|
||||
|
||||
// Returns mutable reference to value of |node|.
|
||||
Val& MutableValueOf(uint32_t node) {
|
||||
assert(node);
|
||||
return nodes_.at(node).value;
|
||||
}
|
||||
|
||||
// Returns mutable reference to handle of left child of |node|.
|
||||
uint32_t& MutableLeftOf(uint32_t node) {
|
||||
assert(node);
|
||||
return nodes_.at(node).left;
|
||||
}
|
||||
|
||||
// Returns mutable reference to handle of right child of |node|.
|
||||
uint32_t& MutableRightOf(uint32_t node) {
|
||||
assert(node);
|
||||
return nodes_.at(node).right;
|
||||
}
|
||||
|
||||
// Returns mutable reference to weight of |node|.
|
||||
uint32_t& MutableWeightOf(uint32_t node) { return nodes_.at(node).weight; }
|
||||
|
||||
// Returns mutable reference to id of |node|.
|
||||
uint32_t& MutableIdOf(uint32_t node) { return nodes_.at(node).id; }
|
||||
|
||||
// Returns true if |left| has bigger weight than |right|. Node ids are
|
||||
// used as tie-breaker.
|
||||
bool LeftIsBigger(uint32_t left, uint32_t right) const {
|
||||
if (WeightOf(left) == WeightOf(right)) {
|
||||
assert(IdOf(left) != IdOf(right));
|
||||
return IdOf(left) > IdOf(right);
|
||||
}
|
||||
return WeightOf(left) > WeightOf(right);
|
||||
}
|
||||
|
||||
// Prints subtree (helper function used by PrintTree).
|
||||
void PrintTreeInternal(std::ostream& out, uint32_t node, size_t depth) const {
|
||||
if (!node) return;
|
||||
|
||||
const size_t kTextFieldWidth = 7;
|
||||
|
||||
if (!RightOf(node) && !LeftOf(node)) {
|
||||
out << ValueOf(node) << std::endl;
|
||||
} else {
|
||||
if (RightOf(node)) {
|
||||
std::stringstream label;
|
||||
label << std::setfill('-') << std::left << std::setw(kTextFieldWidth)
|
||||
<< WeightOf(RightOf(node));
|
||||
out << label.str();
|
||||
PrintTreeInternal(out, RightOf(node), depth + 1);
|
||||
}
|
||||
|
||||
if (LeftOf(node)) {
|
||||
out << std::string(depth * kTextFieldWidth, ' ');
|
||||
std::stringstream label;
|
||||
label << std::setfill('-') << std::left << std::setw(kTextFieldWidth)
|
||||
<< WeightOf(LeftOf(node));
|
||||
out << label.str();
|
||||
PrintTreeInternal(out, LeftOf(node), depth + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Traverses the Huffman tree and saves paths to the leaves as bit
|
||||
// sequences to encoding_table_.
|
||||
void CreateEncodingTable() {
|
||||
struct Context {
|
||||
Context(uint32_t in_node, uint64_t in_bits, size_t in_depth)
|
||||
: node(in_node), bits(in_bits), depth(in_depth) {}
|
||||
uint32_t node;
|
||||
// Huffman tree depth cannot exceed 64 as histogramm counts are expected
|
||||
// to be positive and limited by numeric_limits<uint32_t>::max().
|
||||
// For practical applications tree depth would be much smaller than 64.
|
||||
uint64_t bits;
|
||||
size_t depth;
|
||||
};
|
||||
|
||||
std::queue<Context> queue;
|
||||
queue.emplace(root_, 0, 0);
|
||||
|
||||
while (!queue.empty()) {
|
||||
const Context& context = queue.front();
|
||||
const uint32_t node = context.node;
|
||||
const uint64_t bits = context.bits;
|
||||
const size_t depth = context.depth;
|
||||
queue.pop();
|
||||
|
||||
if (!RightOf(node) && !LeftOf(node)) {
|
||||
auto insertion_result = encoding_table_.emplace(
|
||||
ValueOf(node), std::pair<uint64_t, size_t>(bits, depth));
|
||||
assert(insertion_result.second);
|
||||
(void)insertion_result;
|
||||
} else {
|
||||
if (LeftOf(node)) queue.emplace(LeftOf(node), bits, depth + 1);
|
||||
|
||||
if (RightOf(node))
|
||||
queue.emplace(RightOf(node), bits | (1ULL << depth), depth + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Creates new Huffman tree node and stores it in the deleter array.
|
||||
uint32_t CreateNode() {
|
||||
const uint32_t handle = static_cast<uint32_t>(nodes_.size());
|
||||
nodes_.emplace_back(Node());
|
||||
nodes_.back().id = next_node_id_++;
|
||||
return handle;
|
||||
}
|
||||
|
||||
// Huffman tree root handle.
|
||||
uint32_t root_ = 0;
|
||||
|
||||
// Huffman tree deleter.
|
||||
std::vector<Node> nodes_;
|
||||
|
||||
// Encoding table value -> {bits, num_bits}.
|
||||
// Huffman codes are expected to never exceed 64 bit length (this is in fact
|
||||
// impossible if frequencies are stored as uint32_t).
|
||||
std::unordered_map<Val, std::pair<uint64_t, size_t>> encoding_table_;
|
||||
|
||||
// Next node id issued by CreateNode();
|
||||
uint32_t next_node_id_ = 1;
|
||||
};
|
||||
|
||||
} // namespace comp
|
||||
} // namespace spvtools
|
||||
|
||||
#endif // SOURCE_COMP_HUFFMAN_CODEC_H_
|
@ -1,112 +0,0 @@
|
||||
// Copyright (c) 2018 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "source/comp/markv.h"
|
||||
|
||||
#include "source/comp/markv_decoder.h"
|
||||
#include "source/comp/markv_encoder.h"
|
||||
|
||||
namespace spvtools {
|
||||
namespace comp {
|
||||
namespace {
|
||||
|
||||
spv_result_t EncodeHeader(void* user_data, spv_endianness_t endian,
|
||||
uint32_t magic, uint32_t version, uint32_t generator,
|
||||
uint32_t id_bound, uint32_t schema) {
|
||||
MarkvEncoder* encoder = reinterpret_cast<MarkvEncoder*>(user_data);
|
||||
return encoder->EncodeHeader(endian, magic, version, generator, id_bound,
|
||||
schema);
|
||||
}
|
||||
|
||||
spv_result_t EncodeInstruction(void* user_data,
|
||||
const spv_parsed_instruction_t* inst) {
|
||||
MarkvEncoder* encoder = reinterpret_cast<MarkvEncoder*>(user_data);
|
||||
return encoder->EncodeInstruction(*inst);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
spv_result_t SpirvToMarkv(
|
||||
spv_const_context context, const std::vector<uint32_t>& spirv,
|
||||
const MarkvCodecOptions& options, const MarkvModel& markv_model,
|
||||
MessageConsumer message_consumer, MarkvLogConsumer log_consumer,
|
||||
MarkvDebugConsumer debug_consumer, std::vector<uint8_t>* markv) {
|
||||
spv_context_t hijack_context = *context;
|
||||
SetContextMessageConsumer(&hijack_context, message_consumer);
|
||||
|
||||
spv_validator_options validator_options =
|
||||
MarkvDecoder::GetValidatorOptions(options);
|
||||
if (validator_options) {
|
||||
spv_const_binary_t spirv_binary = {spirv.data(), spirv.size()};
|
||||
const spv_result_t result = spvValidateWithOptions(
|
||||
&hijack_context, validator_options, &spirv_binary, nullptr);
|
||||
if (result != SPV_SUCCESS) return result;
|
||||
}
|
||||
|
||||
MarkvEncoder encoder(&hijack_context, options, &markv_model);
|
||||
|
||||
spv_position_t position = {};
|
||||
if (log_consumer || debug_consumer) {
|
||||
encoder.CreateLogger(log_consumer, debug_consumer);
|
||||
|
||||
spv_text text = nullptr;
|
||||
if (spvBinaryToText(&hijack_context, spirv.data(), spirv.size(),
|
||||
SPV_BINARY_TO_TEXT_OPTION_NO_HEADER, &text,
|
||||
nullptr) != SPV_SUCCESS) {
|
||||
return DiagnosticStream(position, hijack_context.consumer, "",
|
||||
SPV_ERROR_INVALID_BINARY)
|
||||
<< "Failed to disassemble SPIR-V binary.";
|
||||
}
|
||||
assert(text);
|
||||
encoder.SetDisassembly(std::string(text->str, text->length));
|
||||
spvTextDestroy(text);
|
||||
}
|
||||
|
||||
if (spvBinaryParse(&hijack_context, &encoder, spirv.data(), spirv.size(),
|
||||
EncodeHeader, EncodeInstruction, nullptr) != SPV_SUCCESS) {
|
||||
return DiagnosticStream(position, hijack_context.consumer, "",
|
||||
SPV_ERROR_INVALID_BINARY)
|
||||
<< "Unable to encode to MARK-V.";
|
||||
}
|
||||
|
||||
*markv = encoder.GetMarkvBinary();
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
spv_result_t MarkvToSpirv(
|
||||
spv_const_context context, const std::vector<uint8_t>& markv,
|
||||
const MarkvCodecOptions& options, const MarkvModel& markv_model,
|
||||
MessageConsumer message_consumer, MarkvLogConsumer log_consumer,
|
||||
MarkvDebugConsumer debug_consumer, std::vector<uint32_t>* spirv) {
|
||||
spv_position_t position = {};
|
||||
spv_context_t hijack_context = *context;
|
||||
SetContextMessageConsumer(&hijack_context, message_consumer);
|
||||
|
||||
MarkvDecoder decoder(&hijack_context, markv, options, &markv_model);
|
||||
|
||||
if (log_consumer || debug_consumer)
|
||||
decoder.CreateLogger(log_consumer, debug_consumer);
|
||||
|
||||
if (decoder.DecodeModule(spirv) != SPV_SUCCESS) {
|
||||
return DiagnosticStream(position, hijack_context.consumer, "",
|
||||
SPV_ERROR_INVALID_BINARY)
|
||||
<< "Unable to decode MARK-V.";
|
||||
}
|
||||
|
||||
assert(!spirv->empty());
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
} // namespace comp
|
||||
} // namespace spvtools
|
@ -1,74 +0,0 @@
|
||||
// Copyright (c) 2018 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// MARK-V is a compression format for SPIR-V binaries. It strips away
|
||||
// non-essential information (such as result ids which can be regenerated) and
|
||||
// uses various bit reduction techiniques to reduce the size of the binary and
|
||||
// make it more similar to other compressed SPIR-V files to further improve
|
||||
// compression of the dataset.
|
||||
|
||||
#ifndef SOURCE_COMP_MARKV_H_
|
||||
#define SOURCE_COMP_MARKV_H_
|
||||
|
||||
#include "spirv-tools/libspirv.hpp"
|
||||
|
||||
namespace spvtools {
|
||||
namespace comp {
|
||||
|
||||
class MarkvModel;
|
||||
|
||||
struct MarkvCodecOptions {
|
||||
bool validate_spirv_binary = false;
|
||||
};
|
||||
|
||||
// Debug callback. Called once per instruction.
|
||||
// |words| is instruction SPIR-V words.
|
||||
// |bits| is a textual representation of the MARK-V bit sequence used to encode
|
||||
// the instruction (char '0' for 0, char '1' for 1).
|
||||
// |comment| contains all logs generated while processing the instruction.
|
||||
using MarkvDebugConsumer =
|
||||
std::function<bool(const std::vector<uint32_t>& words,
|
||||
const std::string& bits, const std::string& comment)>;
|
||||
|
||||
// Logging callback. Called often (if decoder reads a single bit, the log
|
||||
// consumer will receive 1 character string with that bit).
|
||||
// This callback is more suitable for continous output than MarkvDebugConsumer,
|
||||
// for example if the codec crashes it would allow to pinpoint on which operand
|
||||
// or bit the crash happened.
|
||||
// |snippet| could be any atomic fragment of text logged by the codec. It can
|
||||
// contain a paragraph of text with newlines, or can be just one character.
|
||||
using MarkvLogConsumer = std::function<void(const std::string& snippet)>;
|
||||
|
||||
// Encodes the given SPIR-V binary to MARK-V binary.
|
||||
// |log_consumer| is optional (pass MarkvLogConsumer() to disable).
|
||||
// |debug_consumer| is optional (pass MarkvDebugConsumer() to disable).
|
||||
spv_result_t SpirvToMarkv(
|
||||
spv_const_context context, const std::vector<uint32_t>& spirv,
|
||||
const MarkvCodecOptions& options, const MarkvModel& markv_model,
|
||||
MessageConsumer message_consumer, MarkvLogConsumer log_consumer,
|
||||
MarkvDebugConsumer debug_consumer, std::vector<uint8_t>* markv);
|
||||
|
||||
// Decodes a SPIR-V binary from the given MARK-V binary.
|
||||
// |log_consumer| is optional (pass MarkvLogConsumer() to disable).
|
||||
// |debug_consumer| is optional (pass MarkvDebugConsumer() to disable).
|
||||
spv_result_t MarkvToSpirv(
|
||||
spv_const_context context, const std::vector<uint8_t>& markv,
|
||||
const MarkvCodecOptions& options, const MarkvModel& markv_model,
|
||||
MessageConsumer message_consumer, MarkvLogConsumer log_consumer,
|
||||
MarkvDebugConsumer debug_consumer, std::vector<uint32_t>* spirv);
|
||||
|
||||
} // namespace comp
|
||||
} // namespace spvtools
|
||||
|
||||
#endif // SOURCE_COMP_MARKV_H_
|
@ -1,793 +0,0 @@
|
||||
// Copyright (c) 2018 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// MARK-V is a compression format for SPIR-V binaries. It strips away
|
||||
// non-essential information (such as result IDs which can be regenerated) and
|
||||
// uses various bit reduction techniques to reduce the size of the binary.
|
||||
|
||||
#include "source/comp/markv_codec.h"
|
||||
|
||||
#include "source/comp/markv_logger.h"
|
||||
#include "source/latest_version_glsl_std_450_header.h"
|
||||
#include "source/latest_version_opencl_std_header.h"
|
||||
#include "source/opcode.h"
|
||||
#include "source/util/make_unique.h"
|
||||
|
||||
namespace spvtools {
|
||||
namespace comp {
|
||||
namespace {
|
||||
|
||||
// Custom hash function used to produce short descriptors.
|
||||
uint32_t ShortHashU32Array(const std::vector<uint32_t>& words) {
|
||||
// The hash function is a sum of hashes of each word seeded by word index.
|
||||
// Knuth's multiplicative hash is used to hash the words.
|
||||
const uint32_t kKnuthMulHash = 2654435761;
|
||||
uint32_t val = 0;
|
||||
for (uint32_t i = 0; i < words.size(); ++i) {
|
||||
val += (words[i] + i + 123) * kKnuthMulHash;
|
||||
}
|
||||
return 1 + val % ((1 << MarkvCodec::kShortDescriptorNumBits) - 1);
|
||||
}
|
||||
|
||||
// Returns a set of mtf rank codecs based on a plausible hand-coded
|
||||
// distribution.
|
||||
std::map<uint64_t, std::unique_ptr<HuffmanCodec<uint32_t>>>
|
||||
GetMtfHuffmanCodecs() {
|
||||
std::map<uint64_t, std::unique_ptr<HuffmanCodec<uint32_t>>> codecs;
|
||||
|
||||
std::unique_ptr<HuffmanCodec<uint32_t>> codec;
|
||||
|
||||
codec = MakeUnique<HuffmanCodec<uint32_t>>(std::map<uint32_t, uint32_t>({
|
||||
{0, 5},
|
||||
{1, 40},
|
||||
{2, 10},
|
||||
{3, 5},
|
||||
{4, 5},
|
||||
{5, 5},
|
||||
{6, 3},
|
||||
{7, 3},
|
||||
{8, 3},
|
||||
{9, 3},
|
||||
{MarkvCodec::kMtfRankEncodedByValueSignal, 10},
|
||||
}));
|
||||
codecs.emplace(kMtfAll, std::move(codec));
|
||||
|
||||
codec = MakeUnique<HuffmanCodec<uint32_t>>(std::map<uint32_t, uint32_t>({
|
||||
{1, 50},
|
||||
{2, 20},
|
||||
{3, 5},
|
||||
{4, 5},
|
||||
{5, 2},
|
||||
{6, 1},
|
||||
{7, 1},
|
||||
{8, 1},
|
||||
{9, 1},
|
||||
{MarkvCodec::kMtfRankEncodedByValueSignal, 10},
|
||||
}));
|
||||
codecs.emplace(kMtfGenericNonZeroRank, std::move(codec));
|
||||
|
||||
return codecs;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
const uint32_t MarkvCodec::kMarkvMagicNumber = 0x07230303;
|
||||
|
||||
const uint32_t MarkvCodec::kMtfSmallestRankEncodedByValue = 10;
|
||||
|
||||
const uint32_t MarkvCodec::kMtfRankEncodedByValueSignal =
|
||||
std::numeric_limits<uint32_t>::max();
|
||||
|
||||
const uint32_t MarkvCodec::kShortDescriptorNumBits = 8;
|
||||
|
||||
const size_t MarkvCodec::kByteBreakAfterInstIfLessThanUntilNextByte = 8;
|
||||
|
||||
MarkvCodec::MarkvCodec(spv_const_context context,
|
||||
spv_validator_options validator_options,
|
||||
const MarkvModel* model)
|
||||
: validator_options_(validator_options),
|
||||
grammar_(context),
|
||||
model_(model),
|
||||
short_id_descriptors_(ShortHashU32Array),
|
||||
mtf_huffman_codecs_(GetMtfHuffmanCodecs()),
|
||||
context_(context) {}
|
||||
|
||||
MarkvCodec::~MarkvCodec() { spvValidatorOptionsDestroy(validator_options_); }
|
||||
|
||||
MarkvCodec::MarkvHeader::MarkvHeader()
|
||||
: magic_number(MarkvCodec::kMarkvMagicNumber),
|
||||
markv_version(MarkvCodec::GetMarkvVersion()) {}
|
||||
|
||||
// Defines and returns current MARK-V version.
|
||||
// static
|
||||
uint32_t MarkvCodec::GetMarkvVersion() {
|
||||
const uint32_t kVersionMajor = 1;
|
||||
const uint32_t kVersionMinor = 4;
|
||||
return kVersionMinor | (kVersionMajor << 16);
|
||||
}
|
||||
|
||||
size_t MarkvCodec::GetNumBitsToNextByte(size_t bit_pos) const {
|
||||
return (8 - (bit_pos % 8)) % 8;
|
||||
}
|
||||
|
||||
// Returns true if the opcode has a fixed number of operands. May return a
|
||||
// false negative.
|
||||
bool MarkvCodec::OpcodeHasFixedNumberOfOperands(SpvOp opcode) const {
|
||||
switch (opcode) {
|
||||
// TODO(atgoo@github.com) This is not a complete list.
|
||||
case SpvOpNop:
|
||||
case SpvOpName:
|
||||
case SpvOpUndef:
|
||||
case SpvOpSizeOf:
|
||||
case SpvOpLine:
|
||||
case SpvOpNoLine:
|
||||
case SpvOpDecorationGroup:
|
||||
case SpvOpExtension:
|
||||
case SpvOpExtInstImport:
|
||||
case SpvOpMemoryModel:
|
||||
case SpvOpCapability:
|
||||
case SpvOpTypeVoid:
|
||||
case SpvOpTypeBool:
|
||||
case SpvOpTypeInt:
|
||||
case SpvOpTypeFloat:
|
||||
case SpvOpTypeVector:
|
||||
case SpvOpTypeMatrix:
|
||||
case SpvOpTypeSampler:
|
||||
case SpvOpTypeSampledImage:
|
||||
case SpvOpTypeArray:
|
||||
case SpvOpTypePointer:
|
||||
case SpvOpConstantTrue:
|
||||
case SpvOpConstantFalse:
|
||||
case SpvOpLabel:
|
||||
case SpvOpBranch:
|
||||
case SpvOpFunction:
|
||||
case SpvOpFunctionParameter:
|
||||
case SpvOpFunctionEnd:
|
||||
case SpvOpBitcast:
|
||||
case SpvOpCopyObject:
|
||||
case SpvOpTranspose:
|
||||
case SpvOpSNegate:
|
||||
case SpvOpFNegate:
|
||||
case SpvOpIAdd:
|
||||
case SpvOpFAdd:
|
||||
case SpvOpISub:
|
||||
case SpvOpFSub:
|
||||
case SpvOpIMul:
|
||||
case SpvOpFMul:
|
||||
case SpvOpUDiv:
|
||||
case SpvOpSDiv:
|
||||
case SpvOpFDiv:
|
||||
case SpvOpUMod:
|
||||
case SpvOpSRem:
|
||||
case SpvOpSMod:
|
||||
case SpvOpFRem:
|
||||
case SpvOpFMod:
|
||||
case SpvOpVectorTimesScalar:
|
||||
case SpvOpMatrixTimesScalar:
|
||||
case SpvOpVectorTimesMatrix:
|
||||
case SpvOpMatrixTimesVector:
|
||||
case SpvOpMatrixTimesMatrix:
|
||||
case SpvOpOuterProduct:
|
||||
case SpvOpDot:
|
||||
return true;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void MarkvCodec::ProcessCurInstruction() {
|
||||
instructions_.emplace_back(new val::Instruction(&inst_));
|
||||
|
||||
const SpvOp opcode = SpvOp(inst_.opcode);
|
||||
|
||||
if (inst_.result_id) {
|
||||
id_to_def_instruction_.emplace(inst_.result_id, instructions_.back().get());
|
||||
|
||||
// Collect ids local to the current function.
|
||||
if (cur_function_id_) {
|
||||
ids_local_to_cur_function_.push_back(inst_.result_id);
|
||||
}
|
||||
|
||||
// Starting new function.
|
||||
if (opcode == SpvOpFunction) {
|
||||
cur_function_id_ = inst_.result_id;
|
||||
cur_function_return_type_ = inst_.type_id;
|
||||
if (model_->id_fallback_strategy() ==
|
||||
MarkvModel::IdFallbackStrategy::kRuleBased) {
|
||||
multi_mtf_.Insert(GetMtfFunctionWithReturnType(inst_.type_id),
|
||||
inst_.result_id);
|
||||
}
|
||||
|
||||
// Store function parameter types in a queue, so that we know which types
|
||||
// to expect in the following OpFunctionParameter instructions.
|
||||
const val::Instruction* def_inst = FindDef(inst_.words[4]);
|
||||
assert(def_inst);
|
||||
assert(def_inst->opcode() == SpvOpTypeFunction);
|
||||
for (uint32_t i = 3; i < def_inst->words().size(); ++i) {
|
||||
remaining_function_parameter_types_.push_back(def_inst->word(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove local ids from MTFs if function end.
|
||||
if (opcode == SpvOpFunctionEnd) {
|
||||
cur_function_id_ = 0;
|
||||
for (uint32_t id : ids_local_to_cur_function_) multi_mtf_.RemoveFromAll(id);
|
||||
ids_local_to_cur_function_.clear();
|
||||
assert(remaining_function_parameter_types_.empty());
|
||||
}
|
||||
|
||||
if (!inst_.result_id) return;
|
||||
|
||||
{
|
||||
// Save the result ID to type ID mapping.
|
||||
// In the grammar, type ID always appears before result ID.
|
||||
// A regular value maps to its type. Some instructions (e.g. OpLabel)
|
||||
// have no type Id, and will map to 0. The result Id for a
|
||||
// type-generating instruction (e.g. OpTypeInt) maps to itself.
|
||||
auto insertion_result = id_to_type_id_.emplace(
|
||||
inst_.result_id, spvOpcodeGeneratesType(SpvOp(inst_.opcode))
|
||||
? inst_.result_id
|
||||
: inst_.type_id);
|
||||
(void)insertion_result;
|
||||
assert(insertion_result.second);
|
||||
}
|
||||
|
||||
// Add result_id to MTFs.
|
||||
if (model_->id_fallback_strategy() ==
|
||||
MarkvModel::IdFallbackStrategy::kRuleBased) {
|
||||
switch (opcode) {
|
||||
case SpvOpTypeFloat:
|
||||
case SpvOpTypeInt:
|
||||
case SpvOpTypeBool:
|
||||
case SpvOpTypeVector:
|
||||
case SpvOpTypePointer:
|
||||
case SpvOpExtInstImport:
|
||||
case SpvOpTypeSampledImage:
|
||||
case SpvOpTypeImage:
|
||||
case SpvOpTypeSampler:
|
||||
multi_mtf_.Insert(GetMtfIdGeneratedByOpcode(opcode), inst_.result_id);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
if (spvOpcodeIsComposite(opcode)) {
|
||||
multi_mtf_.Insert(kMtfTypeComposite, inst_.result_id);
|
||||
}
|
||||
|
||||
if (opcode == SpvOpLabel) {
|
||||
multi_mtf_.InsertOrPromote(kMtfLabel, inst_.result_id);
|
||||
}
|
||||
|
||||
if (opcode == SpvOpTypeInt) {
|
||||
multi_mtf_.Insert(kMtfTypeScalar, inst_.result_id);
|
||||
multi_mtf_.Insert(kMtfTypeIntScalarOrVector, inst_.result_id);
|
||||
}
|
||||
|
||||
if (opcode == SpvOpTypeFloat) {
|
||||
multi_mtf_.Insert(kMtfTypeScalar, inst_.result_id);
|
||||
multi_mtf_.Insert(kMtfTypeFloatScalarOrVector, inst_.result_id);
|
||||
}
|
||||
|
||||
if (opcode == SpvOpTypeBool) {
|
||||
multi_mtf_.Insert(kMtfTypeScalar, inst_.result_id);
|
||||
multi_mtf_.Insert(kMtfTypeBoolScalarOrVector, inst_.result_id);
|
||||
}
|
||||
|
||||
if (opcode == SpvOpTypeVector) {
|
||||
const uint32_t component_type_id = inst_.words[2];
|
||||
const uint32_t size = inst_.words[3];
|
||||
if (multi_mtf_.HasValue(GetMtfIdGeneratedByOpcode(SpvOpTypeFloat),
|
||||
component_type_id)) {
|
||||
multi_mtf_.Insert(kMtfTypeFloatScalarOrVector, inst_.result_id);
|
||||
} else if (multi_mtf_.HasValue(GetMtfIdGeneratedByOpcode(SpvOpTypeInt),
|
||||
component_type_id)) {
|
||||
multi_mtf_.Insert(kMtfTypeIntScalarOrVector, inst_.result_id);
|
||||
} else if (multi_mtf_.HasValue(GetMtfIdGeneratedByOpcode(SpvOpTypeBool),
|
||||
component_type_id)) {
|
||||
multi_mtf_.Insert(kMtfTypeBoolScalarOrVector, inst_.result_id);
|
||||
}
|
||||
multi_mtf_.Insert(GetMtfTypeVectorOfSize(size), inst_.result_id);
|
||||
}
|
||||
|
||||
if (inst_.opcode == SpvOpTypeFunction) {
|
||||
const uint32_t return_type = inst_.words[2];
|
||||
multi_mtf_.Insert(kMtfTypeReturnedByFunction, return_type);
|
||||
multi_mtf_.Insert(GetMtfFunctionTypeWithReturnType(return_type),
|
||||
inst_.result_id);
|
||||
}
|
||||
|
||||
if (inst_.type_id) {
|
||||
const val::Instruction* type_inst = FindDef(inst_.type_id);
|
||||
assert(type_inst);
|
||||
|
||||
multi_mtf_.Insert(kMtfObject, inst_.result_id);
|
||||
|
||||
multi_mtf_.Insert(GetMtfIdOfType(inst_.type_id), inst_.result_id);
|
||||
|
||||
if (multi_mtf_.HasValue(kMtfTypeFloatScalarOrVector, inst_.type_id)) {
|
||||
multi_mtf_.Insert(kMtfFloatScalarOrVector, inst_.result_id);
|
||||
}
|
||||
|
||||
if (multi_mtf_.HasValue(kMtfTypeIntScalarOrVector, inst_.type_id))
|
||||
multi_mtf_.Insert(kMtfIntScalarOrVector, inst_.result_id);
|
||||
|
||||
if (multi_mtf_.HasValue(kMtfTypeBoolScalarOrVector, inst_.type_id))
|
||||
multi_mtf_.Insert(kMtfBoolScalarOrVector, inst_.result_id);
|
||||
|
||||
if (multi_mtf_.HasValue(kMtfTypeComposite, inst_.type_id))
|
||||
multi_mtf_.Insert(kMtfComposite, inst_.result_id);
|
||||
|
||||
switch (type_inst->opcode()) {
|
||||
case SpvOpTypeInt:
|
||||
case SpvOpTypeBool:
|
||||
case SpvOpTypePointer:
|
||||
case SpvOpTypeVector:
|
||||
case SpvOpTypeImage:
|
||||
case SpvOpTypeSampledImage:
|
||||
case SpvOpTypeSampler:
|
||||
multi_mtf_.Insert(
|
||||
GetMtfIdWithTypeGeneratedByOpcode(type_inst->opcode()),
|
||||
inst_.result_id);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
if (type_inst->opcode() == SpvOpTypeVector) {
|
||||
const uint32_t component_type = type_inst->word(2);
|
||||
multi_mtf_.Insert(GetMtfVectorOfComponentType(component_type),
|
||||
inst_.result_id);
|
||||
}
|
||||
|
||||
if (type_inst->opcode() == SpvOpTypePointer) {
|
||||
assert(type_inst->operands().size() > 2);
|
||||
assert(type_inst->words().size() > type_inst->operands()[2].offset);
|
||||
const uint32_t data_type =
|
||||
type_inst->word(type_inst->operands()[2].offset);
|
||||
multi_mtf_.Insert(GetMtfPointerToType(data_type), inst_.result_id);
|
||||
|
||||
if (multi_mtf_.HasValue(kMtfTypeComposite, data_type))
|
||||
multi_mtf_.Insert(kMtfTypePointerToComposite, inst_.result_id);
|
||||
}
|
||||
}
|
||||
|
||||
if (spvOpcodeGeneratesType(opcode)) {
|
||||
if (opcode != SpvOpTypeFunction) {
|
||||
multi_mtf_.Insert(kMtfTypeNonFunction, inst_.result_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (model_->AnyDescriptorHasCodingScheme()) {
|
||||
const uint32_t long_descriptor =
|
||||
long_id_descriptors_.ProcessInstruction(inst_);
|
||||
if (model_->DescriptorHasCodingScheme(long_descriptor))
|
||||
multi_mtf_.Insert(GetMtfLongIdDescriptor(long_descriptor),
|
||||
inst_.result_id);
|
||||
}
|
||||
|
||||
if (model_->id_fallback_strategy() ==
|
||||
MarkvModel::IdFallbackStrategy::kShortDescriptor) {
|
||||
const uint32_t short_descriptor =
|
||||
short_id_descriptors_.ProcessInstruction(inst_);
|
||||
multi_mtf_.Insert(GetMtfShortIdDescriptor(short_descriptor),
|
||||
inst_.result_id);
|
||||
}
|
||||
}
|
||||
|
||||
uint64_t MarkvCodec::GetRuleBasedMtf() {
|
||||
// This function is only called for id operands (but not result ids).
|
||||
assert(spvIsIdType(operand_.type) ||
|
||||
operand_.type == SPV_OPERAND_TYPE_OPTIONAL_ID);
|
||||
assert(operand_.type != SPV_OPERAND_TYPE_RESULT_ID);
|
||||
|
||||
const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
|
||||
|
||||
// All operand slots which expect label id.
|
||||
if ((inst_.opcode == SpvOpLoopMerge && operand_index_ <= 1) ||
|
||||
(inst_.opcode == SpvOpSelectionMerge && operand_index_ == 0) ||
|
||||
(inst_.opcode == SpvOpBranch && operand_index_ == 0) ||
|
||||
(inst_.opcode == SpvOpBranchConditional &&
|
||||
(operand_index_ == 1 || operand_index_ == 2)) ||
|
||||
(inst_.opcode == SpvOpPhi && operand_index_ >= 3 &&
|
||||
operand_index_ % 2 == 1) ||
|
||||
(inst_.opcode == SpvOpSwitch && operand_index_ > 0)) {
|
||||
return kMtfLabel;
|
||||
}
|
||||
|
||||
switch (opcode) {
|
||||
case SpvOpFAdd:
|
||||
case SpvOpFSub:
|
||||
case SpvOpFMul:
|
||||
case SpvOpFDiv:
|
||||
case SpvOpFRem:
|
||||
case SpvOpFMod:
|
||||
case SpvOpFNegate: {
|
||||
if (operand_index_ == 0) return kMtfTypeFloatScalarOrVector;
|
||||
return GetMtfIdOfType(inst_.type_id);
|
||||
}
|
||||
|
||||
case SpvOpISub:
|
||||
case SpvOpIAdd:
|
||||
case SpvOpIMul:
|
||||
case SpvOpSDiv:
|
||||
case SpvOpUDiv:
|
||||
case SpvOpSMod:
|
||||
case SpvOpUMod:
|
||||
case SpvOpSRem:
|
||||
case SpvOpSNegate: {
|
||||
if (operand_index_ == 0) return kMtfTypeIntScalarOrVector;
|
||||
|
||||
return kMtfIntScalarOrVector;
|
||||
}
|
||||
|
||||
// TODO(atgoo@github.com) Add OpConvertFToU and other opcodes.
|
||||
|
||||
case SpvOpFOrdEqual:
|
||||
case SpvOpFUnordEqual:
|
||||
case SpvOpFOrdNotEqual:
|
||||
case SpvOpFUnordNotEqual:
|
||||
case SpvOpFOrdLessThan:
|
||||
case SpvOpFUnordLessThan:
|
||||
case SpvOpFOrdGreaterThan:
|
||||
case SpvOpFUnordGreaterThan:
|
||||
case SpvOpFOrdLessThanEqual:
|
||||
case SpvOpFUnordLessThanEqual:
|
||||
case SpvOpFOrdGreaterThanEqual:
|
||||
case SpvOpFUnordGreaterThanEqual: {
|
||||
if (operand_index_ == 0) return kMtfTypeBoolScalarOrVector;
|
||||
if (operand_index_ == 2) return kMtfFloatScalarOrVector;
|
||||
if (operand_index_ == 3) {
|
||||
const uint32_t first_operand_id = GetInstWords()[3];
|
||||
const uint32_t first_operand_type = id_to_type_id_.at(first_operand_id);
|
||||
return GetMtfIdOfType(first_operand_type);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpVectorShuffle: {
|
||||
if (operand_index_ == 0) {
|
||||
assert(inst_.num_operands > 4);
|
||||
return GetMtfTypeVectorOfSize(inst_.num_operands - 4);
|
||||
}
|
||||
|
||||
assert(inst_.type_id);
|
||||
if (operand_index_ == 2 || operand_index_ == 3)
|
||||
return GetMtfVectorOfComponentType(
|
||||
GetVectorComponentType(inst_.type_id));
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpVectorTimesScalar: {
|
||||
if (operand_index_ == 0) {
|
||||
// TODO(atgoo@github.com) Could be narrowed to vector of floats.
|
||||
return GetMtfIdGeneratedByOpcode(SpvOpTypeVector);
|
||||
}
|
||||
|
||||
assert(inst_.type_id);
|
||||
if (operand_index_ == 2) return GetMtfIdOfType(inst_.type_id);
|
||||
if (operand_index_ == 3)
|
||||
return GetMtfIdOfType(GetVectorComponentType(inst_.type_id));
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpDot: {
|
||||
if (operand_index_ == 0) return GetMtfIdGeneratedByOpcode(SpvOpTypeFloat);
|
||||
|
||||
assert(inst_.type_id);
|
||||
if (operand_index_ == 2)
|
||||
return GetMtfVectorOfComponentType(inst_.type_id);
|
||||
if (operand_index_ == 3) {
|
||||
const uint32_t vector_id = GetInstWords()[3];
|
||||
const uint32_t vector_type = id_to_type_id_.at(vector_id);
|
||||
return GetMtfIdOfType(vector_type);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpTypeVector: {
|
||||
if (operand_index_ == 1) {
|
||||
return kMtfTypeScalar;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpTypeMatrix: {
|
||||
if (operand_index_ == 1) {
|
||||
return GetMtfIdGeneratedByOpcode(SpvOpTypeVector);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpTypePointer: {
|
||||
if (operand_index_ == 2) {
|
||||
return kMtfTypeNonFunction;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpTypeStruct: {
|
||||
if (operand_index_ >= 1) {
|
||||
return kMtfTypeNonFunction;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpTypeFunction: {
|
||||
if (operand_index_ == 1) {
|
||||
return kMtfTypeNonFunction;
|
||||
}
|
||||
|
||||
if (operand_index_ >= 2) {
|
||||
return kMtfTypeNonFunction;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpLoad: {
|
||||
if (operand_index_ == 0) return kMtfTypeNonFunction;
|
||||
|
||||
if (operand_index_ == 2) {
|
||||
assert(inst_.type_id);
|
||||
return GetMtfPointerToType(inst_.type_id);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpStore: {
|
||||
if (operand_index_ == 0)
|
||||
return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypePointer);
|
||||
if (operand_index_ == 1) {
|
||||
const uint32_t pointer_id = GetInstWords()[1];
|
||||
const uint32_t pointer_type = id_to_type_id_.at(pointer_id);
|
||||
const val::Instruction* pointer_inst = FindDef(pointer_type);
|
||||
assert(pointer_inst);
|
||||
assert(pointer_inst->opcode() == SpvOpTypePointer);
|
||||
const uint32_t data_type =
|
||||
pointer_inst->word(pointer_inst->operands()[2].offset);
|
||||
return GetMtfIdOfType(data_type);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpVariable: {
|
||||
if (operand_index_ == 0)
|
||||
return GetMtfIdGeneratedByOpcode(SpvOpTypePointer);
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpAccessChain: {
|
||||
if (operand_index_ == 0)
|
||||
return GetMtfIdGeneratedByOpcode(SpvOpTypePointer);
|
||||
if (operand_index_ == 2) return kMtfTypePointerToComposite;
|
||||
if (operand_index_ >= 3)
|
||||
return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeInt);
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpCompositeConstruct: {
|
||||
if (operand_index_ == 0) return kMtfTypeComposite;
|
||||
if (operand_index_ >= 2) {
|
||||
const uint32_t composite_type = GetInstWords()[1];
|
||||
if (multi_mtf_.HasValue(kMtfTypeFloatScalarOrVector, composite_type))
|
||||
return kMtfFloatScalarOrVector;
|
||||
if (multi_mtf_.HasValue(kMtfTypeIntScalarOrVector, composite_type))
|
||||
return kMtfIntScalarOrVector;
|
||||
if (multi_mtf_.HasValue(kMtfTypeBoolScalarOrVector, composite_type))
|
||||
return kMtfBoolScalarOrVector;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpCompositeExtract: {
|
||||
if (operand_index_ == 2) return kMtfComposite;
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpConstantComposite: {
|
||||
if (operand_index_ == 0) return kMtfTypeComposite;
|
||||
if (operand_index_ >= 2) {
|
||||
const val::Instruction* composite_type_inst = FindDef(inst_.type_id);
|
||||
assert(composite_type_inst);
|
||||
if (composite_type_inst->opcode() == SpvOpTypeVector) {
|
||||
return GetMtfIdOfType(composite_type_inst->word(2));
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpExtInst: {
|
||||
if (operand_index_ == 2)
|
||||
return GetMtfIdGeneratedByOpcode(SpvOpExtInstImport);
|
||||
if (operand_index_ >= 4) {
|
||||
const uint32_t return_type = GetInstWords()[1];
|
||||
const uint32_t ext_inst_type = inst_.ext_inst_type;
|
||||
const uint32_t ext_inst_index = GetInstWords()[4];
|
||||
// TODO(atgoo@github.com) The list of extended instructions is
|
||||
// incomplete. Only common instructions and low-hanging fruits listed.
|
||||
if (ext_inst_type == SPV_EXT_INST_TYPE_GLSL_STD_450) {
|
||||
switch (ext_inst_index) {
|
||||
case GLSLstd450FAbs:
|
||||
case GLSLstd450FClamp:
|
||||
case GLSLstd450FMax:
|
||||
case GLSLstd450FMin:
|
||||
case GLSLstd450FMix:
|
||||
case GLSLstd450Step:
|
||||
case GLSLstd450SmoothStep:
|
||||
case GLSLstd450Fma:
|
||||
case GLSLstd450Pow:
|
||||
case GLSLstd450Exp:
|
||||
case GLSLstd450Exp2:
|
||||
case GLSLstd450Log:
|
||||
case GLSLstd450Log2:
|
||||
case GLSLstd450Sqrt:
|
||||
case GLSLstd450InverseSqrt:
|
||||
case GLSLstd450Fract:
|
||||
case GLSLstd450Floor:
|
||||
case GLSLstd450Ceil:
|
||||
case GLSLstd450Radians:
|
||||
case GLSLstd450Degrees:
|
||||
case GLSLstd450Sin:
|
||||
case GLSLstd450Cos:
|
||||
case GLSLstd450Tan:
|
||||
case GLSLstd450Sinh:
|
||||
case GLSLstd450Cosh:
|
||||
case GLSLstd450Tanh:
|
||||
case GLSLstd450Asin:
|
||||
case GLSLstd450Acos:
|
||||
case GLSLstd450Atan:
|
||||
case GLSLstd450Atan2:
|
||||
case GLSLstd450Asinh:
|
||||
case GLSLstd450Acosh:
|
||||
case GLSLstd450Atanh:
|
||||
case GLSLstd450MatrixInverse:
|
||||
case GLSLstd450Cross:
|
||||
case GLSLstd450Normalize:
|
||||
case GLSLstd450Reflect:
|
||||
case GLSLstd450FaceForward:
|
||||
return GetMtfIdOfType(return_type);
|
||||
case GLSLstd450Length:
|
||||
case GLSLstd450Distance:
|
||||
case GLSLstd450Refract:
|
||||
return kMtfFloatScalarOrVector;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
} else if (ext_inst_type == SPV_EXT_INST_TYPE_OPENCL_STD) {
|
||||
switch (ext_inst_index) {
|
||||
case OpenCLLIB::Fabs:
|
||||
case OpenCLLIB::FClamp:
|
||||
case OpenCLLIB::Fmax:
|
||||
case OpenCLLIB::Fmin:
|
||||
case OpenCLLIB::Step:
|
||||
case OpenCLLIB::Smoothstep:
|
||||
case OpenCLLIB::Fma:
|
||||
case OpenCLLIB::Pow:
|
||||
case OpenCLLIB::Exp:
|
||||
case OpenCLLIB::Exp2:
|
||||
case OpenCLLIB::Log:
|
||||
case OpenCLLIB::Log2:
|
||||
case OpenCLLIB::Sqrt:
|
||||
case OpenCLLIB::Rsqrt:
|
||||
case OpenCLLIB::Fract:
|
||||
case OpenCLLIB::Floor:
|
||||
case OpenCLLIB::Ceil:
|
||||
case OpenCLLIB::Radians:
|
||||
case OpenCLLIB::Degrees:
|
||||
case OpenCLLIB::Sin:
|
||||
case OpenCLLIB::Cos:
|
||||
case OpenCLLIB::Tan:
|
||||
case OpenCLLIB::Sinh:
|
||||
case OpenCLLIB::Cosh:
|
||||
case OpenCLLIB::Tanh:
|
||||
case OpenCLLIB::Asin:
|
||||
case OpenCLLIB::Acos:
|
||||
case OpenCLLIB::Atan:
|
||||
case OpenCLLIB::Atan2:
|
||||
case OpenCLLIB::Asinh:
|
||||
case OpenCLLIB::Acosh:
|
||||
case OpenCLLIB::Atanh:
|
||||
case OpenCLLIB::Cross:
|
||||
case OpenCLLIB::Normalize:
|
||||
return GetMtfIdOfType(return_type);
|
||||
case OpenCLLIB::Length:
|
||||
case OpenCLLIB::Distance:
|
||||
return kMtfFloatScalarOrVector;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpFunction: {
|
||||
if (operand_index_ == 0) return kMtfTypeReturnedByFunction;
|
||||
|
||||
if (operand_index_ == 3) {
|
||||
const uint32_t return_type = GetInstWords()[1];
|
||||
return GetMtfFunctionTypeWithReturnType(return_type);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpFunctionCall: {
|
||||
if (operand_index_ == 0) return kMtfTypeReturnedByFunction;
|
||||
|
||||
if (operand_index_ == 2) {
|
||||
const uint32_t return_type = GetInstWords()[1];
|
||||
return GetMtfFunctionWithReturnType(return_type);
|
||||
}
|
||||
|
||||
if (operand_index_ >= 3) {
|
||||
const uint32_t function_id = GetInstWords()[3];
|
||||
const val::Instruction* function_inst = FindDef(function_id);
|
||||
if (!function_inst) return kMtfObject;
|
||||
|
||||
assert(function_inst->opcode() == SpvOpFunction);
|
||||
|
||||
const uint32_t function_type_id = function_inst->word(4);
|
||||
const val::Instruction* function_type_inst = FindDef(function_type_id);
|
||||
assert(function_type_inst);
|
||||
assert(function_type_inst->opcode() == SpvOpTypeFunction);
|
||||
|
||||
const uint32_t argument_type = function_type_inst->word(operand_index_);
|
||||
return GetMtfIdOfType(argument_type);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpReturnValue: {
|
||||
if (operand_index_ == 0) return GetMtfIdOfType(cur_function_return_type_);
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpBranchConditional: {
|
||||
if (operand_index_ == 0)
|
||||
return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeBool);
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpSampledImage: {
|
||||
if (operand_index_ == 0)
|
||||
return GetMtfIdGeneratedByOpcode(SpvOpTypeSampledImage);
|
||||
if (operand_index_ == 2)
|
||||
return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeImage);
|
||||
if (operand_index_ == 3)
|
||||
return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeSampler);
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpImageSampleImplicitLod: {
|
||||
if (operand_index_ == 0)
|
||||
return GetMtfIdGeneratedByOpcode(SpvOpTypeVector);
|
||||
if (operand_index_ == 2)
|
||||
return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeSampledImage);
|
||||
if (operand_index_ == 3)
|
||||
return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeVector);
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return kMtfNone;
|
||||
}
|
||||
|
||||
} // namespace comp
|
||||
} // namespace spvtools
|
@ -1,337 +0,0 @@
|
||||
// Copyright (c) 2018 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SOURCE_COMP_MARKV_CODEC_H_
|
||||
#define SOURCE_COMP_MARKV_CODEC_H_
|
||||
|
||||
#include <list>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "source/assembly_grammar.h"
|
||||
#include "source/comp/huffman_codec.h"
|
||||
#include "source/comp/markv_model.h"
|
||||
#include "source/comp/move_to_front.h"
|
||||
#include "source/diagnostic.h"
|
||||
#include "source/id_descriptor.h"
|
||||
|
||||
#include "source/val/instruction.h"
|
||||
|
||||
// Base class for MARK-V encoder and decoder. Contains common functionality
|
||||
// such as:
|
||||
// - Validator connection and validation state.
|
||||
// - SPIR-V grammar and helper functions.
|
||||
|
||||
namespace spvtools {
|
||||
namespace comp {
|
||||
|
||||
class MarkvLogger;
|
||||
|
||||
// Handles for move-to-front sequences. Enums which end with "Begin" define
|
||||
// handle spaces which start at that value and span 16 or 32 bit wide.
|
||||
enum : uint64_t {
|
||||
kMtfNone = 0,
|
||||
// All ids.
|
||||
kMtfAll,
|
||||
// All forward declared ids.
|
||||
kMtfForwardDeclared,
|
||||
// All type ids except for generated by OpTypeFunction.
|
||||
kMtfTypeNonFunction,
|
||||
// All labels.
|
||||
kMtfLabel,
|
||||
// All ids created by instructions which had type_id.
|
||||
kMtfObject,
|
||||
// All types generated by OpTypeFloat, OpTypeInt, OpTypeBool.
|
||||
kMtfTypeScalar,
|
||||
// All composite types.
|
||||
kMtfTypeComposite,
|
||||
// Boolean type or any vector type of it.
|
||||
kMtfTypeBoolScalarOrVector,
|
||||
// All float types or any vector floats type.
|
||||
kMtfTypeFloatScalarOrVector,
|
||||
// All int types or any vector int type.
|
||||
kMtfTypeIntScalarOrVector,
|
||||
// All types declared as return types in OpTypeFunction.
|
||||
kMtfTypeReturnedByFunction,
|
||||
// All composite objects.
|
||||
kMtfComposite,
|
||||
// All bool objects or vectors of bools.
|
||||
kMtfBoolScalarOrVector,
|
||||
// All float objects or vectors of float.
|
||||
kMtfFloatScalarOrVector,
|
||||
// All int objects or vectors of int.
|
||||
kMtfIntScalarOrVector,
|
||||
// All pointer types which point to composited.
|
||||
kMtfTypePointerToComposite,
|
||||
// Used by EncodeMtfRankHuffman.
|
||||
kMtfGenericNonZeroRank,
|
||||
// Handle space for ids of specific type.
|
||||
kMtfIdOfTypeBegin = 0x10000,
|
||||
// Handle space for ids generated by specific opcode.
|
||||
kMtfIdGeneratedByOpcode = 0x20000,
|
||||
// Handle space for ids of objects with type generated by specific opcode.
|
||||
kMtfIdWithTypeGeneratedByOpcodeBegin = 0x30000,
|
||||
// All vectors of specific component type.
|
||||
kMtfVectorOfComponentTypeBegin = 0x40000,
|
||||
// All vector types of specific size.
|
||||
kMtfTypeVectorOfSizeBegin = 0x50000,
|
||||
// All pointer types to specific type.
|
||||
kMtfPointerToTypeBegin = 0x60000,
|
||||
// All function types which return specific type.
|
||||
kMtfFunctionTypeWithReturnTypeBegin = 0x70000,
|
||||
// All function objects which return specific type.
|
||||
kMtfFunctionWithReturnTypeBegin = 0x80000,
|
||||
// Short id descriptor space (max 16-bit).
|
||||
kMtfShortIdDescriptorSpaceBegin = 0x90000,
|
||||
// Long id descriptor space (32-bit).
|
||||
kMtfLongIdDescriptorSpaceBegin = 0x100000000,
|
||||
};
|
||||
|
||||
class MarkvCodec {
|
||||
public:
|
||||
static const uint32_t kMarkvMagicNumber;
|
||||
|
||||
// Mtf ranks smaller than this are encoded with Huffman coding.
|
||||
static const uint32_t kMtfSmallestRankEncodedByValue;
|
||||
|
||||
// Signals that the mtf rank is too large to be encoded with Huffman.
|
||||
static const uint32_t kMtfRankEncodedByValueSignal;
|
||||
|
||||
static const uint32_t kShortDescriptorNumBits;
|
||||
|
||||
static const size_t kByteBreakAfterInstIfLessThanUntilNextByte;
|
||||
|
||||
static uint32_t GetMarkvVersion();
|
||||
|
||||
virtual ~MarkvCodec();
|
||||
|
||||
protected:
|
||||
struct MarkvHeader {
|
||||
MarkvHeader();
|
||||
|
||||
uint32_t magic_number;
|
||||
uint32_t markv_version;
|
||||
// Magic number to identify or verify MarkvModel used for encoding.
|
||||
uint32_t markv_model = 0;
|
||||
uint32_t markv_length_in_bits = 0;
|
||||
uint32_t spirv_version = 0;
|
||||
uint32_t spirv_generator = 0;
|
||||
};
|
||||
|
||||
// |model| is owned by the caller, must be not null and valid during the
|
||||
// lifetime of the codec.
|
||||
MarkvCodec(spv_const_context context, spv_validator_options validator_options,
|
||||
const MarkvModel* model);
|
||||
|
||||
// Returns instruction which created |id| or nullptr if such instruction was
|
||||
// not registered.
|
||||
const val::Instruction* FindDef(uint32_t id) const {
|
||||
const auto it = id_to_def_instruction_.find(id);
|
||||
if (it == id_to_def_instruction_.end()) return nullptr;
|
||||
return it->second;
|
||||
}
|
||||
|
||||
size_t GetNumBitsToNextByte(size_t bit_pos) const;
|
||||
bool OpcodeHasFixedNumberOfOperands(SpvOp opcode) const;
|
||||
|
||||
// Returns type id of vector type component.
|
||||
uint32_t GetVectorComponentType(uint32_t vector_type_id) const {
|
||||
const val::Instruction* type_inst = FindDef(vector_type_id);
|
||||
assert(type_inst);
|
||||
assert(type_inst->opcode() == SpvOpTypeVector);
|
||||
|
||||
const uint32_t component_type =
|
||||
type_inst->word(type_inst->operands()[1].offset);
|
||||
return component_type;
|
||||
}
|
||||
|
||||
// Returns mtf handle for ids of given type.
|
||||
uint64_t GetMtfIdOfType(uint32_t type_id) const {
|
||||
return kMtfIdOfTypeBegin + type_id;
|
||||
}
|
||||
|
||||
// Returns mtf handle for ids generated by given opcode.
|
||||
uint64_t GetMtfIdGeneratedByOpcode(SpvOp opcode) const {
|
||||
return kMtfIdGeneratedByOpcode + opcode;
|
||||
}
|
||||
|
||||
// Returns mtf handle for ids of type generated by given opcode.
|
||||
uint64_t GetMtfIdWithTypeGeneratedByOpcode(SpvOp opcode) const {
|
||||
return kMtfIdWithTypeGeneratedByOpcodeBegin + opcode;
|
||||
}
|
||||
|
||||
// Returns mtf handle for vectors of specific component type.
|
||||
uint64_t GetMtfVectorOfComponentType(uint32_t type_id) const {
|
||||
return kMtfVectorOfComponentTypeBegin + type_id;
|
||||
}
|
||||
|
||||
// Returns mtf handle for vector type of specific size.
|
||||
uint64_t GetMtfTypeVectorOfSize(uint32_t size) const {
|
||||
return kMtfTypeVectorOfSizeBegin + size;
|
||||
}
|
||||
|
||||
// Returns mtf handle for pointers to specific size.
|
||||
uint64_t GetMtfPointerToType(uint32_t type_id) const {
|
||||
return kMtfPointerToTypeBegin + type_id;
|
||||
}
|
||||
|
||||
// Returns mtf handle for function types with given return type.
|
||||
uint64_t GetMtfFunctionTypeWithReturnType(uint32_t type_id) const {
|
||||
return kMtfFunctionTypeWithReturnTypeBegin + type_id;
|
||||
}
|
||||
|
||||
// Returns mtf handle for functions with given return type.
|
||||
uint64_t GetMtfFunctionWithReturnType(uint32_t type_id) const {
|
||||
return kMtfFunctionWithReturnTypeBegin + type_id;
|
||||
}
|
||||
|
||||
// Returns mtf handle for the given long id descriptor.
|
||||
uint64_t GetMtfLongIdDescriptor(uint32_t descriptor) const {
|
||||
return kMtfLongIdDescriptorSpaceBegin + descriptor;
|
||||
}
|
||||
|
||||
// Returns mtf handle for the given short id descriptor.
|
||||
uint64_t GetMtfShortIdDescriptor(uint32_t descriptor) const {
|
||||
return kMtfShortIdDescriptorSpaceBegin + descriptor;
|
||||
}
|
||||
|
||||
// Process data from the current instruction. This would update MTFs and
|
||||
// other data containers.
|
||||
void ProcessCurInstruction();
|
||||
|
||||
// Returns move-to-front handle to be used for the current operand slot.
|
||||
// Mtf handle is chosen based on a set of rules defined by SPIR-V grammar.
|
||||
uint64_t GetRuleBasedMtf();
|
||||
|
||||
// Returns words of the current instruction. Decoder has a different
|
||||
// implementation and the array is valid only until the previously decoded
|
||||
// word.
|
||||
virtual const uint32_t* GetInstWords() const { return inst_.words; }
|
||||
|
||||
// Returns the opcode of the previous instruction.
|
||||
SpvOp GetPrevOpcode() const {
|
||||
if (instructions_.empty()) return SpvOpNop;
|
||||
|
||||
return instructions_.back()->opcode();
|
||||
}
|
||||
|
||||
// Returns diagnostic stream, position index is set to instruction number.
|
||||
DiagnosticStream Diag(spv_result_t error_code) const {
|
||||
return DiagnosticStream({0, 0, instructions_.size()}, context_->consumer,
|
||||
"", error_code);
|
||||
}
|
||||
|
||||
// Returns current id bound.
|
||||
uint32_t GetIdBound() const { return id_bound_; }
|
||||
|
||||
// Sets current id bound, expected to be no lower than the previous one.
|
||||
void SetIdBound(uint32_t id_bound) {
|
||||
assert(id_bound >= id_bound_);
|
||||
id_bound_ = id_bound;
|
||||
}
|
||||
|
||||
// Returns Huffman codec for ranks of the mtf with given |handle|.
|
||||
// Different mtfs can use different rank distributions.
|
||||
// May return nullptr if the codec doesn't exist.
|
||||
const HuffmanCodec<uint32_t>* GetMtfHuffmanCodec(uint64_t handle) const {
|
||||
const auto it = mtf_huffman_codecs_.find(handle);
|
||||
if (it == mtf_huffman_codecs_.end()) return nullptr;
|
||||
return it->second.get();
|
||||
}
|
||||
|
||||
// Promotes id in all move-to-front sequences if ids can be shared by multiple
|
||||
// sequences.
|
||||
void PromoteIfNeeded(uint32_t id) {
|
||||
if (!model_->AnyDescriptorHasCodingScheme() &&
|
||||
model_->id_fallback_strategy() ==
|
||||
MarkvModel::IdFallbackStrategy::kShortDescriptor) {
|
||||
// Move-to-front sequences do not share ids. Nothing to do.
|
||||
return;
|
||||
}
|
||||
multi_mtf_.Promote(id);
|
||||
}
|
||||
|
||||
spv_validator_options validator_options_ = nullptr;
|
||||
const AssemblyGrammar grammar_;
|
||||
MarkvHeader header_;
|
||||
|
||||
// MARK-V model, not owned.
|
||||
const MarkvModel* model_ = nullptr;
|
||||
|
||||
// Current instruction, current operand and current operand index.
|
||||
spv_parsed_instruction_t inst_;
|
||||
spv_parsed_operand_t operand_;
|
||||
uint32_t operand_index_;
|
||||
|
||||
// Maps a result ID to its type ID. By convention:
|
||||
// - a result ID that is a type definition maps to itself.
|
||||
// - a result ID without a type maps to 0. (E.g. for OpLabel)
|
||||
std::unordered_map<uint32_t, uint32_t> id_to_type_id_;
|
||||
|
||||
// Container for all move-to-front sequences.
|
||||
MultiMoveToFront multi_mtf_;
|
||||
|
||||
// Id of the current function or zero if outside of function.
|
||||
uint32_t cur_function_id_ = 0;
|
||||
|
||||
// Return type of the current function.
|
||||
uint32_t cur_function_return_type_ = 0;
|
||||
|
||||
// Remaining function parameter types. This container is filled on OpFunction,
|
||||
// and drained on OpFunctionParameter.
|
||||
std::list<uint32_t> remaining_function_parameter_types_;
|
||||
|
||||
// List of ids local to the current function.
|
||||
std::vector<uint32_t> ids_local_to_cur_function_;
|
||||
|
||||
// List of instructions in the order they are given in the module.
|
||||
std::vector<std::unique_ptr<const val::Instruction>> instructions_;
|
||||
|
||||
// Container/computer for long (32-bit) id descriptors.
|
||||
IdDescriptorCollection long_id_descriptors_;
|
||||
|
||||
// Container/computer for short id descriptors.
|
||||
// Short descriptors are stored in uint32_t, but their actual bit width is
|
||||
// defined with kShortDescriptorNumBits.
|
||||
// It doesn't seem logical to have a different computer for short id
|
||||
// descriptors, since one could actually map/truncate long descriptors.
|
||||
// But as short descriptors have collisions, the efficiency of
|
||||
// compression depends on the collision pattern, and short descriptors
|
||||
// produced by function ShortHashU32Array have been empirically proven to
|
||||
// produce better results.
|
||||
IdDescriptorCollection short_id_descriptors_;
|
||||
|
||||
// Huffman codecs for move-to-front ranks. The map key is mtf handle. Doesn't
|
||||
// need to contain a different codec for every handle as most use one and the
|
||||
// same.
|
||||
std::map<uint64_t, std::unique_ptr<HuffmanCodec<uint32_t>>>
|
||||
mtf_huffman_codecs_;
|
||||
|
||||
// If not nullptr, codec will log comments on the compression process.
|
||||
std::unique_ptr<MarkvLogger> logger_;
|
||||
|
||||
spv_const_context context_ = nullptr;
|
||||
|
||||
private:
|
||||
// Maps result id to the instruction which defined it.
|
||||
std::unordered_map<uint32_t, const val::Instruction*> id_to_def_instruction_;
|
||||
|
||||
uint32_t id_bound_ = 1;
|
||||
};
|
||||
|
||||
} // namespace comp
|
||||
} // namespace spvtools
|
||||
|
||||
#endif // SOURCE_COMP_MARKV_CODEC_H_
|
@ -1,925 +0,0 @@
|
||||
// Copyright (c) 2018 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "source/comp/markv_decoder.h"
|
||||
|
||||
#include <cstring>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
|
||||
#include "source/ext_inst.h"
|
||||
#include "source/opcode.h"
|
||||
#include "spirv-tools/libspirv.hpp"
|
||||
|
||||
namespace spvtools {
|
||||
namespace comp {
|
||||
|
||||
spv_result_t MarkvDecoder::DecodeNonIdWord(uint32_t* word) {
|
||||
auto* codec = model_->GetNonIdWordHuffmanCodec(inst_.opcode, operand_index_);
|
||||
|
||||
if (codec) {
|
||||
uint64_t decoded_value = 0;
|
||||
if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
|
||||
return Diag(SPV_ERROR_INVALID_BINARY)
|
||||
<< "Failed to decode non-id word with Huffman";
|
||||
|
||||
if (decoded_value != MarkvModel::GetMarkvNoneOfTheAbove()) {
|
||||
// The word decoded successfully.
|
||||
*word = uint32_t(decoded_value);
|
||||
assert(*word == decoded_value);
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
// Received kMarkvNoneOfTheAbove signal, use fallback decoding.
|
||||
}
|
||||
|
||||
const size_t chunk_length =
|
||||
model_->GetOperandVariableWidthChunkLength(operand_.type);
|
||||
if (chunk_length) {
|
||||
if (!reader_.ReadVariableWidthU32(word, chunk_length))
|
||||
return Diag(SPV_ERROR_INVALID_BINARY)
|
||||
<< "Failed to decode non-id word with varint";
|
||||
} else {
|
||||
if (!reader_.ReadUnencoded(word))
|
||||
return Diag(SPV_ERROR_INVALID_BINARY)
|
||||
<< "Failed to read unencoded non-id word";
|
||||
}
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
spv_result_t MarkvDecoder::DecodeOpcodeAndNumberOfOperands(
|
||||
uint32_t* opcode, uint32_t* num_operands) {
|
||||
// First try to use the Markov chain codec.
|
||||
auto* codec =
|
||||
model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(GetPrevOpcode());
|
||||
if (codec) {
|
||||
uint64_t decoded_value = 0;
|
||||
if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
|
||||
return Diag(SPV_ERROR_INTERNAL)
|
||||
<< "Failed to decode opcode_and_num_operands, previous opcode is "
|
||||
<< spvOpcodeString(GetPrevOpcode());
|
||||
|
||||
if (decoded_value != MarkvModel::GetMarkvNoneOfTheAbove()) {
|
||||
// The word was successfully decoded.
|
||||
*opcode = uint32_t(decoded_value & 0xFFFF);
|
||||
*num_operands = uint32_t(decoded_value >> 16);
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
// Received kMarkvNoneOfTheAbove signal, use fallback decoding.
|
||||
}
|
||||
|
||||
// Fallback to base-rate codec.
|
||||
codec = model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(SpvOpNop);
|
||||
assert(codec);
|
||||
uint64_t decoded_value = 0;
|
||||
if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
|
||||
return Diag(SPV_ERROR_INTERNAL)
|
||||
<< "Failed to decode opcode_and_num_operands with global codec";
|
||||
|
||||
if (decoded_value == MarkvModel::GetMarkvNoneOfTheAbove()) {
|
||||
// Received kMarkvNoneOfTheAbove signal, fallback further.
|
||||
return SPV_UNSUPPORTED;
|
||||
}
|
||||
|
||||
*opcode = uint32_t(decoded_value & 0xFFFF);
|
||||
*num_operands = uint32_t(decoded_value >> 16);
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
spv_result_t MarkvDecoder::DecodeMtfRankHuffman(uint64_t mtf,
|
||||
uint32_t fallback_method,
|
||||
uint32_t* rank) {
|
||||
const auto* codec = GetMtfHuffmanCodec(mtf);
|
||||
if (!codec) {
|
||||
assert(fallback_method != kMtfNone);
|
||||
codec = GetMtfHuffmanCodec(fallback_method);
|
||||
}
|
||||
|
||||
if (!codec) return Diag(SPV_ERROR_INTERNAL) << "No codec to decode MTF rank";
|
||||
|
||||
uint32_t decoded_value = 0;
|
||||
if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
|
||||
return Diag(SPV_ERROR_INTERNAL) << "Failed to decode MTF rank with Huffman";
|
||||
|
||||
if (decoded_value == kMtfRankEncodedByValueSignal) {
|
||||
// Decode by value.
|
||||
if (!reader_.ReadVariableWidthU32(rank, model_->mtf_rank_chunk_length()))
|
||||
return Diag(SPV_ERROR_INTERNAL)
|
||||
<< "Failed to decode MTF rank with varint";
|
||||
*rank += MarkvCodec::kMtfSmallestRankEncodedByValue;
|
||||
} else {
|
||||
// Decode using Huffman coding.
|
||||
assert(decoded_value < MarkvCodec::kMtfSmallestRankEncodedByValue);
|
||||
*rank = decoded_value;
|
||||
}
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
spv_result_t MarkvDecoder::DecodeIdWithDescriptor(uint32_t* id) {
|
||||
auto* codec =
|
||||
model_->GetIdDescriptorHuffmanCodec(inst_.opcode, operand_index_);
|
||||
|
||||
uint64_t mtf = kMtfNone;
|
||||
if (codec) {
|
||||
uint64_t decoded_value = 0;
|
||||
if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
|
||||
return Diag(SPV_ERROR_INTERNAL)
|
||||
<< "Failed to decode descriptor with Huffman";
|
||||
|
||||
if (decoded_value != MarkvModel::GetMarkvNoneOfTheAbove()) {
|
||||
const uint32_t long_descriptor = uint32_t(decoded_value);
|
||||
mtf = GetMtfLongIdDescriptor(long_descriptor);
|
||||
}
|
||||
}
|
||||
|
||||
if (mtf == kMtfNone) {
|
||||
if (model_->id_fallback_strategy() !=
|
||||
MarkvModel::IdFallbackStrategy::kShortDescriptor) {
|
||||
return SPV_UNSUPPORTED;
|
||||
}
|
||||
|
||||
uint64_t decoded_value = 0;
|
||||
if (!reader_.ReadBits(&decoded_value, MarkvCodec::kShortDescriptorNumBits))
|
||||
return Diag(SPV_ERROR_INTERNAL) << "Failed to read short descriptor";
|
||||
const uint32_t short_descriptor = uint32_t(decoded_value);
|
||||
if (short_descriptor == 0) {
|
||||
// Forward declared id.
|
||||
return SPV_UNSUPPORTED;
|
||||
}
|
||||
mtf = GetMtfShortIdDescriptor(short_descriptor);
|
||||
}
|
||||
|
||||
return DecodeExistingId(mtf, id);
|
||||
}
|
||||
|
||||
spv_result_t MarkvDecoder::DecodeExistingId(uint64_t mtf, uint32_t* id) {
|
||||
assert(multi_mtf_.GetSize(mtf) > 0);
|
||||
*id = 0;
|
||||
|
||||
uint32_t rank = 0;
|
||||
|
||||
if (multi_mtf_.GetSize(mtf) == 1) {
|
||||
rank = 1;
|
||||
} else {
|
||||
const spv_result_t result =
|
||||
DecodeMtfRankHuffman(mtf, kMtfGenericNonZeroRank, &rank);
|
||||
if (result != SPV_SUCCESS) return result;
|
||||
}
|
||||
|
||||
assert(rank);
|
||||
if (!multi_mtf_.ValueFromRank(mtf, rank, id))
|
||||
return Diag(SPV_ERROR_INTERNAL) << "MTF rank is out of bounds";
|
||||
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
spv_result_t MarkvDecoder::DecodeRefId(uint32_t* id) {
|
||||
{
|
||||
const spv_result_t result = DecodeIdWithDescriptor(id);
|
||||
if (result != SPV_UNSUPPORTED) return result;
|
||||
}
|
||||
|
||||
const bool can_forward_declare = spvOperandCanBeForwardDeclaredFunction(
|
||||
SpvOp(inst_.opcode))(operand_index_);
|
||||
uint32_t rank = 0;
|
||||
*id = 0;
|
||||
|
||||
if (model_->id_fallback_strategy() ==
|
||||
MarkvModel::IdFallbackStrategy::kRuleBased) {
|
||||
uint64_t mtf = GetRuleBasedMtf();
|
||||
if (mtf != kMtfNone && !can_forward_declare) {
|
||||
return DecodeExistingId(mtf, id);
|
||||
}
|
||||
|
||||
if (mtf == kMtfNone) mtf = kMtfAll;
|
||||
{
|
||||
const spv_result_t result = DecodeMtfRankHuffman(mtf, kMtfAll, &rank);
|
||||
if (result != SPV_SUCCESS) return result;
|
||||
}
|
||||
|
||||
if (rank == 0) {
|
||||
// This is the first occurrence of a forward declared id.
|
||||
*id = GetIdBound();
|
||||
SetIdBound(*id + 1);
|
||||
multi_mtf_.Insert(kMtfAll, *id);
|
||||
multi_mtf_.Insert(kMtfForwardDeclared, *id);
|
||||
if (mtf != kMtfAll) multi_mtf_.Insert(mtf, *id);
|
||||
} else {
|
||||
if (!multi_mtf_.ValueFromRank(mtf, rank, id))
|
||||
return Diag(SPV_ERROR_INTERNAL) << "MTF rank out of bounds";
|
||||
}
|
||||
} else {
|
||||
assert(can_forward_declare);
|
||||
|
||||
if (!reader_.ReadVariableWidthU32(&rank, model_->mtf_rank_chunk_length()))
|
||||
return Diag(SPV_ERROR_INTERNAL)
|
||||
<< "Failed to decode MTF rank with varint";
|
||||
|
||||
if (rank == 0) {
|
||||
// This is the first occurrence of a forward declared id.
|
||||
*id = GetIdBound();
|
||||
SetIdBound(*id + 1);
|
||||
multi_mtf_.Insert(kMtfForwardDeclared, *id);
|
||||
} else {
|
||||
if (!multi_mtf_.ValueFromRank(kMtfForwardDeclared, rank, id))
|
||||
return Diag(SPV_ERROR_INTERNAL) << "MTF rank out of bounds";
|
||||
}
|
||||
}
|
||||
assert(*id);
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
spv_result_t MarkvDecoder::DecodeTypeId() {
|
||||
if (inst_.opcode == SpvOpFunctionParameter) {
|
||||
assert(!remaining_function_parameter_types_.empty());
|
||||
inst_.type_id = remaining_function_parameter_types_.front();
|
||||
remaining_function_parameter_types_.pop_front();
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
{
|
||||
const spv_result_t result = DecodeIdWithDescriptor(&inst_.type_id);
|
||||
if (result != SPV_UNSUPPORTED) return result;
|
||||
}
|
||||
|
||||
assert(model_->id_fallback_strategy() ==
|
||||
MarkvModel::IdFallbackStrategy::kRuleBased);
|
||||
|
||||
uint64_t mtf = GetRuleBasedMtf();
|
||||
assert(!spvOperandCanBeForwardDeclaredFunction(SpvOp(inst_.opcode))(
|
||||
operand_index_));
|
||||
|
||||
if (mtf == kMtfNone) {
|
||||
mtf = kMtfTypeNonFunction;
|
||||
// Function types should have been handled by GetRuleBasedMtf.
|
||||
assert(inst_.opcode != SpvOpFunction);
|
||||
}
|
||||
|
||||
return DecodeExistingId(mtf, &inst_.type_id);
|
||||
}
|
||||
|
||||
spv_result_t MarkvDecoder::DecodeResultId() {
|
||||
uint32_t rank = 0;
|
||||
|
||||
const uint64_t num_still_forward_declared =
|
||||
multi_mtf_.GetSize(kMtfForwardDeclared);
|
||||
|
||||
if (num_still_forward_declared) {
|
||||
// Some ids were forward declared. Check if this id is one of them.
|
||||
uint64_t id_was_forward_declared;
|
||||
if (!reader_.ReadBits(&id_was_forward_declared, 1))
|
||||
return Diag(SPV_ERROR_INVALID_BINARY)
|
||||
<< "Failed to read id_was_forward_declared flag";
|
||||
|
||||
if (id_was_forward_declared) {
|
||||
if (!reader_.ReadVariableWidthU32(&rank, model_->mtf_rank_chunk_length()))
|
||||
return Diag(SPV_ERROR_INVALID_BINARY)
|
||||
<< "Failed to read MTF rank of forward declared id";
|
||||
|
||||
if (rank) {
|
||||
// The id was forward declared, recover it from kMtfForwardDeclared.
|
||||
if (!multi_mtf_.ValueFromRank(kMtfForwardDeclared, rank,
|
||||
&inst_.result_id))
|
||||
return Diag(SPV_ERROR_INTERNAL)
|
||||
<< "Forward declared MTF rank is out of bounds";
|
||||
|
||||
// We can now remove the id from kMtfForwardDeclared.
|
||||
if (!multi_mtf_.Remove(kMtfForwardDeclared, inst_.result_id))
|
||||
return Diag(SPV_ERROR_INTERNAL)
|
||||
<< "Failed to remove id from kMtfForwardDeclared";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (inst_.result_id == 0) {
|
||||
// The id was not forward declared, issue a new id.
|
||||
inst_.result_id = GetIdBound();
|
||||
SetIdBound(inst_.result_id + 1);
|
||||
}
|
||||
|
||||
if (model_->id_fallback_strategy() ==
|
||||
MarkvModel::IdFallbackStrategy::kRuleBased) {
|
||||
if (!rank) {
|
||||
multi_mtf_.Insert(kMtfAll, inst_.result_id);
|
||||
}
|
||||
}
|
||||
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
spv_result_t MarkvDecoder::DecodeLiteralNumber(
|
||||
const spv_parsed_operand_t& operand) {
|
||||
if (operand.number_bit_width <= 32) {
|
||||
uint32_t word = 0;
|
||||
const spv_result_t result = DecodeNonIdWord(&word);
|
||||
if (result != SPV_SUCCESS) return result;
|
||||
inst_words_.push_back(word);
|
||||
} else {
|
||||
assert(operand.number_bit_width <= 64);
|
||||
uint64_t word = 0;
|
||||
if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
|
||||
if (!reader_.ReadVariableWidthU64(&word, model_->u64_chunk_length()))
|
||||
return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal U64";
|
||||
} else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
|
||||
int64_t val = 0;
|
||||
if (!reader_.ReadVariableWidthS64(&val, model_->s64_chunk_length(),
|
||||
model_->s64_block_exponent()))
|
||||
return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal S64";
|
||||
std::memcpy(&word, &val, 8);
|
||||
} else if (operand.number_kind == SPV_NUMBER_FLOATING) {
|
||||
if (!reader_.ReadUnencoded(&word))
|
||||
return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal F64";
|
||||
} else {
|
||||
return Diag(SPV_ERROR_INTERNAL) << "Unsupported bit length";
|
||||
}
|
||||
inst_words_.push_back(static_cast<uint32_t>(word));
|
||||
inst_words_.push_back(static_cast<uint32_t>(word >> 32));
|
||||
}
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
bool MarkvDecoder::ReadToByteBreak(size_t byte_break_if_less_than) {
|
||||
const size_t num_bits_to_next_byte =
|
||||
GetNumBitsToNextByte(reader_.GetNumReadBits());
|
||||
if (num_bits_to_next_byte == 0 ||
|
||||
num_bits_to_next_byte > byte_break_if_less_than)
|
||||
return true;
|
||||
|
||||
uint64_t bits = 0;
|
||||
if (!reader_.ReadBits(&bits, num_bits_to_next_byte)) return false;
|
||||
|
||||
assert(bits == 0);
|
||||
if (bits != 0) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
spv_result_t MarkvDecoder::DecodeModule(std::vector<uint32_t>* spirv_binary) {
|
||||
const bool header_read_success =
|
||||
reader_.ReadUnencoded(&header_.magic_number) &&
|
||||
reader_.ReadUnencoded(&header_.markv_version) &&
|
||||
reader_.ReadUnencoded(&header_.markv_model) &&
|
||||
reader_.ReadUnencoded(&header_.markv_length_in_bits) &&
|
||||
reader_.ReadUnencoded(&header_.spirv_version) &&
|
||||
reader_.ReadUnencoded(&header_.spirv_generator);
|
||||
|
||||
if (!header_read_success)
|
||||
return Diag(SPV_ERROR_INVALID_BINARY) << "Unable to read MARK-V header";
|
||||
|
||||
if (header_.markv_length_in_bits == 0)
|
||||
return Diag(SPV_ERROR_INVALID_BINARY)
|
||||
<< "Header markv_length_in_bits field is zero";
|
||||
|
||||
if (header_.magic_number != MarkvCodec::kMarkvMagicNumber)
|
||||
return Diag(SPV_ERROR_INVALID_BINARY)
|
||||
<< "MARK-V binary has incorrect magic number";
|
||||
|
||||
// TODO(atgoo@github.com): Print version strings.
|
||||
if (header_.markv_version != MarkvCodec::GetMarkvVersion())
|
||||
return Diag(SPV_ERROR_INVALID_BINARY)
|
||||
<< "MARK-V binary and the codec have different versions";
|
||||
|
||||
const uint32_t model_type = header_.markv_model >> 16;
|
||||
const uint32_t model_version = header_.markv_model & 0xFFFF;
|
||||
if (model_type != model_->model_type())
|
||||
return Diag(SPV_ERROR_INVALID_BINARY)
|
||||
<< "MARK-V binary and the codec use different MARK-V models";
|
||||
|
||||
if (model_version != model_->model_version())
|
||||
return Diag(SPV_ERROR_INVALID_BINARY)
|
||||
<< "MARK-V binary and the codec use different versions if the same "
|
||||
<< "MARK-V model";
|
||||
|
||||
spirv_.reserve(header_.markv_length_in_bits / 2); // Heuristic.
|
||||
spirv_.resize(5, 0);
|
||||
spirv_[0] = SpvMagicNumber;
|
||||
spirv_[1] = header_.spirv_version;
|
||||
spirv_[2] = header_.spirv_generator;
|
||||
|
||||
if (logger_) {
|
||||
reader_.SetCallback(
|
||||
[this](const std::string& str) { logger_->AppendBitSequence(str); });
|
||||
}
|
||||
|
||||
while (reader_.GetNumReadBits() < header_.markv_length_in_bits) {
|
||||
inst_ = {};
|
||||
const spv_result_t decode_result = DecodeInstruction();
|
||||
if (decode_result != SPV_SUCCESS) return decode_result;
|
||||
}
|
||||
|
||||
if (validator_options_) {
|
||||
spv_const_binary_t validation_binary = {spirv_.data(), spirv_.size()};
|
||||
const spv_result_t result = spvValidateWithOptions(
|
||||
context_, validator_options_, &validation_binary, nullptr);
|
||||
if (result != SPV_SUCCESS) return result;
|
||||
}
|
||||
|
||||
// Validate the decode binary
|
||||
if (reader_.GetNumReadBits() != header_.markv_length_in_bits ||
|
||||
!reader_.OnlyZeroesLeft()) {
|
||||
return Diag(SPV_ERROR_INVALID_BINARY)
|
||||
<< "MARK-V binary has wrong stated bit length "
|
||||
<< reader_.GetNumReadBits() << " " << header_.markv_length_in_bits;
|
||||
}
|
||||
|
||||
// Decoding of the module is finished, validation state should have correct
|
||||
// id bound.
|
||||
spirv_[3] = GetIdBound();
|
||||
|
||||
*spirv_binary = std::move(spirv_);
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
// TODO(atgoo@github.com): The implementation borrows heavily from
|
||||
// Parser::parseOperand.
|
||||
// Consider coupling them together in some way once MARK-V codec is more mature.
|
||||
// For now it's better to keep the code independent for experimentation
|
||||
// purposes.
|
||||
spv_result_t MarkvDecoder::DecodeOperand(
|
||||
size_t operand_offset, const spv_operand_type_t type,
|
||||
spv_operand_pattern_t* expected_operands) {
|
||||
const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
|
||||
|
||||
memset(&operand_, 0, sizeof(operand_));
|
||||
|
||||
assert((operand_offset >> 16) == 0);
|
||||
operand_.offset = static_cast<uint16_t>(operand_offset);
|
||||
operand_.type = type;
|
||||
|
||||
// Set default values, may be updated later.
|
||||
operand_.number_kind = SPV_NUMBER_NONE;
|
||||
operand_.number_bit_width = 0;
|
||||
|
||||
const size_t first_word_index = inst_words_.size();
|
||||
|
||||
switch (type) {
|
||||
case SPV_OPERAND_TYPE_RESULT_ID: {
|
||||
const spv_result_t result = DecodeResultId();
|
||||
if (result != SPV_SUCCESS) return result;
|
||||
|
||||
inst_words_.push_back(inst_.result_id);
|
||||
SetIdBound(std::max(GetIdBound(), inst_.result_id + 1));
|
||||
PromoteIfNeeded(inst_.result_id);
|
||||
break;
|
||||
}
|
||||
|
||||
case SPV_OPERAND_TYPE_TYPE_ID: {
|
||||
const spv_result_t result = DecodeTypeId();
|
||||
if (result != SPV_SUCCESS) return result;
|
||||
|
||||
inst_words_.push_back(inst_.type_id);
|
||||
SetIdBound(std::max(GetIdBound(), inst_.type_id + 1));
|
||||
PromoteIfNeeded(inst_.type_id);
|
||||
break;
|
||||
}
|
||||
|
||||
case SPV_OPERAND_TYPE_ID:
|
||||
case SPV_OPERAND_TYPE_OPTIONAL_ID:
|
||||
case SPV_OPERAND_TYPE_SCOPE_ID:
|
||||
case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: {
|
||||
uint32_t id = 0;
|
||||
const spv_result_t result = DecodeRefId(&id);
|
||||
if (result != SPV_SUCCESS) return result;
|
||||
|
||||
if (id == 0) return Diag(SPV_ERROR_INVALID_BINARY) << "Decoded id is 0";
|
||||
|
||||
if (type == SPV_OPERAND_TYPE_ID || type == SPV_OPERAND_TYPE_OPTIONAL_ID) {
|
||||
operand_.type = SPV_OPERAND_TYPE_ID;
|
||||
|
||||
if (opcode == SpvOpExtInst && operand_.offset == 3) {
|
||||
// The current word is the extended instruction set id.
|
||||
// Set the extended instruction set type for the current
|
||||
// instruction.
|
||||
auto ext_inst_type_iter = import_id_to_ext_inst_type_.find(id);
|
||||
if (ext_inst_type_iter == import_id_to_ext_inst_type_.end()) {
|
||||
return Diag(SPV_ERROR_INVALID_ID)
|
||||
<< "OpExtInst set id " << id
|
||||
<< " does not reference an OpExtInstImport result Id";
|
||||
}
|
||||
inst_.ext_inst_type = ext_inst_type_iter->second;
|
||||
}
|
||||
}
|
||||
|
||||
inst_words_.push_back(id);
|
||||
SetIdBound(std::max(GetIdBound(), id + 1));
|
||||
PromoteIfNeeded(id);
|
||||
break;
|
||||
}
|
||||
|
||||
case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER: {
|
||||
uint32_t word = 0;
|
||||
const spv_result_t result = DecodeNonIdWord(&word);
|
||||
if (result != SPV_SUCCESS) return result;
|
||||
|
||||
inst_words_.push_back(word);
|
||||
|
||||
assert(SpvOpExtInst == opcode);
|
||||
assert(inst_.ext_inst_type != SPV_EXT_INST_TYPE_NONE);
|
||||
spv_ext_inst_desc ext_inst;
|
||||
if (grammar_.lookupExtInst(inst_.ext_inst_type, word, &ext_inst))
|
||||
return Diag(SPV_ERROR_INVALID_BINARY)
|
||||
<< "Invalid extended instruction number: " << word;
|
||||
spvPushOperandTypes(ext_inst->operandTypes, expected_operands);
|
||||
break;
|
||||
}
|
||||
|
||||
case SPV_OPERAND_TYPE_LITERAL_INTEGER:
|
||||
case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER: {
|
||||
// These are regular single-word literal integer operands.
|
||||
// Post-parsing validation should check the range of the parsed value.
|
||||
operand_.type = SPV_OPERAND_TYPE_LITERAL_INTEGER;
|
||||
// It turns out they are always unsigned integers!
|
||||
operand_.number_kind = SPV_NUMBER_UNSIGNED_INT;
|
||||
operand_.number_bit_width = 32;
|
||||
|
||||
uint32_t word = 0;
|
||||
const spv_result_t result = DecodeNonIdWord(&word);
|
||||
if (result != SPV_SUCCESS) return result;
|
||||
|
||||
inst_words_.push_back(word);
|
||||
break;
|
||||
}
|
||||
|
||||
case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER:
|
||||
case SPV_OPERAND_TYPE_OPTIONAL_TYPED_LITERAL_INTEGER: {
|
||||
operand_.type = SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER;
|
||||
if (opcode == SpvOpSwitch) {
|
||||
// The literal operands have the same type as the value
|
||||
// referenced by the selector Id.
|
||||
const uint32_t selector_id = inst_words_.at(1);
|
||||
const auto type_id_iter = id_to_type_id_.find(selector_id);
|
||||
if (type_id_iter == id_to_type_id_.end() || type_id_iter->second == 0) {
|
||||
return Diag(SPV_ERROR_INVALID_BINARY)
|
||||
<< "Invalid OpSwitch: selector id " << selector_id
|
||||
<< " has no type";
|
||||
}
|
||||
uint32_t type_id = type_id_iter->second;
|
||||
|
||||
if (selector_id == type_id) {
|
||||
// Recall that by convention, a result ID that is a type definition
|
||||
// maps to itself.
|
||||
return Diag(SPV_ERROR_INVALID_BINARY)
|
||||
<< "Invalid OpSwitch: selector id " << selector_id
|
||||
<< " is a type, not a value";
|
||||
}
|
||||
if (auto error = SetNumericTypeInfoForType(&operand_, type_id))
|
||||
return error;
|
||||
if (operand_.number_kind != SPV_NUMBER_UNSIGNED_INT &&
|
||||
operand_.number_kind != SPV_NUMBER_SIGNED_INT) {
|
||||
return Diag(SPV_ERROR_INVALID_BINARY)
|
||||
<< "Invalid OpSwitch: selector id " << selector_id
|
||||
<< " is not a scalar integer";
|
||||
}
|
||||
} else {
|
||||
assert(opcode == SpvOpConstant || opcode == SpvOpSpecConstant);
|
||||
// The literal number type is determined by the type Id for the
|
||||
// constant.
|
||||
assert(inst_.type_id);
|
||||
if (auto error = SetNumericTypeInfoForType(&operand_, inst_.type_id))
|
||||
return error;
|
||||
}
|
||||
|
||||
if (auto error = DecodeLiteralNumber(operand_)) return error;
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
case SPV_OPERAND_TYPE_LITERAL_STRING:
|
||||
case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_STRING: {
|
||||
operand_.type = SPV_OPERAND_TYPE_LITERAL_STRING;
|
||||
std::vector<char> str;
|
||||
auto* codec = model_->GetLiteralStringHuffmanCodec(inst_.opcode);
|
||||
|
||||
if (codec) {
|
||||
std::string decoded_string;
|
||||
const bool huffman_result =
|
||||
codec->DecodeFromStream(GetReadBitCallback(), &decoded_string);
|
||||
assert(huffman_result);
|
||||
if (!huffman_result)
|
||||
return Diag(SPV_ERROR_INVALID_BINARY)
|
||||
<< "Failed to read literal string";
|
||||
|
||||
if (decoded_string != "kMarkvNoneOfTheAbove") {
|
||||
std::copy(decoded_string.begin(), decoded_string.end(),
|
||||
std::back_inserter(str));
|
||||
str.push_back('\0');
|
||||
}
|
||||
}
|
||||
|
||||
// The loop is expected to terminate once we encounter '\0' or exhaust
|
||||
// the bit stream.
|
||||
if (str.empty()) {
|
||||
while (true) {
|
||||
char ch = 0;
|
||||
if (!reader_.ReadUnencoded(&ch))
|
||||
return Diag(SPV_ERROR_INVALID_BINARY)
|
||||
<< "Failed to read literal string";
|
||||
|
||||
str.push_back(ch);
|
||||
|
||||
if (ch == '\0') break;
|
||||
}
|
||||
}
|
||||
|
||||
while (str.size() % 4 != 0) str.push_back('\0');
|
||||
|
||||
inst_words_.resize(inst_words_.size() + str.size() / 4);
|
||||
std::memcpy(&inst_words_[first_word_index], str.data(), str.size());
|
||||
|
||||
if (SpvOpExtInstImport == opcode) {
|
||||
// Record the extended instruction type for the ID for this import.
|
||||
// There is only one string literal argument to OpExtInstImport,
|
||||
// so it's sufficient to guard this just on the opcode.
|
||||
const spv_ext_inst_type_t ext_inst_type =
|
||||
spvExtInstImportTypeGet(str.data());
|
||||
if (SPV_EXT_INST_TYPE_NONE == ext_inst_type) {
|
||||
return Diag(SPV_ERROR_INVALID_BINARY)
|
||||
<< "Invalid extended instruction import '" << str.data()
|
||||
<< "'";
|
||||
}
|
||||
// We must have parsed a valid result ID. It's a condition
|
||||
// of the grammar, and we only accept non-zero result Ids.
|
||||
assert(inst_.result_id);
|
||||
const bool inserted =
|
||||
import_id_to_ext_inst_type_.emplace(inst_.result_id, ext_inst_type)
|
||||
.second;
|
||||
(void)inserted;
|
||||
assert(inserted);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case SPV_OPERAND_TYPE_CAPABILITY:
|
||||
case SPV_OPERAND_TYPE_SOURCE_LANGUAGE:
|
||||
case SPV_OPERAND_TYPE_EXECUTION_MODEL:
|
||||
case SPV_OPERAND_TYPE_ADDRESSING_MODEL:
|
||||
case SPV_OPERAND_TYPE_MEMORY_MODEL:
|
||||
case SPV_OPERAND_TYPE_EXECUTION_MODE:
|
||||
case SPV_OPERAND_TYPE_STORAGE_CLASS:
|
||||
case SPV_OPERAND_TYPE_DIMENSIONALITY:
|
||||
case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE:
|
||||
case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE:
|
||||
case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT:
|
||||
case SPV_OPERAND_TYPE_FP_ROUNDING_MODE:
|
||||
case SPV_OPERAND_TYPE_LINKAGE_TYPE:
|
||||
case SPV_OPERAND_TYPE_ACCESS_QUALIFIER:
|
||||
case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER:
|
||||
case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE:
|
||||
case SPV_OPERAND_TYPE_DECORATION:
|
||||
case SPV_OPERAND_TYPE_BUILT_IN:
|
||||
case SPV_OPERAND_TYPE_GROUP_OPERATION:
|
||||
case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS:
|
||||
case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO: {
|
||||
// A single word that is a plain enum value.
|
||||
uint32_t word = 0;
|
||||
const spv_result_t result = DecodeNonIdWord(&word);
|
||||
if (result != SPV_SUCCESS) return result;
|
||||
|
||||
inst_words_.push_back(word);
|
||||
|
||||
// Map an optional operand type to its corresponding concrete type.
|
||||
if (type == SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER)
|
||||
operand_.type = SPV_OPERAND_TYPE_ACCESS_QUALIFIER;
|
||||
|
||||
spv_operand_desc entry;
|
||||
if (grammar_.lookupOperand(type, word, &entry)) {
|
||||
return Diag(SPV_ERROR_INVALID_BINARY)
|
||||
<< "Invalid " << spvOperandTypeStr(operand_.type)
|
||||
<< " operand: " << word;
|
||||
}
|
||||
|
||||
// Prepare to accept operands to this operand, if needed.
|
||||
spvPushOperandTypes(entry->operandTypes, expected_operands);
|
||||
break;
|
||||
}
|
||||
|
||||
case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE:
|
||||
case SPV_OPERAND_TYPE_FUNCTION_CONTROL:
|
||||
case SPV_OPERAND_TYPE_LOOP_CONTROL:
|
||||
case SPV_OPERAND_TYPE_IMAGE:
|
||||
case SPV_OPERAND_TYPE_OPTIONAL_IMAGE:
|
||||
case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS:
|
||||
case SPV_OPERAND_TYPE_SELECTION_CONTROL: {
|
||||
// This operand is a mask.
|
||||
uint32_t word = 0;
|
||||
const spv_result_t result = DecodeNonIdWord(&word);
|
||||
if (result != SPV_SUCCESS) return result;
|
||||
|
||||
inst_words_.push_back(word);
|
||||
|
||||
// Map an optional operand type to its corresponding concrete type.
|
||||
if (type == SPV_OPERAND_TYPE_OPTIONAL_IMAGE)
|
||||
operand_.type = SPV_OPERAND_TYPE_IMAGE;
|
||||
else if (type == SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS)
|
||||
operand_.type = SPV_OPERAND_TYPE_MEMORY_ACCESS;
|
||||
|
||||
// Check validity of set mask bits. Also prepare for operands for those
|
||||
// masks if they have any. To get operand order correct, scan from
|
||||
// MSB to LSB since we can only prepend operands to a pattern.
|
||||
// The only case in the grammar where you have more than one mask bit
|
||||
// having an operand is for image operands. See SPIR-V 3.14 Image
|
||||
// Operands.
|
||||
uint32_t remaining_word = word;
|
||||
for (uint32_t mask = (1u << 31); remaining_word; mask >>= 1) {
|
||||
if (remaining_word & mask) {
|
||||
spv_operand_desc entry;
|
||||
if (grammar_.lookupOperand(type, mask, &entry)) {
|
||||
return Diag(SPV_ERROR_INVALID_BINARY)
|
||||
<< "Invalid " << spvOperandTypeStr(operand_.type)
|
||||
<< " operand: " << word << " has invalid mask component "
|
||||
<< mask;
|
||||
}
|
||||
remaining_word ^= mask;
|
||||
spvPushOperandTypes(entry->operandTypes, expected_operands);
|
||||
}
|
||||
}
|
||||
if (word == 0) {
|
||||
// An all-zeroes mask *might* also be valid.
|
||||
spv_operand_desc entry;
|
||||
if (SPV_SUCCESS == grammar_.lookupOperand(type, 0, &entry)) {
|
||||
// Prepare for its operands, if any.
|
||||
spvPushOperandTypes(entry->operandTypes, expected_operands);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return Diag(SPV_ERROR_INVALID_BINARY)
|
||||
<< "Internal error: Unhandled operand type: " << type;
|
||||
}
|
||||
|
||||
operand_.num_words = uint16_t(inst_words_.size() - first_word_index);
|
||||
|
||||
assert(spvOperandIsConcrete(operand_.type));
|
||||
|
||||
parsed_operands_.push_back(operand_);
|
||||
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
spv_result_t MarkvDecoder::DecodeInstruction() {
|
||||
parsed_operands_.clear();
|
||||
inst_words_.clear();
|
||||
|
||||
// Opcode/num_words placeholder, the word will be filled in later.
|
||||
inst_words_.push_back(0);
|
||||
|
||||
bool num_operands_still_unknown = true;
|
||||
{
|
||||
uint32_t opcode = 0;
|
||||
uint32_t num_operands = 0;
|
||||
|
||||
const spv_result_t opcode_decoding_result =
|
||||
DecodeOpcodeAndNumberOfOperands(&opcode, &num_operands);
|
||||
if (opcode_decoding_result < 0) return opcode_decoding_result;
|
||||
|
||||
if (opcode_decoding_result == SPV_SUCCESS) {
|
||||
inst_.num_operands = static_cast<uint16_t>(num_operands);
|
||||
num_operands_still_unknown = false;
|
||||
} else {
|
||||
if (!reader_.ReadVariableWidthU32(&opcode,
|
||||
model_->opcode_chunk_length())) {
|
||||
return Diag(SPV_ERROR_INVALID_BINARY)
|
||||
<< "Failed to read opcode of instruction";
|
||||
}
|
||||
}
|
||||
|
||||
inst_.opcode = static_cast<uint16_t>(opcode);
|
||||
}
|
||||
|
||||
const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
|
||||
|
||||
spv_opcode_desc opcode_desc;
|
||||
if (grammar_.lookupOpcode(opcode, &opcode_desc) != SPV_SUCCESS) {
|
||||
return Diag(SPV_ERROR_INVALID_BINARY) << "Invalid opcode";
|
||||
}
|
||||
|
||||
spv_operand_pattern_t expected_operands;
|
||||
expected_operands.reserve(opcode_desc->numTypes);
|
||||
for (auto i = 0; i < opcode_desc->numTypes; i++) {
|
||||
expected_operands.push_back(
|
||||
opcode_desc->operandTypes[opcode_desc->numTypes - i - 1]);
|
||||
}
|
||||
|
||||
if (num_operands_still_unknown) {
|
||||
if (!OpcodeHasFixedNumberOfOperands(opcode)) {
|
||||
if (!reader_.ReadVariableWidthU16(&inst_.num_operands,
|
||||
model_->num_operands_chunk_length()))
|
||||
return Diag(SPV_ERROR_INVALID_BINARY)
|
||||
<< "Failed to read num_operands of instruction";
|
||||
} else {
|
||||
inst_.num_operands = static_cast<uint16_t>(expected_operands.size());
|
||||
}
|
||||
}
|
||||
|
||||
for (operand_index_ = 0;
|
||||
operand_index_ < static_cast<size_t>(inst_.num_operands);
|
||||
++operand_index_) {
|
||||
assert(!expected_operands.empty());
|
||||
const spv_operand_type_t type =
|
||||
spvTakeFirstMatchableOperand(&expected_operands);
|
||||
|
||||
const size_t operand_offset = inst_words_.size();
|
||||
|
||||
const spv_result_t decode_result =
|
||||
DecodeOperand(operand_offset, type, &expected_operands);
|
||||
|
||||
if (decode_result != SPV_SUCCESS) return decode_result;
|
||||
}
|
||||
|
||||
assert(inst_.num_operands == parsed_operands_.size());
|
||||
|
||||
// Only valid while inst_words_ and parsed_operands_ remain unchanged (until
|
||||
// next DecodeInstruction call).
|
||||
inst_.words = inst_words_.data();
|
||||
inst_.operands = parsed_operands_.empty() ? nullptr : parsed_operands_.data();
|
||||
inst_.num_words = static_cast<uint16_t>(inst_words_.size());
|
||||
inst_words_[0] = spvOpcodeMake(inst_.num_words, SpvOp(inst_.opcode));
|
||||
|
||||
std::copy(inst_words_.begin(), inst_words_.end(), std::back_inserter(spirv_));
|
||||
|
||||
assert(inst_.num_words ==
|
||||
std::accumulate(
|
||||
parsed_operands_.begin(), parsed_operands_.end(), 1,
|
||||
[](int num_words, const spv_parsed_operand_t& operand) {
|
||||
return num_words += operand.num_words;
|
||||
}) &&
|
||||
"num_words in instruction doesn't correspond to the sum of num_words"
|
||||
"in the operands");
|
||||
|
||||
RecordNumberType();
|
||||
ProcessCurInstruction();
|
||||
|
||||
if (!ReadToByteBreak(MarkvCodec::kByteBreakAfterInstIfLessThanUntilNextByte))
|
||||
return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read to byte break";
|
||||
|
||||
if (logger_) {
|
||||
logger_->NewLine();
|
||||
std::stringstream ss;
|
||||
ss << spvOpcodeString(opcode) << " ";
|
||||
for (size_t index = 1; index < inst_words_.size(); ++index)
|
||||
ss << inst_words_[index] << " ";
|
||||
logger_->AppendText(ss.str());
|
||||
logger_->NewLine();
|
||||
logger_->NewLine();
|
||||
if (!logger_->DebugInstruction(inst_)) return SPV_REQUESTED_TERMINATION;
|
||||
}
|
||||
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
spv_result_t MarkvDecoder::SetNumericTypeInfoForType(
|
||||
spv_parsed_operand_t* parsed_operand, uint32_t type_id) {
|
||||
assert(type_id != 0);
|
||||
auto type_info_iter = type_id_to_number_type_info_.find(type_id);
|
||||
if (type_info_iter == type_id_to_number_type_info_.end()) {
|
||||
return Diag(SPV_ERROR_INVALID_BINARY)
|
||||
<< "Type Id " << type_id << " is not a type";
|
||||
}
|
||||
|
||||
const NumberType& info = type_info_iter->second;
|
||||
if (info.type == SPV_NUMBER_NONE) {
|
||||
// This is a valid type, but for something other than a scalar number.
|
||||
return Diag(SPV_ERROR_INVALID_BINARY)
|
||||
<< "Type Id " << type_id << " is not a scalar numeric type";
|
||||
}
|
||||
|
||||
parsed_operand->number_kind = info.type;
|
||||
parsed_operand->number_bit_width = info.bit_width;
|
||||
// Round up the word count.
|
||||
parsed_operand->num_words = static_cast<uint16_t>((info.bit_width + 31) / 32);
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
void MarkvDecoder::RecordNumberType() {
|
||||
const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
|
||||
if (spvOpcodeGeneratesType(opcode)) {
|
||||
NumberType info = {SPV_NUMBER_NONE, 0};
|
||||
if (SpvOpTypeInt == opcode) {
|
||||
info.bit_width = inst_.words[inst_.operands[1].offset];
|
||||
info.type = inst_.words[inst_.operands[2].offset]
|
||||
? SPV_NUMBER_SIGNED_INT
|
||||
: SPV_NUMBER_UNSIGNED_INT;
|
||||
} else if (SpvOpTypeFloat == opcode) {
|
||||
info.bit_width = inst_.words[inst_.operands[1].offset];
|
||||
info.type = SPV_NUMBER_FLOATING;
|
||||
}
|
||||
// The *result* Id of a type generating instruction is the type Id.
|
||||
type_id_to_number_type_info_[inst_.result_id] = info;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace comp
|
||||
} // namespace spvtools
|
@ -1,175 +0,0 @@
|
||||
// Copyright (c) 2018 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "source/comp/bit_stream.h"
|
||||
#include "source/comp/markv.h"
|
||||
#include "source/comp/markv_codec.h"
|
||||
#include "source/comp/markv_logger.h"
|
||||
#include "source/util/make_unique.h"
|
||||
|
||||
#ifndef SOURCE_COMP_MARKV_DECODER_H_
|
||||
#define SOURCE_COMP_MARKV_DECODER_H_
|
||||
|
||||
namespace spvtools {
|
||||
namespace comp {
|
||||
|
||||
class MarkvLogger;
|
||||
|
||||
// Decodes MARK-V buffers written by MarkvEncoder.
|
||||
class MarkvDecoder : public MarkvCodec {
|
||||
public:
|
||||
// |model| is owned by the caller, must be not null and valid during the
|
||||
// lifetime of MarkvEncoder.
|
||||
MarkvDecoder(spv_const_context context, const std::vector<uint8_t>& markv,
|
||||
const MarkvCodecOptions& options, const MarkvModel* model)
|
||||
: MarkvCodec(context, GetValidatorOptions(options), model),
|
||||
options_(options),
|
||||
reader_(markv) {
|
||||
SetIdBound(1);
|
||||
parsed_operands_.reserve(25);
|
||||
inst_words_.reserve(25);
|
||||
}
|
||||
~MarkvDecoder() = default;
|
||||
|
||||
// Creates an internal logger which writes comments on the decoding process.
|
||||
void CreateLogger(MarkvLogConsumer log_consumer,
|
||||
MarkvDebugConsumer debug_consumer) {
|
||||
logger_ = MakeUnique<MarkvLogger>(log_consumer, debug_consumer);
|
||||
}
|
||||
|
||||
// Decodes SPIR-V from MARK-V and stores the words in |spirv_binary|.
|
||||
// Can be called only once. Fails if data of wrong format or ends prematurely,
|
||||
// of if validation fails.
|
||||
spv_result_t DecodeModule(std::vector<uint32_t>* spirv_binary);
|
||||
|
||||
// Creates and returns validator options. Returned value owned by the caller.
|
||||
static spv_validator_options GetValidatorOptions(
|
||||
const MarkvCodecOptions& options) {
|
||||
return options.validate_spirv_binary ? spvValidatorOptionsCreate()
|
||||
: nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
// Describes the format of a typed literal number.
|
||||
struct NumberType {
|
||||
spv_number_kind_t type;
|
||||
uint32_t bit_width;
|
||||
};
|
||||
|
||||
// Reads a single bit from reader_. The read bit is stored in |bit|.
|
||||
// Returns false iff reader_ fails.
|
||||
bool ReadBit(bool* bit) {
|
||||
uint64_t bits = 0;
|
||||
const bool result = reader_.ReadBits(&bits, 1);
|
||||
if (result) *bit = bits ? true : false;
|
||||
return result;
|
||||
};
|
||||
|
||||
// Returns ReadBit bound to the class object.
|
||||
std::function<bool(bool*)> GetReadBitCallback() {
|
||||
return std::bind(&MarkvDecoder::ReadBit, this, std::placeholders::_1);
|
||||
}
|
||||
|
||||
// Reads a single non-id word from bit stream. operand_.type determines if
|
||||
// the word needs to be decoded and how.
|
||||
spv_result_t DecodeNonIdWord(uint32_t* word);
|
||||
|
||||
// Reads and decodes both opcode and num_operands as a single code.
|
||||
// Returns SPV_UNSUPPORTED iff no suitable codec was found.
|
||||
spv_result_t DecodeOpcodeAndNumberOfOperands(uint32_t* opcode,
|
||||
uint32_t* num_operands);
|
||||
|
||||
// Reads mtf rank from bit stream. |mtf| is used to determine the codec
|
||||
// scheme. |fallback_method| is used if no codec defined for |mtf|.
|
||||
spv_result_t DecodeMtfRankHuffman(uint64_t mtf, uint32_t fallback_method,
|
||||
uint32_t* rank);
|
||||
|
||||
// Reads id using coding based on mtf associated with the id descriptor.
|
||||
// Returns SPV_UNSUPPORTED iff fallback method needs to be used.
|
||||
spv_result_t DecodeIdWithDescriptor(uint32_t* id);
|
||||
|
||||
// Reads id using coding based on the given |mtf|, which is expected to
|
||||
// contain the needed |id|.
|
||||
spv_result_t DecodeExistingId(uint64_t mtf, uint32_t* id);
|
||||
|
||||
// Reads type id of the current instruction if can't be inferred.
|
||||
spv_result_t DecodeTypeId();
|
||||
|
||||
// Reads result id of the current instruction if can't be inferred.
|
||||
spv_result_t DecodeResultId();
|
||||
|
||||
// Reads id which is neither type nor result id.
|
||||
spv_result_t DecodeRefId(uint32_t* id);
|
||||
|
||||
// Reads and discards bits until the beginning of the next byte if the
|
||||
// number of bits until the next byte is less than |byte_break_if_less_than|.
|
||||
bool ReadToByteBreak(size_t byte_break_if_less_than);
|
||||
|
||||
// Returns instruction words decoded up to this point.
|
||||
const uint32_t* GetInstWords() const override { return inst_words_.data(); }
|
||||
|
||||
// Reads a literal number as it is described in |operand| from the bit stream,
|
||||
// decodes and writes it to spirv_.
|
||||
spv_result_t DecodeLiteralNumber(const spv_parsed_operand_t& operand);
|
||||
|
||||
// Reads instruction from bit stream, decodes and validates it.
|
||||
// Decoded instruction is valid until the next call of DecodeInstruction().
|
||||
spv_result_t DecodeInstruction();
|
||||
|
||||
// Read operand from the stream decodes and validates it.
|
||||
spv_result_t DecodeOperand(size_t operand_offset,
|
||||
const spv_operand_type_t type,
|
||||
spv_operand_pattern_t* expected_operands);
|
||||
|
||||
// Records the numeric type for an operand according to the type information
|
||||
// associated with the given non-zero type Id. This can fail if the type Id
|
||||
// is not a type Id, or if the type Id does not reference a scalar numeric
|
||||
// type. On success, return SPV_SUCCESS and populates the num_words,
|
||||
// number_kind, and number_bit_width fields of parsed_operand.
|
||||
spv_result_t SetNumericTypeInfoForType(spv_parsed_operand_t* parsed_operand,
|
||||
uint32_t type_id);
|
||||
|
||||
// Records the number type for the current instruction, if it generates a
|
||||
// type. For types that aren't scalar numbers, record something with number
|
||||
// kind SPV_NUMBER_NONE.
|
||||
void RecordNumberType();
|
||||
|
||||
MarkvCodecOptions options_;
|
||||
|
||||
// Temporary sink where decoded SPIR-V words are written. Once it contains the
|
||||
// entire module, the container is moved and returned.
|
||||
std::vector<uint32_t> spirv_;
|
||||
|
||||
// Bit stream containing encoded data.
|
||||
BitReaderWord64 reader_;
|
||||
|
||||
// Temporary storage for operands of the currently parsed instruction.
|
||||
// Valid until next DecodeInstruction call.
|
||||
std::vector<spv_parsed_operand_t> parsed_operands_;
|
||||
|
||||
// Temporary storage for current instruction words.
|
||||
// Valid until next DecodeInstruction call.
|
||||
std::vector<uint32_t> inst_words_;
|
||||
|
||||
// Maps a type ID to its number type description.
|
||||
std::unordered_map<uint32_t, NumberType> type_id_to_number_type_info_;
|
||||
|
||||
// Maps an ExtInstImport id to the extended instruction type.
|
||||
std::unordered_map<uint32_t, spv_ext_inst_type_t> import_id_to_ext_inst_type_;
|
||||
};
|
||||
|
||||
} // namespace comp
|
||||
} // namespace spvtools
|
||||
|
||||
#endif // SOURCE_COMP_MARKV_DECODER_H_
|
@ -1,486 +0,0 @@
|
||||
// Copyright (c) 2018 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "source/comp/markv_encoder.h"
|
||||
|
||||
#include "source/binary.h"
|
||||
#include "source/opcode.h"
|
||||
#include "spirv-tools/libspirv.hpp"
|
||||
|
||||
namespace spvtools {
|
||||
namespace comp {
|
||||
namespace {
|
||||
|
||||
const size_t kCommentNumWhitespaces = 2;
|
||||
|
||||
} // namespace
|
||||
|
||||
spv_result_t MarkvEncoder::EncodeNonIdWord(uint32_t word) {
|
||||
auto* codec = model_->GetNonIdWordHuffmanCodec(inst_.opcode, operand_index_);
|
||||
|
||||
if (codec) {
|
||||
uint64_t bits = 0;
|
||||
size_t num_bits = 0;
|
||||
if (codec->Encode(word, &bits, &num_bits)) {
|
||||
// Encoding successful.
|
||||
writer_.WriteBits(bits, num_bits);
|
||||
return SPV_SUCCESS;
|
||||
} else {
|
||||
// Encoding failed, write kMarkvNoneOfTheAbove flag.
|
||||
if (!codec->Encode(MarkvModel::GetMarkvNoneOfTheAbove(), &bits,
|
||||
&num_bits))
|
||||
return Diag(SPV_ERROR_INTERNAL)
|
||||
<< "Non-id word Huffman table for "
|
||||
<< spvOpcodeString(SpvOp(inst_.opcode)) << " operand index "
|
||||
<< operand_index_ << " is missing kMarkvNoneOfTheAbove";
|
||||
writer_.WriteBits(bits, num_bits);
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback encoding.
|
||||
const size_t chunk_length =
|
||||
model_->GetOperandVariableWidthChunkLength(operand_.type);
|
||||
if (chunk_length) {
|
||||
writer_.WriteVariableWidthU32(word, chunk_length);
|
||||
} else {
|
||||
writer_.WriteUnencoded(word);
|
||||
}
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
spv_result_t MarkvEncoder::EncodeOpcodeAndNumOperands(uint32_t opcode,
|
||||
uint32_t num_operands) {
|
||||
uint64_t bits = 0;
|
||||
size_t num_bits = 0;
|
||||
|
||||
const uint32_t word = opcode | (num_operands << 16);
|
||||
|
||||
// First try to use the Markov chain codec.
|
||||
auto* codec =
|
||||
model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(GetPrevOpcode());
|
||||
if (codec) {
|
||||
if (codec->Encode(word, &bits, &num_bits)) {
|
||||
// The word was successfully encoded into bits/num_bits.
|
||||
writer_.WriteBits(bits, num_bits);
|
||||
return SPV_SUCCESS;
|
||||
} else {
|
||||
// The word is not in the Huffman table. Write kMarkvNoneOfTheAbove
|
||||
// and use fallback encoding.
|
||||
if (!codec->Encode(MarkvModel::GetMarkvNoneOfTheAbove(), &bits,
|
||||
&num_bits))
|
||||
return Diag(SPV_ERROR_INTERNAL)
|
||||
<< "opcode_and_num_operands Huffman table for "
|
||||
<< spvOpcodeString(GetPrevOpcode())
|
||||
<< "is missing kMarkvNoneOfTheAbove";
|
||||
writer_.WriteBits(bits, num_bits);
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to base-rate codec.
|
||||
codec = model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(SpvOpNop);
|
||||
assert(codec);
|
||||
if (codec->Encode(word, &bits, &num_bits)) {
|
||||
// The word was successfully encoded into bits/num_bits.
|
||||
writer_.WriteBits(bits, num_bits);
|
||||
return SPV_SUCCESS;
|
||||
} else {
|
||||
// The word is not in the Huffman table. Write kMarkvNoneOfTheAbove
|
||||
// and return false.
|
||||
if (!codec->Encode(MarkvModel::GetMarkvNoneOfTheAbove(), &bits, &num_bits))
|
||||
return Diag(SPV_ERROR_INTERNAL)
|
||||
<< "Global opcode_and_num_operands Huffman table is missing "
|
||||
<< "kMarkvNoneOfTheAbove";
|
||||
writer_.WriteBits(bits, num_bits);
|
||||
return SPV_UNSUPPORTED;
|
||||
}
|
||||
}
|
||||
|
||||
spv_result_t MarkvEncoder::EncodeMtfRankHuffman(uint32_t rank, uint64_t mtf,
|
||||
uint64_t fallback_method) {
|
||||
const auto* codec = GetMtfHuffmanCodec(mtf);
|
||||
if (!codec) {
|
||||
assert(fallback_method != kMtfNone);
|
||||
codec = GetMtfHuffmanCodec(fallback_method);
|
||||
}
|
||||
|
||||
if (!codec) return Diag(SPV_ERROR_INTERNAL) << "No codec to encode MTF rank";
|
||||
|
||||
uint64_t bits = 0;
|
||||
size_t num_bits = 0;
|
||||
if (rank < MarkvCodec::kMtfSmallestRankEncodedByValue) {
|
||||
// Encode using Huffman coding.
|
||||
if (!codec->Encode(rank, &bits, &num_bits))
|
||||
return Diag(SPV_ERROR_INTERNAL)
|
||||
<< "Failed to encode MTF rank with Huffman";
|
||||
|
||||
writer_.WriteBits(bits, num_bits);
|
||||
} else {
|
||||
// Encode by value.
|
||||
if (!codec->Encode(MarkvCodec::kMtfRankEncodedByValueSignal, &bits,
|
||||
&num_bits))
|
||||
return Diag(SPV_ERROR_INTERNAL)
|
||||
<< "Failed to encode kMtfRankEncodedByValueSignal";
|
||||
|
||||
writer_.WriteBits(bits, num_bits);
|
||||
writer_.WriteVariableWidthU32(
|
||||
rank - MarkvCodec::kMtfSmallestRankEncodedByValue,
|
||||
model_->mtf_rank_chunk_length());
|
||||
}
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
spv_result_t MarkvEncoder::EncodeIdWithDescriptor(uint32_t id) {
|
||||
// Get the descriptor for id.
|
||||
const uint32_t long_descriptor = long_id_descriptors_.GetDescriptor(id);
|
||||
auto* codec =
|
||||
model_->GetIdDescriptorHuffmanCodec(inst_.opcode, operand_index_);
|
||||
uint64_t bits = 0;
|
||||
size_t num_bits = 0;
|
||||
uint64_t mtf = kMtfNone;
|
||||
if (long_descriptor && codec &&
|
||||
codec->Encode(long_descriptor, &bits, &num_bits)) {
|
||||
// If the descriptor exists and is in the table, write the descriptor and
|
||||
// proceed to encoding the rank.
|
||||
writer_.WriteBits(bits, num_bits);
|
||||
mtf = GetMtfLongIdDescriptor(long_descriptor);
|
||||
} else {
|
||||
if (codec) {
|
||||
// The descriptor doesn't exist or we have no coding for it. Write
|
||||
// kMarkvNoneOfTheAbove and go to fallback method.
|
||||
if (!codec->Encode(MarkvModel::GetMarkvNoneOfTheAbove(), &bits,
|
||||
&num_bits))
|
||||
return Diag(SPV_ERROR_INTERNAL)
|
||||
<< "Descriptor Huffman table for "
|
||||
<< spvOpcodeString(SpvOp(inst_.opcode)) << " operand index "
|
||||
<< operand_index_ << " is missing kMarkvNoneOfTheAbove";
|
||||
|
||||
writer_.WriteBits(bits, num_bits);
|
||||
}
|
||||
|
||||
if (model_->id_fallback_strategy() !=
|
||||
MarkvModel::IdFallbackStrategy::kShortDescriptor) {
|
||||
return SPV_UNSUPPORTED;
|
||||
}
|
||||
|
||||
const uint32_t short_descriptor = short_id_descriptors_.GetDescriptor(id);
|
||||
writer_.WriteBits(short_descriptor, MarkvCodec::kShortDescriptorNumBits);
|
||||
|
||||
if (short_descriptor == 0) {
|
||||
// Forward declared id.
|
||||
return SPV_UNSUPPORTED;
|
||||
}
|
||||
|
||||
mtf = GetMtfShortIdDescriptor(short_descriptor);
|
||||
}
|
||||
|
||||
// Descriptor has been encoded. Now encode the rank of the id in the
|
||||
// associated mtf sequence.
|
||||
return EncodeExistingId(mtf, id);
|
||||
}
|
||||
|
||||
spv_result_t MarkvEncoder::EncodeExistingId(uint64_t mtf, uint32_t id) {
|
||||
assert(multi_mtf_.GetSize(mtf) > 0);
|
||||
if (multi_mtf_.GetSize(mtf) == 1) {
|
||||
// If the sequence has only one element no need to write rank, the decoder
|
||||
// would make the same decision.
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
uint32_t rank = 0;
|
||||
if (!multi_mtf_.RankFromValue(mtf, id, &rank))
|
||||
return Diag(SPV_ERROR_INTERNAL) << "Id is not in the MTF sequence";
|
||||
|
||||
return EncodeMtfRankHuffman(rank, mtf, kMtfGenericNonZeroRank);
|
||||
}
|
||||
|
||||
spv_result_t MarkvEncoder::EncodeRefId(uint32_t id) {
|
||||
{
|
||||
// Try to encode using id descriptor mtfs.
|
||||
const spv_result_t result = EncodeIdWithDescriptor(id);
|
||||
if (result != SPV_UNSUPPORTED) return result;
|
||||
// If can't be done continue with other methods.
|
||||
}
|
||||
|
||||
const bool can_forward_declare = spvOperandCanBeForwardDeclaredFunction(
|
||||
SpvOp(inst_.opcode))(operand_index_);
|
||||
uint32_t rank = 0;
|
||||
|
||||
if (model_->id_fallback_strategy() ==
|
||||
MarkvModel::IdFallbackStrategy::kRuleBased) {
|
||||
// Encode using rule-based mtf.
|
||||
uint64_t mtf = GetRuleBasedMtf();
|
||||
|
||||
if (mtf != kMtfNone && !can_forward_declare) {
|
||||
assert(multi_mtf_.HasValue(kMtfAll, id));
|
||||
return EncodeExistingId(mtf, id);
|
||||
}
|
||||
|
||||
if (mtf == kMtfNone) mtf = kMtfAll;
|
||||
|
||||
if (!multi_mtf_.RankFromValue(mtf, id, &rank)) {
|
||||
// This is the first occurrence of a forward declared id.
|
||||
multi_mtf_.Insert(kMtfAll, id);
|
||||
multi_mtf_.Insert(kMtfForwardDeclared, id);
|
||||
if (mtf != kMtfAll) multi_mtf_.Insert(mtf, id);
|
||||
rank = 0;
|
||||
}
|
||||
|
||||
return EncodeMtfRankHuffman(rank, mtf, kMtfAll);
|
||||
} else {
|
||||
assert(can_forward_declare);
|
||||
|
||||
if (!multi_mtf_.RankFromValue(kMtfForwardDeclared, id, &rank)) {
|
||||
// This is the first occurrence of a forward declared id.
|
||||
multi_mtf_.Insert(kMtfForwardDeclared, id);
|
||||
rank = 0;
|
||||
}
|
||||
|
||||
writer_.WriteVariableWidthU32(rank, model_->mtf_rank_chunk_length());
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
}
|
||||
|
||||
spv_result_t MarkvEncoder::EncodeTypeId() {
|
||||
if (inst_.opcode == SpvOpFunctionParameter) {
|
||||
assert(!remaining_function_parameter_types_.empty());
|
||||
assert(inst_.type_id == remaining_function_parameter_types_.front());
|
||||
remaining_function_parameter_types_.pop_front();
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
{
|
||||
// Try to encode using id descriptor mtfs.
|
||||
const spv_result_t result = EncodeIdWithDescriptor(inst_.type_id);
|
||||
if (result != SPV_UNSUPPORTED) return result;
|
||||
// If can't be done continue with other methods.
|
||||
}
|
||||
|
||||
assert(model_->id_fallback_strategy() ==
|
||||
MarkvModel::IdFallbackStrategy::kRuleBased);
|
||||
|
||||
uint64_t mtf = GetRuleBasedMtf();
|
||||
assert(!spvOperandCanBeForwardDeclaredFunction(SpvOp(inst_.opcode))(
|
||||
operand_index_));
|
||||
|
||||
if (mtf == kMtfNone) {
|
||||
mtf = kMtfTypeNonFunction;
|
||||
// Function types should have been handled by GetRuleBasedMtf.
|
||||
assert(inst_.opcode != SpvOpFunction);
|
||||
}
|
||||
|
||||
return EncodeExistingId(mtf, inst_.type_id);
|
||||
}
|
||||
|
||||
spv_result_t MarkvEncoder::EncodeResultId() {
|
||||
uint32_t rank = 0;
|
||||
|
||||
const uint64_t num_still_forward_declared =
|
||||
multi_mtf_.GetSize(kMtfForwardDeclared);
|
||||
|
||||
if (num_still_forward_declared) {
|
||||
// We write the rank only if kMtfForwardDeclared is not empty. If it is
|
||||
// empty the decoder knows that there are no forward declared ids to expect.
|
||||
if (multi_mtf_.RankFromValue(kMtfForwardDeclared, inst_.result_id, &rank)) {
|
||||
// This is a definition of a forward declared id. We can remove the id
|
||||
// from kMtfForwardDeclared.
|
||||
if (!multi_mtf_.Remove(kMtfForwardDeclared, inst_.result_id))
|
||||
return Diag(SPV_ERROR_INTERNAL)
|
||||
<< "Failed to remove id from kMtfForwardDeclared";
|
||||
writer_.WriteBits(1, 1);
|
||||
writer_.WriteVariableWidthU32(rank, model_->mtf_rank_chunk_length());
|
||||
} else {
|
||||
rank = 0;
|
||||
writer_.WriteBits(0, 1);
|
||||
}
|
||||
}
|
||||
|
||||
if (model_->id_fallback_strategy() ==
|
||||
MarkvModel::IdFallbackStrategy::kRuleBased) {
|
||||
if (!rank) {
|
||||
multi_mtf_.Insert(kMtfAll, inst_.result_id);
|
||||
}
|
||||
}
|
||||
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
spv_result_t MarkvEncoder::EncodeLiteralNumber(
|
||||
const spv_parsed_operand_t& operand) {
|
||||
if (operand.number_bit_width <= 32) {
|
||||
const uint32_t word = inst_.words[operand.offset];
|
||||
return EncodeNonIdWord(word);
|
||||
} else {
|
||||
assert(operand.number_bit_width <= 64);
|
||||
const uint64_t word = uint64_t(inst_.words[operand.offset]) |
|
||||
(uint64_t(inst_.words[operand.offset + 1]) << 32);
|
||||
if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
|
||||
writer_.WriteVariableWidthU64(word, model_->u64_chunk_length());
|
||||
} else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
|
||||
int64_t val = 0;
|
||||
std::memcpy(&val, &word, 8);
|
||||
writer_.WriteVariableWidthS64(val, model_->s64_chunk_length(),
|
||||
model_->s64_block_exponent());
|
||||
} else if (operand.number_kind == SPV_NUMBER_FLOATING) {
|
||||
writer_.WriteUnencoded(word);
|
||||
} else {
|
||||
return Diag(SPV_ERROR_INTERNAL) << "Unsupported bit length";
|
||||
}
|
||||
}
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
void MarkvEncoder::AddByteBreak(size_t byte_break_if_less_than) {
|
||||
const size_t num_bits_to_next_byte =
|
||||
GetNumBitsToNextByte(writer_.GetNumBits());
|
||||
if (num_bits_to_next_byte == 0 ||
|
||||
num_bits_to_next_byte > byte_break_if_less_than)
|
||||
return;
|
||||
|
||||
if (logger_) {
|
||||
logger_->AppendWhitespaces(kCommentNumWhitespaces);
|
||||
logger_->AppendText("<byte break>");
|
||||
}
|
||||
|
||||
writer_.WriteBits(0, num_bits_to_next_byte);
|
||||
}
|
||||
|
||||
spv_result_t MarkvEncoder::EncodeInstruction(
|
||||
const spv_parsed_instruction_t& inst) {
|
||||
SpvOp opcode = SpvOp(inst.opcode);
|
||||
inst_ = inst;
|
||||
|
||||
LogDisassemblyInstruction();
|
||||
|
||||
const spv_result_t opcode_encodig_result =
|
||||
EncodeOpcodeAndNumOperands(opcode, inst.num_operands);
|
||||
if (opcode_encodig_result < 0) return opcode_encodig_result;
|
||||
|
||||
if (opcode_encodig_result != SPV_SUCCESS) {
|
||||
// Fallback encoding for opcode and num_operands.
|
||||
writer_.WriteVariableWidthU32(opcode, model_->opcode_chunk_length());
|
||||
|
||||
if (!OpcodeHasFixedNumberOfOperands(opcode)) {
|
||||
// If the opcode has a variable number of operands, encode the number of
|
||||
// operands with the instruction.
|
||||
|
||||
if (logger_) logger_->AppendWhitespaces(kCommentNumWhitespaces);
|
||||
|
||||
writer_.WriteVariableWidthU16(inst.num_operands,
|
||||
model_->num_operands_chunk_length());
|
||||
}
|
||||
}
|
||||
|
||||
// Write operands.
|
||||
const uint32_t num_operands = inst_.num_operands;
|
||||
for (operand_index_ = 0; operand_index_ < num_operands; ++operand_index_) {
|
||||
operand_ = inst_.operands[operand_index_];
|
||||
|
||||
if (logger_) {
|
||||
logger_->AppendWhitespaces(kCommentNumWhitespaces);
|
||||
logger_->AppendText("<");
|
||||
logger_->AppendText(spvOperandTypeStr(operand_.type));
|
||||
logger_->AppendText(">");
|
||||
}
|
||||
|
||||
switch (operand_.type) {
|
||||
case SPV_OPERAND_TYPE_RESULT_ID:
|
||||
case SPV_OPERAND_TYPE_TYPE_ID:
|
||||
case SPV_OPERAND_TYPE_ID:
|
||||
case SPV_OPERAND_TYPE_OPTIONAL_ID:
|
||||
case SPV_OPERAND_TYPE_SCOPE_ID:
|
||||
case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: {
|
||||
const uint32_t id = inst_.words[operand_.offset];
|
||||
if (operand_.type == SPV_OPERAND_TYPE_TYPE_ID) {
|
||||
const spv_result_t result = EncodeTypeId();
|
||||
if (result != SPV_SUCCESS) return result;
|
||||
} else if (operand_.type == SPV_OPERAND_TYPE_RESULT_ID) {
|
||||
const spv_result_t result = EncodeResultId();
|
||||
if (result != SPV_SUCCESS) return result;
|
||||
} else {
|
||||
const spv_result_t result = EncodeRefId(id);
|
||||
if (result != SPV_SUCCESS) return result;
|
||||
}
|
||||
|
||||
PromoteIfNeeded(id);
|
||||
break;
|
||||
}
|
||||
|
||||
case SPV_OPERAND_TYPE_LITERAL_INTEGER: {
|
||||
const spv_result_t result =
|
||||
EncodeNonIdWord(inst_.words[operand_.offset]);
|
||||
if (result != SPV_SUCCESS) return result;
|
||||
break;
|
||||
}
|
||||
|
||||
case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER: {
|
||||
const spv_result_t result = EncodeLiteralNumber(operand_);
|
||||
if (result != SPV_SUCCESS) return result;
|
||||
break;
|
||||
}
|
||||
|
||||
case SPV_OPERAND_TYPE_LITERAL_STRING: {
|
||||
const char* src =
|
||||
reinterpret_cast<const char*>(&inst_.words[operand_.offset]);
|
||||
|
||||
auto* codec = model_->GetLiteralStringHuffmanCodec(opcode);
|
||||
if (codec) {
|
||||
uint64_t bits = 0;
|
||||
size_t num_bits = 0;
|
||||
const std::string str = src;
|
||||
if (codec->Encode(str, &bits, &num_bits)) {
|
||||
writer_.WriteBits(bits, num_bits);
|
||||
break;
|
||||
} else {
|
||||
bool result =
|
||||
codec->Encode("kMarkvNoneOfTheAbove", &bits, &num_bits);
|
||||
(void)result;
|
||||
assert(result);
|
||||
writer_.WriteBits(bits, num_bits);
|
||||
}
|
||||
}
|
||||
|
||||
const size_t length = spv_strnlen_s(src, operand_.num_words * 4);
|
||||
if (length == operand_.num_words * 4)
|
||||
return Diag(SPV_ERROR_INVALID_BINARY)
|
||||
<< "Failed to find terminal character of literal string";
|
||||
for (size_t i = 0; i < length + 1; ++i) writer_.WriteUnencoded(src[i]);
|
||||
break;
|
||||
}
|
||||
|
||||
default: {
|
||||
for (int i = 0; i < operand_.num_words; ++i) {
|
||||
const uint32_t word = inst_.words[operand_.offset + i];
|
||||
const spv_result_t result = EncodeNonIdWord(word);
|
||||
if (result != SPV_SUCCESS) return result;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
AddByteBreak(MarkvCodec::kByteBreakAfterInstIfLessThanUntilNextByte);
|
||||
|
||||
if (logger_) {
|
||||
logger_->NewLine();
|
||||
logger_->NewLine();
|
||||
if (!logger_->DebugInstruction(inst_)) return SPV_REQUESTED_TERMINATION;
|
||||
}
|
||||
|
||||
ProcessCurInstruction();
|
||||
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
} // namespace comp
|
||||
} // namespace spvtools
|
@ -1,167 +0,0 @@
|
||||
// Copyright (c) 2018 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "source/comp/bit_stream.h"
|
||||
#include "source/comp/markv.h"
|
||||
#include "source/comp/markv_codec.h"
|
||||
#include "source/comp/markv_logger.h"
|
||||
#include "source/util/make_unique.h"
|
||||
|
||||
#ifndef SOURCE_COMP_MARKV_ENCODER_H_
|
||||
#define SOURCE_COMP_MARKV_ENCODER_H_
|
||||
|
||||
#include <cstring>
|
||||
|
||||
namespace spvtools {
|
||||
namespace comp {
|
||||
|
||||
// SPIR-V to MARK-V encoder. Exposes functions EncodeHeader and
|
||||
// EncodeInstruction which can be used as callback by spvBinaryParse.
|
||||
// Encoded binary is written to an internally maintained bitstream.
|
||||
// After the last instruction is encoded, the resulting MARK-V binary can be
|
||||
// acquired by calling GetMarkvBinary().
|
||||
//
|
||||
// The encoder uses SPIR-V validator to keep internal state, therefore
|
||||
// SPIR-V binary needs to be able to pass validator checks.
|
||||
// CreateCommentsLogger() can be used to enable the encoder to write comments
|
||||
// on how encoding was done, which can later be accessed with GetComments().
|
||||
class MarkvEncoder : public MarkvCodec {
|
||||
public:
|
||||
// |model| is owned by the caller, must be not null and valid during the
|
||||
// lifetime of MarkvEncoder.
|
||||
MarkvEncoder(spv_const_context context, const MarkvCodecOptions& options,
|
||||
const MarkvModel* model)
|
||||
: MarkvCodec(context, GetValidatorOptions(options), model),
|
||||
options_(options) {}
|
||||
~MarkvEncoder() override = default;
|
||||
|
||||
// Writes data from SPIR-V header to MARK-V header.
|
||||
spv_result_t EncodeHeader(spv_endianness_t /* endian */, uint32_t /* magic */,
|
||||
uint32_t version, uint32_t generator,
|
||||
uint32_t id_bound, uint32_t /* schema */) {
|
||||
SetIdBound(id_bound);
|
||||
header_.spirv_version = version;
|
||||
header_.spirv_generator = generator;
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
// Creates an internal logger which writes comments on the encoding process.
|
||||
void CreateLogger(MarkvLogConsumer log_consumer,
|
||||
MarkvDebugConsumer debug_consumer) {
|
||||
logger_ = MakeUnique<MarkvLogger>(log_consumer, debug_consumer);
|
||||
writer_.SetCallback(
|
||||
[this](const std::string& str) { logger_->AppendBitSequence(str); });
|
||||
}
|
||||
|
||||
// Encodes SPIR-V instruction to MARK-V and writes to bit stream.
|
||||
// Operation can fail if the instruction fails to pass the validator or if
|
||||
// the encoder stubmles on something unexpected.
|
||||
spv_result_t EncodeInstruction(const spv_parsed_instruction_t& inst);
|
||||
|
||||
// Concatenates MARK-V header and the bit stream with encoded instructions
|
||||
// into a single buffer and returns it as spv_markv_binary. The returned
|
||||
// value is owned by the caller and needs to be destroyed with
|
||||
// spvMarkvBinaryDestroy().
|
||||
std::vector<uint8_t> GetMarkvBinary() {
|
||||
header_.markv_length_in_bits =
|
||||
static_cast<uint32_t>(sizeof(header_) * 8 + writer_.GetNumBits());
|
||||
header_.markv_model =
|
||||
(model_->model_type() << 16) | model_->model_version();
|
||||
|
||||
const size_t num_bytes = sizeof(header_) + writer_.GetDataSizeBytes();
|
||||
std::vector<uint8_t> markv(num_bytes);
|
||||
|
||||
assert(writer_.GetData());
|
||||
std::memcpy(markv.data(), &header_, sizeof(header_));
|
||||
std::memcpy(markv.data() + sizeof(header_), writer_.GetData(),
|
||||
writer_.GetDataSizeBytes());
|
||||
return markv;
|
||||
}
|
||||
|
||||
// Optionally adds disassembly to the comments.
|
||||
// Disassembly should contain all instructions in the module separated by
|
||||
// \n, and no header.
|
||||
void SetDisassembly(std::string&& disassembly) {
|
||||
disassembly_ = MakeUnique<std::stringstream>(std::move(disassembly));
|
||||
}
|
||||
|
||||
// Extracts the next instruction line from the disassembly and logs it.
|
||||
void LogDisassemblyInstruction() {
|
||||
if (logger_ && disassembly_) {
|
||||
std::string line;
|
||||
std::getline(*disassembly_, line, '\n');
|
||||
logger_->AppendTextNewLine(line);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// Creates and returns validator options. Returned value owned by the caller.
|
||||
static spv_validator_options GetValidatorOptions(
|
||||
const MarkvCodecOptions& options) {
|
||||
return options.validate_spirv_binary ? spvValidatorOptionsCreate()
|
||||
: nullptr;
|
||||
}
|
||||
|
||||
// Writes a single word to bit stream. operand_.type determines if the word is
|
||||
// encoded and how.
|
||||
spv_result_t EncodeNonIdWord(uint32_t word);
|
||||
|
||||
// Writes both opcode and num_operands as a single code.
|
||||
// Returns SPV_UNSUPPORTED iff no suitable codec was found.
|
||||
spv_result_t EncodeOpcodeAndNumOperands(uint32_t opcode,
|
||||
uint32_t num_operands);
|
||||
|
||||
// Writes mtf rank to bit stream. |mtf| is used to determine the codec
|
||||
// scheme. |fallback_method| is used if no codec defined for |mtf|.
|
||||
spv_result_t EncodeMtfRankHuffman(uint32_t rank, uint64_t mtf,
|
||||
uint64_t fallback_method);
|
||||
|
||||
// Writes id using coding based on mtf associated with the id descriptor.
|
||||
// Returns SPV_UNSUPPORTED iff fallback method needs to be used.
|
||||
spv_result_t EncodeIdWithDescriptor(uint32_t id);
|
||||
|
||||
// Writes id using coding based on the given |mtf|, which is expected to
|
||||
// contain the given |id|.
|
||||
spv_result_t EncodeExistingId(uint64_t mtf, uint32_t id);
|
||||
|
||||
// Writes type id of the current instruction if can't be inferred.
|
||||
spv_result_t EncodeTypeId();
|
||||
|
||||
// Writes result id of the current instruction if can't be inferred.
|
||||
spv_result_t EncodeResultId();
|
||||
|
||||
// Writes ids which are neither type nor result ids.
|
||||
spv_result_t EncodeRefId(uint32_t id);
|
||||
|
||||
// Writes bits to the stream until the beginning of the next byte if the
|
||||
// number of bits until the next byte is less than |byte_break_if_less_than|.
|
||||
void AddByteBreak(size_t byte_break_if_less_than);
|
||||
|
||||
// Encodes a literal number operand and writes it to the bit stream.
|
||||
spv_result_t EncodeLiteralNumber(const spv_parsed_operand_t& operand);
|
||||
|
||||
MarkvCodecOptions options_;
|
||||
|
||||
// Bit stream where encoded instructions are written.
|
||||
BitWriterWord64 writer_;
|
||||
|
||||
// If not nullptr, disassembled instruction lines will be written to comments.
|
||||
// Format: \n separated instruction lines, no header.
|
||||
std::unique_ptr<std::stringstream> disassembly_;
|
||||
};
|
||||
|
||||
} // namespace comp
|
||||
} // namespace spvtools
|
||||
|
||||
#endif // SOURCE_COMP_MARKV_ENCODER_H_
|
@ -1,93 +0,0 @@
|
||||
// Copyright (c) 2018 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SOURCE_COMP_MARKV_LOGGER_H_
|
||||
#define SOURCE_COMP_MARKV_LOGGER_H_
|
||||
|
||||
#include "source/comp/markv.h"
|
||||
|
||||
namespace spvtools {
|
||||
namespace comp {
|
||||
|
||||
class MarkvLogger {
|
||||
public:
|
||||
MarkvLogger(MarkvLogConsumer log_consumer, MarkvDebugConsumer debug_consumer)
|
||||
: log_consumer_(log_consumer), debug_consumer_(debug_consumer) {}
|
||||
|
||||
void AppendText(const std::string& str) {
|
||||
Append(str);
|
||||
use_delimiter_ = false;
|
||||
}
|
||||
|
||||
void AppendTextNewLine(const std::string& str) {
|
||||
Append(str);
|
||||
Append("\n");
|
||||
use_delimiter_ = false;
|
||||
}
|
||||
|
||||
void AppendBitSequence(const std::string& str) {
|
||||
if (debug_consumer_) instruction_bits_ << str;
|
||||
if (use_delimiter_) Append("-");
|
||||
Append(str);
|
||||
use_delimiter_ = true;
|
||||
}
|
||||
|
||||
void AppendWhitespaces(size_t num) {
|
||||
Append(std::string(num, ' '));
|
||||
use_delimiter_ = false;
|
||||
}
|
||||
|
||||
void NewLine() {
|
||||
Append("\n");
|
||||
use_delimiter_ = false;
|
||||
}
|
||||
|
||||
bool DebugInstruction(const spv_parsed_instruction_t& inst) {
|
||||
bool result = true;
|
||||
if (debug_consumer_) {
|
||||
result = debug_consumer_(
|
||||
std::vector<uint32_t>(inst.words, inst.words + inst.num_words),
|
||||
instruction_bits_.str(), instruction_comment_.str());
|
||||
instruction_bits_.str(std::string());
|
||||
instruction_comment_.str(std::string());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
MarkvLogger(const MarkvLogger&) = delete;
|
||||
MarkvLogger(MarkvLogger&&) = delete;
|
||||
MarkvLogger& operator=(const MarkvLogger&) = delete;
|
||||
MarkvLogger& operator=(MarkvLogger&&) = delete;
|
||||
|
||||
void Append(const std::string& str) {
|
||||
if (log_consumer_) log_consumer_(str);
|
||||
if (debug_consumer_) instruction_comment_ << str;
|
||||
}
|
||||
|
||||
MarkvLogConsumer log_consumer_;
|
||||
MarkvDebugConsumer debug_consumer_;
|
||||
|
||||
std::stringstream instruction_bits_;
|
||||
std::stringstream instruction_comment_;
|
||||
|
||||
// If true a delimiter will be appended before the next bit sequence.
|
||||
// Used to generate outputs like: 1100-0 1110-1-1100-1-1111-0 110-0.
|
||||
bool use_delimiter_ = false;
|
||||
};
|
||||
|
||||
} // namespace comp
|
||||
} // namespace spvtools
|
||||
|
||||
#endif // SOURCE_COMP_MARKV_LOGGER_H_
|
@ -1,232 +0,0 @@
|
||||
// Copyright (c) 2018 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SOURCE_COMP_MARKV_MODEL_H_
|
||||
#define SOURCE_COMP_MARKV_MODEL_H_
|
||||
|
||||
#include <unordered_set>
|
||||
|
||||
#include "source/comp/huffman_codec.h"
|
||||
#include "source/latest_version_spirv_header.h"
|
||||
#include "spirv-tools/libspirv.hpp"
|
||||
|
||||
namespace spvtools {
|
||||
namespace comp {
|
||||
|
||||
// Base class for MARK-V models.
|
||||
// The class contains encoding/decoding model with various constants and
|
||||
// codecs used by the compression algorithm.
|
||||
class MarkvModel {
|
||||
public:
|
||||
MarkvModel()
|
||||
: operand_chunk_lengths_(
|
||||
static_cast<size_t>(SPV_OPERAND_TYPE_NUM_OPERAND_TYPES), 0) {
|
||||
// Set default values.
|
||||
operand_chunk_lengths_[SPV_OPERAND_TYPE_TYPE_ID] = 4;
|
||||
operand_chunk_lengths_[SPV_OPERAND_TYPE_RESULT_ID] = 8;
|
||||
operand_chunk_lengths_[SPV_OPERAND_TYPE_ID] = 8;
|
||||
operand_chunk_lengths_[SPV_OPERAND_TYPE_SCOPE_ID] = 8;
|
||||
operand_chunk_lengths_[SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID] = 8;
|
||||
operand_chunk_lengths_[SPV_OPERAND_TYPE_LITERAL_INTEGER] = 6;
|
||||
operand_chunk_lengths_[SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER] = 6;
|
||||
operand_chunk_lengths_[SPV_OPERAND_TYPE_CAPABILITY] = 6;
|
||||
operand_chunk_lengths_[SPV_OPERAND_TYPE_SOURCE_LANGUAGE] = 3;
|
||||
operand_chunk_lengths_[SPV_OPERAND_TYPE_EXECUTION_MODEL] = 3;
|
||||
operand_chunk_lengths_[SPV_OPERAND_TYPE_ADDRESSING_MODEL] = 2;
|
||||
operand_chunk_lengths_[SPV_OPERAND_TYPE_MEMORY_MODEL] = 2;
|
||||
operand_chunk_lengths_[SPV_OPERAND_TYPE_EXECUTION_MODE] = 6;
|
||||
operand_chunk_lengths_[SPV_OPERAND_TYPE_STORAGE_CLASS] = 4;
|
||||
operand_chunk_lengths_[SPV_OPERAND_TYPE_DIMENSIONALITY] = 3;
|
||||
operand_chunk_lengths_[SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE] = 3;
|
||||
operand_chunk_lengths_[SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE] = 2;
|
||||
operand_chunk_lengths_[SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT] = 6;
|
||||
operand_chunk_lengths_[SPV_OPERAND_TYPE_FP_ROUNDING_MODE] = 2;
|
||||
operand_chunk_lengths_[SPV_OPERAND_TYPE_LINKAGE_TYPE] = 2;
|
||||
operand_chunk_lengths_[SPV_OPERAND_TYPE_ACCESS_QUALIFIER] = 2;
|
||||
operand_chunk_lengths_[SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER] = 2;
|
||||
operand_chunk_lengths_[SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE] = 3;
|
||||
operand_chunk_lengths_[SPV_OPERAND_TYPE_DECORATION] = 6;
|
||||
operand_chunk_lengths_[SPV_OPERAND_TYPE_BUILT_IN] = 6;
|
||||
operand_chunk_lengths_[SPV_OPERAND_TYPE_GROUP_OPERATION] = 2;
|
||||
operand_chunk_lengths_[SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS] = 2;
|
||||
operand_chunk_lengths_[SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO] = 2;
|
||||
operand_chunk_lengths_[SPV_OPERAND_TYPE_FP_FAST_MATH_MODE] = 4;
|
||||
operand_chunk_lengths_[SPV_OPERAND_TYPE_FUNCTION_CONTROL] = 4;
|
||||
operand_chunk_lengths_[SPV_OPERAND_TYPE_LOOP_CONTROL] = 4;
|
||||
operand_chunk_lengths_[SPV_OPERAND_TYPE_IMAGE] = 4;
|
||||
operand_chunk_lengths_[SPV_OPERAND_TYPE_OPTIONAL_IMAGE] = 4;
|
||||
operand_chunk_lengths_[SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS] = 4;
|
||||
operand_chunk_lengths_[SPV_OPERAND_TYPE_SELECTION_CONTROL] = 4;
|
||||
operand_chunk_lengths_[SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER] = 6;
|
||||
operand_chunk_lengths_[SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER] = 6;
|
||||
}
|
||||
|
||||
uint32_t model_type() const { return model_type_; }
|
||||
uint32_t model_version() const { return model_version_; }
|
||||
|
||||
uint32_t opcode_chunk_length() const { return opcode_chunk_length_; }
|
||||
uint32_t num_operands_chunk_length() const {
|
||||
return num_operands_chunk_length_;
|
||||
}
|
||||
uint32_t mtf_rank_chunk_length() const { return mtf_rank_chunk_length_; }
|
||||
|
||||
uint32_t u64_chunk_length() const { return u64_chunk_length_; }
|
||||
uint32_t s64_chunk_length() const { return s64_chunk_length_; }
|
||||
uint32_t s64_block_exponent() const { return s64_block_exponent_; }
|
||||
|
||||
enum class IdFallbackStrategy {
|
||||
kRuleBased = 0,
|
||||
kShortDescriptor,
|
||||
};
|
||||
|
||||
IdFallbackStrategy id_fallback_strategy() const {
|
||||
return id_fallback_strategy_;
|
||||
}
|
||||
|
||||
// Returns a codec for common opcode_and_num_operands words for the given
|
||||
// previous opcode. May return nullptr if the codec doesn't exist.
|
||||
const HuffmanCodec<uint64_t>* GetOpcodeAndNumOperandsMarkovHuffmanCodec(
|
||||
uint32_t prev_opcode) const {
|
||||
if (prev_opcode == SpvOpNop)
|
||||
return opcode_and_num_operands_huffman_codec_.get();
|
||||
|
||||
const auto it =
|
||||
opcode_and_num_operands_markov_huffman_codecs_.find(prev_opcode);
|
||||
if (it == opcode_and_num_operands_markov_huffman_codecs_.end())
|
||||
return nullptr;
|
||||
return it->second.get();
|
||||
}
|
||||
|
||||
// Returns a codec for common non-id words used for given operand slot.
|
||||
// Operand slot is defined by the opcode and the operand index.
|
||||
// May return nullptr if the codec doesn't exist.
|
||||
const HuffmanCodec<uint64_t>* GetNonIdWordHuffmanCodec(
|
||||
uint32_t opcode, uint32_t operand_index) const {
|
||||
const auto it = non_id_word_huffman_codecs_.find(
|
||||
std::pair<uint32_t, uint32_t>(opcode, operand_index));
|
||||
if (it == non_id_word_huffman_codecs_.end()) return nullptr;
|
||||
return it->second.get();
|
||||
}
|
||||
|
||||
// Returns a codec for common id descriptos used for given operand slot.
|
||||
// Operand slot is defined by the opcode and the operand index.
|
||||
// May return nullptr if the codec doesn't exist.
|
||||
const HuffmanCodec<uint64_t>* GetIdDescriptorHuffmanCodec(
|
||||
uint32_t opcode, uint32_t operand_index) const {
|
||||
const auto it = id_descriptor_huffman_codecs_.find(
|
||||
std::pair<uint32_t, uint32_t>(opcode, operand_index));
|
||||
if (it == id_descriptor_huffman_codecs_.end()) return nullptr;
|
||||
return it->second.get();
|
||||
}
|
||||
|
||||
// Returns a codec for common strings used by the given opcode.
|
||||
// Operand slot is defined by the opcode and the operand index.
|
||||
// May return nullptr if the codec doesn't exist.
|
||||
const HuffmanCodec<std::string>* GetLiteralStringHuffmanCodec(
|
||||
uint32_t opcode) const {
|
||||
const auto it = literal_string_huffman_codecs_.find(opcode);
|
||||
if (it == literal_string_huffman_codecs_.end()) return nullptr;
|
||||
return it->second.get();
|
||||
}
|
||||
|
||||
// Checks if |descriptor| has a coding scheme in any of
|
||||
// id_descriptor_huffman_codecs_.
|
||||
bool DescriptorHasCodingScheme(uint32_t descriptor) const {
|
||||
return descriptors_with_coding_scheme_.count(descriptor);
|
||||
}
|
||||
|
||||
// Checks if any descriptor has a coding scheme.
|
||||
bool AnyDescriptorHasCodingScheme() const {
|
||||
return !descriptors_with_coding_scheme_.empty();
|
||||
}
|
||||
|
||||
// Returns chunk length used for variable length encoding of spirv operand
|
||||
// words.
|
||||
uint32_t GetOperandVariableWidthChunkLength(spv_operand_type_t type) const {
|
||||
return operand_chunk_lengths_.at(static_cast<size_t>(type));
|
||||
}
|
||||
|
||||
// Sets model type.
|
||||
void SetModelType(uint32_t in_model_type) { model_type_ = in_model_type; }
|
||||
|
||||
// Sets model version.
|
||||
void SetModelVersion(uint32_t in_model_version) {
|
||||
model_version_ = in_model_version;
|
||||
}
|
||||
|
||||
// Returns value used by Huffman codecs as a signal that a value is not in the
|
||||
// coding table.
|
||||
static uint64_t GetMarkvNoneOfTheAbove() {
|
||||
// Magic number.
|
||||
return 1111111111111111111;
|
||||
}
|
||||
|
||||
MarkvModel(const MarkvModel&) = delete;
|
||||
const MarkvModel& operator=(const MarkvModel&) = delete;
|
||||
|
||||
protected:
|
||||
// Huffman codec for base-rate of opcode_and_num_operands.
|
||||
std::unique_ptr<HuffmanCodec<uint64_t>>
|
||||
opcode_and_num_operands_huffman_codec_;
|
||||
|
||||
// Huffman codecs for opcode_and_num_operands. The map key is previous opcode.
|
||||
std::map<uint32_t, std::unique_ptr<HuffmanCodec<uint64_t>>>
|
||||
opcode_and_num_operands_markov_huffman_codecs_;
|
||||
|
||||
// Huffman codecs for non-id single-word operand values.
|
||||
// The map key is pair <opcode, operand_index>.
|
||||
std::map<std::pair<uint32_t, uint32_t>,
|
||||
std::unique_ptr<HuffmanCodec<uint64_t>>>
|
||||
non_id_word_huffman_codecs_;
|
||||
|
||||
// Huffman codecs for id descriptors. The map key is pair
|
||||
// <opcode, operand_index>.
|
||||
std::map<std::pair<uint32_t, uint32_t>,
|
||||
std::unique_ptr<HuffmanCodec<uint64_t>>>
|
||||
id_descriptor_huffman_codecs_;
|
||||
|
||||
// Set of all descriptors which have a coding scheme in any of
|
||||
// id_descriptor_huffman_codecs_.
|
||||
std::unordered_set<uint32_t> descriptors_with_coding_scheme_;
|
||||
|
||||
// Huffman codecs for literal strings. The map key is the opcode of the
|
||||
// current instruction. This assumes, that there is no more than one literal
|
||||
// string operand per instruction, but would still work even if this is not
|
||||
// the case. Names and debug information strings are not collected.
|
||||
std::map<uint32_t, std::unique_ptr<HuffmanCodec<std::string>>>
|
||||
literal_string_huffman_codecs_;
|
||||
|
||||
// Chunk lengths used for variable width encoding of operands (index is
|
||||
// spv_operand_type of the operand).
|
||||
std::vector<uint32_t> operand_chunk_lengths_;
|
||||
|
||||
uint32_t opcode_chunk_length_ = 7;
|
||||
uint32_t num_operands_chunk_length_ = 3;
|
||||
uint32_t mtf_rank_chunk_length_ = 5;
|
||||
|
||||
uint32_t u64_chunk_length_ = 8;
|
||||
uint32_t s64_chunk_length_ = 8;
|
||||
uint32_t s64_block_exponent_ = 10;
|
||||
|
||||
IdFallbackStrategy id_fallback_strategy_ =
|
||||
IdFallbackStrategy::kShortDescriptor;
|
||||
|
||||
uint32_t model_type_ = 0;
|
||||
uint32_t model_version_ = 0;
|
||||
};
|
||||
|
||||
} // namespace comp
|
||||
} // namespace spvtools
|
||||
|
||||
#endif // SOURCE_COMP_MARKV_MODEL_H_
|
@ -1,456 +0,0 @@
|
||||
// Copyright (c) 2018 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "source/comp/move_to_front.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <sstream>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
|
||||
namespace spvtools {
|
||||
namespace comp {
|
||||
|
||||
bool MoveToFront::Insert(uint32_t value) {
|
||||
auto it = value_to_node_.find(value);
|
||||
if (it != value_to_node_.end() && IsInTree(it->second)) return false;
|
||||
|
||||
const uint32_t old_size = GetSize();
|
||||
(void)old_size;
|
||||
|
||||
InsertNode(CreateNode(next_timestamp_++, value));
|
||||
|
||||
last_accessed_value_ = value;
|
||||
last_accessed_value_valid_ = true;
|
||||
|
||||
assert(value_to_node_.count(value));
|
||||
assert(old_size + 1 == GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MoveToFront::Remove(uint32_t value) {
|
||||
auto it = value_to_node_.find(value);
|
||||
if (it == value_to_node_.end()) return false;
|
||||
|
||||
if (!IsInTree(it->second)) return false;
|
||||
|
||||
if (last_accessed_value_ == value) last_accessed_value_valid_ = false;
|
||||
|
||||
const uint32_t orphan = RemoveNode(it->second);
|
||||
(void)orphan;
|
||||
// The node of |value| is still alive but it's orphaned now. Can still be
|
||||
// reused later.
|
||||
assert(!IsInTree(orphan));
|
||||
assert(ValueOf(orphan) == value);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MoveToFront::RankFromValue(uint32_t value, uint32_t* rank) {
|
||||
if (last_accessed_value_valid_ && last_accessed_value_ == value) {
|
||||
*rank = 1;
|
||||
return true;
|
||||
}
|
||||
|
||||
const uint32_t old_size = GetSize();
|
||||
if (old_size == 1) {
|
||||
if (ValueOf(root_) == value) {
|
||||
*rank = 1;
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
const auto it = value_to_node_.find(value);
|
||||
if (it == value_to_node_.end()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_t target = it->second;
|
||||
|
||||
if (!IsInTree(target)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_t node = target;
|
||||
*rank = 1 + SizeOf(LeftOf(node));
|
||||
while (node) {
|
||||
if (IsRightChild(node)) *rank += 1 + SizeOf(LeftOf(ParentOf(node)));
|
||||
node = ParentOf(node);
|
||||
}
|
||||
|
||||
// Don't update timestamp if the node has rank 1.
|
||||
if (*rank != 1) {
|
||||
// Update timestamp and reposition the node.
|
||||
target = RemoveNode(target);
|
||||
assert(ValueOf(target) == value);
|
||||
assert(old_size == GetSize() + 1);
|
||||
MutableTimestampOf(target) = next_timestamp_++;
|
||||
InsertNode(target);
|
||||
assert(old_size == GetSize());
|
||||
}
|
||||
|
||||
last_accessed_value_ = value;
|
||||
last_accessed_value_valid_ = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MoveToFront::HasValue(uint32_t value) const {
|
||||
const auto it = value_to_node_.find(value);
|
||||
if (it == value_to_node_.end()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return IsInTree(it->second);
|
||||
}
|
||||
|
||||
bool MoveToFront::Promote(uint32_t value) {
|
||||
if (last_accessed_value_valid_ && last_accessed_value_ == value) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const uint32_t old_size = GetSize();
|
||||
if (old_size == 1) return ValueOf(root_) == value;
|
||||
|
||||
const auto it = value_to_node_.find(value);
|
||||
if (it == value_to_node_.end()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_t target = it->second;
|
||||
|
||||
if (!IsInTree(target)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Update timestamp and reposition the node.
|
||||
target = RemoveNode(target);
|
||||
assert(ValueOf(target) == value);
|
||||
assert(old_size == GetSize() + 1);
|
||||
|
||||
MutableTimestampOf(target) = next_timestamp_++;
|
||||
InsertNode(target);
|
||||
assert(old_size == GetSize());
|
||||
|
||||
last_accessed_value_ = value;
|
||||
last_accessed_value_valid_ = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MoveToFront::ValueFromRank(uint32_t rank, uint32_t* value) {
|
||||
if (last_accessed_value_valid_ && rank == 1) {
|
||||
*value = last_accessed_value_;
|
||||
return true;
|
||||
}
|
||||
|
||||
const uint32_t old_size = GetSize();
|
||||
if (rank <= 0 || rank > old_size) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (old_size == 1) {
|
||||
*value = ValueOf(root_);
|
||||
return true;
|
||||
}
|
||||
|
||||
const bool update_timestamp = (rank != 1);
|
||||
|
||||
uint32_t node = root_;
|
||||
while (node) {
|
||||
const uint32_t left_subtree_num_nodes = SizeOf(LeftOf(node));
|
||||
if (rank == left_subtree_num_nodes + 1) {
|
||||
// This is the node we are looking for.
|
||||
// Don't update timestamp if the node has rank 1.
|
||||
if (update_timestamp) {
|
||||
node = RemoveNode(node);
|
||||
assert(old_size == GetSize() + 1);
|
||||
MutableTimestampOf(node) = next_timestamp_++;
|
||||
InsertNode(node);
|
||||
assert(old_size == GetSize());
|
||||
}
|
||||
*value = ValueOf(node);
|
||||
last_accessed_value_ = *value;
|
||||
last_accessed_value_valid_ = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
if (rank < left_subtree_num_nodes + 1) {
|
||||
// Descend into the left subtree. The rank is still valid.
|
||||
node = LeftOf(node);
|
||||
} else {
|
||||
// Descend into the right subtree. We leave behind the left subtree and
|
||||
// the current node, adjust the |rank| accordingly.
|
||||
rank -= left_subtree_num_nodes + 1;
|
||||
node = RightOf(node);
|
||||
}
|
||||
}
|
||||
|
||||
assert(0);
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_t MoveToFront::CreateNode(uint32_t timestamp, uint32_t value) {
|
||||
uint32_t handle = static_cast<uint32_t>(nodes_.size());
|
||||
const auto result = value_to_node_.emplace(value, handle);
|
||||
if (result.second) {
|
||||
// Create new node.
|
||||
nodes_.emplace_back(Node());
|
||||
Node& node = nodes_.back();
|
||||
node.timestamp = timestamp;
|
||||
node.value = value;
|
||||
node.size = 1;
|
||||
// Non-NIL nodes start with height 1 because their NIL children are
|
||||
// leaves.
|
||||
node.height = 1;
|
||||
} else {
|
||||
// Reuse old node.
|
||||
handle = result.first->second;
|
||||
assert(!IsInTree(handle));
|
||||
assert(ValueOf(handle) == value);
|
||||
assert(SizeOf(handle) == 1);
|
||||
assert(HeightOf(handle) == 1);
|
||||
MutableTimestampOf(handle) = timestamp;
|
||||
}
|
||||
|
||||
return handle;
|
||||
}
|
||||
|
||||
void MoveToFront::InsertNode(uint32_t node) {
|
||||
assert(!IsInTree(node));
|
||||
assert(SizeOf(node) == 1);
|
||||
assert(HeightOf(node) == 1);
|
||||
assert(TimestampOf(node));
|
||||
|
||||
if (!root_) {
|
||||
root_ = node;
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t iter = root_;
|
||||
uint32_t parent = 0;
|
||||
|
||||
// Will determine if |node| will become the right or left child after
|
||||
// insertion (but before balancing).
|
||||
bool right_child = true;
|
||||
|
||||
// Find the node which will become |node|'s parent after insertion
|
||||
// (but before balancing).
|
||||
while (iter) {
|
||||
parent = iter;
|
||||
assert(TimestampOf(iter) != TimestampOf(node));
|
||||
right_child = TimestampOf(iter) > TimestampOf(node);
|
||||
iter = right_child ? RightOf(iter) : LeftOf(iter);
|
||||
}
|
||||
|
||||
assert(parent);
|
||||
|
||||
// Connect node and parent.
|
||||
MutableParentOf(node) = parent;
|
||||
if (right_child)
|
||||
MutableRightOf(parent) = node;
|
||||
else
|
||||
MutableLeftOf(parent) = node;
|
||||
|
||||
// Insertion is finished. Start the balancing process.
|
||||
bool needs_rebalancing = true;
|
||||
parent = ParentOf(node);
|
||||
|
||||
while (parent) {
|
||||
UpdateNode(parent);
|
||||
|
||||
if (needs_rebalancing) {
|
||||
const int parent_balance = BalanceOf(parent);
|
||||
|
||||
if (RightOf(parent) == node) {
|
||||
// Added node to the right subtree.
|
||||
if (parent_balance > 1) {
|
||||
// Parent is right heavy, rotate left.
|
||||
if (BalanceOf(node) < 0) RotateRight(node);
|
||||
parent = RotateLeft(parent);
|
||||
} else if (parent_balance == 0 || parent_balance == -1) {
|
||||
// Parent is balanced or left heavy, no need to balance further.
|
||||
needs_rebalancing = false;
|
||||
}
|
||||
} else {
|
||||
// Added node to the left subtree.
|
||||
if (parent_balance < -1) {
|
||||
// Parent is left heavy, rotate right.
|
||||
if (BalanceOf(node) > 0) RotateLeft(node);
|
||||
parent = RotateRight(parent);
|
||||
} else if (parent_balance == 0 || parent_balance == 1) {
|
||||
// Parent is balanced or right heavy, no need to balance further.
|
||||
needs_rebalancing = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert(BalanceOf(parent) >= -1 && (BalanceOf(parent) <= 1));
|
||||
|
||||
node = parent;
|
||||
parent = ParentOf(parent);
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t MoveToFront::RemoveNode(uint32_t node) {
|
||||
if (LeftOf(node) && RightOf(node)) {
|
||||
// If |node| has two children, then use another node as scapegoat and swap
|
||||
// their contents. We pick the scapegoat on the side of the tree which has
|
||||
// more nodes.
|
||||
const uint32_t scapegoat = SizeOf(LeftOf(node)) >= SizeOf(RightOf(node))
|
||||
? RightestDescendantOf(LeftOf(node))
|
||||
: LeftestDescendantOf(RightOf(node));
|
||||
assert(scapegoat);
|
||||
std::swap(MutableValueOf(node), MutableValueOf(scapegoat));
|
||||
std::swap(MutableTimestampOf(node), MutableTimestampOf(scapegoat));
|
||||
value_to_node_[ValueOf(node)] = node;
|
||||
value_to_node_[ValueOf(scapegoat)] = scapegoat;
|
||||
node = scapegoat;
|
||||
}
|
||||
|
||||
// |node| may have only one child at this point.
|
||||
assert(!RightOf(node) || !LeftOf(node));
|
||||
|
||||
uint32_t parent = ParentOf(node);
|
||||
uint32_t child = RightOf(node) ? RightOf(node) : LeftOf(node);
|
||||
|
||||
// Orphan |node| and reconnect parent and child.
|
||||
if (child) MutableParentOf(child) = parent;
|
||||
|
||||
if (parent) {
|
||||
if (LeftOf(parent) == node)
|
||||
MutableLeftOf(parent) = child;
|
||||
else
|
||||
MutableRightOf(parent) = child;
|
||||
}
|
||||
|
||||
MutableParentOf(node) = 0;
|
||||
MutableLeftOf(node) = 0;
|
||||
MutableRightOf(node) = 0;
|
||||
UpdateNode(node);
|
||||
const uint32_t orphan = node;
|
||||
|
||||
if (root_ == node) root_ = child;
|
||||
|
||||
// Removal is finished. Start the balancing process.
|
||||
bool needs_rebalancing = true;
|
||||
node = child;
|
||||
|
||||
while (parent) {
|
||||
UpdateNode(parent);
|
||||
|
||||
if (needs_rebalancing) {
|
||||
const int parent_balance = BalanceOf(parent);
|
||||
|
||||
if (parent_balance == 1 || parent_balance == -1) {
|
||||
// The height of the subtree was not changed.
|
||||
needs_rebalancing = false;
|
||||
} else {
|
||||
if (RightOf(parent) == node) {
|
||||
// Removed node from the right subtree.
|
||||
if (parent_balance < -1) {
|
||||
// Parent is left heavy, rotate right.
|
||||
const uint32_t sibling = LeftOf(parent);
|
||||
if (BalanceOf(sibling) > 0) RotateLeft(sibling);
|
||||
parent = RotateRight(parent);
|
||||
}
|
||||
} else {
|
||||
// Removed node from the left subtree.
|
||||
if (parent_balance > 1) {
|
||||
// Parent is right heavy, rotate left.
|
||||
const uint32_t sibling = RightOf(parent);
|
||||
if (BalanceOf(sibling) < 0) RotateRight(sibling);
|
||||
parent = RotateLeft(parent);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert(BalanceOf(parent) >= -1 && (BalanceOf(parent) <= 1));
|
||||
|
||||
node = parent;
|
||||
parent = ParentOf(parent);
|
||||
}
|
||||
|
||||
return orphan;
|
||||
}
|
||||
|
||||
uint32_t MoveToFront::RotateLeft(const uint32_t node) {
|
||||
const uint32_t pivot = RightOf(node);
|
||||
assert(pivot);
|
||||
|
||||
// LeftOf(pivot) gets attached to node in place of pivot.
|
||||
MutableRightOf(node) = LeftOf(pivot);
|
||||
if (RightOf(node)) MutableParentOf(RightOf(node)) = node;
|
||||
|
||||
// Pivot gets attached to ParentOf(node) in place of node.
|
||||
MutableParentOf(pivot) = ParentOf(node);
|
||||
if (!ParentOf(node))
|
||||
root_ = pivot;
|
||||
else if (IsLeftChild(node))
|
||||
MutableLeftOf(ParentOf(node)) = pivot;
|
||||
else
|
||||
MutableRightOf(ParentOf(node)) = pivot;
|
||||
|
||||
// Node is child of pivot.
|
||||
MutableLeftOf(pivot) = node;
|
||||
MutableParentOf(node) = pivot;
|
||||
|
||||
// Update both node and pivot. Pivot is the new parent of node, so node should
|
||||
// be updated first.
|
||||
UpdateNode(node);
|
||||
UpdateNode(pivot);
|
||||
|
||||
return pivot;
|
||||
}
|
||||
|
||||
uint32_t MoveToFront::RotateRight(const uint32_t node) {
|
||||
const uint32_t pivot = LeftOf(node);
|
||||
assert(pivot);
|
||||
|
||||
// RightOf(pivot) gets attached to node in place of pivot.
|
||||
MutableLeftOf(node) = RightOf(pivot);
|
||||
if (LeftOf(node)) MutableParentOf(LeftOf(node)) = node;
|
||||
|
||||
// Pivot gets attached to ParentOf(node) in place of node.
|
||||
MutableParentOf(pivot) = ParentOf(node);
|
||||
if (!ParentOf(node))
|
||||
root_ = pivot;
|
||||
else if (IsLeftChild(node))
|
||||
MutableLeftOf(ParentOf(node)) = pivot;
|
||||
else
|
||||
MutableRightOf(ParentOf(node)) = pivot;
|
||||
|
||||
// Node is child of pivot.
|
||||
MutableRightOf(pivot) = node;
|
||||
MutableParentOf(node) = pivot;
|
||||
|
||||
// Update both node and pivot. Pivot is the new parent of node, so node should
|
||||
// be updated first.
|
||||
UpdateNode(node);
|
||||
UpdateNode(pivot);
|
||||
|
||||
return pivot;
|
||||
}
|
||||
|
||||
void MoveToFront::UpdateNode(uint32_t node) {
|
||||
MutableSizeOf(node) = 1 + SizeOf(LeftOf(node)) + SizeOf(RightOf(node));
|
||||
MutableHeightOf(node) =
|
||||
1 + std::max(HeightOf(LeftOf(node)), HeightOf(RightOf(node)));
|
||||
}
|
||||
|
||||
} // namespace comp
|
||||
} // namespace spvtools
|
@ -1,384 +0,0 @@
|
||||
// Copyright (c) 2017 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SOURCE_COMP_MOVE_TO_FRONT_H_
|
||||
#define SOURCE_COMP_MOVE_TO_FRONT_H_
|
||||
|
||||
#include <cassert>
|
||||
#include <cstdint>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace spvtools {
|
||||
namespace comp {
|
||||
|
||||
// Log(n) move-to-front implementation. Implements the following functions:
|
||||
// Insert - pushes value to the front of the mtf sequence
|
||||
// (only unique values allowed).
|
||||
// Remove - remove value from the sequence.
|
||||
// ValueFromRank - access value by its 1-indexed rank in the sequence.
|
||||
// RankFromValue - get the rank of the given value in the sequence.
|
||||
// Accessing a value with ValueFromRank or RankFromValue moves the value to the
|
||||
// front of the sequence (rank of 1).
|
||||
//
|
||||
// The implementation is based on an AVL-based order statistic tree. The tree
|
||||
// is ordered by timestamps issued when values are inserted or accessed (recent
|
||||
// values go to the left side of the tree, old values are gradually rotated to
|
||||
// the right side).
|
||||
//
|
||||
// Terminology
|
||||
// rank: 1-indexed rank showing how recently the value was inserted or accessed.
|
||||
// node: handle used internally to access node data.
|
||||
// size: size of the subtree of a node (including the node).
|
||||
// height: distance from a node to the farthest leaf.
|
||||
class MoveToFront {
|
||||
public:
|
||||
explicit MoveToFront(size_t reserve_capacity = 4) {
|
||||
nodes_.reserve(reserve_capacity);
|
||||
|
||||
// Create NIL node.
|
||||
nodes_.emplace_back(Node());
|
||||
}
|
||||
|
||||
virtual ~MoveToFront() = default;
|
||||
|
||||
// Inserts value in the move-to-front sequence. Does nothing if the value is
|
||||
// already in the sequence. Returns true if insertion was successful.
|
||||
// The inserted value is placed at the front of the sequence (rank 1).
|
||||
bool Insert(uint32_t value);
|
||||
|
||||
// Removes value from move-to-front sequence. Returns false iff the value
|
||||
// was not found.
|
||||
bool Remove(uint32_t value);
|
||||
|
||||
// Computes 1-indexed rank of value in the move-to-front sequence and moves
|
||||
// the value to the front. Example:
|
||||
// Before the call: 4 8 2 1 7
|
||||
// RankFromValue(8) returns 2
|
||||
// After the call: 8 4 2 1 7
|
||||
// Returns true iff the value was found in the sequence.
|
||||
bool RankFromValue(uint32_t value, uint32_t* rank);
|
||||
|
||||
// Returns value corresponding to a 1-indexed rank in the move-to-front
|
||||
// sequence and moves the value to the front. Example:
|
||||
// Before the call: 4 8 2 1 7
|
||||
// ValueFromRank(2) returns 8
|
||||
// After the call: 8 4 2 1 7
|
||||
// Returns true iff the rank is within bounds [1, GetSize()].
|
||||
bool ValueFromRank(uint32_t rank, uint32_t* value);
|
||||
|
||||
// Moves the value to the front of the sequence.
|
||||
// Returns false iff value is not in the sequence.
|
||||
bool Promote(uint32_t value);
|
||||
|
||||
// Returns true iff the move-to-front sequence contains the value.
|
||||
bool HasValue(uint32_t value) const;
|
||||
|
||||
// Returns the number of elements in the move-to-front sequence.
|
||||
uint32_t GetSize() const { return SizeOf(root_); }
|
||||
|
||||
protected:
|
||||
// Internal tree data structure uses handles instead of pointers. Leaves and
|
||||
// root parent reference a singleton under handle 0. Although dereferencing
|
||||
// a null pointer is not possible, inappropriate access to handle 0 would
|
||||
// cause an assertion. Handles are not garbage collected if value was
|
||||
// deprecated
|
||||
// with DeprecateValue(). But handles are recycled when a node is
|
||||
// repositioned.
|
||||
|
||||
// Internal tree data structure node.
|
||||
struct Node {
|
||||
// Timestamp from a logical clock which updates every time the element is
|
||||
// accessed through ValueFromRank or RankFromValue.
|
||||
uint32_t timestamp = 0;
|
||||
// The size of the node's subtree, including the node.
|
||||
// SizeOf(LeftOf(node)) + SizeOf(RightOf(node)) + 1.
|
||||
uint32_t size = 0;
|
||||
// Handles to connected nodes.
|
||||
uint32_t left = 0;
|
||||
uint32_t right = 0;
|
||||
uint32_t parent = 0;
|
||||
// Distance to the farthest leaf.
|
||||
// Leaves have height 0, real nodes at least 1.
|
||||
uint32_t height = 0;
|
||||
// Stored value.
|
||||
uint32_t value = 0;
|
||||
};
|
||||
|
||||
// Creates node and sets correct values. Non-NIL nodes should be created only
|
||||
// through this function. If the node with this value has been created
|
||||
// previously
|
||||
// and since orphaned, reuses the old node instead of creating a new one.
|
||||
uint32_t CreateNode(uint32_t timestamp, uint32_t value);
|
||||
|
||||
// Node accessor methods. Naming is designed to be similar to natural
|
||||
// language as these functions tend to be used in sequences, for example:
|
||||
// ParentOf(LeftestDescendentOf(RightOf(node)))
|
||||
|
||||
// Returns value of the node referenced by |handle|.
|
||||
uint32_t ValueOf(uint32_t node) const { return nodes_.at(node).value; }
|
||||
|
||||
// Returns left child of |node|.
|
||||
uint32_t LeftOf(uint32_t node) const { return nodes_.at(node).left; }
|
||||
|
||||
// Returns right child of |node|.
|
||||
uint32_t RightOf(uint32_t node) const { return nodes_.at(node).right; }
|
||||
|
||||
// Returns parent of |node|.
|
||||
uint32_t ParentOf(uint32_t node) const { return nodes_.at(node).parent; }
|
||||
|
||||
// Returns timestamp of |node|.
|
||||
uint32_t TimestampOf(uint32_t node) const {
|
||||
assert(node);
|
||||
return nodes_.at(node).timestamp;
|
||||
}
|
||||
|
||||
// Returns size of |node|.
|
||||
uint32_t SizeOf(uint32_t node) const { return nodes_.at(node).size; }
|
||||
|
||||
// Returns height of |node|.
|
||||
uint32_t HeightOf(uint32_t node) const { return nodes_.at(node).height; }
|
||||
|
||||
// Returns mutable reference to value of |node|.
|
||||
uint32_t& MutableValueOf(uint32_t node) {
|
||||
assert(node);
|
||||
return nodes_.at(node).value;
|
||||
}
|
||||
|
||||
// Returns mutable reference to handle of left child of |node|.
|
||||
uint32_t& MutableLeftOf(uint32_t node) {
|
||||
assert(node);
|
||||
return nodes_.at(node).left;
|
||||
}
|
||||
|
||||
// Returns mutable reference to handle of right child of |node|.
|
||||
uint32_t& MutableRightOf(uint32_t node) {
|
||||
assert(node);
|
||||
return nodes_.at(node).right;
|
||||
}
|
||||
|
||||
// Returns mutable reference to handle of parent of |node|.
|
||||
uint32_t& MutableParentOf(uint32_t node) {
|
||||
assert(node);
|
||||
return nodes_.at(node).parent;
|
||||
}
|
||||
|
||||
// Returns mutable reference to timestamp of |node|.
|
||||
uint32_t& MutableTimestampOf(uint32_t node) {
|
||||
assert(node);
|
||||
return nodes_.at(node).timestamp;
|
||||
}
|
||||
|
||||
// Returns mutable reference to size of |node|.
|
||||
uint32_t& MutableSizeOf(uint32_t node) {
|
||||
assert(node);
|
||||
return nodes_.at(node).size;
|
||||
}
|
||||
|
||||
// Returns mutable reference to height of |node|.
|
||||
uint32_t& MutableHeightOf(uint32_t node) {
|
||||
assert(node);
|
||||
return nodes_.at(node).height;
|
||||
}
|
||||
|
||||
// Returns true iff |node| is left child of its parent.
|
||||
bool IsLeftChild(uint32_t node) const {
|
||||
assert(node);
|
||||
return LeftOf(ParentOf(node)) == node;
|
||||
}
|
||||
|
||||
// Returns true iff |node| is right child of its parent.
|
||||
bool IsRightChild(uint32_t node) const {
|
||||
assert(node);
|
||||
return RightOf(ParentOf(node)) == node;
|
||||
}
|
||||
|
||||
// Returns true iff |node| has no relatives.
|
||||
bool IsOrphan(uint32_t node) const {
|
||||
assert(node);
|
||||
return !ParentOf(node) && !LeftOf(node) && !RightOf(node);
|
||||
}
|
||||
|
||||
// Returns true iff |node| is in the tree.
|
||||
bool IsInTree(uint32_t node) const {
|
||||
assert(node);
|
||||
return node == root_ || !IsOrphan(node);
|
||||
}
|
||||
|
||||
// Returns the height difference between right and left subtrees.
|
||||
int BalanceOf(uint32_t node) const {
|
||||
return int(HeightOf(RightOf(node))) - int(HeightOf(LeftOf(node)));
|
||||
}
|
||||
|
||||
// Updates size and height of the node, assuming that the children have
|
||||
// correct values.
|
||||
void UpdateNode(uint32_t node);
|
||||
|
||||
// Returns the most LeftOf(LeftOf(... descendent which is not leaf.
|
||||
uint32_t LeftestDescendantOf(uint32_t node) const {
|
||||
uint32_t parent = 0;
|
||||
while (node) {
|
||||
parent = node;
|
||||
node = LeftOf(node);
|
||||
}
|
||||
return parent;
|
||||
}
|
||||
|
||||
// Returns the most RightOf(RightOf(... descendent which is not leaf.
|
||||
uint32_t RightestDescendantOf(uint32_t node) const {
|
||||
uint32_t parent = 0;
|
||||
while (node) {
|
||||
parent = node;
|
||||
node = RightOf(node);
|
||||
}
|
||||
return parent;
|
||||
}
|
||||
|
||||
// Inserts node in the tree. The node must be an orphan.
|
||||
void InsertNode(uint32_t node);
|
||||
|
||||
// Removes node from the tree. May change value_to_node_ if removal uses a
|
||||
// scapegoat. Returns the removed (orphaned) handle for recycling. The
|
||||
// returned handle may not be equal to |node| if scapegoat was used.
|
||||
uint32_t RemoveNode(uint32_t node);
|
||||
|
||||
// Rotates |node| left, reassigns all connections and returns the node
|
||||
// which takes place of the |node|.
|
||||
uint32_t RotateLeft(const uint32_t node);
|
||||
|
||||
// Rotates |node| right, reassigns all connections and returns the node
|
||||
// which takes place of the |node|.
|
||||
uint32_t RotateRight(const uint32_t node);
|
||||
|
||||
// Root node handle. The tree is empty if root_ is 0.
|
||||
uint32_t root_ = 0;
|
||||
|
||||
// Incremented counters for next timestamp and value.
|
||||
uint32_t next_timestamp_ = 1;
|
||||
|
||||
// Holds all tree nodes. Indices of this vector are node handles.
|
||||
std::vector<Node> nodes_;
|
||||
|
||||
// Maps ids to node handles.
|
||||
std::unordered_map<uint32_t, uint32_t> value_to_node_;
|
||||
|
||||
// Cache for the last accessed value in the sequence.
|
||||
uint32_t last_accessed_value_ = 0;
|
||||
bool last_accessed_value_valid_ = false;
|
||||
};
|
||||
|
||||
class MultiMoveToFront {
|
||||
public:
|
||||
// Inserts |value| to sequence with handle |mtf|.
|
||||
// Returns false if |mtf| already has |value|.
|
||||
bool Insert(uint64_t mtf, uint32_t value) {
|
||||
if (GetMtf(mtf).Insert(value)) {
|
||||
val_to_mtfs_[value].insert(mtf);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Removes |value| from sequence with handle |mtf|.
|
||||
// Returns false if |mtf| doesn't have |value|.
|
||||
bool Remove(uint64_t mtf, uint32_t value) {
|
||||
if (GetMtf(mtf).Remove(value)) {
|
||||
val_to_mtfs_[value].erase(mtf);
|
||||
return true;
|
||||
}
|
||||
assert(val_to_mtfs_[value].count(mtf) == 0);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Removes |value| from all sequences which have it.
|
||||
void RemoveFromAll(uint32_t value) {
|
||||
auto it = val_to_mtfs_.find(value);
|
||||
if (it == val_to_mtfs_.end()) return;
|
||||
|
||||
auto& mtfs_containing_value = it->second;
|
||||
for (uint64_t mtf : mtfs_containing_value) {
|
||||
GetMtf(mtf).Remove(value);
|
||||
}
|
||||
|
||||
val_to_mtfs_.erase(value);
|
||||
}
|
||||
|
||||
// Computes rank of |value| in sequence |mtf|.
|
||||
// Returns false if |mtf| doesn't have |value|.
|
||||
bool RankFromValue(uint64_t mtf, uint32_t value, uint32_t* rank) {
|
||||
return GetMtf(mtf).RankFromValue(value, rank);
|
||||
}
|
||||
|
||||
// Finds |value| with |rank| in sequence |mtf|.
|
||||
// Returns false if |rank| is out of bounds.
|
||||
bool ValueFromRank(uint64_t mtf, uint32_t rank, uint32_t* value) {
|
||||
return GetMtf(mtf).ValueFromRank(rank, value);
|
||||
}
|
||||
|
||||
// Returns size of |mtf| sequence.
|
||||
uint32_t GetSize(uint64_t mtf) { return GetMtf(mtf).GetSize(); }
|
||||
|
||||
// Promotes |value| in all sequences which have it.
|
||||
void Promote(uint32_t value) {
|
||||
const auto it = val_to_mtfs_.find(value);
|
||||
if (it == val_to_mtfs_.end()) return;
|
||||
|
||||
const auto& mtfs_containing_value = it->second;
|
||||
for (uint64_t mtf : mtfs_containing_value) {
|
||||
GetMtf(mtf).Promote(value);
|
||||
}
|
||||
}
|
||||
|
||||
// Inserts |value| in sequence |mtf| or promotes if it's already there.
|
||||
void InsertOrPromote(uint64_t mtf, uint32_t value) {
|
||||
if (!Insert(mtf, value)) {
|
||||
GetMtf(mtf).Promote(value);
|
||||
}
|
||||
}
|
||||
|
||||
// Returns if |mtf| sequence has |value|.
|
||||
bool HasValue(uint64_t mtf, uint32_t value) {
|
||||
return GetMtf(mtf).HasValue(value);
|
||||
}
|
||||
|
||||
private:
|
||||
// Returns actual MoveToFront object corresponding to |handle|.
|
||||
// As multiple operations are often performed consecutively for the same
|
||||
// sequence, the last returned value is cached.
|
||||
MoveToFront& GetMtf(uint64_t handle) {
|
||||
if (!cached_mtf_ || cached_handle_ != handle) {
|
||||
cached_handle_ = handle;
|
||||
cached_mtf_ = &mtfs_[handle];
|
||||
}
|
||||
|
||||
return *cached_mtf_;
|
||||
}
|
||||
|
||||
// Container holding MoveToFront objects. Map key is sequence handle.
|
||||
std::map<uint64_t, MoveToFront> mtfs_;
|
||||
|
||||
// Container mapping value to sequences which contain that value.
|
||||
std::unordered_map<uint32_t, std::set<uint64_t>> val_to_mtfs_;
|
||||
|
||||
// Cache for the last accessed sequence.
|
||||
uint64_t cached_handle_ = 0;
|
||||
MoveToFront* cached_mtf_ = nullptr;
|
||||
};
|
||||
|
||||
} // namespace comp
|
||||
} // namespace spvtools
|
||||
|
||||
#endif // SOURCE_COMP_MOVE_TO_FRONT_H_
|
@ -1,78 +0,0 @@
|
||||
// Copyright (c) 2017 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "source/id_descriptor.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
|
||||
#include "source/opcode.h"
|
||||
#include "source/operand.h"
|
||||
|
||||
namespace spvtools {
|
||||
namespace {
|
||||
|
||||
// Hashes an array of words. Order of words is important.
|
||||
uint32_t HashU32Array(const std::vector<uint32_t>& words) {
|
||||
// The hash function is a sum of hashes of each word seeded by word index.
|
||||
// Knuth's multiplicative hash is used to hash the words.
|
||||
const uint32_t kKnuthMulHash = 2654435761;
|
||||
uint32_t val = 0;
|
||||
for (uint32_t i = 0; i < words.size(); ++i) {
|
||||
val += (words[i] + i + 123) * kKnuthMulHash;
|
||||
}
|
||||
return val;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
uint32_t IdDescriptorCollection::ProcessInstruction(
|
||||
const spv_parsed_instruction_t& inst) {
|
||||
if (!inst.result_id) return 0;
|
||||
|
||||
assert(words_.empty());
|
||||
words_.push_back(inst.words[0]);
|
||||
|
||||
for (size_t operand_index = 0; operand_index < inst.num_operands;
|
||||
++operand_index) {
|
||||
const auto& operand = inst.operands[operand_index];
|
||||
if (spvIsIdType(operand.type)) {
|
||||
const uint32_t id = inst.words[operand.offset];
|
||||
const auto it = id_to_descriptor_.find(id);
|
||||
// Forward declared ids are not hashed.
|
||||
if (it != id_to_descriptor_.end()) {
|
||||
words_.push_back(it->second);
|
||||
}
|
||||
} else {
|
||||
for (size_t operand_word_index = 0;
|
||||
operand_word_index < operand.num_words; ++operand_word_index) {
|
||||
words_.push_back(inst.words[operand.offset + operand_word_index]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t descriptor =
|
||||
custom_hash_func_ ? custom_hash_func_(words_) : HashU32Array(words_);
|
||||
if (descriptor == 0) descriptor = 1;
|
||||
assert(descriptor);
|
||||
|
||||
words_.clear();
|
||||
|
||||
const auto result = id_to_descriptor_.emplace(inst.result_id, descriptor);
|
||||
assert(result.second);
|
||||
(void)result;
|
||||
return descriptor;
|
||||
}
|
||||
|
||||
} // namespace spvtools
|
@ -1,63 +0,0 @@
|
||||
// Copyright (c) 2017 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SOURCE_ID_DESCRIPTOR_H_
|
||||
#define SOURCE_ID_DESCRIPTOR_H_
|
||||
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "spirv-tools/libspirv.hpp"
|
||||
|
||||
namespace spvtools {
|
||||
|
||||
using CustomHashFunc = std::function<uint32_t(const std::vector<uint32_t>&)>;
|
||||
|
||||
// Computes and stores id descriptors.
|
||||
//
|
||||
// Descriptors are computed as hash of all words in the instruction where ids
|
||||
// were substituted with previously computed descriptors.
|
||||
class IdDescriptorCollection {
|
||||
public:
|
||||
explicit IdDescriptorCollection(
|
||||
CustomHashFunc custom_hash_func = CustomHashFunc())
|
||||
: custom_hash_func_(custom_hash_func) {
|
||||
words_.reserve(16);
|
||||
}
|
||||
|
||||
// Computes descriptor for the result id of the given instruction and
|
||||
// registers it in id_to_descriptor_. Returns the computed descriptor.
|
||||
// This function needs to be sequentially called for every instruction in the
|
||||
// module.
|
||||
uint32_t ProcessInstruction(const spv_parsed_instruction_t& inst);
|
||||
|
||||
// Returns a previously computed descriptor id.
|
||||
uint32_t GetDescriptor(uint32_t id) const {
|
||||
const auto it = id_to_descriptor_.find(id);
|
||||
if (it == id_to_descriptor_.end()) return 0;
|
||||
return it->second;
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_map<uint32_t, uint32_t> id_to_descriptor_;
|
||||
|
||||
std::function<uint32_t(const std::vector<uint32_t>&)> custom_hash_func_;
|
||||
|
||||
// Scratch buffer used for hashing. Class member to optimize on allocation.
|
||||
std::vector<uint32_t> words_;
|
||||
};
|
||||
|
||||
} // namespace spvtools
|
||||
|
||||
#endif // SOURCE_ID_DESCRIPTOR_H_
|
@ -183,33 +183,9 @@ add_spvtools_unittest(
|
||||
endif()
|
||||
|
||||
|
||||
add_spvtools_unittest(
|
||||
TARGET bit_stream
|
||||
SRCS bit_stream.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../source/comp/bit_stream.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../source/comp/bit_stream.h
|
||||
LIBS ${SPIRV_TOOLS})
|
||||
|
||||
add_spvtools_unittest(
|
||||
TARGET huffman_codec
|
||||
SRCS huffman_codec.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../source/comp/bit_stream.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../source/comp/bit_stream.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../source/comp/huffman_codec.h
|
||||
LIBS ${SPIRV_TOOLS})
|
||||
|
||||
add_spvtools_unittest(
|
||||
TARGET move_to_front
|
||||
SRCS move_to_front_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../source/comp/move_to_front.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../source/comp/move_to_front.cpp
|
||||
LIBS ${SPIRV_TOOLS})
|
||||
|
||||
add_subdirectory(comp)
|
||||
add_subdirectory(link)
|
||||
add_subdirectory(opt)
|
||||
add_subdirectory(reduce)
|
||||
add_subdirectory(stats)
|
||||
add_subdirectory(tools)
|
||||
add_subdirectory(util)
|
||||
add_subdirectory(val)
|
||||
|
1025
test/bit_stream.cpp
1025
test/bit_stream.cpp
File diff suppressed because it is too large
Load Diff
@ -1,29 +0,0 @@
|
||||
# Copyright (c) 2017 Google Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
set(VAL_TEST_COMMON_SRCS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../test_fixture.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../unit_spirv.h
|
||||
)
|
||||
|
||||
if(SPIRV_BUILD_COMPRESSION)
|
||||
add_spvtools_unittest(TARGET markv_codec
|
||||
SRCS
|
||||
markv_codec_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../tools/comp/markv_model_factory.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../tools/comp/markv_model_shader.cpp
|
||||
${VAL_TEST_COMMON_SRCS}
|
||||
LIBS SPIRV-Tools-comp ${SPIRV_TOOLS}
|
||||
)
|
||||
endif(SPIRV_BUILD_COMPRESSION)
|
@ -1,829 +0,0 @@
|
||||
// Copyright (c) 2017 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Tests for unique type declaration rules validator.
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "gmock/gmock.h"
|
||||
#include "source/comp/markv.h"
|
||||
#include "test/test_fixture.h"
|
||||
#include "test/unit_spirv.h"
|
||||
#include "tools/comp/markv_model_factory.h"
|
||||
|
||||
namespace spvtools {
|
||||
namespace comp {
|
||||
namespace {
|
||||
|
||||
using spvtest::ScopedContext;
|
||||
using MarkvTest = ::testing::TestWithParam<MarkvModelType>;
|
||||
|
||||
void DiagnosticsMessageHandler(spv_message_level_t level, const char*,
|
||||
const spv_position_t& position,
|
||||
const char* message) {
|
||||
switch (level) {
|
||||
case SPV_MSG_FATAL:
|
||||
case SPV_MSG_INTERNAL_ERROR:
|
||||
case SPV_MSG_ERROR:
|
||||
std::cerr << "error: " << position.index << ": " << message << std::endl;
|
||||
break;
|
||||
case SPV_MSG_WARNING:
|
||||
std::cout << "warning: " << position.index << ": " << message
|
||||
<< std::endl;
|
||||
break;
|
||||
case SPV_MSG_INFO:
|
||||
std::cout << "info: " << position.index << ": " << message << std::endl;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Compiles |code| to SPIR-V |words|.
|
||||
void Compile(const std::string& code, std::vector<uint32_t>* words,
|
||||
uint32_t options = SPV_TEXT_TO_BINARY_OPTION_NONE,
|
||||
spv_target_env env = SPV_ENV_UNIVERSAL_1_2) {
|
||||
spvtools::Context ctx(env);
|
||||
ctx.SetMessageConsumer(DiagnosticsMessageHandler);
|
||||
|
||||
spv_binary spirv_binary;
|
||||
ASSERT_EQ(SPV_SUCCESS, spvTextToBinaryWithOptions(
|
||||
ctx.CContext(), code.c_str(), code.size(), options,
|
||||
&spirv_binary, nullptr));
|
||||
|
||||
*words = std::vector<uint32_t>(spirv_binary->code,
|
||||
spirv_binary->code + spirv_binary->wordCount);
|
||||
|
||||
spvBinaryDestroy(spirv_binary);
|
||||
}
|
||||
|
||||
// Disassembles SPIR-V |words| to |out_text|.
|
||||
void Disassemble(const std::vector<uint32_t>& words, std::string* out_text,
|
||||
spv_target_env env = SPV_ENV_UNIVERSAL_1_2) {
|
||||
spvtools::Context ctx(env);
|
||||
ctx.SetMessageConsumer(DiagnosticsMessageHandler);
|
||||
|
||||
spv_text text = nullptr;
|
||||
ASSERT_EQ(SPV_SUCCESS, spvBinaryToText(ctx.CContext(), words.data(),
|
||||
words.size(), 0, &text, nullptr));
|
||||
assert(text);
|
||||
|
||||
*out_text = std::string(text->str, text->length);
|
||||
spvTextDestroy(text);
|
||||
}
|
||||
|
||||
// Encodes/decodes |original|, assembles/dissasembles |original|, then compares
|
||||
// the results of the two operations.
|
||||
void TestEncodeDecode(MarkvModelType model_type,
|
||||
const std::string& original_text) {
|
||||
spvtools::Context ctx(SPV_ENV_UNIVERSAL_1_2);
|
||||
std::unique_ptr<MarkvModel> model = CreateMarkvModel(model_type);
|
||||
MarkvCodecOptions options;
|
||||
|
||||
std::vector<uint32_t> expected_binary;
|
||||
Compile(original_text, &expected_binary);
|
||||
ASSERT_FALSE(expected_binary.empty());
|
||||
|
||||
std::string expected_text;
|
||||
Disassemble(expected_binary, &expected_text);
|
||||
ASSERT_FALSE(expected_text.empty());
|
||||
|
||||
std::vector<uint32_t> binary_to_encode;
|
||||
Compile(original_text, &binary_to_encode,
|
||||
SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
|
||||
ASSERT_FALSE(binary_to_encode.empty());
|
||||
|
||||
std::stringstream encoder_comments;
|
||||
const auto output_to_string_stream =
|
||||
[&encoder_comments](const std::string& str) { encoder_comments << str; };
|
||||
|
||||
std::vector<uint8_t> markv;
|
||||
ASSERT_EQ(SPV_SUCCESS,
|
||||
SpirvToMarkv(ctx.CContext(), binary_to_encode, options, *model,
|
||||
DiagnosticsMessageHandler, output_to_string_stream,
|
||||
MarkvDebugConsumer(), &markv));
|
||||
ASSERT_FALSE(markv.empty());
|
||||
|
||||
std::vector<uint32_t> decoded_binary;
|
||||
ASSERT_EQ(SPV_SUCCESS,
|
||||
MarkvToSpirv(ctx.CContext(), markv, options, *model,
|
||||
DiagnosticsMessageHandler, MarkvLogConsumer(),
|
||||
MarkvDebugConsumer(), &decoded_binary));
|
||||
ASSERT_FALSE(decoded_binary.empty());
|
||||
|
||||
EXPECT_EQ(expected_binary, decoded_binary) << encoder_comments.str();
|
||||
|
||||
std::string decoded_text;
|
||||
Disassemble(decoded_binary, &decoded_text);
|
||||
ASSERT_FALSE(decoded_text.empty());
|
||||
|
||||
EXPECT_EQ(expected_text, decoded_text) << encoder_comments.str();
|
||||
}
|
||||
|
||||
void TestEncodeDecodeShaderMainBody(MarkvModelType model_type,
|
||||
const std::string& body) {
|
||||
const std::string prefix =
|
||||
R"(
|
||||
OpCapability Shader
|
||||
OpCapability Int64
|
||||
OpCapability Float64
|
||||
%ext_inst = OpExtInstImport "GLSL.std.450"
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint Fragment %main "main"
|
||||
%void = OpTypeVoid
|
||||
%func = OpTypeFunction %void
|
||||
%bool = OpTypeBool
|
||||
%f32 = OpTypeFloat 32
|
||||
%u32 = OpTypeInt 32 0
|
||||
%s32 = OpTypeInt 32 1
|
||||
%f64 = OpTypeFloat 64
|
||||
%u64 = OpTypeInt 64 0
|
||||
%s64 = OpTypeInt 64 1
|
||||
%boolvec2 = OpTypeVector %bool 2
|
||||
%s32vec2 = OpTypeVector %s32 2
|
||||
%u32vec2 = OpTypeVector %u32 2
|
||||
%f32vec2 = OpTypeVector %f32 2
|
||||
%f64vec2 = OpTypeVector %f64 2
|
||||
%boolvec3 = OpTypeVector %bool 3
|
||||
%u32vec3 = OpTypeVector %u32 3
|
||||
%s32vec3 = OpTypeVector %s32 3
|
||||
%f32vec3 = OpTypeVector %f32 3
|
||||
%f64vec3 = OpTypeVector %f64 3
|
||||
%boolvec4 = OpTypeVector %bool 4
|
||||
%u32vec4 = OpTypeVector %u32 4
|
||||
%s32vec4 = OpTypeVector %s32 4
|
||||
%f32vec4 = OpTypeVector %f32 4
|
||||
%f64vec4 = OpTypeVector %f64 4
|
||||
|
||||
%f32_0 = OpConstant %f32 0
|
||||
%f32_1 = OpConstant %f32 1
|
||||
%f32_2 = OpConstant %f32 2
|
||||
%f32_3 = OpConstant %f32 3
|
||||
%f32_4 = OpConstant %f32 4
|
||||
%f32_pi = OpConstant %f32 3.14159
|
||||
|
||||
%s32_0 = OpConstant %s32 0
|
||||
%s32_1 = OpConstant %s32 1
|
||||
%s32_2 = OpConstant %s32 2
|
||||
%s32_3 = OpConstant %s32 3
|
||||
%s32_4 = OpConstant %s32 4
|
||||
%s32_m1 = OpConstant %s32 -1
|
||||
|
||||
%u32_0 = OpConstant %u32 0
|
||||
%u32_1 = OpConstant %u32 1
|
||||
%u32_2 = OpConstant %u32 2
|
||||
%u32_3 = OpConstant %u32 3
|
||||
%u32_4 = OpConstant %u32 4
|
||||
|
||||
%u32vec2_01 = OpConstantComposite %u32vec2 %u32_0 %u32_1
|
||||
%u32vec2_12 = OpConstantComposite %u32vec2 %u32_1 %u32_2
|
||||
%u32vec3_012 = OpConstantComposite %u32vec3 %u32_0 %u32_1 %u32_2
|
||||
%u32vec3_123 = OpConstantComposite %u32vec3 %u32_1 %u32_2 %u32_3
|
||||
%u32vec4_0123 = OpConstantComposite %u32vec4 %u32_0 %u32_1 %u32_2 %u32_3
|
||||
%u32vec4_1234 = OpConstantComposite %u32vec4 %u32_1 %u32_2 %u32_3 %u32_4
|
||||
|
||||
%s32vec2_01 = OpConstantComposite %s32vec2 %s32_0 %s32_1
|
||||
%s32vec2_12 = OpConstantComposite %s32vec2 %s32_1 %s32_2
|
||||
%s32vec3_012 = OpConstantComposite %s32vec3 %s32_0 %s32_1 %s32_2
|
||||
%s32vec3_123 = OpConstantComposite %s32vec3 %s32_1 %s32_2 %s32_3
|
||||
%s32vec4_0123 = OpConstantComposite %s32vec4 %s32_0 %s32_1 %s32_2 %s32_3
|
||||
%s32vec4_1234 = OpConstantComposite %s32vec4 %s32_1 %s32_2 %s32_3 %s32_4
|
||||
|
||||
%f32vec2_01 = OpConstantComposite %f32vec2 %f32_0 %f32_1
|
||||
%f32vec2_12 = OpConstantComposite %f32vec2 %f32_1 %f32_2
|
||||
%f32vec3_012 = OpConstantComposite %f32vec3 %f32_0 %f32_1 %f32_2
|
||||
%f32vec3_123 = OpConstantComposite %f32vec3 %f32_1 %f32_2 %f32_3
|
||||
%f32vec4_0123 = OpConstantComposite %f32vec4 %f32_0 %f32_1 %f32_2 %f32_3
|
||||
%f32vec4_1234 = OpConstantComposite %f32vec4 %f32_1 %f32_2 %f32_3 %f32_4
|
||||
|
||||
%main = OpFunction %void None %func
|
||||
%main_entry = OpLabel)";
|
||||
|
||||
const std::string suffix =
|
||||
R"(
|
||||
OpReturn
|
||||
OpFunctionEnd)";
|
||||
|
||||
TestEncodeDecode(model_type, prefix + body + suffix);
|
||||
}
|
||||
|
||||
TEST_P(MarkvTest, U32Literal) {
|
||||
TestEncodeDecode(GetParam(), R"(
|
||||
OpCapability Shader
|
||||
OpCapability Linkage
|
||||
OpMemoryModel Logical GLSL450
|
||||
%u32 = OpTypeInt 32 0
|
||||
%100 = OpConstant %u32 0
|
||||
%200 = OpConstant %u32 1
|
||||
%300 = OpConstant %u32 4294967295
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_P(MarkvTest, S32Literal) {
|
||||
TestEncodeDecode(GetParam(), R"(
|
||||
OpCapability Shader
|
||||
OpCapability Linkage
|
||||
OpMemoryModel Logical GLSL450
|
||||
%s32 = OpTypeInt 32 1
|
||||
%100 = OpConstant %s32 0
|
||||
%200 = OpConstant %s32 1
|
||||
%300 = OpConstant %s32 -1
|
||||
%400 = OpConstant %s32 2147483647
|
||||
%500 = OpConstant %s32 -2147483648
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_P(MarkvTest, U64Literal) {
|
||||
TestEncodeDecode(GetParam(), R"(
|
||||
OpCapability Shader
|
||||
OpCapability Linkage
|
||||
OpCapability Int64
|
||||
OpMemoryModel Logical GLSL450
|
||||
%u64 = OpTypeInt 64 0
|
||||
%100 = OpConstant %u64 0
|
||||
%200 = OpConstant %u64 1
|
||||
%300 = OpConstant %u64 18446744073709551615
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_P(MarkvTest, S64Literal) {
|
||||
TestEncodeDecode(GetParam(), R"(
|
||||
OpCapability Shader
|
||||
OpCapability Linkage
|
||||
OpCapability Int64
|
||||
OpMemoryModel Logical GLSL450
|
||||
%s64 = OpTypeInt 64 1
|
||||
%100 = OpConstant %s64 0
|
||||
%200 = OpConstant %s64 1
|
||||
%300 = OpConstant %s64 -1
|
||||
%400 = OpConstant %s64 9223372036854775807
|
||||
%500 = OpConstant %s64 -9223372036854775808
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_P(MarkvTest, U16Literal) {
|
||||
TestEncodeDecode(GetParam(), R"(
|
||||
OpCapability Shader
|
||||
OpCapability Linkage
|
||||
OpCapability Int16
|
||||
OpMemoryModel Logical GLSL450
|
||||
%u16 = OpTypeInt 16 0
|
||||
%100 = OpConstant %u16 0
|
||||
%200 = OpConstant %u16 1
|
||||
%300 = OpConstant %u16 65535
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_P(MarkvTest, S16Literal) {
|
||||
TestEncodeDecode(GetParam(), R"(
|
||||
OpCapability Shader
|
||||
OpCapability Linkage
|
||||
OpCapability Int16
|
||||
OpMemoryModel Logical GLSL450
|
||||
%s16 = OpTypeInt 16 1
|
||||
%100 = OpConstant %s16 0
|
||||
%200 = OpConstant %s16 1
|
||||
%300 = OpConstant %s16 -1
|
||||
%400 = OpConstant %s16 32767
|
||||
%500 = OpConstant %s16 -32768
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_P(MarkvTest, F32Literal) {
|
||||
TestEncodeDecode(GetParam(), R"(
|
||||
OpCapability Shader
|
||||
OpCapability Linkage
|
||||
OpMemoryModel Logical GLSL450
|
||||
%f32 = OpTypeFloat 32
|
||||
%100 = OpConstant %f32 0
|
||||
%200 = OpConstant %f32 1
|
||||
%300 = OpConstant %f32 0.1
|
||||
%400 = OpConstant %f32 -0.1
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_P(MarkvTest, F64Literal) {
|
||||
TestEncodeDecode(GetParam(), R"(
|
||||
OpCapability Shader
|
||||
OpCapability Linkage
|
||||
OpCapability Float64
|
||||
OpMemoryModel Logical GLSL450
|
||||
%f64 = OpTypeFloat 64
|
||||
%100 = OpConstant %f64 0
|
||||
%200 = OpConstant %f64 1
|
||||
%300 = OpConstant %f64 0.1
|
||||
%400 = OpConstant %f64 -0.1
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_P(MarkvTest, F16Literal) {
|
||||
TestEncodeDecode(GetParam(), R"(
|
||||
OpCapability Shader
|
||||
OpCapability Linkage
|
||||
OpCapability Float16
|
||||
OpMemoryModel Logical GLSL450
|
||||
%f16 = OpTypeFloat 16
|
||||
%100 = OpConstant %f16 0
|
||||
%200 = OpConstant %f16 1
|
||||
%300 = OpConstant %f16 0.1
|
||||
%400 = OpConstant %f16 -0.1
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_P(MarkvTest, StringLiteral) {
|
||||
TestEncodeDecode(GetParam(), R"(
|
||||
OpCapability Shader
|
||||
OpCapability Linkage
|
||||
OpExtension "SPV_KHR_16bit_storage"
|
||||
OpExtension "xxx"
|
||||
OpExtension "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
|
||||
OpExtension ""
|
||||
OpMemoryModel Logical GLSL450
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_P(MarkvTest, WithFunction) {
|
||||
TestEncodeDecode(GetParam(), R"(
|
||||
OpCapability Addresses
|
||||
OpCapability Kernel
|
||||
OpCapability GenericPointer
|
||||
OpCapability Linkage
|
||||
OpExtension "SPV_KHR_16bit_storage"
|
||||
OpMemoryModel Physical32 OpenCL
|
||||
%f32 = OpTypeFloat 32
|
||||
%u32 = OpTypeInt 32 0
|
||||
%void = OpTypeVoid
|
||||
%void_func = OpTypeFunction %void
|
||||
%100 = OpConstant %u32 1
|
||||
%200 = OpConstant %u32 2
|
||||
%main = OpFunction %void None %void_func
|
||||
%entry_main = OpLabel
|
||||
%300 = OpIAdd %u32 %100 %200
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_P(MarkvTest, WithMultipleFunctions) {
|
||||
TestEncodeDecode(GetParam(), R"(
|
||||
OpCapability Addresses
|
||||
OpCapability Kernel
|
||||
OpCapability GenericPointer
|
||||
OpCapability Linkage
|
||||
OpMemoryModel Physical32 OpenCL
|
||||
%f32 = OpTypeFloat 32
|
||||
%one = OpConstant %f32 1
|
||||
%void = OpTypeVoid
|
||||
%void_func = OpTypeFunction %void
|
||||
%f32_func = OpTypeFunction %f32 %f32
|
||||
%sqr_plus_one = OpFunction %f32 None %f32_func
|
||||
%x = OpFunctionParameter %f32
|
||||
%100 = OpLabel
|
||||
%x2 = OpFMul %f32 %x %x
|
||||
%x2p1 = OpFunctionCall %f32 %plus_one %x2
|
||||
OpReturnValue %x2p1
|
||||
OpFunctionEnd
|
||||
%plus_one = OpFunction %f32 None %f32_func
|
||||
%y = OpFunctionParameter %f32
|
||||
%200 = OpLabel
|
||||
%yp1 = OpFAdd %f32 %y %one
|
||||
OpReturnValue %yp1
|
||||
OpFunctionEnd
|
||||
%main = OpFunction %void None %void_func
|
||||
%entry_main = OpLabel
|
||||
%1p1 = OpFunctionCall %f32 %sqr_plus_one %one
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_P(MarkvTest, ForwardDeclaredId) {
|
||||
TestEncodeDecode(GetParam(), R"(
|
||||
OpCapability Addresses
|
||||
OpCapability Kernel
|
||||
OpCapability GenericPointer
|
||||
OpCapability Linkage
|
||||
OpMemoryModel Physical32 OpenCL
|
||||
OpEntryPoint Kernel %1 "simple_kernel"
|
||||
%2 = OpTypeInt 32 0
|
||||
%3 = OpTypeVector %2 2
|
||||
%4 = OpConstant %2 2
|
||||
%5 = OpTypeArray %2 %4
|
||||
%6 = OpTypeVoid
|
||||
%7 = OpTypeFunction %6
|
||||
%1 = OpFunction %6 None %7
|
||||
%8 = OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_P(MarkvTest, WithSwitch) {
|
||||
TestEncodeDecode(GetParam(), R"(
|
||||
OpCapability Addresses
|
||||
OpCapability Kernel
|
||||
OpCapability GenericPointer
|
||||
OpCapability Linkage
|
||||
OpCapability Int64
|
||||
OpMemoryModel Physical32 OpenCL
|
||||
%u64 = OpTypeInt 64 0
|
||||
%void = OpTypeVoid
|
||||
%void_func = OpTypeFunction %void
|
||||
%val = OpConstant %u64 1
|
||||
%main = OpFunction %void None %void_func
|
||||
%entry_main = OpLabel
|
||||
OpSwitch %val %default 1 %case1 1000000000000 %case2
|
||||
%case1 = OpLabel
|
||||
OpNop
|
||||
OpBranch %after_switch
|
||||
%case2 = OpLabel
|
||||
OpNop
|
||||
OpBranch %after_switch
|
||||
%default = OpLabel
|
||||
OpNop
|
||||
OpBranch %after_switch
|
||||
%after_switch = OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_P(MarkvTest, WithLoop) {
|
||||
TestEncodeDecode(GetParam(), R"(
|
||||
OpCapability Addresses
|
||||
OpCapability Kernel
|
||||
OpCapability GenericPointer
|
||||
OpCapability Linkage
|
||||
OpMemoryModel Physical32 OpenCL
|
||||
%void = OpTypeVoid
|
||||
%void_func = OpTypeFunction %void
|
||||
%main = OpFunction %void None %void_func
|
||||
%entry_main = OpLabel
|
||||
OpLoopMerge %merge %continue DontUnroll|DependencyLength 10
|
||||
OpBranch %begin_loop
|
||||
%begin_loop = OpLabel
|
||||
OpNop
|
||||
OpBranch %continue
|
||||
%continue = OpLabel
|
||||
OpNop
|
||||
OpBranch %begin_loop
|
||||
%merge = OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_P(MarkvTest, WithDecorate) {
|
||||
TestEncodeDecode(GetParam(), R"(
|
||||
OpCapability Shader
|
||||
OpCapability Linkage
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpDecorate %1 ArrayStride 4
|
||||
OpDecorate %1 Uniform
|
||||
%2 = OpTypeFloat 32
|
||||
%1 = OpTypeRuntimeArray %2
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_P(MarkvTest, WithExtInst) {
|
||||
TestEncodeDecode(GetParam(), R"(
|
||||
OpCapability Addresses
|
||||
OpCapability Kernel
|
||||
OpCapability GenericPointer
|
||||
OpCapability Linkage
|
||||
%opencl = OpExtInstImport "OpenCL.std"
|
||||
OpMemoryModel Physical32 OpenCL
|
||||
%f32 = OpTypeFloat 32
|
||||
%void = OpTypeVoid
|
||||
%void_func = OpTypeFunction %void
|
||||
%100 = OpConstant %f32 1.1
|
||||
%main = OpFunction %void None %void_func
|
||||
%entry_main = OpLabel
|
||||
%200 = OpExtInst %f32 %opencl cos %100
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_P(MarkvTest, F32Mul) {
|
||||
TestEncodeDecodeShaderMainBody(GetParam(), R"(
|
||||
%val1 = OpFMul %f32 %f32_0 %f32_1
|
||||
%val2 = OpFMul %f32 %f32_2 %f32_0
|
||||
%val3 = OpFMul %f32 %f32_pi %f32_2
|
||||
%val4 = OpFMul %f32 %f32_1 %f32_1
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_P(MarkvTest, U32Mul) {
|
||||
TestEncodeDecodeShaderMainBody(GetParam(), R"(
|
||||
%val1 = OpIMul %u32 %u32_0 %u32_1
|
||||
%val2 = OpIMul %u32 %u32_2 %u32_0
|
||||
%val3 = OpIMul %u32 %u32_3 %u32_2
|
||||
%val4 = OpIMul %u32 %u32_1 %u32_1
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_P(MarkvTest, S32Mul) {
|
||||
TestEncodeDecodeShaderMainBody(GetParam(), R"(
|
||||
%val1 = OpIMul %s32 %s32_0 %s32_1
|
||||
%val2 = OpIMul %s32 %s32_2 %s32_0
|
||||
%val3 = OpIMul %s32 %s32_m1 %s32_2
|
||||
%val4 = OpIMul %s32 %s32_1 %s32_1
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_P(MarkvTest, F32Add) {
|
||||
TestEncodeDecodeShaderMainBody(GetParam(), R"(
|
||||
%val1 = OpFAdd %f32 %f32_0 %f32_1
|
||||
%val2 = OpFAdd %f32 %f32_2 %f32_0
|
||||
%val3 = OpFAdd %f32 %f32_pi %f32_2
|
||||
%val4 = OpFAdd %f32 %f32_1 %f32_1
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_P(MarkvTest, U32Add) {
|
||||
TestEncodeDecodeShaderMainBody(GetParam(), R"(
|
||||
%val1 = OpIAdd %u32 %u32_0 %u32_1
|
||||
%val2 = OpIAdd %u32 %u32_2 %u32_0
|
||||
%val3 = OpIAdd %u32 %u32_3 %u32_2
|
||||
%val4 = OpIAdd %u32 %u32_1 %u32_1
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_P(MarkvTest, S32Add) {
|
||||
TestEncodeDecodeShaderMainBody(GetParam(), R"(
|
||||
%val1 = OpIAdd %s32 %s32_0 %s32_1
|
||||
%val2 = OpIAdd %s32 %s32_2 %s32_0
|
||||
%val3 = OpIAdd %s32 %s32_m1 %s32_2
|
||||
%val4 = OpIAdd %s32 %s32_1 %s32_1
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_P(MarkvTest, F32Dot) {
|
||||
TestEncodeDecodeShaderMainBody(GetParam(), R"(
|
||||
%dot2_1 = OpDot %f32 %f32vec2_01 %f32vec2_12
|
||||
%dot2_2 = OpDot %f32 %f32vec2_01 %f32vec2_01
|
||||
%dot2_3 = OpDot %f32 %f32vec2_12 %f32vec2_12
|
||||
%dot3_1 = OpDot %f32 %f32vec3_012 %f32vec3_123
|
||||
%dot3_2 = OpDot %f32 %f32vec3_012 %f32vec3_012
|
||||
%dot3_3 = OpDot %f32 %f32vec3_123 %f32vec3_123
|
||||
%dot4_1 = OpDot %f32 %f32vec4_0123 %f32vec4_1234
|
||||
%dot4_2 = OpDot %f32 %f32vec4_0123 %f32vec4_0123
|
||||
%dot4_3 = OpDot %f32 %f32vec4_1234 %f32vec4_1234
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_P(MarkvTest, F32VectorCompositeConstruct) {
|
||||
TestEncodeDecodeShaderMainBody(GetParam(), R"(
|
||||
%cc1 = OpCompositeConstruct %f32vec4 %f32vec2_01 %f32vec2_12
|
||||
%cc2 = OpCompositeConstruct %f32vec3 %f32vec2_01 %f32_2
|
||||
%cc3 = OpCompositeConstruct %f32vec2 %f32_1 %f32_2
|
||||
%cc4 = OpCompositeConstruct %f32vec4 %f32_1 %f32_2 %cc3
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_P(MarkvTest, U32VectorCompositeConstruct) {
|
||||
TestEncodeDecodeShaderMainBody(GetParam(), R"(
|
||||
%cc1 = OpCompositeConstruct %u32vec4 %u32vec2_01 %u32vec2_12
|
||||
%cc2 = OpCompositeConstruct %u32vec3 %u32vec2_01 %u32_2
|
||||
%cc3 = OpCompositeConstruct %u32vec2 %u32_1 %u32_2
|
||||
%cc4 = OpCompositeConstruct %u32vec4 %u32_1 %u32_2 %cc3
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_P(MarkvTest, S32VectorCompositeConstruct) {
|
||||
TestEncodeDecodeShaderMainBody(GetParam(), R"(
|
||||
%cc1 = OpCompositeConstruct %u32vec4 %u32vec2_01 %u32vec2_12
|
||||
%cc2 = OpCompositeConstruct %u32vec3 %u32vec2_01 %u32_2
|
||||
%cc3 = OpCompositeConstruct %u32vec2 %u32_1 %u32_2
|
||||
%cc4 = OpCompositeConstruct %u32vec4 %u32_1 %u32_2 %cc3
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_P(MarkvTest, F32VectorCompositeExtract) {
|
||||
TestEncodeDecodeShaderMainBody(GetParam(), R"(
|
||||
%f32vec4_3210 = OpCompositeConstruct %f32vec4 %f32_3 %f32_2 %f32_1 %f32_0
|
||||
%f32vec3_013 = OpCompositeExtract %f32vec3 %f32vec4_0123 0 1 3
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_P(MarkvTest, F32VectorComparison) {
|
||||
TestEncodeDecodeShaderMainBody(GetParam(), R"(
|
||||
%f32vec4_3210 = OpCompositeConstruct %f32vec4 %f32_3 %f32_2 %f32_1 %f32_0
|
||||
%c1 = OpFOrdEqual %boolvec4 %f32vec4_0123 %f32vec4_3210
|
||||
%c2 = OpFUnordEqual %boolvec4 %f32vec4_0123 %f32vec4_3210
|
||||
%c3 = OpFOrdNotEqual %boolvec4 %f32vec4_0123 %f32vec4_3210
|
||||
%c4 = OpFUnordNotEqual %boolvec4 %f32vec4_0123 %f32vec4_3210
|
||||
%c5 = OpFOrdLessThan %boolvec4 %f32vec4_0123 %f32vec4_3210
|
||||
%c6 = OpFUnordLessThan %boolvec4 %f32vec4_0123 %f32vec4_3210
|
||||
%c7 = OpFOrdGreaterThan %boolvec4 %f32vec4_0123 %f32vec4_3210
|
||||
%c8 = OpFUnordGreaterThan %boolvec4 %f32vec4_0123 %f32vec4_3210
|
||||
%c9 = OpFOrdLessThanEqual %boolvec4 %f32vec4_0123 %f32vec4_3210
|
||||
%c10 = OpFUnordLessThanEqual %boolvec4 %f32vec4_0123 %f32vec4_3210
|
||||
%c11 = OpFOrdGreaterThanEqual %boolvec4 %f32vec4_0123 %f32vec4_3210
|
||||
%c12 = OpFUnordGreaterThanEqual %boolvec4 %f32vec4_0123 %f32vec4_3210
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_P(MarkvTest, VectorShuffle) {
|
||||
TestEncodeDecodeShaderMainBody(GetParam(), R"(
|
||||
%f32vec4_3210 = OpCompositeConstruct %f32vec4 %f32_3 %f32_2 %f32_1 %f32_0
|
||||
%sh1 = OpVectorShuffle %f32vec2 %f32vec4_0123 %f32vec4_3210 3 6
|
||||
%sh2 = OpVectorShuffle %f32vec3 %f32vec2_01 %f32vec4_3210 0 3 4
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_P(MarkvTest, VectorTimesScalar) {
|
||||
TestEncodeDecodeShaderMainBody(GetParam(), R"(
|
||||
%f32vec4_3210 = OpCompositeConstruct %f32vec4 %f32_3 %f32_2 %f32_1 %f32_0
|
||||
%res1 = OpVectorTimesScalar %f32vec4 %f32vec4_0123 %f32_2
|
||||
%res2 = OpVectorTimesScalar %f32vec4 %f32vec4_3210 %f32_2
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_P(MarkvTest, SpirvSpecSample) {
|
||||
TestEncodeDecode(GetParam(), R"(
|
||||
OpCapability Shader
|
||||
%1 = OpExtInstImport "GLSL.std.450"
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint Fragment %4 "main" %31 %33 %42 %57
|
||||
OpExecutionMode %4 OriginLowerLeft
|
||||
|
||||
; Debug information
|
||||
OpSource GLSL 450
|
||||
OpName %4 "main"
|
||||
OpName %9 "scale"
|
||||
OpName %17 "S"
|
||||
OpMemberName %17 0 "b"
|
||||
OpMemberName %17 1 "v"
|
||||
OpMemberName %17 2 "i"
|
||||
OpName %18 "blockName"
|
||||
OpMemberName %18 0 "s"
|
||||
OpMemberName %18 1 "cond"
|
||||
OpName %20 ""
|
||||
OpName %31 "color"
|
||||
OpName %33 "color1"
|
||||
OpName %42 "color2"
|
||||
OpName %48 "i"
|
||||
OpName %57 "multiplier"
|
||||
|
||||
; Annotations (non-debug)
|
||||
OpDecorate %15 ArrayStride 16
|
||||
OpMemberDecorate %17 0 Offset 0
|
||||
OpMemberDecorate %17 1 Offset 16
|
||||
OpMemberDecorate %17 2 Offset 96
|
||||
OpMemberDecorate %18 0 Offset 0
|
||||
OpMemberDecorate %18 1 Offset 112
|
||||
OpDecorate %18 Block
|
||||
OpDecorate %20 DescriptorSet 0
|
||||
OpDecorate %42 NoPerspective
|
||||
|
||||
; All types, variables, and constants
|
||||
%2 = OpTypeVoid
|
||||
%3 = OpTypeFunction %2 ; void ()
|
||||
%6 = OpTypeFloat 32 ; 32-bit float
|
||||
%7 = OpTypeVector %6 4 ; vec4
|
||||
%8 = OpTypePointer Function %7 ; function-local vec4*
|
||||
%10 = OpConstant %6 1
|
||||
%11 = OpConstant %6 2
|
||||
%12 = OpConstantComposite %7 %10 %10 %11 %10 ; vec4(1.0, 1.0, 2.0, 1.0)
|
||||
%13 = OpTypeInt 32 0 ; 32-bit int, sign-less
|
||||
%14 = OpConstant %13 5
|
||||
%15 = OpTypeArray %7 %14
|
||||
%16 = OpTypeInt 32 1
|
||||
%17 = OpTypeStruct %13 %15 %16
|
||||
%18 = OpTypeStruct %17 %13
|
||||
%19 = OpTypePointer Uniform %18
|
||||
%20 = OpVariable %19 Uniform
|
||||
%21 = OpConstant %16 1
|
||||
%22 = OpTypePointer Uniform %13
|
||||
%25 = OpTypeBool
|
||||
%26 = OpConstant %13 0
|
||||
%30 = OpTypePointer Output %7
|
||||
%31 = OpVariable %30 Output
|
||||
%32 = OpTypePointer Input %7
|
||||
%33 = OpVariable %32 Input
|
||||
%35 = OpConstant %16 0
|
||||
%36 = OpConstant %16 2
|
||||
%37 = OpTypePointer Uniform %7
|
||||
%42 = OpVariable %32 Input
|
||||
%47 = OpTypePointer Function %16
|
||||
%55 = OpConstant %16 4
|
||||
%57 = OpVariable %32 Input
|
||||
|
||||
; All functions
|
||||
%4 = OpFunction %2 None %3 ; main()
|
||||
%5 = OpLabel
|
||||
%9 = OpVariable %8 Function
|
||||
%48 = OpVariable %47 Function
|
||||
OpStore %9 %12
|
||||
%23 = OpAccessChain %22 %20 %21 ; location of cond
|
||||
%24 = OpLoad %13 %23 ; load 32-bit int from cond
|
||||
%27 = OpINotEqual %25 %24 %26 ; convert to bool
|
||||
OpSelectionMerge %29 None ; structured if
|
||||
OpBranchConditional %27 %28 %41 ; if cond
|
||||
%28 = OpLabel ; then
|
||||
%34 = OpLoad %7 %33
|
||||
%38 = OpAccessChain %37 %20 %35 %21 %36 ; s.v[2]
|
||||
%39 = OpLoad %7 %38
|
||||
%40 = OpFAdd %7 %34 %39
|
||||
OpStore %31 %40
|
||||
OpBranch %29
|
||||
%41 = OpLabel ; else
|
||||
%43 = OpLoad %7 %42
|
||||
%44 = OpExtInst %7 %1 Sqrt %43 ; extended instruction sqrt
|
||||
%45 = OpLoad %7 %9
|
||||
%46 = OpFMul %7 %44 %45
|
||||
OpStore %31 %46
|
||||
OpBranch %29
|
||||
%29 = OpLabel ; endif
|
||||
OpStore %48 %35
|
||||
OpBranch %49
|
||||
%49 = OpLabel
|
||||
OpLoopMerge %51 %52 None ; structured loop
|
||||
OpBranch %53
|
||||
%53 = OpLabel
|
||||
%54 = OpLoad %16 %48
|
||||
%56 = OpSLessThan %25 %54 %55 ; i < 4 ?
|
||||
OpBranchConditional %56 %50 %51 ; body or break
|
||||
%50 = OpLabel ; body
|
||||
%58 = OpLoad %7 %57
|
||||
%59 = OpLoad %7 %31
|
||||
%60 = OpFMul %7 %59 %58
|
||||
OpStore %31 %60
|
||||
OpBranch %52
|
||||
%52 = OpLabel ; continue target
|
||||
%61 = OpLoad %16 %48
|
||||
%62 = OpIAdd %16 %61 %21 ; ++i
|
||||
OpStore %48 %62
|
||||
OpBranch %49 ; loop back
|
||||
%51 = OpLabel ; loop merge point
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_P(MarkvTest, SampleFromDeadBranchEliminationTest) {
|
||||
TestEncodeDecode(GetParam(), R"(
|
||||
OpCapability Shader
|
||||
%1 = OpExtInstImport "GLSL.std.450"
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint Fragment %main "main" %gl_FragColor
|
||||
OpExecutionMode %main OriginUpperLeft
|
||||
OpSource GLSL 140
|
||||
OpName %main "main"
|
||||
OpName %gl_FragColor "gl_FragColor"
|
||||
%void = OpTypeVoid
|
||||
%5 = OpTypeFunction %void
|
||||
%bool = OpTypeBool
|
||||
%true = OpConstantTrue %bool
|
||||
%float = OpTypeFloat 32
|
||||
%v4float = OpTypeVector %float 4
|
||||
%_ptr_Function_v4float = OpTypePointer Function %v4float
|
||||
%float_0 = OpConstant %float 0
|
||||
%12 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0
|
||||
%float_1 = OpConstant %float 1
|
||||
%14 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1
|
||||
%_ptr_Output_v4float = OpTypePointer Output %v4float
|
||||
%gl_FragColor = OpVariable %_ptr_Output_v4float Output
|
||||
%_ptr_Input_v4float = OpTypePointer Input %v4float
|
||||
%main = OpFunction %void None %5
|
||||
%17 = OpLabel
|
||||
OpSelectionMerge %18 None
|
||||
OpBranchConditional %true %19 %20
|
||||
%19 = OpLabel
|
||||
OpBranch %18
|
||||
%20 = OpLabel
|
||||
OpBranch %18
|
||||
%18 = OpLabel
|
||||
%21 = OpPhi %v4float %12 %19 %14 %20
|
||||
OpStore %gl_FragColor %21
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)");
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(AllMarkvModels, MarkvTest,
|
||||
::testing::ValuesIn(std::vector<MarkvModelType>{
|
||||
kMarkvModelShaderLite,
|
||||
kMarkvModelShaderMid,
|
||||
kMarkvModelShaderMax,
|
||||
}));
|
||||
|
||||
} // namespace
|
||||
} // namespace comp
|
||||
} // namespace spvtools
|
@ -1,317 +0,0 @@
|
||||
// Copyright (c) 2017 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "gmock/gmock.h"
|
||||
#include "source/comp/bit_stream.h"
|
||||
#include "source/comp/huffman_codec.h"
|
||||
|
||||
namespace spvtools {
|
||||
namespace comp {
|
||||
namespace {
|
||||
|
||||
const std::map<std::string, uint32_t>& GetTestSet() {
|
||||
static const std::map<std::string, uint32_t> hist = {
|
||||
{"a", 4}, {"e", 7}, {"f", 3}, {"h", 2}, {"i", 3},
|
||||
{"m", 2}, {"n", 2}, {"s", 2}, {"t", 2}, {"l", 1},
|
||||
{"o", 2}, {"p", 1}, {"r", 1}, {"u", 1}, {"x", 1},
|
||||
};
|
||||
|
||||
return hist;
|
||||
}
|
||||
|
||||
class TestBitReader {
|
||||
public:
|
||||
TestBitReader(const std::string& bits) : bits_(bits) {}
|
||||
|
||||
bool ReadBit(bool* bit) {
|
||||
if (pos_ < bits_.length()) {
|
||||
*bit = bits_[pos_++] == '0' ? false : true;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private:
|
||||
std::string bits_;
|
||||
size_t pos_ = 0;
|
||||
};
|
||||
|
||||
TEST(Huffman, PrintTree) {
|
||||
HuffmanCodec<std::string> huffman(GetTestSet());
|
||||
std::stringstream ss;
|
||||
huffman.PrintTree(ss);
|
||||
|
||||
// clang-format off
|
||||
const std::string expected = std::string(R"(
|
||||
15-----7------e
|
||||
8------4------a
|
||||
4------2------m
|
||||
2------n
|
||||
19-----8------4------2------o
|
||||
2------s
|
||||
4------2------t
|
||||
2------1------l
|
||||
1------p
|
||||
11-----5------2------1------r
|
||||
1------u
|
||||
3------f
|
||||
6------3------i
|
||||
3------1------x
|
||||
2------h
|
||||
)").substr(1);
|
||||
// clang-format on
|
||||
|
||||
EXPECT_EQ(expected, ss.str());
|
||||
}
|
||||
|
||||
TEST(Huffman, PrintTable) {
|
||||
HuffmanCodec<std::string> huffman(GetTestSet());
|
||||
std::stringstream ss;
|
||||
huffman.PrintTable(ss);
|
||||
|
||||
const std::string expected = std::string(R"(
|
||||
e 7 11
|
||||
a 4 101
|
||||
i 3 0001
|
||||
f 3 0010
|
||||
t 2 0101
|
||||
s 2 0110
|
||||
o 2 0111
|
||||
n 2 1000
|
||||
m 2 1001
|
||||
h 2 00000
|
||||
x 1 00001
|
||||
u 1 00110
|
||||
r 1 00111
|
||||
p 1 01000
|
||||
l 1 01001
|
||||
)")
|
||||
.substr(1);
|
||||
|
||||
EXPECT_EQ(expected, ss.str());
|
||||
}
|
||||
|
||||
TEST(Huffman, TestValidity) {
|
||||
HuffmanCodec<std::string> huffman(GetTestSet());
|
||||
const auto& encoding_table = huffman.GetEncodingTable();
|
||||
std::vector<std::string> codes;
|
||||
for (const auto& entry : encoding_table) {
|
||||
codes.push_back(BitsToStream(entry.second.first, entry.second.second));
|
||||
}
|
||||
|
||||
std::sort(codes.begin(), codes.end());
|
||||
|
||||
ASSERT_LT(codes.size(), 20u) << "Inefficient test ahead";
|
||||
|
||||
for (size_t i = 0; i < codes.size(); ++i) {
|
||||
for (size_t j = i + 1; j < codes.size(); ++j) {
|
||||
ASSERT_FALSE(codes[i] == codes[j].substr(0, codes[i].length()))
|
||||
<< codes[i] << " is prefix of " << codes[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Huffman, TestEncode) {
|
||||
HuffmanCodec<std::string> huffman(GetTestSet());
|
||||
|
||||
uint64_t bits = 0;
|
||||
size_t num_bits = 0;
|
||||
|
||||
EXPECT_TRUE(huffman.Encode("e", &bits, &num_bits));
|
||||
EXPECT_EQ(2u, num_bits);
|
||||
EXPECT_EQ("11", BitsToStream(bits, num_bits));
|
||||
|
||||
EXPECT_TRUE(huffman.Encode("a", &bits, &num_bits));
|
||||
EXPECT_EQ(3u, num_bits);
|
||||
EXPECT_EQ("101", BitsToStream(bits, num_bits));
|
||||
|
||||
EXPECT_TRUE(huffman.Encode("x", &bits, &num_bits));
|
||||
EXPECT_EQ(5u, num_bits);
|
||||
EXPECT_EQ("00001", BitsToStream(bits, num_bits));
|
||||
|
||||
EXPECT_FALSE(huffman.Encode("y", &bits, &num_bits));
|
||||
}
|
||||
|
||||
TEST(Huffman, TestDecode) {
|
||||
HuffmanCodec<std::string> huffman(GetTestSet());
|
||||
TestBitReader bit_reader(
|
||||
"01001"
|
||||
"0001"
|
||||
"1000"
|
||||
"00110"
|
||||
"00001"
|
||||
"00");
|
||||
auto read_bit = [&bit_reader](bool* bit) { return bit_reader.ReadBit(bit); };
|
||||
|
||||
std::string decoded;
|
||||
|
||||
ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
|
||||
EXPECT_EQ("l", decoded);
|
||||
|
||||
ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
|
||||
EXPECT_EQ("i", decoded);
|
||||
|
||||
ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
|
||||
EXPECT_EQ("n", decoded);
|
||||
|
||||
ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
|
||||
EXPECT_EQ("u", decoded);
|
||||
|
||||
ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
|
||||
EXPECT_EQ("x", decoded);
|
||||
|
||||
ASSERT_FALSE(huffman.DecodeFromStream(read_bit, &decoded));
|
||||
}
|
||||
|
||||
TEST(Huffman, TestDecodeNumbers) {
|
||||
const std::map<uint32_t, uint32_t> hist = {{1, 10}, {2, 5}, {3, 15}};
|
||||
HuffmanCodec<uint32_t> huffman(hist);
|
||||
|
||||
TestBitReader bit_reader(
|
||||
"1"
|
||||
"1"
|
||||
"01"
|
||||
"00"
|
||||
"01"
|
||||
"1");
|
||||
auto read_bit = [&bit_reader](bool* bit) { return bit_reader.ReadBit(bit); };
|
||||
|
||||
uint32_t decoded;
|
||||
|
||||
ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
|
||||
EXPECT_EQ(3u, decoded);
|
||||
|
||||
ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
|
||||
EXPECT_EQ(3u, decoded);
|
||||
|
||||
ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
|
||||
EXPECT_EQ(2u, decoded);
|
||||
|
||||
ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
|
||||
EXPECT_EQ(1u, decoded);
|
||||
|
||||
ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
|
||||
EXPECT_EQ(2u, decoded);
|
||||
|
||||
ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
|
||||
EXPECT_EQ(3u, decoded);
|
||||
}
|
||||
|
||||
TEST(Huffman, SerializeToTextU64) {
|
||||
const std::map<uint64_t, uint32_t> hist = {{1001, 10}, {1002, 5}, {1003, 15}};
|
||||
HuffmanCodec<uint64_t> huffman(hist);
|
||||
|
||||
const std::string code = huffman.SerializeToText(2);
|
||||
|
||||
const std::string expected = R"((5, {
|
||||
{0, 0, 0},
|
||||
{1001, 0, 0},
|
||||
{1002, 0, 0},
|
||||
{1003, 0, 0},
|
||||
{0, 1, 2},
|
||||
{0, 4, 3},
|
||||
}))";
|
||||
|
||||
ASSERT_EQ(expected, code);
|
||||
}
|
||||
|
||||
TEST(Huffman, SerializeToTextString) {
|
||||
const std::map<std::string, uint32_t> hist = {
|
||||
{"aaa", 10}, {"bbb", 20}, {"ccc", 15}};
|
||||
HuffmanCodec<std::string> huffman(hist);
|
||||
|
||||
const std::string code = huffman.SerializeToText(4);
|
||||
|
||||
const std::string expected = R"((5, {
|
||||
{"", 0, 0},
|
||||
{"aaa", 0, 0},
|
||||
{"bbb", 0, 0},
|
||||
{"ccc", 0, 0},
|
||||
{"", 3, 1},
|
||||
{"", 4, 2},
|
||||
}))";
|
||||
|
||||
ASSERT_EQ(expected, code);
|
||||
}
|
||||
|
||||
TEST(Huffman, CreateFromTextString) {
|
||||
std::vector<HuffmanCodec<std::string>::Node> nodes = {
|
||||
{},
|
||||
{"root", 2, 3},
|
||||
{"left", 0, 0},
|
||||
{"right", 0, 0},
|
||||
};
|
||||
|
||||
HuffmanCodec<std::string> huffman(1, std::move(nodes));
|
||||
|
||||
std::stringstream ss;
|
||||
huffman.PrintTree(ss);
|
||||
|
||||
const std::string expected = std::string(R"(
|
||||
0------right
|
||||
0------left
|
||||
)")
|
||||
.substr(1);
|
||||
|
||||
EXPECT_EQ(expected, ss.str());
|
||||
}
|
||||
|
||||
TEST(Huffman, CreateFromTextU64) {
|
||||
HuffmanCodec<uint64_t> huffman(5, {
|
||||
{0, 0, 0},
|
||||
{1001, 0, 0},
|
||||
{1002, 0, 0},
|
||||
{1003, 0, 0},
|
||||
{0, 1, 2},
|
||||
{0, 4, 3},
|
||||
});
|
||||
|
||||
std::stringstream ss;
|
||||
huffman.PrintTree(ss);
|
||||
|
||||
const std::string expected = std::string(R"(
|
||||
0------1003
|
||||
0------0------1002
|
||||
0------1001
|
||||
)")
|
||||
.substr(1);
|
||||
|
||||
EXPECT_EQ(expected, ss.str());
|
||||
|
||||
TestBitReader bit_reader("01");
|
||||
auto read_bit = [&bit_reader](bool* bit) { return bit_reader.ReadBit(bit); };
|
||||
|
||||
uint64_t decoded = 0;
|
||||
ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
|
||||
EXPECT_EQ(1002u, decoded);
|
||||
|
||||
uint64_t bits = 0;
|
||||
size_t num_bits = 0;
|
||||
|
||||
EXPECT_TRUE(huffman.Encode(1001, &bits, &num_bits));
|
||||
EXPECT_EQ(2u, num_bits);
|
||||
EXPECT_EQ("00", BitsToStream(bits, num_bits));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace comp
|
||||
} // namespace spvtools
|
@ -1,828 +0,0 @@
|
||||
// Copyright (c) 2017 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "gmock/gmock.h"
|
||||
#include "source/comp/move_to_front.h"
|
||||
|
||||
namespace spvtools {
|
||||
namespace comp {
|
||||
namespace {
|
||||
|
||||
// Class used to test the inner workings of MoveToFront.
|
||||
class MoveToFrontTester : public MoveToFront {
|
||||
public:
|
||||
// Inserts the value in the internal tree data structure. For testing only.
|
||||
void TestInsert(uint32_t val) { InsertNode(CreateNode(val, val)); }
|
||||
|
||||
// Removes the value from the internal tree data structure. For testing only.
|
||||
void TestRemove(uint32_t val) {
|
||||
const auto it = value_to_node_.find(val);
|
||||
assert(it != value_to_node_.end());
|
||||
RemoveNode(it->second);
|
||||
}
|
||||
|
||||
// Prints the internal tree data structure to |out|. For testing only.
|
||||
void PrintTree(std::ostream& out, bool print_timestamp = false) const {
|
||||
if (root_) PrintTreeInternal(out, root_, 1, print_timestamp);
|
||||
}
|
||||
|
||||
// Returns node handle corresponding to the value. The value may not be in the
|
||||
// tree.
|
||||
uint32_t GetNodeHandle(uint32_t value) const {
|
||||
const auto it = value_to_node_.find(value);
|
||||
if (it == value_to_node_.end()) return 0;
|
||||
|
||||
return it->second;
|
||||
}
|
||||
|
||||
// Returns total node count (both those in the tree and removed,
|
||||
// but not the NIL singleton).
|
||||
size_t GetTotalNodeCount() const {
|
||||
assert(nodes_.size());
|
||||
return nodes_.size() - 1;
|
||||
}
|
||||
|
||||
uint32_t GetLastAccessedValue() const { return last_accessed_value_; }
|
||||
|
||||
private:
|
||||
// Prints the internal tree data structure for debug purposes in the following
|
||||
// format:
|
||||
// 10H3S4----5H1S1-----D2
|
||||
// 15H2S2----12H1S1----D3
|
||||
// Right links are horizontal, left links step down one line.
|
||||
// 5H1S1 is read as value 5, height 1, size 1. Optionally node label can also
|
||||
// contain timestamp (5H1S1T15). D3 stands for depth 3.
|
||||
void PrintTreeInternal(std::ostream& out, uint32_t node, size_t depth,
|
||||
bool print_timestamp) const;
|
||||
};
|
||||
|
||||
void MoveToFrontTester::PrintTreeInternal(std::ostream& out, uint32_t node,
|
||||
size_t depth,
|
||||
bool print_timestamp) const {
|
||||
if (!node) {
|
||||
out << "D" << depth - 1 << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t kTextFieldWvaluethWithoutTimestamp = 10;
|
||||
const size_t kTextFieldWvaluethWithTimestamp = 14;
|
||||
const size_t text_field_wvalueth = print_timestamp
|
||||
? kTextFieldWvaluethWithTimestamp
|
||||
: kTextFieldWvaluethWithoutTimestamp;
|
||||
|
||||
std::stringstream label;
|
||||
label << ValueOf(node) << "H" << HeightOf(node) << "S" << SizeOf(node);
|
||||
if (print_timestamp) label << "T" << TimestampOf(node);
|
||||
const size_t label_length = label.str().length();
|
||||
if (label_length < text_field_wvalueth)
|
||||
label << std::string(text_field_wvalueth - label_length, '-');
|
||||
|
||||
out << label.str();
|
||||
|
||||
PrintTreeInternal(out, RightOf(node), depth + 1, print_timestamp);
|
||||
|
||||
if (LeftOf(node)) {
|
||||
out << std::string(depth * text_field_wvalueth, ' ');
|
||||
PrintTreeInternal(out, LeftOf(node), depth + 1, print_timestamp);
|
||||
}
|
||||
}
|
||||
|
||||
void CheckTree(const MoveToFrontTester& mtf, const std::string& expected,
|
||||
bool print_timestamp = false) {
|
||||
std::stringstream ss;
|
||||
mtf.PrintTree(ss, print_timestamp);
|
||||
EXPECT_EQ(expected, ss.str());
|
||||
}
|
||||
|
||||
TEST(MoveToFront, EmptyTree) {
|
||||
MoveToFrontTester mtf;
|
||||
CheckTree(mtf, std::string());
|
||||
}
|
||||
|
||||
TEST(MoveToFront, InsertLeftRotation) {
|
||||
MoveToFrontTester mtf;
|
||||
|
||||
mtf.TestInsert(30);
|
||||
mtf.TestInsert(20);
|
||||
|
||||
CheckTree(mtf, std::string(R"(
|
||||
30H2S2----20H1S1----D2
|
||||
)")
|
||||
.substr(1));
|
||||
|
||||
mtf.TestInsert(10);
|
||||
CheckTree(mtf, std::string(R"(
|
||||
20H2S3----10H1S1----D2
|
||||
30H1S1----D2
|
||||
)")
|
||||
.substr(1));
|
||||
}
|
||||
|
||||
TEST(MoveToFront, InsertRightRotation) {
|
||||
MoveToFrontTester mtf;
|
||||
|
||||
mtf.TestInsert(10);
|
||||
mtf.TestInsert(20);
|
||||
|
||||
CheckTree(mtf, std::string(R"(
|
||||
10H2S2----D1
|
||||
20H1S1----D2
|
||||
)")
|
||||
.substr(1));
|
||||
|
||||
mtf.TestInsert(30);
|
||||
CheckTree(mtf, std::string(R"(
|
||||
20H2S3----10H1S1----D2
|
||||
30H1S1----D2
|
||||
)")
|
||||
.substr(1));
|
||||
}
|
||||
|
||||
TEST(MoveToFront, InsertRightLeftRotation) {
|
||||
MoveToFrontTester mtf;
|
||||
|
||||
mtf.TestInsert(30);
|
||||
mtf.TestInsert(20);
|
||||
|
||||
CheckTree(mtf, std::string(R"(
|
||||
30H2S2----20H1S1----D2
|
||||
)")
|
||||
.substr(1));
|
||||
|
||||
mtf.TestInsert(25);
|
||||
CheckTree(mtf, std::string(R"(
|
||||
25H2S3----20H1S1----D2
|
||||
30H1S1----D2
|
||||
)")
|
||||
.substr(1));
|
||||
}
|
||||
|
||||
TEST(MoveToFront, InsertLeftRightRotation) {
|
||||
MoveToFrontTester mtf;
|
||||
|
||||
mtf.TestInsert(10);
|
||||
mtf.TestInsert(20);
|
||||
|
||||
CheckTree(mtf, std::string(R"(
|
||||
10H2S2----D1
|
||||
20H1S1----D2
|
||||
)")
|
||||
.substr(1));
|
||||
|
||||
mtf.TestInsert(15);
|
||||
CheckTree(mtf, std::string(R"(
|
||||
15H2S3----10H1S1----D2
|
||||
20H1S1----D2
|
||||
)")
|
||||
.substr(1));
|
||||
}
|
||||
|
||||
TEST(MoveToFront, RemoveSingleton) {
|
||||
MoveToFrontTester mtf;
|
||||
|
||||
mtf.TestInsert(10);
|
||||
CheckTree(mtf, std::string(R"(
|
||||
10H1S1----D1
|
||||
)")
|
||||
.substr(1));
|
||||
|
||||
mtf.TestRemove(10);
|
||||
CheckTree(mtf, "");
|
||||
}
|
||||
|
||||
TEST(MoveToFront, RemoveRootWithScapegoat) {
|
||||
MoveToFrontTester mtf;
|
||||
|
||||
mtf.TestInsert(10);
|
||||
mtf.TestInsert(5);
|
||||
mtf.TestInsert(15);
|
||||
CheckTree(mtf, std::string(R"(
|
||||
10H2S3----5H1S1-----D2
|
||||
15H1S1----D2
|
||||
)")
|
||||
.substr(1));
|
||||
|
||||
mtf.TestRemove(10);
|
||||
CheckTree(mtf, std::string(R"(
|
||||
15H2S2----5H1S1-----D2
|
||||
)")
|
||||
.substr(1));
|
||||
}
|
||||
|
||||
TEST(MoveToFront, RemoveRightRotation) {
|
||||
MoveToFrontTester mtf;
|
||||
|
||||
mtf.TestInsert(10);
|
||||
mtf.TestInsert(5);
|
||||
mtf.TestInsert(15);
|
||||
mtf.TestInsert(20);
|
||||
CheckTree(mtf, std::string(R"(
|
||||
10H3S4----5H1S1-----D2
|
||||
15H2S2----D2
|
||||
20H1S1----D3
|
||||
)")
|
||||
.substr(1));
|
||||
|
||||
mtf.TestRemove(5);
|
||||
|
||||
CheckTree(mtf, std::string(R"(
|
||||
15H2S3----10H1S1----D2
|
||||
20H1S1----D2
|
||||
)")
|
||||
.substr(1));
|
||||
}
|
||||
|
||||
TEST(MoveToFront, RemoveLeftRotation) {
|
||||
MoveToFrontTester mtf;
|
||||
|
||||
mtf.TestInsert(10);
|
||||
mtf.TestInsert(15);
|
||||
mtf.TestInsert(5);
|
||||
mtf.TestInsert(1);
|
||||
CheckTree(mtf, std::string(R"(
|
||||
10H3S4----5H2S2-----1H1S1-----D3
|
||||
15H1S1----D2
|
||||
)")
|
||||
.substr(1));
|
||||
|
||||
mtf.TestRemove(15);
|
||||
|
||||
CheckTree(mtf, std::string(R"(
|
||||
5H2S3-----1H1S1-----D2
|
||||
10H1S1----D2
|
||||
)")
|
||||
.substr(1));
|
||||
}
|
||||
|
||||
TEST(MoveToFront, RemoveLeftRightRotation) {
|
||||
MoveToFrontTester mtf;
|
||||
|
||||
mtf.TestInsert(10);
|
||||
mtf.TestInsert(15);
|
||||
mtf.TestInsert(5);
|
||||
mtf.TestInsert(12);
|
||||
CheckTree(mtf, std::string(R"(
|
||||
10H3S4----5H1S1-----D2
|
||||
15H2S2----12H1S1----D3
|
||||
)")
|
||||
.substr(1));
|
||||
|
||||
mtf.TestRemove(5);
|
||||
|
||||
CheckTree(mtf, std::string(R"(
|
||||
12H2S3----10H1S1----D2
|
||||
15H1S1----D2
|
||||
)")
|
||||
.substr(1));
|
||||
}
|
||||
|
||||
TEST(MoveToFront, RemoveRightLeftRotation) {
|
||||
MoveToFrontTester mtf;
|
||||
|
||||
mtf.TestInsert(10);
|
||||
mtf.TestInsert(15);
|
||||
mtf.TestInsert(5);
|
||||
mtf.TestInsert(8);
|
||||
CheckTree(mtf, std::string(R"(
|
||||
10H3S4----5H2S2-----D2
|
||||
8H1S1-----D3
|
||||
15H1S1----D2
|
||||
)")
|
||||
.substr(1));
|
||||
|
||||
mtf.TestRemove(15);
|
||||
|
||||
CheckTree(mtf, std::string(R"(
|
||||
8H2S3-----5H1S1-----D2
|
||||
10H1S1----D2
|
||||
)")
|
||||
.substr(1));
|
||||
}
|
||||
|
||||
TEST(MoveToFront, MultipleOperations) {
|
||||
MoveToFrontTester mtf;
|
||||
std::vector<uint32_t> vals = {5, 11, 12, 16, 15, 6, 14, 2,
|
||||
7, 10, 4, 8, 9, 3, 1, 13};
|
||||
|
||||
for (uint32_t i : vals) {
|
||||
mtf.TestInsert(i);
|
||||
}
|
||||
|
||||
CheckTree(mtf, std::string(R"(
|
||||
11H5S16---5H4S10----3H3S4-----2H2S2-----1H1S1-----D5
|
||||
4H1S1-----D4
|
||||
7H3S5-----6H1S1-----D4
|
||||
9H2S3-----8H1S1-----D5
|
||||
10H1S1----D5
|
||||
15H3S5----13H2S3----12H1S1----D4
|
||||
14H1S1----D4
|
||||
16H1S1----D3
|
||||
)")
|
||||
.substr(1));
|
||||
|
||||
mtf.TestRemove(11);
|
||||
|
||||
CheckTree(mtf, std::string(R"(
|
||||
10H5S15---5H4S9-----3H3S4-----2H2S2-----1H1S1-----D5
|
||||
4H1S1-----D4
|
||||
7H3S4-----6H1S1-----D4
|
||||
9H2S2-----8H1S1-----D5
|
||||
15H3S5----13H2S3----12H1S1----D4
|
||||
14H1S1----D4
|
||||
16H1S1----D3
|
||||
)")
|
||||
.substr(1));
|
||||
|
||||
mtf.TestInsert(11);
|
||||
|
||||
CheckTree(mtf, std::string(R"(
|
||||
10H5S16---5H4S9-----3H3S4-----2H2S2-----1H1S1-----D5
|
||||
4H1S1-----D4
|
||||
7H3S4-----6H1S1-----D4
|
||||
9H2S2-----8H1S1-----D5
|
||||
13H3S6----12H2S2----11H1S1----D4
|
||||
15H2S3----14H1S1----D4
|
||||
16H1S1----D4
|
||||
)")
|
||||
.substr(1));
|
||||
|
||||
mtf.TestRemove(5);
|
||||
|
||||
CheckTree(mtf, std::string(R"(
|
||||
10H5S15---6H4S8-----3H3S4-----2H2S2-----1H1S1-----D5
|
||||
4H1S1-----D4
|
||||
8H2S3-----7H1S1-----D4
|
||||
9H1S1-----D4
|
||||
13H3S6----12H2S2----11H1S1----D4
|
||||
15H2S3----14H1S1----D4
|
||||
16H1S1----D4
|
||||
)")
|
||||
.substr(1));
|
||||
|
||||
mtf.TestInsert(5);
|
||||
|
||||
CheckTree(mtf, std::string(R"(
|
||||
10H5S16---6H4S9-----3H3S5-----2H2S2-----1H1S1-----D5
|
||||
4H2S2-----D4
|
||||
5H1S1-----D5
|
||||
8H2S3-----7H1S1-----D4
|
||||
9H1S1-----D4
|
||||
13H3S6----12H2S2----11H1S1----D4
|
||||
15H2S3----14H1S1----D4
|
||||
16H1S1----D4
|
||||
)")
|
||||
.substr(1));
|
||||
|
||||
mtf.TestRemove(2);
|
||||
mtf.TestRemove(1);
|
||||
mtf.TestRemove(4);
|
||||
mtf.TestRemove(3);
|
||||
mtf.TestRemove(6);
|
||||
mtf.TestRemove(5);
|
||||
mtf.TestRemove(7);
|
||||
mtf.TestRemove(9);
|
||||
|
||||
CheckTree(mtf, std::string(R"(
|
||||
13H4S8----10H3S4----8H1S1-----D3
|
||||
12H2S2----11H1S1----D4
|
||||
15H2S3----14H1S1----D3
|
||||
16H1S1----D3
|
||||
)")
|
||||
.substr(1));
|
||||
}
|
||||
|
||||
TEST(MoveToFront, BiggerScaleTreeTest) {
|
||||
MoveToFrontTester mtf;
|
||||
std::set<uint32_t> all_vals;
|
||||
|
||||
const uint32_t kMagic1 = 2654435761;
|
||||
const uint32_t kMagic2 = 10000;
|
||||
|
||||
for (uint32_t i = 1; i < 1000; ++i) {
|
||||
const uint32_t val = (i * kMagic1) % kMagic2;
|
||||
if (!all_vals.count(val)) {
|
||||
mtf.TestInsert(val);
|
||||
all_vals.insert(val);
|
||||
}
|
||||
}
|
||||
|
||||
for (uint32_t i = 1; i < 1000; ++i) {
|
||||
const uint32_t val = (i * kMagic1) % kMagic2;
|
||||
if (val % 2 == 0) {
|
||||
mtf.TestRemove(val);
|
||||
all_vals.erase(val);
|
||||
}
|
||||
}
|
||||
|
||||
for (uint32_t i = 1000; i < 2000; ++i) {
|
||||
const uint32_t val = (i * kMagic1) % kMagic2;
|
||||
if (!all_vals.count(val)) {
|
||||
mtf.TestInsert(val);
|
||||
all_vals.insert(val);
|
||||
}
|
||||
}
|
||||
|
||||
for (uint32_t i = 1; i < 2000; ++i) {
|
||||
const uint32_t val = (i * kMagic1) % kMagic2;
|
||||
if (val > 50) {
|
||||
mtf.TestRemove(val);
|
||||
all_vals.erase(val);
|
||||
}
|
||||
}
|
||||
|
||||
EXPECT_EQ(all_vals, std::set<uint32_t>({2, 4, 11, 13, 24, 33, 35, 37, 46}));
|
||||
|
||||
CheckTree(mtf, std::string(R"(
|
||||
33H4S9----11H3S5----2H2S2-----D3
|
||||
4H1S1-----D4
|
||||
13H2S2----D3
|
||||
24H1S1----D4
|
||||
37H2S3----35H1S1----D3
|
||||
46H1S1----D3
|
||||
)")
|
||||
.substr(1));
|
||||
}
|
||||
|
||||
TEST(MoveToFront, RankFromValue) {
|
||||
MoveToFrontTester mtf;
|
||||
|
||||
uint32_t rank = 0;
|
||||
EXPECT_FALSE(mtf.RankFromValue(1, &rank));
|
||||
|
||||
EXPECT_TRUE(mtf.Insert(1));
|
||||
EXPECT_TRUE(mtf.Insert(2));
|
||||
EXPECT_TRUE(mtf.Insert(3));
|
||||
EXPECT_FALSE(mtf.Insert(2));
|
||||
CheckTree(mtf,
|
||||
std::string(R"(
|
||||
2H2S3T2-------1H1S1T1-------D2
|
||||
3H1S1T3-------D2
|
||||
)")
|
||||
.substr(1),
|
||||
/* print_timestamp = */ true);
|
||||
|
||||
EXPECT_FALSE(mtf.RankFromValue(4, &rank));
|
||||
|
||||
EXPECT_TRUE(mtf.RankFromValue(1, &rank));
|
||||
EXPECT_EQ(3u, rank);
|
||||
|
||||
CheckTree(mtf,
|
||||
std::string(R"(
|
||||
3H2S3T3-------2H1S1T2-------D2
|
||||
1H1S1T4-------D2
|
||||
)")
|
||||
.substr(1),
|
||||
/* print_timestamp = */ true);
|
||||
|
||||
EXPECT_TRUE(mtf.RankFromValue(1, &rank));
|
||||
EXPECT_EQ(1u, rank);
|
||||
|
||||
EXPECT_TRUE(mtf.RankFromValue(3, &rank));
|
||||
EXPECT_EQ(2u, rank);
|
||||
|
||||
EXPECT_TRUE(mtf.RankFromValue(2, &rank));
|
||||
EXPECT_EQ(3u, rank);
|
||||
|
||||
EXPECT_TRUE(mtf.Insert(40));
|
||||
|
||||
EXPECT_TRUE(mtf.RankFromValue(1, &rank));
|
||||
EXPECT_EQ(4u, rank);
|
||||
|
||||
EXPECT_TRUE(mtf.Insert(50));
|
||||
|
||||
EXPECT_TRUE(mtf.RankFromValue(1, &rank));
|
||||
EXPECT_EQ(2u, rank);
|
||||
|
||||
CheckTree(mtf,
|
||||
std::string(R"(
|
||||
2H3S5T6-------3H1S1T5-------D2
|
||||
50H2S3T9------40H1S1T7------D3
|
||||
1H1S1T10------D3
|
||||
)")
|
||||
.substr(1),
|
||||
/* print_timestamp = */ true);
|
||||
|
||||
EXPECT_TRUE(mtf.RankFromValue(50, &rank));
|
||||
EXPECT_EQ(2u, rank);
|
||||
|
||||
EXPECT_EQ(5u, mtf.GetSize());
|
||||
CheckTree(mtf,
|
||||
std::string(R"(
|
||||
2H3S5T6-------3H1S1T5-------D2
|
||||
1H2S3T10------40H1S1T7------D3
|
||||
50H1S1T11-----D3
|
||||
)")
|
||||
.substr(1),
|
||||
/* print_timestamp = */ true);
|
||||
|
||||
EXPECT_FALSE(mtf.RankFromValue(0, &rank));
|
||||
EXPECT_FALSE(mtf.RankFromValue(20, &rank));
|
||||
}
|
||||
|
||||
TEST(MoveToFront, ValueFromRank) {
|
||||
MoveToFrontTester mtf;
|
||||
|
||||
uint32_t value = 0;
|
||||
EXPECT_FALSE(mtf.ValueFromRank(0, &value));
|
||||
EXPECT_FALSE(mtf.ValueFromRank(1, &value));
|
||||
|
||||
EXPECT_TRUE(mtf.Insert(1));
|
||||
EXPECT_EQ(1u, mtf.GetLastAccessedValue());
|
||||
EXPECT_TRUE(mtf.Insert(2));
|
||||
EXPECT_EQ(2u, mtf.GetLastAccessedValue());
|
||||
EXPECT_TRUE(mtf.Insert(3));
|
||||
EXPECT_EQ(3u, mtf.GetLastAccessedValue());
|
||||
|
||||
EXPECT_TRUE(mtf.ValueFromRank(3, &value));
|
||||
EXPECT_EQ(1u, value);
|
||||
EXPECT_EQ(1u, mtf.GetLastAccessedValue());
|
||||
|
||||
EXPECT_TRUE(mtf.ValueFromRank(1, &value));
|
||||
EXPECT_EQ(1u, value);
|
||||
EXPECT_EQ(1u, mtf.GetLastAccessedValue());
|
||||
|
||||
CheckTree(mtf,
|
||||
std::string(R"(
|
||||
3H2S3T3-------2H1S1T2-------D2
|
||||
1H1S1T4-------D2
|
||||
)")
|
||||
.substr(1),
|
||||
/* print_timestamp = */ true);
|
||||
|
||||
EXPECT_TRUE(mtf.ValueFromRank(2, &value));
|
||||
EXPECT_EQ(3u, value);
|
||||
|
||||
EXPECT_EQ(3u, mtf.GetSize());
|
||||
|
||||
CheckTree(mtf,
|
||||
std::string(R"(
|
||||
1H2S3T4-------2H1S1T2-------D2
|
||||
3H1S1T5-------D2
|
||||
)")
|
||||
.substr(1),
|
||||
/* print_timestamp = */ true);
|
||||
|
||||
EXPECT_TRUE(mtf.ValueFromRank(3, &value));
|
||||
EXPECT_EQ(2u, value);
|
||||
|
||||
CheckTree(mtf,
|
||||
std::string(R"(
|
||||
3H2S3T5-------1H1S1T4-------D2
|
||||
2H1S1T6-------D2
|
||||
)")
|
||||
.substr(1),
|
||||
/* print_timestamp = */ true);
|
||||
|
||||
EXPECT_TRUE(mtf.Insert(10));
|
||||
CheckTree(mtf,
|
||||
std::string(R"(
|
||||
3H3S4T5-------1H1S1T4-------D2
|
||||
2H2S2T6-------D2
|
||||
10H1S1T7------D3
|
||||
)")
|
||||
.substr(1),
|
||||
/* print_timestamp = */ true);
|
||||
|
||||
EXPECT_TRUE(mtf.ValueFromRank(1, &value));
|
||||
EXPECT_EQ(10u, value);
|
||||
}
|
||||
|
||||
TEST(MoveToFront, Remove) {
|
||||
MoveToFrontTester mtf;
|
||||
|
||||
EXPECT_FALSE(mtf.Remove(1));
|
||||
EXPECT_EQ(0u, mtf.GetTotalNodeCount());
|
||||
|
||||
EXPECT_TRUE(mtf.Insert(1));
|
||||
EXPECT_TRUE(mtf.Insert(2));
|
||||
EXPECT_TRUE(mtf.Insert(3));
|
||||
|
||||
CheckTree(mtf,
|
||||
std::string(R"(
|
||||
2H2S3T2-------1H1S1T1-------D2
|
||||
3H1S1T3-------D2
|
||||
)")
|
||||
.substr(1),
|
||||
/* print_timestamp = */ true);
|
||||
|
||||
EXPECT_EQ(1u, mtf.GetNodeHandle(1));
|
||||
EXPECT_EQ(3u, mtf.GetTotalNodeCount());
|
||||
EXPECT_TRUE(mtf.Remove(1));
|
||||
EXPECT_EQ(3u, mtf.GetTotalNodeCount());
|
||||
|
||||
CheckTree(mtf,
|
||||
std::string(R"(
|
||||
2H2S2T2-------D1
|
||||
3H1S1T3-------D2
|
||||
)")
|
||||
.substr(1),
|
||||
/* print_timestamp = */ true);
|
||||
|
||||
uint32_t value = 0;
|
||||
EXPECT_TRUE(mtf.ValueFromRank(2, &value));
|
||||
EXPECT_EQ(2u, value);
|
||||
|
||||
CheckTree(mtf,
|
||||
std::string(R"(
|
||||
3H2S2T3-------D1
|
||||
2H1S1T4-------D2
|
||||
)")
|
||||
.substr(1),
|
||||
/* print_timestamp = */ true);
|
||||
|
||||
EXPECT_TRUE(mtf.Insert(1));
|
||||
EXPECT_EQ(1u, mtf.GetNodeHandle(1));
|
||||
EXPECT_EQ(3u, mtf.GetTotalNodeCount());
|
||||
}
|
||||
|
||||
TEST(MoveToFront, LargerScale) {
|
||||
MoveToFrontTester mtf;
|
||||
uint32_t value = 0;
|
||||
uint32_t rank = 0;
|
||||
|
||||
for (uint32_t i = 1; i < 1000; ++i) {
|
||||
ASSERT_TRUE(mtf.Insert(i));
|
||||
ASSERT_EQ(i, mtf.GetSize());
|
||||
|
||||
ASSERT_TRUE(mtf.RankFromValue(i, &rank));
|
||||
ASSERT_EQ(1u, rank);
|
||||
|
||||
ASSERT_TRUE(mtf.ValueFromRank(1, &value));
|
||||
ASSERT_EQ(i, value);
|
||||
}
|
||||
|
||||
ASSERT_TRUE(mtf.ValueFromRank(999, &value));
|
||||
ASSERT_EQ(1u, value);
|
||||
|
||||
ASSERT_TRUE(mtf.ValueFromRank(999, &value));
|
||||
ASSERT_EQ(2u, value);
|
||||
|
||||
ASSERT_TRUE(mtf.ValueFromRank(999, &value));
|
||||
ASSERT_EQ(3u, value);
|
||||
|
||||
ASSERT_TRUE(mtf.ValueFromRank(999, &value));
|
||||
ASSERT_EQ(4u, value);
|
||||
|
||||
ASSERT_TRUE(mtf.ValueFromRank(999, &value));
|
||||
ASSERT_EQ(5u, value);
|
||||
|
||||
ASSERT_TRUE(mtf.ValueFromRank(999, &value));
|
||||
ASSERT_EQ(6u, value);
|
||||
|
||||
ASSERT_TRUE(mtf.ValueFromRank(101, &value));
|
||||
ASSERT_EQ(905u, value);
|
||||
|
||||
ASSERT_TRUE(mtf.ValueFromRank(101, &value));
|
||||
ASSERT_EQ(906u, value);
|
||||
|
||||
ASSERT_TRUE(mtf.ValueFromRank(101, &value));
|
||||
ASSERT_EQ(907u, value);
|
||||
|
||||
ASSERT_TRUE(mtf.ValueFromRank(201, &value));
|
||||
ASSERT_EQ(805u, value);
|
||||
|
||||
ASSERT_TRUE(mtf.ValueFromRank(201, &value));
|
||||
ASSERT_EQ(806u, value);
|
||||
|
||||
ASSERT_TRUE(mtf.ValueFromRank(201, &value));
|
||||
ASSERT_EQ(807u, value);
|
||||
|
||||
ASSERT_TRUE(mtf.ValueFromRank(301, &value));
|
||||
ASSERT_EQ(705u, value);
|
||||
|
||||
ASSERT_TRUE(mtf.ValueFromRank(301, &value));
|
||||
ASSERT_EQ(706u, value);
|
||||
|
||||
ASSERT_TRUE(mtf.ValueFromRank(301, &value));
|
||||
ASSERT_EQ(707u, value);
|
||||
|
||||
ASSERT_TRUE(mtf.RankFromValue(605, &rank));
|
||||
ASSERT_EQ(401u, rank);
|
||||
|
||||
ASSERT_TRUE(mtf.RankFromValue(606, &rank));
|
||||
ASSERT_EQ(401u, rank);
|
||||
|
||||
ASSERT_TRUE(mtf.RankFromValue(607, &rank));
|
||||
ASSERT_EQ(401u, rank);
|
||||
|
||||
ASSERT_TRUE(mtf.ValueFromRank(1, &value));
|
||||
ASSERT_EQ(607u, value);
|
||||
|
||||
ASSERT_TRUE(mtf.ValueFromRank(2, &value));
|
||||
ASSERT_EQ(606u, value);
|
||||
|
||||
ASSERT_TRUE(mtf.ValueFromRank(3, &value));
|
||||
ASSERT_EQ(605u, value);
|
||||
|
||||
ASSERT_TRUE(mtf.ValueFromRank(4, &value));
|
||||
ASSERT_EQ(707u, value);
|
||||
|
||||
ASSERT_TRUE(mtf.ValueFromRank(5, &value));
|
||||
ASSERT_EQ(706u, value);
|
||||
|
||||
ASSERT_TRUE(mtf.ValueFromRank(6, &value));
|
||||
ASSERT_EQ(705u, value);
|
||||
|
||||
ASSERT_TRUE(mtf.ValueFromRank(7, &value));
|
||||
ASSERT_EQ(807u, value);
|
||||
|
||||
ASSERT_TRUE(mtf.ValueFromRank(8, &value));
|
||||
ASSERT_EQ(806u, value);
|
||||
|
||||
ASSERT_TRUE(mtf.ValueFromRank(9, &value));
|
||||
ASSERT_EQ(805u, value);
|
||||
|
||||
ASSERT_TRUE(mtf.ValueFromRank(10, &value));
|
||||
ASSERT_EQ(907u, value);
|
||||
|
||||
ASSERT_TRUE(mtf.ValueFromRank(11, &value));
|
||||
ASSERT_EQ(906u, value);
|
||||
|
||||
ASSERT_TRUE(mtf.ValueFromRank(12, &value));
|
||||
ASSERT_EQ(905u, value);
|
||||
|
||||
ASSERT_TRUE(mtf.ValueFromRank(13, &value));
|
||||
ASSERT_EQ(6u, value);
|
||||
|
||||
ASSERT_TRUE(mtf.ValueFromRank(14, &value));
|
||||
ASSERT_EQ(5u, value);
|
||||
|
||||
ASSERT_TRUE(mtf.ValueFromRank(15, &value));
|
||||
ASSERT_EQ(4u, value);
|
||||
|
||||
ASSERT_TRUE(mtf.ValueFromRank(16, &value));
|
||||
ASSERT_EQ(3u, value);
|
||||
|
||||
ASSERT_TRUE(mtf.ValueFromRank(17, &value));
|
||||
ASSERT_EQ(2u, value);
|
||||
|
||||
ASSERT_TRUE(mtf.ValueFromRank(18, &value));
|
||||
ASSERT_EQ(1u, value);
|
||||
|
||||
ASSERT_TRUE(mtf.ValueFromRank(19, &value));
|
||||
ASSERT_EQ(999u, value);
|
||||
|
||||
ASSERT_TRUE(mtf.ValueFromRank(20, &value));
|
||||
ASSERT_EQ(998u, value);
|
||||
|
||||
ASSERT_TRUE(mtf.ValueFromRank(21, &value));
|
||||
ASSERT_EQ(997u, value);
|
||||
|
||||
ASSERT_TRUE(mtf.RankFromValue(997, &rank));
|
||||
ASSERT_EQ(1u, rank);
|
||||
|
||||
ASSERT_TRUE(mtf.RankFromValue(998, &rank));
|
||||
ASSERT_EQ(2u, rank);
|
||||
|
||||
ASSERT_TRUE(mtf.RankFromValue(996, &rank));
|
||||
ASSERT_EQ(22u, rank);
|
||||
|
||||
ASSERT_TRUE(mtf.Remove(995));
|
||||
|
||||
ASSERT_TRUE(mtf.RankFromValue(994, &rank));
|
||||
ASSERT_EQ(23u, rank);
|
||||
|
||||
for (uint32_t i = 10; i < 1000; ++i) {
|
||||
if (i != 995) {
|
||||
ASSERT_TRUE(mtf.Remove(i));
|
||||
} else {
|
||||
ASSERT_FALSE(mtf.Remove(i));
|
||||
}
|
||||
}
|
||||
|
||||
CheckTree(mtf,
|
||||
std::string(R"(
|
||||
6H4S9T1029----8H2S3T8-------7H1S1T7-------D3
|
||||
9H1S1T9-------D3
|
||||
2H3S5T1033----4H2S3T1031----5H1S1T1030----D4
|
||||
3H1S1T1032----D4
|
||||
1H1S1T1034----D3
|
||||
)")
|
||||
.substr(1),
|
||||
/* print_timestamp = */ true);
|
||||
|
||||
ASSERT_TRUE(mtf.Insert(1000));
|
||||
ASSERT_TRUE(mtf.ValueFromRank(1, &value));
|
||||
ASSERT_EQ(1000u, value);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace comp
|
||||
} // namespace spvtools
|
@ -1,27 +0,0 @@
|
||||
# Copyright (c) 2017 Google Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
set(VAL_TEST_COMMON_SRCS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../test_fixture.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../unit_spirv.h
|
||||
)
|
||||
|
||||
add_spvtools_unittest(TARGET stats
|
||||
SRCS stats_aggregate_test.cpp
|
||||
stats_analyzer_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../tools/stats/spirv_stats.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../tools/stats/stats_analyzer.cpp
|
||||
${VAL_TEST_COMMON_SRCS}
|
||||
LIBS ${SPIRV_TOOLS}
|
||||
)
|
@ -1,438 +0,0 @@
|
||||
// Copyright (c) 2017 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Tests for unique type declaration rules validator.
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "test/test_fixture.h"
|
||||
#include "test/unit_spirv.h"
|
||||
#include "tools/stats/spirv_stats.h"
|
||||
|
||||
namespace spvtools {
|
||||
namespace stats {
|
||||
namespace {
|
||||
|
||||
using spvtest::ScopedContext;
|
||||
|
||||
void DiagnosticsMessageHandler(spv_message_level_t level, const char*,
|
||||
const spv_position_t& position,
|
||||
const char* message) {
|
||||
switch (level) {
|
||||
case SPV_MSG_FATAL:
|
||||
case SPV_MSG_INTERNAL_ERROR:
|
||||
case SPV_MSG_ERROR:
|
||||
std::cerr << "error: " << position.index << ": " << message << std::endl;
|
||||
break;
|
||||
case SPV_MSG_WARNING:
|
||||
std::cout << "warning: " << position.index << ": " << message
|
||||
<< std::endl;
|
||||
break;
|
||||
case SPV_MSG_INFO:
|
||||
std::cout << "info: " << position.index << ": " << message << std::endl;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Calls AggregateStats for binary compiled from |code|.
|
||||
void CompileAndAggregateStats(const std::string& code, SpirvStats* stats,
|
||||
spv_target_env env = SPV_ENV_UNIVERSAL_1_1) {
|
||||
spvtools::Context ctx(env);
|
||||
ctx.SetMessageConsumer(DiagnosticsMessageHandler);
|
||||
spv_binary binary;
|
||||
ASSERT_EQ(SPV_SUCCESS, spvTextToBinary(ctx.CContext(), code.c_str(),
|
||||
code.size(), &binary, nullptr));
|
||||
|
||||
ASSERT_EQ(SPV_SUCCESS, AggregateStats(ctx.CContext(), binary->code,
|
||||
binary->wordCount, nullptr, stats));
|
||||
spvBinaryDestroy(binary);
|
||||
}
|
||||
|
||||
TEST(AggregateStats, CapabilityHistogram) {
|
||||
const std::string code1 = R"(
|
||||
OpCapability Addresses
|
||||
OpCapability Kernel
|
||||
OpCapability GenericPointer
|
||||
OpCapability Linkage
|
||||
OpMemoryModel Physical32 OpenCL
|
||||
)";
|
||||
|
||||
const std::string code2 = R"(
|
||||
OpCapability Shader
|
||||
OpCapability Linkage
|
||||
OpMemoryModel Logical GLSL450
|
||||
)";
|
||||
|
||||
SpirvStats stats;
|
||||
|
||||
CompileAndAggregateStats(code1, &stats);
|
||||
EXPECT_EQ(4u, stats.capability_hist.size());
|
||||
EXPECT_EQ(0u, stats.capability_hist.count(SpvCapabilityShader));
|
||||
EXPECT_EQ(1u, stats.capability_hist.at(SpvCapabilityAddresses));
|
||||
EXPECT_EQ(1u, stats.capability_hist.at(SpvCapabilityKernel));
|
||||
EXPECT_EQ(1u, stats.capability_hist.at(SpvCapabilityGenericPointer));
|
||||
EXPECT_EQ(1u, stats.capability_hist.at(SpvCapabilityLinkage));
|
||||
|
||||
CompileAndAggregateStats(code2, &stats);
|
||||
EXPECT_EQ(5u, stats.capability_hist.size());
|
||||
EXPECT_EQ(1u, stats.capability_hist.at(SpvCapabilityShader));
|
||||
EXPECT_EQ(1u, stats.capability_hist.at(SpvCapabilityAddresses));
|
||||
EXPECT_EQ(1u, stats.capability_hist.at(SpvCapabilityKernel));
|
||||
EXPECT_EQ(1u, stats.capability_hist.at(SpvCapabilityGenericPointer));
|
||||
EXPECT_EQ(2u, stats.capability_hist.at(SpvCapabilityLinkage));
|
||||
|
||||
CompileAndAggregateStats(code1, &stats);
|
||||
EXPECT_EQ(5u, stats.capability_hist.size());
|
||||
EXPECT_EQ(1u, stats.capability_hist.at(SpvCapabilityShader));
|
||||
EXPECT_EQ(2u, stats.capability_hist.at(SpvCapabilityAddresses));
|
||||
EXPECT_EQ(2u, stats.capability_hist.at(SpvCapabilityKernel));
|
||||
EXPECT_EQ(2u, stats.capability_hist.at(SpvCapabilityGenericPointer));
|
||||
EXPECT_EQ(3u, stats.capability_hist.at(SpvCapabilityLinkage));
|
||||
|
||||
CompileAndAggregateStats(code2, &stats);
|
||||
EXPECT_EQ(5u, stats.capability_hist.size());
|
||||
EXPECT_EQ(2u, stats.capability_hist.at(SpvCapabilityShader));
|
||||
EXPECT_EQ(2u, stats.capability_hist.at(SpvCapabilityAddresses));
|
||||
EXPECT_EQ(2u, stats.capability_hist.at(SpvCapabilityKernel));
|
||||
EXPECT_EQ(2u, stats.capability_hist.at(SpvCapabilityGenericPointer));
|
||||
EXPECT_EQ(4u, stats.capability_hist.at(SpvCapabilityLinkage));
|
||||
}
|
||||
|
||||
TEST(AggregateStats, ExtensionHistogram) {
|
||||
const std::string code1 = R"(
|
||||
OpCapability Addresses
|
||||
OpCapability Kernel
|
||||
OpCapability GenericPointer
|
||||
OpCapability Linkage
|
||||
OpExtension "SPV_KHR_16bit_storage"
|
||||
OpMemoryModel Physical32 OpenCL
|
||||
)";
|
||||
|
||||
const std::string code2 = R"(
|
||||
OpCapability Shader
|
||||
OpCapability Linkage
|
||||
OpExtension "SPV_NV_viewport_array2"
|
||||
OpExtension "greatest_extension_ever"
|
||||
OpMemoryModel Logical GLSL450
|
||||
)";
|
||||
|
||||
SpirvStats stats;
|
||||
|
||||
CompileAndAggregateStats(code1, &stats);
|
||||
EXPECT_EQ(1u, stats.extension_hist.size());
|
||||
EXPECT_EQ(0u, stats.extension_hist.count("SPV_NV_viewport_array2"));
|
||||
EXPECT_EQ(1u, stats.extension_hist.at("SPV_KHR_16bit_storage"));
|
||||
|
||||
CompileAndAggregateStats(code2, &stats);
|
||||
EXPECT_EQ(3u, stats.extension_hist.size());
|
||||
EXPECT_EQ(1u, stats.extension_hist.at("SPV_NV_viewport_array2"));
|
||||
EXPECT_EQ(1u, stats.extension_hist.at("SPV_KHR_16bit_storage"));
|
||||
EXPECT_EQ(1u, stats.extension_hist.at("greatest_extension_ever"));
|
||||
|
||||
CompileAndAggregateStats(code1, &stats);
|
||||
EXPECT_EQ(3u, stats.extension_hist.size());
|
||||
EXPECT_EQ(1u, stats.extension_hist.at("SPV_NV_viewport_array2"));
|
||||
EXPECT_EQ(2u, stats.extension_hist.at("SPV_KHR_16bit_storage"));
|
||||
EXPECT_EQ(1u, stats.extension_hist.at("greatest_extension_ever"));
|
||||
|
||||
CompileAndAggregateStats(code2, &stats);
|
||||
EXPECT_EQ(3u, stats.extension_hist.size());
|
||||
EXPECT_EQ(2u, stats.extension_hist.at("SPV_NV_viewport_array2"));
|
||||
EXPECT_EQ(2u, stats.extension_hist.at("SPV_KHR_16bit_storage"));
|
||||
EXPECT_EQ(2u, stats.extension_hist.at("greatest_extension_ever"));
|
||||
}
|
||||
|
||||
TEST(AggregateStats, VersionHistogram) {
|
||||
const std::string code1 = R"(
|
||||
OpCapability Shader
|
||||
OpCapability Linkage
|
||||
OpMemoryModel Logical GLSL450
|
||||
)";
|
||||
|
||||
SpirvStats stats;
|
||||
|
||||
CompileAndAggregateStats(code1, &stats);
|
||||
EXPECT_EQ(1u, stats.version_hist.size());
|
||||
EXPECT_EQ(1u, stats.version_hist.at(0x00010100));
|
||||
|
||||
CompileAndAggregateStats(code1, &stats, SPV_ENV_UNIVERSAL_1_0);
|
||||
EXPECT_EQ(2u, stats.version_hist.size());
|
||||
EXPECT_EQ(1u, stats.version_hist.at(0x00010100));
|
||||
EXPECT_EQ(1u, stats.version_hist.at(0x00010000));
|
||||
|
||||
CompileAndAggregateStats(code1, &stats);
|
||||
EXPECT_EQ(2u, stats.version_hist.size());
|
||||
EXPECT_EQ(2u, stats.version_hist.at(0x00010100));
|
||||
EXPECT_EQ(1u, stats.version_hist.at(0x00010000));
|
||||
|
||||
CompileAndAggregateStats(code1, &stats, SPV_ENV_UNIVERSAL_1_0);
|
||||
EXPECT_EQ(2u, stats.version_hist.size());
|
||||
EXPECT_EQ(2u, stats.version_hist.at(0x00010100));
|
||||
EXPECT_EQ(2u, stats.version_hist.at(0x00010000));
|
||||
}
|
||||
|
||||
TEST(AggregateStats, GeneratorHistogram) {
|
||||
const std::string code1 = R"(
|
||||
OpCapability Shader
|
||||
OpCapability Linkage
|
||||
OpMemoryModel Logical GLSL450
|
||||
)";
|
||||
|
||||
const uint32_t kGeneratorKhronosAssembler = SPV_GENERATOR_KHRONOS_ASSEMBLER
|
||||
<< 16;
|
||||
|
||||
SpirvStats stats;
|
||||
|
||||
CompileAndAggregateStats(code1, &stats);
|
||||
EXPECT_EQ(1u, stats.generator_hist.size());
|
||||
EXPECT_EQ(1u, stats.generator_hist.at(kGeneratorKhronosAssembler));
|
||||
|
||||
CompileAndAggregateStats(code1, &stats);
|
||||
EXPECT_EQ(1u, stats.generator_hist.size());
|
||||
EXPECT_EQ(2u, stats.generator_hist.at(kGeneratorKhronosAssembler));
|
||||
}
|
||||
|
||||
TEST(AggregateStats, OpcodeHistogram) {
|
||||
const std::string code1 = R"(
|
||||
OpCapability Addresses
|
||||
OpCapability Kernel
|
||||
OpCapability Int64
|
||||
OpCapability Linkage
|
||||
OpMemoryModel Physical32 OpenCL
|
||||
%u64 = OpTypeInt 64 0
|
||||
%u32 = OpTypeInt 32 0
|
||||
%f32 = OpTypeFloat 32
|
||||
)";
|
||||
|
||||
const std::string code2 = R"(
|
||||
OpCapability Shader
|
||||
OpCapability Linkage
|
||||
OpExtension "SPV_NV_viewport_array2"
|
||||
OpMemoryModel Logical GLSL450
|
||||
)";
|
||||
|
||||
SpirvStats stats;
|
||||
|
||||
CompileAndAggregateStats(code1, &stats);
|
||||
EXPECT_EQ(4u, stats.opcode_hist.size());
|
||||
EXPECT_EQ(4u, stats.opcode_hist.at(SpvOpCapability));
|
||||
EXPECT_EQ(1u, stats.opcode_hist.at(SpvOpMemoryModel));
|
||||
EXPECT_EQ(2u, stats.opcode_hist.at(SpvOpTypeInt));
|
||||
EXPECT_EQ(1u, stats.opcode_hist.at(SpvOpTypeFloat));
|
||||
|
||||
CompileAndAggregateStats(code2, &stats);
|
||||
EXPECT_EQ(5u, stats.opcode_hist.size());
|
||||
EXPECT_EQ(6u, stats.opcode_hist.at(SpvOpCapability));
|
||||
EXPECT_EQ(2u, stats.opcode_hist.at(SpvOpMemoryModel));
|
||||
EXPECT_EQ(2u, stats.opcode_hist.at(SpvOpTypeInt));
|
||||
EXPECT_EQ(1u, stats.opcode_hist.at(SpvOpTypeFloat));
|
||||
EXPECT_EQ(1u, stats.opcode_hist.at(SpvOpExtension));
|
||||
|
||||
CompileAndAggregateStats(code1, &stats);
|
||||
EXPECT_EQ(5u, stats.opcode_hist.size());
|
||||
EXPECT_EQ(10u, stats.opcode_hist.at(SpvOpCapability));
|
||||
EXPECT_EQ(3u, stats.opcode_hist.at(SpvOpMemoryModel));
|
||||
EXPECT_EQ(4u, stats.opcode_hist.at(SpvOpTypeInt));
|
||||
EXPECT_EQ(2u, stats.opcode_hist.at(SpvOpTypeFloat));
|
||||
EXPECT_EQ(1u, stats.opcode_hist.at(SpvOpExtension));
|
||||
|
||||
CompileAndAggregateStats(code2, &stats);
|
||||
EXPECT_EQ(5u, stats.opcode_hist.size());
|
||||
EXPECT_EQ(12u, stats.opcode_hist.at(SpvOpCapability));
|
||||
EXPECT_EQ(4u, stats.opcode_hist.at(SpvOpMemoryModel));
|
||||
EXPECT_EQ(4u, stats.opcode_hist.at(SpvOpTypeInt));
|
||||
EXPECT_EQ(2u, stats.opcode_hist.at(SpvOpTypeFloat));
|
||||
EXPECT_EQ(2u, stats.opcode_hist.at(SpvOpExtension));
|
||||
}
|
||||
|
||||
TEST(AggregateStats, OpcodeMarkovHistogram) {
|
||||
const std::string code1 = R"(
|
||||
OpCapability Shader
|
||||
OpCapability Linkage
|
||||
OpExtension "SPV_NV_viewport_array2"
|
||||
OpMemoryModel Logical GLSL450
|
||||
)";
|
||||
|
||||
const std::string code2 = R"(
|
||||
OpCapability Addresses
|
||||
OpCapability Kernel
|
||||
OpCapability Int64
|
||||
OpCapability Linkage
|
||||
OpMemoryModel Physical32 OpenCL
|
||||
%u64 = OpTypeInt 64 0
|
||||
%u32 = OpTypeInt 32 0
|
||||
%f32 = OpTypeFloat 32
|
||||
)";
|
||||
|
||||
SpirvStats stats;
|
||||
stats.opcode_markov_hist.resize(2);
|
||||
|
||||
CompileAndAggregateStats(code1, &stats);
|
||||
ASSERT_EQ(2u, stats.opcode_markov_hist.size());
|
||||
EXPECT_EQ(2u, stats.opcode_markov_hist[0].size());
|
||||
EXPECT_EQ(2u, stats.opcode_markov_hist[0].at(SpvOpCapability).size());
|
||||
EXPECT_EQ(1u, stats.opcode_markov_hist[0].at(SpvOpExtension).size());
|
||||
EXPECT_EQ(
|
||||
1u, stats.opcode_markov_hist[0].at(SpvOpCapability).at(SpvOpCapability));
|
||||
EXPECT_EQ(1u,
|
||||
stats.opcode_markov_hist[0].at(SpvOpCapability).at(SpvOpExtension));
|
||||
EXPECT_EQ(
|
||||
1u, stats.opcode_markov_hist[0].at(SpvOpExtension).at(SpvOpMemoryModel));
|
||||
|
||||
EXPECT_EQ(1u, stats.opcode_markov_hist[1].size());
|
||||
EXPECT_EQ(2u, stats.opcode_markov_hist[1].at(SpvOpCapability).size());
|
||||
EXPECT_EQ(1u,
|
||||
stats.opcode_markov_hist[1].at(SpvOpCapability).at(SpvOpExtension));
|
||||
EXPECT_EQ(
|
||||
1u, stats.opcode_markov_hist[1].at(SpvOpCapability).at(SpvOpMemoryModel));
|
||||
|
||||
CompileAndAggregateStats(code2, &stats);
|
||||
ASSERT_EQ(2u, stats.opcode_markov_hist.size());
|
||||
EXPECT_EQ(4u, stats.opcode_markov_hist[0].size());
|
||||
EXPECT_EQ(3u, stats.opcode_markov_hist[0].at(SpvOpCapability).size());
|
||||
EXPECT_EQ(1u, stats.opcode_markov_hist[0].at(SpvOpExtension).size());
|
||||
EXPECT_EQ(1u, stats.opcode_markov_hist[0].at(SpvOpMemoryModel).size());
|
||||
EXPECT_EQ(2u, stats.opcode_markov_hist[0].at(SpvOpTypeInt).size());
|
||||
EXPECT_EQ(
|
||||
4u, stats.opcode_markov_hist[0].at(SpvOpCapability).at(SpvOpCapability));
|
||||
EXPECT_EQ(1u,
|
||||
stats.opcode_markov_hist[0].at(SpvOpCapability).at(SpvOpExtension));
|
||||
EXPECT_EQ(
|
||||
1u, stats.opcode_markov_hist[0].at(SpvOpCapability).at(SpvOpMemoryModel));
|
||||
EXPECT_EQ(
|
||||
1u, stats.opcode_markov_hist[0].at(SpvOpExtension).at(SpvOpMemoryModel));
|
||||
EXPECT_EQ(1u,
|
||||
stats.opcode_markov_hist[0].at(SpvOpMemoryModel).at(SpvOpTypeInt));
|
||||
EXPECT_EQ(1u, stats.opcode_markov_hist[0].at(SpvOpTypeInt).at(SpvOpTypeInt));
|
||||
EXPECT_EQ(1u,
|
||||
stats.opcode_markov_hist[0].at(SpvOpTypeInt).at(SpvOpTypeFloat));
|
||||
|
||||
EXPECT_EQ(3u, stats.opcode_markov_hist[1].size());
|
||||
EXPECT_EQ(4u, stats.opcode_markov_hist[1].at(SpvOpCapability).size());
|
||||
EXPECT_EQ(1u, stats.opcode_markov_hist[1].at(SpvOpMemoryModel).size());
|
||||
EXPECT_EQ(1u, stats.opcode_markov_hist[1].at(SpvOpTypeInt).size());
|
||||
EXPECT_EQ(
|
||||
2u, stats.opcode_markov_hist[1].at(SpvOpCapability).at(SpvOpCapability));
|
||||
EXPECT_EQ(1u,
|
||||
stats.opcode_markov_hist[1].at(SpvOpCapability).at(SpvOpExtension));
|
||||
EXPECT_EQ(
|
||||
2u, stats.opcode_markov_hist[1].at(SpvOpCapability).at(SpvOpMemoryModel));
|
||||
EXPECT_EQ(1u,
|
||||
stats.opcode_markov_hist[1].at(SpvOpCapability).at(SpvOpTypeInt));
|
||||
EXPECT_EQ(1u,
|
||||
stats.opcode_markov_hist[1].at(SpvOpMemoryModel).at(SpvOpTypeInt));
|
||||
EXPECT_EQ(1u,
|
||||
stats.opcode_markov_hist[1].at(SpvOpTypeInt).at(SpvOpTypeFloat));
|
||||
}
|
||||
|
||||
TEST(AggregateStats, ConstantLiteralsHistogram) {
|
||||
const std::string code1 = R"(
|
||||
OpCapability Addresses
|
||||
OpCapability Kernel
|
||||
OpCapability GenericPointer
|
||||
OpCapability Linkage
|
||||
OpCapability Float64
|
||||
OpCapability Int16
|
||||
OpCapability Int64
|
||||
OpMemoryModel Physical32 OpenCL
|
||||
%u16 = OpTypeInt 16 0
|
||||
%u32 = OpTypeInt 32 0
|
||||
%u64 = OpTypeInt 64 0
|
||||
%f32 = OpTypeFloat 32
|
||||
%f64 = OpTypeFloat 64
|
||||
%1 = OpConstant %f32 0.1
|
||||
%2 = OpConstant %f32 -2
|
||||
%3 = OpConstant %f64 -2
|
||||
%4 = OpConstant %u16 16
|
||||
%5 = OpConstant %u16 2
|
||||
%6 = OpConstant %u32 32
|
||||
%7 = OpConstant %u64 64
|
||||
)";
|
||||
|
||||
const std::string code2 = R"(
|
||||
OpCapability Shader
|
||||
OpCapability Linkage
|
||||
OpCapability Int16
|
||||
OpCapability Int64
|
||||
OpMemoryModel Logical GLSL450
|
||||
%f32 = OpTypeFloat 32
|
||||
%u16 = OpTypeInt 16 0
|
||||
%s16 = OpTypeInt 16 1
|
||||
%u32 = OpTypeInt 32 0
|
||||
%s32 = OpTypeInt 32 1
|
||||
%u64 = OpTypeInt 64 0
|
||||
%s64 = OpTypeInt 64 1
|
||||
%1 = OpConstant %f32 0.1
|
||||
%2 = OpConstant %f32 -2
|
||||
%3 = OpConstant %u16 1
|
||||
%4 = OpConstant %u16 16
|
||||
%5 = OpConstant %u16 2
|
||||
%6 = OpConstant %s16 -16
|
||||
%7 = OpConstant %u32 32
|
||||
%8 = OpConstant %s32 2
|
||||
%9 = OpConstant %s32 -32
|
||||
%10 = OpConstant %u64 64
|
||||
%11 = OpConstant %s64 -64
|
||||
)";
|
||||
|
||||
SpirvStats stats;
|
||||
|
||||
CompileAndAggregateStats(code1, &stats);
|
||||
EXPECT_EQ(2u, stats.f32_constant_hist.size());
|
||||
EXPECT_EQ(1u, stats.f64_constant_hist.size());
|
||||
EXPECT_EQ(1u, stats.f32_constant_hist.at(0.1f));
|
||||
EXPECT_EQ(1u, stats.f32_constant_hist.at(-2.f));
|
||||
EXPECT_EQ(1u, stats.f64_constant_hist.at(-2));
|
||||
|
||||
EXPECT_EQ(2u, stats.u16_constant_hist.size());
|
||||
EXPECT_EQ(0u, stats.s16_constant_hist.size());
|
||||
EXPECT_EQ(1u, stats.u32_constant_hist.size());
|
||||
EXPECT_EQ(0u, stats.s32_constant_hist.size());
|
||||
EXPECT_EQ(1u, stats.u64_constant_hist.size());
|
||||
EXPECT_EQ(0u, stats.s64_constant_hist.size());
|
||||
EXPECT_EQ(1u, stats.u16_constant_hist.at(16));
|
||||
EXPECT_EQ(1u, stats.u16_constant_hist.at(2));
|
||||
EXPECT_EQ(1u, stats.u32_constant_hist.at(32));
|
||||
EXPECT_EQ(1u, stats.u64_constant_hist.at(64));
|
||||
|
||||
CompileAndAggregateStats(code2, &stats);
|
||||
EXPECT_EQ(2u, stats.f32_constant_hist.size());
|
||||
EXPECT_EQ(1u, stats.f64_constant_hist.size());
|
||||
EXPECT_EQ(2u, stats.f32_constant_hist.at(0.1f));
|
||||
EXPECT_EQ(2u, stats.f32_constant_hist.at(-2.f));
|
||||
EXPECT_EQ(1u, stats.f64_constant_hist.at(-2));
|
||||
|
||||
EXPECT_EQ(3u, stats.u16_constant_hist.size());
|
||||
EXPECT_EQ(1u, stats.s16_constant_hist.size());
|
||||
EXPECT_EQ(1u, stats.u32_constant_hist.size());
|
||||
EXPECT_EQ(2u, stats.s32_constant_hist.size());
|
||||
EXPECT_EQ(1u, stats.u64_constant_hist.size());
|
||||
EXPECT_EQ(1u, stats.s64_constant_hist.size());
|
||||
EXPECT_EQ(2u, stats.u16_constant_hist.at(16));
|
||||
EXPECT_EQ(2u, stats.u16_constant_hist.at(2));
|
||||
EXPECT_EQ(1u, stats.u16_constant_hist.at(1));
|
||||
EXPECT_EQ(1u, stats.s16_constant_hist.at(-16));
|
||||
EXPECT_EQ(2u, stats.u32_constant_hist.at(32));
|
||||
EXPECT_EQ(1u, stats.s32_constant_hist.at(2));
|
||||
EXPECT_EQ(1u, stats.s32_constant_hist.at(-32));
|
||||
EXPECT_EQ(2u, stats.u64_constant_hist.at(64));
|
||||
EXPECT_EQ(1u, stats.s64_constant_hist.at(-64));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace stats
|
||||
} // namespace spvtools
|
@ -1,174 +0,0 @@
|
||||
// Copyright (c) 2017 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Tests for unique type declaration rules validator.
|
||||
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "source/latest_version_spirv_header.h"
|
||||
#include "test/test_fixture.h"
|
||||
#include "tools/stats/stats_analyzer.h"
|
||||
|
||||
namespace spvtools {
|
||||
namespace stats {
|
||||
namespace {
|
||||
|
||||
// Fills |stats| with some synthetic header stats, as if aggregated from 100
|
||||
// modules (100 used for simpler percentage evaluation).
|
||||
void FillDefaultStats(SpirvStats* stats) {
|
||||
*stats = SpirvStats();
|
||||
stats->version_hist[0x00010000] = 40;
|
||||
stats->version_hist[0x00010100] = 60;
|
||||
stats->generator_hist[0x00000000] = 64;
|
||||
stats->generator_hist[0x00010000] = 1;
|
||||
stats->generator_hist[0x00020000] = 2;
|
||||
stats->generator_hist[0x00030000] = 3;
|
||||
stats->generator_hist[0x00040000] = 4;
|
||||
stats->generator_hist[0x00050000] = 5;
|
||||
stats->generator_hist[0x00060000] = 6;
|
||||
stats->generator_hist[0x00070000] = 7;
|
||||
stats->generator_hist[0x00080000] = 8;
|
||||
|
||||
int num_version_entries = 0;
|
||||
for (const auto& pair : stats->version_hist) {
|
||||
num_version_entries += pair.second;
|
||||
}
|
||||
|
||||
int num_generator_entries = 0;
|
||||
for (const auto& pair : stats->generator_hist) {
|
||||
num_generator_entries += pair.second;
|
||||
}
|
||||
|
||||
EXPECT_EQ(num_version_entries, num_generator_entries);
|
||||
}
|
||||
|
||||
TEST(StatsAnalyzer, Version) {
|
||||
SpirvStats stats;
|
||||
FillDefaultStats(&stats);
|
||||
|
||||
StatsAnalyzer analyzer(stats);
|
||||
|
||||
std::stringstream ss;
|
||||
analyzer.WriteVersion(ss);
|
||||
const std::string output = ss.str();
|
||||
const std::string expected_output = "Version 1.1 60%\nVersion 1.0 40%\n";
|
||||
|
||||
EXPECT_EQ(expected_output, output);
|
||||
}
|
||||
|
||||
TEST(StatsAnalyzer, Generator) {
|
||||
SpirvStats stats;
|
||||
FillDefaultStats(&stats);
|
||||
|
||||
StatsAnalyzer analyzer(stats);
|
||||
|
||||
std::stringstream ss;
|
||||
analyzer.WriteGenerator(ss);
|
||||
const std::string output = ss.str();
|
||||
const std::string expected_output =
|
||||
"Khronos 64%\nKhronos Glslang Reference Front End 8%\n"
|
||||
"Khronos SPIR-V Tools Assembler 7%\nKhronos LLVM/SPIR-V Translator 6%"
|
||||
"\nARM 5%\nNVIDIA 4%\nCodeplay 3%\nValve 2%\nLunarG 1%\n";
|
||||
|
||||
EXPECT_EQ(expected_output, output);
|
||||
}
|
||||
|
||||
TEST(StatsAnalyzer, Capability) {
|
||||
SpirvStats stats;
|
||||
FillDefaultStats(&stats);
|
||||
|
||||
stats.capability_hist[SpvCapabilityShader] = 25;
|
||||
stats.capability_hist[SpvCapabilityKernel] = 75;
|
||||
|
||||
StatsAnalyzer analyzer(stats);
|
||||
|
||||
std::stringstream ss;
|
||||
analyzer.WriteCapability(ss);
|
||||
const std::string output = ss.str();
|
||||
const std::string expected_output = "Kernel 75%\nShader 25%\n";
|
||||
|
||||
EXPECT_EQ(expected_output, output);
|
||||
}
|
||||
|
||||
TEST(StatsAnalyzer, Extension) {
|
||||
SpirvStats stats;
|
||||
FillDefaultStats(&stats);
|
||||
|
||||
stats.extension_hist["greatest_extension_ever"] = 1;
|
||||
stats.extension_hist["worst_extension_ever"] = 10;
|
||||
|
||||
StatsAnalyzer analyzer(stats);
|
||||
|
||||
std::stringstream ss;
|
||||
analyzer.WriteExtension(ss);
|
||||
const std::string output = ss.str();
|
||||
const std::string expected_output =
|
||||
"worst_extension_ever 10%\ngreatest_extension_ever 1%\n";
|
||||
|
||||
EXPECT_EQ(expected_output, output);
|
||||
}
|
||||
|
||||
TEST(StatsAnalyzer, Opcode) {
|
||||
SpirvStats stats;
|
||||
FillDefaultStats(&stats);
|
||||
|
||||
stats.opcode_hist[SpvOpCapability] = 20;
|
||||
stats.opcode_hist[SpvOpConstant] = 80;
|
||||
stats.opcode_hist[SpvOpDecorate] = 100;
|
||||
|
||||
StatsAnalyzer analyzer(stats);
|
||||
|
||||
std::stringstream ss;
|
||||
analyzer.WriteOpcode(ss);
|
||||
const std::string output = ss.str();
|
||||
const std::string expected_output =
|
||||
"Total unique opcodes used: 3\nDecorate 50%\n"
|
||||
"Constant 40%\nCapability 10%\n";
|
||||
|
||||
EXPECT_EQ(expected_output, output);
|
||||
}
|
||||
|
||||
TEST(StatsAnalyzer, OpcodeMarkov) {
|
||||
SpirvStats stats;
|
||||
FillDefaultStats(&stats);
|
||||
|
||||
stats.opcode_hist[SpvOpFMul] = 400;
|
||||
stats.opcode_hist[SpvOpFAdd] = 200;
|
||||
stats.opcode_hist[SpvOpFSub] = 400;
|
||||
|
||||
stats.opcode_markov_hist.resize(1);
|
||||
auto& hist = stats.opcode_markov_hist[0];
|
||||
hist[SpvOpFMul][SpvOpFAdd] = 100;
|
||||
hist[SpvOpFMul][SpvOpFSub] = 300;
|
||||
hist[SpvOpFAdd][SpvOpFMul] = 100;
|
||||
hist[SpvOpFAdd][SpvOpFAdd] = 100;
|
||||
|
||||
StatsAnalyzer analyzer(stats);
|
||||
|
||||
std::stringstream ss;
|
||||
analyzer.WriteOpcodeMarkov(ss);
|
||||
const std::string output = ss.str();
|
||||
const std::string expected_output =
|
||||
"FMul -> FSub 75% (base rate 40%, pair occurrences 300)\n"
|
||||
"FMul -> FAdd 25% (base rate 20%, pair occurrences 100)\n"
|
||||
"FAdd -> FAdd 50% (base rate 20%, pair occurrences 100)\n"
|
||||
"FAdd -> FMul 50% (base rate 40%, pair occurrences 100)\n";
|
||||
|
||||
EXPECT_EQ(expected_output, output);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace stats
|
||||
} // namespace spvtools
|
@ -48,13 +48,6 @@ if (NOT ${SPIRV_SKIP_EXECUTABLES})
|
||||
add_spvtools_tool(TARGET spirv-reduce SRCS reduce/reduce.cpp util/cli_consumer.cpp LIBS SPIRV-Tools-reduce ${SPIRV_TOOLS})
|
||||
endif()
|
||||
add_spvtools_tool(TARGET spirv-link SRCS link/linker.cpp LIBS SPIRV-Tools-link ${SPIRV_TOOLS})
|
||||
add_spvtools_tool(TARGET spirv-stats
|
||||
SRCS stats/stats.cpp
|
||||
stats/stats_analyzer.cpp
|
||||
stats/stats_analyzer.h
|
||||
stats/spirv_stats.cpp
|
||||
stats/spirv_stats.h
|
||||
LIBS ${SPIRV_TOOLS})
|
||||
add_spvtools_tool(TARGET spirv-cfg
|
||||
SRCS cfg/cfg.cpp
|
||||
cfg/bin_to_dot.h
|
||||
@ -62,26 +55,12 @@ if (NOT ${SPIRV_SKIP_EXECUTABLES})
|
||||
LIBS ${SPIRV_TOOLS})
|
||||
target_include_directories(spirv-cfg PRIVATE ${spirv-tools_SOURCE_DIR}
|
||||
${SPIRV_HEADER_INCLUDE_DIR})
|
||||
target_include_directories(spirv-stats PRIVATE ${spirv-tools_SOURCE_DIR}
|
||||
${SPIRV_HEADER_INCLUDE_DIR})
|
||||
|
||||
set(SPIRV_INSTALL_TARGETS spirv-as spirv-dis spirv-val spirv-opt spirv-stats
|
||||
set(SPIRV_INSTALL_TARGETS spirv-as spirv-dis spirv-val spirv-opt
|
||||
spirv-cfg spirv-link)
|
||||
if(NOT DEFINED IOS_PLATFORM)
|
||||
set(SPIRV_INSTALL_TARGETS ${SPIRV_INSTALL_TARGETS} spirv-reduce)
|
||||
endif()
|
||||
|
||||
if(SPIRV_BUILD_COMPRESSION)
|
||||
add_spvtools_tool(TARGET spirv-markv
|
||||
SRCS comp/markv.cpp
|
||||
comp/markv_model_factory.cpp
|
||||
comp/markv_model_shader.cpp
|
||||
LIBS SPIRV-Tools-comp SPIRV-Tools-opt ${SPIRV_TOOLS})
|
||||
target_include_directories(spirv-markv PRIVATE ${spirv-tools_SOURCE_DIR}
|
||||
${SPIRV_HEADER_INCLUDE_DIR})
|
||||
set(SPIRV_INSTALL_TARGETS ${SPIRV_INSTALL_TARGETS} spirv-markv)
|
||||
endif(SPIRV_BUILD_COMPRESSION)
|
||||
|
||||
if(ENABLE_SPIRV_TOOLS_INSTALL)
|
||||
install(TARGETS ${SPIRV_INSTALL_TARGETS}
|
||||
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
|
||||
|
@ -1,385 +0,0 @@
|
||||
// Copyright (c) 2017 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "source/comp/markv.h"
|
||||
#include "source/spirv_target_env.h"
|
||||
#include "source/table.h"
|
||||
#include "spirv-tools/optimizer.hpp"
|
||||
#include "tools/comp/markv_model_factory.h"
|
||||
#include "tools/io.h"
|
||||
|
||||
namespace {
|
||||
|
||||
const auto kSpvEnv = SPV_ENV_UNIVERSAL_1_2;
|
||||
|
||||
enum Task {
|
||||
kNoTask = 0,
|
||||
kEncode,
|
||||
kDecode,
|
||||
kTest,
|
||||
};
|
||||
|
||||
struct ScopedContext {
|
||||
ScopedContext(spv_target_env env) : context(spvContextCreate(env)) {}
|
||||
~ScopedContext() { spvContextDestroy(context); }
|
||||
spv_context context;
|
||||
};
|
||||
|
||||
void print_usage(char* argv0) {
|
||||
printf(
|
||||
R"(%s - Encodes or decodes a SPIR-V binary to or from a MARK-V binary.
|
||||
|
||||
USAGE: %s [e|d|t] [options] [<filename>]
|
||||
|
||||
The input binary is read from <filename>. If no file is specified,
|
||||
or if the filename is "-", then the binary is read from standard input.
|
||||
|
||||
If no output is specified then the output is printed to stdout in a human
|
||||
readable format.
|
||||
|
||||
WIP: MARK-V codec is in early stages of development. At the moment it only
|
||||
can encode and decode some SPIR-V files and only if exacly the same build of
|
||||
software is used (is doesn't write or handle version numbers yet).
|
||||
|
||||
Tasks:
|
||||
e Encode SPIR-V to MARK-V.
|
||||
d Decode MARK-V to SPIR-V.
|
||||
t Test the codec by first encoding the given SPIR-V file to
|
||||
MARK-V, then decoding it back to SPIR-V and comparing results.
|
||||
|
||||
Options:
|
||||
-h, --help Print this help.
|
||||
--comments Write codec comments to stderr.
|
||||
--version Display MARK-V codec version.
|
||||
--validate Validate SPIR-V while encoding or decoding.
|
||||
--model=<model-name>
|
||||
Compression model, possible values:
|
||||
shader_lite - fast, poor compression ratio
|
||||
shader_mid - balanced
|
||||
shader_max - best compression ratio
|
||||
Default: shader_lite
|
||||
|
||||
-o <filename> Set the output filename.
|
||||
Output goes to standard output if this option is
|
||||
not specified, or if the filename is "-".
|
||||
Not needed for 't' task (testing).
|
||||
)",
|
||||
argv0, argv0);
|
||||
}
|
||||
|
||||
void DiagnosticsMessageHandler(spv_message_level_t level, const char*,
|
||||
const spv_position_t& position,
|
||||
const char* message) {
|
||||
switch (level) {
|
||||
case SPV_MSG_FATAL:
|
||||
case SPV_MSG_INTERNAL_ERROR:
|
||||
case SPV_MSG_ERROR:
|
||||
std::cerr << "error: " << position.index << ": " << message << std::endl;
|
||||
break;
|
||||
case SPV_MSG_WARNING:
|
||||
std::cerr << "warning: " << position.index << ": " << message
|
||||
<< std::endl;
|
||||
break;
|
||||
case SPV_MSG_INFO:
|
||||
std::cerr << "info: " << position.index << ": " << message << std::endl;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
const char* input_filename = nullptr;
|
||||
const char* output_filename = nullptr;
|
||||
|
||||
Task task = kNoTask;
|
||||
|
||||
if (argc < 3) {
|
||||
print_usage(argv[0]);
|
||||
return 0;
|
||||
}
|
||||
|
||||
const char* task_char = argv[1];
|
||||
if (0 == strcmp("e", task_char)) {
|
||||
task = kEncode;
|
||||
} else if (0 == strcmp("d", task_char)) {
|
||||
task = kDecode;
|
||||
} else if (0 == strcmp("t", task_char)) {
|
||||
task = kTest;
|
||||
}
|
||||
|
||||
if (task == kNoTask) {
|
||||
print_usage(argv[0]);
|
||||
return 1;
|
||||
}
|
||||
|
||||
bool want_comments = false;
|
||||
bool validate_spirv_binary = false;
|
||||
|
||||
spvtools::comp::MarkvModelType model_type =
|
||||
spvtools::comp::kMarkvModelUnknown;
|
||||
|
||||
for (int argi = 2; argi < argc; ++argi) {
|
||||
if ('-' == argv[argi][0]) {
|
||||
switch (argv[argi][1]) {
|
||||
case 'h':
|
||||
print_usage(argv[0]);
|
||||
return 0;
|
||||
case 'o': {
|
||||
if (!output_filename && argi + 1 < argc &&
|
||||
(task == kEncode || task == kDecode)) {
|
||||
output_filename = argv[++argi];
|
||||
} else {
|
||||
print_usage(argv[0]);
|
||||
return 1;
|
||||
}
|
||||
} break;
|
||||
case '-': {
|
||||
if (0 == strcmp(argv[argi], "--help")) {
|
||||
print_usage(argv[0]);
|
||||
return 0;
|
||||
} else if (0 == strcmp(argv[argi], "--comments")) {
|
||||
want_comments = true;
|
||||
} else if (0 == strcmp(argv[argi], "--version")) {
|
||||
fprintf(stderr, "error: Not implemented\n");
|
||||
return 1;
|
||||
} else if (0 == strcmp(argv[argi], "--validate")) {
|
||||
validate_spirv_binary = true;
|
||||
} else if (0 == strcmp(argv[argi], "--model=shader_lite")) {
|
||||
if (model_type != spvtools::comp::kMarkvModelUnknown)
|
||||
fprintf(stderr, "error: More than one model specified\n");
|
||||
model_type = spvtools::comp::kMarkvModelShaderLite;
|
||||
} else if (0 == strcmp(argv[argi], "--model=shader_mid")) {
|
||||
if (model_type != spvtools::comp::kMarkvModelUnknown)
|
||||
fprintf(stderr, "error: More than one model specified\n");
|
||||
model_type = spvtools::comp::kMarkvModelShaderMid;
|
||||
} else if (0 == strcmp(argv[argi], "--model=shader_max")) {
|
||||
if (model_type != spvtools::comp::kMarkvModelUnknown)
|
||||
fprintf(stderr, "error: More than one model specified\n");
|
||||
model_type = spvtools::comp::kMarkvModelShaderMax;
|
||||
} else {
|
||||
print_usage(argv[0]);
|
||||
return 1;
|
||||
}
|
||||
} break;
|
||||
case '\0': {
|
||||
// Setting a filename of "-" to indicate stdin.
|
||||
if (!input_filename) {
|
||||
input_filename = argv[argi];
|
||||
} else {
|
||||
fprintf(stderr, "error: More than one input file specified\n");
|
||||
return 1;
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
print_usage(argv[0]);
|
||||
return 1;
|
||||
}
|
||||
} else {
|
||||
if (!input_filename) {
|
||||
input_filename = argv[argi];
|
||||
} else {
|
||||
fprintf(stderr, "error: More than one input file specified\n");
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (model_type == spvtools::comp::kMarkvModelUnknown)
|
||||
model_type = spvtools::comp::kMarkvModelShaderLite;
|
||||
|
||||
const auto no_comments = spvtools::comp::MarkvLogConsumer();
|
||||
const auto output_to_stderr = [](const std::string& str) {
|
||||
std::cerr << str;
|
||||
};
|
||||
|
||||
ScopedContext ctx(kSpvEnv);
|
||||
|
||||
std::unique_ptr<spvtools::comp::MarkvModel> model =
|
||||
spvtools::comp::CreateMarkvModel(model_type);
|
||||
|
||||
std::vector<uint32_t> spirv;
|
||||
std::vector<uint8_t> markv;
|
||||
|
||||
spvtools::comp::MarkvCodecOptions options;
|
||||
options.validate_spirv_binary = validate_spirv_binary;
|
||||
|
||||
if (task == kEncode) {
|
||||
if (!ReadFile<uint32_t>(input_filename, "rb", &spirv)) return 1;
|
||||
assert(!spirv.empty());
|
||||
|
||||
if (SPV_SUCCESS != spvtools::comp::SpirvToMarkv(
|
||||
ctx.context, spirv, options, *model,
|
||||
DiagnosticsMessageHandler,
|
||||
want_comments ? output_to_stderr : no_comments,
|
||||
spvtools::comp::MarkvDebugConsumer(), &markv)) {
|
||||
std::cerr << "error: Failed to encode " << input_filename << " to MARK-V "
|
||||
<< std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (!WriteFile<uint8_t>(output_filename, "wb", markv.data(), markv.size()))
|
||||
return 1;
|
||||
} else if (task == kDecode) {
|
||||
if (!ReadFile<uint8_t>(input_filename, "rb", &markv)) return 1;
|
||||
assert(!markv.empty());
|
||||
|
||||
if (SPV_SUCCESS != spvtools::comp::MarkvToSpirv(
|
||||
ctx.context, markv, options, *model,
|
||||
DiagnosticsMessageHandler,
|
||||
want_comments ? output_to_stderr : no_comments,
|
||||
spvtools::comp::MarkvDebugConsumer(), &spirv)) {
|
||||
std::cerr << "error: Failed to decode " << input_filename << " to SPIR-V "
|
||||
<< std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (!WriteFile<uint32_t>(output_filename, "wb", spirv.data(), spirv.size()))
|
||||
return 1;
|
||||
} else if (task == kTest) {
|
||||
if (!ReadFile<uint32_t>(input_filename, "rb", &spirv)) return 1;
|
||||
assert(!spirv.empty());
|
||||
|
||||
std::vector<uint32_t> spirv_before;
|
||||
spvtools::Optimizer optimizer(kSpvEnv);
|
||||
optimizer.RegisterPass(spvtools::CreateCompactIdsPass());
|
||||
if (!optimizer.Run(spirv.data(), spirv.size(), &spirv_before)) {
|
||||
std::cerr << "error: Optimizer failure on: " << input_filename
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
std::vector<std::string> encoder_instruction_bits;
|
||||
std::vector<std::string> encoder_instruction_comments;
|
||||
std::vector<std::vector<uint32_t>> encoder_instruction_words;
|
||||
std::vector<std::string> decoder_instruction_bits;
|
||||
std::vector<std::string> decoder_instruction_comments;
|
||||
std::vector<std::vector<uint32_t>> decoder_instruction_words;
|
||||
|
||||
const auto encoder_debug_consumer = [&](const std::vector<uint32_t>& words,
|
||||
const std::string& bits,
|
||||
const std::string& comment) {
|
||||
encoder_instruction_words.push_back(words);
|
||||
encoder_instruction_bits.push_back(bits);
|
||||
encoder_instruction_comments.push_back(comment);
|
||||
return true;
|
||||
};
|
||||
|
||||
if (SPV_SUCCESS != spvtools::comp::SpirvToMarkv(
|
||||
ctx.context, spirv_before, options, *model,
|
||||
DiagnosticsMessageHandler,
|
||||
want_comments ? output_to_stderr : no_comments,
|
||||
encoder_debug_consumer, &markv)) {
|
||||
std::cerr << "error: Failed to encode " << input_filename << " to MARK-V "
|
||||
<< std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
const auto write_bug_report = [&]() {
|
||||
for (size_t inst_index = 0; inst_index < decoder_instruction_words.size();
|
||||
++inst_index) {
|
||||
std::cerr << "\nInstruction #" << inst_index << std::endl;
|
||||
std::cerr << "\nEncoder words: ";
|
||||
for (uint32_t word : encoder_instruction_words[inst_index])
|
||||
std::cerr << word << " ";
|
||||
std::cerr << "\nDecoder words: ";
|
||||
for (uint32_t word : decoder_instruction_words[inst_index])
|
||||
std::cerr << word << " ";
|
||||
std::cerr << std::endl;
|
||||
|
||||
std::cerr << "\nEncoder bits: " << encoder_instruction_bits[inst_index];
|
||||
std::cerr << "\nDecoder bits: " << decoder_instruction_bits[inst_index];
|
||||
std::cerr << std::endl;
|
||||
|
||||
std::cerr << "\nEncoder comments:\n"
|
||||
<< encoder_instruction_comments[inst_index];
|
||||
std::cerr << "Decoder comments:\n"
|
||||
<< decoder_instruction_comments[inst_index];
|
||||
std::cerr << std::endl;
|
||||
}
|
||||
};
|
||||
|
||||
const auto decoder_debug_consumer = [&](const std::vector<uint32_t>& words,
|
||||
const std::string& bits,
|
||||
const std::string& comment) {
|
||||
const size_t inst_index = decoder_instruction_words.size();
|
||||
if (inst_index >= encoder_instruction_words.size()) {
|
||||
write_bug_report();
|
||||
std::cerr << "error: Decoder has more instructions than encoder: "
|
||||
<< input_filename << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
decoder_instruction_words.push_back(words);
|
||||
decoder_instruction_bits.push_back(bits);
|
||||
decoder_instruction_comments.push_back(comment);
|
||||
|
||||
if (encoder_instruction_words[inst_index] !=
|
||||
decoder_instruction_words[inst_index]) {
|
||||
write_bug_report();
|
||||
std::cerr << "error: Words of the last decoded instruction differ from "
|
||||
"reference: "
|
||||
<< input_filename << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (encoder_instruction_bits[inst_index] !=
|
||||
decoder_instruction_bits[inst_index]) {
|
||||
write_bug_report();
|
||||
std::cerr << "error: Bits of the last decoded instruction differ from "
|
||||
"reference: "
|
||||
<< input_filename << std::endl;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
std::vector<uint32_t> spirv_after;
|
||||
const spv_result_t decoding_result = spvtools::comp::MarkvToSpirv(
|
||||
ctx.context, markv, options, *model, DiagnosticsMessageHandler,
|
||||
want_comments ? output_to_stderr : no_comments, decoder_debug_consumer,
|
||||
&spirv_after);
|
||||
|
||||
if (decoding_result == SPV_REQUESTED_TERMINATION) {
|
||||
std::cerr << "error: Decoding interrupted by the debugger: "
|
||||
<< input_filename << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (decoding_result != SPV_SUCCESS) {
|
||||
std::cerr << "error: Failed to decode encoded " << input_filename
|
||||
<< " back to SPIR-V " << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
assert(spirv_before.size() == spirv_after.size());
|
||||
assert(std::mismatch(std::next(spirv_before.begin(), 5), spirv_before.end(),
|
||||
std::next(spirv_after.begin(), 5)) ==
|
||||
std::make_pair(spirv_before.end(), spirv_after.end()));
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
@ -1,50 +0,0 @@
|
||||
// Copyright (c) 2017 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "tools/comp/markv_model_factory.h"
|
||||
|
||||
#include "source/util/make_unique.h"
|
||||
#include "tools/comp/markv_model_shader.h"
|
||||
|
||||
namespace spvtools {
|
||||
namespace comp {
|
||||
|
||||
std::unique_ptr<MarkvModel> CreateMarkvModel(MarkvModelType type) {
|
||||
std::unique_ptr<MarkvModel> model;
|
||||
switch (type) {
|
||||
case kMarkvModelShaderLite: {
|
||||
model = MakeUnique<MarkvModelShaderLite>();
|
||||
break;
|
||||
}
|
||||
case kMarkvModelShaderMid: {
|
||||
model = MakeUnique<MarkvModelShaderMid>();
|
||||
break;
|
||||
}
|
||||
case kMarkvModelShaderMax: {
|
||||
model = MakeUnique<MarkvModelShaderMax>();
|
||||
break;
|
||||
}
|
||||
case kMarkvModelUnknown: {
|
||||
assert(0 && "kMarkvModelUnknown supplied to CreateMarkvModel");
|
||||
return model;
|
||||
}
|
||||
}
|
||||
|
||||
model->SetModelType(static_cast<uint32_t>(type));
|
||||
|
||||
return model;
|
||||
}
|
||||
|
||||
} // namespace comp
|
||||
} // namespace spvtools
|
@ -1,37 +0,0 @@
|
||||
// Copyright (c) 2017 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef TOOLS_COMP_MARKV_MODEL_FACTORY_H_
|
||||
#define TOOLS_COMP_MARKV_MODEL_FACTORY_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "source/comp/markv_model.h"
|
||||
|
||||
namespace spvtools {
|
||||
namespace comp {
|
||||
|
||||
enum MarkvModelType {
|
||||
kMarkvModelUnknown = 0,
|
||||
kMarkvModelShaderLite,
|
||||
kMarkvModelShaderMid,
|
||||
kMarkvModelShaderMax,
|
||||
};
|
||||
|
||||
std::unique_ptr<MarkvModel> CreateMarkvModel(MarkvModelType type);
|
||||
|
||||
} // namespace comp
|
||||
} // namespace spvtools
|
||||
|
||||
#endif // TOOLS_COMP_MARKV_MODEL_FACTORY_H_
|
@ -1,84 +0,0 @@
|
||||
// Copyright (c) 2017 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "tools/comp/markv_model_shader.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "source/util/make_unique.h"
|
||||
|
||||
namespace spvtools {
|
||||
namespace comp {
|
||||
namespace {
|
||||
|
||||
// Signals that the value is not in the coding scheme and a fallback method
|
||||
// needs to be used.
|
||||
const uint64_t kMarkvNoneOfTheAbove = MarkvModel::GetMarkvNoneOfTheAbove();
|
||||
|
||||
inline uint32_t CombineOpcodeAndNumOperands(uint32_t opcode,
|
||||
uint32_t num_operands) {
|
||||
return opcode | (num_operands << 16);
|
||||
}
|
||||
|
||||
#include "tools/comp/markv_model_shader_default_autogen.inc"
|
||||
|
||||
} // namespace
|
||||
|
||||
MarkvModelShaderLite::MarkvModelShaderLite() {
|
||||
const uint16_t kVersionNumber = 1;
|
||||
SetModelVersion(kVersionNumber);
|
||||
|
||||
opcode_and_num_operands_huffman_codec_ =
|
||||
MakeUnique<HuffmanCodec<uint64_t>>(GetOpcodeAndNumOperandsHist());
|
||||
|
||||
id_fallback_strategy_ = IdFallbackStrategy::kShortDescriptor;
|
||||
}
|
||||
|
||||
MarkvModelShaderMid::MarkvModelShaderMid() {
|
||||
const uint16_t kVersionNumber = 1;
|
||||
SetModelVersion(kVersionNumber);
|
||||
|
||||
opcode_and_num_operands_huffman_codec_ =
|
||||
MakeUnique<HuffmanCodec<uint64_t>>(GetOpcodeAndNumOperandsHist());
|
||||
non_id_word_huffman_codecs_ = GetNonIdWordHuffmanCodecs();
|
||||
id_descriptor_huffman_codecs_ = GetIdDescriptorHuffmanCodecs();
|
||||
descriptors_with_coding_scheme_ = GetDescriptorsWithCodingScheme();
|
||||
literal_string_huffman_codecs_ = GetLiteralStringHuffmanCodecs();
|
||||
|
||||
id_fallback_strategy_ = IdFallbackStrategy::kShortDescriptor;
|
||||
}
|
||||
|
||||
MarkvModelShaderMax::MarkvModelShaderMax() {
|
||||
const uint16_t kVersionNumber = 1;
|
||||
SetModelVersion(kVersionNumber);
|
||||
|
||||
opcode_and_num_operands_huffman_codec_ =
|
||||
MakeUnique<HuffmanCodec<uint64_t>>(GetOpcodeAndNumOperandsHist());
|
||||
opcode_and_num_operands_markov_huffman_codecs_ =
|
||||
GetOpcodeAndNumOperandsMarkovHuffmanCodecs();
|
||||
non_id_word_huffman_codecs_ = GetNonIdWordHuffmanCodecs();
|
||||
id_descriptor_huffman_codecs_ = GetIdDescriptorHuffmanCodecs();
|
||||
descriptors_with_coding_scheme_ = GetDescriptorsWithCodingScheme();
|
||||
literal_string_huffman_codecs_ = GetLiteralStringHuffmanCodecs();
|
||||
|
||||
id_fallback_strategy_ = IdFallbackStrategy::kRuleBased;
|
||||
}
|
||||
|
||||
} // namespace comp
|
||||
} // namespace spvtools
|
@ -1,47 +0,0 @@
|
||||
// Copyright (c) 2017 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef TOOLS_COMP_MARKV_MODEL_SHADER_H_
|
||||
#define TOOLS_COMP_MARKV_MODEL_SHADER_H_
|
||||
|
||||
#include "source/comp/markv_model.h"
|
||||
|
||||
namespace spvtools {
|
||||
namespace comp {
|
||||
|
||||
// MARK-V shader compression model, which only uses fast and lightweight
|
||||
// algorithms, which do not require training and are not heavily dependent on
|
||||
// SPIR-V grammar. Compression ratio is worse than by other models.
|
||||
class MarkvModelShaderLite : public MarkvModel {
|
||||
public:
|
||||
MarkvModelShaderLite();
|
||||
};
|
||||
|
||||
// MARK-V shader compression model with balanced compression ratio and runtime
|
||||
// performance.
|
||||
class MarkvModelShaderMid : public MarkvModel {
|
||||
public:
|
||||
MarkvModelShaderMid();
|
||||
};
|
||||
|
||||
// MARK-V shader compression model designed for maximum compression.
|
||||
class MarkvModelShaderMax : public MarkvModel {
|
||||
public:
|
||||
MarkvModelShaderMax();
|
||||
};
|
||||
|
||||
} // namespace comp
|
||||
} // namespace spvtools
|
||||
|
||||
#endif // TOOLS_COMP_MARKV_MODEL_SHADER_H_
|
File diff suppressed because it is too large
Load Diff
@ -1,165 +0,0 @@
|
||||
// Copyright (c) 2017 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "tools/stats/spirv_stats.h"
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "source/diagnostic.h"
|
||||
#include "source/enum_string_mapping.h"
|
||||
#include "source/extensions.h"
|
||||
#include "source/id_descriptor.h"
|
||||
#include "source/instruction.h"
|
||||
#include "source/opcode.h"
|
||||
#include "source/operand.h"
|
||||
#include "source/val/instruction.h"
|
||||
#include "source/val/validate.h"
|
||||
#include "source/val/validation_state.h"
|
||||
#include "spirv-tools/libspirv.h"
|
||||
|
||||
namespace spvtools {
|
||||
namespace stats {
|
||||
namespace {
|
||||
|
||||
// Helper class for stats aggregation. Receives as in/out parameter.
|
||||
// Constructs ValidationState and updates it by running validator for each
|
||||
// instruction.
|
||||
class StatsAggregator {
|
||||
public:
|
||||
StatsAggregator(SpirvStats* in_out_stats, const val::ValidationState_t* state)
|
||||
: stats_(in_out_stats), vstate_(state) {}
|
||||
|
||||
// Processes the instructions to collect stats.
|
||||
void aggregate() {
|
||||
const auto& instructions = vstate_->ordered_instructions();
|
||||
|
||||
++stats_->version_hist[vstate_->version()];
|
||||
++stats_->generator_hist[vstate_->generator()];
|
||||
|
||||
for (size_t i = 0; i < instructions.size(); ++i) {
|
||||
const auto& inst = instructions[i];
|
||||
|
||||
ProcessOpcode(&inst, i);
|
||||
ProcessCapability(&inst);
|
||||
ProcessExtension(&inst);
|
||||
ProcessConstant(&inst);
|
||||
}
|
||||
}
|
||||
|
||||
// Collects OpCapability statistics.
|
||||
void ProcessCapability(const val::Instruction* inst) {
|
||||
if (inst->opcode() != SpvOpCapability) return;
|
||||
const uint32_t capability = inst->word(inst->operands()[0].offset);
|
||||
++stats_->capability_hist[capability];
|
||||
}
|
||||
|
||||
// Collects OpExtension statistics.
|
||||
void ProcessExtension(const val::Instruction* inst) {
|
||||
if (inst->opcode() != SpvOpExtension) return;
|
||||
const std::string extension = GetExtensionString(&inst->c_inst());
|
||||
++stats_->extension_hist[extension];
|
||||
}
|
||||
|
||||
// Collects OpCode statistics.
|
||||
void ProcessOpcode(const val::Instruction* inst, size_t idx) {
|
||||
const SpvOp opcode = inst->opcode();
|
||||
++stats_->opcode_hist[opcode];
|
||||
|
||||
if (idx == 0) return;
|
||||
|
||||
--idx;
|
||||
|
||||
const auto& instructions = vstate_->ordered_instructions();
|
||||
|
||||
auto step_it = stats_->opcode_markov_hist.begin();
|
||||
for (; step_it != stats_->opcode_markov_hist.end(); --idx, ++step_it) {
|
||||
auto& hist = (*step_it)[instructions[idx].opcode()];
|
||||
++hist[opcode];
|
||||
|
||||
if (idx == 0) break;
|
||||
}
|
||||
}
|
||||
|
||||
// Collects OpConstant statistics.
|
||||
void ProcessConstant(const val::Instruction* inst) {
|
||||
if (inst->opcode() != SpvOpConstant) return;
|
||||
|
||||
const uint32_t type_id = inst->GetOperandAs<uint32_t>(0);
|
||||
const auto type_decl_it = vstate_->all_definitions().find(type_id);
|
||||
assert(type_decl_it != vstate_->all_definitions().end());
|
||||
|
||||
const val::Instruction& type_decl_inst = *type_decl_it->second;
|
||||
const SpvOp type_op = type_decl_inst.opcode();
|
||||
if (type_op == SpvOpTypeInt) {
|
||||
const uint32_t bit_width = type_decl_inst.GetOperandAs<uint32_t>(1);
|
||||
const uint32_t is_signed = type_decl_inst.GetOperandAs<uint32_t>(2);
|
||||
assert(is_signed == 0 || is_signed == 1);
|
||||
if (bit_width == 16) {
|
||||
if (is_signed)
|
||||
++stats_->s16_constant_hist[inst->GetOperandAs<int16_t>(2)];
|
||||
else
|
||||
++stats_->u16_constant_hist[inst->GetOperandAs<uint16_t>(2)];
|
||||
} else if (bit_width == 32) {
|
||||
if (is_signed)
|
||||
++stats_->s32_constant_hist[inst->GetOperandAs<int32_t>(2)];
|
||||
else
|
||||
++stats_->u32_constant_hist[inst->GetOperandAs<uint32_t>(2)];
|
||||
} else if (bit_width == 64) {
|
||||
if (is_signed)
|
||||
++stats_->s64_constant_hist[inst->GetOperandAs<int64_t>(2)];
|
||||
else
|
||||
++stats_->u64_constant_hist[inst->GetOperandAs<uint64_t>(2)];
|
||||
} else {
|
||||
assert(false && "TypeInt bit width is not 16, 32 or 64");
|
||||
}
|
||||
} else if (type_op == SpvOpTypeFloat) {
|
||||
const uint32_t bit_width = type_decl_inst.GetOperandAs<uint32_t>(1);
|
||||
if (bit_width == 32) {
|
||||
++stats_->f32_constant_hist[inst->GetOperandAs<float>(2)];
|
||||
} else if (bit_width == 64) {
|
||||
++stats_->f64_constant_hist[inst->GetOperandAs<double>(2)];
|
||||
} else {
|
||||
assert(bit_width == 16);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
SpirvStats* stats_;
|
||||
const val::ValidationState_t* vstate_;
|
||||
IdDescriptorCollection id_descriptors_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
spv_result_t AggregateStats(const spv_context context, const uint32_t* words,
|
||||
const size_t num_words, spv_diagnostic* pDiagnostic,
|
||||
SpirvStats* stats) {
|
||||
std::unique_ptr<val::ValidationState_t> vstate;
|
||||
spv_validator_options_t options;
|
||||
spv_result_t result = ValidateBinaryAndKeepValidationState(
|
||||
context, &options, words, num_words, pDiagnostic, &vstate);
|
||||
if (result != SPV_SUCCESS) return result;
|
||||
|
||||
StatsAggregator stats_aggregator(stats, vstate.get());
|
||||
stats_aggregator.aggregate();
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
} // namespace stats
|
||||
} // namespace spvtools
|
@ -1,93 +0,0 @@
|
||||
// Copyright (c) 2017 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef TOOLS_STATS_SPIRV_STATS_H_
|
||||
#define TOOLS_STATS_SPIRV_STATS_H_
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "spirv-tools/libspirv.hpp"
|
||||
|
||||
namespace spvtools {
|
||||
namespace stats {
|
||||
|
||||
struct SpirvStats {
|
||||
// Version histogram, version_word -> count.
|
||||
std::unordered_map<uint32_t, uint32_t> version_hist;
|
||||
|
||||
// Generator histogram, generator_word -> count.
|
||||
std::unordered_map<uint32_t, uint32_t> generator_hist;
|
||||
|
||||
// Capability histogram, SpvCapabilityXXX -> count.
|
||||
std::unordered_map<uint32_t, uint32_t> capability_hist;
|
||||
|
||||
// Extension histogram, extension_string -> count.
|
||||
std::unordered_map<std::string, uint32_t> extension_hist;
|
||||
|
||||
// Opcode histogram, SpvOpXXX -> count.
|
||||
std::unordered_map<uint32_t, uint32_t> opcode_hist;
|
||||
|
||||
// OpConstant u16 histogram, value -> count.
|
||||
std::unordered_map<uint16_t, uint32_t> u16_constant_hist;
|
||||
|
||||
// OpConstant u32 histogram, value -> count.
|
||||
std::unordered_map<uint32_t, uint32_t> u32_constant_hist;
|
||||
|
||||
// OpConstant u64 histogram, value -> count.
|
||||
std::unordered_map<uint64_t, uint32_t> u64_constant_hist;
|
||||
|
||||
// OpConstant s16 histogram, value -> count.
|
||||
std::unordered_map<int16_t, uint32_t> s16_constant_hist;
|
||||
|
||||
// OpConstant s32 histogram, value -> count.
|
||||
std::unordered_map<int32_t, uint32_t> s32_constant_hist;
|
||||
|
||||
// OpConstant s64 histogram, value -> count.
|
||||
std::unordered_map<int64_t, uint32_t> s64_constant_hist;
|
||||
|
||||
// OpConstant f32 histogram, value -> count.
|
||||
std::unordered_map<float, uint32_t> f32_constant_hist;
|
||||
|
||||
// OpConstant f64 histogram, value -> count.
|
||||
std::unordered_map<double, uint32_t> f64_constant_hist;
|
||||
|
||||
// Used to collect statistics on opcodes triggering other opcodes.
|
||||
// Container scheme: gap between instructions -> cue opcode -> later opcode
|
||||
// -> count.
|
||||
// For example opcode_markov_hist[2][OpFMul][OpFAdd] corresponds to
|
||||
// the number of times an OpMul appears, followed by 2 other instructions,
|
||||
// followed by OpFAdd.
|
||||
// opcode_markov_hist[0][OpFMul][OpFAdd] corresponds to how many times
|
||||
// OpFMul appears, directly followed by OpFAdd.
|
||||
// The size of the outer std::vector also serves as an input parameter,
|
||||
// determining how many steps will be collected.
|
||||
// I.e. do opcode_markov_hist.resize(1) to collect data for one step only.
|
||||
std::vector<
|
||||
std::unordered_map<uint32_t, std::unordered_map<uint32_t, uint32_t>>>
|
||||
opcode_markov_hist;
|
||||
};
|
||||
|
||||
// Aggregates existing |stats| with new stats extracted from |binary|.
|
||||
spv_result_t AggregateStats(const spv_context context, const uint32_t* words,
|
||||
const size_t num_words, spv_diagnostic* pDiagnostic,
|
||||
SpirvStats* stats);
|
||||
|
||||
} // namespace stats
|
||||
} // namespace spvtools
|
||||
|
||||
#endif // TOOLS_STATS_SPIRV_STATS_H_
|
@ -1,173 +0,0 @@
|
||||
// Copyright (c) 2017 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <cassert>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "spirv-tools/libspirv.h"
|
||||
#include "tools/io.h"
|
||||
#include "tools/stats/spirv_stats.h"
|
||||
#include "tools/stats/stats_analyzer.h"
|
||||
|
||||
namespace {
|
||||
|
||||
void PrintUsage(char* argv0) {
|
||||
printf(
|
||||
R"(%s - Collect statistics from one or more SPIR-V binary file(s).
|
||||
|
||||
USAGE: %s [options] [<filepaths>]
|
||||
|
||||
TIP: In order to collect statistics from all .spv files under current dir use
|
||||
find . -name "*.spv" -print0 | xargs -0 -s 2000000 %s
|
||||
|
||||
Options:
|
||||
-h, --help
|
||||
Print this help.
|
||||
|
||||
-v, --verbose
|
||||
Print additional info to stderr.
|
||||
)",
|
||||
argv0, argv0, argv0);
|
||||
}
|
||||
|
||||
void DiagnosticsMessageHandler(spv_message_level_t level, const char*,
|
||||
const spv_position_t& position,
|
||||
const char* message) {
|
||||
switch (level) {
|
||||
case SPV_MSG_FATAL:
|
||||
case SPV_MSG_INTERNAL_ERROR:
|
||||
case SPV_MSG_ERROR:
|
||||
std::cerr << "error: " << position.index << ": " << message << std::endl;
|
||||
break;
|
||||
case SPV_MSG_WARNING:
|
||||
std::cout << "warning: " << position.index << ": " << message
|
||||
<< std::endl;
|
||||
break;
|
||||
case SPV_MSG_INFO:
|
||||
std::cout << "info: " << position.index << ": " << message << std::endl;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
bool continue_processing = true;
|
||||
int return_code = 0;
|
||||
|
||||
bool expect_output_path = false;
|
||||
bool verbose = false;
|
||||
|
||||
std::vector<const char*> paths;
|
||||
const char* output_path = nullptr;
|
||||
|
||||
for (int argi = 1; continue_processing && argi < argc; ++argi) {
|
||||
const char* cur_arg = argv[argi];
|
||||
if ('-' == cur_arg[0]) {
|
||||
if (0 == strcmp(cur_arg, "--help") || 0 == strcmp(cur_arg, "-h")) {
|
||||
PrintUsage(argv[0]);
|
||||
continue_processing = false;
|
||||
return_code = 0;
|
||||
} else if (0 == strcmp(cur_arg, "--verbose") ||
|
||||
0 == strcmp(cur_arg, "-v")) {
|
||||
verbose = true;
|
||||
} else if (0 == strcmp(cur_arg, "--output") ||
|
||||
0 == strcmp(cur_arg, "-o")) {
|
||||
expect_output_path = true;
|
||||
} else {
|
||||
PrintUsage(argv[0]);
|
||||
continue_processing = false;
|
||||
return_code = 1;
|
||||
}
|
||||
} else {
|
||||
if (expect_output_path) {
|
||||
output_path = cur_arg;
|
||||
expect_output_path = false;
|
||||
} else {
|
||||
paths.push_back(cur_arg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Exit if command line parsing was not successful.
|
||||
if (!continue_processing) {
|
||||
return return_code;
|
||||
}
|
||||
|
||||
std::cerr << "Processing " << paths.size() << " files..." << std::endl;
|
||||
|
||||
spvtools::Context ctx(SPV_ENV_UNIVERSAL_1_1);
|
||||
ctx.SetMessageConsumer(DiagnosticsMessageHandler);
|
||||
|
||||
spvtools::stats::SpirvStats stats;
|
||||
stats.opcode_markov_hist.resize(1);
|
||||
|
||||
for (size_t index = 0; index < paths.size(); ++index) {
|
||||
const size_t kMilestonePeriod = 1000;
|
||||
if (verbose) {
|
||||
if (index % kMilestonePeriod == kMilestonePeriod - 1)
|
||||
std::cerr << "Processed " << index + 1 << " files..." << std::endl;
|
||||
}
|
||||
|
||||
const char* path = paths[index];
|
||||
std::vector<uint32_t> contents;
|
||||
if (!ReadFile<uint32_t>(path, "rb", &contents)) return 1;
|
||||
|
||||
if (SPV_SUCCESS !=
|
||||
spvtools::stats::AggregateStats(ctx.CContext(), contents.data(),
|
||||
contents.size(), nullptr, &stats)) {
|
||||
std::cerr << "error: Failed to aggregate stats for " << path << std::endl;
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
spvtools::stats::StatsAnalyzer analyzer(stats);
|
||||
|
||||
std::ofstream fout;
|
||||
if (output_path) {
|
||||
fout.open(output_path);
|
||||
if (!fout.is_open()) {
|
||||
std::cerr << "error: Failed to open " << output_path << std::endl;
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
std::ostream& out = fout.is_open() ? fout : std::cout;
|
||||
out << std::endl;
|
||||
analyzer.WriteVersion(out);
|
||||
analyzer.WriteGenerator(out);
|
||||
|
||||
out << std::endl;
|
||||
analyzer.WriteCapability(out);
|
||||
|
||||
out << std::endl;
|
||||
analyzer.WriteExtension(out);
|
||||
|
||||
out << std::endl;
|
||||
analyzer.WriteOpcode(out);
|
||||
|
||||
out << std::endl;
|
||||
analyzer.WriteOpcodeMarkov(out);
|
||||
|
||||
out << std::endl;
|
||||
analyzer.WriteConstantLiterals(out);
|
||||
|
||||
return 0;
|
||||
}
|
@ -1,235 +0,0 @@
|
||||
// Copyright (c) 2017 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "tools/stats/stats_analyzer.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "source/comp/markv_model.h"
|
||||
#include "source/enum_string_mapping.h"
|
||||
#include "source/latest_version_spirv_header.h"
|
||||
#include "source/opcode.h"
|
||||
#include "source/operand.h"
|
||||
#include "source/spirv_constant.h"
|
||||
|
||||
namespace spvtools {
|
||||
namespace stats {
|
||||
namespace {
|
||||
|
||||
// Signals that the value is not in the coding scheme and a fallback method
|
||||
// needs to be used.
|
||||
const uint64_t kMarkvNoneOfTheAbove =
|
||||
comp::MarkvModel::GetMarkvNoneOfTheAbove();
|
||||
|
||||
std::string GetVersionString(uint32_t word) {
|
||||
std::stringstream ss;
|
||||
ss << "Version " << SPV_SPIRV_VERSION_MAJOR_PART(word) << "."
|
||||
<< SPV_SPIRV_VERSION_MINOR_PART(word);
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
std::string GetGeneratorString(uint32_t word) {
|
||||
return spvGeneratorStr(SPV_GENERATOR_TOOL_PART(word));
|
||||
}
|
||||
|
||||
std::string GetOpcodeString(uint32_t word) {
|
||||
return spvOpcodeString(static_cast<SpvOp>(word));
|
||||
}
|
||||
|
||||
std::string GetCapabilityString(uint32_t word) {
|
||||
return CapabilityToString(static_cast<SpvCapability>(word));
|
||||
}
|
||||
|
||||
template <class T>
|
||||
std::string KeyIsLabel(T key) {
|
||||
std::stringstream ss;
|
||||
ss << key;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
template <class Key>
|
||||
std::unordered_map<Key, double> GetRecall(
|
||||
const std::unordered_map<Key, uint32_t>& hist, uint64_t total) {
|
||||
std::unordered_map<Key, double> freq;
|
||||
for (const auto& pair : hist) {
|
||||
const double frequency =
|
||||
static_cast<double>(pair.second) / static_cast<double>(total);
|
||||
freq.emplace(pair.first, frequency);
|
||||
}
|
||||
return freq;
|
||||
}
|
||||
|
||||
template <class Key>
|
||||
std::unordered_map<Key, double> GetPrevalence(
|
||||
const std::unordered_map<Key, uint32_t>& hist) {
|
||||
uint64_t total = 0;
|
||||
for (const auto& pair : hist) {
|
||||
total += pair.second;
|
||||
}
|
||||
|
||||
return GetRecall(hist, total);
|
||||
}
|
||||
|
||||
// Writes |freq| to |out| sorted by frequency in the following format:
|
||||
// LABEL3 70%
|
||||
// LABEL1 20%
|
||||
// LABEL2 10%
|
||||
// |label_from_key| is used to convert |Key| to label.
|
||||
template <class Key>
|
||||
void WriteFreq(std::ostream& out, const std::unordered_map<Key, double>& freq,
|
||||
std::string (*label_from_key)(Key)) {
|
||||
std::vector<std::pair<Key, double>> sorted_freq(freq.begin(), freq.end());
|
||||
std::sort(sorted_freq.begin(), sorted_freq.end(),
|
||||
[](const std::pair<Key, double>& left,
|
||||
const std::pair<Key, double>& right) {
|
||||
return left.second > right.second;
|
||||
});
|
||||
|
||||
for (const auto& pair : sorted_freq) {
|
||||
if (pair.second < 0.001) break;
|
||||
out << label_from_key(pair.first) << " " << pair.second * 100.0 << "%"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
StatsAnalyzer::StatsAnalyzer(const SpirvStats& stats) : stats_(stats) {
|
||||
num_modules_ = 0;
|
||||
for (const auto& pair : stats_.version_hist) {
|
||||
num_modules_ += pair.second;
|
||||
}
|
||||
|
||||
version_freq_ = GetRecall(stats_.version_hist, num_modules_);
|
||||
generator_freq_ = GetRecall(stats_.generator_hist, num_modules_);
|
||||
capability_freq_ = GetRecall(stats_.capability_hist, num_modules_);
|
||||
extension_freq_ = GetRecall(stats_.extension_hist, num_modules_);
|
||||
opcode_freq_ = GetPrevalence(stats_.opcode_hist);
|
||||
}
|
||||
|
||||
void StatsAnalyzer::WriteVersion(std::ostream& out) {
|
||||
WriteFreq(out, version_freq_, GetVersionString);
|
||||
}
|
||||
|
||||
void StatsAnalyzer::WriteGenerator(std::ostream& out) {
|
||||
WriteFreq(out, generator_freq_, GetGeneratorString);
|
||||
}
|
||||
|
||||
void StatsAnalyzer::WriteCapability(std::ostream& out) {
|
||||
WriteFreq(out, capability_freq_, GetCapabilityString);
|
||||
}
|
||||
|
||||
void StatsAnalyzer::WriteExtension(std::ostream& out) {
|
||||
WriteFreq(out, extension_freq_, KeyIsLabel);
|
||||
}
|
||||
|
||||
void StatsAnalyzer::WriteOpcode(std::ostream& out) {
|
||||
out << "Total unique opcodes used: " << opcode_freq_.size() << std::endl;
|
||||
WriteFreq(out, opcode_freq_, GetOpcodeString);
|
||||
}
|
||||
|
||||
void StatsAnalyzer::WriteConstantLiterals(std::ostream& out) {
|
||||
out << "Constant literals" << std::endl;
|
||||
|
||||
out << "Float 32" << std::endl;
|
||||
WriteFreq(out, GetPrevalence(stats_.f32_constant_hist), KeyIsLabel);
|
||||
|
||||
out << std::endl << "Float 64" << std::endl;
|
||||
WriteFreq(out, GetPrevalence(stats_.f64_constant_hist), KeyIsLabel);
|
||||
|
||||
out << std::endl << "Unsigned int 16" << std::endl;
|
||||
WriteFreq(out, GetPrevalence(stats_.u16_constant_hist), KeyIsLabel);
|
||||
|
||||
out << std::endl << "Signed int 16" << std::endl;
|
||||
WriteFreq(out, GetPrevalence(stats_.s16_constant_hist), KeyIsLabel);
|
||||
|
||||
out << std::endl << "Unsigned int 32" << std::endl;
|
||||
WriteFreq(out, GetPrevalence(stats_.u32_constant_hist), KeyIsLabel);
|
||||
|
||||
out << std::endl << "Signed int 32" << std::endl;
|
||||
WriteFreq(out, GetPrevalence(stats_.s32_constant_hist), KeyIsLabel);
|
||||
|
||||
out << std::endl << "Unsigned int 64" << std::endl;
|
||||
WriteFreq(out, GetPrevalence(stats_.u64_constant_hist), KeyIsLabel);
|
||||
|
||||
out << std::endl << "Signed int 64" << std::endl;
|
||||
WriteFreq(out, GetPrevalence(stats_.s64_constant_hist), KeyIsLabel);
|
||||
}
|
||||
|
||||
void StatsAnalyzer::WriteOpcodeMarkov(std::ostream& out) {
|
||||
if (stats_.opcode_markov_hist.empty()) return;
|
||||
|
||||
const std::unordered_map<uint32_t, std::unordered_map<uint32_t, uint32_t>>&
|
||||
cue_to_hist = stats_.opcode_markov_hist[0];
|
||||
|
||||
// Sort by prevalence of the opcodes in opcode_freq_ (descending).
|
||||
std::vector<std::pair<uint32_t, std::unordered_map<uint32_t, uint32_t>>>
|
||||
sorted_cue_to_hist(cue_to_hist.begin(), cue_to_hist.end());
|
||||
std::sort(
|
||||
sorted_cue_to_hist.begin(), sorted_cue_to_hist.end(),
|
||||
[this](const std::pair<uint32_t, std::unordered_map<uint32_t, uint32_t>>&
|
||||
left,
|
||||
const std::pair<uint32_t, std::unordered_map<uint32_t, uint32_t>>&
|
||||
right) {
|
||||
const double lf = opcode_freq_[left.first];
|
||||
const double rf = opcode_freq_[right.first];
|
||||
if (lf == rf) return right.first > left.first;
|
||||
return lf > rf;
|
||||
});
|
||||
|
||||
for (const auto& kv : sorted_cue_to_hist) {
|
||||
const uint32_t cue = kv.first;
|
||||
const double kFrequentEnoughToAnalyze = 0.0001;
|
||||
if (opcode_freq_[cue] < kFrequentEnoughToAnalyze) continue;
|
||||
|
||||
const std::unordered_map<uint32_t, uint32_t>& hist = kv.second;
|
||||
|
||||
uint32_t total = 0;
|
||||
for (const auto& pair : hist) {
|
||||
total += pair.second;
|
||||
}
|
||||
|
||||
std::vector<std::pair<uint32_t, uint32_t>> sorted_hist(hist.begin(),
|
||||
hist.end());
|
||||
std::sort(sorted_hist.begin(), sorted_hist.end(),
|
||||
[](const std::pair<uint32_t, uint32_t>& left,
|
||||
const std::pair<uint32_t, uint32_t>& right) {
|
||||
if (left.second == right.second)
|
||||
return right.first > left.first;
|
||||
return left.second > right.second;
|
||||
});
|
||||
|
||||
for (const auto& pair : sorted_hist) {
|
||||
const double prior = opcode_freq_[pair.first];
|
||||
const double posterior =
|
||||
static_cast<double>(pair.second) / static_cast<double>(total);
|
||||
out << GetOpcodeString(cue) << " -> " << GetOpcodeString(pair.first)
|
||||
<< " " << posterior * 100 << "% (base rate " << prior * 100
|
||||
<< "%, pair occurrences " << pair.second << ")" << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace stats
|
||||
} // namespace spvtools
|
@ -1,58 +0,0 @@
|
||||
// Copyright (c) 2017 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef TOOLS_STATS_STATS_ANALYZER_H_
|
||||
#define TOOLS_STATS_STATS_ANALYZER_H_
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tools/stats/spirv_stats.h"
|
||||
|
||||
namespace spvtools {
|
||||
namespace stats {
|
||||
|
||||
class StatsAnalyzer {
|
||||
public:
|
||||
explicit StatsAnalyzer(const SpirvStats& stats);
|
||||
|
||||
// Writes respective histograms to |out|.
|
||||
void WriteVersion(std::ostream& out);
|
||||
void WriteGenerator(std::ostream& out);
|
||||
void WriteCapability(std::ostream& out);
|
||||
void WriteExtension(std::ostream& out);
|
||||
void WriteOpcode(std::ostream& out);
|
||||
void WriteConstantLiterals(std::ostream& out);
|
||||
|
||||
// Writes first order Markov analysis to |out|.
|
||||
// stats_.opcode_markov_hist needs to contain raw data for at least one
|
||||
// level.
|
||||
void WriteOpcodeMarkov(std::ostream& out);
|
||||
|
||||
private:
|
||||
const SpirvStats& stats_;
|
||||
|
||||
uint32_t num_modules_;
|
||||
|
||||
std::unordered_map<uint32_t, double> version_freq_;
|
||||
std::unordered_map<uint32_t, double> generator_freq_;
|
||||
std::unordered_map<uint32_t, double> capability_freq_;
|
||||
std::unordered_map<std::string, double> extension_freq_;
|
||||
std::unordered_map<uint32_t, double> opcode_freq_;
|
||||
};
|
||||
|
||||
} // namespace stats
|
||||
} // namespace spvtools
|
||||
|
||||
#endif // TOOLS_STATS_STATS_ANALYZER_H_
|
Loading…
Reference in New Issue
Block a user