Address feedback

This commit is contained in:
Henrik Rydgård 2023-02-14 11:11:34 +01:00
parent 7f6e6ae985
commit a1e1386e0b
2 changed files with 25 additions and 15 deletions

View File

@ -255,7 +255,6 @@ public:
}
}
// Load file template
template<class T>
static Error Load(const Path &filename, std::string *gitVersion, T& _class, std::string *failureReason)

View File

@ -183,13 +183,22 @@ namespace SaveState
return CChunkFileReader::ERROR_BAD_FILE;
static std::vector<u8> buffer;
LockedDecompress(buffer, states_[n], bases_[baseMapping_[n]]);
CChunkFileReader::Error error = LoadFromRam(buffer, errorString);
rewindLastTime_ = time_now_d();
return error;
if (LockedDecompress(buffer, states_[n], bases_[baseMapping_[n]])) {
CChunkFileReader::Error error = LoadFromRam(buffer, errorString);
if (error == CChunkFileReader::ERROR_NONE) {
INFO_LOG(SAVESTATE, "Rewinding to recent savestate snapshot (%d bytes compressed)", states_[n].zstd_compressed.size());
rewindLastTime_ = time_now_d();
}
return error;
} else {
WARN_LOG(SAVESTATE, "Failed to load rewind savestate");
// Unclear what CChunkFileReader error code we should pass in this case, which I'm not sure will
// happen in practice barring memory corruption.
}
return CChunkFileReader::ERROR_NONE;
}
void ScheduleCompress(StateBuffer *result, std::vector<u8> *state, const std::vector<u8> *base)
void ScheduleCompress(StateBuffer *result, const std::vector<u8> *state, const std::vector<u8> *base)
{
if (compressThread_.joinable())
compressThread_.join();
@ -203,7 +212,7 @@ namespace SaveState
const bool USE_XOR = false;
void Compress(StateBuffer *result, std::vector<u8> &state, const std::vector<u8> &base)
void Compress(StateBuffer *result, const std::vector<u8> &state, const std::vector<u8> &base)
{
std::lock_guard<std::mutex> guard(lock_);
// Bail if we were cleared before locking.
@ -241,20 +250,18 @@ namespace SaveState
// Temporarily allocate a buffer to do compression in.
size_t compressCapacity = ZSTD_compressBound(compressed.size());
u8 *compress_buf = (u8 *)malloc(compressCapacity);
result->compressed_size = ZSTD_compress(compress_buf, compressCapacity, compressed.data(), compressed.size(), 0);
result->zstd_compressed.resize(compressCapacity);
result->compressed_size = ZSTD_compress(&result->zstd_compressed[0], compressCapacity, compressed.data(), compressed.size(), 1);
if (result->compressed_size) {
result->zstd_compressed = std::vector<u8>(result->compressed_size, 0);
memcpy(&result->zstd_compressed[0], compress_buf, result->compressed_size);
result->zstd_compressed.resize(result->compressed_size);
result->decompressed_size = compressed.size();
}
free(compress_buf);
double zstd_s = time_now_d() - start_time - taken_s;
DEBUG_LOG(SAVESTATE, "Rewind: ZSTD compressed to %d in %0.2f ms.", (int)result->compressed_size, zstd_s * 1000.0);
}
void LockedDecompress(std::vector<u8> &result, const StateBuffer &buffer, const std::vector<u8> &base)
bool LockedDecompress(std::vector<u8> &result, const StateBuffer &buffer, const std::vector<u8> &base)
{
result.clear();
result.reserve(base.size());
@ -262,7 +269,11 @@ namespace SaveState
// OK, zstd decompress first.
std::vector<u8> compressed = std::vector<u8>(buffer.decompressed_size, 0);
ZSTD_decompress(&compressed[0], compressed.size(), buffer.zstd_compressed.data(), buffer.zstd_compressed.size());
size_t retval = ZSTD_decompress(&compressed[0], compressed.size(), buffer.zstd_compressed.data(), buffer.zstd_compressed.size());
if (ZSTD_isError(retval)) {
WARN_LOG(SAVESTATE, "Failed to decompress zstd-compressed rewind savestate");
return false;
}
if (USE_XOR) {
result.resize(compressed.size());
@ -296,6 +307,7 @@ namespace SaveState
}
}
}
return true;
}
void Clear()
@ -1072,7 +1084,6 @@ namespace SaveState
break;
case SAVESTATE_REWIND:
INFO_LOG(SAVESTATE, "Rewinding to recent savestate snapshot");
result = rewindStates.Restore(&errorString);
if (result == CChunkFileReader::ERROR_NONE) {
callbackMessage = sc->T("Loaded State");