[Libomptarget] Remove RPCHandleTy indirection

The 'RPCHandleTy' was intended to capture the intention that a specific
device owns its slot in the RPC server. However, this required creating
a temporary store to hold these pointers. This was causing really weird
spurious failure due to undefined behaviour in the order of library
teardown. For example, the x64 plugin would be torn down, set this to
some invalid memory, and then the CUDA plugin would crash. Rather than
spend the time to fully diagnose this problem I found it pertinent to
simply remove the failure mode.

This patch removes this indirection so now the usage of the RPC server
must always be done with the intended device. This just requires some
extra handling for the AMDGPU indirection where we need to store a
reference to the device.

Reviewed By: JonChesterfield

Differential Revision: https://reviews.llvm.org/D154971
This commit is contained in:
Joseph Huber 2023-07-11 09:27:22 -05:00
parent 14742f2a68
commit 8a0763f19c
7 changed files with 27 additions and 70 deletions

View File

@ -520,9 +520,9 @@ struct AMDGPUSignalTy {
}
/// Wait until the signal gets a zero value.
Error wait(const uint64_t ActiveTimeout = 0,
RPCHandleTy *RPCHandle = nullptr) const {
if (ActiveTimeout && !RPCHandle) {
Error wait(const uint64_t ActiveTimeout = 0, RPCServerTy *RPCServer = nullptr,
GenericDeviceTy *Device = nullptr) const {
if (ActiveTimeout && !RPCServer) {
hsa_signal_value_t Got = 1;
Got = hsa_signal_wait_scacquire(Signal, HSA_SIGNAL_CONDITION_EQ, 0,
ActiveTimeout, HSA_WAIT_STATE_ACTIVE);
@ -531,12 +531,12 @@ struct AMDGPUSignalTy {
}
// If there is an RPC device attached to this stream we run it as a server.
uint64_t Timeout = RPCHandle ? 8192 : UINT64_MAX;
auto WaitState = RPCHandle ? HSA_WAIT_STATE_ACTIVE : HSA_WAIT_STATE_BLOCKED;
uint64_t Timeout = RPCServer ? 8192 : UINT64_MAX;
auto WaitState = RPCServer ? HSA_WAIT_STATE_ACTIVE : HSA_WAIT_STATE_BLOCKED;
while (hsa_signal_wait_scacquire(Signal, HSA_SIGNAL_CONDITION_EQ, 0,
Timeout, WaitState) != 0) {
if (RPCHandle)
if (auto Err = RPCHandle->runServer())
if (RPCServer && Device)
if (auto Err = RPCServer->runServer(*Device))
return Err;
}
return Plugin::success();
@ -888,6 +888,9 @@ private:
/// The manager of signals to reuse signals.
AMDGPUSignalManagerTy &SignalManager;
/// A reference to the associated device.
GenericDeviceTy &Device;
/// Array of stream slots. Use std::deque because it can dynamically grow
/// without invalidating the already inserted elements. For instance, the
/// std::vector may invalidate the elements by reallocating the internal
@ -907,7 +910,7 @@ private:
/// A pointer associated with an RPC server running on the given device. If
/// RPC is not being used this will be a null pointer. Otherwise, this
/// indicates that an RPC server is expected to be run on this stream.
RPCHandleTy *RPCHandle;
RPCServerTy *RPCServer;
/// Mutex to protect stream's management.
mutable std::mutex Mutex;
@ -1064,8 +1067,8 @@ public:
/// Deinitialize the stream's signals.
Error deinit() { return Plugin::success(); }
/// Attach an RPC handle to this stream.
void setRPCHandle(RPCHandleTy *Handle) { RPCHandle = Handle; }
/// Attach an RPC server to this stream.
void setRPCServer(RPCServerTy *Server) { RPCServer = Server; }
/// Push a asynchronous kernel to the stream. The kernel arguments must be
/// placed in a special allocation for kernel args and must keep alive until
@ -1281,8 +1284,8 @@ public:
return Plugin::success();
// Wait until all previous operations on the stream have completed.
if (auto Err =
Slots[last()].Signal->wait(StreamBusyWaitMicroseconds, RPCHandle))
if (auto Err = Slots[last()].Signal->wait(StreamBusyWaitMicroseconds,
RPCServer, &Device))
return Err;
// Reset the stream and perform all pending post actions.
@ -2529,9 +2532,9 @@ Error AMDGPUResourceRef<ResourceTy>::create(GenericDeviceTy &Device) {
AMDGPUStreamTy::AMDGPUStreamTy(AMDGPUDeviceTy &Device)
: Agent(Device.getAgent()), Queue(Device.getNextQueue()),
SignalManager(Device.getSignalManager()),
SignalManager(Device.getSignalManager()), Device(Device),
// Initialize the std::deque with some empty positions.
Slots(32), NextSlot(0), SyncCycle(0), RPCHandle(nullptr),
Slots(32), NextSlot(0), SyncCycle(0), RPCServer(nullptr),
StreamBusyWaitMicroseconds(Device.getStreamBusyWaitMicroseconds()) {}
/// Class implementing the AMDGPU-specific functionalities of the global
@ -2866,8 +2869,8 @@ Error AMDGPUKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
AMDGPUStreamTy &Stream = AMDGPUDevice.getStream(AsyncInfoWrapper);
// If this kernel requires an RPC server we attach its pointer to the stream.
if (GenericDevice.getRPCHandle())
Stream.setRPCHandle(GenericDevice.getRPCHandle());
if (GenericDevice.getRPCServer())
Stream.setRPCServer(GenericDevice.getRPCServer());
// Push the kernel launch into the stream.
return Stream.pushKernelLaunch(*this, AllArgs, NumThreads, NumBlocks,

View File

@ -70,7 +70,6 @@ elseif(${LIBOMPTARGET_GPU_LIBC_SUPPORT})
find_library(llvmlibc_rpc_server NAMES llvmlibc_rpc_server
PATHS ${LIBOMPTARGET_LLVM_LIBRARY_DIR} NO_DEFAULT_PATH)
if(llvmlibc_rpc_server)
message(WARNING ${llvmlibc_rpc_server})
target_link_libraries(PluginInterface PRIVATE llvmlibc_rpc_server)
target_compile_definitions(PluginInterface PRIVATE LIBOMPTARGET_RPC_SUPPORT)
endif()

View File

@ -401,7 +401,7 @@ GenericDeviceTy::GenericDeviceTy(int32_t DeviceId, int32_t NumDevices,
OMPX_InitialNumEvents("LIBOMPTARGET_NUM_INITIAL_EVENTS", 32),
DeviceId(DeviceId), GridValues(OMPGridValues),
PeerAccesses(NumDevices, PeerAccessState::PENDING), PeerAccessesLock(),
PinnedAllocs(*this), RPCHandle(nullptr) {
PinnedAllocs(*this), RPCServer(nullptr) {
#ifdef OMPT_SUPPORT
OmptInitialized.store(false);
// Bind the callbacks to this device's member functions
@ -483,8 +483,8 @@ Error GenericDeviceTy::deinit() {
if (RecordReplay.isRecordingOrReplaying())
RecordReplay.deinit();
if (RPCHandle)
if (auto Err = RPCHandle->deinitDevice())
if (RPCServer)
if (auto Err = RPCServer->deinitDevice(*this))
return Err;
#ifdef OMPT_SUPPORT
@ -599,10 +599,7 @@ Error GenericDeviceTy::setupRPCServer(GenericPluginTy &Plugin,
if (auto Err = Server.initDevice(*this, Plugin.getGlobalHandler(), Image))
return Err;
auto DeviceOrErr = Server.getDevice(*this);
if (!DeviceOrErr)
return DeviceOrErr.takeError();
RPCHandle = *DeviceOrErr;
RPCServer = &Server;
DP("Running an RPC server on device %d\n", getDeviceId());
return Plugin::success();
}

View File

@ -762,7 +762,7 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
}
/// Get the RPC server running on this device.
RPCHandleTy *getRPCHandle() const { return RPCHandle; }
RPCServerTy *getRPCServer() const { return RPCServer; }
private:
/// Register offload entry for global variable.
@ -857,7 +857,7 @@ protected:
/// A pointer to an RPC server instance attached to this device if present.
/// This is used to run the RPC server during task synchronization.
RPCHandleTy *RPCHandle;
RPCServerTy *RPCServer;
#ifdef OMPT_SUPPORT
/// OMPT callback functions

View File

@ -28,7 +28,6 @@ RPCServerTy::RPCServerTy(uint32_t NumDevices) {
// If this fails then something is catastrophically wrong, just exit.
if (rpc_status_t Err = rpc_init(NumDevices))
FATAL_MESSAGE(1, "Error initializing the RPC server: %d\n", Err);
Handles.resize(NumDevices);
#endif
}
@ -118,28 +117,10 @@ Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device,
if (auto Err = Device.dataSubmit(ClientPtr, ClientBuffer,
rpc_get_client_size(), nullptr))
return Err;
Handles[DeviceId] = std::make_unique<RPCHandleTy>(*this, Device);
#endif
return Error::success();
}
llvm::Expected<RPCHandleTy *>
RPCServerTy::getDevice(plugin::GenericDeviceTy &Device) {
#ifdef LIBOMPTARGET_RPC_SUPPORT
uint32_t DeviceId = Device.getDeviceId();
if (!Handles[DeviceId] || !rpc_get_buffer(DeviceId) ||
!rpc_get_client_buffer(DeviceId))
return plugin::Plugin::error(
"Attempt to get an RPC device while not initialized");
return Handles[DeviceId].get();
#else
return plugin::Plugin::error(
"Attempt to get an RPC device while not available");
#endif
}
Error RPCServerTy::runServer(plugin::GenericDeviceTy &Device) {
#ifdef LIBOMPTARGET_RPC_SUPPORT
if (rpc_status_t Err = rpc_handle_server(Device.getDeviceId()))

View File

@ -32,21 +32,6 @@ class DeviceImageTy;
/// these routines will perform no action.
struct RPCServerTy {
public:
/// A wrapper around a single instance of the RPC server for a given device.
/// This is provided to simplify ownership of the underlying device.
struct RPCHandleTy {
RPCHandleTy(RPCServerTy &Server, plugin::GenericDeviceTy &Device)
: Server(Server), Device(Device) {}
llvm::Error runServer() { return Server.runServer(Device); }
llvm::Error deinitDevice() { return Server.deinitDevice(Device); }
private:
RPCServerTy &Server;
plugin::GenericDeviceTy &Device;
};
RPCServerTy(uint32_t NumDevices);
/// Check if this device image is using an RPC server. This checks for the
@ -63,9 +48,6 @@ public:
plugin::GenericGlobalHandlerTy &Handler,
plugin::DeviceImageTy &Image);
/// Gets a reference to this server for a specific device.
llvm::Expected<RPCHandleTy *> getDevice(plugin::GenericDeviceTy &Device);
/// Runs the RPC server associated with the \p Device until the pending work
/// is cleared.
llvm::Error runServer(plugin::GenericDeviceTy &Device);
@ -75,13 +57,8 @@ public:
llvm::Error deinitDevice(plugin::GenericDeviceTy &Device);
~RPCServerTy();
private:
llvm::SmallVector<std::unique_ptr<RPCHandleTy>> Handles;
};
using RPCHandleTy = RPCServerTy::RPCHandleTy;
} // namespace llvm::omp::target
#endif

View File

@ -474,12 +474,12 @@ struct CUDADeviceTy : public GenericDeviceTy {
CUresult Res;
// If we have an RPC server running on this device we will continuously
// query it for work rather than blocking.
if (!getRPCHandle()) {
if (!getRPCServer()) {
Res = cuStreamSynchronize(Stream);
} else {
do {
Res = cuStreamQuery(Stream);
if (auto Err = getRPCHandle()->runServer())
if (auto Err = getRPCServer()->runServer(*this))
return Err;
} while (Res == CUDA_ERROR_NOT_READY);
}