Savestate: Prepare some sanity checks

This commit is contained in:
Henrik Rydgård 2022-12-02 10:55:03 +01:00
parent 116bc9d59a
commit 237fbca979
2 changed files with 78 additions and 11 deletions

View File

@ -33,8 +33,25 @@ enum class SerializeCompressType {
static constexpr SerializeCompressType SAVE_TYPE = SerializeCompressType::ZSTD;
PointerWrapSection PointerWrap::Section(const char *title, int ver) {
return Section(title, ver, ver);
void PointerWrap::RewindForWrite(u8 *writePtr) {
_assert_(mode == MODE_MEASURE);
// Switch to writing mode and
mode = MODE_WRITE;
*ptr = writePtr;
ptrStart_ = writePtr;
}
bool PointerWrap::CheckAfterWrite() {
_assert_(mode == MODE_WRITE);
if (measuredSize_ != 0 && Offset() != measuredSize_) {
WARN_LOG(SAVESTATE, "CheckAfterWrite: Size mismatch! %d vs %d", (int)Offset(), (int)measuredSize_);
return false;
}
if (!checkpoints_.empty() && curCheckpoint_ != checkpoints_.size()) {
WARN_LOG(SAVESTATE, "Checkpoint count mismatch!");
return false;
}
return true;
}
PointerWrapSection PointerWrap::Section(const char *title, int minVer, int ver) {
@ -44,6 +61,30 @@ PointerWrapSection PointerWrap::Section(const char *title, int minVer, int ver)
// This is strncpy because we rely on its weird non-null-terminating zero-filling truncation behaviour.
// Can't replace it with the more sensible truncate_cpy because that would break savestates.
strncpy(marker, title, sizeof(marker));
// Compare the measure and write passes. Sanity check to catch bugs, doesn't do anything for output.
if (mode == MODE_MEASURE) {
checkpoints_.emplace_back(marker, Offset());
} else if (mode == MODE_WRITE) {
if (!checkpoints_.empty()) {
if (checkpoints_.size() <= curCheckpoint_) {
WARN_LOG(SAVESTATE, "Write: Not enough checkpoints from measure pass (%d). cur section: %s", (int)checkpoints_.size(), title);
SetError(ERROR_FAILURE);
return PointerWrapSection(*this, -1, title);
}
if (!checkpoints_[curCheckpoint_].Matches(marker, Offset())) {
WARN_LOG(SAVESTATE, "Checkpoint mismatch during write! Section %s vs %s, offset %d vs %d", title, marker, (int)Offset(), (int)checkpoints_[curCheckpoint_].offset);
if (curCheckpoint_ > 1) {
WARN_LOG(SAVESTATE, "Previous checkpoint: %s (%d)", checkpoints_[curCheckpoint_ - 1].title, (int)checkpoints_[curCheckpoint_ - 1].offset);
}
SetError(ERROR_FAILURE);
return PointerWrapSection(*this, -1, title);
}
} else {
WARN_LOG(SAVESTATE, "Writing savestate without checkpoints. This is OK but should be fixed.");
}
}
if (!ExpectVoid(marker, sizeof(marker))) {
// Might be before we added name markers for safety.
if (foundVersion == 1 && ExpectVoid(&foundVersion, sizeof(foundVersion))) {

View File

@ -30,6 +30,7 @@
// - Serialization code for anything complex has to be manually written.
#include <string>
#include <cstring>
#include <vector>
#include <cstdlib>
@ -52,8 +53,7 @@ class PointerWrap;
class PointerWrapSection
{
public:
PointerWrapSection(PointerWrap &p, int ver, const char *title) : p_(p), ver_(ver), title_(title) {
}
PointerWrapSection(PointerWrap &p, int ver, const char *title) : p_(p), ver_(ver), title_(title) {}
~PointerWrapSection();
bool operator == (const int &v) const { return ver_ == v; }
@ -73,6 +73,22 @@ private:
const char *title_;
};
// For measure vs write detailed verification
struct SerializeCheckpoint {
char title[17]; // 16-byte section header, plus a zero terminator for debug printing.
size_t offset;
SerializeCheckpoint(char _title[16], size_t off) {
memcpy(title, _title, 16);
title[16] = 0;
offset = off;
}
bool Matches(const char *_title, size_t off) const {
return memcmp(title, _title, 16) == 0 && off == offset;
}
};
// Wrapper class
class PointerWrap
{
@ -94,19 +110,23 @@ public:
Mode mode;
Error error = ERROR_NONE;
PointerWrap(u8 **ptr_, Mode mode_) : ptr(ptr_), mode(mode_) {}
PointerWrap(unsigned char **ptr_, int mode_) : ptr((u8**)ptr_), mode((Mode)mode_) {}
PointerWrap(u8 **ptr_, Mode mode_) : ptr(ptr_), ptrStart_(*ptr), mode(mode_) {}
PointerWrap(unsigned char **ptr_, int mode_) : ptr((u8**)ptr_), ptrStart_(*ptr), mode((Mode)mode_) {}
PointerWrapSection Section(const char *title, int ver);
void RewindForWrite(u8 *writePtr);
bool CheckAfterWrite();
// The returned object can be compared against the version that was loaded.
// This can be used to support versions as old as minVer.
// Version = 0 means the section was not found.
PointerWrapSection Section(const char *title, int minVer, int ver);
PointerWrapSection Section(const char *title, int ver) {
return Section(title, ver, ver);
}
void SetMode(Mode mode_) {mode = mode_;}
Mode GetMode() const {return mode;}
u8 **GetPPtr() {return ptr;}
void SetMode(Mode mode_) { mode = mode_; }
Mode GetMode() const { return mode; }
u8 **GetPPtr() { return ptr; }
void SetError(Error error_);
const char *GetBadSectionTitle() const {
@ -120,7 +140,13 @@ public:
void DoMarker(const char *prevName, u32 arbitraryNumber = 0x42);
private:
size_t Offset() const { return *ptr - ptrStart_; }
const char *firstBadSectionTitle_ = nullptr;
u8 *ptrStart_;
std::vector<SerializeCheckpoint> checkpoints_;
size_t curCheckpoint_ = 0;
size_t measuredSize_ = 0;
};
class CChunkFileReader
@ -204,7 +230,7 @@ public:
return ERROR_BAD_ALLOC;
Error error = SavePtr(buffer, _class, sz);
// SaveFile takes ownership of buffer
// SaveFile takes ownership of buffer (malloc/free)
if (error == ERROR_NONE)
error = SaveFile(filename, title, gitVersion, buffer, sz);
return error;