CommonWindows: Split out code invalidation logic from WOW64

This will also be used by FEX ARM64EC module.
This commit is contained in:
Billy Laws 2023-11-17 23:59:37 +00:00
parent 1115ce4a95
commit 9f311cd97e
5 changed files with 137 additions and 103 deletions

View File

@ -1,4 +1,4 @@
add_library(CommonWindows STATIC CPUFeatures.cpp)
add_library(CommonWindows STATIC CPUFeatures.cpp InvalidationTracker.cpp)
target_link_libraries(CommonWindows FEXCore_Base)
target_include_directories(CommonWindows PRIVATE

View File

@ -0,0 +1,98 @@
// SPDX-License-Identifier: MIT
#include <FEXCore/Utils/LogManager.h>
#include <FEXCore/Core/Context.h>
#include <FEXCore/Debug/InternalThreadState.h>
#include "InvalidationTracker.h"
#include <windef.h>
#include <winternl.h>
namespace FEX::Windows {
void InvalidationTracker::HandleMemoryProtectionNotification(FEXCore::Core::InternalThreadState *Thread, uint64_t Address, uint64_t Size, ULONG Prot) {
const auto AlignedBase = Address & FHU::FEX_PAGE_MASK;
const auto AlignedSize = (Address - AlignedBase + Size + FHU::FEX_PAGE_SIZE - 1) & FHU::FEX_PAGE_MASK;
if (Prot & (PAGE_EXECUTE | PAGE_EXECUTE_READ | PAGE_EXECUTE_READWRITE)) {
Thread->CTX->InvalidateGuestCodeRange(Thread, AlignedBase, AlignedSize);
}
if (Prot & PAGE_EXECUTE_READWRITE) {
LogMan::Msg::DFmt("Add SMC interval: {:X} - {:X}", AlignedBase, AlignedBase + AlignedSize);
std::scoped_lock Lock(RWXIntervalsLock);
RWXIntervals.Insert({AlignedBase, AlignedBase + AlignedSize});
} else {
std::scoped_lock Lock(RWXIntervalsLock);
RWXIntervals.Remove({AlignedBase, AlignedBase + AlignedSize});
}
}
void InvalidationTracker::InvalidateContainingSection(FEXCore::Core::InternalThreadState *Thread, uint64_t Address, bool Free) {
MEMORY_BASIC_INFORMATION Info;
if (NtQueryVirtualMemory(NtCurrentProcess(), reinterpret_cast<void *>(Address), MemoryBasicInformation, &Info, sizeof(Info), nullptr))
return;
const auto SectionBase = reinterpret_cast<uint64_t>(Info.AllocationBase);
const auto SectionSize = reinterpret_cast<uint64_t>(Info.BaseAddress) + Info.RegionSize
- reinterpret_cast<uint64_t>(Info.AllocationBase);
Thread->CTX->InvalidateGuestCodeRange(Thread, SectionBase, SectionSize);
if (Free) {
std::scoped_lock Lock(RWXIntervalsLock);
RWXIntervals.Remove({SectionBase, SectionBase + SectionSize});
}
}
void InvalidationTracker::InvalidateAlignedInterval(FEXCore::Core::InternalThreadState *Thread, uint64_t Address, uint64_t Size, bool Free) {
const auto AlignedBase = Address & FHU::FEX_PAGE_MASK;
const auto AlignedSize = (Address - AlignedBase + Size + FHU::FEX_PAGE_SIZE - 1) & FHU::FEX_PAGE_MASK;
Thread->CTX->InvalidateGuestCodeRange(Thread, AlignedBase, AlignedSize);
if (Free) {
std::scoped_lock Lock(RWXIntervalsLock);
RWXIntervals.Remove({AlignedBase, AlignedBase + AlignedSize});
}
}
void InvalidationTracker::ReprotectRWXIntervals(uint64_t Address, uint64_t Size) {
const auto End = Address + Size;
std::scoped_lock Lock(RWXIntervalsLock);
do {
const auto Query = RWXIntervals.Query(Address);
if (Query.Enclosed) {
void *TmpAddress = reinterpret_cast<void *>(Address);
SIZE_T TmpSize = static_cast<SIZE_T>(std::min(End, Address + Query.Size) - Address);
ULONG TmpProt;
NtProtectVirtualMemory(NtCurrentProcess(), &TmpAddress, &TmpSize, PAGE_EXECUTE_READ, &TmpProt);
} else if (!Query.Size) {
// No more regions past `Address` in the interval list
break;
}
Address += Query.Size;
} while (Address < End);
}
bool InvalidationTracker::HandleRWXAccessViolation(FEXCore::Core::InternalThreadState *Thread, uint64_t FaultAddress) {
const bool NeedsInvalidate = [&](uint64_t Address) {
std::unique_lock Lock(RWXIntervalsLock);
const bool Enclosed = RWXIntervals.Query(Address).Enclosed;
// Invalidate just the single faulting page
if (!Enclosed)
return false;
ULONG TmpProt;
void *TmpAddress = reinterpret_cast<void *>(Address);
SIZE_T TmpSize = 1;
NtProtectVirtualMemory(NtCurrentProcess(), &TmpAddress, &TmpSize, PAGE_EXECUTE_READWRITE, &TmpProt);
return true;
}(FaultAddress);
if (NeedsInvalidate) {
// RWXIntervalsLock cannot be held during invalidation
Thread->CTX->InvalidateGuestCodeRange(Thread, FaultAddress & FHU::FEX_PAGE_MASK, FHU::FEX_PAGE_SIZE);
return true;
}
return false;
}
}

View File

@ -0,0 +1,28 @@
// SPDX-License-Identifier: MIT
// FIXME TODO put in cpp
#pragma once
#include "IntervalList.h"
#include <mutex>
namespace FEXCore::Core {
struct InternalThreadState;
}
namespace FEX::Windows {
/**
* @brief Handles SMC and regular code invalidation
*/
class InvalidationTracker {
public:
void HandleMemoryProtectionNotification(FEXCore::Core::InternalThreadState *Thread, uint64_t Address, uint64_t Size, ULONG Prot);
void InvalidateContainingSection(FEXCore::Core::InternalThreadState *Thread, uint64_t Address, bool Free);
void InvalidateAlignedInterval(FEXCore::Core::InternalThreadState *Thread, uint64_t Address, uint64_t Size, bool Free);
void ReprotectRWXIntervals(uint64_t Address, uint64_t Size);
bool HandleRWXAccessViolation(FEXCore::Core::InternalThreadState *Thread, uint64_t FaultAddress);
private:
IntervalList<uint64_t> RWXIntervals;
std::mutex RWXIntervalsLock;
};
}

View File

@ -26,10 +26,10 @@ $end_info$
#include <FEXHeaderUtils/TypeDefines.h>
#include "Common/Config.h"
#include "Common/InvalidationTracker.h"
#include "Common/CPUFeatures.h"
#include "DummyHandlers.h"
#include "BTInterface.h"
#include "IntervalList.h"
#include <cstdint>
#include <type_traits>
@ -93,6 +93,7 @@ namespace {
fextl::unique_ptr<FEX::DummyHandlers::DummySignalDelegator> SignalDelegator;
fextl::unique_ptr<WowSyscallHandler> SyscallHandler;
FEX::Windows::InvalidationTracker InvalidationTracker;
std::optional<FEX::Windows::CPUFeatures> CPUFeatures;
std::mutex ThreadSuspendLock;
@ -327,99 +328,6 @@ namespace Context {
}
}
namespace Invalidation {
static IntervalList<uint64_t> RWXIntervals;
static std::mutex RWXIntervalsLock;
void HandleMemoryProtectionNotification(uint64_t Address, uint64_t Size, ULONG Prot) {
const auto AlignedBase = Address & FHU::FEX_PAGE_MASK;
const auto AlignedSize = (Address - AlignedBase + Size + FHU::FEX_PAGE_SIZE - 1) & FHU::FEX_PAGE_MASK;
if (Prot & (PAGE_EXECUTE | PAGE_EXECUTE_READ | PAGE_EXECUTE_READWRITE)) {
CTX->InvalidateGuestCodeRange(GetTLS().ThreadState(), AlignedBase, AlignedSize);
}
if (Prot & PAGE_EXECUTE_READWRITE) {
LogMan::Msg::DFmt("Add SMC interval: {:X} - {:X}", AlignedBase, AlignedBase + AlignedSize);
std::scoped_lock Lock(RWXIntervalsLock);
RWXIntervals.Insert({AlignedBase, AlignedBase + AlignedSize});
} else {
std::scoped_lock Lock(RWXIntervalsLock);
RWXIntervals.Remove({AlignedBase, AlignedBase + AlignedSize});
}
}
void InvalidateContainingSection(uint64_t Address, bool Free) {
MEMORY_BASIC_INFORMATION Info;
if (NtQueryVirtualMemory(NtCurrentProcess(), reinterpret_cast<void *>(Address), MemoryBasicInformation, &Info, sizeof(Info), nullptr))
return;
const auto SectionBase = reinterpret_cast<uint64_t>(Info.AllocationBase);
const auto SectionSize = reinterpret_cast<uint64_t>(Info.BaseAddress) + Info.RegionSize
- reinterpret_cast<uint64_t>(Info.AllocationBase);
CTX->InvalidateGuestCodeRange(GetTLS().ThreadState(), SectionBase, SectionSize);
if (Free) {
std::scoped_lock Lock(RWXIntervalsLock);
RWXIntervals.Remove({SectionBase, SectionBase + SectionSize});
}
}
void InvalidateAlignedInterval(uint64_t Address, uint64_t Size, bool Free) {
const auto AlignedBase = Address & FHU::FEX_PAGE_MASK;
const auto AlignedSize = (Address - AlignedBase + Size + FHU::FEX_PAGE_SIZE - 1) & FHU::FEX_PAGE_MASK;
CTX->InvalidateGuestCodeRange(GetTLS().ThreadState(), AlignedBase, AlignedSize);
if (Free) {
std::scoped_lock Lock(RWXIntervalsLock);
RWXIntervals.Remove({AlignedBase, AlignedBase + AlignedSize});
}
}
void ReprotectRWXIntervals(uint64_t Address, uint64_t Size) {
const auto End = Address + Size;
std::scoped_lock Lock(RWXIntervalsLock);
do {
const auto Query = RWXIntervals.Query(Address);
if (Query.Enclosed) {
void *TmpAddress = reinterpret_cast<void *>(Address);
SIZE_T TmpSize = static_cast<SIZE_T>(std::min(End, Address + Query.Size) - Address);
ULONG TmpProt;
NtProtectVirtualMemory(NtCurrentProcess(), &TmpAddress, &TmpSize, PAGE_EXECUTE_READ, &TmpProt);
} else if (!Query.Size) {
// No more regions past `Address` in the interval list
break;
}
Address += Query.Size;
} while (Address < End);
}
bool HandleRWXAccessViolation(uint64_t FaultAddress) {
const bool NeedsInvalidate = [](uint64_t Address) {
std::unique_lock Lock(RWXIntervalsLock);
const bool Enclosed = RWXIntervals.Query(Address).Enclosed;
// Invalidate just the single faulting page
if (!Enclosed)
return false;
ULONG TmpProt;
void *TmpAddress = reinterpret_cast<void *>(Address);
SIZE_T TmpSize = 1;
NtProtectVirtualMemory(NtCurrentProcess(), &TmpAddress, &TmpSize, PAGE_EXECUTE_READWRITE, &TmpProt);
return true;
}(FaultAddress);
if (NeedsInvalidate) {
// RWXIntervalsLock cannot be held during invalidation
CTX->InvalidateGuestCodeRange(GetTLS().ThreadState(), FaultAddress & FHU::FEX_PAGE_MASK, FHU::FEX_PAGE_SIZE);
return true;
}
return false;
}
}
namespace Logging {
void MsgHandler(LogMan::DebugLevels Level, char const *Message) {
const auto Output = fextl::fmt::format("[{}][{:X}] {}\n", LogMan::DebugLevelStr(Level), GetCurrentThreadId(), Message);
@ -491,7 +399,7 @@ public:
}
void MarkGuestExecutableRange(FEXCore::Core::InternalThreadState *Thread, uint64_t Start, uint64_t Length) override {
Invalidation::ReprotectRWXIntervals(Start, Length);
InvalidationTracker.ReprotectRWXIntervals(Start, Length);
}
};
@ -712,7 +620,7 @@ NTSTATUS BTCpuResetToConsistentState(EXCEPTION_POINTERS *Ptrs) {
if (Exception->ExceptionCode == EXCEPTION_ACCESS_VIOLATION) {
const auto FaultAddress = static_cast<uint64_t>(Exception->ExceptionInformation[1]);
if (Invalidation::HandleRWXAccessViolation(FaultAddress)) {
if (InvalidationTracker.HandleRWXAccessViolation(GetTLS().ThreadState(), FaultAddress)) {
LogMan::Msg::DFmt("Handled self-modifying code: pc: {:X} fault: {:X}", Context->Pc, FaultAddress);
NtContinue(Context, FALSE);
}
@ -742,29 +650,29 @@ NTSTATUS BTCpuResetToConsistentState(EXCEPTION_POINTERS *Ptrs) {
}
void BTCpuFlushInstructionCache2(const void *Address, SIZE_T Size) {
Invalidation::InvalidateAlignedInterval(reinterpret_cast<uint64_t>(Address), static_cast<uint64_t>(Size), false);
InvalidationTracker.InvalidateAlignedInterval(GetTLS().ThreadState(), reinterpret_cast<uint64_t>(Address), static_cast<uint64_t>(Size), false);
}
void BTCpuNotifyMemoryAlloc(void *Address, SIZE_T Size, ULONG Type, ULONG Prot) {
Invalidation::HandleMemoryProtectionNotification(reinterpret_cast<uint64_t>(Address), static_cast<uint64_t>(Size),
InvalidationTracker.HandleMemoryProtectionNotification(GetTLS().ThreadState(), reinterpret_cast<uint64_t>(Address), static_cast<uint64_t>(Size),
Prot);
}
void BTCpuNotifyMemoryProtect(void *Address, SIZE_T Size, ULONG NewProt) {
Invalidation::HandleMemoryProtectionNotification(reinterpret_cast<uint64_t>(Address), static_cast<uint64_t>(Size),
InvalidationTracker.HandleMemoryProtectionNotification(GetTLS().ThreadState(), reinterpret_cast<uint64_t>(Address), static_cast<uint64_t>(Size),
NewProt);
}
void BTCpuNotifyMemoryFree(void *Address, SIZE_T Size, ULONG FreeType) {
if (!Size) {
Invalidation::InvalidateContainingSection(reinterpret_cast<uint64_t>(Address), true);
InvalidationTracker.InvalidateContainingSection(GetTLS().ThreadState(), reinterpret_cast<uint64_t>(Address), true);
} else if (FreeType & MEM_DECOMMIT) {
Invalidation::InvalidateAlignedInterval(reinterpret_cast<uint64_t>(Address), static_cast<uint64_t>(Size), true);
InvalidationTracker.InvalidateAlignedInterval(GetTLS().ThreadState(), reinterpret_cast<uint64_t>(Address), static_cast<uint64_t>(Size), true);
}
}
void BTCpuNotifyUnmapViewOfSection(void *Address, ULONG Flags) {
Invalidation::InvalidateContainingSection(reinterpret_cast<uint64_t>(Address), true);
InvalidationTracker.InvalidateContainingSection(GetTLS().ThreadState(), reinterpret_cast<uint64_t>(Address), true);
}
BOOLEAN WINAPI BTCpuIsProcessorFeaturePresent(UINT Feature) {