WOW64: Keep track of all created threads on the frontend

This is necessary so that code can be invalidated across all threads
rather than just the initiator on any event that triggers invalidation.
This commit is contained in:
Billy Laws 2024-02-17 23:19:05 +00:00
parent 5cb11aed3d
commit d92580bccf

View File

@ -36,7 +36,7 @@ $end_info$
#include <atomic> #include <atomic>
#include <mutex> #include <mutex>
#include <utility> #include <utility>
#include <unordered_set> #include <unordered_map>
#include <ntstatus.h> #include <ntstatus.h>
#include <windef.h> #include <windef.h>
#include <winternl.h> #include <winternl.h>
@ -97,8 +97,9 @@ fextl::unique_ptr<WowSyscallHandler> SyscallHandler;
FEX::Windows::InvalidationTracker InvalidationTracker; FEX::Windows::InvalidationTracker InvalidationTracker;
std::optional<FEX::Windows::CPUFeatures> CPUFeatures; std::optional<FEX::Windows::CPUFeatures> CPUFeatures;
std::mutex ThreadSuspendLock; std::mutex ThreadCreationMutex;
std::unordered_set<DWORD> InitializedWOWThreads; // Set of TIDs, `ThreadSuspendLock` must be locked when accessing // Map of TIDs to their FEX thread state, `ThreadCreationMutex` must be locked when accessing
std::unordered_map<DWORD, FEXCore::Core::InternalThreadState*> Threads;
std::pair<NTSTATUS, TLS> GetThreadTLS(HANDLE Thread) { std::pair<NTSTATUS, TLS> GetThreadTLS(HANDLE Thread) {
THREAD_BASIC_INFORMATION Info; THREAD_BASIC_INFORMATION Info;
@ -426,10 +427,11 @@ void BTCpuProcessInit() {
} }
NTSTATUS BTCpuThreadInit() { NTSTATUS BTCpuThreadInit() {
GetTLS().ThreadState() = CTX->CreateThread(0, 0); auto* Thread = CTX->CreateThread(0, 0);
GetTLS().ThreadState() = Thread;
std::scoped_lock Lock(ThreadSuspendLock); std::scoped_lock Lock(ThreadCreationMutex);
InitializedWOWThreads.emplace(GetCurrentThreadId()); Threads.emplace(GetCurrentThreadId(), Thread);
return STATUS_SUCCESS; return STATUS_SUCCESS;
} }
@ -446,8 +448,8 @@ NTSTATUS BTCpuThreadTerm(HANDLE Thread) {
} }
const auto ThreadTID = reinterpret_cast<uint64_t>(Info.ClientId.UniqueThread); const auto ThreadTID = reinterpret_cast<uint64_t>(Info.ClientId.UniqueThread);
std::scoped_lock Lock(ThreadSuspendLock); std::scoped_lock Lock(ThreadCreationMutex);
InitializedWOWThreads.erase(ThreadTID); Threads.erase(ThreadTID);
} }
CTX->DestroyThread(TLS.ThreadState()); CTX->DestroyThread(TLS.ThreadState());
@ -550,10 +552,10 @@ NTSTATUS BTCpuSuspendLocalThread(HANDLE Thread, ULONG* Count) {
return Err; return Err;
} }
std::scoped_lock Lock(ThreadSuspendLock); std::scoped_lock Lock(ThreadCreationMutex);
// If the thread hasn't yet been initialized, suspend it without special handling as it wont yet have entered the JIT // If the thread hasn't yet been initialized, suspend it without special handling as it wont yet have entered the JIT
if (!InitializedWOWThreads.contains(ThreadTID)) { if (!Threads.contains(ThreadTID)) {
return NtSuspendThread(Thread, Count); return NtSuspendThread(Thread, Count);
} }