simplified MemoryTable utility

This commit is contained in:
DH 2024-09-03 10:09:35 +03:00
parent bd39f9a070
commit 86e2d8b129
4 changed files with 153 additions and 113 deletions

View File

@ -2215,22 +2215,29 @@ struct CacheOverlayBase {
virtual void release(std::uint64_t tag) {}
std::optional<decltype(syncState)::AreaInfo> getSyncTag(std::uint64_t address,
std::uint64_t size) {
struct SyncTag {
std::uint64_t beginAddress;
std::uint64_t endAddress;
std::uint64_t value;
};
std::optional<SyncTag> getSyncTag(std::uint64_t address, std::uint64_t size) {
std::lock_guard lock(mtx);
auto it = syncState.queryArea(address);
if (it == syncState.end()) {
return {};
}
auto state = *it;
if (state.endAddress < address + size || state.beginAddress > address) {
if (it.endAddress() < address + size || it.beginAddress() > address) {
// has no single sync state
return {};
}
return state;
return SyncTag{
.beginAddress = it.beginAddress(),
.endAddress = it.endAddress(),
.value = it.get(),
};
}
bool isInSync(util::MemoryTableWithPayload<CacheSyncEntry> &table,
@ -2250,14 +2257,12 @@ struct CacheOverlayBase {
return false;
}
auto tableTag = *tableArea;
if (tableTag.beginAddress > address ||
tableTag.endAddress < address + size) {
if (tableArea.beginAddress() > address ||
tableArea.endAddress() < address + size) {
return false;
}
return tableTag.payload.tag == syncTag.payload;
return tableArea->tag == syncTag.value;
}
virtual void writeBuffer(TaskChain &taskChain,
@ -2277,6 +2282,13 @@ struct CacheOverlayBase {
}
};
struct CacheEntry {
std::uint64_t beginAddress;
std::uint64_t endAddress;
std::uint64_t tag;
Ref<CacheOverlayBase> overlay;
};
struct CacheBufferOverlay : CacheOverlayBase {
vk::Buffer buffer;
std::uint64_t bufferAddress;
@ -2305,7 +2317,12 @@ struct CacheBufferOverlay : CacheOverlayBase {
util::unreachable();
}
return *it;
return CacheEntry{
.beginAddress = it.beginAddress(),
.endAddress = it.endAddress(),
.tag = it->tag,
.overlay = it->overlay,
};
}
std::lock_guard lock(tableMtx);
@ -2313,7 +2330,12 @@ struct CacheBufferOverlay : CacheOverlayBase {
if (it == table.end()) {
util::unreachable();
}
return *it;
return CacheEntry{
.beginAddress = it.beginAddress(),
.endAddress = it.endAddress(),
.tag = it->tag,
.overlay = it->overlay,
};
};
while (size > 0) {
@ -2324,8 +2346,7 @@ struct CacheBufferOverlay : CacheOverlayBase {
auto areaSize = origAreaSize;
if (!cache) {
state.payload.overlay->readBuffer(taskChain, this, address, areaSize,
waitTask);
state.overlay->readBuffer(taskChain, this, address, areaSize, waitTask);
size -= areaSize;
address += areaSize;
continue;
@ -2335,17 +2356,16 @@ struct CacheBufferOverlay : CacheOverlayBase {
auto blockSyncStateIt = syncState.queryArea(address);
if (blockSyncStateIt == syncState.end()) {
doRead(address, areaSize, state.payload.tag, state.payload.overlay);
doRead(address, areaSize, state.tag, state.overlay);
address += areaSize;
break;
}
auto blockSyncState = *blockSyncStateIt;
auto blockSize =
std::min(blockSyncState.endAddress - address, areaSize);
std::min(blockSyncStateIt.endAddress() - address, areaSize);
if (blockSyncState.payload != state.payload.tag) {
doRead(address, areaSize, state.payload.tag, state.payload.overlay);
if (blockSyncStateIt.get() != state.tag) {
doRead(address, areaSize, state.tag, state.overlay);
}
areaSize -= blockSize;
@ -2445,7 +2465,7 @@ struct CacheImageOverlay : CacheOverlayBase {
VK_IMAGE_LAYOUT_GENERAL, 1, &region);
auto tag = *srcBuffer->getSyncTag(address, size);
std::lock_guard lock(self->mtx);
self->syncState.map(address, address + size, tag.payload);
self->syncState.map(address, address + size, tag.value);
});
return;
@ -2469,7 +2489,7 @@ struct CacheImageOverlay : CacheOverlayBase {
auto tag = *srcBuffer->getSyncTag(address, size);
std::lock_guard lock(self->mtx);
self->syncState.map(address, address + size, tag.payload);
self->syncState.map(address, address + size, tag.value);
});
}
@ -2666,8 +2686,9 @@ struct CacheLine {
std::mutex writeBackTableMtx;
util::MemoryTableWithPayload<Ref<AsyncTaskCtl>> writeBackTable;
CacheLine(RemoteMemory memory, std::uint64_t areaAddress, std::uint64_t areaSize)
:memory(memory), areaAddress(areaAddress), areaSize(areaSize) {
CacheLine(RemoteMemory memory, std::uint64_t areaAddress,
std::uint64_t areaSize)
: memory(memory), areaAddress(areaAddress), areaSize(areaSize) {
memoryOverlay = new MemoryOverlay();
memoryOverlay->memory = memory;
hostSyncTable.map(areaAddress, areaAddress + areaSize, {1, memoryOverlay});
@ -2720,25 +2741,24 @@ struct CacheLine {
auto it = writeBackTable.queryArea(address);
while (it != writeBackTable.end()) {
auto taskInfo = *it;
if (taskInfo.beginAddress >= address + size) {
if (it.beginAddress() >= address + size) {
break;
}
if (taskInfo.beginAddress >= address &&
taskInfo.endAddress <= address + size) {
if (taskInfo.payload != nullptr) {
auto task = it.get();
if (it.beginAddress() >= address && it.endAddress() <= address + size) {
if (task != nullptr) {
// another task with smaller range already in progress, we can
// cancel it
// std::printf("prev upload task cancelation\n");
taskInfo.payload->cancel();
task->cancel();
}
}
if (taskInfo.payload != nullptr) {
taskInfo.payload->wait();
if (task != nullptr) {
task->wait();
}
++it;
@ -2751,8 +2771,9 @@ struct CacheLine {
void lazyMemoryUpdate(std::uint64_t tag, std::uint64_t address) {
// std::printf("memory lazy update, address %lx\n", address);
decltype(hostSyncTable)::AreaInfo area;
std::size_t beginAddress;
std::size_t areaSize;
{
std::lock_guard lock(hostSyncMtx);
auto it = hostSyncTable.queryArea(address);
@ -2761,20 +2782,18 @@ struct CacheLine {
util::unreachable();
}
area = *it;
beginAddress = it.beginAddress();
areaSize = it.size();
}
auto areaSize = area.endAddress - area.beginAddress;
auto updateTaskChain = TaskChain::Create();
auto uploadBuffer =
getBuffer(tag, *updateTaskChain.get(), area.beginAddress, areaSize, 1,
1, shader::AccessOp::Load);
auto uploadBuffer = getBuffer(tag, *updateTaskChain.get(), beginAddress,
areaSize, 1, 1, shader::AccessOp::Load);
memoryOverlay->writeBuffer(*updateTaskChain.get(), uploadBuffer,
area.beginAddress, areaSize);
beginAddress, areaSize);
updateTaskChain->wait();
uploadBuffer->unlock(tag);
unlockReadWrite(memory.vmId, area.beginAddress, areaSize);
unlockReadWrite(memory.vmId, beginAddress, areaSize);
// std::printf("memory lazy update, %lx finish\n", address);
}
@ -3020,32 +3039,27 @@ private:
auto &table = bufferTable[offset];
if (auto it = table.queryArea(address); it != table.end()) {
auto bufferInfo = *it;
if (bufferInfo.beginAddress <= address &&
bufferInfo.endAddress >= address + size) {
if (!isAligned(address - bufferInfo.beginAddress, alignment)) {
if (it.beginAddress() <= address && it.endAddress() >= address + size) {
if (!isAligned(address - it.beginAddress(), alignment)) {
util::unreachable();
}
return bufferInfo.payload;
return it.get();
}
assert(bufferInfo.beginAddress <= address);
assert(it.beginAddress() <= address);
auto endAddress = std::max(bufferInfo.endAddress, address + size);
address = bufferInfo.beginAddress;
auto endAddress = std::max(it.endAddress(), address + size);
address = it.beginAddress();
while (it != table.end()) {
bufferInfo = *it;
if (endAddress > bufferInfo.endAddress) {
if (endAddress > it.endAddress()) {
auto nextIt = it;
if (++nextIt != table.end()) {
auto nextInfo = *nextIt;
if (nextInfo.beginAddress >= endAddress) {
if (nextIt.beginAddress() >= endAddress) {
break;
}
endAddress = nextInfo.endAddress;
endAddress = nextIt.endAddress();
}
}
++it;
@ -4817,8 +4831,8 @@ void amdgpu::device::AmdgpuDevice::handleProtectMemory(RemoteMemory memory,
protStr = "unknown";
break;
}
std::fprintf(stderr, "Allocated area at %zx, size %lx, prot %s, vmid %u\n", address,
size, protStr, memory.vmId);
std::fprintf(stderr, "Allocated area at %zx, size %lx, prot %s, vmid %u\n",
address, size, protStr, memory.vmId);
} else {
memoryAreaTable[memory.vmId].unmap(beginPage, endPage);
std::fprintf(stderr, "Unmapped area at %zx, size %lx\n", address, size);
@ -5069,8 +5083,8 @@ bool amdgpu::device::AmdgpuDevice::handleFlip(
g_bridge->flipBuffer[memory.vmId] = bufferIndex;
g_bridge->flipArg[memory.vmId] = arg;
g_bridge->flipCount[memory.vmId] = g_bridge->flipCount[memory.vmId] + 1;
auto bufferInUse =
memory.getPointer<std::uint64_t>(g_bridge->bufferInUseAddress[memory.vmId]);
auto bufferInUse = memory.getPointer<std::uint64_t>(
g_bridge->bufferInUseAddress[memory.vmId]);
if (bufferInUse != nullptr) {
bufferInUse[bufferIndex] = 0;
}

View File

@ -47,8 +47,7 @@ orbis::ErrorCode DmemDevice::mmap(void **address, std::uint64_t len,
int memoryType = 0;
if (auto allocationInfoIt = allocations.queryArea(directMemoryStart);
allocationInfoIt != allocations.end()) {
auto allocationInfo = *allocationInfoIt;
memoryType = allocationInfo.payload.memoryType;
memoryType = allocationInfoIt->memoryType;
}
auto result =
@ -183,25 +182,24 @@ static orbis::ErrorCode dmem_ioctl(orbis::File *file, std::uint64_t request,
auto queryInfo = *it;
if (queryInfo.payload.memoryType == -1u) {
if (it->memoryType == -1u) {
return orbis::ErrorCode::ACCES;
}
if ((args->flags & 1) == 0) {
if (queryInfo.endAddress <= args->offset) {
if (it.endAddress() <= args->offset) {
return orbis::ErrorCode::ACCES;
}
} else {
if (queryInfo.beginAddress > args->offset ||
queryInfo.endAddress <= args->offset) {
if (it.beginAddress() > args->offset || it.endAddress() <= args->offset) {
return orbis::ErrorCode::ACCES;
}
}
DirectMemoryQueryInfo info{
.start = queryInfo.beginAddress,
.end = queryInfo.endAddress,
.memoryType = queryInfo.payload.memoryType,
.start = it.beginAddress(),
.end = it.endAddress(),
.memoryType = it->memoryType,
};
ORBIS_LOG_WARNING("dmem directMemoryQuery", device->index, args->devIndex,
@ -255,20 +253,19 @@ orbis::ErrorCode DmemDevice::allocate(std::uint64_t *start,
auto it = allocations.lowerBound(offset);
if (it != allocations.end()) {
auto allocation = *it;
if (allocation.payload.memoryType == -1u) {
if (offset < allocation.beginAddress) {
offset = allocation.beginAddress + alignment - 1;
if (it->memoryType == -1u) {
if (offset < it.beginAddress()) {
offset = it.beginAddress() + alignment - 1;
offset &= ~(alignment - 1);
}
if (offset + len >= allocation.endAddress) {
offset = allocation.endAddress;
if (offset + len >= it.endAddress()) {
offset = it.endAddress();
continue;
}
} else {
if (offset + len > allocation.beginAddress) {
offset = allocation.endAddress;
if (offset + len > it.beginAddress()) {
offset = it.endAddress();
continue;
}
}
@ -315,25 +312,23 @@ orbis::ErrorCode DmemDevice::queryMaxFreeChunkSize(std::uint64_t *start,
break;
}
auto allocation = *it;
if (allocation.payload.memoryType == -1u) {
if (offset < allocation.beginAddress) {
offset = allocation.beginAddress + alignment - 1;
if (it->memoryType == -1u) {
if (offset < it.beginAddress()) {
offset = it.beginAddress() + alignment - 1;
offset &= ~(alignment - 1);
}
if (allocation.endAddress > offset &&
resultSize < allocation.endAddress - offset) {
resultSize = allocation.endAddress - offset;
if (it.endAddress() > offset && resultSize < it.endAddress() - offset) {
resultSize = it.endAddress() - offset;
resultOffset = offset;
}
} else if (offset > allocation.beginAddress &&
resultSize < offset - allocation.beginAddress) {
resultSize = offset - allocation.beginAddress;
} else if (offset > it.beginAddress() &&
resultSize < offset - it.beginAddress()) {
resultSize = offset - it.beginAddress();
resultOffset = offset;
}
offset = allocation.endAddress;
offset = it.endAddress();
}
*start = resultOffset;

View File

@ -929,7 +929,7 @@ void *rx::vm::map(void *addr, std::uint64_t len, std::int32_t prot,
{
MapInfo info;
if (auto it = gMapInfo.queryArea(address); it != gMapInfo.end()) {
info = (*it).payload;
info = it.get();
}
info.device = device;
info.flags = flags;
@ -1124,29 +1124,27 @@ bool rx::vm::virtualQuery(const void *addr, std::int32_t flags,
return false;
}
auto queryInfo = *it;
if ((flags & 1) == 0) {
if (queryInfo.endAddress <= address) {
if (it.endAddress() <= address) {
return false;
}
} else {
if (queryInfo.beginAddress > address || queryInfo.endAddress <= address) {
if (it.beginAddress() > address || it.endAddress() <= address) {
return false;
}
}
std::int32_t memoryType = 0;
std::uint32_t blockFlags = 0;
if (queryInfo.payload.device != nullptr) {
if (it->device != nullptr) {
if (auto dmem =
dynamic_cast<DmemDevice *>(queryInfo.payload.device.get())) {
auto dmemIt = dmem->allocations.queryArea(queryInfo.payload.offset);
dynamic_cast<DmemDevice *>(it->device.get())) {
auto dmemIt = dmem->allocations.queryArea(it->offset);
if (dmemIt == dmem->allocations.end()) {
return false;
}
auto alloc = *dmemIt;
memoryType = alloc.payload.memoryType;
memoryType = dmemIt->memoryType;
blockFlags = kBlockFlagDirectMemory;
std::fprintf(stderr, "virtual query %p", addr);
std::fprintf(stderr, "memory type: %u\n", memoryType);
@ -1154,11 +1152,11 @@ bool rx::vm::virtualQuery(const void *addr, std::int32_t flags,
// TODO
}
std::int32_t prot = getPageProtectionImpl(queryInfo.beginAddress);
std::int32_t prot = getPageProtectionImpl(it.beginAddress());
*info = {
.start = queryInfo.beginAddress,
.end = queryInfo.endAddress,
.start = it.beginAddress(),
.end = it.endAddress(),
.protection = prot,
.memoryType = memoryType,
.flags = blockFlags,
@ -1167,7 +1165,7 @@ bool rx::vm::virtualQuery(const void *addr, std::int32_t flags,
ORBIS_LOG_ERROR("virtualQuery", addr, flags, info->start, info->end,
info->protection, info->memoryType, info->flags);
std::memcpy(info->name, queryInfo.payload.name, sizeof(info->name));
std::memcpy(info->name, it->name, sizeof(info->name));
return true;
}
@ -1177,7 +1175,7 @@ void rx::vm::setName(std::uint64_t start, std::uint64_t size,
MapInfo info;
if (auto it = gMapInfo.queryArea(start); it != gMapInfo.end()) {
info = (*it).payload;
info = it.get();
}
std::strncpy(info.name, name, sizeof(info.name));

View File

@ -214,7 +214,9 @@ public:
struct AreaInfo {
std::uint64_t beginAddress;
std::uint64_t endAddress;
PayloadT payload;
PayloadT &payload;
std::size_t size() const { return endAddress - beginAddress; }
};
class iterator {
@ -230,6 +232,12 @@ public:
return {it->first, std::next(it)->first, it->second.second};
}
std::uint64_t beginAddress() const { return it->first; }
std::uint64_t endAddress() const { return std::next(it)->first; }
std::uint64_t size() const { return endAddress() - beginAddress(); }
PayloadT &get() const { return it->second.second; }
PayloadT *operator->() const { return &it->second.second; }
iterator &operator++() {
++it;
@ -242,6 +250,8 @@ public:
bool operator==(iterator other) const { return it == other.it; }
bool operator!=(iterator other) const { return it != other.it; }
friend MemoryTableWithPayload;
};
iterator begin() { return iterator(mAreas.begin()); }
@ -252,18 +262,14 @@ public:
iterator lowerBound(std::uint64_t address) {
auto it = mAreas.lower_bound(address);
if (it == mAreas.end()) {
if (it == mAreas.end() || it->second.first != Kind::X) {
return it;
}
if (it->first == address) {
if (it->second.first == Kind::X) {
++it;
}
++it;
} else {
if (it->second.first != Kind::O) {
--it;
}
--it;
}
return it;
@ -296,8 +302,8 @@ public:
return endAddress < address ? mAreas.end() : it;
}
void map(std::uint64_t beginAddress, std::uint64_t endAddress,
PayloadT payload, bool merge = true) {
iterator map(std::uint64_t beginAddress, std::uint64_t endAddress,
PayloadT payload, bool merge = true) {
assert(beginAddress < endAddress);
auto [beginIt, beginInserted] =
mAreas.emplace(beginAddress, std::pair{Kind::O, payload});
@ -370,7 +376,7 @@ public:
}
if (!merge) {
return;
return origBegin;
}
if (origBegin->second.first == Kind::XO) {
@ -378,6 +384,7 @@ public:
if (prevBegin->second.second == origBegin->second.second) {
mAreas.erase(origBegin);
origBegin = prevBegin;
}
}
@ -386,6 +393,32 @@ public:
mAreas.erase(endIt);
}
}
return origBegin;
}
void unmap(iterator it) {
auto openIt = it.it;
auto closeIt = openIt;
++closeIt;
if (openIt->second.first == Kind::XO) {
openIt->second.first = Kind::X;
openIt->second.second = {};
} else {
mAreas.erase(openIt);
}
if (closeIt->second.first == Kind::XO) {
closeIt->second.first = Kind::O;
} else {
mAreas.erase(closeIt);
}
}
void unmap(std::uint64_t beginAddress, std::uint64_t endAddress) {
// FIXME: can be optimized
unmap(map(beginAddress, endAddress, PayloadT{}, false));
}
};
} // namespace rx