diff --git a/Source/Windows/Common/CMakeLists.txt b/Source/Windows/Common/CMakeLists.txt index 680d8f747..4b5b47b96 100644 --- a/Source/Windows/Common/CMakeLists.txt +++ b/Source/Windows/Common/CMakeLists.txt @@ -1,4 +1,4 @@ -add_library(CommonWindows STATIC CPUFeatures.cpp) +add_library(CommonWindows STATIC CPUFeatures.cpp InvalidationTracker.cpp) target_link_libraries(CommonWindows FEXCore_Base) target_include_directories(CommonWindows PRIVATE diff --git a/Source/Windows/WOW64/IntervalList.h b/Source/Windows/Common/IntervalList.h similarity index 100% rename from Source/Windows/WOW64/IntervalList.h rename to Source/Windows/Common/IntervalList.h diff --git a/Source/Windows/Common/InvalidationTracker.cpp b/Source/Windows/Common/InvalidationTracker.cpp new file mode 100644 index 000000000..429352c2d --- /dev/null +++ b/Source/Windows/Common/InvalidationTracker.cpp @@ -0,0 +1,98 @@ +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include "InvalidationTracker.h" +#include +#include + +namespace FEX::Windows { +void InvalidationTracker::HandleMemoryProtectionNotification(FEXCore::Core::InternalThreadState *Thread, uint64_t Address, uint64_t Size, ULONG Prot) { + const auto AlignedBase = Address & FHU::FEX_PAGE_MASK; + const auto AlignedSize = (Address - AlignedBase + Size + FHU::FEX_PAGE_SIZE - 1) & FHU::FEX_PAGE_MASK; + + if (Prot & (PAGE_EXECUTE | PAGE_EXECUTE_READ | PAGE_EXECUTE_READWRITE)) { + Thread->CTX->InvalidateGuestCodeRange(Thread, AlignedBase, AlignedSize); + } + + if (Prot & PAGE_EXECUTE_READWRITE) { + LogMan::Msg::DFmt("Add SMC interval: {:X} - {:X}", AlignedBase, AlignedBase + AlignedSize); + std::scoped_lock Lock(RWXIntervalsLock); + RWXIntervals.Insert({AlignedBase, AlignedBase + AlignedSize}); + } else { + std::scoped_lock Lock(RWXIntervalsLock); + RWXIntervals.Remove({AlignedBase, AlignedBase + AlignedSize}); + } +} + +void InvalidationTracker::InvalidateContainingSection(FEXCore::Core::InternalThreadState *Thread, uint64_t Address, bool Free) { + MEMORY_BASIC_INFORMATION Info; + if (NtQueryVirtualMemory(NtCurrentProcess(), reinterpret_cast(Address), MemoryBasicInformation, &Info, sizeof(Info), nullptr)) + return; + + const auto SectionBase = reinterpret_cast(Info.AllocationBase); + const auto SectionSize = reinterpret_cast(Info.BaseAddress) + Info.RegionSize + - reinterpret_cast(Info.AllocationBase); + Thread->CTX->InvalidateGuestCodeRange(Thread, SectionBase, SectionSize); + + if (Free) { + std::scoped_lock Lock(RWXIntervalsLock); + RWXIntervals.Remove({SectionBase, SectionBase + SectionSize}); + } +} + +void InvalidationTracker::InvalidateAlignedInterval(FEXCore::Core::InternalThreadState *Thread, uint64_t Address, uint64_t Size, bool Free) { + const auto AlignedBase = Address & FHU::FEX_PAGE_MASK; + const auto AlignedSize = (Address - AlignedBase + Size + FHU::FEX_PAGE_SIZE - 1) & FHU::FEX_PAGE_MASK; + Thread->CTX->InvalidateGuestCodeRange(Thread, AlignedBase, AlignedSize); + + if (Free) { + std::scoped_lock Lock(RWXIntervalsLock); + RWXIntervals.Remove({AlignedBase, AlignedBase + AlignedSize}); + } +} + +void InvalidationTracker::ReprotectRWXIntervals(uint64_t Address, uint64_t Size) { + const auto End = Address + Size; + std::scoped_lock Lock(RWXIntervalsLock); + + do { + const auto Query = RWXIntervals.Query(Address); + if (Query.Enclosed) { + void *TmpAddress = reinterpret_cast(Address); + SIZE_T TmpSize = static_cast(std::min(End, Address + Query.Size) - Address); + ULONG TmpProt; + NtProtectVirtualMemory(NtCurrentProcess(), &TmpAddress, &TmpSize, PAGE_EXECUTE_READ, &TmpProt); + } else if (!Query.Size) { + // No more regions past `Address` in the interval list + break; + } + + Address += Query.Size; + } while (Address < End); +} + +bool InvalidationTracker::HandleRWXAccessViolation(FEXCore::Core::InternalThreadState *Thread, uint64_t FaultAddress) { + const bool NeedsInvalidate = [&](uint64_t Address) { + std::unique_lock Lock(RWXIntervalsLock); + const bool Enclosed = RWXIntervals.Query(Address).Enclosed; + // Invalidate just the single faulting page + if (!Enclosed) + return false; + + ULONG TmpProt; + void *TmpAddress = reinterpret_cast(Address); + SIZE_T TmpSize = 1; + NtProtectVirtualMemory(NtCurrentProcess(), &TmpAddress, &TmpSize, PAGE_EXECUTE_READWRITE, &TmpProt); + return true; + }(FaultAddress); + + if (NeedsInvalidate) { + // RWXIntervalsLock cannot be held during invalidation + Thread->CTX->InvalidateGuestCodeRange(Thread, FaultAddress & FHU::FEX_PAGE_MASK, FHU::FEX_PAGE_SIZE); + return true; + } + return false; +} +} diff --git a/Source/Windows/Common/InvalidationTracker.h b/Source/Windows/Common/InvalidationTracker.h new file mode 100644 index 000000000..86aa8afce --- /dev/null +++ b/Source/Windows/Common/InvalidationTracker.h @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +// FIXME TODO put in cpp +#pragma once + +#include "IntervalList.h" +#include + +namespace FEXCore::Core { + struct InternalThreadState; +} + +namespace FEX::Windows { +/** + * @brief Handles SMC and regular code invalidation + */ +class InvalidationTracker { +public: + void HandleMemoryProtectionNotification(FEXCore::Core::InternalThreadState *Thread, uint64_t Address, uint64_t Size, ULONG Prot); + void InvalidateContainingSection(FEXCore::Core::InternalThreadState *Thread, uint64_t Address, bool Free); + void InvalidateAlignedInterval(FEXCore::Core::InternalThreadState *Thread, uint64_t Address, uint64_t Size, bool Free); + void ReprotectRWXIntervals(uint64_t Address, uint64_t Size); + bool HandleRWXAccessViolation(FEXCore::Core::InternalThreadState *Thread, uint64_t FaultAddress); + +private: + IntervalList RWXIntervals; + std::mutex RWXIntervalsLock; +}; +} diff --git a/Source/Windows/WOW64/Module.cpp b/Source/Windows/WOW64/Module.cpp index 384f6ae07..cdd4b8e40 100644 --- a/Source/Windows/WOW64/Module.cpp +++ b/Source/Windows/WOW64/Module.cpp @@ -26,10 +26,10 @@ $end_info$ #include #include "Common/Config.h" +#include "Common/InvalidationTracker.h" #include "Common/CPUFeatures.h" #include "DummyHandlers.h" #include "BTInterface.h" -#include "IntervalList.h" #include #include @@ -93,6 +93,7 @@ namespace { fextl::unique_ptr SignalDelegator; fextl::unique_ptr SyscallHandler; + FEX::Windows::InvalidationTracker InvalidationTracker; std::optional CPUFeatures; std::mutex ThreadSuspendLock; @@ -327,99 +328,6 @@ namespace Context { } } -namespace Invalidation { - static IntervalList RWXIntervals; - static std::mutex RWXIntervalsLock; - - void HandleMemoryProtectionNotification(uint64_t Address, uint64_t Size, ULONG Prot) { - const auto AlignedBase = Address & FHU::FEX_PAGE_MASK; - const auto AlignedSize = (Address - AlignedBase + Size + FHU::FEX_PAGE_SIZE - 1) & FHU::FEX_PAGE_MASK; - - if (Prot & (PAGE_EXECUTE | PAGE_EXECUTE_READ | PAGE_EXECUTE_READWRITE)) { - CTX->InvalidateGuestCodeRange(GetTLS().ThreadState(), AlignedBase, AlignedSize); - } - - if (Prot & PAGE_EXECUTE_READWRITE) { - LogMan::Msg::DFmt("Add SMC interval: {:X} - {:X}", AlignedBase, AlignedBase + AlignedSize); - std::scoped_lock Lock(RWXIntervalsLock); - RWXIntervals.Insert({AlignedBase, AlignedBase + AlignedSize}); - } else { - std::scoped_lock Lock(RWXIntervalsLock); - RWXIntervals.Remove({AlignedBase, AlignedBase + AlignedSize}); - } - } - - void InvalidateContainingSection(uint64_t Address, bool Free) { - MEMORY_BASIC_INFORMATION Info; - if (NtQueryVirtualMemory(NtCurrentProcess(), reinterpret_cast(Address), MemoryBasicInformation, &Info, sizeof(Info), nullptr)) - return; - - const auto SectionBase = reinterpret_cast(Info.AllocationBase); - const auto SectionSize = reinterpret_cast(Info.BaseAddress) + Info.RegionSize - - reinterpret_cast(Info.AllocationBase); - CTX->InvalidateGuestCodeRange(GetTLS().ThreadState(), SectionBase, SectionSize); - - if (Free) { - std::scoped_lock Lock(RWXIntervalsLock); - RWXIntervals.Remove({SectionBase, SectionBase + SectionSize}); - } - } - - void InvalidateAlignedInterval(uint64_t Address, uint64_t Size, bool Free) { - const auto AlignedBase = Address & FHU::FEX_PAGE_MASK; - const auto AlignedSize = (Address - AlignedBase + Size + FHU::FEX_PAGE_SIZE - 1) & FHU::FEX_PAGE_MASK; - CTX->InvalidateGuestCodeRange(GetTLS().ThreadState(), AlignedBase, AlignedSize); - - if (Free) { - std::scoped_lock Lock(RWXIntervalsLock); - RWXIntervals.Remove({AlignedBase, AlignedBase + AlignedSize}); - } - } - - void ReprotectRWXIntervals(uint64_t Address, uint64_t Size) { - const auto End = Address + Size; - std::scoped_lock Lock(RWXIntervalsLock); - - do { - const auto Query = RWXIntervals.Query(Address); - if (Query.Enclosed) { - void *TmpAddress = reinterpret_cast(Address); - SIZE_T TmpSize = static_cast(std::min(End, Address + Query.Size) - Address); - ULONG TmpProt; - NtProtectVirtualMemory(NtCurrentProcess(), &TmpAddress, &TmpSize, PAGE_EXECUTE_READ, &TmpProt); - } else if (!Query.Size) { - // No more regions past `Address` in the interval list - break; - } - - Address += Query.Size; - } while (Address < End); - } - - bool HandleRWXAccessViolation(uint64_t FaultAddress) { - const bool NeedsInvalidate = [](uint64_t Address) { - std::unique_lock Lock(RWXIntervalsLock); - const bool Enclosed = RWXIntervals.Query(Address).Enclosed; - // Invalidate just the single faulting page - if (!Enclosed) - return false; - - ULONG TmpProt; - void *TmpAddress = reinterpret_cast(Address); - SIZE_T TmpSize = 1; - NtProtectVirtualMemory(NtCurrentProcess(), &TmpAddress, &TmpSize, PAGE_EXECUTE_READWRITE, &TmpProt); - return true; - }(FaultAddress); - - if (NeedsInvalidate) { - // RWXIntervalsLock cannot be held during invalidation - CTX->InvalidateGuestCodeRange(GetTLS().ThreadState(), FaultAddress & FHU::FEX_PAGE_MASK, FHU::FEX_PAGE_SIZE); - return true; - } - return false; - } -} - namespace Logging { void MsgHandler(LogMan::DebugLevels Level, char const *Message) { const auto Output = fextl::fmt::format("[{}][{:X}] {}\n", LogMan::DebugLevelStr(Level), GetCurrentThreadId(), Message); @@ -491,7 +399,7 @@ public: } void MarkGuestExecutableRange(FEXCore::Core::InternalThreadState *Thread, uint64_t Start, uint64_t Length) override { - Invalidation::ReprotectRWXIntervals(Start, Length); + InvalidationTracker.ReprotectRWXIntervals(Start, Length); } }; @@ -712,7 +620,7 @@ NTSTATUS BTCpuResetToConsistentState(EXCEPTION_POINTERS *Ptrs) { if (Exception->ExceptionCode == EXCEPTION_ACCESS_VIOLATION) { const auto FaultAddress = static_cast(Exception->ExceptionInformation[1]); - if (Invalidation::HandleRWXAccessViolation(FaultAddress)) { + if (InvalidationTracker.HandleRWXAccessViolation(GetTLS().ThreadState(), FaultAddress)) { LogMan::Msg::DFmt("Handled self-modifying code: pc: {:X} fault: {:X}", Context->Pc, FaultAddress); NtContinue(Context, FALSE); } @@ -742,29 +650,29 @@ NTSTATUS BTCpuResetToConsistentState(EXCEPTION_POINTERS *Ptrs) { } void BTCpuFlushInstructionCache2(const void *Address, SIZE_T Size) { - Invalidation::InvalidateAlignedInterval(reinterpret_cast(Address), static_cast(Size), false); + InvalidationTracker.InvalidateAlignedInterval(GetTLS().ThreadState(), reinterpret_cast(Address), static_cast(Size), false); } void BTCpuNotifyMemoryAlloc(void *Address, SIZE_T Size, ULONG Type, ULONG Prot) { - Invalidation::HandleMemoryProtectionNotification(reinterpret_cast(Address), static_cast(Size), + InvalidationTracker.HandleMemoryProtectionNotification(GetTLS().ThreadState(), reinterpret_cast(Address), static_cast(Size), Prot); } void BTCpuNotifyMemoryProtect(void *Address, SIZE_T Size, ULONG NewProt) { - Invalidation::HandleMemoryProtectionNotification(reinterpret_cast(Address), static_cast(Size), + InvalidationTracker.HandleMemoryProtectionNotification(GetTLS().ThreadState(), reinterpret_cast(Address), static_cast(Size), NewProt); } void BTCpuNotifyMemoryFree(void *Address, SIZE_T Size, ULONG FreeType) { if (!Size) { - Invalidation::InvalidateContainingSection(reinterpret_cast(Address), true); + InvalidationTracker.InvalidateContainingSection(GetTLS().ThreadState(), reinterpret_cast(Address), true); } else if (FreeType & MEM_DECOMMIT) { - Invalidation::InvalidateAlignedInterval(reinterpret_cast(Address), static_cast(Size), true); + InvalidationTracker.InvalidateAlignedInterval(GetTLS().ThreadState(), reinterpret_cast(Address), static_cast(Size), true); } } void BTCpuNotifyUnmapViewOfSection(void *Address, ULONG Flags) { - Invalidation::InvalidateContainingSection(reinterpret_cast(Address), true); + InvalidationTracker.InvalidateContainingSection(GetTLS().ThreadState(), reinterpret_cast(Address), true); } BOOLEAN WINAPI BTCpuIsProcessorFeaturePresent(UINT Feature) {