mirror of
https://gitee.com/openharmony/ai_neural_network_runtime
synced 2024-12-04 14:11:01 +00:00
modify cache
Signed-off-by: w30052974 <wangyifan94@huawei.com>
This commit is contained in:
parent
24d94d5f56
commit
60ed86ae6d
@ -81,7 +81,7 @@ NNRT_API OH_NN_ReturnCode OH_NNDevice_GetName(size_t deviceID, const char **name
|
||||
BackendManager& backendManager = BackendManager::GetInstance();
|
||||
const std::string& backendName = backendManager.GetBackendName(deviceID);
|
||||
if (backendName.empty()) {
|
||||
LOGE("OH_NNDevice_GetName failed, error happened when getting name of deviceID %{public}zu.", deviceID);
|
||||
LOGE("OH_NNDevice_GetName failed, error happened when getting name of deviceID.");
|
||||
*name = nullptr;
|
||||
return OH_NN_FAILED;
|
||||
}
|
||||
@ -106,7 +106,7 @@ NNRT_API OH_NN_ReturnCode OH_NNDevice_GetType(size_t deviceID, OH_NN_DeviceType*
|
||||
|
||||
OH_NN_ReturnCode ret = backend->GetBackendType(*deviceType);
|
||||
if (ret != OH_NN_SUCCESS) {
|
||||
LOGE("OH_NNDevice_GetType failed, device id: %{public}zu.", deviceID);
|
||||
LOGE("OH_NNDevice_GetType failed.");
|
||||
return ret;
|
||||
}
|
||||
return OH_NN_SUCCESS;
|
||||
@ -978,7 +978,7 @@ NNRT_API NN_Tensor* OH_NNTensor_Create(size_t deviceID, NN_TensorDesc *tensorDes
|
||||
BackendManager& backendManager = BackendManager::GetInstance();
|
||||
std::shared_ptr<Backend> backend = backendManager.GetBackend(deviceID);
|
||||
if (backend == nullptr) {
|
||||
LOGE("OH_NNTensor_Create failed, passed invalid backend name %{public}zu.", deviceID);
|
||||
LOGE("OH_NNTensor_Create failed, passed invalid backend name.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@ -1010,7 +1010,7 @@ NNRT_API NN_Tensor* OH_NNTensor_CreateWithSize(size_t deviceID, NN_TensorDesc *t
|
||||
BackendManager& backendManager = BackendManager::GetInstance();
|
||||
std::shared_ptr<Backend> backend = backendManager.GetBackend(deviceID);
|
||||
if (backend == nullptr) {
|
||||
LOGE("OH_NNTensor_CreateWithSize failed, passed invalid backend name %{public}zu.", deviceID);
|
||||
LOGE("OH_NNTensor_CreateWithSize failed, passed invalid backend name.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@ -1068,7 +1068,7 @@ NNRT_API NN_Tensor* OH_NNTensor_CreateWithFd(size_t deviceID,
|
||||
BackendManager& backendManager = BackendManager::GetInstance();
|
||||
std::shared_ptr<Backend> backend = backendManager.GetBackend(deviceID);
|
||||
if (backend == nullptr) {
|
||||
LOGE("OH_NNTensor_CreateWithFd failed, passed invalid backend name %{public}zu.", deviceID);
|
||||
LOGE("OH_NNTensor_CreateWithFd failed, passed invalid backend name.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -503,7 +503,7 @@ NNRT_API OH_NN_ReturnCode OH_NNModel_BuildFromLiteGraph(OH_NNModel *model, const
|
||||
return innerModel->BuildFromLiteGraph(pLiteGraph, extensionConfig);
|
||||
}
|
||||
|
||||
NNRT_API bool OH_NNModel_HasCache(const char *cacheDir, const char *modelName)
|
||||
NNRT_API bool OH_NNModel_HasCache(const char *cacheDir, const char *modelName, uint32_t version)
|
||||
{
|
||||
if (cacheDir == nullptr) {
|
||||
LOGI("OH_NNModel_HasCache get empty cache directory.");
|
||||
@ -543,8 +543,15 @@ NNRT_API bool OH_NNModel_HasCache(const char *cacheDir, const char *modelName)
|
||||
}
|
||||
|
||||
int64_t fileNumber{0};
|
||||
int64_t cacheVersion{0};
|
||||
if (!ifs.read(reinterpret_cast<char*>(&(fileNumber)), sizeof(fileNumber))) {
|
||||
LOGI("OH_NNModel_HasCache read cache info file failed.");
|
||||
LOGI("OH_NNModel_HasCache read fileNumber cache info file failed.");
|
||||
ifs.close();
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!ifs.read(reinterpret_cast<char*>(&(cacheVersion)), sizeof(cacheVersion))) {
|
||||
LOGI("OH_NNModel_HasCache read cacheVersion cache info file failed.");
|
||||
ifs.close();
|
||||
return false;
|
||||
}
|
||||
@ -557,6 +564,11 @@ NNRT_API bool OH_NNModel_HasCache(const char *cacheDir, const char *modelName)
|
||||
exist = (exist && (stat(cacheModelPath.c_str(), &buffer) == 0));
|
||||
}
|
||||
|
||||
if (cacheVersion != version) {
|
||||
LOGW("OH_NNModel_HasCache version is not match.");
|
||||
exist = false;
|
||||
}
|
||||
|
||||
return exist;
|
||||
}
|
||||
|
||||
|
@ -213,8 +213,9 @@ OH_NN_ReturnCode NNCompiledCache::GenerateCacheModel(const std::vector<OHOS::Neu
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::string cachePath = path;
|
||||
for (size_t i = 0; i < cacheNumber; ++i) {
|
||||
std::string cacheModelFile = cacheDir + "/" + m_modelName + std::to_string(i) + ".nncache";
|
||||
std::string cacheModelFile = cachePath + "/" + m_modelName + std::to_string(i) + ".nncache";
|
||||
std::ofstream cacheModelStream(cacheModelFile, std::ios::binary | std::ios::out | std::ios::trunc);
|
||||
if (cacheModelStream.fail()) {
|
||||
LOGE("[NNCompiledCache] GenerateCacheModel failed, model cache file is invalid.");
|
||||
@ -265,7 +266,8 @@ OH_NN_ReturnCode NNCompiledCache::WriteCacheInfo(uint32_t cacheSize,
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::string cacheInfoPath = cacheDir + "/" + m_modelName + "cache_info.nncache";
|
||||
std::string cachePath = path;
|
||||
std::string cacheInfoPath = cachePath + "/" + m_modelName + "cache_info.nncache";
|
||||
std::ofstream cacheInfoStream(cacheInfoPath, std::ios::binary | std::ios::out | std::ios::trunc);
|
||||
if (cacheInfoStream.fail()) {
|
||||
LOGE("[NNCompiledCache] WriteCacheInfo failed, model cache info file is invalid.");
|
||||
@ -303,11 +305,9 @@ OH_NN_ReturnCode NNCompiledCache::CheckCacheInfo(NNCompiledCacheInfo& modelCache
|
||||
// it is transformed from size_t value, so the transform here will not truncate value.
|
||||
size_t deviceId = static_cast<size_t>(modelCacheInfo.deviceId);
|
||||
if (deviceId != m_backendID) {
|
||||
LOGE("[NNCompiledCache] CheckCacheInfo failed. The deviceId=%{public}zu in the cache files "
|
||||
"is different from current deviceId=%{public}zu,"
|
||||
"please change the cache directory or current deviceId.",
|
||||
deviceId,
|
||||
m_backendID);
|
||||
LOGE("[NNCompiledCache] CheckCacheInfo failed. The deviceId in the cache files "
|
||||
"is different from current deviceId,"
|
||||
"please change the cache directory or current deviceId.");
|
||||
infoCacheFile.close();
|
||||
return OH_NN_INVALID_PARAMETER;
|
||||
}
|
||||
|
@ -450,6 +450,32 @@ OH_NN_ReturnCode NNCompiler::OnlineBuild()
|
||||
{
|
||||
// cache存在,从cache直接复原prepareModel、input/output TensorDesc
|
||||
OH_NN_ReturnCode ret = RestoreFromCacheFile();
|
||||
if (ret != OH_NN_SUCCESS) {
|
||||
LOGE("[NNCompiler] cache file is failed, to delete cache.");
|
||||
char path[PATH_MAX];
|
||||
if (realpath(m_cachePath.c_str(), path) == nullptr) {
|
||||
LOGE("[NNCompiledCache] WriteCacheInfo failed, fail to get the real path of cache Dir.");
|
||||
return OH_NN_INVALID_PARAMETER;
|
||||
}
|
||||
|
||||
std::string cachePath = path;
|
||||
std::string firstCache = cachePath + "/" + m_extensionConfig.modelName + "0.nncache";
|
||||
std::string secondCache = cachePath + "/" + m_extensionConfig.modelName + "1.nncache";
|
||||
std::string thirdCache = cachePath + "/" + m_extensionConfig.modelName + "2.nncache";
|
||||
std::string cacheInfo = cachePath + "/" + m_extensionConfig.modelName + "cache_info.nncache";
|
||||
if (std::filesystem::exists(firstCache)) {
|
||||
std::filesystem::remove_all(firstCache);
|
||||
}
|
||||
if (std::filesystem::exists(secondCache)) {
|
||||
std::filesystem::remove_all(secondCache);
|
||||
}
|
||||
if (std::filesystem::exists(thirdCache)) {
|
||||
std::filesystem::remove_all(thirdCache);
|
||||
}
|
||||
if (std::filesystem::exists(cacheInfo)) {
|
||||
std::filesystem::remove_all(cacheInfo);
|
||||
}
|
||||
}
|
||||
if (ret == OH_NN_OPERATION_FORBIDDEN) {
|
||||
LOGE("[NNCompiler] Build failed, operation is forbidden.");
|
||||
return ret;
|
||||
|
@ -139,7 +139,7 @@ OH_NN_ReturnCode OH_NNModel_BuildFromMetaGraph(OH_NNModel *model, const void *me
|
||||
* @since 11
|
||||
* @version 1.0
|
||||
*/
|
||||
bool OH_NNModel_HasCache(const char *cacheDir, const char *modelName);
|
||||
bool OH_NNModel_HasCache(const char *cacheDir, const char *modelName, uint32_t version);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user